feat: asynchronous code generation (#6770)

This PR enables code generation to proceed in parallel to further
elaboration.

It does not aim to make further refinements such as generating code for
different declarations in parallel or removing the dependency on kernel
checking.
This commit is contained in:
Sebastian Ullrich 2025-02-03 18:17:18 +01:00 committed by GitHub
parent a4ad409ae0
commit d01e038210
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 90 additions and 86 deletions

View file

@ -67,6 +67,7 @@ def addDecl (decl : Declaration) : CoreM Unit := do
| .thmDecl thm => pure (thm.name, .thmInfo thm, .thm)
| .defnDecl defn => pure (defn.name, .defnInfo defn, .defn)
| .mutualDefnDecl [defn] => pure (defn.name, .defnInfo defn, .defn)
| .axiomDecl ax => pure (ax.name, .axiomInfo ax, .axiom)
| _ => return (← doAdd)
-- no environment extension changes to report after kernel checking; ensures we do not

View file

@ -89,6 +89,11 @@ builtin_initialize declMapExt : SimplePersistentEnvExtension Decl DeclMap ←
toArrayFn := fun s =>
let decls := s.foldl (init := #[]) fun decls decl => decls.push decl
sortDecls decls
-- Written to on codegen environment branch but accessed from other elaboration branches when
-- calling into the interpreter. We cannot use `async` as the IR declarations added may not
-- share a name prefix with the top-level Lean declaration being compiled, e.g. from
-- specialization.
asyncMode := .sync
}
@[export lean_ir_find_env_decl]

View file

@ -12,7 +12,8 @@ namespace Lean.Compiler.LCNF
abbrev AuxDeclCache := PHashMap Decl Name
builtin_initialize auxDeclCacheExt : EnvExtension AuxDeclCache ← registerEnvExtension (pure {})
builtin_initialize auxDeclCacheExt : EnvExtension AuxDeclCache ←
registerEnvExtension (pure {}) (asyncMode := .sync) -- compilation is non-parallel anyway
inductive CacheAuxDeclResult where
| new
@ -29,4 +30,3 @@ def cacheAuxDecl (decl : Decl) : CompilerM CacheAuxDeclResult := do
return .new
end Lean.Compiler.LCNF

View file

@ -32,6 +32,7 @@ def mkDeclExt (name : Name := by exact decl_name%) : IO DeclExt := do
exportEntriesFn := fun s =>
let decls := s.foldl (init := #[]) fun decls _ decl => decls.push decl
sortDecls decls
asyncMode := .sync -- compilation is non-parallel anyway
}
builtin_initialize baseExt : PersistentEnvExtension Decl Decl DeclExtState ← mkDeclExt

View file

@ -110,6 +110,7 @@ builtin_initialize specExtension : SimplePersistentEnvExtension SpecEntry SpecSt
registerSimplePersistentEnvExtension {
addEntryFn := SpecState.addEntry,
addImportedFn := fun es => (mkStateFromImportedEntries SpecState.addEntry {} es).switch
asyncMode := .sync -- compilation is non-parallel anyway
}
@[export lean_add_specialization_info]

View file

@ -521,26 +521,24 @@ opaque compileDeclsNew (declNames : List Name) : CoreM Unit
@[extern "lean_compile_decls"]
opaque compileDeclsOld (env : Environment) (opt : @& Options) (decls : @& List Name) : Except Kernel.Exception Environment
def compileDecl (decl : Declaration) : CoreM Unit := do
-- don't compile if kernel errored; should be converted into a task dependency when compilation
-- is made async as well
if !decl.getNames.all (← getEnv).constants.contains then
-- `ref?` is used for error reporting if available
partial def compileDecls (decls : List Name) (ref? : Option Declaration := none)
(logErrors := true) : CoreM Unit := do
if !Elab.async.get (← getOptions) then
doCompile
return
let opts ← getOptions
let decls := Compiler.getDeclNamesForCodeGen decl
if compiler.enableNew.get opts then
compileDeclsNew decls
let res ← withTraceNode `compiler (fun _ => return m!"compiling old: {decls}") do
return compileDeclsOld (← getEnv) opts decls
match res with
| Except.ok env => setEnv env
| Except.error (.other msg) =>
checkUnsupported decl -- Generate nicer error message for unsupported recursors and axioms
throwError msg
| Except.error ex =>
throwKernelException ex
def compileDecls (decls : List Name) : CoreM Unit := do
let env ← getEnv
let (postEnv, prom) ← env.promiseChecked
let checkAct ← Core.wrapAsyncAsSnapshot fun _ => do
try
doCompile
finally
prom.resolve (← getEnv)
let t ← BaseIO.mapTask (fun _ => checkAct) env.checked
let endRange? := (← getRef).getTailPos?.map fun pos => ⟨pos, pos⟩
Core.logSnapshotTask { range? := endRange?, task := t }
setEnv postEnv
where doCompile := do
-- don't compile if kernel errored; should be converted into a task dependency when compilation
-- is made async as well
if !decls.all (← getEnv).constants.contains then
@ -548,12 +546,22 @@ def compileDecls (decls : List Name) : CoreM Unit := do
let opts ← getOptions
if compiler.enableNew.get opts then
compileDeclsNew decls
match compileDeclsOld (← getEnv) opts decls with
let res ← withTraceNode `compiler (fun _ => return m!"compiling old: {decls}") do
return compileDeclsOld (← getEnv) opts decls
match res with
| Except.ok env => setEnv env
| Except.error (.other msg) =>
throwError msg
if logErrors then
if let some decl := ref? then
checkUnsupported decl -- Generate nicer error message for unsupported recursors and axioms
throwError msg
| Except.error ex =>
throwKernelException ex
if logErrors then
throwKernelException ex
def compileDecl (decl : Declaration) (logErrors := true) : CoreM Unit := do
compileDecls (Compiler.getDeclNamesForCodeGen decl) decl logErrors
def getDiag (opts : Options) : Bool :=
diagnostics.get opts
@ -637,4 +645,8 @@ def logMessageKind (kind : Name) : CoreM Bool := do
modify fun s => { s with messages.loggedKinds := s.messages.loggedKinds.insert kind }
return true
builtin_initialize
registerTraceClass `Elab.async
registerTraceClass `Elab.block
end Lean

View file

@ -102,15 +102,8 @@ def addAsAxiom (preDef : PreDefinition) : MetaM Unit := do
private def shouldGenCodeFor (preDef : PreDefinition) : Bool :=
!preDef.kind.isTheorem && !preDef.modifiers.isNoncomputable
private def compileDecl (decl : Declaration) : TermElabM Bool := do
try
Lean.compileDecl decl
catch ex =>
if (← read).isNoncomputableSection then
return false
else
throw ex
return true
private def compileDecl (decl : Declaration) : TermElabM Unit := do
Lean.compileDecl (logErrors := !(← read).isNoncomputableSection) decl
register_builtin_option diagnostics.threshold.proofSize : Nat := {
defValue := 16384
@ -166,7 +159,7 @@ private def addNonRecAux (preDef : PreDefinition) (compile : Bool) (all : List N
if preDef.modifiers.isNoncomputable then
modifyEnv fun env => addNoncomputable env preDef.declName
if compile && shouldGenCodeFor preDef then
discard <| compileDecl decl
compileDecl decl
if applyAttrAfterCompilation then
generateEagerEqns preDef.declName
applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation
@ -206,7 +199,7 @@ def addAndCompileUnsafe (preDefs : Array PreDefinition) (safety := DefinitionSaf
for preDef in preDefs do
addTermInfo' preDef.ref (← mkConstWithLevelParams preDef.declName) (isBinder := true)
applyAttributesOf preDefs AttributeApplicationTime.afterTypeChecking
discard <| compileDecl decl
compileDecl decl
applyAttributesOf preDefs AttributeApplicationTime.afterCompilation
return ()

View file

@ -300,12 +300,13 @@ def addPreDefinitions (preDefs : Array PreDefinition) : TermElabM Unit := withLC
try
if preDefs.all fun preDef => (preDef.kind matches DefKind.def | DefKind.instance) || preDefs.all fun preDef => preDef.kind == DefKind.abbrev then
-- try to add as partial definition
try
addAndCompilePartial preDefs (useSorry := true)
catch _ =>
-- Compilation failed try again just as axiom
s.restore
addAsAxioms preDefs
withOptions (Elab.async.set · false) do
try
addAndCompilePartial preDefs (useSorry := true)
catch _ =>
-- Compilation failed try again just as axiom
s.restore
addAsAxioms preDefs
else if preDefs.all fun preDef => preDef.kind == DefKind.theorem then
addAsAxioms preDefs
catch _ => s.restore

View file

@ -451,9 +451,7 @@ def ofKernelEnv (env : Kernel.Environment) : Environment :=
@[export lean_elab_environment_to_kernel_env]
def toKernelEnv (env : Environment) : Kernel.Environment :=
-- TODO: should just be the following when we store extension data in `checked`
--env.checked.get
{ env.checked.get with extensions := env.checkedWithoutAsync.extensions }
env.checked.get
/-- Consistently updates synchronous and asynchronous parts of the environment without blocking. -/
private def modifyCheckedAsync (env : Environment) (f : Kernel.Environment → Kernel.Environment) : Environment :=
@ -463,6 +461,10 @@ private def modifyCheckedAsync (env : Environment) (f : Kernel.Environment → K
private def setCheckedSync (env : Environment) (newChecked : Kernel.Environment) : Environment :=
{ env with checked := .pure newChecked, checkedWithoutAsync := newChecked }
def promiseChecked (env : Environment) : BaseIO (Environment × IO.Promise Environment) := do
let prom ← IO.Promise.new
return ({ env with checked := prom.result.bind (sync := true) (·.checked) }, prom)
/--
Checks whether the given declaration name may potentially added, or have been added, to the current
environment branch, which is the case either if this is the main branch or if the declaration name
@ -527,7 +529,11 @@ def addExtraName (env : Environment) (name : Name) : Environment :=
/-- Find base case: name did not match any asynchronous declaration. -/
private def findNoAsync (env : Environment) (n : Name) : Option ConstantInfo := do
if let some _ := env.asyncConsts.findPrefix? n then
if env.asyncMayContain n then
-- Constant definitely not generated in a different environment branch: return none, callers
-- have already checked this branch.
none
else if let some _ := env.asyncConsts.findPrefix? n then
-- Constant generated in a different environment branch: wait for final kernel environment. Rare
-- case when only proofs are elaborated asynchronously as they are rarely inspected. Could be
-- optimized in the future by having the elaboration thread publish an (incremental?) map of
@ -621,6 +627,7 @@ information.
-/
def addConstAsync (env : Environment) (constName : Name) (kind : ConstantKind) (reportExts := true) :
IO AddConstAsyncResult := do
assert! env.asyncMayContain constName
let sigPromise ← IO.Promise.new
let infoPromise ← IO.Promise.new
let extensionsPromise ← IO.Promise.new
@ -702,6 +709,9 @@ def AddConstAsyncResult.commitFailure (res : AddConstAsyncResult) : BaseIO Unit
| .thm => .thmInfo { val with
value := mkApp2 (mkConst ``sorryAx [0]) val.type (mkConst ``true)
}
| .axiom => .axiomInfo { val with
isUnsafe := false
}
| k => panic! s!"AddConstAsyncResult.commitFailure: unsupported constant kind {repr k}"
res.extensionsPromise.resolve #[]
let _ ← BaseIO.mapTask (t := res.asyncEnv.checked) (sync := true) res.checkedEnvPromise.resolve
@ -1565,7 +1575,7 @@ def getNamespaceSet (env : Environment) : NameSSet :=
@[export lean_elab_environment_update_base_after_kernel_add]
private def updateBaseAfterKernelAdd (env : Environment) (kernel : Kernel.Environment) : Environment :=
env.setCheckedSync { kernel with extensions := env.checkedWithoutAsync.extensions }
{ env with checked := .pure kernel, checkedWithoutAsync := { kernel with extensions := env.checkedWithoutAsync.extensions } }
@[export lean_display_stats]
def displayStats (env : Environment) : IO Unit := do

View file

@ -24,8 +24,15 @@ unsafe def evalExprCore (α) (value : Expr) (checkType : Expr → MetaM Unit) (s
value, hints := ReducibilityHints.opaque,
safety
}
addAndCompile decl
evalConst α name
-- compilation will invariably wait on `checked`, do it now and tag as blocker
unless (← IO.hasFinished (← getEnv).checked) do
withTraceNode `Elab.block (fun _ => pure "") do
let _ ← IO.wait (← getEnv).checked
-- now that we've already waited, async would just introduce (minor) overhead and trigger
-- `Task.get` blocking debug code
withOptions (Elab.async.set · false) do
addAndCompile decl
evalConst α name
unsafe def evalExpr' (α) (typeName : Name) (value : Expr) (safety := DefinitionSafety.safe) : MetaM α :=
evalExprCore (safety := safety) α value fun type => do

View file

@ -31,7 +31,7 @@ instance : ToJson Microseconds where toJson x := toJson x.μs
instance : FromJson Microseconds where fromJson? j := Microseconds.mk <$> fromJson? j
structure Category where
name : String
name : Name
color : String
subcategories : Array String := #[]
deriving FromJson, ToJson
@ -164,10 +164,13 @@ structure ThreadWithMaps extends Thread where
lastTime : Float := 0
-- TODO: add others, dynamically?
-- NOTE: more specific prefixes should come first
def categories : Array Category := #[
{ name := "Other", color := "gray" },
{ name := "Elab", color := "red" },
{ name := "Meta", color := "yellow" }
{ name := `Other, color := "gray" },
{ name := `Elab.async, color := "gray" },
{ name := `Elab.block, color := "brown" },
{ name := `Elab, color := "red" },
{ name := `Meta, color := "yellow" }
]
/-- Returns first `startTime` in the trace tree, if any. -/
@ -201,7 +204,7 @@ where
(thread.stringMap.size, { thread with
stringArray := thread.stringArray.push funcName
stringMap := thread.stringMap.insert funcName thread.stringMap.size })
let category := categories.findIdx? (·.name == data.cls.getRoot.toString) |>.getD 0
let category := categories.findIdx? (·.name.isPrefixOf data.cls) |>.getD 0
let funcIdx ← modifyGet fun thread =>
if let some idx := thread.funcMap[strIdx]? then
(idx, thread)

View file

@ -54,6 +54,7 @@ def Exp.casesOn._override._rarg._boxed (x_1 : obj) (x_2 : obj) (x_3 : obj) (x_4
dec x_5;
dec x_4;
ret x_9
[result]
def Exp.var._override (x_1 : u32) : obj :=
let x_2 : u64 := UInt32.toUInt64 x_1;

View file

@ -1,31 +0,0 @@
partial def f : List Nat → Bool
| [] => false
| (a::as) => a > 0 && f as
#print f._cstage2
#exit
mutual def f1, f2, f3, f4, f5
with f1 : Nat → Bool
| 0 := f3 0
| x := f2 x
with f2 : Nat → Bool
| 0 := f4 0
| x := f3 x
with f3 : Nat → Bool
| 0 := f5 0
| x := f4 x
with f4 : Nat → Bool
| 0 := f5 0
| (x+1) := f4 x
with f5 : Nat → Bool
| 0 := true
| _ := false
#check f1._main._cstage2
#check f2._main._cstage2
#check f3._main._cstage2
#check f4._main._cstage2
#check f5._main._cstage2

View file

@ -8,6 +8,8 @@ The following example would cause the pretty printer to panic.
set_option trace.compiler.simp true in
/--
info: [0]
---
info: [compiler.simp] >> _eval
let _x_21 := `Nat;
let _x_22 := [];
@ -105,8 +107,6 @@ let _x_1 :=
Lean.List.toExprAux._at._eval._spec_1✝ _eval._closed_9 _eval._closed_13
_eval._closed_14;
Lean.MessageData.ofExpr _x_1
---
info: [0]
-/
#guard_msgs in
#eval [0]