feat: use realizeConst for all equation, unfold, induction, and partial fixpoint theorems (#7261)

This PR ensures all equation, unfold, induction, and partial fixpoint
theorem generators in core are compatible with parallelism.

Stacked on #7247
This commit is contained in:
Sebastian Ullrich 2025-03-06 16:38:04 +01:00 committed by GitHub
parent 141e519009
commit 24db5b598b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 125 additions and 88 deletions

View file

@ -39,3 +39,4 @@ import Lean.AddDecl
import Lean.Replay
import Lean.PrivateName
import Lean.PremiseSelection
import Lean.Namespace

View file

@ -5,6 +5,7 @@ Authors: Leonardo de Moura
-/
prelude
import Lean.CoreM
import Lean.Namespace
namespace Lean

View file

@ -64,6 +64,7 @@ Assign final attributes to the definitions. Assumes the EqnInfos to be already p
def addPreDefAttributes (preDefs : Array PreDefinition) : TermElabM Unit := do
for preDef in preDefs do
markAsRecursive preDef.declName
-- must happen before `generateEagerEqns`
enableRealizationsForConst preDef.declName
generateEagerEqns preDef.declName
applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation

View file

@ -63,7 +63,9 @@ private def numberNames (n : Nat) (base : String) : Array Name :=
.ofFn (n := n) fun ⟨i, _⟩ =>
if n == 1 then .mkSimple base else .mkSimple s!"{base}_{i+1}"
def deriveInduction (name : Name) : MetaM Unit := do
def deriveInduction (name : Name) : MetaM Unit :=
let inductName := name ++ `fixpoint_induct
realizeConst name inductName do
mapError (f := (m!"Cannot derive fixpoint induction principle (please report this issue)\n{indentD ·}")) do
let some eqnInfo := eqnInfoExt.find? (← getEnv) name |
throwError "{name} is not defined by partial_fixpoint"
@ -156,7 +158,6 @@ def deriveInduction (name : Name) : MetaM Unit := do
-- Prune unused level parameters, preserving the original order
let us := infos[0]!.levelParams.filter (params.contains ·)
let inductName := name ++ `fixpoint_induct
addDecl <| Declaration.thmDecl
{ name := inductName, levelParams := us, type := eTyp, value := e' }
@ -223,6 +224,8 @@ def mkOptionAdm (motive : Expr) : MetaM Expr := do
pure inst
def derivePartialCorrectness (name : Name) : MetaM Unit := do
let inductName := name ++ `partial_correctness
realizeConst name inductName do
let fixpointInductThm := name ++ `fixpoint_induct
unless (← getEnv).contains fixpointInductThm do
deriveInduction name
@ -278,7 +281,6 @@ def derivePartialCorrectness (name : Name) : MetaM Unit := do
-- Prune unused level parameters, preserving the original order
let us := infos[0]!.levelParams.filter (params.contains ·)
let inductName := name ++ `partial_correctness
addDecl <| Declaration.thmDecl
{ name := inductName, levelParams := us, type := eTyp, value := e' }

View file

@ -68,8 +68,10 @@ def wfRecursion (preDefs : Array PreDefinition) (termMeasure?s : Array (Option T
unless preDef.kind.isTheorem do
unless (← isProp preDef.type) do
WF.mkUnfoldEq preDef preDefNonRec.declName wfPreprocessProof
Mutual.addPreDefAttributes preDefs
-- must happen before `addPreDefAttributes` enables realizations for the top-level functions,
-- which may need to use realizations on `preDefNonRec`
enableRealizationsForConst preDefNonRec.declName
Mutual.addPreDefAttributes preDefs
builtin_initialize registerTraceClass `Elab.definition.wf

View file

@ -1485,25 +1485,30 @@ end TagDeclarationExtension
/-- Environment extension for mapping declarations to values.
Declarations must only be inserted into the mapping in the module where they were declared. -/
def MapDeclarationExtension (α : Type) := PersistentEnvExtension (Name × α) (Name × α) (NameMap α)
structure MapDeclarationExtension (α : Type) extends PersistentEnvExtension (Name × α) (Name × α) (NameMap α)
deriving Inhabited
def mkMapDeclarationExtension (name : Name := by exact decl_name%) : IO (MapDeclarationExtension α) :=
registerPersistentEnvExtension {
.mk <$> registerPersistentEnvExtension {
name := name,
mkInitial := pure {}
addImportedFn := fun _ => pure {}
addEntryFn := fun s (n, v) => s.insert n v
exportEntriesFn := fun s => s.toArray
asyncMode := .async
replay? := some fun _ newState newConsts s =>
newConsts.foldl (init := s) fun s c =>
if let some a := newState.find? c then
s.insert c a
else s
}
namespace MapDeclarationExtension
instance : Inhabited (MapDeclarationExtension α) :=
inferInstanceAs (Inhabited (PersistentEnvExtension ..))
def insert (ext : MapDeclarationExtension α) (env : Environment) (declName : Name) (val : α) : Environment :=
have : Inhabited Environment := ⟨env⟩
assert! env.getModuleIdxFor? declName |>.isNone -- See comment at `MapDeclarationExtension`
assert! env.asyncMayContain declName
ext.addEntry env (declName, val)
def find? [Inhabited α] (ext : MapDeclarationExtension α) (env : Environment) (declName : Name) : Option α :=
@ -1512,12 +1517,12 @@ def find? [Inhabited α] (ext : MapDeclarationExtension α) (env : Environment)
match (ext.getModuleEntries env modIdx).binSearch (declName, default) (fun a b => Name.quickLt a.1 b.1) with
| some e => some e.2
| none => none
| none => (ext.getState env).find? declName
| none => (ext.findStateAsync env declName).find? declName
def contains [Inhabited α] (ext : MapDeclarationExtension α) (env : Environment) (declName : Name) : Bool :=
match env.getModuleIdxFor? declName with
| some modIdx => (ext.getModuleEntries env modIdx).binSearchContains (declName, default) (fun a b => Name.quickLt a.1 b.1)
| none => (ext.getState env).contains declName
| none => (ext.findStateAsync env declName).contains declName
end MapDeclarationExtension
@ -1795,27 +1800,6 @@ unsafe def withImportModules {α : Type} (imports : Array Import) (opts : Option
let env ← importModules imports opts trustLevel
try act env finally env.freeRegions
/--
Environment extension for tracking all `namespace` declared by users.
-/
builtin_initialize namespacesExt : SimplePersistentEnvExtension Name NameSSet ←
registerSimplePersistentEnvExtension {
addImportedFn := fun as =>
/-
We compute a `HashMap Name Unit` and then convert to `NameSSet` to improve Lean startup time.
Note: we have used `perf` to profile Lean startup cost when processing a file containing just `import Lean`.
6.18% of the runtime is here. It was 9.31% before the `HashMap` optimization.
-/
let capacity := as.foldl (init := 0) fun r e => r + e.size
let map : Std.HashMap Name Unit := Std.HashMap.empty capacity
let map := mkStateFromImportedEntries (fun map name => map.insert name ()) map as
SMap.fromHashMap map |>.switch
addEntryFn := fun s n => s.insert n
-- Namespaces from local helper constants can be disregarded in other environment branches. We
-- do *not* want `getNamespaceSet` to have to wait on all prior branches.
asyncMode := .local
}
@[inherit_doc Kernel.Environment.enableDiag]
def Kernel.enableDiag (env : Lean.Environment) (flag : Bool) : Lean.Environment :=
env.modifyCheckedAsync (·.enableDiag flag)
@ -1834,18 +1818,6 @@ def Kernel.setDiagnostics (env : Lean.Environment) (diag : Diagnostics) : Lean.E
namespace Environment
/-- Register a new namespace in the environment. -/
def registerNamespace (env : Environment) (n : Name) : Environment :=
if (namespacesExt.getState env).contains n then env else namespacesExt.addEntry env n
/-- Return `true` if `n` is the name of a namespace in `env`. -/
def isNamespace (env : Environment) (n : Name) : Bool :=
(namespacesExt.getState env).contains n
/-- Return a set containing all namespaces in `env`. -/
def getNamespaceSet (env : Environment) : NameSSet :=
namespacesExt.getState env
@[export lean_elab_environment_update_base_after_kernel_add]
private def updateBaseAfterKernelAdd (env : Environment) (kenv : Kernel.Environment) (decl : Declaration) : Environment :=
{ env with

View file

@ -2302,24 +2302,27 @@ where
IO.FS.withIsolatedStreams (isolateStderr := Core.stderrAsMessages.get opts) do
-- catch all exceptions
let _ : MonadExceptOf _ MetaM := MonadAlwaysExcept.except
try
observing do
realize
if !(← getEnv).contains constName then
throwError "Lean.Meta.realizeConst: {constName} was not added to the environment"
finally
addTraceAsMessages
<* addTraceAsMessages
let res? ← act |>.run' |>.run coreCtx { env } |>.toBaseIO
match res? with
| .ok ((output, ()), st) => pure (st.env, .mk {
| .ok ((output, err?), st) => pure (st.env, .mk {
snap := (← Core.mkSnapshot output coreCtx st)
error? := none
: RealizeConstantResult
})
| .error e => pure (env, .mk {
snap := toSnapshotTree { diagnostics := .empty : Language.SnapshotLeaf}
error? := some e
error? := match err? with
| .ok () => none
| .error e => some e
: RealizeConstantResult
})
| _ =>
let _ : Inhabited (Environment × Dynamic) := ⟨env, .mk {
snap := (← Core.mkSnapshot "" coreCtx { env })
error? := none
: RealizeConstantResult
}⟩
unreachable!
end Meta

View file

@ -417,7 +417,7 @@ def mkHCongrWithArityForConst? (declName : Name) (levels : List Level) (numArgs
executeReservedNameAction thmName
let proof := mkConst thmName levels
let type ← inferType proof
let some argKinds := congrKindsExt.getState (← getEnv) |>.find? thmName
let some argKinds := congrKindsExt.find? (← getEnv) thmName
| unreachable!
return some { proof, type, argKinds }
catch _ =>
@ -434,7 +434,7 @@ def mkCongrSimpForConst? (declName : Name) (levels : List Level) : MetaM (Option
executeReservedNameAction thmName
let proof := mkConst thmName levels
let type ← inferType proof
let some argKinds := congrKindsExt.getState (← getEnv) |>.find? thmName
let some argKinds := congrKindsExt.find? (← getEnv) thmName
| unreachable!
return some { proof, type, argKinds }
catch _ =>

View file

@ -673,10 +673,10 @@ Given a unary definition `foo` defined via `WellFounded.fixF`, derive a suitable
-/
def deriveUnaryInduction (name : Name) : MetaM Name := do
let inductName := getFunInductName name
if ← hasConst inductName then return inductName
realizeConst name inductName (doRealize inductName)
return inductName
where doRealize (inductName : Name) := do
let info ← getConstInfoDefn name
let varNames ← forallTelescope info.type fun xs _ => xs.mapM (·.fvarId!.getUserName)
-- Uses of WellFounded.fix can be partially applied. Here we eta-expand the body
@ -762,19 +762,17 @@ def deriveUnaryInduction (name : Name) : MetaM Name := do
levelMask := usMask
params := paramMask.map (cond · .param .dropped) ++ #[.target]
}
return inductName
/--
Given `foo.mutual_induct`, defined `foo.induct`, `bar.induct` etc.
Given a realizer for `foo.mutual_induct`, defines `foo.induct`, `bar.induct` etc.
Used for well-founded and structural recursion.
-/
def projectMutualInduct (names : Array Name) (mutualInduct : Name) : MetaM Unit := do
let ci ← getConstInfo mutualInduct
let levelParams := ci.levelParams
def projectMutualInduct (names : Array Name) (mutualInduct : MetaM Name) (finalizeFirstInd : MetaM Unit) : MetaM Unit := do
for name in names, idx in [:names.size] do
let inductName := getFunInductName name
unless ← hasConst inductName do
realizeConst name inductName do
let ci ← getConstInfo (← mutualInduct)
let levelParams := ci.levelParams
let value ← forallTelescope ci.type fun xs _body => do
let value := .const ci.name (levelParams.map mkLevelParam)
let value := mkAppN value xs
@ -783,6 +781,8 @@ def projectMutualInduct (names : Array Name) (mutualInduct : Name) : MetaM Unit
let type ← inferType value
addDecl <| Declaration.thmDecl { name := inductName, levelParams, type, value }
if idx == 0 then finalizeFirstInd
/--
For a (non-mutual!) definition of `name`, uses the `FunIndInfo` associated with the `unaryInduct` and
derives the one for the n-ary function.
@ -864,17 +864,19 @@ def cleanPackedArgs (eqnInfo : WF.EqnInfo) (value : Expr) : MetaM Expr := do
mkExpectedTypeHint value cleanType
/--
Takes `foo._unary.induct`, where the motive is a `PSigma`/`PSum` type and
Retrieves `foo._unary.induct`, where the motive is a `PSigma`/`PSum` type, and
unpacks it into a n-ary and (possibly) joint induction principle.
-/
def unpackMutualInduction (eqnInfo : WF.EqnInfo) (unaryInductName : Name) : MetaM Name := do
def unpackMutualInduction (eqnInfo : WF.EqnInfo) : MetaM Name := do
let inductName := if eqnInfo.declNames.size > 1 then
getMutualInductName eqnInfo.declNames[0]!
else
-- If there is no mutual recursion, we generate the `foo.induct` directly.
getFunInductName eqnInfo.declNames[0]!
if ← hasConst inductName then return inductName
realizeConst eqnInfo.declNames[0]! inductName (doRealize inductName)
return inductName
where doRealize inductName := do
let unaryInductName ← deriveUnaryInduction eqnInfo.declNameNonRec
let ci ← getConstInfo unaryInductName
let us := ci.levelParams
let value := .const ci.name (us.map mkLevelParam)
@ -910,8 +912,9 @@ def unpackMutualInduction (eqnInfo : WF.EqnInfo) (unaryInductName : Name) : Meta
addDecl <| Declaration.thmDecl
{ name := inductName, levelParams := ci.levelParams, type, value }
return inductName
if eqnInfo.argsPacker.numFuncs = 1 then
setNaryFunIndInfo eqnInfo.fixedParamPerms eqnInfo.declNames[0]! unaryInductName
def withLetDecls {α} (name : Name) (ts : Array Expr) (es : Array Expr) (k : Array Expr → MetaM α) : MetaM α := do
assert! es.size = ts.size
@ -929,7 +932,15 @@ Given a recursive definition `foo` defined via structural recursion, derive `foo
if needed, and `foo.induct` for all functions in the group.
See module doc for details.
-/
def deriveInductionStructural (names : Array Name) (fixedParamPerms : FixedParamPerms) : MetaM Unit := do
def deriveInductionStructural (names : Array Name) (fixedParamPerms : FixedParamPerms) : MetaM Name := do
let inductName :=
if names.size = 1 then
getFunInductName names[0]!
else
getMutualInductName names[0]!
realizeConst names[0]! inductName (doRealize inductName)
return inductName
where doRealize inductName := do
let infos ← names.mapM getConstInfoDefn
assert! infos.size > 0
-- First open up the fixed parameters everywhere
@ -1118,19 +1129,9 @@ def deriveInductionStructural (names : Array Name) (fixedParamPerms : FixedParam
let usMask := funUs.map (levelParams.contains ·)
let us := maskArray usMask funUs |>.toList
let inductName :=
if names.size = 1 then
getFunInductName names[0]!
else
getMutualInductName names[0]!
addDecl <| Declaration.thmDecl
{ name := inductName, levelParams := us, type := eTyp, value := e' }
if names.size > 1 then
projectMutualInduct names inductName
if names.size = 1 then
setFunIndInfo {
funIndName := inductName
@ -1153,6 +1154,8 @@ targets that are unchanged in each case, so simplify applying the lemma when the
are not variables, to avoid having to generalize them.
-/
def deriveCases (name : Name) : MetaM Unit := do
let casesName := getFunCasesName name
realizeConst name casesName do
mapError (f := (m!"Cannot derive functional cases principle (please report this issue)\n{indentD ·}")) do
let info ← getConstInfo name
let value ←
@ -1197,7 +1200,6 @@ def deriveCases (name : Name) : MetaM Unit := do
let usMask := funUs.map (levelParams.contains ·)
let us := maskArray usMask funUs |>.toList
let casesName := getFunCasesName info.name
addDecl <| Declaration.thmDecl
{ name := casesName, levelParams := us, type := eTyp, value := e' }
@ -1215,12 +1217,20 @@ def deriveInduction (name : Name) : MetaM Unit := do
mapError (f := (m!"Cannot derive functional induction principle (please report this issue)\n{indentD ·}")) do
if let some eqnInfo := WF.eqnInfoExt.find? (← getEnv) name then
let unaryInductName ← deriveUnaryInduction eqnInfo.declNameNonRec
let unpackedInductName ← unpackMutualInduction eqnInfo unaryInductName
projectMutualInduct eqnInfo.declNames unpackedInductName
if eqnInfo.argsPacker.numFuncs = 1 then
setNaryFunIndInfo eqnInfo.fixedParamPerms eqnInfo.declNames[0]! unaryInductName
if eqnInfo.declNames.size > 1 then
projectMutualInduct eqnInfo.declNames (unpackMutualInduction eqnInfo) do
-- We set the FunIndInfo on the first induction principle, which must happen inside its
-- realization.
if eqnInfo.argsPacker.numFuncs = 1 then
setNaryFunIndInfo eqnInfo.fixedParamPerms eqnInfo.declNames[0]! unaryInductName
else
-- (in this case, `unpackMutualInduction` already does `setNaryFunIndInfo`)
let _ ← unpackMutualInduction eqnInfo
else if let some eqnInfo := Structural.eqnInfoExt.find? (← getEnv) name then
deriveInductionStructural eqnInfo.declNames eqnInfo.fixedParamPerms
if eqnInfo.declNames.size > 1 then
projectMutualInduct eqnInfo.declNames (deriveInductionStructural eqnInfo.declNames eqnInfo.fixedParamPerms) (pure ())
else
let _ ← deriveInductionStructural eqnInfo.declNames eqnInfo.fixedParamPerms
else
throwError "constant '{name}' is not structurally or well-founded recursive"

44
src/Lean/Namespace.lean Normal file
View file

@ -0,0 +1,44 @@
/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Environment
namespace Lean
/--
Environment extension for tracking all `namespace` declared by users.
-/
builtin_initialize namespacesExt : SimplePersistentEnvExtension Name NameSSet ←
registerSimplePersistentEnvExtension {
addImportedFn := fun as =>
/-
We compute a `HashMap Name Unit` and then convert to `NameSSet` to improve Lean startup time.
Note: we have used `perf` to profile Lean startup cost when processing a file containing just `import Lean`.
6.18% of the runtime is here. It was 9.31% before the `HashMap` optimization.
-/
let capacity := as.foldl (init := 0) fun r e => r + e.size
let map : Std.HashMap Name Unit := Std.HashMap.empty capacity
let map := mkStateFromImportedEntries (fun map name => map.insert name ()) map as
SMap.fromHashMap map |>.switch
addEntryFn := fun s n => s.insert n
-- Namespaces from local helper constants can be disregarded in other environment branches. We
-- do *not* want `getNamespaceSet` to have to wait on all prior branches.
asyncMode := .local
}
namespace Environment
/-- Register a new namespace in the environment. -/
def registerNamespace (env : Environment) (n : Name) : Environment :=
if (namespacesExt.getState env).contains n then env else namespacesExt.addEntry env n
/-- Return `true` if `n` is the name of a namespace in `env`. -/
def isNamespace (env : Environment) (n : Name) : Bool :=
(namespacesExt.getState env).contains n
/-- Return a set containing all namespaces in `env`. -/
def getNamespaceSet (env : Environment) : NameSSet :=
namespacesExt.getState env

View file

@ -8,6 +8,7 @@ import Lean.Data.OpenDecl
import Lean.Hygiene
import Lean.Modifiers
import Lean.Exception
import Lean.Namespace
namespace Lean
/-!