refactor: use match compilation to generate splitter (#11220)

This PR changes how match splitters are generated: Rather than rewriting
the match statement, the match compilation pipeline is used again.


The benefits are:

* Re-doing the match compilation means we can do more intelligent book
keeping, e.g. prove overlap assumptions only once and re-use the proof,
or prune the context of the MVar to speed up `contradiction`. This may
have allowed a different solution than #11200.
 
* It would unblock #11105, as the existing splitter implementation would
have trouble dealing with the matchers produced that way.
 
* It provides the necessary machinery also for source-exposed “none of
the above” bindings, a feature that we probably want at some point (and
we mostly need to find good syntax for, see #3136, although maybe I
should open a dedicated RFC).

* It allows us to skip costly things during matcher creation that would
only be useful for the splitter, and thus allows performance
improvements like #11508.
 
 * We can drop the existing implementation.
 
It’s not entirely free:

* We have to run `simpH` twice, once for the match equations and once
for the splitter.
This commit is contained in:
Joachim Breitner 2025-12-04 16:03:13 +01:00 committed by GitHub
parent 31d629cb67
commit af6d2077a0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 178 additions and 314 deletions

View file

@ -8,10 +8,12 @@ module
prelude
public import Lean.Elab.PreDefinition.FixedParams
import Lean.Elab.PreDefinition.EqnsUtils
import Lean.Meta.Tactic.Split
import Lean.Meta.Tactic.CasesOnStuckLHS
import Lean.Meta.Tactic.Delta
import Lean.Meta.Tactic.Simp.Main
import Lean.Meta.Tactic.Delta
import Lean.Meta.Tactic.CasesOnStuckLHS
import Lean.Meta.Tactic.Split
namespace Lean.Elab
open Meta

View file

@ -213,7 +213,9 @@ public def mkCasesOnSameCtor (declName : Name) (indName : Name) : MetaM Unit :=
numDiscrs := info.numIndices + 3
altInfos
uElimPos? := some 0
discrInfos := #[{}, {}, {}]}
discrInfos := #[{}, {}, {}]
overlaps := {}
}
-- Compare attributes with `mkMatcherAuxDefinition`
withExporting (isExporting := !isPrivateName declName) do

View file

@ -319,7 +319,7 @@ public partial def mkBelowMatcher (matcherApp : MatcherApp) (belowParams : Array
(ctx : RecursionContext) (transformAlt : RecursionContext → Expr → MetaM Expr) :
MetaM (Option (Expr × MetaM Unit)) :=
withTraceNode `Meta.IndPredBelow.match (return m!"{exceptEmoji ·} {matcherApp.toExpr} and {belowParams}") do
let mut input ← getMkMatcherInputInContext matcherApp
let mut input ← getMkMatcherInputInContext matcherApp (unfoldNamed := false)
let mut discrs := matcherApp.discrs
let mut matchTypeAdd := #[] -- #[(discrIdx, ), ...]
let mut i := discrs.size

View file

@ -150,6 +150,11 @@ structure Alt where
After we perform additional case analysis, their types become definitionally equal.
-/
cnstrs : List (Expr × Expr)
/--
Indices of previous alternatives that this alternative expects a not-that-proofs.
(When producing a splitter, and in the future also for source-level overlap hypotheses.)
-/
notAltIdxs : Array Nat
deriving Inhabited
namespace Alt

View file

@ -12,7 +12,11 @@ public import Lean.Meta.GeneralizeTelescope
public import Lean.Meta.Match.Basic
public import Lean.Meta.Match.MatcherApp.Basic
public import Lean.Meta.Match.MVarRenaming
public import Lean.Meta.Match.MVarRenaming
import Lean.Meta.Match.SimpH
import Lean.Meta.Match.SolveOverlap
import Lean.Meta.HasNotBit
import Lean.Meta.Match.NamedPatterns
public section
@ -92,34 +96,62 @@ where
/-- Given a list of `AltLHS`, create a minor premise for each one, convert them into `Alt`, and then execute `k` -/
private def withAlts {α} (motive : Expr) (discrs : Array Expr) (discrInfos : Array DiscrInfo)
(lhss : List AltLHS) (k : List Alt → Array Expr → Array AltParamInfo → MetaM α) : MetaM α :=
loop lhss [] #[] #[]
(lhss : List AltLHS) (isSplitter : Option Overlaps)
(k : List Alt → Array Expr → Array AltParamInfo → MetaM α) : MetaM α :=
loop lhss [] #[] #[] #[]
where
mkMinorType (xs : Array Expr) (lhs : AltLHS) : MetaM Expr :=
mkSplitterHyps (idx : Nat) (lhs : AltLHS) (notAlts : Array Expr) : MetaM (Array Expr × Array Nat) := do
withExistingLocalDecls lhs.fvarDecls do
let patterns ← lhs.patterns.toArray.mapM (Pattern.toExpr · (annotate := true))
let mut hs := #[]
let mut notAltIdxs := #[]
for overlappingIdx in isSplitter.get!.overlapping idx do
let notAlt := notAlts[overlappingIdx]!
let h ← instantiateForall notAlt patterns
if let some h ← simpH? h patterns.size then
notAltIdxs := notAltIdxs.push overlappingIdx
hs := hs.push h
trace[Meta.Match.debug] "hs for {lhs.ref}: {hs}"
return (hs, notAltIdxs)
mkMinorType (xs : Array Expr) (lhs : AltLHS) (notAltHs : Array Expr): MetaM Expr :=
withExistingLocalDecls lhs.fvarDecls do
let args ← lhs.patterns.toArray.mapM (Pattern.toExpr · (annotate := true))
let minorType := mkAppN motive args
withEqs discrs args discrInfos fun eqs => do
mkForallFVars (xs ++ eqs) minorType
let minorType ← mkForallFVars eqs minorType
let minorType ← mkArrowN notAltHs minorType
mkForallFVars xs minorType
loop (lhss : List AltLHS) (alts : List Alt) (minors : Array Expr) (altInfos : Array AltParamInfo) : MetaM α := do
mkNotAlt (xs : Array Expr) (lhs : AltLHS) : MetaM Expr := do
withExistingLocalDecls lhs.fvarDecls do
let mut notAlt := mkConst ``False
for discr in discrs.reverse, pattern in lhs.patterns.reverse do
notAlt ← mkArrow (← mkEqHEq discr (← pattern.toExpr)) notAlt
notAlt ← mkForallFVars (discrs ++ xs) notAlt
return notAlt
loop (lhss : List AltLHS) (alts : List Alt) (minors : Array Expr) (altInfos : Array AltParamInfo) (notAlts : Array Expr) : MetaM α := do
match lhss with
| [] => k alts.reverse minors altInfos
| lhs::lhss =>
let xs := lhs.fvarDecls.toArray.map LocalDecl.toExpr
let minorType ← mkMinorType xs lhs
let hasParams := !xs.isEmpty || discrInfos.any fun info => info.hName?.isSome
let minorType := if hasParams then minorType else mkSimpleThunkType minorType
let idx := alts.length
let xs := lhs.fvarDecls.toArray.map LocalDecl.toExpr
let (notAltHs, notAltIdxs) ← if isSplitter.isSome then mkSplitterHyps idx lhs notAlts else pure (#[], #[])
let minorType ← mkMinorType xs lhs notAltHs
let notAlt ← mkNotAlt xs lhs
let hasParams := !xs.isEmpty || !notAltHs.isEmpty || discrInfos.any fun info => info.hName?.isSome
let minorType := if hasParams then minorType else mkSimpleThunkType minorType
let minorName := (`h).appendIndexAfter (idx+1)
trace[Meta.Match.debug] "minor premise {minorName} : {minorType}"
withLocalDeclD minorName minorType fun minor => do
let rhs := if hasParams then mkAppN minor xs else mkApp minor (mkConst `Unit.unit)
let minors := minors.push minor
let altInfos := altInfos.push { numFields := xs.size, numOverlaps := 0, hasUnitThunk := !hasParams }
let altInfos := altInfos.push { numFields := xs.size, numOverlaps := notAltHs.size, hasUnitThunk := !hasParams }
let fvarDecls ← lhs.fvarDecls.mapM instantiateLocalDeclMVars
let alts := { ref := lhs.ref, idx := idx, rhs := rhs, fvarDecls := fvarDecls, patterns := lhs.patterns, cnstrs := [] } :: alts
loop lhss alts minors altInfos
let alt := { ref := lhs.ref, idx := idx, rhs := rhs, fvarDecls := fvarDecls, patterns := lhs.patterns, cnstrs := [], notAltIdxs := notAltIdxs }
let alts := alt :: alts
loop lhss alts minors altInfos (notAlts.push notAlt)
structure State where
/-- Used alternatives -/
@ -338,7 +370,7 @@ where
return (p, (lhs, rhs) :: cnstrs)
/--
Solve pending alternative constraints.
Solve pending alternative constraints and overlap assumptions.
If all constraints can be solved perform assignment `mvarId := alt.rhs`, else throw error.
-/
private partial def solveCnstrs (mvarId : MVarId) (alt : Alt) : StateRefT State MetaM Unit := do
@ -350,13 +382,19 @@ where
| none =>
let alt ← filterTrivialCnstrs alt
if alt.cnstrs.isEmpty then
let eType ← inferType alt.rhs
let targetType ← mvarId.getType
unless (← isDefEqGuarded targetType eType) do
trace[Meta.Match.match] "assignGoalOf failed {eType} =?= {targetType}"
throwErrorAt alt.ref "Dependent elimination failed: Type mismatch when solving this alternative: it {← mkHasTypeButIsExpectedMsg eType targetType}"
mvarId.assign alt.rhs
modify fun s => { s with used := s.used.insert alt.idx }
mvarId.withContext do
let eType ← inferType alt.rhs
let (notAltsMVarIds, _, eType) ← forallMetaBoundedTelescope eType alt.notAltIdxs.size
unless notAltsMVarIds.size = alt.notAltIdxs.size do
throwErrorAt alt.ref "Incorrect number of overlap hypotheses in the right-hand-side, expected {alt.notAltIdxs.size}:{indentExpr eType}"
let targetType ← mvarId.getType
unless (← isDefEqGuarded targetType eType) do
trace[Meta.Match.match] "assignGoalOf failed {eType} =?= {targetType}"
throwErrorAt alt.ref "Dependent elimination failed: Type mismatch when solving this alternative: it {← mkHasTypeButIsExpectedMsg eType targetType}"
for notAltMVarId in notAltsMVarIds do
solveOverlap notAltMVarId.mvarId!
mvarId.assign (mkAppN alt.rhs notAltsMVarIds)
modify fun s => { s with used := s.used.insert alt.idx }
else
trace[Meta.Match.match] "alt has unsolved cnstrs:\n{← alt.toMessageData}"
let mut msg := m!"Dependent match elimination failed: Could not solve constraints"
@ -636,7 +674,7 @@ private def processConstructor (p : Problem) : MetaM (Array Problem) := do
| .var _ :: _ => expandVarIntoCtor alt ctorName
| .inaccessible _ :: _ => processInaccessibleAsCtor alt ctorName
| _ => unreachable!
return { mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples }
return { p with mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples }
else
-- A catch-all case
let subst := subgoal.subst
@ -647,7 +685,7 @@ private def processConstructor (p : Problem) : MetaM (Array Problem) := do
| .ctor .. :: _ => false
| _ => true
let newAlts := newAlts.map fun alt => alt.applyFVarSubst subst
return { mvarId := subgoal.mvarId, alts := newAlts, vars := newVars, examples := examples }
return { p with mvarId := subgoal.mvarId, alts := newAlts, vars := newVars, examples := examples }
private def processNonVariable (p : Problem) : MetaM Problem := withGoalOf p do
let x :: xs := p.vars | unreachable!
@ -708,7 +746,7 @@ private def processValue (p : Problem) : MetaM (Array Problem) := do
alt.replaceFVarId fvarId value
| _ => unreachable!
let newVars := xs.map fun x => x.applyFVarSubst subst
return { mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples }
return { p with mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples }
else
-- else branch for value
let newAlts := p.alts.filter isFirstPatternVar
@ -764,7 +802,7 @@ private def processArrayLit (p : Problem) : MetaM (Array Problem) := do
let α ← getArrayArgType <| subst.apply x
expandVarIntoArrayLit { alt with patterns := ps } fvarId α size
| _ => unreachable!
return { mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples }
return { p with mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples }
else
-- else branch
let newAlts := p.alts.filter isFirstPatternVar
@ -1018,7 +1056,7 @@ private builtin_initialize matcherExt : EnvExtension (PHashMap MatcherKey Name)
/-- Similar to `mkAuxDefinition`, but uses the cache `matcherExt`.
It also returns an Boolean that indicates whether a new matcher function was added to the environment or not. -/
def mkMatcherAuxDefinition (name : Name) (type : Expr) (value : Expr) : MetaM (Expr × Option (MatcherInfo → MetaM Unit)) := do
def mkMatcherAuxDefinition (name : Name) (type : Expr) (value : Expr) (isSplitter : Bool) : MetaM (Expr × Option (MatcherInfo → MetaM Unit)) := do
trace[Meta.Match.debug] "{name} : {type} := {value}"
let compile := bootstrap.genMatcherCode.get (← getOptions)
let result ← Closure.mkValueTypeClosure type value (zetaDelta := false)
@ -1026,10 +1064,12 @@ def mkMatcherAuxDefinition (name : Name) (type : Expr) (value : Expr) : MetaM (E
let mkMatcherConst name :=
mkAppN (mkConst name result.levelArgs.toList) result.exprArgs
let key := { value := result.value, compile, isPrivate := env.header.isModule && isPrivateName name }
let mut nameNew? := (matcherExt.getState env).find? key
if nameNew?.isNone && key.isPrivate then
-- private contexts may reuse public matchers
nameNew? := (matcherExt.getState env).find? { key with isPrivate := false }
let mut nameNew? := none
unless isSplitter do
nameNew? := (matcherExt.getState env).find? key
if nameNew?.isNone && key.isPrivate then
-- private contexts may reuse public matchers
nameNew? := (matcherExt.getState env).find? { key with isPrivate := false }
match nameNew? with
| some nameNew => return (mkMatcherConst nameNew, none)
| none =>
@ -1040,8 +1080,9 @@ def mkMatcherAuxDefinition (name : Name) (type : Expr) (value : Expr) : MetaM (E
-- matcher bodies should always be exported, if not private anyway
withExporting do
addDecl decl
modifyEnv fun env => matcherExt.modifyState env fun s => s.insert key name
addMatcherInfo name mi
unless isSplitter do
modifyEnv fun env => matcherExt.modifyState env fun s => s.insert key name
addMatcherInfo name mi
setInlineAttribute name
enableRealizationsForConst name
if compile then
@ -1053,6 +1094,7 @@ structure MkMatcherInput where
matchType : Expr
discrInfos : Array DiscrInfo
lhss : List AltLHS
isSplitter : Option Overlaps := none
def MkMatcherInput.numDiscrs (m : MkMatcherInput) :=
m.discrInfos.size
@ -1093,7 +1135,7 @@ The generated matcher has the structure described at `MatcherInfo`. The motive a
where `v` is a universe parameter or 0 if `B[a_1, ..., a_n]` is a proposition.
-/
def mkMatcher (input : MkMatcherInput) : MetaM MatcherResult := withCleanLCtxFor input do
let ⟨matcherName, matchType, discrInfos, lhss⟩ := input
let {matcherName, matchType, discrInfos, lhss, isSplitter} := input
let numDiscrs := discrInfos.size
checkNumPatterns numDiscrs lhss
forallBoundedTelescope matchType numDiscrs fun discrs matchTypeBody => do
@ -1116,7 +1158,7 @@ def mkMatcher (input : MkMatcherInput) : MetaM MatcherResult := withCleanLCtxFor
| negSucc n => succ n
```
which is defined **before** `Int.decLt` -/
let (matcher, addMatcher) ← mkMatcherAuxDefinition matcherName type val
let (matcher, addMatcher) ← mkMatcherAuxDefinition matcherName type val (isSplitter := input.isSplitter.isSome)
trace[Meta.Match.debug] "matcher levels: {matcher.getAppFn.constLevels!}, uElim: {uElimGen}"
let uElimPos? ← getUElimPos? matcher.getAppFn.constLevels! uElimGen
discard <| isLevelDefEq uElimGen uElim
@ -1152,7 +1194,7 @@ def mkMatcher (input : MkMatcherInput) : MetaM MatcherResult := withCleanLCtxFor
let isEqMask ← eqs.mapM fun eq => return (← inferType eq).isEq
return (mvarType, isEqMask)
trace[Meta.Match.debug] "target: {mvarType}"
withAlts motive discrs discrInfos lhss fun alts minors altInfos => do
withAlts motive discrs discrInfos lhss isSplitter fun alts minors altInfos => do
let mvar ← mkFreshExprMVar mvarType
trace[Meta.Match.debug] "goal\n{mvar.mvarId!}"
let examples := discrs'.toList.map fun discr => Example.var discr.fvarId!
@ -1176,7 +1218,7 @@ def mkMatcher (input : MkMatcherInput) : MetaM MatcherResult := withCleanLCtxFor
else
let mvarType := mkAppN motive discrs
trace[Meta.Match.debug] "target: {mvarType}"
withAlts motive discrs discrInfos lhss fun alts minors altInfos => do
withAlts motive discrs discrInfos lhss isSplitter fun alts minors altInfos => do
let mvar ← mkFreshExprMVar mvarType
let examples := discrs.toList.map fun discr => Example.var discr.fvarId!
let (_, s) ← (process { mvarId := mvar.mvarId!, vars := discrs.toList, alts := alts, examples := examples }).run {}
@ -1185,7 +1227,7 @@ def mkMatcher (input : MkMatcherInput) : MetaM MatcherResult := withCleanLCtxFor
let val ← mkLambdaFVars args mvar
mkMatcher type val altInfos s
def getMkMatcherInputInContext (matcherApp : MatcherApp) : MetaM MkMatcherInput := do
def getMkMatcherInputInContext (matcherApp : MatcherApp) (unfoldNamed : Bool) : MetaM MkMatcherInput := do
let matcherName := matcherApp.matcherName
let some matcherInfo ← getMatcherInfo? matcherName
| throwError "Internal error during match expression elaboration: Could not find a matcher named `{matcherName}`"
@ -1204,6 +1246,7 @@ def getMkMatcherInputInContext (matcherApp : MatcherApp) : MetaM MkMatcherInput
let lhss ← forallBoundedTelescope matcherType (some matcherApp.alts.size) fun alts _ =>
alts.mapM fun alt => do
let ty ← inferType alt
let ty ← if unfoldNamed then unfoldNamedPattern ty else pure ty
forallTelescope ty fun xs body => do
let xs ← xs.filterM fun x => dependsOn body x.fvarId!
body.withApp fun _ args => do
@ -1217,18 +1260,17 @@ def getMkMatcherInputInContext (matcherApp : MatcherApp) : MetaM MkMatcherInput
return { matcherName, matchType, discrInfos := matcherInfo.discrInfos, lhss := lhss.toList }
/-- This function is only used for testing purposes -/
def withMkMatcherInput (matcherName : Name) (k : MkMatcherInput → MetaM α) : MetaM α := do
def withMkMatcherInput (matcherName : Name) (unfoldNamed : Bool) (k : MkMatcherInput → MetaM α) : MetaM α := do
let some matcherInfo ← getMatcherInfo? matcherName
| throwError "Internal error during match expression elaboration: Could not find a matcher named `{matcherName}`"
| throwError "withMkMatcherInput: {.ofConstName matcherName} is not a matcher"
let matcherConst ← getConstInfo matcherName
forallBoundedTelescope matcherConst.type (some matcherInfo.arity) fun xs _ => do
let matcherApp ← mkConstWithLevelParams matcherConst.name
let matcherApp := mkAppN matcherApp xs
let some matcherApp ← matchMatcherApp? matcherApp
| throwError "Internal error during match expression elaboration: Could not find a matcher app named `{matcherApp}`"
let mkMatcherInput ← getMkMatcherInputInContext matcherApp
k mkMatcherInput
forallBoundedTelescope matcherConst.type matcherInfo.arity fun xs _ => do
let matcherApp ← mkConstWithLevelParams matcherConst.name
let matcherApp := mkAppN matcherApp xs
let some matcherApp ← matchMatcherApp? matcherApp
| throwError "withMkMatcherInput: {.ofConstName matcherName} does not produce a matcher application"
let mkMatcherInput ← getMkMatcherInputInContext matcherApp unfoldNamed
k mkMatcherInput
end Match

View file

@ -110,220 +110,6 @@ where
(throwError "failed to generate equality theorems for `match` expression `{matchDeclName}`\n{MessageData.ofGoal mvarId}")
subgoals.forM (go · (depth+1))
/-- Construct new local declarations `xs` with types `altTypes`, and then execute `f xs` -/
private partial def withSplitterAlts (altTypes : Array Expr) (f : Array Expr → MetaM α) : MetaM α := do
let rec go (i : Nat) (xs : Array Expr) : MetaM α := do
if h : i < altTypes.size then
let hName := (`h).appendIndexAfter (i+1)
withLocalDeclD hName altTypes[i] fun x =>
go (i+1) (xs.push x)
else
f xs
go 0 #[]
private abbrev ConvertM := ReaderT (FVarIdMap (Expr × AltParamInfo × Array Bool)) $ StateRefT (Array MVarId) MetaM
/--
Construct a proof for the splitter generated by `mkEquationsFor`.
The proof uses the definition of the `match`-declaration as a template (argument `template`).
- `alts` are free variables corresponding to alternatives of the `match` auxiliary declaration being processed.
- `altNews` are the new free variables which contains additional hypotheses that ensure they are only used
when the previous overlapping alternatives are not applicable.
- `altInfos` refers to the splitter -/
private partial def mkSplitterProof (matchDeclName : Name) (template : Expr) (alts altsNew : Array Expr)
(altInfos : Array AltParamInfo) (altArgMasks : Array (Array Bool)) : MetaM Expr := do
trace[Meta.Match.matchEqs] "proof template: {template}"
let map := mkMap
let (proof, mvarIds) ← convertTemplate template |>.run map |>.run #[]
trace[Meta.Match.matchEqs] "splitter proof: {proof}"
for mvarId in mvarIds do
let mvarId ← mvarId.tryClearMany (alts.map (·.fvarId!))
solveOverlap mvarId
instantiateMVars proof
where
mkMap : FVarIdMap (Expr × AltParamInfo × Array Bool) := Id.run do
let mut m := {}
for alt in alts, altNew in altsNew, altInfo in altInfos, argMask in altArgMasks do
m := m.insert alt.fvarId! (altNew, altInfo, argMask)
return m
trimFalseTrail (argMask : Array Bool) : Array Bool :=
if argMask.isEmpty then
argMask
else if !argMask.back! then
trimFalseTrail argMask.pop
else
argMask
/--
Auxiliary function used at `convertTemplate` to decide whether to use `convertCastEqRec`.
See `convertCastEqRec`. -/
isCastEqRec (e : Expr) : ConvertM Bool := do
-- TODO: we do not handle `Eq.rec` since we never found an example that needed it.
-- If we find one we must extend `convertCastEqRec`.
unless e.isAppOf ``Eq.ndrec do return false
unless e.getAppNumArgs > 6 do return false
for arg in e.getAppArgs[6...*] do
if arg.isFVar && (← read).contains arg.fvarId! then
return true
return true
/--
Auxiliary function used at `convertTemplate`. It is needed when the auxiliary `match` declaration had to refine the type of its
minor premises during dependent pattern match. For an example, consider
```
inductive Foo : Nat → Type _
| nil : Foo 0
| cons (t: Foo l): Foo l
def Foo.bar (t₁: Foo l₁): Foo l₂ → Bool
| cons s₁ => t₁.bar s₁
| _ => false
attribute [simp] Foo.bar
```
The auxiliary `Foo.bar.match_1` is of the form
```
def Foo.bar.match_1.{u_1} : {l₂ : Nat} →
(t₂ : Foo l₂) →
(motive : Foo l₂ → Sort u_1) →
(t₂ : Foo l₂) → ((s₁ : Foo l₂) → motive (Foo.cons s₁)) → ((x : Foo l₂) → motive x) → motive t₂ :=
fun {l₂} t₂ motive t₂_1 h_1 h_2 =>
(fun t₂_2 =>
Foo.casesOn (motive := fun a x => l₂ = a → t₂_1 ≍ x → motive t₂_1) t₂_2
(fun h =>
Eq.ndrec (motive := fun {l₂} =>
(t₂ t₂ : Foo l₂) →
(motive : Foo l₂ → Sort u_1) →
((s₁ : Foo l₂) → motive (Foo.cons s₁)) → ((x : Foo l₂) → motive x) → t₂ ≍ Foo.nil → motive t₂)
(fun t₂ t₂ motive h_1 h_2 h => Eq.symm (eq_of_heq h) ▸ h_2 Foo.nil) (Eq.symm h) t₂ t₂_1 motive h_1 h_2) --- HERE
fun {l} t h =>
Eq.ndrec (motive := fun {l} => (t : Foo l) → t₂_1 ≍ Foo.cons t → motive t₂_1)
(fun t h => Eq.symm (eq_of_heq h) ▸ h_1 t) h t)
t₂_1 (Eq.refl l₂) (HEq.refl t₂_1)
```
The `HERE` comment marks the place where the type of `Foo.bar.match_1` minor premises `h_1` and `h_2` is being "refined"
using `Eq.ndrec`.
This function will adjust the motive and minor premise of the `Eq.ndrec` to reflect the new minor premises used in the
corresponding splitter theorem.
We may have to extend this function to handle `Eq.rec` too.
This function was added to address issue #1179
-/
convertCastEqRec (e : Expr) : ConvertM Expr := do
assert! (← isCastEqRec e)
e.withApp fun f args => do
let mut argsNew := args
let mut isAlt := #[]
for i in 6...args.size do
let arg := argsNew[i]!
if arg.isFVar then
match (← read).get? arg.fvarId! with
| some (altNew, _, _) =>
argsNew := argsNew.set! i altNew
trace[Meta.Match.matchEqs] "arg: {arg} : {← inferType arg}, altNew: {altNew} : {← inferType altNew}"
isAlt := isAlt.push true
| none =>
argsNew := argsNew.set! i (← convertTemplate arg)
isAlt := isAlt.push false
else
argsNew := argsNew.set! i (← convertTemplate arg)
isAlt := isAlt.push false
assert! isAlt.size == args.size - 6
let rhs := args[4]!
let motive := args[2]!
-- Construct new motive using the splitter theorem minor premise types.
let motiveNew ← lambdaTelescope motive fun motiveArgs body => do
unless motiveArgs.size == 1 do
throwError "unexpected `Eq.ndrec` motive while creating splitter/eliminator theorem for `{matchDeclName}`, expected lambda with 1 binder{indentExpr motive}"
let x := motiveArgs[0]!
forallTelescopeReducing body fun motiveTypeArgs resultType => do
unless motiveTypeArgs.size >= isAlt.size do
throwError "unexpected `Eq.ndrec` motive while creating splitter/eliminator theorem for `{matchDeclName}`, expected arrow with at least #{isAlt.size} binders{indentExpr body}"
let rec go (i : Nat) (motiveTypeArgsNew : Array Expr) : ConvertM Expr := do
assert! motiveTypeArgsNew.size == i
if h : i < motiveTypeArgs.size then
let motiveTypeArg := motiveTypeArgs[i]
if i < isAlt.size && isAlt[i]! then
let altNew := argsNew[6+i]! -- Recall that `Eq.ndrec` has 6 arguments
let altTypeNew ← inferType altNew
trace[Meta.Match.matchEqs] "altNew: {altNew} : {altTypeNew}"
-- Replace `rhs` with `x` (the lambda binder in the motive)
let mut altTypeNewAbst := (← kabstract altTypeNew rhs).instantiate1 x
-- Replace args[6...(6+i)] with `motiveTypeArgsNew`
for j in *...i do
altTypeNewAbst := (← kabstract altTypeNewAbst argsNew[6+j]!).instantiate1 motiveTypeArgsNew[j]!
let localDecl ← motiveTypeArg.fvarId!.getDecl
withLocalDecl localDecl.userName localDecl.binderInfo altTypeNewAbst fun motiveTypeArgNew =>
go (i+1) (motiveTypeArgsNew.push motiveTypeArgNew)
else
go (i+1) (motiveTypeArgsNew.push motiveTypeArg)
else
mkLambdaFVars motiveArgs (← mkForallFVars motiveTypeArgsNew resultType)
go 0 #[]
trace[Meta.Match.matchEqs] "new motive: {motiveNew}"
unless (← isTypeCorrect motiveNew) do
throwError "failed to construct new type correct motive for `Eq.ndrec` while creating splitter/eliminator theorem for `{matchDeclName}`{indentExpr motiveNew}"
argsNew := argsNew.set! 2 motiveNew
-- Construct the new minor premise for the `Eq.ndrec` application.
-- First, we use `eqRecNewPrefix` to infer the new minor premise binders for `Eq.ndrec`
let eqRecNewPrefix := mkAppN f argsNew[*...3] -- `Eq.ndrec` minor premise is the fourth argument.
let .forallE _ minorTypeNew .. ← whnf (← inferType eqRecNewPrefix) | unreachable!
trace[Meta.Match.matchEqs] "new minor type: {minorTypeNew}"
let minor := args[3]!
let minorNew ← forallBoundedTelescope minorTypeNew isAlt.size fun minorArgsNew _ => do
let mut minorBodyNew := minor
-- We have to extend the mapping to make sure `convertTemplate` can "fix" occurrences of the refined minor premises
let mut m ← read
for h : i in *...isAlt.size do
if isAlt[i] then
-- `convertTemplate` will correct occurrences of the alternative
let alt := args[6+i]! -- Recall that `Eq.ndrec` has 6 arguments
let some (_, numParams, argMask) := m.get? alt.fvarId! | unreachable!
-- We add a new entry to `m` to make sure `convertTemplate` will correct the occurrences of the alternative
m := m.insert minorArgsNew[i]!.fvarId! (minorArgsNew[i]!, numParams, argMask)
unless minorBodyNew.isLambda do
throwError "unexpected `Eq.ndrec` minor premise while creating splitter/eliminator theorem for `{matchDeclName}`, expected lambda with at least #{isAlt.size} binders{indentExpr minor}"
minorBodyNew := minorBodyNew.bindingBody!
minorBodyNew := minorBodyNew.instantiateRev minorArgsNew
trace[Meta.Match.matchEqs] "minor premise new body before convertTemplate:{indentExpr minorBodyNew}"
minorBodyNew ← withReader (fun _ => m) <| convertTemplate minorBodyNew
trace[Meta.Match.matchEqs] "minor premise new body after convertTemplate:{indentExpr minorBodyNew}"
mkLambdaFVars minorArgsNew minorBodyNew
unless (← isTypeCorrect minorNew) do
throwError "failed to construct new type correct minor premise for `Eq.ndrec` while creating splitter/eliminator theorem for `{matchDeclName}`{indentExpr minorNew}"
argsNew := argsNew.set! 3 minorNew
-- trace[Meta.Match.matchEqs] "argsNew: {argsNew}"
trace[Meta.Match.matchEqs] "found cast target {e}"
return mkAppN f argsNew
convertTemplate (e : Expr) : ConvertM Expr :=
transform e fun e => do
if (← isCastEqRec e) then
return .done (← convertCastEqRec e)
else
let Expr.fvar fvarId .. := e.getAppFn | return .continue
let some (altNew, altParamInfo, argMask) := (← read).get? fvarId | return .continue
trace[Meta.Match.matchEqs] ">> argMask: {argMask}, altParamInfo: {repr altParamInfo}, e: {e}, alsNew: {altNew}, "
if altParamInfo.hasUnitThunk then
let eNew := mkApp altNew (mkConst ``Unit.unit)
return TransformStep.done eNew
let mut newArgs := #[]
let argMask := trimFalseTrail argMask
unless e.getAppNumArgs ≥ argMask.size do
throwError "unexpected occurrence of `match`-expression alternative (aka minor premise) while creating splitter/eliminator theorem for `{matchDeclName}`, minor premise is partially applied{indentExpr e}\npossible solution if you are matching on inductive families: add its indices as additional discriminants"
for arg in e.getAppArgs, includeArg in argMask do
if includeArg then
newArgs := newArgs.push arg
let eNew := mkAppN altNew newArgs
let (mvars, _, _) ← forallMetaBoundedTelescope (← inferType eNew) altParamInfo.numOverlaps (kind := MetavarKind.syntheticOpaque)
modify fun s => s ++ (mvars.map (·.mvarId!))
let eNew := mkAppN eNew mvars
return TransformStep.done eNew
/--
Create new alternatives (aka minor premises) by replacing `discrs` with `patterns` at `alts`.
Recall that `alts` depends on `discrs` when `numDiscrEqs > 0`, where `numDiscrEqs` is the number of discriminants
@ -364,13 +150,15 @@ def getEquationsForImpl (matchDeclName : Name) : MetaM MatchEqns := do
-- `realizeConst` as well as for looking up the resultant environment extension state via
-- `getState`.
realizeConst matchDeclName splitterName (go baseName splitterName)
return matchEqnsExt.getState (asyncMode := .async .asyncEnv) (asyncDecl := splitterName) (← getEnv) |>.map.find! matchDeclName
match matchEqnsExt.getState (asyncMode := .async .asyncEnv) (asyncDecl := splitterName) (← getEnv) |>.map.find? matchDeclName with
| some eqns => return eqns
| none => throwError "failed to retrieve match equations for `{matchDeclName}` after realization"
where go baseName splitterName := withConfig (fun c => { c with etaStruct := .none }) do
let constInfo ← getConstInfo matchDeclName
let us := constInfo.levelParams.map mkLevelParam
let some matchInfo ← getMatcherInfo? matchDeclName | throwError "`{matchDeclName}` is not a matcher function"
let numDiscrEqs := getNumEqsFromDiscrInfos matchInfo.discrInfos
forallTelescopeReducing constInfo.type fun xs matchResultType => do
forallTelescopeReducing constInfo.type fun xs _matchResultType => do
let mut eqnNames := #[]
let params := xs[*...matchInfo.numParams]
let motive := xs[matchInfo.getMotivePos]!
@ -379,16 +167,15 @@ where go baseName splitterName := withConfig (fun c => { c with etaStruct := .no
let discrs := xs[firstDiscrIdx...(firstDiscrIdx + matchInfo.numDiscrs)]
let mut notAlts := #[]
let mut idx := 1
let mut splitterAltTypes := #[]
let mut splitterAltInfos := #[]
let mut altArgMasks := #[] -- masks produced by `forallAltTelescope`
for i in *...alts.size do
let altInfo := matchInfo.altInfos[i]!
let thmName := Name.str baseName eqnThmSuffixBase |>.appendIndexAfter idx
eqnNames := eqnNames.push thmName
let (notAlt, splitterAltType, splitterAltInfo, argMask) ←
let (notAlt, splitterAltInfo, argMask) ←
forallAltTelescope (← inferType alts[i]!) altInfo numDiscrEqs
fun ys eqs rhsArgs argMask altResultType => do
fun ys _eqs rhsArgs argMask altResultType => do
let patterns := altResultType.getAppArgs
let mut hs := #[]
for overlappedBy in matchInfo.overlaps.overlapping i do
@ -397,15 +184,7 @@ where go baseName splitterName := withConfig (fun c => { c with etaStruct := .no
if let some h ← simpH? h patterns.size then
hs := hs.push h
trace[Meta.Match.matchEqs] "hs: {hs}"
let splitterAltType ← mkForallFVars eqs altResultType
let splitterAltType ← mkArrowN hs splitterAltType
let splitterAltType ← mkForallFVars ys splitterAltType
let hasUnitThunk := splitterAltType == altResultType
let splitterAltType ← if hasUnitThunk then
mkArrow (mkConst ``Unit) splitterAltType
else
pure splitterAltType
let splitterAltType ← unfoldNamedPattern splitterAltType
let hasUnitThunk := ys.isEmpty && hs.isEmpty && numDiscrEqs = 0
let splitterAltInfo := { numFields := ys.size, numOverlaps := hs.size, hasUnitThunk }
-- Create a proposition for representing terms that do not match `patterns`
let mut notAlt := mkConst ``False
@ -429,38 +208,38 @@ where go baseName splitterName := withConfig (fun c => { c with etaStruct := .no
type := thmType
value := thmVal
}
return (notAlt, splitterAltType, splitterAltInfo, argMask)
return (notAlt, splitterAltInfo, argMask)
notAlts := notAlts.push notAlt
splitterAltTypes := splitterAltTypes.push splitterAltType
splitterAltInfos := splitterAltInfos.push splitterAltInfo
altArgMasks := altArgMasks.push argMask
trace[Meta.Match.matchEqs] "splitterAltType: {splitterAltType}"
idx := idx + 1
-- Define splitter with conditional/refined alternatives
withSplitterAlts splitterAltTypes fun altsNew => do
let splitterParams := params.toArray ++ #[motive] ++ discrs.toArray ++ altsNew
let splitterType ← mkForallFVars splitterParams matchResultType
trace[Meta.Match.matchEqs] "splitterType: {splitterType}"
let splitterVal ←
if (← isDefEq splitterType constInfo.type) then
pure <| mkConst constInfo.name us
else
let template := mkAppN (mkConst constInfo.name us) (params ++ #[motive] ++ discrs ++ alts)
let template ← deltaExpand template (· == constInfo.name)
let template := template.headBeta
mkLambdaFVars splitterParams (← mkSplitterProof matchDeclName template alts altsNew splitterAltInfos altArgMasks)
let splitterMatchInfo : MatcherInfo := { matchInfo with altInfos := splitterAltInfos }
let needsSplitter := !matchInfo.overlaps.isEmpty || (constInfo.type.find? (isNamedPattern )).isSome
if needsSplitter then
withMkMatcherInput matchDeclName (unfoldNamed := true) fun matcherInput => do
let matcherInput := { matcherInput with
matcherName := splitterName
isSplitter := some matchInfo.overlaps
}
let res ← Match.mkMatcher matcherInput
res.addMatcher -- TODO: Do not set matcherinfo for the splitter!
else
assert! matchInfo.altInfos == splitterAltInfos
-- This match statement does not need a splitter, we can use itself for that.
-- (We still have to generate a declaration to satisfy the realizable constant)
addAndCompile <| Declaration.defnDecl {
name := splitterName
levelParams := constInfo.levelParams
type := splitterType
value := splitterVal
type := constInfo.type
value := mkConst matchDeclName us
hints := .abbrev
safety := .safe
}
setInlineAttribute splitterName
let splitterMatchInfo := { matchInfo with altInfos := splitterAltInfos }
let result := { eqnNames, splitterName, splitterMatchInfo }
registerMatchEqns matchDeclName result
let result := { eqnNames, splitterName, splitterMatchInfo }
registerMatchEqns matchDeclName result
/- We generate the equations and splitter on demand, and do not save them on .olean files. -/
builtin_initialize matchCongrEqnsExt : EnvExtension (PHashMap Name (Array Name)) ←

View file

@ -67,6 +67,7 @@ def matchMatcherApp? [Monad m] [MonadEnv m] [MonadError m] (e : Expr) (alsoCases
matcherName := declName
matcherLevels := declLevels.toArray
uElimPos?, discrInfos, params, motive, discrs, alts, remaining, altInfos
overlaps := {} -- CasesOn constructor have no overlaps
}
return none

View file

@ -23,6 +23,9 @@ structure Overlaps where
map : Std.HashMap Nat (Std.TreeSet Nat) := {}
deriving Inhabited, Repr
def Overlaps.isEmpty (o : Overlaps) : Bool :=
o.map.isEmpty
def Overlaps.insert (o : Overlaps) (overlapping overlapped : Nat) : Overlaps where
map := o.map.alter overlapped fun s? => some ((s?.getD {}).insert overlapping)
@ -41,29 +44,32 @@ structure AltParamInfo where
numOverlaps : Nat
/-- Whether this alternatie has an artifcial `Unit` parameter -/
hasUnitThunk : Bool
deriving Inhabited, Repr
deriving Inhabited, Repr, BEq
/--
A "matcher" auxiliary declaration has the following structure:
- `numParams` parameters
- motive
- `numDiscrs` discriminators (aka major premises)
- `altInfos.size` alternatives (aka minor premises) with parameter structure information
- `uElimPos?` is `some pos` when the matcher can eliminate in different universe levels, and
`pos` is the position of the universe level parameter that specifies the elimination universe.
It is `none` if the matcher only eliminates into `Prop`.
- `overlaps` indicates which alternatives may overlap another
Information about the structure of a matcher declaration
-/
structure MatcherInfo where
/-- Number of parameters -/
numParams : Nat
/-- Number of discriminants -/
numDiscrs : Nat
/-- Parameter structure information for each alternative -/
altInfos : Array AltParamInfo
/--
`uElimPos?` is `some pos` when the matcher can eliminate in different universe levels, and
`pos` is the position of the universe level parameter that specifies the elimination universe.
It is `none` if the matcher only eliminates into `Prop`.
-/
uElimPos? : Option Nat
/--
`discrInfos[i] = { hName? := some h }` if the i-th discriminant was annotated with `h :`.
`discrInfos[i] = { hName? := some h }` if the i-th discriminant was annotated with `h :`.
-/
discrInfos : Array DiscrInfo
overlaps : Overlaps := {}
/--
(Conservative approximation of) which alternatives may overlap another.
-/
overlaps : Overlaps
deriving Inhabited, Repr
@[expose] def MatcherInfo.numAlts (info : MatcherInfo) : Nat :=

View file

@ -1,5 +1,7 @@
#include "util/options.h"
// please update stage0
namespace lean {
options get_default_options() {
options opts;

View file

@ -38,8 +38,8 @@ info: Vec.match_on_same_ctor.{u_1, u} {α : Type u}
/--
info: Vec.match_on_same_ctor.splitter.{u_1, u} {α : Type u}
{motive : {a : Nat} → (t t_1 : Vec α a) → t.ctorIdx = t_1.ctorIdx → Sort u_1} {a✝ : Nat} (t t✝ : Vec α a✝)
(h : t.ctorIdx = t✝.ctorIdx) (h_1 : Unit → motive nil nil ⋯)
(h_2 : (a : α) → (n : Nat) → (a_1 : Vec α n) → (a' : α) → (a'_1 : Vec α n) → motive (cons a a_1) (cons a' a'_1) ⋯) :
(h : t.ctorIdx = t✝.ctorIdx) (nil : Unit → motive nil nil ⋯)
(cons : (a : α) → {n : Nat} → (a_1 : Vec α n) → (a' : α) → (a'_1 : Vec α n) → motive (cons a a_1) (cons a' a'_1) ⋯) :
motive t t✝ h
-/
#guard_msgs in

View file

@ -11,7 +11,7 @@ info: private def myTest.match_1.splitter.{u_1} : (motive : List Bool → Sort u
(x : List Bool) →
((x_1 : Bool) → (xs : List Bool) → x = x_1 :: xs → motive (x_1 :: xs)) → (x = [] → motive []) → motive x :=
fun motive x h_1 h_2 =>
List.casesOn (motive := fun x_1 => x = x_1 → motive x_1) x h_2 (fun head tail => h_1 head tail)
(fun x_1 => List.casesOn (motive := fun x_2 => x = x_2 → motive x_2) x_1 h_2 fun head tail => h_1 head tail) x
-/
#guard_msgs in
#print myTest.match_1.splitter

View file

@ -1,7 +1,9 @@
import Lean
set_option linter.unusedVariables false
def checkWithMkMatcherInput (matcher : Lean.Name) : Lean.MetaM Unit :=
Lean.Meta.Match.withMkMatcherInput matcher fun input => do
Lean.Meta.Match.withMkMatcherInput matcher (unfoldNamed := false) fun input => do
let res ← Lean.Meta.Match.mkMatcher input
let origMatcher ← Lean.getConstInfo matcher
if not <| input.matcherName == matcher then

View file

@ -9,6 +9,25 @@ def simple : Lean.Expr → Bool
| .sort _ => true
| _ => false
/--
info: def simple.match_1.{u_1} : (motive : Expr → Sort u_1) →
(x : Expr) → ((u : Level) → motive (sort u)) → ((x : Expr) → motive x) → motive x :=
fun motive x h_1 h_2 => simple._sparseCasesOn_1 x (fun u => h_1 u) fun h => h_2 x
-/
#guard_msgs in
#print simple.match_1
-- Check that the splitter re-uses the sparseCasesOn generated for the matcher:
/--
info: private def simple.match_1.splitter.{u_1} : (motive : Expr → Sort u_1) →
(x : Expr) →
((u : Level) → motive (sort u)) → ((x : Expr) → (∀ (u : Level), x = sort u → False) → motive x) → motive x :=
fun motive x h_1 h_2 => simple._sparseCasesOn_1 x (fun u => h_1 u) fun h => h_2 x ⋯
-/
#guard_msgs in
#print simple.match_1.splitter
def expensive : Lean.Expr → Lean.Expr → Bool
| .app (.app (.sort 1) (.sort 1)) (.sort 1), .app (.app (.sort 1) (.sort 1)) (.sort 1) => false
| _, _ => true
@ -49,6 +68,7 @@ info: expensive.match_1.splitter.{u_1} (motive : Expr → Expr → Sort u_1) (x
-/
#guard_msgs in
#check expensive.match_1.splitter
/--
info: expensive.match_1.eq_1.{u_1} (motive : Expr → Expr → Sort u_1)
(h_1 :

View file

@ -1,3 +1,6 @@
-- set_option trace.Meta.Match.match true
-- set_option trace.Meta.Match.matchEqs true
def f (xs : List Nat) : Nat :=
match xs with
| [] => 1