diff --git a/src/Lean/Meta/Match/Match.lean b/src/Lean/Meta/Match/Match.lean index 50b026574c..73b7395ef2 100644 --- a/src/Lean/Meta/Match/Match.lean +++ b/src/Lean/Meta/Match/Match.lean @@ -804,22 +804,25 @@ def mkMatcher (input : MkMatcherInput) : MetaM MatcherResult := unusedAltIdxs := unusedAltIdxs.reverse, addMatcher } -def withMkMatcherInput (matcherName : Name) (k : MkMatcherInput → MetaM α) : MetaM α := do +def getMkMatcherInputInContext (matcherApp : MatcherApp) : MetaM MkMatcherInput := do + let matcherName := matcherApp.matcherName let some matcherInfo ← getMatcherInfo? matcherName | throwError "not a matcher: {matcherName}" let matcherConst ← getConstInfo matcherName - forallBoundedTelescope matcherConst.type (some matcherInfo.arity) fun xs t => do - let params := xs[:matcherInfo.numParams] - let motive := xs[matcherInfo.numParams] - let discrs := xs[matcherInfo.numParams + 1:matcherInfo.numParams + 1 + matcherInfo.numDiscrs] - let alts := xs[matcherInfo.numParams + 1 + matcherInfo.numDiscrs:] - let u := - if let some idx := matcherInfo.uElimPos? - then mkLevelParam matcherConst.levelParams.toArray[idx] - else levelZero - let matchType ← mkForallFVars discrs (mkConst ``PUnit [u]) - let lhss ← alts.toArray.mapIdxM fun idx t => do - let ty ← inferType t - forallTelescope ty fun xs body => do + let matcherType ← instantiateForall matcherConst.type $ matcherApp.params ++ #[matcherApp.motive] + let matchType ← do + let u := + if let some idx := matcherInfo.uElimPos? + then mkLevelParam matcherConst.levelParams.toArray[idx] + else levelZero + + forallBoundedTelescope matcherType (some matcherInfo.numDiscrs) fun discrs t => do + mkForallFVars discrs (mkConst ``PUnit [u]) + + let matcherType ← instantiateForall matcherType matcherApp.discrs + let lhss ← Array.toList $ ←forallBoundedTelescope matcherType (some matcherApp.alts.size) fun alts _ => + alts.mapM fun alt => do + let ty ← inferType alt + forallTelescope ty fun xs body => do let xs ← xs.filterM fun x => dependsOn body x.fvarId! body.withApp fun f args => do let ctx ← getLCtx @@ -829,7 +832,21 @@ def withMkMatcherInput (matcherName : Name) (k : MkMatcherInput → MetaM α) : ref := Syntax.missing fvarDecls := localDecls.toList patterns := patterns.toList : Match.AltLHS } - k { matcherName, matchType, numDiscrs := matcherInfo.numDiscrs, lhss := lhss.toList } + + return { matcherName, matchType, numDiscrs := matcherApp.discrs.size, lhss } + + +def withMkMatcherInput + (matcherName : Name) + (k : MkMatcherInput → MetaM α) : MetaM α := do + let some matcherInfo ← getMatcherInfo? matcherName | throwError "not a matcher: {matcherName}" + let matcherConst ← getConstInfo matcherName + forallBoundedTelescope matcherConst.type (some matcherInfo.arity) fun xs t => do + let matcherApp ← mkConstWithLevelParams matcherConst.name + let matcherApp := mkAppN matcherApp xs + let some matcherApp ← matchMatcherApp? matcherApp | throwError "not a matcher app: {matcherApp}" + let mkMatcherInput ← getMkMatcherInputInContext matcherApp + k mkMatcherInput end Match