perf: clarify and granularize access to async env ext state (#9587)

* Have asynchronous environment extensions specify whether they are
manipulate data for declarations from the "outside"/main branch (e.g.
attributes) or from the "inside"/async branch (e.g. data collected from
body elaboration) in order to avoid unnecessary waiting.
* Merge `findStateAsync?` into `getState` via a new, optional
`asyncDecl` parameter.
* Make `mayContainAsync` check an automatic part of `modifyState`.
This commit is contained in:
Sebastian Ullrich 2025-08-02 19:01:08 +02:00 committed by GitHub
parent df9ca20339
commit 1e83f62d31
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 227 additions and 185 deletions

View file

@ -50,7 +50,8 @@ where go env
| _ => env
private builtin_initialize privateConstKindsExt : MapDeclarationExtension ConstantKind ←
mkMapDeclarationExtension
-- Use `sync` so we can add entries from anywhere without restrictions
mkMapDeclarationExtension (asyncMode := .sync)
/--
Returns the kind of the declaration as originally declared instead of as exported. This information
@ -58,7 +59,9 @@ is stored by `Lean.addDecl` and may be inaccurate if that function was circumven
if the declaration was not found.
-/
def getOriginalConstKind? (env : Environment) (declName : Name) : Option ConstantKind := do
privateConstKindsExt.find? env declName <|>
-- Use `local` as for asynchronous decls from the current module, `findAsync?` below will yield
-- the same result but potentially earlier (after `addConstAsync` instead of `addDecl`)
privateConstKindsExt.find? (asyncMode := .local) env declName <|>
(env.setExporting false |>.findAsync? declName).map (·.kind)
/--

View file

@ -184,10 +184,10 @@ def registerTagAttribute (name : Name) (descr : String)
let env ← getEnv
unless (env.getModuleIdxFor? decl).isNone do
throwAttrDeclInImportedModule name decl
unless env.asyncMayContain decl do
unless ext.toEnvExtension.asyncMayModify env decl do
throwAttrNotInAsyncCtx name decl env.asyncPrefix?
validate decl
modifyEnv fun env => ext.addEntry env decl
modifyEnv fun env => ext.addEntry (asyncDecl := decl) env decl
}
registerBuiltinAttribute attrImpl
return { attr := attrImpl, ext := ext }
@ -199,22 +199,14 @@ def setTag [Monad m] [MonadError m] [MonadEnv m] (attr : TagAttribute) (decl :
let env ← getEnv
unless (env.getModuleIdxFor? decl).isNone do
throwAttrDeclInImportedModule attr.attr.name decl
unless env.asyncMayContain decl do
unless attr.ext.toEnvExtension.asyncMayModify env decl do
throwAttrNotInAsyncCtx attr.attr.name decl env.asyncPrefix?
modifyEnv fun env => attr.ext.addEntry env decl
modifyEnv fun env => attr.ext.addEntry (asyncDecl := decl) env decl
def hasTag (attr : TagAttribute) (env : Environment) (decl : Name) : Bool :=
match env.getModuleIdxFor? decl with
| some modIdx => (attr.ext.getModuleEntries env modIdx).binSearchContains decl Name.quickLt
| none =>
if attr.ext.toEnvExtension.asyncMode matches .async then
-- It seems that the env extension API doesn't quite allow querying attributes in a way
-- that works for realizable constants, but without waiting on proofs to finish.
-- Until then, we use the following overapproximation, to be refined later:
(attr.ext.findStateAsync env decl).contains decl ||
(attr.ext.getState env (asyncMode := .local)).contains decl
else
(attr.ext.getState env).contains decl
| none => (attr.ext.getState (asyncDecl := decl) env).contains decl
end TagAttribute
@ -253,7 +245,7 @@ def registerParametricAttribute (impl : ParametricAttributeImpl α) : IO (Parame
unless (env.getModuleIdxFor? decl).isNone do
throwAttrDeclInImportedModule impl.name decl
let val ← impl.getParam decl stx
modifyEnv fun env => ext.addEntry env (decl, val)
modifyEnv fun env => ext.addEntry (asyncDecl := decl) env (decl, val)
try impl.afterSet decl val catch _ => setEnv env
}
registerBuiltinAttribute attrImpl
@ -301,9 +293,9 @@ def registerEnumAttributes (attrDescrs : List (Name × String × α))
let r : Array (Name × α) := m.foldl (fun a n p => a.push (n, p)) #[]
r.qsort (fun a b => Name.quickLt a.1 b.1)
statsFn := fun s => "enumeration attribute extension" ++ Format.line ++ "number of local entries: " ++ format s.size
-- We assume (and check below) that, if used asynchronously, enum attributes are set only in the
-- same context in which the tagged declaration was created
asyncMode := .async
-- We assume (and check in `modifyState`) that, if used asynchronously, enum attributes are set
-- only in the same context in which the tagged declaration was created
asyncMode := .async .mainEnv
replay? := some fun _ newState consts st => consts.foldl (init := st) fun st c =>
match newState.find? c with
| some v => st.insert c v
@ -320,7 +312,7 @@ def registerEnumAttributes (attrDescrs : List (Name × String × α))
unless (env.getModuleIdxFor? decl).isNone do
throwAttrDeclInImportedModule name decl
validate decl val
modifyEnv fun env => ext.addEntry env (decl, val)
modifyEnv fun env => ext.addEntry (asyncDecl := decl) env (decl, val)
applicationTime := applicationTime
: AttributeImpl
}
@ -335,17 +327,17 @@ def getValue [Inhabited α] (attr : EnumAttributes α) (env : Environment) (decl
match (attr.ext.getModuleEntries env modIdx).binSearch (decl, default) (fun a b => Name.quickLt a.1 b.1) with
| some (_, val) => some val
| none => none
| none => (attr.ext.findStateAsync env decl).find? decl
| none => (attr.ext.getState (asyncDecl := decl) env).find? decl
def setValue (attrs : EnumAttributes α) (env : Environment) (decl : Name) (val : α) : Except String Environment := do
let pfx := s!"Internal error calling `{attrs.ext.name}.setValue` for `{decl}`"
if (env.getModuleIdxFor? decl).isSome then
throw s!"{pfx}: Declaration is in an imported module"
if !env.asyncMayContain decl then
unless attrs.ext.toEnvExtension.asyncMayModify env decl do
throw s!"{pfx}: Declaration is not from this async context `{env.asyncPrefix?}`"
if ((attrs.ext.findStateAsync env decl).find? decl).isSome then
if ((attrs.ext.getState (asyncDecl := decl) env).find? decl).isSome then
throw s!"{pfx}: Attribute has already been set"
return attrs.ext.addEntry env (decl, val)
return attrs.ext.addEntry (asyncDecl := decl) env (decl, val)
end EnumAttributes

View file

@ -12,7 +12,9 @@ public section
namespace Lean
builtin_initialize metaExt : TagDeclarationExtension ← mkTagDeclarationExtension (asyncMode := .async)
builtin_initialize metaExt : TagDeclarationExtension ←
-- set by `addPreDefinitions`
mkTagDeclarationExtension (asyncMode := .async .asyncEnv)
/-- Marks in the environment extension that the given declaration has been declared by the user as `meta`. -/
def addMeta (env : Environment) (declName : Name) : Environment :=

View file

@ -28,6 +28,9 @@ def addBuiltinDeclarationRanges (declName : Name) (declRanges : DeclarationRange
builtinDeclRanges.modify (·.insert declName declRanges)
def addDeclarationRanges [Monad m] [MonadEnv m] (declName : Name) (declRanges : DeclarationRanges) : m Unit := do
if declName.isAnonymous then
-- This can happen on elaboration of partial syntax and would panic in `modifyState` otherwise
return
modifyEnv fun env => declRangeExt.insert env declName declRanges
def findDeclarationRangesCore? [Monad m] [MonadEnv m] (declName : Name) : m (Option DeclarationRanges) :=

View file

@ -68,7 +68,7 @@ need to be unfolded to prove the theorem are exported and exposed.
builtin_initialize defeqAttr : TagAttribute ←
registerTagAttribute `defeq "mark theorem as a definitional equality, to be used by `dsimp`"
(validate := validateDefEqAttr) (applicationTime := .afterTypeChecking)
(asyncMode := .async)
(asyncMode := .async .mainEnv)
private partial def isRflProofCore (type : Expr) (proof : Expr) : CoreM Bool := do
match type with

View file

@ -71,8 +71,8 @@ def getEntries {α σ : Type} [Inhabited σ] (ext : SimplePersistentEnvExtension
/-- Get the current state of the given `SimplePersistentEnvExtension`. -/
def getState {α σ : Type} [Inhabited σ] (ext : SimplePersistentEnvExtension α σ) (env : Environment)
(asyncMode := ext.toEnvExtension.asyncMode) : σ :=
(PersistentEnvExtension.getState (asyncMode := asyncMode) ext env).2
(asyncMode := ext.toEnvExtension.asyncMode) (asyncDecl : Name := .anonymous) : σ :=
(PersistentEnvExtension.getState (asyncMode := asyncMode) (asyncDecl := asyncDecl) ext env).2
/-- Set the current state of the given `SimplePersistentEnvExtension`. This change is *not* persisted across files. -/
def setState {α σ : Type} (ext : SimplePersistentEnvExtension α σ) (env : Environment) (s : σ) : Environment :=
@ -82,11 +82,6 @@ def setState {α σ : Type} (ext : SimplePersistentEnvExtension α σ) (env : En
def modifyState {α σ : Type} (ext : SimplePersistentEnvExtension α σ) (env : Environment) (f : σσ) : Environment :=
PersistentEnvExtension.modifyState ext env (fun ⟨entries, s⟩ => (entries, f s))
@[inherit_doc PersistentEnvExtension.findStateAsync]
def findStateAsync {α σ : Type} [Inhabited σ] (ext : SimplePersistentEnvExtension α σ)
(env : Environment) (declPrefix : Name) : σ :=
PersistentEnvExtension.findStateAsync ext env declPrefix |>.2
end SimplePersistentEnvExtension
/-- Environment extension for tagging declarations.
@ -111,16 +106,13 @@ instance : Inhabited TagDeclarationExtension :=
def tag (ext : TagDeclarationExtension) (env : Environment) (declName : Name) : Environment :=
have : Inhabited Environment := ⟨env⟩
assert! env.getModuleIdxFor? declName |>.isNone -- See comment at `TagDeclarationExtension`
assert! env.asyncMayContain declName
ext.addEntry env declName
ext.addEntry (asyncDecl := declName) env declName
def isTagged (ext : TagDeclarationExtension) (env : Environment) (declName : Name) : Bool :=
def isTagged (ext : TagDeclarationExtension) (env : Environment) (declName : Name)
(asyncMode := ext.toEnvExtension.asyncMode) : Bool :=
match env.getModuleIdxFor? declName with
| some modIdx => (ext.getModuleEntries env modIdx).binSearchContains declName Name.quickLt
| none => if ext.toEnvExtension.asyncMode matches .async then
(ext.findStateAsync env declName).contains declName
else
(ext.getState env).contains declName
| none => (ext.getState (asyncMode := asyncMode) (asyncDecl := declName) env).contains declName
end TagDeclarationExtension
@ -131,6 +123,7 @@ structure MapDeclarationExtension (α : Type) extends PersistentEnvExtension (Na
deriving Inhabited
def mkMapDeclarationExtension (name : Name := by exact decl_name%)
(asyncMode : EnvExtension.AsyncMode := .async .mainEnv)
(exportEntriesFn : Environment → NameMap α → OLeanLevel → Array (Name × α) :=
fun _ s _ => s.toArray) :
IO (MapDeclarationExtension α) :=
@ -140,7 +133,7 @@ def mkMapDeclarationExtension (name : Name := by exact decl_name%)
addImportedFn := fun _ => pure {}
addEntryFn := fun s (n, v) => s.insert n v
exportEntriesFnEx env s level := exportEntriesFn env s level
asyncMode := .async
asyncMode
replay? := some fun _ newState newConsts s =>
newConsts.foldl (init := s) fun s c =>
if let some a := newState.find? c then
@ -153,23 +146,20 @@ namespace MapDeclarationExtension
def insert (ext : MapDeclarationExtension α) (env : Environment) (declName : Name) (val : α) : Environment :=
have : Inhabited Environment := ⟨env⟩
assert! env.getModuleIdxFor? declName |>.isNone -- See comment at `MapDeclarationExtension`
if !env.asyncMayContain declName then
panic! s!"MapDeclarationExtension.insert: cannot insert {declName} into {ext.name}, it is not contained in {env.asyncPrefix?}"
else
ext.addEntry env (declName, val)
ext.addEntry (asyncDecl := declName) env (declName, val)
def find? [Inhabited α] (ext : MapDeclarationExtension α) (env : Environment) (declName : Name)
(level := OLeanLevel.exported) : Option α :=
(asyncMode := ext.toEnvExtension.asyncMode) (level := OLeanLevel.exported) : Option α :=
match env.getModuleIdxFor? declName with
| some modIdx =>
match (ext.getModuleEntries (level := level) env modIdx).binSearch (declName, default) (fun a b => Name.quickLt a.1 b.1) with
| some e => some e.2
| none => none
| none => (ext.findStateAsync env declName).find? declName
| none => (ext.getState (asyncMode := asyncMode) (asyncDecl := declName) env).find? declName
def contains [Inhabited α] (ext : MapDeclarationExtension α) (env : Environment) (declName : Name) : Bool :=
match env.getModuleIdxFor? declName with
| some modIdx => (ext.getModuleEntries env modIdx).binSearchContains (declName, default) (fun a b => Name.quickLt a.1 b.1)
| none => (ext.findStateAsync env declName).contains declName
| none => (ext.getState (asyncDecl := declName) env).contains declName
end MapDeclarationExtension

View file

@ -434,17 +434,19 @@ private def AsyncContext.mayContain (ctx : AsyncContext) (n : Name) : Bool :=
Constant info and environment extension states eventually resulting from async elaboration.
-/
private structure AsyncConst where
constInfo : AsyncConstantInfo
constInfo : AsyncConstantInfo
/--
Reported extension state eventually fulfilled by promise; may be missing for tasks (e.g. kernel
checking) that can eagerly guarantee they will not report any state.
-/
exts? : Option (Task (Array EnvExtensionState))
exts? : Option (Task (Array EnvExtensionState))
/--
`Task AsyncConsts` except for problematic recursion. The set of nested constants created while
elaborating this constant.
-/
consts : Task Dynamic
aconstsImpl : Task Dynamic
/-- True if generated by `realizeConst`. -/
isRealized : Bool := false
/-- Data structure holding a sequence of `AsyncConst`s optimized for efficient access. -/
private structure AsyncConsts where
@ -456,6 +458,10 @@ private structure AsyncConsts where
normalizedTrie : NameTrie AsyncConst
deriving Inhabited, TypeName
private def AsyncConst.aconsts (c : AsyncConst) : Task AsyncConsts :=
c.aconstsImpl.map (sync := true) fun dyn =>
dyn.get? AsyncConsts |>.getD default
private def AsyncConsts.add (aconsts : AsyncConsts) (aconst : AsyncConst) : AsyncConsts :=
let normalizedName := privateToUserName aconst.constInfo.name
if let some aconst' := aconsts.normalizedTrie.find? normalizedName then
@ -488,7 +494,7 @@ private partial def AsyncConsts.findRec? (aconsts : AsyncConsts) (declName : Nam
-- If privacy is the only difference between `declName` and `findPrefix?` result, we can assume
-- `declName` does not exist according to the `add` invariant
guard <| privateToUserName c.constInfo.name != privateToUserName declName
let aconsts ← c.consts.get.get? AsyncConsts
let aconsts ← c.aconsts.get
AsyncConsts.findRec? aconsts declName
/-- Like `findRec?`; allocating tasks is (currently?) too costly to do always. -/
@ -496,10 +502,21 @@ private partial def AsyncConsts.findRecTask (aconsts : AsyncConsts) (declName :
let some c := aconsts.findPrefix? declName | .pure none
if c.constInfo.name == declName then
return .pure c
c.consts.bind (sync := true) fun aconsts => Id.run do
let some aconsts := aconsts.get? AsyncConsts | .pure none
c.aconsts.bind (sync := true) fun aconsts => Id.run do
AsyncConsts.findRecTask aconsts declName
/-- Like `findRec?` but also returns the constant that has `declName` in its `consts`, if any. -/
private partial def AsyncConsts.findRecAndParent? (aconsts : AsyncConsts) (declName : Name) : Option (AsyncConst × Option AsyncConst) :=
go none aconsts
where go parent? aconsts := do
let c ← aconsts.findPrefix? declName
if c.constInfo.name == declName then
return (c, parent?)
-- If privacy is the only difference between `declName` and `findPrefix?` result, we can assume
-- `declName` does not exist according to the `add` invariant
guard <| privateToUserName c.constInfo.name != privateToUserName declName
go (some c) c.aconsts.get
/-- Accessibility levels of declarations in `Lean.Environment`. -/
private inductive Visibility where
/-- Information private to the module. -/
@ -593,7 +610,7 @@ structure Environment where
/--
Task collecting all realizations from the current and already-forked environment branches, akin to
how `checked` collects all declarations. We only use it as a fallback in
`findAsyncCore?`/`findStateAsync`; see there.
`findAsyncCore?`/`getState`; see there.
-/
private allRealizations : Task (NameMap AsyncConst) := .pure {}
/--
@ -674,18 +691,6 @@ def importEnv? (env : Environment) : Option Environment :=
def unlockAsync (env : Environment) : Environment :=
{ env with asyncCtx? := none }
/--
Checks whether the given declaration name may potentially added, or have been added, to the current
environment branch, which is the case either if this is the main branch or if the declaration name
is a suffix (modulo privacy and hygiene information) of the top-level declaration name for which
this branch was created.
This function should always be checked before modifying an `AsyncMode.async` environment extension
to ensure `findStateAsync` will be able to find the modification from other branches.
-/
def asyncMayContain (env : Environment) (declName : Name) : Bool :=
env.asyncCtx?.all (·.mayContain declName)
@[extern "lean_elab_add_decl"]
private opaque addDeclCheck (env : Environment) (maxHeartbeats : USize) (decl : @& Declaration)
(cancelTk? : @& Option IO.CancelToken) : Except Kernel.Exception Environment
@ -720,14 +725,14 @@ def addDeclCore (env : Environment) (maxHeartbeats : USize) (decl : @& Declarati
env := { env with asyncConstsMap.private := env.asyncConstsMap.private.add {
constInfo := .ofConstantInfo info
exts? := none
consts := .pure <| .mk (α := AsyncConsts) default
aconstsImpl := .pure <| .mk (α := AsyncConsts) default
} }
-- TODO
if true /- !isPrivateName n-/ then
env := { env with asyncConstsMap.public := env.asyncConstsMap.public.add {
constInfo := .ofConstantInfo info
exts? := none
consts := .pure <| .mk (α := AsyncConsts) default
aconstsImpl := .pure <| .mk (α := AsyncConsts) default
} }
return env
@ -749,7 +754,7 @@ private def lakeAdd (env : Environment) (cinfo : ConstantInfo) : Environment :=
asyncConstsMap := env.asyncConstsMap.map (·.add {
constInfo := .ofConstantInfo cinfo
exts? := none
consts := .pure <| .mk (α := AsyncConsts) default
aconstsImpl := .pure <| .mk (α := AsyncConsts) default
})
}
@ -757,23 +762,27 @@ private def lakeAdd (env : Environment) (cinfo : ConstantInfo) : Environment :=
@[extern "lean_is_reserved_name"]
private opaque isReservedName (env : Environment) (name : Name) : Bool
/-- `findAsync?` after `base` access -/
private def findAsyncCore? (env : Environment) (n : Name) (skipRealize := false) :
Option AsyncConstantInfo := do
@[inline] private def findAsyncConst? (env : Environment) (n : Name) (skipRealize := false) :
Option AsyncConst := do
if let some c := env.asyncConsts.find? n then
-- Constant for which an asynchronous elaboration task was spawned
-- (this is an optimized special case of the next branch)
return c.constInfo
return c
if let some c := env.asyncConsts.findRec? n then
-- Constant generated in a different environment branch
return c.constInfo
return c
if !skipRealize && isReservedName env n then
if let some c := env.allRealizations.get.find? n then
return c.constInfo
return c
-- Not in the kernel environment nor in the name prefix of a known environment branch: undefined
-- by `addDeclCore` invariant.
none
/-- `findAsync?` after `base` access -/
private def findAsyncCore? (env : Environment) (n : Name) (skipRealize := false) :
Option AsyncConstantInfo := do
env.findAsyncConst? n (skipRealize := skipRealize) |>.map (·.constInfo)
/-- Like `findAsyncCore?`; allocating tasks is (currently?) too costly to do always. -/
private def findTaskCore (env : Environment) (n : Name) (skipRealize := false) :
Task (Option AsyncConstantInfo) := Id.run do
@ -1032,7 +1041,7 @@ def addConstAsync (env : Environment) (constName : Name) (kind : ConstantKind)
| some v => v.exts
-- any value should work here, `base` does not block
| none => env.base.private.extensions)
consts := constPromise.result?.map (sync := true) fun
aconstsImpl := constPromise.result?.map (sync := true) fun
| some v => .mk v.nestedConsts.private
| none => .mk (α := AsyncConsts) default
}
@ -1043,7 +1052,7 @@ def addConstAsync (env : Environment) (constName : Name) (kind : ConstantKind)
| some c => c.exportedConstInfo
| none => mkFallbackConstInfo constName exportedKind
}
consts := constPromise.result?.map (sync := true) fun
aconstsImpl := constPromise.result?.map (sync := true) fun
| some v => .mk v.nestedConsts.public
| none => .mk (α := AsyncConsts) default
}
@ -1190,6 +1199,23 @@ def instantiateValueLevelParams! (c : ConstantInfo) (ls : List Level) : Expr :=
end ConstantInfo
/--
Branch specification for asynchronous environment extension access.
Note: For declarations not created via `addConstAsync`, including those created via `realizeConst`,
the two specifiers are equivalent.
-/
inductive AsyncBranch where
/--
The main branch that initiated adding a declaration, i.e. `AddConstAsyncResult.mainEnv`.
This is the more common case and true for e.g. all accesses from attributes.
-/
| mainEnv
/-- The async branch that finished adding a declaration, i.e. `AddConstAsyncResult.asyncEnv`. -/
| asyncEnv
deriving BEq
/--
Async access mode for environment extensions used in `EnvExtension.get/set/modifyState`.
When modified in concurrent contexts, extensions may need to switch to a different mode than the
@ -1198,8 +1224,9 @@ registration time but can be overridden when calling the mentioned functions in
for specific accesses.
In all modes, the state stored into the `.olean` file for persistent environment extensions is the
result of `getState` called on the main environment branch at the end of the file, i.e. it
encompasses all modifications for all modes but `local`.
result of `getState (asyncMode := .sync)` called on the main environment branch at the end of the
file, i.e. it encompasses all modifications on all branches except for `local` modifications for
which only the main branch is included.
-/
inductive EnvExtension.AsyncMode where
/--
@ -1235,22 +1262,20 @@ inductive EnvExtension.AsyncMode where
-/
| mainOnly
/--
Accumulates modifications in the `checked` environment like `sync`, but `getState` will panic
instead of blocking. Instead `findStateAsync` should be used, which will access the state of the
environment branch corresponding to the passed declaration name, if any, or otherwise the state
of the current branch. In other words, at most one environment branch will be blocked on instead
of all prior branches. The local state can still be accessed by calling `getState` with mode
`local` explicitly.
Accumulates modifications in the `checked` environment like `sync`, but `get/modify/setState` will
panic instead of blocking unless their `asyncDecl` parameter is specified, which will access the
state of the environment branch corresponding to the passed declaration name, if any; see
`AsyncBranch` for a description of the specific state accessed. In other words, at most the
environment branch corresponding to that declaration will be blocked on instead of all prior
branches. The local state can still be accessed by calling `getState` with mode `local`
explicitly.
This mode is suitable for extensions with map-like state where the key uniquely identifies the
top-level declaration where it could have been set, e.g. because the key on modification is always
the surrounding declaration's name. Any calls to `modifyState`/`setState` should assert
`asyncMayContain` with that key to ensure state is never accidentally stored in a branch where it
cannot be found by `findStateAsync`. In particular, this mode is closest to how the environment's
own constant map works which asserts the same predicate on modification and provides `findAsync?`
for block-avoiding access.
the surrounding declaration's name. In particular, this mode is closest to how the environment's
own constant map works which provides `findAsync?` for block-avoiding access.
-/
| async
| async (branch : AsyncBranch)
deriving Inhabited
abbrev ReplayFn (σ : Type) :=
@ -1327,6 +1352,24 @@ def mkInitialExtStates : IO (Array EnvExtensionState) := do
let exts ← envExtensionsRef.get
exts.mapM fun ext => ext.mkInitial
/--
Checks whether `modifyState (asyncDecl := declName)` may be called on an async environment
extension; see `AsyncMode.async` for details.
-/
def asyncMayModify (ext : EnvExtension σ) (env : Environment) (asyncDecl : Name)
(asyncMode := ext.asyncMode) : Bool :=
env.asyncCtx?.all fun ctx =>
match asyncMode with
-- The main env's async context, if any, should be a strict prefix of `asyncDecl`. This does not
-- conclusively check that we are not in some parent branch of `mainEnv` but it covers the most
-- common case of confusing `mainEnv` and `asyncEnv`.
| .async .mainEnv => ctx.mayContain asyncDecl && ctx.declPrefix != asyncDecl
-- The async env's async context should either be `asyncDecl` itself or `asyncDecl` is a nested
-- declaration that is not itself async.
| .async .asyncEnv => ctx.declPrefix == asyncDecl ||
(ctx.mayContain asyncDecl && (env.findAsyncConst? asyncDecl).any (·.exts?.isNone))
| _ => true
/--
Applies the given function to the extension state. See `AsyncMode` for details on how modifications
from different environment branches are reconciled.
@ -1335,7 +1378,7 @@ Note that in modes `sync` and `async`, `f` will be called twice, on the local an
state.
-/
def modifyState {σ : Type} (ext : EnvExtension σ) (env : Environment) (f : σσ)
(asyncMode := ext.asyncMode) : Environment := Id.run do
(asyncMode := ext.asyncMode) (asyncDecl : Name := .anonymous) : Environment := Id.run do
-- for panics
let _ : Inhabited Environment := ⟨env⟩
-- safety: `ext`'s constructor is private, so we can assume the entry at `ext.idx` is of type `σ`
@ -1348,6 +1391,14 @@ def modifyState {σ : Type} (ext : EnvExtension σ) (env : Environment) (f : σ
| .local =>
return { env with base.private.extensions := unsafe ext.modifyStateImpl env.base.private.extensions f }
| _ =>
if asyncMode matches .async _ then
if asyncDecl.isAnonymous then
return panic! "called on `async` extension, must set `asyncDecl` in that case"
if let some ctx := env.asyncCtx? then
if !ext.asyncMayModify (asyncMode := asyncMode) env asyncDecl then
return panic! s!"`asyncDecl` `{asyncDecl}` is outside current context {ctx.declPrefix}"
if ext.replay?.isNone then
if let some (n :: _) := env.asyncCtx?.map (·.realizingStack) then
return panic! s!"environment extension must set `replay?` field to be \
@ -1359,72 +1410,74 @@ def modifyState {σ : Type} (ext : EnvExtension σ) (env : Environment) (f : σ
Sets the extension state to the given value. See `AsyncMode` for details on how modifications from
different environment branches are reconciled.
-/
def setState {σ : Type} (ext : EnvExtension σ) (env : Environment) (s : σ) : Environment :=
inline <| modifyState ext env fun _ => s
def setState {σ : Type} (ext : EnvExtension σ) (env : Environment) (s : σ) (asyncMode := ext.asyncMode) : Environment :=
inline <| modifyState (asyncMode := asyncMode) ext env fun _ => s
-- `unsafe` fails to infer `Nonempty` here
private unsafe def getStateUnsafe {σ : Type} [Inhabited σ] (ext : EnvExtension σ)
(env : Environment) (asyncMode := ext.asyncMode) : σ :=
(env : Environment) (asyncMode := ext.asyncMode) (asyncDecl : Name := .anonymous) : σ := Id.run do
-- safety: `ext`'s constructor is private, so we can assume the entry at `ext.idx` is of type `σ`
match asyncMode with
| .sync => ext.getStateImpl env.checked.get.extensions
| .async => panic! "called on `async` extension, use `findStateAsync` \
instead or pass `(asyncMode := .local)` to explicitly access local state"
| .sync => ext.getStateImpl env.checked.get.extensions
| .async branch =>
if asyncDecl.isAnonymous then
panic! "called on `async` extension, must set `asyncDecl` \
or pass `(asyncMode := .local)` to explicitly access local state"
-- analogous structure to `findAsync?`; see there
-- safety: `ext`'s constructor is private, so we can assume the entry at `ext.idx` is of type `σ`
if env.base.get env |>.constants.contains asyncDecl then
return ext.getStateImpl env.base.private.extensions
-- specialization of the following branch, nested async decls are rare
if let some c := env.asyncConsts.find? asyncDecl then
match branch with
| .asyncEnv =>
if let some exts := c.exts? then
return ext.getStateImpl exts.get
else
return ext.getStateImpl env.base.private.extensions
| .mainEnv =>
if c.isRealized then
if let some exts := c.exts? then
return ext.getStateImpl exts.get
else
return ext.getStateImpl env.base.private.extensions
if let some (c, parent?) := env.asyncConsts.findRecAndParent? asyncDecl then
-- If `parent?` is `none`, the current branch is the parent
let parentExts? := match parent? with
| some c => c.exts?
| none => some <| .pure env.base.private.extensions
if let some exts := (match branch with
-- If the constant is not async, fall back to parent
| .asyncEnv => c.exts? <|> parentExts?
-- If the constant is realized, parent branch is empty and we should always look at `c`. In
-- this specific case, accessing the latter will in particular not block longer than the
-- former.
| .mainEnv => if c.isRealized then c.exts? else parentExts?) then
return ext.getStateImpl exts.get
-- NOTE: if `exts?` is `none`, we should *not* try the following, more expensive branches that
-- will just come to the same conclusion
else if let some c := env.allRealizations.get.find? asyncDecl then
if let some exts := c.exts? then
return ext.getStateImpl exts.get
-- fallback; we could enforce that `asyncDecl` and its extension state always exist but the
-- upside of doing is unclear and it is not true in e.g. the compiler. One alternative would be
-- to add a `getState?` that does not panic in such cases.
ext.getStateImpl env.base.private.extensions
| _ => ext.getStateImpl env.base.private.extensions
/--
Returns the current extension state. See `AsyncMode` for details on how modifications from
different environment branches are reconciled. Panics if the extension is marked as `async`; see its
documentation for more details. Overriding the extension's default `AsyncMode` is usually not
recommended and should be considered only for important optimizations.
different environment branches are reconciled.
Overriding the extension's default `AsyncMode` is usually not recommended and should be considered
only for important optimizations.
-/
@[implemented_by getStateUnsafe]
opaque getState {σ : Type} [Inhabited σ] (ext : EnvExtension σ) (env : Environment)
(asyncMode := ext.asyncMode) : σ
-- `unsafe` fails to infer `Nonempty` here
private unsafe def findStateAsyncUnsafe {σ : Type} [Inhabited σ]
(ext : EnvExtension σ) (env : Environment) (declName : Name) : σ := Id.run do
-- analogous structure to `findAsync?`; see there
-- safety: `ext`'s constructor is private, so we can assume the entry at `ext.idx` is of type `σ`
if env.base.get env |>.constants.contains declName then
return ext.getStateImpl env.base.private.extensions
if let some c := env.asyncConsts.find? declName then
if let some exts := c.exts? then
return ext.getStateImpl exts.get
-- NOTE: if `exts?` is `none`, we should *not* try the following, more expensive branches that
-- will just come to the same conclusion
else if let some exts := findRecExts? none env.asyncConsts declName then
return ext.getStateImpl exts.get
else if let some c := env.allRealizations.get.find? declName then
if let some exts := c.exts? then
return ext.getStateImpl exts.get
-- fallback; we could enforce that `findStateAsync` is only used on existing constants but the
-- upside of doing is unclear
ext.getStateImpl env.base.private.extensions
where
/--
Like `AsyncConsts.findRec?`, but if `AsyncConst.exts?` is `none`, returns the extension state of
the surrounding `AsyncConst` instead, which is where state for synchronously added constants is
stored.
-/
findRecExts? (parent? : Option AsyncConst) (aconsts : AsyncConsts) (declName : Name) :
Option (Task (Array EnvExtensionState)) := do
let c ← aconsts.findPrefix? declName
if c.constInfo.name == declName then
return (← c.exts?.or (parent?.bind (·.exts?)))
let aconsts ← c.consts.get.get? AsyncConsts
findRecExts? c aconsts declName
/--
Returns the final extension state on the environment branch corresponding to the passed declaration
name, if any, or otherwise the state on the current branch. In other words, at most one environment
branch will be blocked on.
-/
@[implemented_by findStateAsyncUnsafe]
opaque findStateAsync {σ : Type} [Inhabited σ] (ext : EnvExtension σ)
(env : Environment) (declName : Name) : σ
(asyncMode := ext.asyncMode) (asyncDecl : Name := .anonymous) : σ
end EnvExtension
@ -1586,15 +1639,16 @@ def getModuleIREntries {α β σ : Type} [Inhabited σ] (ext : PersistentEnvExte
-- safety: as in `getStateUnsafe`
unsafe (ext.toEnvExtension.getStateImpl env.base.private.irBaseExts).importedEntries[m]!
def addEntry {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : Environment) (b : β) : Environment :=
ext.toEnvExtension.modifyState env fun s =>
def addEntry {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : Environment) (b : β)
(asyncMode := ext.toEnvExtension.asyncMode) (asyncDecl : Name := .anonymous) : Environment :=
ext.toEnvExtension.modifyState (asyncMode := asyncMode) (asyncDecl := asyncDecl) env fun s =>
let state := ext.addEntryFn s.state b;
{ s with state := state }
/-- Get the current state of the given extension in the given environment. -/
def getState {α β σ : Type} [Inhabited σ] (ext : PersistentEnvExtension α β σ) (env : Environment)
(asyncMode := ext.toEnvExtension.asyncMode) : σ :=
(ext.toEnvExtension.getState (asyncMode := asyncMode) env).state
(asyncMode := ext.toEnvExtension.asyncMode) (asyncDecl : Name := .anonymous) : σ :=
(ext.toEnvExtension.getState (asyncMode := asyncMode) (asyncDecl := asyncDecl) env).state
/-- Set the current state of the given extension in the given environment. -/
def setState {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : Environment) (s : σ) : Environment :=
@ -1605,12 +1659,6 @@ def modifyState {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env :
(asyncMode := ext.toEnvExtension.asyncMode) : Environment :=
ext.toEnvExtension.modifyState (asyncMode := asyncMode) env fun ps => { ps with state := f (ps.state) }
@[inherit_doc EnvExtension.findStateAsync]
def findStateAsync {α β σ : Type} [Inhabited σ] (ext : PersistentEnvExtension α β σ)
(env : Environment) (declPrefix : Name) : σ :=
ext.toEnvExtension.findStateAsync env declPrefix |>.state
end PersistentEnvExtension
builtin_initialize persistentEnvExtensionsRef : IO.Ref (Array (PersistentEnvExtension EnvExtensionEntry EnvExtensionEntry EnvExtensionState)) ← IO.mkRef #[]
@ -1736,7 +1784,7 @@ def mkModuleData (env : Environment) (level : OLeanLevel := .private) : IO Modul
let entries := pExts.map fun pExt => Id.run do
-- get state from `checked` at the end if `async`; it would otherwise panic
let mut asyncMode := pExt.toEnvExtension.asyncMode
if asyncMode matches .async then
if asyncMode matches .async _ then
asyncMode := .sync
let state := pExt.getState (asyncMode := asyncMode) env
(pExt.name, pExt.exportEntriesFn env state level)
@ -1845,13 +1893,13 @@ where
let pExtDescrs ← persistentEnvExtensionsRef.get
if h : i < pExtDescrs.size then
let extDescr := pExtDescrs[i]
-- `local` as `async` does not allow for `getState` but it's all safe here as there is only
-- one environment branch at this point.
let s := extDescr.toEnvExtension.getState (asyncMode := .local) env
-- Use `sync` to avoid `async` checks; there is only one environment branch at this point
-- anyway.
let s := extDescr.toEnvExtension.getState (asyncMode := .sync) env
let prevSize := (← persistentEnvExtensionsRef.get).size
let prevAttrSize ← getNumBuiltinAttributes
let newState ← extDescr.addImportedFn s.importedEntries { env := env, opts := opts }
let mut env := extDescr.toEnvExtension.setState env { s with state := newState }
let mut env := extDescr.toEnvExtension.setState (asyncMode := .sync) env { s with state := newState }
env ← ensureExtensionsArraySize env
if (← persistentEnvExtensionsRef.get).size > prevSize || (← getNumBuiltinAttributes) > prevAttrSize then
-- This branch is executed when `pExtDescrs[i]` is the extension associated with the `init` attribute, and
@ -2247,7 +2295,7 @@ private def updateBaseAfterKernelAdd (env : Environment) (kenv : Kernel.Environm
asyncConsts.add {
constInfo := .ofConstantInfo (kenv.find? n |>.get!)
exts? := none
consts := .pure <| .mk (α := AsyncConsts) default
aconstsImpl := .pure <| .mk (α := AsyncConsts) default
}
else asyncConsts
}
@ -2266,7 +2314,7 @@ def displayStats (env : Environment) : IO Unit := do
IO.println ("extension '" ++ toString extDescr.name ++ "'")
-- get state from `checked` at the end if `async`; it would otherwise panic
let mut asyncMode := extDescr.toEnvExtension.asyncMode
if asyncMode matches .async then
if asyncMode matches .async _ then
asyncMode := .sync
let s := extDescr.toEnvExtension.getState (asyncMode := asyncMode) env
let fmt := extDescr.statsFn s.state
@ -2434,12 +2482,14 @@ def realizeConst (env : Environment) (forConst : Name) (constName : Name)
let numNewPrivateConsts := realizeEnv'.asyncConstsMap.private.size - realizeEnv.asyncConstsMap.private.size
let newPrivateConsts := realizeEnv'.asyncConstsMap.private.revList.take numNewPrivateConsts |>.reverse
let newPrivateConsts := newPrivateConsts.map fun c =>
let c := { c with isRealized := true }
if c.exts?.isNone then
{ c with exts? := some <| .pure realizeEnv'.base.private.extensions }
else c
let numNewPublicConsts := realizeEnv'.asyncConstsMap.public.size - realizeEnv.asyncConstsMap.public.size
let newPublicConsts := realizeEnv'.asyncConstsMap.public.revList.take numNewPublicConsts |>.reverse
let newPublicConsts := newPublicConsts.map fun c =>
let c := { c with isRealized := true }
if c.exts?.isNone then
{ c with exts? := some <| .pure realizeEnv'.base.private.extensions }
else c

View file

@ -2483,9 +2483,9 @@ output are reported at all callers via `Core.logSnapshotTask` (so that the locat
diagnostics is deterministic). Note that, as `realize` is run using the options at declaration time
of `forConst`, trace options must be set prior to that (or, for imported constants, on the cmdline)
in order to be active. The environment extension state at the end of `realize` is available to each
caller via `EnvExtension.findStateAsync` for `constName`. If `realize` throws an exception or fails
to add `constName` to the environment, an appropriate diagnostic is reported to all callers but no
constants are added to the environment.
caller via `EnvExtension.getState (asyncDecl := constName)`. If `realize` throws an exception or
fails to add `constName` to the environment, an appropriate diagnostic is reported to all callers
but no constants are added to the environment.
-/
def realizeConst (forConst : Name) (constName : Name) (realize : MetaM Unit) :
MetaM Unit := do

View file

@ -462,7 +462,7 @@ def mkHCongrWithArityForConst? (declName : Name) (levels : List Level) (numArgs
try
let suffix := hcongrThmSuffixBasePrefix ++ toString numArgs
let thmName := Name.str declName suffix
unless (← getEnv).contains thmName do
unless (← getEnv).containsOnBranch thmName do
let _ ← executeReservedNameAction thmName
let proof := mkConst thmName levels
let type ← inferType proof
@ -479,7 +479,7 @@ same congruence theorem over and over again.
def mkCongrSimpForConst? (declName : Name) (levels : List Level) : MetaM (Option CongrTheorem) := do
let thmName := Name.str declName congrSimpSuffix
try
unless (← getEnv).contains thmName do
unless (← getEnv).containsOnBranch thmName do
let _ ← executeReservedNameAction thmName
let proof := mkConst thmName levels
let type ← inferType proof

View file

@ -49,7 +49,7 @@ This information is populated by the `PreDefinition` module, but the simplifier
uses when unfolding declarations.
-/
builtin_initialize recExt : TagDeclarationExtension ←
mkTagDeclarationExtension `recExt (asyncMode := .async)
mkTagDeclarationExtension `recExt (asyncMode := .async .asyncEnv)
/--
Marks the given declaration as recursive.

View file

@ -744,9 +744,9 @@ def getEquationsForImpl (matchDeclName : Name) : MetaM MatchEqns := do
let splitterName := baseName ++ `splitter
-- NOTE: `go` will generate both splitter and equations but we use the splitter as the "key" for
-- `realizeConst` as well as for looking up the resultant environment extension state via
-- `findStateAsync`.
-- `getState`.
realizeConst matchDeclName splitterName (go baseName splitterName)
return matchEqnsExt.findStateAsync (← getEnv) splitterName |>.map.find! matchDeclName
return matchEqnsExt.getState (asyncMode := .async .asyncEnv) (asyncDecl := splitterName) (← getEnv) |>.map.find! matchDeclName
where go baseName splitterName := withConfig (fun c => { c with etaStruct := .none }) do
let constInfo ← getConstInfo matchDeclName
let us := constInfo.levelParams.map mkLevelParam
@ -843,7 +843,7 @@ def isCongrEqnReservedNameSuffix (s : String) : Bool :=
/- We generate the equations and splitter on demand, and do not save them on .olean files. -/
builtin_initialize matchCongrEqnsExt : EnvExtension (PHashMap Name (Array Name)) ←
-- Using `local` allows us to use the extension in `realizeConst` without specifying `replay?`.
-- The resulting state can still be accessed on the generated declarations using `findStateAsync`;
-- The resulting state can still be accessed on the generated declarations using `.asyncEnv`;
-- see below
registerEnvExtension (pure {}) (asyncMode := .local)
@ -866,7 +866,7 @@ def genMatchCongrEqns (matchDeclName : Name) : MetaM (Array Name) := do
let baseName := mkPrivateName (← getEnv) matchDeclName
let firstEqnName := .str baseName congrEqn1ThmSuffix
realizeConst matchDeclName firstEqnName (go baseName)
return matchCongrEqnsExt.findStateAsync (← getEnv) firstEqnName |>.find! matchDeclName
return matchCongrEqnsExt.getState (asyncMode := .async .asyncEnv) (asyncDecl := firstEqnName) (← getEnv) |>.find! matchDeclName
where go baseName := withConfig (fun c => { c with etaStruct := .none }) do
withConfig (fun c => { c with etaStruct := .none }) do
let constInfo ← getConstInfo matchDeclName

View file

@ -30,7 +30,7 @@ structure MatchEqnsExtState where
/- We generate the equations and splitter on demand, and do not save them on .olean files. -/
builtin_initialize matchEqnsExt : EnvExtension MatchEqnsExtState ←
-- Using `local` allows us to use the extension in `realizeConst` without specifying `replay?`.
-- The resulting state can still be accessed on the generated declarations using `findStateAsync`;
-- The resulting state can still be accessed on the generated declarations using `.asyncEnv`;
-- see below
registerEnvExtension (pure {}) (asyncMode := .local)
@ -54,6 +54,6 @@ def isMatchEqnTheorem (env : Environment) (declName : Name) : Bool := Id.run do
let .str _ s := declName.eraseMacroScopes | return false
if !isEqnLikeSuffix s then
return false
(matchEqnsExt.findStateAsync env declName).eqns.contains declName
(matchEqnsExt.getState (asyncMode := .async .asyncEnv) (asyncDecl := declName) env).eqns.contains declName
end Lean.Meta.Match

View file

@ -86,19 +86,18 @@ builtin_initialize extension : SimplePersistentEnvExtension Entry State ←
registerSimplePersistentEnvExtension {
addEntryFn := State.addEntry
addImportedFn := fun es => (mkStateFromImportedEntries State.addEntry {} es).switch
asyncMode := .async
asyncMode := .async .mainEnv
}
def addMatcherInfo (env : Environment) (matcherName : Name) (info : MatcherInfo) : Environment :=
let _ : Inhabited Environment := ⟨env⟩
assert! env.asyncMayContain matcherName
extension.addEntry env { name := matcherName, info := info }
extension.addEntry (asyncDecl := matcherName) env { name := matcherName, info := info }
def getMatcherInfo? (env : Environment) (declName : Name) : Option MatcherInfo := do
-- avoid blocking on async decls whose names look nothing like matchers
let .str _ s := declName.eraseMacroScopes | none
guard <| s.startsWith "match_"
(extension.findStateAsync env declName).map.find? declName
(extension.getState (asyncDecl := declName) env).map.find? declName
end Extension

View file

@ -35,7 +35,8 @@ builtin_initialize reducibilityCoreExt : PersistentEnvExtension (Name × Reducib
let r : Array (Name × ReducibilityStatus) := m.foldl (fun a n p => a.push (n, p)) #[]
r.qsort (fun a b => Name.quickLt a.1 b.1)
statsFn := fun s => "reducibility attribute core extension" ++ Format.line ++ "number of local entries: " ++ format s.size
asyncMode := .async
-- attribute is set by `addPreDefinitions`
asyncMode := .async .asyncEnv
}
builtin_initialize reducibilityExtraExt : SimpleScopedEnvExtension (Name × ReducibilityStatus) (SMap Name ReducibilityStatus) ←
@ -56,7 +57,7 @@ def getReducibilityStatusCore (env : Environment) (declName : Name) : Reducibili
match (reducibilityCoreExt.getModuleEntries env modIdx).binSearch (declName, .semireducible) (fun a b => Name.quickLt a.1 b.1) with
| some (_, status) => status
| none => .semireducible
| none => (reducibilityCoreExt.findStateAsync env declName).find? declName |>.getD .semireducible
| none => (reducibilityCoreExt.getState (asyncDecl := declName) env).find? declName |>.getD .semireducible
private def setReducibilityStatusCore (env : Environment) (declName : Name) (status : ReducibilityStatus) (attrKind : AttributeKind) (currNamespace : Name) : Environment :=
if attrKind matches .global then
@ -66,8 +67,7 @@ private def setReducibilityStatusCore (env : Environment) (declName : Name) (sta
reducibilityExtraExt.addEntry env (declName, status)
| none =>
let _ : Inhabited Environment := ⟨env⟩
assert! env.asyncMayContain declName
reducibilityCoreExt.addEntry env (declName, status)
reducibilityCoreExt.addEntry (asyncDecl := declName) env (declName, status)
else
-- `scoped` and `local` must be handled by `reducibilityExtraExt`
reducibilityExtraExt.addCore env (declName, status) attrKind currNamespace

View file

@ -118,7 +118,10 @@ def importConfigFileCore (olean : FilePath) (leanOpts : Options) : IO Environmen
let env := mod.entries.foldl (init := env) fun env (extName, ents) =>
if lakeExts.contains extName then
match extNameIdx[extName]? with
| some entryIdx => ents.foldl extDescrs[entryIdx]!.addEntry env
| some entryIdx =>
-- Use `sync` to avoid `async` checks, which are not relevant here as there is only one
-- environment branch.
ents.foldl (extDescrs[entryIdx]!.addEntry (asyncMode := .sync)) env
| none => env
else
env