From 63bd0b5e77fbd68079b0b945b33cc776dbd3e4e5 Mon Sep 17 00:00:00 2001 From: Joachim Breitner Date: Wed, 19 Nov 2025 16:09:17 +0100 Subject: [PATCH] refactor: introduce Match.altInfos (#11256) This PR replaces `MatcherInfo.numAltParams` with a more detailed data structure that allows us, in particular, to distinguish between an alternative for a constructor with a `Unit` field and the alternative for a nullary constructor, where an artificial `Unit` argument is introduced. --- src/Lean/Compiler/LCNF/ToDecl.lean | 3 +- .../Elab/PreDefinition/Structural/BRecOn.lean | 6 ++-- src/Lean/Elab/PreDefinition/WF/Fix.lean | 6 ++-- .../Meta/Constructions/CasesOnSameCtor.lean | 7 ++-- src/Lean/Meta/Match/Match.lean | 34 +++++++++---------- src/Lean/Meta/Match/MatcherApp/Basic.lean | 27 +++++---------- src/Lean/Meta/Match/MatcherApp/Transform.lean | 21 +++++++----- src/Lean/Meta/Match/MatcherInfo.lean | 23 +++++++++++-- stage0/src/stdlib_flags.h | 2 ++ 9 files changed, 71 insertions(+), 58 deletions(-) diff --git a/src/Lean/Compiler/LCNF/ToDecl.lean b/src/Lean/Compiler/LCNF/ToDecl.lean index 0bcaa37b6c..8752390a12 100644 --- a/src/Lean/Compiler/LCNF/ToDecl.lean +++ b/src/Lean/Compiler/LCNF/ToDecl.lean @@ -49,10 +49,9 @@ partial def inlineMatchers (e : Expr) : CoreM Expr := return .visit (← Meta.mkLambdaFVars xs (mkAppN e xs)) else let mut args := e.getAppArgs - let numAlts := info.numAlts let altNumParams := info.altNumParams let rec inlineMatcher (i : Nat) (args : Array Expr) (letFVars : Array Expr) : MetaM Expr := do - if h : i < numAlts then + if h : i < altNumParams.size then let altIdx := i + info.getFirstAltPos let numParams := altNumParams[i] let alt ← normalizeAlt args[altIdx]! numParams diff --git a/src/Lean/Elab/PreDefinition/Structural/BRecOn.lean b/src/Lean/Elab/PreDefinition/Structural/BRecOn.lean index 11fa11e246..14dd046ec5 100644 --- a/src/Lean/Elab/PreDefinition/Structural/BRecOn.lean +++ b/src/Lean/Elab/PreDefinition/Structural/BRecOn.lean @@ -187,11 +187,11 @@ private partial def replaceRecApps (recArgInfos : Array RecArgInfo) (positions : trace[Elab.definition.structural] "below before matcherApp.addArg: {below} : {← inferType below}" if let some matcherApp ← matcherApp.addArg? below then let altsNew ← matcherApp.alts.zipWithM (bs := matcherApp.altNumParams) fun alt numParams => - lambdaBoundedTelescope alt numParams fun xs altBody => do + lambdaBoundedTelescope alt (numParams + 1) fun xs altBody => do trace[Elab.definition.structural] "altNumParams: {numParams}, xs: {xs}" - unless xs.size = numParams do + unless xs.size = numParams + 1 do throwError "unexpected matcher application alternative{indentExpr alt}\nat application{indentExpr e}" - let belowForAlt := xs[numParams - 1]! + let belowForAlt := xs[numParams]! mkLambdaFVars xs (← loop belowForAlt altBody) pure { matcherApp with alts := altsNew }.toExpr else diff --git a/src/Lean/Elab/PreDefinition/WF/Fix.lean b/src/Lean/Elab/PreDefinition/WF/Fix.lean index e1813edbcc..afe0464f92 100644 --- a/src/Lean/Elab/PreDefinition/WF/Fix.lean +++ b/src/Lean/Elab/PreDefinition/WF/Fix.lean @@ -101,10 +101,10 @@ where | some matcherApp => if let some matcherApp ← matcherApp.addArg? F then let altsNew ← matcherApp.alts.zipWithM (bs := matcherApp.altNumParams) fun alt numParams => - lambdaBoundedTelescope alt numParams fun xs altBody => do - unless xs.size = numParams do + lambdaBoundedTelescope alt (numParams + 1) fun xs altBody => do + unless xs.size = (numParams + 1) do throwError "unexpected matcher application alternative{indentExpr alt}\nat application{indentExpr e}" - let FAlt := xs[numParams - 1]! + let FAlt := xs[numParams]! let altBody' ← loop FAlt altBody mkLambdaFVars xs altBody' return { matcherApp with alts := altsNew, discrs := (← matcherApp.discrs.mapM (loop F)) }.toExpr diff --git a/src/Lean/Meta/Constructions/CasesOnSameCtor.lean b/src/Lean/Meta/Constructions/CasesOnSameCtor.lean index 92f0d0d615..07ecbfd4de 100644 --- a/src/Lean/Meta/Constructions/CasesOnSameCtor.lean +++ b/src/Lean/Meta/Constructions/CasesOnSameCtor.lean @@ -153,7 +153,7 @@ public def mkCasesOnSameCtor (declName : Name) (indName : Name) : MetaM Unit := let motiveType ← mkForallFVars (is ++ #[x1,x2,heq]) (mkSort v) withLocalDecl `motive .implicit motiveType fun motive => do - let altTypes ← info.ctors.toArray.mapIdxM fun i ctorName => do + let (altTypes, altInfos) ← Array.unzip <$> info.ctors.toArray.mapIdxM fun i ctorName => do let ctor := mkAppN (mkConst ctorName us) params withSharedCtorIndices ctor fun zs12 is fields1 fields2 => do let ctorApp1 := mkAppN ctor fields1 @@ -164,7 +164,8 @@ public def mkCasesOnSameCtor (declName : Name) (indName : Name) : MetaM Unit := let name := match ctorName with | Name.str _ s => Name.mkSimple s | _ => Name.mkSimple s!"alt{i+1}" - return (name, e) + let altInfo := { numFields := zs12.size, numOverlaps := 0, hasUnitThunk := zs12.isEmpty : Match.AltParamInfo} + return ((name, e), altInfo) withLocalDeclsDND altTypes fun alts => do forallBoundedTelescope t0 (some (info.numIndices + 1)) fun ism1' _ => forallBoundedTelescope t0 (some (info.numIndices + 1)) fun ism2' _ => do @@ -210,7 +211,7 @@ public def mkCasesOnSameCtor (declName : Name) (indName : Name) : MetaM Unit := let matcherInfo : MatcherInfo := { numParams := info.numParams numDiscrs := info.numIndices + 3 - altNumParams := altTypes.map (·.2.getNumHeadForalls) + altInfos uElimPos? := some 0 discrInfos := #[{}, {}, {}]} diff --git a/src/Lean/Meta/Match/Match.lean b/src/Lean/Meta/Match/Match.lean index d272b4f95a..03c6e24f29 100644 --- a/src/Lean/Meta/Match/Match.lean +++ b/src/Lean/Meta/Match/Match.lean @@ -91,8 +91,9 @@ where k hs /-- Given a list of `AltLHS`, create a minor premise for each one, convert them into `Alt`, and then execute `k` -/ -private def withAlts {α} (motive : Expr) (discrs : Array Expr) (discrInfos : Array DiscrInfo) (lhss : List AltLHS) (k : List Alt → Array (Expr × Nat) → MetaM α) : MetaM α := - loop lhss [] #[] +private def withAlts {α} (motive : Expr) (discrs : Array Expr) (discrInfos : Array DiscrInfo) + (lhss : List AltLHS) (k : List Alt → Array Expr → Array AltParamInfo → MetaM α) : MetaM α := + loop lhss [] #[] #[] where mkMinorType (xs : Array Expr) (lhs : AltLHS) : MetaM Expr := withExistingLocalDecls lhs.fvarDecls do @@ -101,23 +102,24 @@ where withEqs discrs args discrInfos fun eqs => do mkForallFVars (xs ++ eqs) minorType - loop (lhss : List AltLHS) (alts : List Alt) (minors : Array (Expr × Nat)) : MetaM α := do + loop (lhss : List AltLHS) (alts : List Alt) (minors : Array Expr) (altInfos : Array AltParamInfo) : MetaM α := do match lhss with - | [] => k alts.reverse minors + | [] => k alts.reverse minors altInfos | lhs::lhss => let xs := lhs.fvarDecls.toArray.map LocalDecl.toExpr let minorType ← mkMinorType xs lhs let hasParams := !xs.isEmpty || discrInfos.any fun info => info.hName?.isSome - let (minorType, minorNumParams) := if hasParams then (minorType, xs.size) else (mkSimpleThunkType minorType, 1) + let minorType := if hasParams then minorType else mkSimpleThunkType minorType let idx := alts.length let minorName := (`h).appendIndexAfter (idx+1) trace[Meta.Match.debug] "minor premise {minorName} : {minorType}" withLocalDeclD minorName minorType fun minor => do let rhs := if hasParams then mkAppN minor xs else mkApp minor (mkConst `Unit.unit) - let minors := minors.push (minor, minorNumParams) + let minors := minors.push minor + let altInfos := altInfos.push { numFields := xs.size, numOverlaps := 0, hasUnitThunk := !hasParams } let fvarDecls ← lhs.fvarDecls.mapM instantiateLocalDeclMVars let alts := { ref := lhs.ref, idx := idx, rhs := rhs, fvarDecls := fvarDecls, patterns := lhs.patterns, cnstrs := [] } :: alts - loop lhss alts minors + loop lhss alts minors altInfos structure State where /-- Used alternatives -/ @@ -1094,7 +1096,6 @@ where `v` is a universe parameter or 0 if `B[a_1, ..., a_n]` is a proposition. def mkMatcher (input : MkMatcherInput) : MetaM MatcherResult := withCleanLCtxFor input do let ⟨matcherName, matchType, discrInfos, lhss⟩ := input let numDiscrs := discrInfos.size - let numEqs := getNumEqsFromDiscrInfos discrInfos checkNumPatterns numDiscrs lhss forallBoundedTelescope matchType numDiscrs fun discrs matchTypeBody => do /- We generate an matcher that can eliminate using different motives with different universe levels. @@ -1103,9 +1104,8 @@ def mkMatcher (input : MkMatcherInput) : MetaM MatcherResult := withCleanLCtxFor This is useful for implementing `MatcherApp.addArg` because it may have to change the universe level. -/ let uElim ← getLevel matchTypeBody let uElimGen ← if uElim == levelZero then pure levelZero else mkFreshLevelMVar - let mkMatcher (type val : Expr) (minors : Array (Expr × Nat)) (s : State) : MetaM MatcherResult := do + let mkMatcher (type val : Expr) (altInfos : Array AltParamInfo) (s : State) : MetaM MatcherResult := do trace[Meta.Match.debug] "matcher value: {val}\ntype: {type}" - trace[Meta.Match.debug] "minors num params: {minors.map (·.2)}" /- The option `bootstrap.gen_matcher_code` is a helper hack. It is useful, for example, for compiling `src/Init/Data/Int`. It is needed because the compiler uses `Int.decLt` for generating code for `Int.casesOn` applications, but `Int.casesOn` is used to @@ -1125,7 +1125,7 @@ def mkMatcher (input : MkMatcherInput) : MetaM MatcherResult := withCleanLCtxFor match addMatcher with | some addMatcher => addMatcher <| { numParams := matcher.getAppNumArgs - altNumParams := minors.map fun minor => minor.2 + numEqs + altInfos discrInfos numDiscrs uElimPos? @@ -1153,7 +1153,7 @@ def mkMatcher (input : MkMatcherInput) : MetaM MatcherResult := withCleanLCtxFor let isEqMask ← eqs.mapM fun eq => return (← inferType eq).isEq return (mvarType, isEqMask) trace[Meta.Match.debug] "target: {mvarType}" - withAlts motive discrs discrInfos lhss fun alts minors => do + withAlts motive discrs discrInfos lhss fun alts minors altInfos => do let mvar ← mkFreshExprMVar mvarType trace[Meta.Match.debug] "goal\n{mvar.mvarId!}" let examples := discrs'.toList.map fun discr => Example.var discr.fvarId! @@ -1170,21 +1170,21 @@ def mkMatcher (input : MkMatcherInput) : MetaM MatcherResult := withCleanLCtxFor rfls := rfls.push (← mkHEqRefl discr) isEqMaskIdx := isEqMaskIdx + 1 let val := mkAppN (mkAppN val discrs) rfls - let args := #[motive] ++ discrs ++ minors.map Prod.fst + let args := #[motive] ++ discrs ++ minors let val ← mkLambdaFVars args val let type ← mkForallFVars args (mkAppN motive discrs) - mkMatcher type val minors s + mkMatcher type val altInfos s else let mvarType := mkAppN motive discrs trace[Meta.Match.debug] "target: {mvarType}" - withAlts motive discrs discrInfos lhss fun alts minors => do + withAlts motive discrs discrInfos lhss fun alts minors altInfos => do let mvar ← mkFreshExprMVar mvarType let examples := discrs.toList.map fun discr => Example.var discr.fvarId! let (_, s) ← (process { mvarId := mvar.mvarId!, vars := discrs.toList, alts := alts, examples := examples }).run {} - let args := #[motive] ++ discrs ++ minors.map Prod.fst + let args := #[motive] ++ discrs ++ minors let type ← mkForallFVars args mvarType let val ← mkLambdaFVars args mvar - mkMatcher type val minors s + mkMatcher type val altInfos s def getMkMatcherInputInContext (matcherApp : MatcherApp) : MetaM MkMatcherInput := do let matcherName := matcherApp.matcherName diff --git a/src/Lean/Meta/Match/MatcherApp/Basic.lean b/src/Lean/Meta/Match/MatcherApp/Basic.lean index f7119e9866..0f35e5ab79 100644 --- a/src/Lean/Meta/Match/MatcherApp/Basic.lean +++ b/src/Lean/Meta/Match/MatcherApp/Basic.lean @@ -12,15 +12,12 @@ public section namespace Lean.Meta -structure MatcherApp where +structure MatcherApp extends Match.MatcherInfo where matcherName : Name matcherLevels : Array Level - uElimPos? : Option Nat - discrInfos : Array Match.DiscrInfo params : Array Expr motive : Expr discrs : Array Expr - altNumParams : Array Nat alts : Array Expr remaining : Array Expr @@ -39,14 +36,12 @@ def matchMatcherApp? [Monad m] [MonadEnv m] [MonadError m] (e : Expr) (alsoCases if args.size < info.arity then return none return some { + info with matcherName := declName matcherLevels := declLevels.toArray - uElimPos? := info.uElimPos? - discrInfos := info.discrInfos params := args.extract 0 info.numParams motive := args[info.getMotivePos]! discrs := args[(info.numParams + 1)...(info.numParams + 1 + info.numDiscrs)] - altNumParams := info.altNumParams alts := args[(info.numParams + 1 + info.numDiscrs)...(info.numParams + 1 + info.numDiscrs + info.numAlts)] remaining := args[(info.numParams + 1 + info.numDiscrs + info.numAlts)...args.size] } @@ -63,24 +58,20 @@ def matchMatcherApp? [Monad m] [MonadEnv m] [MonadError m] (e : Expr) (alsoCases let alts := args[(info.numParams + 1 + info.numIndices + 1)...(info.numParams + 1 + info.numIndices + 1 + info.numCtors)] let remaining := args[(info.numParams + 1 + info.numIndices + 1 + info.numCtors)...*] let uElimPos? := if info.levelParams.length == declLevels.length then none else some 0 - let mut altNumParams := #[] - for ctor in info.ctors do - let .ctorInfo ctorInfo ← getConstInfo ctor | unreachable! - altNumParams := altNumParams.push ctorInfo.numFields + let altInfos ← info.ctors.toArray.mapM fun ctor => do + let .ctorInfo ctorInfo ← getConstInfo ctor | panic! "expected constructor" + return { numFields := ctorInfo.numFields, numOverlaps := 0, hasUnitThunk := false : Match.AltParamInfo} return some { + numParams := params.size + numDiscrs := discrs.size matcherName := declName matcherLevels := declLevels.toArray - uElimPos?, discrInfos, params, motive, discrs, alts, remaining, altNumParams + uElimPos?, discrInfos, params, motive, discrs, alts, remaining, altInfos } return none -def MatcherApp.toMatcherInfo (matcherApp : MatcherApp) : MatcherInfo where - uElimPos? := matcherApp.uElimPos? - discrInfos := matcherApp.discrInfos - numParams := matcherApp.params.size - numDiscrs := matcherApp.discrs.size - altNumParams := matcherApp.altNumParams +def MatcherApp.altNumParams (matcherApp : MatcherApp) := matcherApp.toMatcherInfo.altNumParams def MatcherApp.toExpr (matcherApp : MatcherApp) : Expr := let result := mkAppN (mkConst matcherApp.matcherName matcherApp.matcherLevels.toList) matcherApp.params diff --git a/src/Lean/Meta/Match/MatcherApp/Transform.lean b/src/Lean/Meta/Match/MatcherApp/Transform.lean index 13a3273b05..f464bc3175 100644 --- a/src/Lean/Meta/Match/MatcherApp/Transform.lean +++ b/src/Lean/Meta/Match/MatcherApp/Transform.lean @@ -15,7 +15,7 @@ public section namespace Lean.Meta.MatcherApp /-- Auxiliary function for MatcherApp.addArg -/ -private partial def updateAlts (unrefinedArgType : Expr) (typeNew : Expr) (altNumParams : Array Nat) (alts : Array Expr) (refined : Bool) (i : Nat) : MetaM (Array Nat × Array Expr) := do +private partial def updateAlts (unrefinedArgType : Expr) (typeNew : Expr) (altNumParams : Array Nat) (alts : Array Expr) (refined : Bool) (i : Nat) : MetaM (Array Expr) := do if h : i < alts.size then let alt := alts[i] let numParams := altNumParams[i]! @@ -33,11 +33,11 @@ private partial def updateAlts (unrefinedArgType : Expr) (typeNew : Expr) (altNu else pure <| !(← isDefEq unrefinedArgType (← inferType x[0]!)) return (← mkLambdaFVars xs alt, refined) - updateAlts unrefinedArgType (b.instantiate1 alt) (altNumParams.set! i (numParams+1)) (alts.set i alt) refined (i+1) + updateAlts unrefinedArgType (b.instantiate1 alt) altNumParams (alts.set i alt) refined (i+1) | _ => throwError "unexpected type at MatcherApp.addArg" else if refined then - return (altNumParams, alts) + return alts else throwError "failed to add argument to matcher application, argument type was not refined by `casesOn`" @@ -91,12 +91,11 @@ def addArg (matcherApp : MatcherApp) (e : Expr) : MetaM MatcherApp := unless (← isTypeCorrect aux) do throwError "failed to add argument to matcher application, type error when constructing the new motive" let auxType ← inferType aux - let (altNumParams, alts) ← updateAlts eType auxType matcherApp.altNumParams matcherApp.alts false 0 + let alts ← updateAlts eType auxType matcherApp.altNumParams matcherApp.alts false 0 return { matcherApp with matcherLevels := matcherLevels, motive := motive, alts := alts, - altNumParams := altNumParams, remaining := #[e] ++ matcherApp.remaining } @@ -245,7 +244,7 @@ def transform let params' ← matcherApp.params.mapM onParams let discrs' ← matcherApp.discrs.mapM onParams - let (motive', uElim, addHEqualities) ← lambdaTelescope matcherApp.motive fun motiveArgs motiveBody => do + let (motive', uElim, addHEqualities, discrInfos') ← lambdaTelescope matcherApp.motive fun motiveArgs motiveBody => do unless motiveArgs.size == matcherApp.discrs.size do throwError "unexpected matcher application, motive must be lambda expression with #{matcherApp.discrs.size} arguments" let mut motiveBody' ← onMotive motiveArgs motiveBody @@ -253,18 +252,22 @@ def transform -- Prepend `(x = e) →` or `(x ≍ e) → ` to the motive when an equality is requested -- and not already present, and remember whether we added an Eq or a HEq let mut addHEqualities : Array (Option Bool) := #[] + let mut discrInfos' := #[] for arg in motiveArgs, discr in discrs', di in matcherApp.discrInfos do if addEqualities && di.hName?.isNone then if ← isProof arg then addHEqualities := addHEqualities.push none + discrInfos' := discrInfos'.push di else let heq ← mkEqHEq discr arg motiveBody' ← liftMetaM <| mkArrow heq motiveBody' addHEqualities := addHEqualities.push heq.isHEq + discrInfos' := discrInfos'.push { hName? := some .anonymous } else addHEqualities := addHEqualities.push none + discrInfos' := discrInfos'.push di - return (← mkLambdaFVars motiveArgs motiveBody', ← getLevel motiveBody', addHEqualities) + return (← mkLambdaFVars motiveArgs motiveBody', ← getLevel motiveBody', addHEqualities, discrInfos') let matcherLevels ← match matcherApp.uElimPos? with | none => pure matcherApp.matcherLevels @@ -342,7 +345,7 @@ def transform params := params' motive := motive' discrs := discrs' - altNumParams := matchEqns.splitterAltNumParams.map (· + extraEqualities) + discrInfos := discrInfos' alts := alts' remaining := remaining' } @@ -377,7 +380,7 @@ def transform params := params' motive := motive' discrs := discrs' - altNumParams := matcherApp.altNumParams.map (· + extraEqualities) + discrInfos := discrInfos' alts := alts' remaining := remaining' } diff --git a/src/Lean/Meta/Match/MatcherInfo.lean b/src/Lean/Meta/Match/MatcherInfo.lean index 7191f34e7f..41b83bb325 100644 --- a/src/Lean/Meta/Match/MatcherInfo.lean +++ b/src/Lean/Meta/Match/MatcherInfo.lean @@ -30,12 +30,24 @@ def Overlaps.overlapping (o : Overlaps) (overlapped : Nat) : Array Nat := | some s => s.toArray | none => #[] +/-- +Informatino about the parameter structure for the alternative of a matcher or splitter. +-/ +structure AltParamInfo where + /-- Actual fields (not incuding discr eqns) -/ + numFields : Nat + /-- Overlap assumption (for splitters only) -/ + numOverlaps : Nat + /-- Whether this alternatie has an artifcial `Unit` parameter -/ + hasUnitThunk : Bool +deriving Inhabited + /-- A "matcher" auxiliary declaration has the following structure: - `numParams` parameters - motive - `numDiscrs` discriminators (aka major premises) -- `altNumParams.size` alternatives (aka minor premises) where alternative `i` has `altNumParams[i]` parameters +- `altInfos.size` alternatives (aka minor premises) with parameter structure information - `uElimPos?` is `some pos` when the matcher can eliminate in different universe levels, and `pos` is the position of the universe level parameter that specifies the elimination universe. It is `none` if the matcher only eliminates into `Prop`. @@ -44,7 +56,7 @@ A "matcher" auxiliary declaration has the following structure: structure MatcherInfo where numParams : Nat numDiscrs : Nat - altNumParams : Array Nat + altInfos : Array AltParamInfo uElimPos? : Option Nat /-- `discrInfos[i] = { hName? := some h }` if the i-th discriminant was annotated with `h :`. @@ -53,7 +65,7 @@ structure MatcherInfo where overlaps : Overlaps := {} @[expose] def MatcherInfo.numAlts (info : MatcherInfo) : Nat := - info.altNumParams.size + info.altInfos.size def MatcherInfo.arity (info : MatcherInfo) : Nat := info.numParams + 1 + info.numDiscrs + info.numAlts @@ -83,6 +95,11 @@ def getNumEqsFromDiscrInfos (infos : Array DiscrInfo) : Nat := Id.run do def MatcherInfo.getNumDiscrEqs (info : MatcherInfo) : Nat := getNumEqsFromDiscrInfos info.discrInfos +def MatcherInfo.altNumParams (info : MatcherInfo) : Array Nat := + info.altInfos.map fun {numFields, numOverlaps, hasUnitThunk} => + numFields + numOverlaps + (if hasUnitThunk then 1 else 0) + info.getNumDiscrEqs + + namespace Extension structure Entry where diff --git a/stage0/src/stdlib_flags.h b/stage0/src/stdlib_flags.h index 79a0e58edd..420ef43adc 100644 --- a/stage0/src/stdlib_flags.h +++ b/stage0/src/stdlib_flags.h @@ -1,5 +1,7 @@ #include "util/options.h" +// please update this + namespace lean { options get_default_options() { options opts;