lean4-htt/src/Lean/Compiler/LCNF/LambdaLifting.lean
2022-10-14 08:42:50 -07:00

197 lines
7.5 KiB
Text

/-
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.Meta.Instances
import Lean.Compiler.InlineAttrs
import Lean.Compiler.LCNF.Closure
import Lean.Compiler.LCNF.Types
import Lean.Compiler.LCNF.MonadScope
import Lean.Compiler.LCNF.Internalize
import Lean.Compiler.LCNF.Level
import Lean.Compiler.LCNF.AuxDeclCache
namespace Lean.Compiler.LCNF
namespace LambdaLifting
/-- Context for the `LiftM` monad. -/
structure Context where
/--
If `liftInstParamOnly` is `true`, then only local functions that take
local instances as parameters are lambda lifted.
-/
liftInstParamOnly : Bool := false
/-- Suffix for the new auxiliary declarations being created. -/
suffix : Name
/--
Declaration where lambda lifting is being applied.
We use it to provide the "base name" for auxiliary declarations and the flag `safe`.
-/
mainDecl : Decl
/--
If true, the lambda-lifted functions inherit the inline attribute from `mainDecl`.
We use this feature to implement `@[inline] instance ...` and `@[alwaysInline] instance ...`
-/
inheritInlineAttrs := false
/--
Only local functions with `size > minSize` are lambda lifted.
We use this feature to implement `@[inline] instance ...` and `@[alwaysInline] instance ...`
-/
minSize : Nat := 0
/-- State for the `LiftM` monad. -/
structure State where
/--
New auxiliary declarations
-/
decls : Array Decl := #[]
/--
Next index for generating auxiliary declaration name.
-/
nextIdx := 0
/-- Monad for applying lambda lifting. -/
abbrev LiftM := ReaderT Context (StateRefT State (ScopeT CompilerM))
/--
Return `true` if the given declaration takes a local instance as a parameter.
We lambda lift this kind of local function declaration before specialization.
-/
def hasInstParam (decl : FunDecl) : CompilerM Bool :=
decl.params.anyM fun param => return (← isArrowClass? param.type).isSome
/--
Return `true` if the given declaration should be lambda lifted.
-/
def shouldLift (decl : FunDecl) : LiftM Bool := do
let minSize := (← read).minSize
if decl.value.size < minSize then
return false
else if (← read).liftInstParamOnly then
hasInstParam decl
else
return true
partial def mkAuxDeclName : LiftM Name := do
let nextIdx ← modifyGet fun s => (s.nextIdx, { s with nextIdx := s.nextIdx + 1})
let nameNew := (← read).mainDecl.name ++ (← read).suffix.appendIndexAfter nextIdx
if (← getDecl? nameNew).isNone then return nameNew
mkAuxDeclName
open Internalize in
/--
Create a new auxiliary declaration. The array `closure` contains all free variables
occurring in `decl`.
-/
def mkAuxDecl (closure : Array Param) (decl : FunDecl) : LiftM LetDecl := do
let nameNew ← mkAuxDeclName
let inlineAttr? := if (← read).inheritInlineAttrs then (← read).mainDecl.inlineAttr? else none
let auxDecl ← go nameNew (← read).mainDecl.safe inlineAttr? |>.run' {}
let us := auxDecl.levelParams.map mkLevelParam
let auxDeclName ← match (← cacheAuxDecl auxDecl) with
| .new =>
auxDecl.save
modify fun { decls, .. } => { decls := decls.push auxDecl }
pure auxDecl.name
| .alreadyCached declName =>
auxDecl.erase
pure declName
let value := mkAppN (.const auxDeclName us) (closure.map (mkFVar ·.fvarId))
/- We reuse `decl`s `fvarId` to avoid substitution -/
let declNew := { fvarId := decl.fvarId, binderName := decl.binderName, type := decl.type, value }
modifyLCtx fun lctx => lctx.addLetDecl declNew
eraseFunDecl decl
return declNew
where
go (nameNew : Name) (safe : Bool) (inlineAttr? : Option InlineAttributeKind) : InternalizeM Decl := do
let params := (← closure.mapM internalizeParam) ++ (← decl.params.mapM internalizeParam)
let value ← internalizeCode decl.value
let type ← value.inferType
let type ← mkForallParams params type
let decl := { name := nameNew, levelParams := [], params, type, value, safe, inlineAttr?, recursive := false : Decl }
return decl.setLevelParams
mutual
partial def visitFunDecl (funDecl : FunDecl) : LiftM FunDecl := do
let value ← withParams funDecl.params <| visitCode funDecl.value
funDecl.update' funDecl.type value
partial def visitCode (code : Code) : LiftM Code := do
match code with
| .let decl k =>
let k ← withFVar decl.fvarId <| visitCode k
return code.updateLet! decl k
| .fun decl k =>
let decl ← visitFunDecl decl
if (← shouldLift decl) then
let scope ← getScope
let (_, params, _) ← Closure.run (inScope := scope.contains) <| Closure.collectFunDecl decl
let declNew ← mkAuxDecl params decl
let k ← withFVar declNew.fvarId <| visitCode k
return .let declNew k
else
let k ← withFVar decl.fvarId <| visitCode k
return code.updateFun! decl k
| .jp decl k =>
let decl ← visitFunDecl decl
let k ← withFVar decl.fvarId <| visitCode k
return code.updateFun! decl k
| .cases c =>
let alts ← c.alts.mapMonoM fun alt =>
match alt with
| .default k => return alt.updateCode (← visitCode k)
| .alt _ ps k => withParams ps do return alt.updateCode (← visitCode k)
return code.updateAlts! alts
| .unreach .. | .jmp .. | .return .. => return code
end
def main (decl : Decl) : LiftM Decl := do
let value ← withParams decl.params <| visitCode decl.value
return { decl with value }
end LambdaLifting
partial def Decl.lambdaLifting (decl : Decl) (liftInstParamOnly : Bool) (suffix : Name) (inheritInlineAttrs := false) (minSize := 0) : CompilerM (Array Decl) := do
let (decl, s) ← LambdaLifting.main decl |>.run { mainDecl := decl, liftInstParamOnly, suffix, inheritInlineAttrs, minSize } |>.run {} |>.run {}
return s.decls.push decl
/--
Eliminate all local function declarations.
-/
def lambdaLifting : Pass where
phase := .mono
name := `lambdaLifting
run := fun decls => do
decls.foldlM (init := #[]) fun decls decl => return decls ++ (← decl.lambdaLifting false (suffix := `_lambda))
/--
During eager lambda lifting, we lift
- All local function declarations from instances (motivation: make sure it is cheap to inline them later)
- Local function declarations that take local instances as parameters (motivation: ensure they are specialized)
-/
def eagerLambdaLifting : Pass where
phase := .base
name := `eagerLambdaLifting
run := fun decls => do
decls.foldlM (init := #[]) fun decls decl => do
if (← Meta.isInstance decl.name) then
/-
Recall that we lambda lift local functions in instances to control code blowup, and make sure they are cheap to inline.
It is not worth to lift tiny ones. TODO: evaluate whether we should add a compiler option to control the min size.
Recall that when performing eager lambda lifting in instances, we progatate the `[inline]` annotations to the new auxiliary functions.
Note: we have tried `if decl.inlineable then return decls.push decl`, but it didn't help in our preliminary experiments.
-/
return decls ++ (← decl.lambdaLifting (liftInstParamOnly := false) (suffix := `_elambda) (inheritInlineAttrs := true) (minSize := 3))
else
return decls ++ (← decl.lambdaLifting (liftInstParamOnly := true) (suffix := `_elambda))
builtin_initialize
registerTraceClass `Compiler.eagerLambdaLifting (inherited := true)
registerTraceClass `Compiler.lambdaLifting (inherited := true)
end Lean.Compiler.LCNF