doc: lambda lifting

This commit is contained in:
Leonardo de Moura 2022-10-08 19:08:58 -07:00
parent b7d4fd03a3
commit 2efb1dbdf1

View file

@ -13,19 +13,41 @@ import Lean.Compiler.LCNF.Level
namespace Lean.Compiler.LCNF
namespace LambdaLifting
/-- Context for the `LiftM` monad. -/
structure Context where
/--
If `liftInstParamOnly` is `true`, then only local functions that take
local instances as parameters are lambda lifted.
-/
liftInstParamOnly : Bool := false
/-- Suffix for the new auxiliary declarations being created. -/
suffix : Name
/--
Declaration where lambda lifting is being applied.
We use it to provide the "base name" for auxiliary declarations and the flag `safe`.
-/
mainDecl : Decl
/-- State for the `LiftM` monad. -/
structure State where
/--
New auxiliary declarations
-/
decls : Array Decl := #[]
/-- Monad for applying lambda lifting. -/
abbrev LiftM := ReaderT Context (StateRefT State (ScopeT CompilerM))
/--
Return `true` if the given declaration takes a local instance as a parameter.
We lambda lift this kind of local function declaration before specialization.
-/
def hasInstParam (decl : FunDecl) : CompilerM Bool :=
decl.params.anyM fun param => return (← isArrowClass? param.type).isSome
/--
Return `true` if the given declaration should be lambda lifted.
-/
def shouldLift (decl : FunDecl) : LiftM Bool := do
if (← read).liftInstParamOnly then
hasInstParam decl
@ -33,7 +55,11 @@ def shouldLift (decl : FunDecl) : LiftM Bool := do
return true
open Internalize in
def mkAuxDecl (paramsNew : Array Param) (decl : FunDecl) : LiftM LetDecl := do
/--
Create a new auxiliary declaration. The array `closure` contains all free variables
occurring in `decl`.
-/
def mkAuxDecl (closure : Array Param) (decl : FunDecl) : LiftM LetDecl := do
let mainDecl := (← read).mainDecl
let nextIdx := (← get).decls.size
let nameNew := mainDecl.name ++ (← read).suffix.appendIndexAfter nextIdx
@ -41,7 +67,7 @@ def mkAuxDecl (paramsNew : Array Param) (decl : FunDecl) : LiftM LetDecl := do
auxDecl.save
modify fun { decls, .. } => { decls := decls.push auxDecl }
let us := auxDecl.levelParams.map mkLevelParam
let value := mkAppN (.const auxDecl.name us) (paramsNew.map (mkFVar ·.fvarId))
let value := mkAppN (.const auxDecl.name us) (closure.map (mkFVar ·.fvarId))
/- We reuse `decl`s `fvarId` to avoid substitution -/
let declNew := { fvarId := decl.fvarId, binderName := decl.binderName, type := decl.type, value }
modifyLCtx fun lctx => lctx.addLetDecl declNew
@ -49,7 +75,7 @@ def mkAuxDecl (paramsNew : Array Param) (decl : FunDecl) : LiftM LetDecl := do
return declNew
where
go (nameNew : Name) (safe : Bool) : InternalizeM Decl := do
let params := (← paramsNew.mapM internalizeParam) ++ (← decl.params.mapM internalizeParam)
let params := (← closure.mapM internalizeParam) ++ (← decl.params.mapM internalizeParam)
let value ← internalizeCode decl.value
let type ← value.inferType
let type ← mkForallParams params type
@ -100,18 +126,27 @@ partial def Decl.lambdaLifting (decl : Decl) (liftInstParamOnly : Bool) (suffix
let (decl, s) ← LambdaLifting.main decl |>.run { mainDecl := decl, liftInstParamOnly, suffix } |>.run {} |>.run {}
return s.decls.push decl
/--
Eliminate all local function declarations.
-/
def lambdaLifting : Pass where
phase := .mono
name := `lambdaLifting
run := fun decls => do
decls.foldlM (init := #[]) fun decls decl => return decls ++ (← decl.lambdaLifting false (suffix := `_lambda))
/--
During eager lambda lifting, we lift
- All local function declarations from instances (motivation: make sure it is cheap to inline them later)
- Local function declarations that take local instances as parameters (motivation: ensure they are specialized)
-/
def eagerLambdaLifting : Pass where
phase := .base
name := `eagerLambdaLifting
run := fun decls => do
decls.foldlM (init := #[]) fun decls decl => do
let liftInstParamOnly := !(← Meta.isInstance decl.name)
-- TODO: when performing eager lambda lifting in instances, we must check whether they are tagged with `[inline]` and propagate annotation to new functions
return decls ++ (← decl.lambdaLifting liftInstParamOnly (suffix := `_elambda))
builtin_initialize