diff --git a/src/Lean/Meta/Match/Match.lean b/src/Lean/Meta/Match/Match.lean index c5543df8d8..edcfecfa5e 100644 --- a/src/Lean/Meta/Match/Match.lean +++ b/src/Lean/Meta/Match/Match.lean @@ -779,7 +779,7 @@ where `v` is a universe parameter or 0 if `B[a_1, ..., a_n]` is a proposition. - def mkMatcher (input : MkMatcherInput) : MetaM MatcherResult := do let ⟨matcherName, matchType, discrInfos, lhss⟩ := input let numDiscrs := discrInfos.size - let numEqs := (discrInfos.filter fun info => info.hName?.isSome).size + let numEqs := getNumEqsFromDiscrInfos discrInfos checkNumPatterns numDiscrs lhss forallBoundedTelescope matchType numDiscrs fun discrs matchTypeBody => do /- We generate an matcher that can eliminate using different motives with different universe levels. diff --git a/src/Lean/Meta/Match/MatchEqs.lean b/src/Lean/Meta/Match/MatchEqs.lean index 1f57f9ad09..f8f935a396 100644 --- a/src/Lean/Meta/Match/MatchEqs.lean +++ b/src/Lean/Meta/Match/MatchEqs.lean @@ -62,8 +62,15 @@ def unfoldNamedPattern (e : Expr) : MetaM Expr := do Meta.transform e (pre := visit) /-- - Similar to `forallTelescopeReducing`, but eliminates arguments for named parameters and the associated - equation proofs. The continuation `k` takes four arguments `ys args mask type`. + Similar to `forallTelescopeReducing`, but + + 1. Eliminates arguments for named parameters and the associated equation proofs. + + 2. Equality parameters associated with the `h : discr` notation are replaced with `rfl` proofs. + Recall that this kind of parameter always occurs after the parameters correspoting to pattern variables. + `numNonEqParams` is the size of the prefix. + + The continuation `k` takes four arguments `ys args mask type`. - `ys` are variables for the hypotheses that have not been eliminated. - `args` are the arguments for the alternative `alt` that has type `altType`. `ys.size <= args.size` - `mask[i]` is true if the hypotheses has not been eliminated. `mask.size == args.size`. @@ -71,26 +78,31 @@ def unfoldNamedPattern (e : Expr) : MetaM Expr := do We use the `mask` to build the splitter proof. See `mkSplitterProof`. -/ -partial def forallAltTelescope (altType : Expr) (k : Array Expr → Array Expr → Array Bool → Expr → MetaM α) : MetaM α := do - go #[] #[] #[] altType +partial def forallAltTelescope (altType : Expr) (numNonEqParams : Nat) (k : Array Expr → Array Expr → Array Bool → Expr → MetaM α) : MetaM α := do + go #[] #[] #[] 0 altType where - go (ys : Array Expr) (args : Array Expr) (mask : Array Bool) (type : Expr) : MetaM α := do + go (ys : Array Expr) (args : Array Expr) (mask : Array Bool) (i : Nat) (type : Expr) : MetaM α := do let type ← whnfForall type match type with | Expr.forallE n d b .. => - let d ← unfoldNamedPattern d - withLocalDeclD n d fun y => do - let typeNew := b.instantiate1 y - if let some (_, lhs, rhs) ← matchEq? d then - if lhs.isFVar && ys.contains lhs && args.contains lhs && isNamedPatternProof typeNew y then - let some i := ys.getIdx? lhs | unreachable! - let ys := ys.eraseIdx i - let mask := mask.set! i false - let args := args.map fun arg => if arg == lhs then rhs else arg - let args := args.push (← mkEqRefl rhs) - let typeNew := typeNew.replaceFVar lhs rhs - return (← go ys args (mask.push false) typeNew) - go (ys.push y) (args.push y) (mask.push true) typeNew + if i < numNonEqParams then + let d ← unfoldNamedPattern d + withLocalDeclD n d fun y => do + let typeNew := b.instantiate1 y + if let some (_, lhs, rhs) ← matchEq? d then + if lhs.isFVar && ys.contains lhs && args.contains lhs && isNamedPatternProof typeNew y then + let some i := ys.getIdx? lhs | unreachable! + let ys := ys.eraseIdx i + let mask := mask.set! i false + let args := args.map fun arg => if arg == lhs then rhs else arg + let args := args.push (← mkEqRefl rhs) + let typeNew := typeNew.replaceFVar lhs rhs + return (← go ys args (mask.push false) (i+1) typeNew) + go (ys.push y) (args.push y) (mask.push true) (i+1) typeNew + else + let some (_, _, rhs) ← matchEq? d | throwError "unexpected match alternative type{indentExpr altType}" + let arg ← mkEqRefl rhs + go ys (args.push arg) (mask.push false) (i+1) (b.instantiate1 arg) | _ => let type ← unfoldNamedPattern type /- Recall that alternatives that do not have variables have a `Unit` parameter to ensure @@ -258,16 +270,17 @@ private def substSomeVar (mvarId : MVarId) : MetaM (Array MVarId) := withMVarCon /-- Helper method for proving a conditional equational theorem associated with an alternative of the `match`-eliminator `matchDeclName`. `type` contains the type of the theorem. -/ -partial def proveCondEqThm (matchDeclName : Name) (type : Expr) : MetaM Expr := do +partial def proveCondEqThm (matchDeclName : Name) (type : Expr) : MetaM Expr := withLCtx {} {} do let type ← instantiateMVars type forallTelescope type fun ys target => do let mvar0 ← mkFreshExprSyntheticOpaqueMVar target + trace[Meta.Match.matchEqs] "proveCondEqThm {mvar0.mvarId!}" let mvarId ← deltaTarget mvar0.mvarId! (· == matchDeclName) - trace[Meta.Match.matchEqs] "{MessageData.ofGoal mvarId}" withDefault <| go mvarId 0 mkLambdaFVars ys (← instantiateMVars mvar0) where go (mvarId : MVarId) (depth : Nat) : MetaM Unit := withIncRecDepth do + trace[Meta.Match.matchEqs] "proveCondEqThm.go {mvarId}" let mvarId' ← modifyTargetEqLHS mvarId whnfCore let mvarId := mvarId' let subgoals ← @@ -395,6 +408,27 @@ where let mvarId ← tryClearMany mvarId (alts.map (·.fvarId!)) proveSubgoalLoop mvarId +/-- + Create new alternatives (aka minor premises) by replacing `discrs` with `patterns` at `alts`. + Recall that `alts` depends on `discrs` when `numDiscrEqs > 0`, where `numDiscrEqs` is the number of discriminants + annotated with `h : discr`. +-/ +private partial def withNewAlts (numDiscrEqs : Nat) (discrs : Array Expr) (patterns : Array Expr) (alts : Array Expr) (k : Array Expr → MetaM α) : MetaM α := + if numDiscrEqs == 0 then + k alts + else + go 0 #[] +where + go (i : Nat) (altsNew : Array Expr) : MetaM α := do + if h : i < alts.size then + let alt := alts.get ⟨i, h⟩ + let altLocalDecl ← getFVarLocalDecl alt + let typeNew := altLocalDecl.type.replaceFVars discrs patterns + withLocalDecl altLocalDecl.userName altLocalDecl.binderInfo typeNew fun altNew => + go (i+1) (altsNew.push altNew) + else + k altsNew + /-- Create conditional equations and splitter for the given match auxiliary declaration. -/ private partial def mkEquationsFor (matchDeclName : Name) : MetaM MatchEqns := withLCtx {} {} do @@ -404,6 +438,7 @@ private partial def mkEquationsFor (matchDeclName : Name) : MetaM MatchEqns := let constInfo ← getConstInfo matchDeclName let us := constInfo.levelParams.map mkLevelParam let some matchInfo ← getMatcherInfo? matchDeclName | throwError "'{matchDeclName}' is not a matcher function" + let numDiscrEqs := getNumEqsFromDiscrInfos matchInfo.discrInfos forallTelescopeReducing constInfo.type fun xs matchResultType => do let mut eqnNames := #[] let params := xs[:matchInfo.numParams] @@ -416,10 +451,12 @@ private partial def mkEquationsFor (matchDeclName : Name) : MetaM MatchEqns := let mut splitterAltTypes := #[] let mut splitterAltNumParams := #[] let mut altArgMasks := #[] -- masks produced by `forallAltTelescope` - for alt in alts do + for i in [:alts.size] do + let altNumParams := matchInfo.altNumParams[i] + let altNonEqNumParams := altNumParams - numDiscrEqs let thmName := baseName ++ ((`eq).appendIndexAfter idx) eqnNames := eqnNames.push thmName - let (notAlt, splitterAltType, splitterAltNumParam, argMask) ← forallAltTelescope (← inferType alt) fun ys rhsArgs argMask altResultType => do + let (notAlt, splitterAltType, splitterAltNumParam, argMask) ← forallAltTelescope (← inferType alts[i]) altNonEqNumParams fun ys rhsArgs argMask altResultType => do let patterns := altResultType.getAppArgs let mut hs := #[] for notAlt in notAlts do @@ -437,20 +474,24 @@ private partial def mkEquationsFor (matchDeclName : Name) : MetaM MatchEqns := else notAlt ← mkArrow (← mkHEq discr pattern) notAlt notAlt ← mkForallFVars (discrs ++ ys) notAlt - let lhs := mkAppN (mkConst constInfo.name us) (params ++ #[motive] ++ patterns ++ alts) - let rhs := mkAppN alt rhsArgs - let thmType ← mkEq lhs rhs - let thmType ← hs.foldrM (init := thmType) mkArrow - let thmType ← mkForallFVars (params ++ #[motive] ++ alts ++ ys) thmType - let thmType ← unfoldNamedPattern thmType - let thmVal ← proveCondEqThm matchDeclName thmType - addDecl <| Declaration.thmDecl { - name := thmName - levelParams := constInfo.levelParams - type := thmType - value := thmVal - } - return (notAlt, splitterAltType, splitterAltNumParam, argMask) + /- Recall that when we use the `h : discr`, the alternative type depends on the discriminant. + Thus, we need to create new `alts`. -/ + withNewAlts numDiscrEqs discrs patterns alts fun alts => do + let alt := alts[i] + let lhs := mkAppN (mkConst constInfo.name us) (params ++ #[motive] ++ patterns ++ alts) + let rhs := mkAppN alt rhsArgs + let thmType ← mkEq lhs rhs + let thmType ← hs.foldrM (init := thmType) mkArrow + let thmType ← mkForallFVars (params ++ #[motive] ++ ys ++ alts) thmType + let thmType ← unfoldNamedPattern thmType + let thmVal ← proveCondEqThm matchDeclName thmType + addDecl <| Declaration.thmDecl { + name := thmName + levelParams := constInfo.levelParams + type := thmType + value := thmVal + } + return (notAlt, splitterAltType, splitterAltNumParam, argMask) notAlts := notAlts.push notAlt splitterAltTypes := splitterAltTypes.push splitterAltType splitterAltNumParams := splitterAltNumParams.push splitterAltNumParam diff --git a/src/Lean/Meta/Match/MatcherInfo.lean b/src/Lean/Meta/Match/MatcherInfo.lean index da30ab3c0a..8d592fabe9 100644 --- a/src/Lean/Meta/Match/MatcherInfo.lean +++ b/src/Lean/Meta/Match/MatcherInfo.lean @@ -47,6 +47,16 @@ def MatcherInfo.getFirstAltPos (info : MatcherInfo) : Nat := def MatcherInfo.getMotivePos (info : MatcherInfo) : Nat := info.numParams +def getNumEqsFromDiscrInfos (infos : Array DiscrInfo) : Nat := Id.run do + let mut r := 0 + for info in infos do + if info.hName?.isSome then + r := r + 1 + return r + +def MatcherInfo.getNumDiscrEqs (info : MatcherInfo) : Nat := + getNumEqsFromDiscrInfos info.discrInfos + namespace Extension structure Entry where