diff --git a/src/Lean/Compiler/LCNF/CSE.lean b/src/Lean/Compiler/LCNF/CSE.lean index b49d260460..bf0cae5557 100644 --- a/src/Lean/Compiler/LCNF/CSE.lean +++ b/src/Lean/Compiler/LCNF/CSE.lean @@ -7,6 +7,7 @@ prelude import Lean.Compiler.LCNF.CompilerM import Lean.Compiler.LCNF.ToExpr import Lean.Compiler.LCNF.PassManager +import Lean.Compiler.NeverExtractAttr namespace Lean.Compiler.LCNF @@ -44,6 +45,13 @@ def replaceFun (decl : FunDecl) (fvarId : FVarId) : M Unit := do eraseFunDecl decl addFVarSubst decl.fvarId fvarId +def hasNeverExtract (v : LetValue) : CompilerM Bool := + match v with + | .const declName .. => + return hasNeverExtractAttribute (← getEnv) declName + | .lit _ | .erased | .proj .. | .fvar .. => + return false + partial def _root_.Lean.Compiler.LCNF.Code.cse (shouldElimFunDecls : Bool) (code : Code) : CompilerM Code := go code |>.run' {} where @@ -57,15 +65,18 @@ where match code with | .let decl k => let decl ← normLetDecl decl - -- We only apply CSE to pure code - let key := decl.value.toExpr - match (← get).map.find? key with - | some fvarId => - replaceLet decl fvarId - go k - | none => - addEntry key decl.fvarId + if (← hasNeverExtract decl.value) then return code.updateLet! decl (← go k) + else + -- We only apply CSE to pure code + let key := decl.value.toExpr + match (← get).map.find? key with + | some fvarId => + replaceLet decl fvarId + go k + | none => + addEntry key decl.fvarId + return code.updateLet! decl (← go k) | .fun decl k => let decl ← goFunDecl decl if shouldElimFunDecls then diff --git a/tests/compiler/never_extract.lean b/tests/compiler/never_extract.lean new file mode 100644 index 0000000000..de16a9291d --- /dev/null +++ b/tests/compiler/never_extract.lean @@ -0,0 +1,13 @@ +def test1 (a : Nat) : Nat := + let f a := + dbg_trace s!"{a}" + a + let g a := + dbg_trace s!"{a + 0}" + a + (f a) + (g a) + +def main : IO Unit := + -- Use `eprintln` because that is what `dbg_trace` uses. + IO.eprintln f!"{test1 1}" + diff --git a/tests/compiler/never_extract.lean.expected.out b/tests/compiler/never_extract.lean.expected.out new file mode 100644 index 0000000000..33280629d4 --- /dev/null +++ b/tests/compiler/never_extract.lean.expected.out @@ -0,0 +1,3 @@ +1 +1 +2