diff --git a/src/Lean/Meta/Match/MatchEqs.lean b/src/Lean/Meta/Match/MatchEqs.lean index 909edfe87c..f9212a366f 100644 --- a/src/Lean/Meta/Match/MatchEqs.lean +++ b/src/Lean/Meta/Match/MatchEqs.lean @@ -80,25 +80,26 @@ def unfoldNamedPattern (e : Expr) : MetaM Expr := do We use the `mask` to build the splitter proof. See `mkSplitterProof`. -/ -partial def forallAltTelescope (altType : Expr) (numNonEqParams : Nat) +partial def forallAltTelescope (altType : Expr) (altNumParams numDiscrEqs : Nat) (k : (ys : Array Expr) → (eqs : Array Expr) → (args : Array Expr) → (mask : Array Bool) → (type : Expr) → MetaM α) : MetaM α := do go #[] #[] #[] #[] 0 altType where go (ys : Array Expr) (eqs : Array Expr) (args : Array Expr) (mask : Array Bool) (i : Nat) (type : Expr) : MetaM α := do let type ← whnfForall type - match type with - | Expr.forallE n d b .. => - if i < numNonEqParams then + if i < altNumParams then + let Expr.forallE n d b .. := type + | throwError "expecting {altNumParams} parameters, including {numDiscrEqs} equalities, but found type{indentExpr altType}" + if i < altNumParams - numDiscrEqs then let d ← unfoldNamedPattern d withLocalDeclD n d fun y => do let typeNew := b.instantiate1 y if let some (_, lhs, rhs) ← matchEq? d then if lhs.isFVar && ys.contains lhs && args.contains lhs && isNamedPatternProof typeNew y then - let some i := ys.getIdx? lhs | unreachable! - let ys := ys.eraseIdx i - let some j := args.getIdx? lhs | unreachable! - let mask := mask.set! j false + let some j := ys.getIdx? lhs | unreachable! + let ys := ys.eraseIdx j + let some k := args.getIdx? lhs | unreachable! + let mask := mask.set! k false let args := args.map fun arg => if arg == lhs then rhs else arg let args := args.push (← mkEqRefl rhs) let typeNew := typeNew.replaceFVar lhs rhs @@ -114,7 +115,7 @@ where withLocalDeclD n d fun eq => do let typeNew := b.instantiate1 eq go ys (eqs.push eq) (args.push arg) (mask.push false) (i+1) typeNew - | _ => + else let type ← unfoldNamedPattern type /- Recall that alternatives that do not have variables have a `Unit` parameter to ensure they are not eagerly evaluated. -/ @@ -628,10 +629,11 @@ private partial def mkEquationsFor (matchDeclName : Name) : MetaM MatchEqns := let mut altArgMasks := #[] -- masks produced by `forallAltTelescope` for i in [:alts.size] do let altNumParams := matchInfo.altNumParams[i]! - let altNonEqNumParams := altNumParams - numDiscrEqs let thmName := baseName ++ ((`eq).appendIndexAfter idx) eqnNames := eqnNames.push thmName - let (notAlt, splitterAltType, splitterAltNumParam, argMask) ← forallAltTelescope (← inferType alts[i]!) altNonEqNumParams fun ys eqs rhsArgs argMask altResultType => do + let (notAlt, splitterAltType, splitterAltNumParam, argMask) ← + forallAltTelescope (← inferType alts[i]!) altNumParams numDiscrEqs + fun ys eqs rhsArgs argMask altResultType => do let patterns := altResultType.getAppArgs let mut hs := #[] for notAlt in notAlts do