fix: toposort declarations to ensure proper constant initialization (#11388)

This PR is a followup of #11381 and enforces the invariants on ordering
of closed terms and constants required by the EmitC pass properly by
toposorting before saving the declarations into the Environment.
This commit is contained in:
Henrik Böving 2025-11-26 19:17:17 +01:00 committed by GitHub
parent 8639afacf8
commit 5dde403ec0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 99 additions and 37 deletions

View file

@ -27,6 +27,7 @@ public import Lean.Compiler.IR.Sorry
public import Lean.Compiler.IR.ToIR
public import Lean.Compiler.IR.ToIRType
public import Lean.Compiler.IR.Meta
public import Lean.Compiler.IR.Toposort
-- The following imports are not required by the compiler. They are here to ensure that there
-- are no orphaned modules.
@ -71,6 +72,7 @@ def compile (decls : Array Decl) : CompilerM (Array Decl) := do
decls ← updateSorryDep decls
logDecls `result decls
checkDecls decls
decls ← toposortDecls decls
addDecls decls
inferMeta decls
return decls

View file

@ -0,0 +1,70 @@
/-
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Henrik Böving
-/
module
prelude
public import Lean.Compiler.IR.CompilerM
/-!
This module "topologically sorts" an SCC of decls (an SCC of decls in the pipeline may in fact
contain more than one SCC at the moment). This is relevant because EmitC relies on the invariant
that the constants (and in particular also the closed terms) occur in a reverse topologically sorted
order for emitting them.
-/
namespace Lean.IR
structure TopoState where
declsMap : Std.HashMap Name Decl
seen : Std.HashSet Name
order : Array Decl
abbrev ToposortM := StateRefT TopoState CompilerM
partial def toposort (decls : Array Decl) : CompilerM (Array Decl) := do
let declsMap := .ofList (decls.toList.map (fun d => (d.name, d)))
let (_, s) ← go decls |>.run {
declsMap,
seen := .emptyWithCapacity decls.size,
order := .emptyWithCapacity decls.size
}
return s.order
where
go (decls : Array Decl) : ToposortM Unit := do
decls.forM process
process (decl : Decl) : ToposortM Unit := do
if (← get).seen.contains decl.name then
return ()
modify fun s => { s with seen := s.seen.insert decl.name }
let .fdecl (body := body) .. := decl | unreachable!
processBody body
modify fun s => { s with order := s.order.push decl }
processBody (b : FnBody) : ToposortM Unit := do
match b with
| .vdecl _ _ e b =>
match e with
| .fap c .. | .pap c .. =>
if let some decl := (← get).declsMap[c]? then
process decl
| _ => pure ()
processBody b
| .jdecl _ _ v b =>
processBody v
processBody b
| .case _ _ _ cs => cs.forM (fun alt => processBody alt.body)
| .jmp .. | .ret .. | .unreachable => return ()
| _ => processBody b.body
public def toposortDecls (decls : Array Decl) : CompilerM (Array Decl) := do
let (externDecls, otherDecls) := decls.partition (fun decl => decl.isExtern)
let otherDecls ← toposort otherDecls
return externDecls ++ otherDecls
end Lean.IR

View file

@ -141,27 +141,17 @@ def visitDecl (decl : Decl) : M Decl := do
end ExtractClosed
partial def Decl.extractClosed (decl : Decl) (sccDecls : Array Decl) : CompilerM (Decl × Array Decl) := do
partial def Decl.extractClosed (decl : Decl) (sccDecls : Array Decl) : CompilerM (Array Decl) := do
let ⟨decl, s⟩ ← ExtractClosed.visitDecl decl |>.run { baseName := decl.name, sccDecls } |>.run {}
return (decl, s.decls)
return s.decls.push decl
def extractClosed : Pass where
phase := .mono
name := `extractClosed
run := fun decls => do
if (← getConfig).extractClosed then
let mut changedDecls := Array.emptyWithCapacity decls.size
let mut closedDecls := #[]
for decl in decls do
let (change, new) ← decl.extractClosed decls
changedDecls := changedDecls.push change
closedDecls := closedDecls ++ new
/-
EmitC later relies on the fact that within an SCC the closed term declarations come first,
then the declarations that rely on them.
-/
return closedDecls ++ changedDecls
decls.foldlM (init := #[]) fun newDecls decl =>
return newDecls ++ (← decl.extractClosed decls)
else
return decls

View file

@ -11,7 +11,17 @@ trace: [Compiler.result] size: 0
def f x : Nat :=
---
trace: [Compiler.result] size: 1
trace: [Compiler.result] size: 5
def _private.lean.run.emptyLcnf.0._eval._lam_0 _x.1 _x.2 _y.3 _y.4 _y.5 _y.6 _y.7 _y.8 _y.9 : EST.Out Lean.Exception
lcAny PUnit :=
let _x.10 := Lean.Compiler.compile _x.1 _y.7 _y.8 _y.9;
cases _x.10 : EST.Out Lean.Exception lcAny PUnit
| EST.Out.ok a.11 a.12 =>
let _x.13 := @EST.Out.ok ◾ ◾ ◾ _x.2 a.12;
return _x.13
| EST.Out.error a.14 a.15 =>
return _x.10
[Compiler.result] size: 1
def _private.lean.run.emptyLcnf.0._eval._closed_0 : String :=
let _x.1 := "f";
return _x.1
@ -31,16 +41,6 @@ trace: [Compiler.result] size: 1
let _x.2 := _eval._closed_2.2;
let _x.3 := Array.push ◾ _x.2 _x.1;
return _x.3
[Compiler.result] size: 5
def _private.lean.run.emptyLcnf.0._eval._lam_0 _x.1 _x.2 _y.3 _y.4 _y.5 _y.6 _y.7 _y.8 _y.9 : EST.Out Lean.Exception
lcAny PUnit :=
let _x.10 := Lean.Compiler.compile _x.1 _y.7 _y.8 _y.9;
cases _x.10 : EST.Out Lean.Exception lcAny PUnit
| EST.Out.ok a.11 a.12 =>
let _x.13 := @EST.Out.ok ◾ ◾ ◾ _x.2 a.12;
return _x.13
| EST.Out.error a.14 a.15 =>
return _x.10
[Compiler.result] size: 8
def _private.lean.run.emptyLcnf.0._eval a.1 a.2 a.3 : EST.Out Lean.Exception lcAny PUnit :=
let _x.4 := _eval._closed_0.2;

View file

@ -25,7 +25,18 @@ trace: [Compiler.result] size: 1
let _x.1 : PSigma lcErased lcAny := PSigma.mk ◾ ◾ ◾ ◾;
return _x.1
---
trace: [Compiler.result] size: 1
trace: [Compiler.result] size: 5
def _private.lean.run.erased.0._eval._lam_0 (_x.1 : Array
Lean.Name) (_x.2 : PUnit) (_y.3 : Lean.Elab.Term.Context) (_y.4 : lcAny) (_y.5 : Lean.Meta.Context) (_y.6 : lcAny) (_y.7 : Lean.Core.Context) (_y.8 : lcAny) (_y.9 : lcVoid) : EST.Out
Lean.Exception lcAny PUnit :=
let _x.10 : EST.Out Lean.Exception lcAny PUnit := compile _x.1 _y.7 _y.8 _y.9;
cases _x.10 : EST.Out Lean.Exception lcAny PUnit
| EST.Out.ok (a.11 : PUnit) (a.12 : lcVoid) =>
let _x.13 : EST.Out Lean.Exception lcAny PUnit := @EST.Out.ok ◾ ◾ ◾ _x.2 a.12;
return _x.13
| EST.Out.error (a.14 : Lean.Exception) (a.15 : lcVoid) =>
return _x.10
[Compiler.result] size: 1
def _private.lean.run.erased.0._eval._closed_0 : String :=
let _x.1 : String := "Erased";
return _x.1
@ -50,17 +61,6 @@ trace: [Compiler.result] size: 1
let _x.2 : Array Lean.Name := _eval._closed_3.2;
let _x.3 : Array Lean.Name := Array.push ◾ _x.2 _x.1;
return _x.3
[Compiler.result] size: 5
def _private.lean.run.erased.0._eval._lam_0 (_x.1 : Array
Lean.Name) (_x.2 : PUnit) (_y.3 : Lean.Elab.Term.Context) (_y.4 : lcAny) (_y.5 : Lean.Meta.Context) (_y.6 : lcAny) (_y.7 : Lean.Core.Context) (_y.8 : lcAny) (_y.9 : lcVoid) : EST.Out
Lean.Exception lcAny PUnit :=
let _x.10 : EST.Out Lean.Exception lcAny PUnit := compile _x.1 _y.7 _y.8 _y.9;
cases _x.10 : EST.Out Lean.Exception lcAny PUnit
| EST.Out.ok (a.11 : PUnit) (a.12 : lcVoid) =>
let _x.13 : EST.Out Lean.Exception lcAny PUnit := @EST.Out.ok ◾ ◾ ◾ _x.2 a.12;
return _x.13
| EST.Out.error (a.14 : Lean.Exception) (a.15 : lcVoid) =>
return _x.10
[Compiler.result] size: 9
def _private.lean.run.erased.0._eval (a.1 : Lean.Elab.Command.Context) (a.2 : lcAny) (a.3 : lcVoid) : EST.Out
Lean.Exception lcAny PUnit :=