refactor: introduce Match.altInfos (#11256)

This PR replaces `MatcherInfo.numAltParams` with a more detailed data
structure that allows us, in particular, to distinguish between an
alternative for a constructor with a `Unit` field and the alternative
for a nullary constructor, where an artificial `Unit` argument is
introduced.
This commit is contained in:
Joachim Breitner 2025-11-19 16:09:17 +01:00 committed by GitHub
parent 75342961fc
commit 63bd0b5e77
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 71 additions and 58 deletions

View file

@ -49,10 +49,9 @@ partial def inlineMatchers (e : Expr) : CoreM Expr :=
return .visit (← Meta.mkLambdaFVars xs (mkAppN e xs))
else
let mut args := e.getAppArgs
let numAlts := info.numAlts
let altNumParams := info.altNumParams
let rec inlineMatcher (i : Nat) (args : Array Expr) (letFVars : Array Expr) : MetaM Expr := do
if h : i < numAlts then
if h : i < altNumParams.size then
let altIdx := i + info.getFirstAltPos
let numParams := altNumParams[i]
let alt ← normalizeAlt args[altIdx]! numParams

View file

@ -187,11 +187,11 @@ private partial def replaceRecApps (recArgInfos : Array RecArgInfo) (positions :
trace[Elab.definition.structural] "below before matcherApp.addArg: {below} : {← inferType below}"
if let some matcherApp ← matcherApp.addArg? below then
let altsNew ← matcherApp.alts.zipWithM (bs := matcherApp.altNumParams) fun alt numParams =>
lambdaBoundedTelescope alt numParams fun xs altBody => do
lambdaBoundedTelescope alt (numParams + 1) fun xs altBody => do
trace[Elab.definition.structural] "altNumParams: {numParams}, xs: {xs}"
unless xs.size = numParams do
unless xs.size = numParams + 1 do
throwError "unexpected matcher application alternative{indentExpr alt}\nat application{indentExpr e}"
let belowForAlt := xs[numParams - 1]!
let belowForAlt := xs[numParams]!
mkLambdaFVars xs (← loop belowForAlt altBody)
pure { matcherApp with alts := altsNew }.toExpr
else

View file

@ -101,10 +101,10 @@ where
| some matcherApp =>
if let some matcherApp ← matcherApp.addArg? F then
let altsNew ← matcherApp.alts.zipWithM (bs := matcherApp.altNumParams) fun alt numParams =>
lambdaBoundedTelescope alt numParams fun xs altBody => do
unless xs.size = numParams do
lambdaBoundedTelescope alt (numParams + 1) fun xs altBody => do
unless xs.size = (numParams + 1) do
throwError "unexpected matcher application alternative{indentExpr alt}\nat application{indentExpr e}"
let FAlt := xs[numParams - 1]!
let FAlt := xs[numParams]!
let altBody' ← loop FAlt altBody
mkLambdaFVars xs altBody'
return { matcherApp with alts := altsNew, discrs := (← matcherApp.discrs.mapM (loop F)) }.toExpr

View file

@ -153,7 +153,7 @@ public def mkCasesOnSameCtor (declName : Name) (indName : Name) : MetaM Unit :=
let motiveType ← mkForallFVars (is ++ #[x1,x2,heq]) (mkSort v)
withLocalDecl `motive .implicit motiveType fun motive => do
let altTypes ← info.ctors.toArray.mapIdxM fun i ctorName => do
let (altTypes, altInfos) ← Array.unzip <$> info.ctors.toArray.mapIdxM fun i ctorName => do
let ctor := mkAppN (mkConst ctorName us) params
withSharedCtorIndices ctor fun zs12 is fields1 fields2 => do
let ctorApp1 := mkAppN ctor fields1
@ -164,7 +164,8 @@ public def mkCasesOnSameCtor (declName : Name) (indName : Name) : MetaM Unit :=
let name := match ctorName with
| Name.str _ s => Name.mkSimple s
| _ => Name.mkSimple s!"alt{i+1}"
return (name, e)
let altInfo := { numFields := zs12.size, numOverlaps := 0, hasUnitThunk := zs12.isEmpty : Match.AltParamInfo}
return ((name, e), altInfo)
withLocalDeclsDND altTypes fun alts => do
forallBoundedTelescope t0 (some (info.numIndices + 1)) fun ism1' _ =>
forallBoundedTelescope t0 (some (info.numIndices + 1)) fun ism2' _ => do
@ -210,7 +211,7 @@ public def mkCasesOnSameCtor (declName : Name) (indName : Name) : MetaM Unit :=
let matcherInfo : MatcherInfo := {
numParams := info.numParams
numDiscrs := info.numIndices + 3
altNumParams := altTypes.map (·.2.getNumHeadForalls)
altInfos
uElimPos? := some 0
discrInfos := #[{}, {}, {}]}

View file

@ -91,8 +91,9 @@ where
k hs
/-- Given a list of `AltLHS`, create a minor premise for each one, convert them into `Alt`, and then execute `k` -/
private def withAlts {α} (motive : Expr) (discrs : Array Expr) (discrInfos : Array DiscrInfo) (lhss : List AltLHS) (k : List Alt → Array (Expr × Nat) → MetaM α) : MetaM α :=
loop lhss [] #[]
private def withAlts {α} (motive : Expr) (discrs : Array Expr) (discrInfos : Array DiscrInfo)
(lhss : List AltLHS) (k : List Alt → Array Expr → Array AltParamInfo → MetaM α) : MetaM α :=
loop lhss [] #[] #[]
where
mkMinorType (xs : Array Expr) (lhs : AltLHS) : MetaM Expr :=
withExistingLocalDecls lhs.fvarDecls do
@ -101,23 +102,24 @@ where
withEqs discrs args discrInfos fun eqs => do
mkForallFVars (xs ++ eqs) minorType
loop (lhss : List AltLHS) (alts : List Alt) (minors : Array (Expr × Nat)) : MetaM α := do
loop (lhss : List AltLHS) (alts : List Alt) (minors : Array Expr) (altInfos : Array AltParamInfo) : MetaM α := do
match lhss with
| [] => k alts.reverse minors
| [] => k alts.reverse minors altInfos
| lhs::lhss =>
let xs := lhs.fvarDecls.toArray.map LocalDecl.toExpr
let minorType ← mkMinorType xs lhs
let hasParams := !xs.isEmpty || discrInfos.any fun info => info.hName?.isSome
let (minorType, minorNumParams) := if hasParams then (minorType, xs.size) else (mkSimpleThunkType minorType, 1)
let minorType := if hasParams then minorType else mkSimpleThunkType minorType
let idx := alts.length
let minorName := (`h).appendIndexAfter (idx+1)
trace[Meta.Match.debug] "minor premise {minorName} : {minorType}"
withLocalDeclD minorName minorType fun minor => do
let rhs := if hasParams then mkAppN minor xs else mkApp minor (mkConst `Unit.unit)
let minors := minors.push (minor, minorNumParams)
let minors := minors.push minor
let altInfos := altInfos.push { numFields := xs.size, numOverlaps := 0, hasUnitThunk := !hasParams }
let fvarDecls ← lhs.fvarDecls.mapM instantiateLocalDeclMVars
let alts := { ref := lhs.ref, idx := idx, rhs := rhs, fvarDecls := fvarDecls, patterns := lhs.patterns, cnstrs := [] } :: alts
loop lhss alts minors
loop lhss alts minors altInfos
structure State where
/-- Used alternatives -/
@ -1094,7 +1096,6 @@ where `v` is a universe parameter or 0 if `B[a_1, ..., a_n]` is a proposition.
def mkMatcher (input : MkMatcherInput) : MetaM MatcherResult := withCleanLCtxFor input do
let ⟨matcherName, matchType, discrInfos, lhss⟩ := input
let numDiscrs := discrInfos.size
let numEqs := getNumEqsFromDiscrInfos discrInfos
checkNumPatterns numDiscrs lhss
forallBoundedTelescope matchType numDiscrs fun discrs matchTypeBody => do
/- We generate an matcher that can eliminate using different motives with different universe levels.
@ -1103,9 +1104,8 @@ def mkMatcher (input : MkMatcherInput) : MetaM MatcherResult := withCleanLCtxFor
This is useful for implementing `MatcherApp.addArg` because it may have to change the universe level. -/
let uElim ← getLevel matchTypeBody
let uElimGen ← if uElim == levelZero then pure levelZero else mkFreshLevelMVar
let mkMatcher (type val : Expr) (minors : Array (Expr × Nat)) (s : State) : MetaM MatcherResult := do
let mkMatcher (type val : Expr) (altInfos : Array AltParamInfo) (s : State) : MetaM MatcherResult := do
trace[Meta.Match.debug] "matcher value: {val}\ntype: {type}"
trace[Meta.Match.debug] "minors num params: {minors.map (·.2)}"
/- The option `bootstrap.gen_matcher_code` is a helper hack. It is useful, for example,
for compiling `src/Init/Data/Int`. It is needed because the compiler uses `Int.decLt`
for generating code for `Int.casesOn` applications, but `Int.casesOn` is used to
@ -1125,7 +1125,7 @@ def mkMatcher (input : MkMatcherInput) : MetaM MatcherResult := withCleanLCtxFor
match addMatcher with
| some addMatcher => addMatcher <|
{ numParams := matcher.getAppNumArgs
altNumParams := minors.map fun minor => minor.2 + numEqs
altInfos
discrInfos
numDiscrs
uElimPos?
@ -1153,7 +1153,7 @@ def mkMatcher (input : MkMatcherInput) : MetaM MatcherResult := withCleanLCtxFor
let isEqMask ← eqs.mapM fun eq => return (← inferType eq).isEq
return (mvarType, isEqMask)
trace[Meta.Match.debug] "target: {mvarType}"
withAlts motive discrs discrInfos lhss fun alts minors => do
withAlts motive discrs discrInfos lhss fun alts minors altInfos => do
let mvar ← mkFreshExprMVar mvarType
trace[Meta.Match.debug] "goal\n{mvar.mvarId!}"
let examples := discrs'.toList.map fun discr => Example.var discr.fvarId!
@ -1170,21 +1170,21 @@ def mkMatcher (input : MkMatcherInput) : MetaM MatcherResult := withCleanLCtxFor
rfls := rfls.push (← mkHEqRefl discr)
isEqMaskIdx := isEqMaskIdx + 1
let val := mkAppN (mkAppN val discrs) rfls
let args := #[motive] ++ discrs ++ minors.map Prod.fst
let args := #[motive] ++ discrs ++ minors
let val ← mkLambdaFVars args val
let type ← mkForallFVars args (mkAppN motive discrs)
mkMatcher type val minors s
mkMatcher type val altInfos s
else
let mvarType := mkAppN motive discrs
trace[Meta.Match.debug] "target: {mvarType}"
withAlts motive discrs discrInfos lhss fun alts minors => do
withAlts motive discrs discrInfos lhss fun alts minors altInfos => do
let mvar ← mkFreshExprMVar mvarType
let examples := discrs.toList.map fun discr => Example.var discr.fvarId!
let (_, s) ← (process { mvarId := mvar.mvarId!, vars := discrs.toList, alts := alts, examples := examples }).run {}
let args := #[motive] ++ discrs ++ minors.map Prod.fst
let args := #[motive] ++ discrs ++ minors
let type ← mkForallFVars args mvarType
let val ← mkLambdaFVars args mvar
mkMatcher type val minors s
mkMatcher type val altInfos s
def getMkMatcherInputInContext (matcherApp : MatcherApp) : MetaM MkMatcherInput := do
let matcherName := matcherApp.matcherName

View file

@ -12,15 +12,12 @@ public section
namespace Lean.Meta
structure MatcherApp where
structure MatcherApp extends Match.MatcherInfo where
matcherName : Name
matcherLevels : Array Level
uElimPos? : Option Nat
discrInfos : Array Match.DiscrInfo
params : Array Expr
motive : Expr
discrs : Array Expr
altNumParams : Array Nat
alts : Array Expr
remaining : Array Expr
@ -39,14 +36,12 @@ def matchMatcherApp? [Monad m] [MonadEnv m] [MonadError m] (e : Expr) (alsoCases
if args.size < info.arity then
return none
return some {
info with
matcherName := declName
matcherLevels := declLevels.toArray
uElimPos? := info.uElimPos?
discrInfos := info.discrInfos
params := args.extract 0 info.numParams
motive := args[info.getMotivePos]!
discrs := args[(info.numParams + 1)...(info.numParams + 1 + info.numDiscrs)]
altNumParams := info.altNumParams
alts := args[(info.numParams + 1 + info.numDiscrs)...(info.numParams + 1 + info.numDiscrs + info.numAlts)]
remaining := args[(info.numParams + 1 + info.numDiscrs + info.numAlts)...args.size]
}
@ -63,24 +58,20 @@ def matchMatcherApp? [Monad m] [MonadEnv m] [MonadError m] (e : Expr) (alsoCases
let alts := args[(info.numParams + 1 + info.numIndices + 1)...(info.numParams + 1 + info.numIndices + 1 + info.numCtors)]
let remaining := args[(info.numParams + 1 + info.numIndices + 1 + info.numCtors)...*]
let uElimPos? := if info.levelParams.length == declLevels.length then none else some 0
let mut altNumParams := #[]
for ctor in info.ctors do
let .ctorInfo ctorInfo ← getConstInfo ctor | unreachable!
altNumParams := altNumParams.push ctorInfo.numFields
let altInfos ← info.ctors.toArray.mapM fun ctor => do
let .ctorInfo ctorInfo ← getConstInfo ctor | panic! "expected constructor"
return { numFields := ctorInfo.numFields, numOverlaps := 0, hasUnitThunk := false : Match.AltParamInfo}
return some {
numParams := params.size
numDiscrs := discrs.size
matcherName := declName
matcherLevels := declLevels.toArray
uElimPos?, discrInfos, params, motive, discrs, alts, remaining, altNumParams
uElimPos?, discrInfos, params, motive, discrs, alts, remaining, altInfos
}
return none
def MatcherApp.toMatcherInfo (matcherApp : MatcherApp) : MatcherInfo where
uElimPos? := matcherApp.uElimPos?
discrInfos := matcherApp.discrInfos
numParams := matcherApp.params.size
numDiscrs := matcherApp.discrs.size
altNumParams := matcherApp.altNumParams
def MatcherApp.altNumParams (matcherApp : MatcherApp) := matcherApp.toMatcherInfo.altNumParams
def MatcherApp.toExpr (matcherApp : MatcherApp) : Expr :=
let result := mkAppN (mkConst matcherApp.matcherName matcherApp.matcherLevels.toList) matcherApp.params

View file

@ -15,7 +15,7 @@ public section
namespace Lean.Meta.MatcherApp
/-- Auxiliary function for MatcherApp.addArg -/
private partial def updateAlts (unrefinedArgType : Expr) (typeNew : Expr) (altNumParams : Array Nat) (alts : Array Expr) (refined : Bool) (i : Nat) : MetaM (Array Nat × Array Expr) := do
private partial def updateAlts (unrefinedArgType : Expr) (typeNew : Expr) (altNumParams : Array Nat) (alts : Array Expr) (refined : Bool) (i : Nat) : MetaM (Array Expr) := do
if h : i < alts.size then
let alt := alts[i]
let numParams := altNumParams[i]!
@ -33,11 +33,11 @@ private partial def updateAlts (unrefinedArgType : Expr) (typeNew : Expr) (altNu
else
pure <| !(← isDefEq unrefinedArgType (← inferType x[0]!))
return (← mkLambdaFVars xs alt, refined)
updateAlts unrefinedArgType (b.instantiate1 alt) (altNumParams.set! i (numParams+1)) (alts.set i alt) refined (i+1)
updateAlts unrefinedArgType (b.instantiate1 alt) altNumParams (alts.set i alt) refined (i+1)
| _ => throwError "unexpected type at MatcherApp.addArg"
else
if refined then
return (altNumParams, alts)
return alts
else
throwError "failed to add argument to matcher application, argument type was not refined by `casesOn`"
@ -91,12 +91,11 @@ def addArg (matcherApp : MatcherApp) (e : Expr) : MetaM MatcherApp :=
unless (← isTypeCorrect aux) do
throwError "failed to add argument to matcher application, type error when constructing the new motive"
let auxType ← inferType aux
let (altNumParams, alts) ← updateAlts eType auxType matcherApp.altNumParams matcherApp.alts false 0
let alts ← updateAlts eType auxType matcherApp.altNumParams matcherApp.alts false 0
return { matcherApp with
matcherLevels := matcherLevels,
motive := motive,
alts := alts,
altNumParams := altNumParams,
remaining := #[e] ++ matcherApp.remaining
}
@ -245,7 +244,7 @@ def transform
let params' ← matcherApp.params.mapM onParams
let discrs' ← matcherApp.discrs.mapM onParams
let (motive', uElim, addHEqualities) ← lambdaTelescope matcherApp.motive fun motiveArgs motiveBody => do
let (motive', uElim, addHEqualities, discrInfos') ← lambdaTelescope matcherApp.motive fun motiveArgs motiveBody => do
unless motiveArgs.size == matcherApp.discrs.size do
throwError "unexpected matcher application, motive must be lambda expression with #{matcherApp.discrs.size} arguments"
let mut motiveBody' ← onMotive motiveArgs motiveBody
@ -253,18 +252,22 @@ def transform
-- Prepend `(x = e) →` or `(x ≍ e) → ` to the motive when an equality is requested
-- and not already present, and remember whether we added an Eq or a HEq
let mut addHEqualities : Array (Option Bool) := #[]
let mut discrInfos' := #[]
for arg in motiveArgs, discr in discrs', di in matcherApp.discrInfos do
if addEqualities && di.hName?.isNone then
if ← isProof arg then
addHEqualities := addHEqualities.push none
discrInfos' := discrInfos'.push di
else
let heq ← mkEqHEq discr arg
motiveBody' ← liftMetaM <| mkArrow heq motiveBody'
addHEqualities := addHEqualities.push heq.isHEq
discrInfos' := discrInfos'.push { hName? := some .anonymous }
else
addHEqualities := addHEqualities.push none
discrInfos' := discrInfos'.push di
return (← mkLambdaFVars motiveArgs motiveBody', ← getLevel motiveBody', addHEqualities)
return (← mkLambdaFVars motiveArgs motiveBody', ← getLevel motiveBody', addHEqualities, discrInfos')
let matcherLevels ← match matcherApp.uElimPos? with
| none => pure matcherApp.matcherLevels
@ -342,7 +345,7 @@ def transform
params := params'
motive := motive'
discrs := discrs'
altNumParams := matchEqns.splitterAltNumParams.map (· + extraEqualities)
discrInfos := discrInfos'
alts := alts'
remaining := remaining'
}
@ -377,7 +380,7 @@ def transform
params := params'
motive := motive'
discrs := discrs'
altNumParams := matcherApp.altNumParams.map (· + extraEqualities)
discrInfos := discrInfos'
alts := alts'
remaining := remaining'
}

View file

@ -30,12 +30,24 @@ def Overlaps.overlapping (o : Overlaps) (overlapped : Nat) : Array Nat :=
| some s => s.toArray
| none => #[]
/--
Informatino about the parameter structure for the alternative of a matcher or splitter.
-/
structure AltParamInfo where
/-- Actual fields (not incuding discr eqns) -/
numFields : Nat
/-- Overlap assumption (for splitters only) -/
numOverlaps : Nat
/-- Whether this alternatie has an artifcial `Unit` parameter -/
hasUnitThunk : Bool
deriving Inhabited
/--
A "matcher" auxiliary declaration has the following structure:
- `numParams` parameters
- motive
- `numDiscrs` discriminators (aka major premises)
- `altNumParams.size` alternatives (aka minor premises) where alternative `i` has `altNumParams[i]` parameters
- `altInfos.size` alternatives (aka minor premises) with parameter structure information
- `uElimPos?` is `some pos` when the matcher can eliminate in different universe levels, and
`pos` is the position of the universe level parameter that specifies the elimination universe.
It is `none` if the matcher only eliminates into `Prop`.
@ -44,7 +56,7 @@ A "matcher" auxiliary declaration has the following structure:
structure MatcherInfo where
numParams : Nat
numDiscrs : Nat
altNumParams : Array Nat
altInfos : Array AltParamInfo
uElimPos? : Option Nat
/--
`discrInfos[i] = { hName? := some h }` if the i-th discriminant was annotated with `h :`.
@ -53,7 +65,7 @@ structure MatcherInfo where
overlaps : Overlaps := {}
@[expose] def MatcherInfo.numAlts (info : MatcherInfo) : Nat :=
info.altNumParams.size
info.altInfos.size
def MatcherInfo.arity (info : MatcherInfo) : Nat :=
info.numParams + 1 + info.numDiscrs + info.numAlts
@ -83,6 +95,11 @@ def getNumEqsFromDiscrInfos (infos : Array DiscrInfo) : Nat := Id.run do
def MatcherInfo.getNumDiscrEqs (info : MatcherInfo) : Nat :=
getNumEqsFromDiscrInfos info.discrInfos
def MatcherInfo.altNumParams (info : MatcherInfo) : Array Nat :=
info.altInfos.map fun {numFields, numOverlaps, hasUnitThunk} =>
numFields + numOverlaps + (if hasUnitThunk then 1 else 0) + info.getNumDiscrEqs
namespace Extension
structure Entry where

View file

@ -1,5 +1,7 @@
#include "util/options.h"
// please update this
namespace lean {
options get_default_options() {
options opts;