fix: substVars in functional inductions removed valuable information (#3695)

using the `substVars` tactic on the goal can remove too much
information, as it does not take into account that the `motive` may
depend on the fixed parameters.

This is fixed by etracting `substVar` from `subst` which expects the
`x`, not the `h : x = rhs`, and then using this tactic on the local
declarations _after_ the `motive` exclusively.
This commit is contained in:
Joachim Breitner 2024-03-16 15:55:31 +01:00 committed by GitHub
parent 4c57da4b0f
commit 0b01ceb3bb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 85 additions and 30 deletions

View file

@ -371,6 +371,22 @@ def assertIHs (vals : Array Expr) (mvarid : MVarId) : MetaM MVarId := do
mvarid ← mvarid.assert s!"ih{i+1}" (← inferType v) v
return mvarid
/--
Substitutes equations, but makes sure to only substitute variables introduced after the motive
as the motive could depend on anything before, and `substVar` would happily drop equations
about these fixed parameters.
-/
def substVarAfter (mvarId : MVarId) (x : FVarId) : MetaM MVarId := do
mvarId.withContext do
let mut mvarId := mvarId
let index := (← x.getDecl).index
for localDecl in (← getLCtx) do
if localDecl.index > index then
mvarId ← trySubstVar mvarId localDecl.fvarId
return mvarId
/-- Base case of `buildInductionBody`: Construct a case for the final induction hypthesis. -/
def buildInductionCase (fn : Expr) (oldIH newIH : FVarId) (toClear toPreserve : Array FVarId)
(goal : Expr) (IHs : Array Expr) (e : Expr) : MetaM Expr := do
@ -383,7 +399,6 @@ def buildInductionCase (fn : Expr) (oldIH newIH : FVarId) (toClear toPreserve :
for fvarId in toClear do
mvarId ← mvarId.clear fvarId
mvarId ← mvarId.cleanup (toPreserve := toPreserve)
mvarId ← substVars mvarId
let mvar ← instantiateMVars mvar
pure mvar
@ -552,6 +567,7 @@ def deriveUnaryInduction (name : Name) : MetaM Name := do
let e' ← mkLambdaFVars #[params.back] e'
let mvars ← getMVarsNoDelayed e'
let mvars ← mvars.mapM fun mvar => do
let mvar ← substVarAfter mvar motive.fvarId!
let (_, mvar) ← mvar.revertAfter motive.fvarId!
pure mvar
-- Using `mkLambdaFVars` on mvars directly does not reliably replace

View file

@ -138,6 +138,35 @@ def heqToEq (mvarId : MVarId) (fvarId : FVarId) (tryToClear : Bool := true) : Me
else
return (fvarId, mvarId)
/--
Given `x`, try to find an equation of the form `heq : x = rhs` or `heq : lhs = x`,
and runs `substCore` on it. Throws an expection if no such equation is found.
-/
partial def substVar (mvarId : MVarId) (x : FVarId) : MetaM MVarId :=
mvarId.withContext do
let localDecl ← x.getDecl
if localDecl.isLet then
throwTacticEx `subst mvarId m!"variable '{mkFVar x}' is a let-declaration"
let lctx ← getLCtx
let some (fvarId, symm) ← lctx.findDeclM? fun localDecl => do
if localDecl.isImplementationDetail then
return none
else
match (← matchEq? localDecl.type) with
| some (_, lhs, rhs) =>
let lhs ← instantiateMVars lhs
let rhs ← instantiateMVars rhs
if rhs.isFVar && rhs.fvarId! == x then
if !(← exprDependsOn lhs x) then
return some (localDecl.fvarId, true)
if lhs.isFVar && lhs.fvarId! == x then
if !(← exprDependsOn rhs x) then
return some (localDecl.fvarId, false)
return none
| _ => return none
| throwTacticEx `subst mvarId m!"did not find equation for eliminating '{mkFVar x}'"
return (← substCore mvarId fvarId (symm := symm) (tryToSkip := true)).2
partial def subst (mvarId : MVarId) (h : FVarId) : MetaM MVarId :=
mvarId.withContext do
let type ← h.getType
@ -147,10 +176,10 @@ partial def subst (mvarId : MVarId) (h : FVarId) : MetaM MVarId :=
| some _ =>
let (h', mvarId') ← heqToEq mvarId h
if mvarId == mvarId' then
findEq mvarId h
substVar mvarId h
else
subst mvarId' h'
| none => findEq mvarId h
| none => substVar mvarId h
where
/-- Give `h : Eq α a b`, try to apply `substCore` -/
substEq (mvarId : MVarId) (h : FVarId) : MetaM MVarId := mvarId.withContext do
@ -177,30 +206,12 @@ where
else do
throwTacticEx `subst mvarId m!"invalid equality proof, it is not of the form (x = t) or (t = x){indentExpr localDecl.type}"
/-- Try to find an equation of the form `heq : h = rhs` or `heq : lhs = h` -/
findEq (mvarId : MVarId) (h : FVarId) : MetaM MVarId := mvarId.withContext do
let localDecl ← h.getDecl
if localDecl.isLet then
throwTacticEx `subst mvarId m!"variable '{mkFVar h}' is a let-declaration"
let lctx ← getLCtx
let some (fvarId, symm) ← lctx.findDeclM? fun localDecl => do
if localDecl.isImplementationDetail then
return none
else
match (← matchEq? localDecl.type) with
| some (_, lhs, rhs) =>
let lhs ← instantiateMVars lhs
let rhs ← instantiateMVars rhs
if rhs.isFVar && rhs.fvarId! == h then
if !(← exprDependsOn lhs h) then
return some (localDecl.fvarId, true)
if lhs.isFVar && lhs.fvarId! == h then
if !(← exprDependsOn rhs h) then
return some (localDecl.fvarId, false)
return none
| _ => return none
| throwTacticEx `subst mvarId m!"did not find equation for eliminating '{mkFVar h}'"
return (← substCore mvarId fvarId (symm := symm) (tryToSkip := true)).2
/--
Given `x`, try to find an equation of the form `heq : x = rhs` or `heq : lhs = x`,
and runs `substCore` on it. Returns `none` if no such equation is found, or if `substCore` fails.
-/
def substVar? (mvarId : MVarId) (hFVarId : FVarId) : MetaM (Option MVarId) :=
observing? (substVar mvarId hFVarId)
def subst? (mvarId : MVarId) (hFVarId : FVarId) : MetaM (Option MVarId) :=
observing? (subst mvarId hFVarId)
@ -208,10 +219,11 @@ def subst? (mvarId : MVarId) (hFVarId : FVarId) : MetaM (Option MVarId) :=
def substCore? (mvarId : MVarId) (hFVarId : FVarId) (symm := false) (fvarSubst : FVarSubst := {}) (clearH := true) (tryToSkip := false) : MetaM (Option (FVarSubst × MVarId)) :=
observing? (substCore mvarId hFVarId symm fvarSubst clearH tryToSkip)
def trySubstVar (mvarId : MVarId) (hFVarId : FVarId) : MetaM MVarId := do
return (← substVar? mvarId hFVarId).getD mvarId
def trySubst (mvarId : MVarId) (hFVarId : FVarId) : MetaM MVarId := do
match (← subst? mvarId hFVarId) with
| some mvarId => return mvarId
| none => return mvarId
return (← subst? mvarId hFVarId).getD mvarId
def substSomeVar? (mvarId : MVarId) : MetaM (Option MVarId) := mvarId.withContext do
for localDecl in (← getLCtx) do

View file

@ -837,3 +837,30 @@ derive_functional_induction takeWhile -- Cryptic error message
derive_functional_induction takeWhile.foo
end Errors
namespace PreserveParams
/-
Tests that cleaning up the goal state does not throw away useful equalties
relating varying parameters to fixed ones.
-/
def foo (a : Nat) : Nat → Nat
| 0 => 0
| n+1 =>
if a = 23 then 23 else
if a = n then 42 else
foo a n
termination_by n => n
derive_functional_induction foo
/--
info: PreserveParams.foo.induct (a : Nat) (motive : Nat → Prop) (case1 : motive 0)
(case2 : ∀ (n : Nat), a = 23 → motive (Nat.succ n)) (case3 : ¬a = 23 → motive (Nat.succ a))
(case4 : ∀ (n : Nat), ¬a = 23 → ¬a = n → motive n → motive (Nat.succ n)) (x : Nat) : motive x
-/
#guard_msgs in
#check foo.induct
end PreserveParams