From 5dde403ec084da660d85b44824f82b6a06fa5a4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrik=20B=C3=B6ving?= Date: Wed, 26 Nov 2025 19:17:17 +0100 Subject: [PATCH] fix: toposort declarations to ensure proper constant initialization (#11388) This PR is a followup of #11381 and enforces the invariants on ordering of closed terms and constants required by the EmitC pass properly by toposorting before saving the declarations into the Environment. --- src/Lean/Compiler/IR.lean | 2 + src/Lean/Compiler/IR/Toposort.lean | 70 +++++++++++++++++++++++ src/Lean/Compiler/LCNF/ExtractClosed.lean | 18 ++---- tests/lean/run/emptyLcnf.lean | 22 +++---- tests/lean/run/erased.lean | 24 ++++---- 5 files changed, 99 insertions(+), 37 deletions(-) create mode 100644 src/Lean/Compiler/IR/Toposort.lean diff --git a/src/Lean/Compiler/IR.lean b/src/Lean/Compiler/IR.lean index 2320b76de8..bdda211013 100644 --- a/src/Lean/Compiler/IR.lean +++ b/src/Lean/Compiler/IR.lean @@ -27,6 +27,7 @@ public import Lean.Compiler.IR.Sorry public import Lean.Compiler.IR.ToIR public import Lean.Compiler.IR.ToIRType public import Lean.Compiler.IR.Meta +public import Lean.Compiler.IR.Toposort -- The following imports are not required by the compiler. They are here to ensure that there -- are no orphaned modules. @@ -71,6 +72,7 @@ def compile (decls : Array Decl) : CompilerM (Array Decl) := do decls ← updateSorryDep decls logDecls `result decls checkDecls decls + decls ← toposortDecls decls addDecls decls inferMeta decls return decls diff --git a/src/Lean/Compiler/IR/Toposort.lean b/src/Lean/Compiler/IR/Toposort.lean new file mode 100644 index 0000000000..1110c4afe8 --- /dev/null +++ b/src/Lean/Compiler/IR/Toposort.lean @@ -0,0 +1,70 @@ +/- +Copyright (c) 2025 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Henrik Böving +-/ +module + +prelude +public import Lean.Compiler.IR.CompilerM + +/-! +This module "topologically sorts" an SCC of decls (an SCC of decls in the pipeline may in fact +contain more than one SCC at the moment). This is relevant because EmitC relies on the invariant +that the constants (and in particular also the closed terms) occur in a reverse topologically sorted +order for emitting them. +-/ + +namespace Lean.IR + +structure TopoState where + declsMap : Std.HashMap Name Decl + seen : Std.HashSet Name + order : Array Decl + +abbrev ToposortM := StateRefT TopoState CompilerM + +partial def toposort (decls : Array Decl) : CompilerM (Array Decl) := do + let declsMap := .ofList (decls.toList.map (fun d => (d.name, d))) + let (_, s) ← go decls |>.run { + declsMap, + seen := .emptyWithCapacity decls.size, + order := .emptyWithCapacity decls.size + } + return s.order +where + go (decls : Array Decl) : ToposortM Unit := do + decls.forM process + + process (decl : Decl) : ToposortM Unit := do + if (← get).seen.contains decl.name then + return () + + modify fun s => { s with seen := s.seen.insert decl.name } + let .fdecl (body := body) .. := decl | unreachable! + processBody body + modify fun s => { s with order := s.order.push decl } + + processBody (b : FnBody) : ToposortM Unit := do + match b with + | .vdecl _ _ e b => + match e with + | .fap c .. | .pap c .. => + if let some decl := (← get).declsMap[c]? then + process decl + | _ => pure () + processBody b + | .jdecl _ _ v b => + processBody v + processBody b + | .case _ _ _ cs => cs.forM (fun alt => processBody alt.body) + | .jmp .. | .ret .. | .unreachable => return () + | _ => processBody b.body + + +public def toposortDecls (decls : Array Decl) : CompilerM (Array Decl) := do + let (externDecls, otherDecls) := decls.partition (fun decl => decl.isExtern) + let otherDecls ← toposort otherDecls + return externDecls ++ otherDecls + +end Lean.IR diff --git a/src/Lean/Compiler/LCNF/ExtractClosed.lean b/src/Lean/Compiler/LCNF/ExtractClosed.lean index 1f9981b195..26461d3df1 100644 --- a/src/Lean/Compiler/LCNF/ExtractClosed.lean +++ b/src/Lean/Compiler/LCNF/ExtractClosed.lean @@ -141,27 +141,17 @@ def visitDecl (decl : Decl) : M Decl := do end ExtractClosed -partial def Decl.extractClosed (decl : Decl) (sccDecls : Array Decl) : CompilerM (Decl × Array Decl) := do +partial def Decl.extractClosed (decl : Decl) (sccDecls : Array Decl) : CompilerM (Array Decl) := do let ⟨decl, s⟩ ← ExtractClosed.visitDecl decl |>.run { baseName := decl.name, sccDecls } |>.run {} - return (decl, s.decls) + return s.decls.push decl def extractClosed : Pass where phase := .mono name := `extractClosed run := fun decls => do if (← getConfig).extractClosed then - let mut changedDecls := Array.emptyWithCapacity decls.size - let mut closedDecls := #[] - for decl in decls do - let (change, new) ← decl.extractClosed decls - changedDecls := changedDecls.push change - closedDecls := closedDecls ++ new - - /- - EmitC later relies on the fact that within an SCC the closed term declarations come first, - then the declarations that rely on them. - -/ - return closedDecls ++ changedDecls + decls.foldlM (init := #[]) fun newDecls decl => + return newDecls ++ (← decl.extractClosed decls) else return decls diff --git a/tests/lean/run/emptyLcnf.lean b/tests/lean/run/emptyLcnf.lean index b2a7f12e7d..3e74cdd6f7 100644 --- a/tests/lean/run/emptyLcnf.lean +++ b/tests/lean/run/emptyLcnf.lean @@ -11,7 +11,17 @@ trace: [Compiler.result] size: 0 def f x : Nat := ⊥ --- -trace: [Compiler.result] size: 1 +trace: [Compiler.result] size: 5 + def _private.lean.run.emptyLcnf.0._eval._lam_0 _x.1 _x.2 _y.3 _y.4 _y.5 _y.6 _y.7 _y.8 _y.9 : EST.Out Lean.Exception + lcAny PUnit := + let _x.10 := Lean.Compiler.compile _x.1 _y.7 _y.8 _y.9; + cases _x.10 : EST.Out Lean.Exception lcAny PUnit + | EST.Out.ok a.11 a.12 => + let _x.13 := @EST.Out.ok ◾ ◾ ◾ _x.2 a.12; + return _x.13 + | EST.Out.error a.14 a.15 => + return _x.10 +[Compiler.result] size: 1 def _private.lean.run.emptyLcnf.0._eval._closed_0 : String := let _x.1 := "f"; return _x.1 @@ -31,16 +41,6 @@ trace: [Compiler.result] size: 1 let _x.2 := _eval._closed_2.2; let _x.3 := Array.push ◾ _x.2 _x.1; return _x.3 -[Compiler.result] size: 5 - def _private.lean.run.emptyLcnf.0._eval._lam_0 _x.1 _x.2 _y.3 _y.4 _y.5 _y.6 _y.7 _y.8 _y.9 : EST.Out Lean.Exception - lcAny PUnit := - let _x.10 := Lean.Compiler.compile _x.1 _y.7 _y.8 _y.9; - cases _x.10 : EST.Out Lean.Exception lcAny PUnit - | EST.Out.ok a.11 a.12 => - let _x.13 := @EST.Out.ok ◾ ◾ ◾ _x.2 a.12; - return _x.13 - | EST.Out.error a.14 a.15 => - return _x.10 [Compiler.result] size: 8 def _private.lean.run.emptyLcnf.0._eval a.1 a.2 a.3 : EST.Out Lean.Exception lcAny PUnit := let _x.4 := _eval._closed_0.2; diff --git a/tests/lean/run/erased.lean b/tests/lean/run/erased.lean index be9e7e45f8..6cccf2ac6e 100644 --- a/tests/lean/run/erased.lean +++ b/tests/lean/run/erased.lean @@ -25,7 +25,18 @@ trace: [Compiler.result] size: 1 let _x.1 : PSigma lcErased lcAny := PSigma.mk ◾ ◾ ◾ ◾; return _x.1 --- -trace: [Compiler.result] size: 1 +trace: [Compiler.result] size: 5 + def _private.lean.run.erased.0._eval._lam_0 (_x.1 : Array + Lean.Name) (_x.2 : PUnit) (_y.3 : Lean.Elab.Term.Context) (_y.4 : lcAny) (_y.5 : Lean.Meta.Context) (_y.6 : lcAny) (_y.7 : Lean.Core.Context) (_y.8 : lcAny) (_y.9 : lcVoid) : EST.Out + Lean.Exception lcAny PUnit := + let _x.10 : EST.Out Lean.Exception lcAny PUnit := compile _x.1 _y.7 _y.8 _y.9; + cases _x.10 : EST.Out Lean.Exception lcAny PUnit + | EST.Out.ok (a.11 : PUnit) (a.12 : lcVoid) => + let _x.13 : EST.Out Lean.Exception lcAny PUnit := @EST.Out.ok ◾ ◾ ◾ _x.2 a.12; + return _x.13 + | EST.Out.error (a.14 : Lean.Exception) (a.15 : lcVoid) => + return _x.10 +[Compiler.result] size: 1 def _private.lean.run.erased.0._eval._closed_0 : String := let _x.1 : String := "Erased"; return _x.1 @@ -50,17 +61,6 @@ trace: [Compiler.result] size: 1 let _x.2 : Array Lean.Name := _eval._closed_3.2; let _x.3 : Array Lean.Name := Array.push ◾ _x.2 _x.1; return _x.3 -[Compiler.result] size: 5 - def _private.lean.run.erased.0._eval._lam_0 (_x.1 : Array - Lean.Name) (_x.2 : PUnit) (_y.3 : Lean.Elab.Term.Context) (_y.4 : lcAny) (_y.5 : Lean.Meta.Context) (_y.6 : lcAny) (_y.7 : Lean.Core.Context) (_y.8 : lcAny) (_y.9 : lcVoid) : EST.Out - Lean.Exception lcAny PUnit := - let _x.10 : EST.Out Lean.Exception lcAny PUnit := compile _x.1 _y.7 _y.8 _y.9; - cases _x.10 : EST.Out Lean.Exception lcAny PUnit - | EST.Out.ok (a.11 : PUnit) (a.12 : lcVoid) => - let _x.13 : EST.Out Lean.Exception lcAny PUnit := @EST.Out.ok ◾ ◾ ◾ _x.2 a.12; - return _x.13 - | EST.Out.error (a.14 : Lean.Exception) (a.15 : lcVoid) => - return _x.10 [Compiler.result] size: 9 def _private.lean.run.erased.0._eval (a.1 : Lean.Elab.Command.Context) (a.2 : lcAny) (a.3 : lcVoid) : EST.Out Lean.Exception lcAny PUnit :=