fix: more realizeConst fixes (#7300)

Found and debugged while working on stage 2 of #7247
This commit is contained in:
Sebastian Ullrich 2025-03-03 13:10:40 +01:00 committed by GitHub
parent e7a411a66d
commit 0a55f4bf36
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 34 additions and 5 deletions

View file

@ -11,10 +11,17 @@ namespace Lean
structure ClosedTermCache where
map : PHashMap Expr Name := {}
constNames : NameSet := {}
-- used for `replay?` only
revExprs : List Expr := []
deriving Inhabited
builtin_initialize closedTermCacheExt : EnvExtension ClosedTermCache ←
registerEnvExtension (pure {}) (asyncMode := .sync) -- compilation is non-parallel anyway
(replay? := some fun oldState newState _ s =>
let newExprs := newState.revExprs.take (newState.revExprs.length - oldState.revExprs.length)
newExprs.foldl (init := s) fun s e =>
let c := newState.map.find! e
{ s with map := s.map.insert e c, constNames := s.constNames.insert c, revExprs := e :: s.revExprs })
@[export lean_cache_closed_term_name]
def cacheClosedTermName (env : Environment) (e : Expr) (n : Name) : Environment :=

View file

@ -94,6 +94,7 @@ builtin_initialize declMapExt : SimplePersistentEnvExtension Decl DeclMap ←
-- share a name prefix with the top-level Lean declaration being compiled, e.g. from
-- specialization.
asyncMode := .sync
replay? := some <| SimplePersistentEnvExtension.replayOfFilter (!·.contains ·.name) (fun s d => s.insert d.name d)
}
@[export lean_ir_find_env_decl]

View file

@ -143,6 +143,7 @@ builtin_initialize functionSummariesExt : SimplePersistentEnvExtension (FunId ×
addEntryFn := fun s ⟨e, n⟩ => s.insert e n
toArrayFn := fun s => sortEntries s.toArray
asyncMode := .sync -- compilation is non-parallel anyway
replay? := some <| SimplePersistentEnvExtension.replayOfFilter (!·.contains ·.1) (fun s ⟨e, n⟩ => s.insert e n)
}
def addFunctionSummary (env : Environment) (fid : FunId) (v : Value) : Environment :=

View file

@ -111,6 +111,9 @@ builtin_initialize specExtension : SimplePersistentEnvExtension SpecEntry SpecSt
addEntryFn := SpecState.addEntry,
addImportedFn := fun es => (mkStateFromImportedEntries SpecState.addEntry {} es).switch
asyncMode := .sync -- compilation is non-parallel anyway
replay? := some <| SimplePersistentEnvExtension.replayOfFilter (fun
| s, .info n _ => !s.specInfo.contains n
| s, .cache key _ => !s.cache.contains key) SpecState.addEntry
}
@[export lean_add_specialization_info]

View file

@ -1351,6 +1351,18 @@ structure SimplePersistentEnvExtensionDescr (α σ : Type) where
addImportedFn : Array (Array α) → σ
toArrayFn : List α → Array α := fun es => es.toArray
asyncMode : EnvExtension.AsyncMode := .mainOnly
replay? : Option ((newEntries : List α) → (newState : σ) → σ → List α × σ) := none
/--
Returns a function suitable for `SimplePersistentEnvExtensionDescr.replay?` that replays all new
entries onto the state using `addEntryFn`. `p` should filter out entries that have already been
added to the state by a prior replay of the same realizable constant.
-/
def SimplePersistentEnvExtension.replayOfFilter (p : σα → Bool)
(addEntryFn : σασ) : List ασσ → List α × σ :=
fun newEntries _ s =>
let newEntries := newEntries.filter (p s)
(newEntries, newEntries.foldl (init := s) addEntryFn)
def registerSimplePersistentEnvExtension {α σ : Type} [Inhabited σ] (descr : SimplePersistentEnvExtensionDescr α σ) : IO (SimplePersistentEnvExtension α σ) :=
registerPersistentEnvExtension {
@ -1362,9 +1374,10 @@ def registerSimplePersistentEnvExtension {α σ : Type} [Inhabited σ] (descr :
exportEntriesFn := fun s => descr.toArrayFn s.1.reverse,
statsFn := fun s => format "number of local entries: " ++ format s.1.length
asyncMode := descr.asyncMode
replay? := some fun oldState newState _ (entries, s) =>
let newEntries := newState.1.drop oldState.1.length
(newEntries ++ entries, newEntries.foldl descr.addEntryFn s)
replay? := descr.replay?.map fun replay oldState newState _ (entries, s) =>
let newEntries := newState.1.take (newState.1.length - oldState.1.length)
let (newEntries, s) := replay newEntries newState.2 s
(entries ++ newEntries, s)
}
namespace SimplePersistentEnvExtension

View file

@ -2255,8 +2255,12 @@ def realizeConst (forConst : Name) (constName : Name) (realize : MetaM Unit) :
return
withTraceNode `Meta.realizeConst (fun _ => return constName) do
let coreCtx ← readThe Core.Context
-- these fields should be invariant throughout the file
let coreCtx := { fileName := coreCtx.fileName, fileMap := coreCtx.fileMap }
let coreCtx := {
-- these fields should be invariant throughout the file
fileName := coreCtx.fileName, fileMap := coreCtx.fileMap
-- heartbeat limits inside `realizeAndReport` should be measured from this point on
initHeartbeats := (← IO.getNumHeartbeats)
}
let (env, dyn) ← env.realizeConst forConst constName (realizeAndReport coreCtx)
if let some res := dyn.get? RealizeConstantResult then
let mut snap := res.snap