fix: correctly handle never_extract attribute in LCNF CSE (#8952)
This PR fixes the handling of the `never_extract` attribute in the compiler's CSE pass. There is an interesting debate to be had about exactly how hard the compiler should try to avoid duplicating anything that transitively uses `never_extract`, but this is the simplest form and roughly matches the check in the old compiler (although due to different handling of local function decls in the two compilers, the consequences might be slightly different). This gets half of the way to #8944.
This commit is contained in:
parent
b0269d2875
commit
24cbd4efbe
3 changed files with 35 additions and 8 deletions
|
|
@ -7,6 +7,7 @@ prelude
|
|||
import Lean.Compiler.LCNF.CompilerM
|
||||
import Lean.Compiler.LCNF.ToExpr
|
||||
import Lean.Compiler.LCNF.PassManager
|
||||
import Lean.Compiler.NeverExtractAttr
|
||||
|
||||
namespace Lean.Compiler.LCNF
|
||||
|
||||
|
|
@ -44,6 +45,13 @@ def replaceFun (decl : FunDecl) (fvarId : FVarId) : M Unit := do
|
|||
eraseFunDecl decl
|
||||
addFVarSubst decl.fvarId fvarId
|
||||
|
||||
def hasNeverExtract (v : LetValue) : CompilerM Bool :=
|
||||
match v with
|
||||
| .const declName .. =>
|
||||
return hasNeverExtractAttribute (← getEnv) declName
|
||||
| .lit _ | .erased | .proj .. | .fvar .. =>
|
||||
return false
|
||||
|
||||
partial def _root_.Lean.Compiler.LCNF.Code.cse (shouldElimFunDecls : Bool) (code : Code) : CompilerM Code :=
|
||||
go code |>.run' {}
|
||||
where
|
||||
|
|
@ -57,15 +65,18 @@ where
|
|||
match code with
|
||||
| .let decl k =>
|
||||
let decl ← normLetDecl decl
|
||||
-- We only apply CSE to pure code
|
||||
let key := decl.value.toExpr
|
||||
match (← get).map.find? key with
|
||||
| some fvarId =>
|
||||
replaceLet decl fvarId
|
||||
go k
|
||||
| none =>
|
||||
addEntry key decl.fvarId
|
||||
if (← hasNeverExtract decl.value) then
|
||||
return code.updateLet! decl (← go k)
|
||||
else
|
||||
-- We only apply CSE to pure code
|
||||
let key := decl.value.toExpr
|
||||
match (← get).map.find? key with
|
||||
| some fvarId =>
|
||||
replaceLet decl fvarId
|
||||
go k
|
||||
| none =>
|
||||
addEntry key decl.fvarId
|
||||
return code.updateLet! decl (← go k)
|
||||
| .fun decl k =>
|
||||
let decl ← goFunDecl decl
|
||||
if shouldElimFunDecls then
|
||||
|
|
|
|||
13
tests/compiler/never_extract.lean
Normal file
13
tests/compiler/never_extract.lean
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
def test1 (a : Nat) : Nat :=
|
||||
let f a :=
|
||||
dbg_trace s!"{a}"
|
||||
a
|
||||
let g a :=
|
||||
dbg_trace s!"{a + 0}"
|
||||
a
|
||||
(f a) + (g a)
|
||||
|
||||
def main : IO Unit :=
|
||||
-- Use `eprintln` because that is what `dbg_trace` uses.
|
||||
IO.eprintln f!"{test1 1}"
|
||||
|
||||
3
tests/compiler/never_extract.lean.expected.out
Normal file
3
tests/compiler/never_extract.lean.expected.out
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
1
|
||||
1
|
||||
2
|
||||
Loading…
Add table
Reference in a new issue