refactor: forallAltTelescope to take altNumParams (#3230)

this way this function does not have to peek at the `altType` to see
when there are no more arguments, which makes it a bit more explicit,
and also a bit more robust should one apply this function to the type of
an alternative with the motive already instantiated.

It seems this uncovered a variable shadow bug, where the counter `i` was
accidentially reset after removing the `i`’th entry in `ys`.
This commit is contained in:
Joachim Breitner 2024-01-31 12:03:03 +01:00 committed by GitHub
parent 456e435fe0
commit 279607f5f8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -80,25 +80,26 @@ def unfoldNamedPattern (e : Expr) : MetaM Expr := do
We use the `mask` to build the splitter proof. See `mkSplitterProof`.
-/
partial def forallAltTelescope (altType : Expr) (numNonEqParams : Nat)
partial def forallAltTelescope (altType : Expr) (altNumParams numDiscrEqs : Nat)
(k : (ys : Array Expr) → (eqs : Array Expr) → (args : Array Expr) → (mask : Array Bool) → (type : Expr) → MetaM α)
: MetaM α := do
go #[] #[] #[] #[] 0 altType
where
go (ys : Array Expr) (eqs : Array Expr) (args : Array Expr) (mask : Array Bool) (i : Nat) (type : Expr) : MetaM α := do
let type ← whnfForall type
match type with
| Expr.forallE n d b .. =>
if i < numNonEqParams then
if i < altNumParams then
let Expr.forallE n d b .. := type
| throwError "expecting {altNumParams} parameters, including {numDiscrEqs} equalities, but found type{indentExpr altType}"
if i < altNumParams - numDiscrEqs then
let d ← unfoldNamedPattern d
withLocalDeclD n d fun y => do
let typeNew := b.instantiate1 y
if let some (_, lhs, rhs) ← matchEq? d then
if lhs.isFVar && ys.contains lhs && args.contains lhs && isNamedPatternProof typeNew y then
let some i := ys.getIdx? lhs | unreachable!
let ys := ys.eraseIdx i
let some j := args.getIdx? lhs | unreachable!
let mask := mask.set! j false
let some j := ys.getIdx? lhs | unreachable!
let ys := ys.eraseIdx j
let some k := args.getIdx? lhs | unreachable!
let mask := mask.set! k false
let args := args.map fun arg => if arg == lhs then rhs else arg
let args := args.push (← mkEqRefl rhs)
let typeNew := typeNew.replaceFVar lhs rhs
@ -114,7 +115,7 @@ where
withLocalDeclD n d fun eq => do
let typeNew := b.instantiate1 eq
go ys (eqs.push eq) (args.push arg) (mask.push false) (i+1) typeNew
| _ =>
else
let type ← unfoldNamedPattern type
/- Recall that alternatives that do not have variables have a `Unit` parameter to ensure
they are not eagerly evaluated. -/
@ -628,10 +629,11 @@ private partial def mkEquationsFor (matchDeclName : Name) : MetaM MatchEqns :=
let mut altArgMasks := #[] -- masks produced by `forallAltTelescope`
for i in [:alts.size] do
let altNumParams := matchInfo.altNumParams[i]!
let altNonEqNumParams := altNumParams - numDiscrEqs
let thmName := baseName ++ ((`eq).appendIndexAfter idx)
eqnNames := eqnNames.push thmName
let (notAlt, splitterAltType, splitterAltNumParam, argMask) ← forallAltTelescope (← inferType alts[i]!) altNonEqNumParams fun ys eqs rhsArgs argMask altResultType => do
let (notAlt, splitterAltType, splitterAltNumParam, argMask) ←
forallAltTelescope (← inferType alts[i]!) altNumParams numDiscrEqs
fun ys eqs rhsArgs argMask altResultType => do
let patterns := altResultType.getAppArgs
let mut hs := #[]
for notAlt in notAlts do