fix: toposort declarations to ensure proper constant initialization (#11388)
This PR is a followup of #11381 and enforces the invariants on ordering of closed terms and constants required by the EmitC pass properly by toposorting before saving the declarations into the Environment.
This commit is contained in:
parent
8639afacf8
commit
5dde403ec0
5 changed files with 99 additions and 37 deletions
|
|
@ -27,6 +27,7 @@ public import Lean.Compiler.IR.Sorry
|
|||
public import Lean.Compiler.IR.ToIR
|
||||
public import Lean.Compiler.IR.ToIRType
|
||||
public import Lean.Compiler.IR.Meta
|
||||
public import Lean.Compiler.IR.Toposort
|
||||
|
||||
-- The following imports are not required by the compiler. They are here to ensure that there
|
||||
-- are no orphaned modules.
|
||||
|
|
@ -71,6 +72,7 @@ def compile (decls : Array Decl) : CompilerM (Array Decl) := do
|
|||
decls ← updateSorryDep decls
|
||||
logDecls `result decls
|
||||
checkDecls decls
|
||||
decls ← toposortDecls decls
|
||||
addDecls decls
|
||||
inferMeta decls
|
||||
return decls
|
||||
|
|
|
|||
70
src/Lean/Compiler/IR/Toposort.lean
Normal file
70
src/Lean/Compiler/IR/Toposort.lean
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
/-
|
||||
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Henrik Böving
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Lean.Compiler.IR.CompilerM
|
||||
|
||||
/-!
|
||||
This module "topologically sorts" an SCC of decls (an SCC of decls in the pipeline may in fact
|
||||
contain more than one SCC at the moment). This is relevant because EmitC relies on the invariant
|
||||
that the constants (and in particular also the closed terms) occur in a reverse topologically sorted
|
||||
order for emitting them.
|
||||
-/
|
||||
|
||||
namespace Lean.IR
|
||||
|
||||
structure TopoState where
|
||||
declsMap : Std.HashMap Name Decl
|
||||
seen : Std.HashSet Name
|
||||
order : Array Decl
|
||||
|
||||
abbrev ToposortM := StateRefT TopoState CompilerM
|
||||
|
||||
partial def toposort (decls : Array Decl) : CompilerM (Array Decl) := do
|
||||
let declsMap := .ofList (decls.toList.map (fun d => (d.name, d)))
|
||||
let (_, s) ← go decls |>.run {
|
||||
declsMap,
|
||||
seen := .emptyWithCapacity decls.size,
|
||||
order := .emptyWithCapacity decls.size
|
||||
}
|
||||
return s.order
|
||||
where
|
||||
go (decls : Array Decl) : ToposortM Unit := do
|
||||
decls.forM process
|
||||
|
||||
process (decl : Decl) : ToposortM Unit := do
|
||||
if (← get).seen.contains decl.name then
|
||||
return ()
|
||||
|
||||
modify fun s => { s with seen := s.seen.insert decl.name }
|
||||
let .fdecl (body := body) .. := decl | unreachable!
|
||||
processBody body
|
||||
modify fun s => { s with order := s.order.push decl }
|
||||
|
||||
processBody (b : FnBody) : ToposortM Unit := do
|
||||
match b with
|
||||
| .vdecl _ _ e b =>
|
||||
match e with
|
||||
| .fap c .. | .pap c .. =>
|
||||
if let some decl := (← get).declsMap[c]? then
|
||||
process decl
|
||||
| _ => pure ()
|
||||
processBody b
|
||||
| .jdecl _ _ v b =>
|
||||
processBody v
|
||||
processBody b
|
||||
| .case _ _ _ cs => cs.forM (fun alt => processBody alt.body)
|
||||
| .jmp .. | .ret .. | .unreachable => return ()
|
||||
| _ => processBody b.body
|
||||
|
||||
|
||||
public def toposortDecls (decls : Array Decl) : CompilerM (Array Decl) := do
|
||||
let (externDecls, otherDecls) := decls.partition (fun decl => decl.isExtern)
|
||||
let otherDecls ← toposort otherDecls
|
||||
return externDecls ++ otherDecls
|
||||
|
||||
end Lean.IR
|
||||
|
|
@ -141,27 +141,17 @@ def visitDecl (decl : Decl) : M Decl := do
|
|||
|
||||
end ExtractClosed
|
||||
|
||||
partial def Decl.extractClosed (decl : Decl) (sccDecls : Array Decl) : CompilerM (Decl × Array Decl) := do
|
||||
partial def Decl.extractClosed (decl : Decl) (sccDecls : Array Decl) : CompilerM (Array Decl) := do
|
||||
let ⟨decl, s⟩ ← ExtractClosed.visitDecl decl |>.run { baseName := decl.name, sccDecls } |>.run {}
|
||||
return (decl, s.decls)
|
||||
return s.decls.push decl
|
||||
|
||||
def extractClosed : Pass where
|
||||
phase := .mono
|
||||
name := `extractClosed
|
||||
run := fun decls => do
|
||||
if (← getConfig).extractClosed then
|
||||
let mut changedDecls := Array.emptyWithCapacity decls.size
|
||||
let mut closedDecls := #[]
|
||||
for decl in decls do
|
||||
let (change, new) ← decl.extractClosed decls
|
||||
changedDecls := changedDecls.push change
|
||||
closedDecls := closedDecls ++ new
|
||||
|
||||
/-
|
||||
EmitC later relies on the fact that within an SCC the closed term declarations come first,
|
||||
then the declarations that rely on them.
|
||||
-/
|
||||
return closedDecls ++ changedDecls
|
||||
decls.foldlM (init := #[]) fun newDecls decl =>
|
||||
return newDecls ++ (← decl.extractClosed decls)
|
||||
else
|
||||
return decls
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,17 @@ trace: [Compiler.result] size: 0
|
|||
def f x : Nat :=
|
||||
⊥
|
||||
---
|
||||
trace: [Compiler.result] size: 1
|
||||
trace: [Compiler.result] size: 5
|
||||
def _private.lean.run.emptyLcnf.0._eval._lam_0 _x.1 _x.2 _y.3 _y.4 _y.5 _y.6 _y.7 _y.8 _y.9 : EST.Out Lean.Exception
|
||||
lcAny PUnit :=
|
||||
let _x.10 := Lean.Compiler.compile _x.1 _y.7 _y.8 _y.9;
|
||||
cases _x.10 : EST.Out Lean.Exception lcAny PUnit
|
||||
| EST.Out.ok a.11 a.12 =>
|
||||
let _x.13 := @EST.Out.ok ◾ ◾ ◾ _x.2 a.12;
|
||||
return _x.13
|
||||
| EST.Out.error a.14 a.15 =>
|
||||
return _x.10
|
||||
[Compiler.result] size: 1
|
||||
def _private.lean.run.emptyLcnf.0._eval._closed_0 : String :=
|
||||
let _x.1 := "f";
|
||||
return _x.1
|
||||
|
|
@ -31,16 +41,6 @@ trace: [Compiler.result] size: 1
|
|||
let _x.2 := _eval._closed_2.2;
|
||||
let _x.3 := Array.push ◾ _x.2 _x.1;
|
||||
return _x.3
|
||||
[Compiler.result] size: 5
|
||||
def _private.lean.run.emptyLcnf.0._eval._lam_0 _x.1 _x.2 _y.3 _y.4 _y.5 _y.6 _y.7 _y.8 _y.9 : EST.Out Lean.Exception
|
||||
lcAny PUnit :=
|
||||
let _x.10 := Lean.Compiler.compile _x.1 _y.7 _y.8 _y.9;
|
||||
cases _x.10 : EST.Out Lean.Exception lcAny PUnit
|
||||
| EST.Out.ok a.11 a.12 =>
|
||||
let _x.13 := @EST.Out.ok ◾ ◾ ◾ _x.2 a.12;
|
||||
return _x.13
|
||||
| EST.Out.error a.14 a.15 =>
|
||||
return _x.10
|
||||
[Compiler.result] size: 8
|
||||
def _private.lean.run.emptyLcnf.0._eval a.1 a.2 a.3 : EST.Out Lean.Exception lcAny PUnit :=
|
||||
let _x.4 := _eval._closed_0.2;
|
||||
|
|
|
|||
|
|
@ -25,7 +25,18 @@ trace: [Compiler.result] size: 1
|
|||
let _x.1 : PSigma lcErased lcAny := PSigma.mk ◾ ◾ ◾ ◾;
|
||||
return _x.1
|
||||
---
|
||||
trace: [Compiler.result] size: 1
|
||||
trace: [Compiler.result] size: 5
|
||||
def _private.lean.run.erased.0._eval._lam_0 (_x.1 : Array
|
||||
Lean.Name) (_x.2 : PUnit) (_y.3 : Lean.Elab.Term.Context) (_y.4 : lcAny) (_y.5 : Lean.Meta.Context) (_y.6 : lcAny) (_y.7 : Lean.Core.Context) (_y.8 : lcAny) (_y.9 : lcVoid) : EST.Out
|
||||
Lean.Exception lcAny PUnit :=
|
||||
let _x.10 : EST.Out Lean.Exception lcAny PUnit := compile _x.1 _y.7 _y.8 _y.9;
|
||||
cases _x.10 : EST.Out Lean.Exception lcAny PUnit
|
||||
| EST.Out.ok (a.11 : PUnit) (a.12 : lcVoid) =>
|
||||
let _x.13 : EST.Out Lean.Exception lcAny PUnit := @EST.Out.ok ◾ ◾ ◾ _x.2 a.12;
|
||||
return _x.13
|
||||
| EST.Out.error (a.14 : Lean.Exception) (a.15 : lcVoid) =>
|
||||
return _x.10
|
||||
[Compiler.result] size: 1
|
||||
def _private.lean.run.erased.0._eval._closed_0 : String :=
|
||||
let _x.1 : String := "Erased";
|
||||
return _x.1
|
||||
|
|
@ -50,17 +61,6 @@ trace: [Compiler.result] size: 1
|
|||
let _x.2 : Array Lean.Name := _eval._closed_3.2;
|
||||
let _x.3 : Array Lean.Name := Array.push ◾ _x.2 _x.1;
|
||||
return _x.3
|
||||
[Compiler.result] size: 5
|
||||
def _private.lean.run.erased.0._eval._lam_0 (_x.1 : Array
|
||||
Lean.Name) (_x.2 : PUnit) (_y.3 : Lean.Elab.Term.Context) (_y.4 : lcAny) (_y.5 : Lean.Meta.Context) (_y.6 : lcAny) (_y.7 : Lean.Core.Context) (_y.8 : lcAny) (_y.9 : lcVoid) : EST.Out
|
||||
Lean.Exception lcAny PUnit :=
|
||||
let _x.10 : EST.Out Lean.Exception lcAny PUnit := compile _x.1 _y.7 _y.8 _y.9;
|
||||
cases _x.10 : EST.Out Lean.Exception lcAny PUnit
|
||||
| EST.Out.ok (a.11 : PUnit) (a.12 : lcVoid) =>
|
||||
let _x.13 : EST.Out Lean.Exception lcAny PUnit := @EST.Out.ok ◾ ◾ ◾ _x.2 a.12;
|
||||
return _x.13
|
||||
| EST.Out.error (a.14 : Lean.Exception) (a.15 : lcVoid) =>
|
||||
return _x.10
|
||||
[Compiler.result] size: 9
|
||||
def _private.lean.run.erased.0._eval (a.1 : Lean.Elab.Command.Context) (a.2 : lcAny) (a.3 : lcVoid) : EST.Out
|
||||
Lean.Exception lcAny PUnit :=
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue