diff --git a/src/Lean/Meta/Match/MatcherApp/Transform.lean b/src/Lean/Meta/Match/MatcherApp/Transform.lean index 28183b8138..6715f05559 100644 --- a/src/Lean/Meta/Match/MatcherApp/Transform.lean +++ b/src/Lean/Meta/Match/MatcherApp/Transform.lean @@ -176,16 +176,32 @@ def arrowDomainsN (n : Nat) (type : Expr) : MetaM (Array Expr) := do type := β return ts +private def withUserNamesImpl {α} (fvars : Array Expr) (names : Array Name) (k : MetaM α) : MetaM α := do + let lctx := (Array.zip fvars names).foldl (init := ← (getLCtx)) fun lctx (fvar, name) => + lctx.setUserName fvar.fvarId! name + withTheReader Meta.Context (fun ctx => { ctx with lctx }) k + /-- Sets the user name of the FVars in the local context according to the given array of names. If they differ in size the shorter size wins. -/ -def withUserNames {α} (fvars : Array Expr) (names : Array Name) (k : MetaM α ) : MetaM α := do - let lctx := (Array.zip fvars names).foldl (init := ← (getLCtx)) fun lctx (fvar, name) => - lctx.setUserName fvar.fvarId! name - withTheReader Meta.Context (fun ctx => { ctx with lctx }) k +def withUserNames {n} [MonadControlT MetaM n] [Monad n] + {α} (fvars : Array Expr) (names : Array Name) (k : n α) : n α := do + mapMetaM (withUserNamesImpl fvars names) k +/- +`Match.forallAltTelescope` lifted to a monad transformer +(and only passing those arguments that we care about below) +-/ +private def forallAltTelescope' + {n} [Monad n] [MonadControlT MetaM n] + {α} (origAltType : Expr) (numParams numDiscrEqs : Nat) + (k : Array Expr → Array Expr → n α) : n α := do + map2MetaM (fun k => + Match.forallAltTelescope origAltType (numParams - numDiscrEqs) 0 + fun ys _eqs args _mask _bodyType => k ys args + ) k /-- Performs a possibly type-changing transformation to a `MatcherApp`. @@ -208,14 +224,17 @@ This function works even if the the type of alternatives do *not* fit the inferr allows you to post-process the `MatcherApp` with `MatcherApp.inferMatchType`, which will infer a type, given all the alternatives. -/ -def transform (matcherApp : MatcherApp) +def transform + {n} [MonadLiftT MetaM n] [MonadControlT MetaM n] [Monad n] [MonadError n] [MonadEnv n] [MonadLog n] + [AddMessageContext n] [MonadOptions n] + (matcherApp : MatcherApp) (useSplitter := false) (addEqualities : Array Bool := mkArray matcherApp.discrs.size false) - (onParams : Expr → MetaM Expr := pure) - (onMotive : Array Expr → Expr → MetaM Expr := fun _ e => pure e) - (onAlt : Expr → Expr → MetaM Expr := fun _ e => pure e) - (onRemaining : Array Expr → MetaM (Array Expr) := pure) : - MetaM MatcherApp := do + (onParams : Expr → n Expr := pure) + (onMotive : Array Expr → Expr → n Expr := fun _ e => pure e) + (onAlt : Expr → Expr → n Expr := fun _ e => pure e) + (onRemaining : Array Expr → n (Array Expr) := pure) : + n MatcherApp := do if addEqualities.size != matcherApp.discrs.size then throwError "MatcherApp.transform: addEqualities has wrong size" @@ -247,7 +266,7 @@ def transform (matcherApp : MatcherApp) -- Prepend (x = e) → to the motive when an equality is requested for arg in motiveArgs, discr in discrs', b in addEqualities do if b then - motiveBody' ← mkArrow (← mkEq discr arg) motiveBody' + motiveBody' ← liftMetaM <| mkArrow (← mkEq discr arg) motiveBody' return (← mkLambdaFVars motiveArgs motiveBody', ← getLevel motiveBody') @@ -290,7 +309,7 @@ def transform (matcherApp : MatcherApp) splitterNumParams in matchEqns.splitterAltNumParams, origAltType in origAltTypes, altType in altTypes do - let alt' ← Match.forallAltTelescope origAltType (numParams - numDiscrEqs) 0 fun ys _eqs args _mask _bodyType => do + let alt' ← forallAltTelescope' origAltType (numParams - numDiscrEqs) 0 fun ys args => do let altType ← instantiateForall altType ys -- The splitter inserts its extra paramters after the first ys.size parameters, before -- the parameters for the numDiscrEqs diff --git a/src/Lean/Meta/Tactic/FunInd.lean b/src/Lean/Meta/Tactic/FunInd.lean index 31c6a871c9..8db1be3f10 100644 --- a/src/Lean/Meta/Tactic/FunInd.lean +++ b/src/Lean/Meta/Tactic/FunInd.lean @@ -166,7 +166,7 @@ open Lean Elab Meta This is used when replacing parameters with different expressions. This way it will not be picked up by metavariables. -/ -def removeLamda {α} (e : Expr) (k : FVarId → Expr → MetaM α) : MetaM α := do +def removeLamda {n} [MonadLiftT MetaM n] [MonadError n] [MonadNameGenerator n] [Monad n] {α} (e : Expr) (k : FVarId → Expr → n α) : n α := do let .lam _n _d b _bi := ← whnfD e | throwError "removeLamda: expected lambda, got {e}" let x ← mkFreshFVarId let b := b.instantiate1 (.fvar x) @@ -386,10 +386,15 @@ def substVarAfter (mvarId : MVarId) (x : FVarId) : MetaM MVarId := do mvarId ← trySubstVar mvarId localDecl.fvarId return mvarId +/-- +Helper monad to traverse the function body, collecting the cases as mvars +-/ +abbrev M α := StateT (Array MVarId) MetaM α + /-- Base case of `buildInductionBody`: Construct a case for the final induction hypthesis. -/ def buildInductionCase (fn : Expr) (oldIH newIH : FVarId) (toClear toPreserve : Array FVarId) - (goal : Expr) (IHs : Array Expr) (e : Expr) : MetaM Expr := do + (goal : Expr) (IHs : Array Expr) (e : Expr) : M Expr := do let IHs := IHs ++ (← collectIHs fn oldIH newIH e) let IHs ← deduplicateIHs IHs @@ -398,7 +403,8 @@ def buildInductionCase (fn : Expr) (oldIH newIH : FVarId) (toClear toPreserve : mvarId ← assertIHs IHs mvarId for fvarId in toClear do mvarId ← mvarId.clear fvarId - _ ← mvarId.cleanup (toPreserve := toPreserve) + mvarId ← mvarId.cleanup (toPreserve := toPreserve) + modify (·.push mvarId) let mvar ← instantiateMVars mvar pure mvar @@ -437,8 +443,14 @@ def maskArray {α} (mask : Array Bool) (xs : Array α) : Array α := Id.run do if b then ys := ys.push x return ys +/-- +Builds an expression of type `goal` by replicating the expression `e` into its tail-call-positions, +where it calls `buildInductionCase`. Collects the cases of the final induction hypothesis +as `MVars` as it goes. +-/ partial def buildInductionBody (fn : Expr) (toClear toPreserve : Array FVarId) - (goal : Expr) (oldIH newIH : FVarId) (IHs : Array Expr) (e : Expr) : MetaM Expr := do + (goal : Expr) (oldIH newIH : FVarId) (IHs : Array Expr) (e : Expr) : M Expr := do + -- logInfo m!"buildInductionBody {e}" if e.isDIte then let #[_α, c, h, t, f] := e.getAppArgs | unreachable! @@ -459,7 +471,7 @@ partial def buildInductionBody (fn : Expr) (toClear toPreserve : Array FVarId) if let some matcherApp ← matchMatcherApp? e (alsoCasesOn := true) then -- Collect IHs from the parameters and discrs of the matcher let paramsAndDiscrs := matcherApp.params ++ matcherApp.discrs - let IHs := IHs ++ (← paramsAndDiscrs.concatMapM (collectIHs fn oldIH newIH)) + let IHs := IHs ++ (← paramsAndDiscrs.concatMapM (collectIHs fn oldIH newIH ·)) -- Calculate motive let eType ← newIH.getType @@ -471,7 +483,7 @@ partial def buildInductionBody (fn : Expr) (toClear toPreserve : Array FVarId) if matcherApp.remaining.size == 1 && matcherApp.remaining[0]!.isFVarOf oldIH then let matcherApp' ← matcherApp.transform (useSplitter := true) (addEqualities := mask.map not) - (onParams := foldCalls fn oldIH) + (onParams := (foldCalls fn oldIH ·)) (onMotive := fun xs _body => pure (absMotiveBody.beta (maskArray mask xs))) (onAlt := fun expAltType alt => do removeLamda alt fun oldIH' alt => do @@ -490,7 +502,7 @@ partial def buildInductionBody (fn : Expr) (toClear toPreserve : Array FVarId) let matcherApp' ← matcherApp.transform (useSplitter := true) (addEqualities := mask.map not) - (onParams := foldCalls fn oldIH) + (onParams := (foldCalls fn oldIH ·)) (onMotive := fun xs _body => pure (absMotiveBody.beta (maskArray mask xs))) (onAlt := fun expAltType alt => do buildInductionBody fn toClear toPreserve expAltType oldIH newIH IHs alt) @@ -514,6 +526,35 @@ partial def buildInductionBody (fn : Expr) (toClear toPreserve : Array FVarId) buildInductionCase fn oldIH newIH toClear toPreserve goal IHs e +/-- +Given an expression `e` with metavariables +* collects all these meta-variables, +* lifts them to the current context by reverting all local declarations up to `x` +* introducing a local variable for each of the meta variable +* assigning that local variable to the mvar +* and finally lambda-abstracting over these new local variables. + +This operation only works if the metavariables are independent from each other. + +The resulting meta variable assignment is no longer valid (mentions out-of-scope +variables), so after this operations, terms that still mention these meta variables must not +be used anymore. + +We are not using `mkLambdaFVars` on mvars directly, nor `abstractMVars`, as these at the moment +do not handle delayed assignemnts correctly. +-/ +def abstractIndependentMVars (mvars : Array MVarId) (x : FVarId) (e : Expr) : MetaM Expr := do + let mvars ← mvars.mapM fun mvar => do + let mvar ← substVarAfter mvar x + let (_, mvar) ← mvar.revertAfter x + pure mvar + let decls := mvars.mapIdx fun i mvar => + (.mkSimple s!"case{i.val+1}", .default, (fun _ => mvar.getType)) + Meta.withLocalDecls decls fun xs => do + for mvar in mvars, x in xs do + mvar.assign x + mkLambdaFVars xs (← instantiateMVars e) + partial def findFixF {α} (name : Name) (e : Expr) (k : Array Expr → Expr → MetaM α) : MetaM α := do lambdaTelescope e fun params body => do if body.isAppOf ``WellFounded.fixF then @@ -552,7 +593,7 @@ def deriveUnaryInduction (name : Name) : MetaM Name := do let e' := mkApp3 (.const ``WellFounded.fixF [argLevel, levelZero]) argType rel motive let fn := mkAppN (.const name (info.levelParams.map mkLevelParam)) params.pop - let body' ← forallTelescope (← inferType e').bindingDomain! fun xs _ => do + let (body', mvars) ← StateT.run (s := {}) <| forallTelescope (← inferType e').bindingDomain! fun xs _ => do let #[param, genIH] := xs | unreachable! -- open body with the same arg let body ← instantiateLambda body #[param] @@ -565,27 +606,8 @@ def deriveUnaryInduction (name : Name) : MetaM Name := do let e' := mkApp3 e' body' arg acc let e' ← mkLambdaFVars #[params.back] e' - let mvars ← getMVarsNoDelayed e' - let mvars ← mvars.mapM fun mvar => do - let mvar ← substVarAfter mvar motive.fvarId! - let (_, mvar) ← mvar.revertAfter motive.fvarId! - pure mvar - -- Using `mkLambdaFVars` on mvars directly does not reliably replace - -- the mvars with the parameter, in the presence of delayed assignemnts. - -- Also `abstractMVars` does not handle delayed assignments correctly (as of now). - -- So instead we bring suitable fvars into scope and use `assign`; this handles - -- delayed assignemnts correctly. - -- NB: This idiom only works because - -- * we know that the `MVars` have the right local context (thanks to `mvarId.revertAfter`) - -- * the MVars are independent (so we don’t need to reorder them) - -- * we do no need the mvars in their unassigned form later - let e' ← Meta.withLocalDecls - (mvars.mapIdx (fun i mv => (.mkSimple s!"case{i.val+1}", .default, (fun _ => mv.getType)))) - fun xs => do - for mvar in mvars, x in xs do - mvar.assign x - let e' ← instantiateMVars e' - mkLambdaFVars xs e' + let e' ← abstractIndependentMVars mvars motive.fvarId! e' + let e' ← mkLambdaFVars #[motive] e' -- We could pass (usedOnly := true) below, and get nicer induction principles that -- do do not mention odd unused parameters. @@ -593,7 +615,7 @@ def deriveUnaryInduction (name : Name) : MetaM Name := do -- that derives them from an function application in the goal) is harder, as -- one would have to infer or keep track of which parameters to pass. -- So for now lets just keep them around. - let e' ← mkLambdaFVars (binderInfoForMVars := .default) (params.pop ++ #[motive]) e' + let e' ← mkLambdaFVars (binderInfoForMVars := .default) params.pop e' let e' ← instantiateMVars e' let eTyp ← inferType e' diff --git a/tests/lean/run/funind_fewer_levels.lean b/tests/lean/run/funind_fewer_levels.lean index d1dbcd762c..111ee45f4a 100644 --- a/tests/lean/run/funind_fewer_levels.lean +++ b/tests/lean/run/funind_fewer_levels.lean @@ -40,9 +40,8 @@ end derive_functional_induction foo /-- -info: Mutual.foo.induct (motive1 motive2 : Nat → Prop) (case1 : motive1 0) (case2 : motive2 0) - (case3 : ∀ (n : Nat), motive2 n → motive1 n.succ) (case4 : ∀ (n : Nat), motive1 n → motive2 n.succ) : - ∀ (a : Nat), motive1 a +info: Mutual.foo.induct (motive1 motive2 : Nat → Prop) (case1 : motive1 0) (case2 : ∀ (n : Nat), motive2 n → motive1 n.succ) + (case3 : motive2 0) (case4 : ∀ (n : Nat), motive1 n → motive2 n.succ) : ∀ (a : Nat), motive1 a -/ #guard_msgs in #check foo.induct @@ -50,8 +49,8 @@ info: Mutual.foo.induct (motive1 motive2 : Nat → Prop) (case1 : motive1 0) (ca example : foo n = .unit := by induction n using foo.induct (motive2 := fun n => bar n = .unit) with | case1 => unfold foo; rfl - | case2 => unfold bar; rfl - | case3 n ih => unfold foo; exact ih + | case2 n ih => unfold foo; exact ih + | case3 => unfold bar; rfl | case4 n ih => unfold bar; exact ih end Mutual diff --git a/tests/lean/run/funind_proof.lean b/tests/lean/run/funind_proof.lean index 028f04f579..4d32d47ae9 100644 --- a/tests/lean/run/funind_proof.lean +++ b/tests/lean/run/funind_proof.lean @@ -27,10 +27,10 @@ end derive_functional_induction replaceConst /-- -info: Term.replaceConst.induct (a b : String) (motive1 : Term → Prop) (motive2 : List Term → Prop) (case1 : motive2 []) - (case2 : ∀ (a_1 : String), (a == a_1) = true → motive1 (const a_1)) - (case3 : ∀ (a_1 : String), ¬(a == a_1) = true → motive1 (const a_1)) - (case4 : ∀ (a : String) (cs : List Term), motive2 cs → motive1 (app a cs)) +info: Term.replaceConst.induct (a b : String) (motive1 : Term → Prop) (motive2 : List Term → Prop) + (case1 : ∀ (a_1 : String), (a == a_1) = true → motive1 (const a_1)) + (case2 : ∀ (a_1 : String), ¬(a == a_1) = true → motive1 (const a_1)) + (case3 : ∀ (a : String) (cs : List Term), motive2 cs → motive1 (app a cs)) (case4 : motive2 []) (case5 : ∀ (c : Term) (cs : List Term), motive1 c → motive2 cs → motive2 (c :: cs)) : ∀ (a : Term), motive1 a -/ #guard_msgs in @@ -40,13 +40,13 @@ theorem numConsts_replaceConst (a b : String) (e : Term) : numConsts (replaceCon apply replaceConst.induct (motive1 := fun e => numConsts (replaceConst a b e) = numConsts e) (motive2 := fun es => numConstsLst (replaceConstLst a b es) = numConstsLst es) - case case1 => simp [replaceConstLst, numConstsLst, *] - case case2 => intro c h; guard_hyp h :ₛ (a == c) = true; simp [replaceConst, numConsts, *] - case case3 => intro c h; guard_hyp h :ₛ ¬(a == c) = true; simp [replaceConst, numConsts, *] - case case4 => + case case1 => intro c h; guard_hyp h :ₛ (a == c) = true; simp [replaceConst, numConsts, *] + case case2 => intro c h; guard_hyp h :ₛ ¬(a == c) = true; simp [replaceConst, numConsts, *] + case case3 => intros f cs ih guard_hyp ih :ₛnumConstsLst (replaceConstLst a b cs) = numConstsLst cs simp [replaceConst, numConsts, *] + case case4 => simp [replaceConstLst, numConstsLst, *] case case5 => intro c cs ih₁ ih₂ guard_hyp ih₁ :ₛ numConsts (replaceConst a b c) = numConsts c diff --git a/tests/lean/run/funind_tests.lean b/tests/lean/run/funind_tests.lean index 5253ef5357..88344d4cbc 100644 --- a/tests/lean/run/funind_tests.lean +++ b/tests/lean/run/funind_tests.lean @@ -587,7 +587,7 @@ info: RecCallInDisrs.foo.induct (motive : Nat → Prop) (case1 : motive 0) def bar : Nat → Nat | 0 => 0 - | n+1 => match h₁ : n, bar n with + | n+1 => match _h : n, bar n with | 0, 0 => 0 | 0, _ => 1 | m+1, _ => bar m @@ -619,17 +619,15 @@ end derive_functional_induction even /-- -info: EvenOdd.even.induct (motive1 motive2 : Nat → Prop) (case1 : motive1 0) (case2 : motive2 0) - (case3 : ∀ (n : Nat), motive2 n → motive1 n.succ) (case4 : ∀ (n : Nat), motive1 n → motive2 n.succ) : - ∀ (a : Nat), motive1 a +info: EvenOdd.even.induct (motive1 motive2 : Nat → Prop) (case1 : motive1 0) (case2 : ∀ (n : Nat), motive2 n → motive1 n.succ) + (case3 : motive2 0) (case4 : ∀ (n : Nat), motive1 n → motive2 n.succ) : ∀ (a : Nat), motive1 a -/ #guard_msgs in #check even.induct /-- -info: EvenOdd.odd.induct (motive1 motive2 : Nat → Prop) (case1 : motive1 0) (case2 : motive2 0) - (case3 : ∀ (n : Nat), motive2 n → motive1 n.succ) (case4 : ∀ (n : Nat), motive1 n → motive2 n.succ) : - ∀ (a : Nat), motive2 a +info: EvenOdd.odd.induct (motive1 motive2 : Nat → Prop) (case1 : motive1 0) (case2 : ∀ (n : Nat), motive2 n → motive1 n.succ) + (case3 : motive2 0) (case4 : ∀ (n : Nat), motive1 n → motive2 n.succ) : ∀ (a : Nat), motive2 a -/ #guard_msgs in #check odd.induct @@ -773,7 +771,7 @@ derive_functional_induction even._mutual /-- info: CommandIdempotence.even._mutual.induct (motive : Nat ⊕' Nat → Prop) (case1 : motive (PSum.inl 0)) - (case2 : motive (PSum.inr 0)) (case3 : ∀ (n : Nat), motive (PSum.inr n) → motive (PSum.inl n.succ)) + (case2 : ∀ (n : Nat), motive (PSum.inr n) → motive (PSum.inl n.succ)) (case3 : motive (PSum.inr 0)) (case4 : ∀ (n : Nat), motive (PSum.inl n) → motive (PSum.inr n.succ)) (x : Nat ⊕' Nat) : motive x -/ #guard_msgs in @@ -787,16 +785,16 @@ derive_functional_induction even /-- info: CommandIdempotence.even._mutual.induct (motive : Nat ⊕' Nat → Prop) (case1 : motive (PSum.inl 0)) - (case2 : motive (PSum.inr 0)) (case3 : ∀ (n : Nat), motive (PSum.inr n) → motive (PSum.inl n.succ)) + (case2 : ∀ (n : Nat), motive (PSum.inr n) → motive (PSum.inl n.succ)) (case3 : motive (PSum.inr 0)) (case4 : ∀ (n : Nat), motive (PSum.inl n) → motive (PSum.inr n.succ)) (x : Nat ⊕' Nat) : motive x -/ #guard_msgs in #check even._mutual.induct /-- -info: CommandIdempotence.even.induct (motive1 motive2 : Nat → Prop) (case1 : motive1 0) (case2 : motive2 0) - (case3 : ∀ (n : Nat), motive2 n → motive1 n.succ) (case4 : ∀ (n : Nat), motive1 n → motive2 n.succ) : - ∀ (a : Nat), motive1 a +info: CommandIdempotence.even.induct (motive1 motive2 : Nat → Prop) (case1 : motive1 0) + (case2 : ∀ (n : Nat), motive2 n → motive1 n.succ) (case3 : motive2 0) + (case4 : ∀ (n : Nat), motive1 n → motive2 n.succ) : ∀ (a : Nat), motive1 a -/ #guard_msgs in #check even.induct