lean4-htt/src/Lean/Compiler/LCNF/ExtractClosed.lean
2025-07-25 12:02:51 +00:00

164 lines
5.3 KiB
Text

/-
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Cameron Zwarich
-/
module
prelude
public import Lean.Compiler.ClosedTermCache
public import Lean.Compiler.NeverExtractAttr
public import Lean.Compiler.LCNF.Basic
public import Lean.Compiler.LCNF.InferType
public import Lean.Compiler.LCNF.Internalize
public import Lean.Compiler.LCNF.MonoTypes
public import Lean.Compiler.LCNF.PassManager
public import Lean.Compiler.LCNF.ToExpr
public section
namespace Lean.Compiler.LCNF
namespace ExtractClosed
abbrev ExtractM := StateRefT (Array CodeDecl) CompilerM
mutual
partial def extractLetValue (v : LetValue) : ExtractM Unit := do
match v with
| .const _ _ args => args.forM extractArg
| .fvar fnVar args =>
extractFVar fnVar
args.forM extractArg
| .proj _ _ baseVar => extractFVar baseVar
| .lit _ | .erased => return ()
partial def extractArg (arg : Arg) : ExtractM Unit := do
match arg with
| .fvar fvarId => extractFVar fvarId
| .type _ | .erased => return ()
partial def extractFVar (fvarId : FVarId) : ExtractM Unit := do
if let some letDecl ← findLetDecl? fvarId then
modify fun decls => decls.push (.let letDecl)
extractLetValue letDecl.value
end
def isIrrelevantArg (arg : Arg) : Bool :=
match arg with
| .erased | .type _ => true
| .fvar _ => false
structure Context where
baseName : Name
sccDecls : Array Decl
structure State where
decls : Array Decl := {}
abbrev M := ReaderT Context $ StateRefT State CompilerM
mutual
partial def shouldExtractLetValue (isRoot : Bool) (v : LetValue) : M Bool := do
match v with
| .lit (.str _) => return true
| .lit (.nat v) =>
-- The old compiler's implementation used the runtime's `is_scalar` function, which
-- introduces a dependency on the architecture used by the compiler.
return !isRoot || v >= Nat.pow 2 63
| .lit _ | .erased => return !isRoot
| .const name _ args =>
if (← read).sccDecls.any (·.name == name) then
return false
if hasNeverExtractAttribute (← getEnv) name then
return false
if isRoot then
if let some constInfo := (← getEnv).find? name then
let shouldExtract := match constInfo with
| .defnInfo val => val.type.isForall
| .ctorInfo _ => !(args.all isIrrelevantArg)
| _ => true
if !shouldExtract then
return false
args.allM shouldExtractArg
| .fvar fnVar args => return (← shouldExtractFVar fnVar) && (← args.allM shouldExtractArg)
| .proj _ _ baseVar => shouldExtractFVar baseVar
partial def shouldExtractArg (arg : Arg) : M Bool := do
match arg with
| .fvar fvarId => shouldExtractFVar fvarId
| .type _ | .erased => return true
partial def shouldExtractFVar (fvarId : FVarId) : M Bool := do
if let some letDecl ← findLetDecl? fvarId then
shouldExtractLetValue false letDecl.value
else
return false
end
mutual
partial def visitCode (code : Code) : M Code := do
match code with
| .let decl k =>
if (← shouldExtractLetValue true decl.value) then
let ⟨_, decls⟩ ← extractLetValue decl.value |>.run {}
let decls := decls.reverse.push (.let decl)
let decls ← decls.mapM Internalize.internalizeCodeDecl |>.run' {}
let closedCode := attachCodeDecls decls (.return decls.back!.fvarId)
let closedExpr := closedCode.toExpr
let env ← getEnv
let name ← if let some closedTermName := getClosedTermName? env closedExpr then
eraseCode closedCode
pure closedTermName
else
let name := (← read).baseName ++ (`_closed).appendIndexAfter (← get).decls.size
cacheClosedTermName env closedExpr name |> setEnv
let decl := { name, levelParams := [], type := decl.type, params := #[],
value := .code closedCode, inlineAttr? := some .noinline }
decl.saveMono
modify fun s => { s with decls := s.decls.push decl }
pure name
let decl ← decl.updateValue (.const name [] #[])
return code.updateLet! decl (← visitCode k)
else
return code.updateLet! decl (← visitCode k)
| .fun decl k =>
let decl ← decl.updateValue (← visitCode decl.value)
return code.updateFun! decl (← visitCode k)
| .jp decl k =>
let decl ← decl.updateValue (← visitCode decl.value)
return code.updateFun! decl (← visitCode k)
| .cases cases =>
let alts ← cases.alts.mapM (fun alt => do return alt.updateCode (← visitCode alt.getCode))
return code.updateAlts! alts
| .jmp .. | .return _ | .unreach .. => return code
end
def visitDecl (decl : Decl) : M Decl := do
let value ← decl.value.mapCodeM visitCode
return { decl with value }
end ExtractClosed
partial def Decl.extractClosed (decl : Decl) (sccDecls : Array Decl) : CompilerM (Array Decl) := do
let ⟨decl, s⟩ ← ExtractClosed.visitDecl decl |>.run { baseName := decl.name, sccDecls } |>.run {}
return s.decls.push decl
def extractClosed : Pass where
phase := .mono
name := `extractClosed
run := fun decls => do
if (← getConfig).extractClosed then
decls.foldlM (init := #[]) fun newDecls decl =>
return newDecls ++ (← decl.extractClosed decls)
else
return decls
builtin_initialize registerTraceClass `Compiler.extractClosed (inherited := true)
end Lean.Compiler.LCNF