diff --git a/src/Lean/Meta/Tactic/FunInd.lean b/src/Lean/Meta/Tactic/FunInd.lean index f20ba206d7..58313d3598 100644 --- a/src/Lean/Meta/Tactic/FunInd.lean +++ b/src/Lean/Meta/Tactic/FunInd.lean @@ -371,6 +371,22 @@ def assertIHs (vals : Array Expr) (mvarid : MVarId) : MetaM MVarId := do mvarid ← mvarid.assert s!"ih{i+1}" (← inferType v) v return mvarid + +/-- +Substitutes equations, but makes sure to only substitute variables introduced after the motive +as the motive could depend on anything before, and `substVar` would happily drop equations +about these fixed parameters. +-/ +def substVarAfter (mvarId : MVarId) (x : FVarId) : MetaM MVarId := do + mvarId.withContext do + let mut mvarId := mvarId + let index := (← x.getDecl).index + for localDecl in (← getLCtx) do + if localDecl.index > index then + mvarId ← trySubstVar mvarId localDecl.fvarId + return mvarId + + /-- Base case of `buildInductionBody`: Construct a case for the final induction hypthesis. -/ def buildInductionCase (fn : Expr) (oldIH newIH : FVarId) (toClear toPreserve : Array FVarId) (goal : Expr) (IHs : Array Expr) (e : Expr) : MetaM Expr := do @@ -383,7 +399,6 @@ def buildInductionCase (fn : Expr) (oldIH newIH : FVarId) (toClear toPreserve : for fvarId in toClear do mvarId ← mvarId.clear fvarId mvarId ← mvarId.cleanup (toPreserve := toPreserve) - mvarId ← substVars mvarId let mvar ← instantiateMVars mvar pure mvar @@ -552,6 +567,7 @@ def deriveUnaryInduction (name : Name) : MetaM Name := do let e' ← mkLambdaFVars #[params.back] e' let mvars ← getMVarsNoDelayed e' let mvars ← mvars.mapM fun mvar => do + let mvar ← substVarAfter mvar motive.fvarId! let (_, mvar) ← mvar.revertAfter motive.fvarId! pure mvar -- Using `mkLambdaFVars` on mvars directly does not reliably replace diff --git a/src/Lean/Meta/Tactic/Subst.lean b/src/Lean/Meta/Tactic/Subst.lean index 8b3f484b34..683b8d24fb 100644 --- a/src/Lean/Meta/Tactic/Subst.lean +++ b/src/Lean/Meta/Tactic/Subst.lean @@ -138,6 +138,35 @@ def heqToEq (mvarId : MVarId) (fvarId : FVarId) (tryToClear : Bool := true) : Me else return (fvarId, mvarId) +/-- +Given `x`, try to find an equation of the form `heq : x = rhs` or `heq : lhs = x`, +and runs `substCore` on it. Throws an expection if no such equation is found. +-/ +partial def substVar (mvarId : MVarId) (x : FVarId) : MetaM MVarId := + mvarId.withContext do + let localDecl ← x.getDecl + if localDecl.isLet then + throwTacticEx `subst mvarId m!"variable '{mkFVar x}' is a let-declaration" + let lctx ← getLCtx + let some (fvarId, symm) ← lctx.findDeclM? fun localDecl => do + if localDecl.isImplementationDetail then + return none + else + match (← matchEq? localDecl.type) with + | some (_, lhs, rhs) => + let lhs ← instantiateMVars lhs + let rhs ← instantiateMVars rhs + if rhs.isFVar && rhs.fvarId! == x then + if !(← exprDependsOn lhs x) then + return some (localDecl.fvarId, true) + if lhs.isFVar && lhs.fvarId! == x then + if !(← exprDependsOn rhs x) then + return some (localDecl.fvarId, false) + return none + | _ => return none + | throwTacticEx `subst mvarId m!"did not find equation for eliminating '{mkFVar x}'" + return (← substCore mvarId fvarId (symm := symm) (tryToSkip := true)).2 + partial def subst (mvarId : MVarId) (h : FVarId) : MetaM MVarId := mvarId.withContext do let type ← h.getType @@ -147,10 +176,10 @@ partial def subst (mvarId : MVarId) (h : FVarId) : MetaM MVarId := | some _ => let (h', mvarId') ← heqToEq mvarId h if mvarId == mvarId' then - findEq mvarId h + substVar mvarId h else subst mvarId' h' - | none => findEq mvarId h + | none => substVar mvarId h where /-- Give `h : Eq α a b`, try to apply `substCore` -/ substEq (mvarId : MVarId) (h : FVarId) : MetaM MVarId := mvarId.withContext do @@ -177,30 +206,12 @@ where else do throwTacticEx `subst mvarId m!"invalid equality proof, it is not of the form (x = t) or (t = x){indentExpr localDecl.type}" - /-- Try to find an equation of the form `heq : h = rhs` or `heq : lhs = h` -/ - findEq (mvarId : MVarId) (h : FVarId) : MetaM MVarId := mvarId.withContext do - let localDecl ← h.getDecl - if localDecl.isLet then - throwTacticEx `subst mvarId m!"variable '{mkFVar h}' is a let-declaration" - let lctx ← getLCtx - let some (fvarId, symm) ← lctx.findDeclM? fun localDecl => do - if localDecl.isImplementationDetail then - return none - else - match (← matchEq? localDecl.type) with - | some (_, lhs, rhs) => - let lhs ← instantiateMVars lhs - let rhs ← instantiateMVars rhs - if rhs.isFVar && rhs.fvarId! == h then - if !(← exprDependsOn lhs h) then - return some (localDecl.fvarId, true) - if lhs.isFVar && lhs.fvarId! == h then - if !(← exprDependsOn rhs h) then - return some (localDecl.fvarId, false) - return none - | _ => return none - | throwTacticEx `subst mvarId m!"did not find equation for eliminating '{mkFVar h}'" - return (← substCore mvarId fvarId (symm := symm) (tryToSkip := true)).2 +/-- +Given `x`, try to find an equation of the form `heq : x = rhs` or `heq : lhs = x`, +and runs `substCore` on it. Returns `none` if no such equation is found, or if `substCore` fails. +-/ +def substVar? (mvarId : MVarId) (hFVarId : FVarId) : MetaM (Option MVarId) := + observing? (substVar mvarId hFVarId) def subst? (mvarId : MVarId) (hFVarId : FVarId) : MetaM (Option MVarId) := observing? (subst mvarId hFVarId) @@ -208,10 +219,11 @@ def subst? (mvarId : MVarId) (hFVarId : FVarId) : MetaM (Option MVarId) := def substCore? (mvarId : MVarId) (hFVarId : FVarId) (symm := false) (fvarSubst : FVarSubst := {}) (clearH := true) (tryToSkip := false) : MetaM (Option (FVarSubst × MVarId)) := observing? (substCore mvarId hFVarId symm fvarSubst clearH tryToSkip) +def trySubstVar (mvarId : MVarId) (hFVarId : FVarId) : MetaM MVarId := do + return (← substVar? mvarId hFVarId).getD mvarId + def trySubst (mvarId : MVarId) (hFVarId : FVarId) : MetaM MVarId := do - match (← subst? mvarId hFVarId) with - | some mvarId => return mvarId - | none => return mvarId + return (← subst? mvarId hFVarId).getD mvarId def substSomeVar? (mvarId : MVarId) : MetaM (Option MVarId) := mvarId.withContext do for localDecl in (← getLCtx) do diff --git a/tests/lean/run/funind_tests.lean b/tests/lean/run/funind_tests.lean index fe0f08128c..3e82eb99cf 100644 --- a/tests/lean/run/funind_tests.lean +++ b/tests/lean/run/funind_tests.lean @@ -837,3 +837,30 @@ derive_functional_induction takeWhile -- Cryptic error message derive_functional_induction takeWhile.foo end Errors + +namespace PreserveParams + +/- +Tests that cleaning up the goal state does not throw away useful equalties +relating varying parameters to fixed ones. +-/ + +def foo (a : Nat) : Nat → Nat + | 0 => 0 + | n+1 => + if a = 23 then 23 else + if a = n then 42 else + foo a n +termination_by n => n +derive_functional_induction foo + +/-- +info: PreserveParams.foo.induct (a : Nat) (motive : Nat → Prop) (case1 : motive 0) + (case2 : ∀ (n : Nat), a = 23 → motive (Nat.succ n)) (case3 : ¬a = 23 → motive (Nat.succ a)) + (case4 : ∀ (n : Nat), ¬a = 23 → ¬a = n → motive n → motive (Nat.succ n)) (x : Nat) : motive x +-/ +#guard_msgs in +#check foo.induct + + +end PreserveParams