fix: correctly handle never_extract attribute in LCNF CSE (#8952)

This PR fixes the handling of the `never_extract` attribute in the
compiler's CSE pass. There is an interesting debate to be had about
exactly how hard the compiler should try to avoid duplicating anything
that transitively uses `never_extract`, but this is the simplest form
and roughly matches the check in the old compiler (although due to
different handling of local function decls in the two compilers, the
consequences might be slightly different).

This gets half of the way to #8944.
This commit is contained in:
Cameron Zwarich 2025-06-23 16:03:10 -07:00 committed by GitHub
parent b0269d2875
commit 24cbd4efbe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 35 additions and 8 deletions

View file

@ -7,6 +7,7 @@ prelude
import Lean.Compiler.LCNF.CompilerM
import Lean.Compiler.LCNF.ToExpr
import Lean.Compiler.LCNF.PassManager
import Lean.Compiler.NeverExtractAttr
namespace Lean.Compiler.LCNF
@ -44,6 +45,13 @@ def replaceFun (decl : FunDecl) (fvarId : FVarId) : M Unit := do
eraseFunDecl decl
addFVarSubst decl.fvarId fvarId
def hasNeverExtract (v : LetValue) : CompilerM Bool :=
match v with
| .const declName .. =>
return hasNeverExtractAttribute (← getEnv) declName
| .lit _ | .erased | .proj .. | .fvar .. =>
return false
partial def _root_.Lean.Compiler.LCNF.Code.cse (shouldElimFunDecls : Bool) (code : Code) : CompilerM Code :=
go code |>.run' {}
where
@ -57,15 +65,18 @@ where
match code with
| .let decl k =>
let decl ← normLetDecl decl
-- We only apply CSE to pure code
let key := decl.value.toExpr
match (← get).map.find? key with
| some fvarId =>
replaceLet decl fvarId
go k
| none =>
addEntry key decl.fvarId
if (← hasNeverExtract decl.value) then
return code.updateLet! decl (← go k)
else
-- We only apply CSE to pure code
let key := decl.value.toExpr
match (← get).map.find? key with
| some fvarId =>
replaceLet decl fvarId
go k
| none =>
addEntry key decl.fvarId
return code.updateLet! decl (← go k)
| .fun decl k =>
let decl ← goFunDecl decl
if shouldElimFunDecls then

View file

@ -0,0 +1,13 @@
def test1 (a : Nat) : Nat :=
let f a :=
dbg_trace s!"{a}"
a
let g a :=
dbg_trace s!"{a + 0}"
a
(f a) + (g a)
def main : IO Unit :=
-- Use `eprintln` because that is what `dbg_trace` uses.
IO.eprintln f!"{test1 1}"

View file

@ -0,0 +1,3 @@
1
1
2