fix: functional induction: preseve order of cases better (#3762)

by passing an explicit array of metavariable around, instead of relying
on `getMVarsNoDelayed`, which may return them in unexpected order.
This commit is contained in:
Joachim Breitner 2024-03-25 12:59:29 +01:00 committed by GitHub
parent 3dd811f9ad
commit e0c6c5d226
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 105 additions and 67 deletions

View file

@ -176,16 +176,32 @@ def arrowDomainsN (n : Nat) (type : Expr) : MetaM (Array Expr) := do
type := β
return ts
private def withUserNamesImpl {α} (fvars : Array Expr) (names : Array Name) (k : MetaM α) : MetaM α := do
let lctx := (Array.zip fvars names).foldl (init := ← (getLCtx)) fun lctx (fvar, name) =>
lctx.setUserName fvar.fvarId! name
withTheReader Meta.Context (fun ctx => { ctx with lctx }) k
/--
Sets the user name of the FVars in the local context according to the given array of names.
If they differ in size the shorter size wins.
-/
def withUserNames {α} (fvars : Array Expr) (names : Array Name) (k : MetaM α ) : MetaM α := do
let lctx := (Array.zip fvars names).foldl (init := ← (getLCtx)) fun lctx (fvar, name) =>
lctx.setUserName fvar.fvarId! name
withTheReader Meta.Context (fun ctx => { ctx with lctx }) k
def withUserNames {n} [MonadControlT MetaM n] [Monad n]
{α} (fvars : Array Expr) (names : Array Name) (k : n α) : n α := do
mapMetaM (withUserNamesImpl fvars names) k
/-
`Match.forallAltTelescope` lifted to a monad transformer
(and only passing those arguments that we care about below)
-/
private def forallAltTelescope'
{n} [Monad n] [MonadControlT MetaM n]
{α} (origAltType : Expr) (numParams numDiscrEqs : Nat)
(k : Array Expr → Array Expr → n α) : n α := do
map2MetaM (fun k =>
Match.forallAltTelescope origAltType (numParams - numDiscrEqs) 0
fun ys _eqs args _mask _bodyType => k ys args
) k
/--
Performs a possibly type-changing transformation to a `MatcherApp`.
@ -208,14 +224,17 @@ This function works even if the the type of alternatives do *not* fit the inferr
allows you to post-process the `MatcherApp` with `MatcherApp.inferMatchType`, which will
infer a type, given all the alternatives.
-/
def transform (matcherApp : MatcherApp)
def transform
{n} [MonadLiftT MetaM n] [MonadControlT MetaM n] [Monad n] [MonadError n] [MonadEnv n] [MonadLog n]
[AddMessageContext n] [MonadOptions n]
(matcherApp : MatcherApp)
(useSplitter := false)
(addEqualities : Array Bool := mkArray matcherApp.discrs.size false)
(onParams : Expr → MetaM Expr := pure)
(onMotive : Array Expr → Expr → MetaM Expr := fun _ e => pure e)
(onAlt : Expr → Expr → MetaM Expr := fun _ e => pure e)
(onRemaining : Array Expr → MetaM (Array Expr) := pure) :
MetaM MatcherApp := do
(onParams : Expr → n Expr := pure)
(onMotive : Array Expr → Expr → n Expr := fun _ e => pure e)
(onAlt : Expr → Expr → n Expr := fun _ e => pure e)
(onRemaining : Array Expr → n (Array Expr) := pure) :
n MatcherApp := do
if addEqualities.size != matcherApp.discrs.size then
throwError "MatcherApp.transform: addEqualities has wrong size"
@ -247,7 +266,7 @@ def transform (matcherApp : MatcherApp)
-- Prepend (x = e) → to the motive when an equality is requested
for arg in motiveArgs, discr in discrs', b in addEqualities do if b then
motiveBody' ← mkArrow (← mkEq discr arg) motiveBody'
motiveBody' ← liftMetaM <| mkArrow (← mkEq discr arg) motiveBody'
return (← mkLambdaFVars motiveArgs motiveBody', ← getLevel motiveBody')
@ -290,7 +309,7 @@ def transform (matcherApp : MatcherApp)
splitterNumParams in matchEqns.splitterAltNumParams,
origAltType in origAltTypes,
altType in altTypes do
let alt' ← Match.forallAltTelescope origAltType (numParams - numDiscrEqs) 0 fun ys _eqs args _mask _bodyType => do
let alt' ← forallAltTelescope' origAltType (numParams - numDiscrEqs) 0 fun ys args => do
let altType ← instantiateForall altType ys
-- The splitter inserts its extra paramters after the first ys.size parameters, before
-- the parameters for the numDiscrEqs

View file

@ -166,7 +166,7 @@ open Lean Elab Meta
This is used when replacing parameters with different expressions.
This way it will not be picked up by metavariables.
-/
def removeLamda {α} (e : Expr) (k : FVarId → Expr → MetaM α) : MetaM α := do
def removeLamda {n} [MonadLiftT MetaM n] [MonadError n] [MonadNameGenerator n] [Monad n] {α} (e : Expr) (k : FVarId → Expr → n α) : n α := do
let .lam _n _d b _bi := ← whnfD e | throwError "removeLamda: expected lambda, got {e}"
let x ← mkFreshFVarId
let b := b.instantiate1 (.fvar x)
@ -386,10 +386,15 @@ def substVarAfter (mvarId : MVarId) (x : FVarId) : MetaM MVarId := do
mvarId ← trySubstVar mvarId localDecl.fvarId
return mvarId
/--
Helper monad to traverse the function body, collecting the cases as mvars
-/
abbrev M α := StateT (Array MVarId) MetaM α
/-- Base case of `buildInductionBody`: Construct a case for the final induction hypthesis. -/
def buildInductionCase (fn : Expr) (oldIH newIH : FVarId) (toClear toPreserve : Array FVarId)
(goal : Expr) (IHs : Array Expr) (e : Expr) : MetaM Expr := do
(goal : Expr) (IHs : Array Expr) (e : Expr) : M Expr := do
let IHs := IHs ++ (← collectIHs fn oldIH newIH e)
let IHs ← deduplicateIHs IHs
@ -398,7 +403,8 @@ def buildInductionCase (fn : Expr) (oldIH newIH : FVarId) (toClear toPreserve :
mvarId ← assertIHs IHs mvarId
for fvarId in toClear do
mvarId ← mvarId.clear fvarId
_ ← mvarId.cleanup (toPreserve := toPreserve)
mvarId ← mvarId.cleanup (toPreserve := toPreserve)
modify (·.push mvarId)
let mvar ← instantiateMVars mvar
pure mvar
@ -437,8 +443,14 @@ def maskArray {α} (mask : Array Bool) (xs : Array α) : Array α := Id.run do
if b then ys := ys.push x
return ys
/--
Builds an expression of type `goal` by replicating the expression `e` into its tail-call-positions,
where it calls `buildInductionCase`. Collects the cases of the final induction hypothesis
as `MVars` as it goes.
-/
partial def buildInductionBody (fn : Expr) (toClear toPreserve : Array FVarId)
(goal : Expr) (oldIH newIH : FVarId) (IHs : Array Expr) (e : Expr) : MetaM Expr := do
(goal : Expr) (oldIH newIH : FVarId) (IHs : Array Expr) (e : Expr) : M Expr := do
-- logInfo m!"buildInductionBody {e}"
if e.isDIte then
let #[_α, c, h, t, f] := e.getAppArgs | unreachable!
@ -459,7 +471,7 @@ partial def buildInductionBody (fn : Expr) (toClear toPreserve : Array FVarId)
if let some matcherApp ← matchMatcherApp? e (alsoCasesOn := true) then
-- Collect IHs from the parameters and discrs of the matcher
let paramsAndDiscrs := matcherApp.params ++ matcherApp.discrs
let IHs := IHs ++ (← paramsAndDiscrs.concatMapM (collectIHs fn oldIH newIH))
let IHs := IHs ++ (← paramsAndDiscrs.concatMapM (collectIHs fn oldIH newIH ·))
-- Calculate motive
let eType ← newIH.getType
@ -471,7 +483,7 @@ partial def buildInductionBody (fn : Expr) (toClear toPreserve : Array FVarId)
if matcherApp.remaining.size == 1 && matcherApp.remaining[0]!.isFVarOf oldIH then
let matcherApp' ← matcherApp.transform (useSplitter := true)
(addEqualities := mask.map not)
(onParams := foldCalls fn oldIH)
(onParams := (foldCalls fn oldIH ·))
(onMotive := fun xs _body => pure (absMotiveBody.beta (maskArray mask xs)))
(onAlt := fun expAltType alt => do
removeLamda alt fun oldIH' alt => do
@ -490,7 +502,7 @@ partial def buildInductionBody (fn : Expr) (toClear toPreserve : Array FVarId)
let matcherApp' ← matcherApp.transform (useSplitter := true)
(addEqualities := mask.map not)
(onParams := foldCalls fn oldIH)
(onParams := (foldCalls fn oldIH ·))
(onMotive := fun xs _body => pure (absMotiveBody.beta (maskArray mask xs)))
(onAlt := fun expAltType alt => do
buildInductionBody fn toClear toPreserve expAltType oldIH newIH IHs alt)
@ -514,6 +526,35 @@ partial def buildInductionBody (fn : Expr) (toClear toPreserve : Array FVarId)
buildInductionCase fn oldIH newIH toClear toPreserve goal IHs e
/--
Given an expression `e` with metavariables
* collects all these meta-variables,
* lifts them to the current context by reverting all local declarations up to `x`
* introducing a local variable for each of the meta variable
* assigning that local variable to the mvar
* and finally lambda-abstracting over these new local variables.
This operation only works if the metavariables are independent from each other.
The resulting meta variable assignment is no longer valid (mentions out-of-scope
variables), so after this operations, terms that still mention these meta variables must not
be used anymore.
We are not using `mkLambdaFVars` on mvars directly, nor `abstractMVars`, as these at the moment
do not handle delayed assignemnts correctly.
-/
def abstractIndependentMVars (mvars : Array MVarId) (x : FVarId) (e : Expr) : MetaM Expr := do
let mvars ← mvars.mapM fun mvar => do
let mvar ← substVarAfter mvar x
let (_, mvar) ← mvar.revertAfter x
pure mvar
let decls := mvars.mapIdx fun i mvar =>
(.mkSimple s!"case{i.val+1}", .default, (fun _ => mvar.getType))
Meta.withLocalDecls decls fun xs => do
for mvar in mvars, x in xs do
mvar.assign x
mkLambdaFVars xs (← instantiateMVars e)
partial def findFixF {α} (name : Name) (e : Expr) (k : Array Expr → Expr → MetaM α) : MetaM α := do
lambdaTelescope e fun params body => do
if body.isAppOf ``WellFounded.fixF then
@ -552,7 +593,7 @@ def deriveUnaryInduction (name : Name) : MetaM Name := do
let e' := mkApp3 (.const ``WellFounded.fixF [argLevel, levelZero]) argType rel motive
let fn := mkAppN (.const name (info.levelParams.map mkLevelParam)) params.pop
let body' ← forallTelescope (← inferType e').bindingDomain! fun xs _ => do
let (body', mvars) ← StateT.run (s := {}) <| forallTelescope (← inferType e').bindingDomain! fun xs _ => do
let #[param, genIH] := xs | unreachable!
-- open body with the same arg
let body ← instantiateLambda body #[param]
@ -565,27 +606,8 @@ def deriveUnaryInduction (name : Name) : MetaM Name := do
let e' := mkApp3 e' body' arg acc
let e' ← mkLambdaFVars #[params.back] e'
let mvars ← getMVarsNoDelayed e'
let mvars ← mvars.mapM fun mvar => do
let mvar ← substVarAfter mvar motive.fvarId!
let (_, mvar) ← mvar.revertAfter motive.fvarId!
pure mvar
-- Using `mkLambdaFVars` on mvars directly does not reliably replace
-- the mvars with the parameter, in the presence of delayed assignemnts.
-- Also `abstractMVars` does not handle delayed assignments correctly (as of now).
-- So instead we bring suitable fvars into scope and use `assign`; this handles
-- delayed assignemnts correctly.
-- NB: This idiom only works because
-- * we know that the `MVars` have the right local context (thanks to `mvarId.revertAfter`)
-- * the MVars are independent (so we dont need to reorder them)
-- * we do no need the mvars in their unassigned form later
let e' ← Meta.withLocalDecls
(mvars.mapIdx (fun i mv => (.mkSimple s!"case{i.val+1}", .default, (fun _ => mv.getType))))
fun xs => do
for mvar in mvars, x in xs do
mvar.assign x
let e' ← instantiateMVars e'
mkLambdaFVars xs e'
let e' ← abstractIndependentMVars mvars motive.fvarId! e'
let e' ← mkLambdaFVars #[motive] e'
-- We could pass (usedOnly := true) below, and get nicer induction principles that
-- do do not mention odd unused parameters.
@ -593,7 +615,7 @@ def deriveUnaryInduction (name : Name) : MetaM Name := do
-- that derives them from an function application in the goal) is harder, as
-- one would have to infer or keep track of which parameters to pass.
-- So for now lets just keep them around.
let e' ← mkLambdaFVars (binderInfoForMVars := .default) (params.pop ++ #[motive]) e'
let e' ← mkLambdaFVars (binderInfoForMVars := .default) params.pop e'
let e' ← instantiateMVars e'
let eTyp ← inferType e'

View file

@ -40,9 +40,8 @@ end
derive_functional_induction foo
/--
info: Mutual.foo.induct (motive1 motive2 : Nat → Prop) (case1 : motive1 0) (case2 : motive2 0)
(case3 : ∀ (n : Nat), motive2 n → motive1 n.succ) (case4 : ∀ (n : Nat), motive1 n → motive2 n.succ) :
∀ (a : Nat), motive1 a
info: Mutual.foo.induct (motive1 motive2 : Nat → Prop) (case1 : motive1 0) (case2 : ∀ (n : Nat), motive2 n → motive1 n.succ)
(case3 : motive2 0) (case4 : ∀ (n : Nat), motive1 n → motive2 n.succ) : ∀ (a : Nat), motive1 a
-/
#guard_msgs in
#check foo.induct
@ -50,8 +49,8 @@ info: Mutual.foo.induct (motive1 motive2 : Nat → Prop) (case1 : motive1 0) (ca
example : foo n = .unit := by
induction n using foo.induct (motive2 := fun n => bar n = .unit) with
| case1 => unfold foo; rfl
| case2 => unfold bar; rfl
| case3 n ih => unfold foo; exact ih
| case2 n ih => unfold foo; exact ih
| case3 => unfold bar; rfl
| case4 n ih => unfold bar; exact ih
end Mutual

View file

@ -27,10 +27,10 @@ end
derive_functional_induction replaceConst
/--
info: Term.replaceConst.induct (a b : String) (motive1 : Term → Prop) (motive2 : List Term → Prop) (case1 : motive2 [])
(case2 : ∀ (a_1 : String), (a == a_1) = true → motive1 (const a_1))
(case3 : ∀ (a_1 : String), ¬(a == a_1) = true → motive1 (const a_1))
(case4 : ∀ (a : String) (cs : List Term), motive2 cs → motive1 (app a cs))
info: Term.replaceConst.induct (a b : String) (motive1 : Term → Prop) (motive2 : List Term → Prop)
(case1 : ∀ (a_1 : String), (a == a_1) = true → motive1 (const a_1))
(case2 : ∀ (a_1 : String), ¬(a == a_1) = true → motive1 (const a_1))
(case3 : ∀ (a : String) (cs : List Term), motive2 cs → motive1 (app a cs)) (case4 : motive2 [])
(case5 : ∀ (c : Term) (cs : List Term), motive1 c → motive2 cs → motive2 (c :: cs)) : ∀ (a : Term), motive1 a
-/
#guard_msgs in
@ -40,13 +40,13 @@ theorem numConsts_replaceConst (a b : String) (e : Term) : numConsts (replaceCon
apply replaceConst.induct
(motive1 := fun e => numConsts (replaceConst a b e) = numConsts e)
(motive2 := fun es => numConstsLst (replaceConstLst a b es) = numConstsLst es)
case case1 => simp [replaceConstLst, numConstsLst, *]
case case2 => intro c h; guard_hyp h :ₛ (a == c) = true; simp [replaceConst, numConsts, *]
case case3 => intro c h; guard_hyp h :ₛ ¬(a == c) = true; simp [replaceConst, numConsts, *]
case case4 =>
case case1 => intro c h; guard_hyp h :ₛ (a == c) = true; simp [replaceConst, numConsts, *]
case case2 => intro c h; guard_hyp h :ₛ ¬(a == c) = true; simp [replaceConst, numConsts, *]
case case3 =>
intros f cs ih
guard_hyp ih :ₛnumConstsLst (replaceConstLst a b cs) = numConstsLst cs
simp [replaceConst, numConsts, *]
case case4 => simp [replaceConstLst, numConstsLst, *]
case case5 =>
intro c cs ih₁ ih₂
guard_hyp ih₁ :ₛ numConsts (replaceConst a b c) = numConsts c

View file

@ -587,7 +587,7 @@ info: RecCallInDisrs.foo.induct (motive : Nat → Prop) (case1 : motive 0)
def bar : Nat → Nat
| 0 => 0
| n+1 => match h : n, bar n with
| n+1 => match _h : n, bar n with
| 0, 0 => 0
| 0, _ => 1
| m+1, _ => bar m
@ -619,17 +619,15 @@ end
derive_functional_induction even
/--
info: EvenOdd.even.induct (motive1 motive2 : Nat → Prop) (case1 : motive1 0) (case2 : motive2 0)
(case3 : ∀ (n : Nat), motive2 n → motive1 n.succ) (case4 : ∀ (n : Nat), motive1 n → motive2 n.succ) :
∀ (a : Nat), motive1 a
info: EvenOdd.even.induct (motive1 motive2 : Nat → Prop) (case1 : motive1 0) (case2 : ∀ (n : Nat), motive2 n → motive1 n.succ)
(case3 : motive2 0) (case4 : ∀ (n : Nat), motive1 n → motive2 n.succ) : ∀ (a : Nat), motive1 a
-/
#guard_msgs in
#check even.induct
/--
info: EvenOdd.odd.induct (motive1 motive2 : Nat → Prop) (case1 : motive1 0) (case2 : motive2 0)
(case3 : ∀ (n : Nat), motive2 n → motive1 n.succ) (case4 : ∀ (n : Nat), motive1 n → motive2 n.succ) :
∀ (a : Nat), motive2 a
info: EvenOdd.odd.induct (motive1 motive2 : Nat → Prop) (case1 : motive1 0) (case2 : ∀ (n : Nat), motive2 n → motive1 n.succ)
(case3 : motive2 0) (case4 : ∀ (n : Nat), motive1 n → motive2 n.succ) : ∀ (a : Nat), motive2 a
-/
#guard_msgs in
#check odd.induct
@ -773,7 +771,7 @@ derive_functional_induction even._mutual
/--
info: CommandIdempotence.even._mutual.induct (motive : Nat ⊕' Nat → Prop) (case1 : motive (PSum.inl 0))
(case2 : motive (PSum.inr 0)) (case3 : ∀ (n : Nat), motive (PSum.inr n) → motive (PSum.inl n.succ))
(case2 : ∀ (n : Nat), motive (PSum.inr n) → motive (PSum.inl n.succ)) (case3 : motive (PSum.inr 0))
(case4 : ∀ (n : Nat), motive (PSum.inl n) → motive (PSum.inr n.succ)) (x : Nat ⊕' Nat) : motive x
-/
#guard_msgs in
@ -787,16 +785,16 @@ derive_functional_induction even
/--
info: CommandIdempotence.even._mutual.induct (motive : Nat ⊕' Nat → Prop) (case1 : motive (PSum.inl 0))
(case2 : motive (PSum.inr 0)) (case3 : ∀ (n : Nat), motive (PSum.inr n) → motive (PSum.inl n.succ))
(case2 : ∀ (n : Nat), motive (PSum.inr n) → motive (PSum.inl n.succ)) (case3 : motive (PSum.inr 0))
(case4 : ∀ (n : Nat), motive (PSum.inl n) → motive (PSum.inr n.succ)) (x : Nat ⊕' Nat) : motive x
-/
#guard_msgs in
#check even._mutual.induct
/--
info: CommandIdempotence.even.induct (motive1 motive2 : Nat → Prop) (case1 : motive1 0) (case2 : motive2 0)
(case3 : ∀ (n : Nat), motive2 n → motive1 n.succ) (case4 : ∀ (n : Nat), motive1 n → motive2 n.succ) :
∀ (a : Nat), motive1 a
info: CommandIdempotence.even.induct (motive1 motive2 : Nat → Prop) (case1 : motive1 0)
(case2 : ∀ (n : Nat), motive2 n → motive1 n.succ) (case3 : motive2 0)
(case4 : ∀ (n : Nat), motive1 n → motive2 n.succ) : ∀ (a : Nat), motive1 a
-/
#guard_msgs in
#check even.induct