perf: eta contract instead of lambda lifting if possible (#11451)
This PR adapts the lambda lifter in LCNF to eta contract instead of lambda lift if possible. This prevents the creation of a few hundred unnecessary lambdas across the code base.
This commit is contained in:
parent
0646bc5979
commit
3dd99fc29c
2 changed files with 64 additions and 12 deletions
|
|
@ -40,6 +40,10 @@ structure Context where
|
|||
We use this feature to implement `@[inline] instance ...` and `@[always_inline] instance ...`
|
||||
-/
|
||||
minSize : Nat := 0
|
||||
/--
|
||||
Allow for eta contraction instead of lifting to a lambda if possible.
|
||||
-/
|
||||
allowEtaContraction : Bool := true
|
||||
|
||||
|
||||
/-- State for the `LiftM` monad. -/
|
||||
|
|
@ -81,6 +85,13 @@ partial def mkAuxDeclName : LiftM Name := do
|
|||
if (← getDecl? nameNew).isNone then return nameNew
|
||||
mkAuxDeclName
|
||||
|
||||
def replaceFunDecl (decl : FunDecl) (value : LetValue) : LiftM LetDecl := do
|
||||
/- We reuse `decl`s `fvarId` to avoid substitution -/
|
||||
let declNew := { fvarId := decl.fvarId, binderName := decl.binderName, type := decl.type, value }
|
||||
modifyLCtx fun lctx => lctx.addLetDecl declNew
|
||||
eraseFunDecl decl
|
||||
return declNew
|
||||
|
||||
open Internalize in
|
||||
/--
|
||||
Create a new auxiliary declaration. The array `closure` contains all free variables
|
||||
|
|
@ -100,11 +111,7 @@ def mkAuxDecl (closure : Array Param) (decl : FunDecl) : LiftM LetDecl := do
|
|||
auxDecl.erase
|
||||
pure declName
|
||||
let value := .const auxDeclName us (closure.map (.fvar ·.fvarId))
|
||||
/- We reuse `decl`s `fvarId` to avoid substitution -/
|
||||
let declNew := { fvarId := decl.fvarId, binderName := decl.binderName, type := decl.type, value }
|
||||
modifyLCtx fun lctx => lctx.addLetDecl declNew
|
||||
eraseFunDecl decl
|
||||
return declNew
|
||||
replaceFunDecl decl value
|
||||
where
|
||||
go (nameNew : Name) (safe : Bool) (inlineAttr? : Option InlineAttributeKind) : InternalizeM Decl := do
|
||||
let params := (← closure.mapM internalizeParam) ++ (← decl.params.mapM internalizeParam)
|
||||
|
|
@ -115,6 +122,20 @@ where
|
|||
let decl := { name := nameNew, levelParams := [], params, type, value, safe, inlineAttr?, recursive := false : Decl }
|
||||
return decl.setLevelParams
|
||||
|
||||
def etaContractibleDecl? (decl : FunDecl) : LiftM (Option LetDecl) := do
|
||||
if !(← read).allowEtaContraction then return none
|
||||
let .let { fvarId := letVar, value := .const declName us args, .. } (.return retVar) := decl.value
|
||||
| return none
|
||||
if letVar != retVar then return none
|
||||
if args.size != decl.params.size then return none
|
||||
if (← getDecl? declName).isNone then return none
|
||||
for arg in args, param in decl.params do
|
||||
let .fvar argVar := arg | return none
|
||||
if argVar != param.fvarId then return none
|
||||
|
||||
let value := .const declName us #[]
|
||||
replaceFunDecl decl value
|
||||
|
||||
mutual
|
||||
partial def visitFunDecl (funDecl : FunDecl) : LiftM FunDecl := do
|
||||
let value ← withParams funDecl.params <| visitCode funDecl.value
|
||||
|
|
@ -128,9 +149,13 @@ mutual
|
|||
| .fun decl k =>
|
||||
let decl ← visitFunDecl decl
|
||||
if (← shouldLift decl) then
|
||||
let scope ← getScope
|
||||
let (_, params, _) ← Closure.run (inScope := scope.contains) <| Closure.collectFunDecl decl
|
||||
let declNew ← mkAuxDecl params decl
|
||||
let declNew ← do
|
||||
if let some letDecl ← etaContractibleDecl? decl then
|
||||
pure letDecl
|
||||
else
|
||||
let scope ← getScope
|
||||
let (_, params, _) ← Closure.run (inScope := scope.contains) <| Closure.collectFunDecl decl
|
||||
mkAuxDecl params decl
|
||||
let k ← withFVar declNew.fvarId <| visitCode k
|
||||
return .let declNew k
|
||||
else
|
||||
|
|
@ -155,8 +180,17 @@ def main (decl : Decl) : LiftM Decl := do
|
|||
|
||||
end LambdaLifting
|
||||
|
||||
partial def Decl.lambdaLifting (decl : Decl) (liftInstParamOnly : Bool) (suffix : Name) (inheritInlineAttrs := false) (minSize := 0) : CompilerM (Array Decl) := do
|
||||
let (decl, s) ← LambdaLifting.main decl |>.run { mainDecl := decl, liftInstParamOnly, suffix, inheritInlineAttrs, minSize } |>.run {} |>.run {}
|
||||
partial def Decl.lambdaLifting (decl : Decl) (liftInstParamOnly : Bool) (allowEtaContraction : Bool)
|
||||
(suffix : Name) (inheritInlineAttrs := false) (minSize := 0) : CompilerM (Array Decl) := do
|
||||
let ctx := {
|
||||
mainDecl := decl,
|
||||
liftInstParamOnly,
|
||||
suffix,
|
||||
inheritInlineAttrs,
|
||||
minSize,
|
||||
allowEtaContraction
|
||||
}
|
||||
let (decl, s) ← LambdaLifting.main decl |>.run ctx |>.run {} |>.run {}
|
||||
return s.decls.push decl
|
||||
|
||||
/--
|
||||
|
|
@ -166,7 +200,8 @@ def lambdaLifting : Pass where
|
|||
phase := .mono
|
||||
name := `lambdaLifting
|
||||
run := fun decls => do
|
||||
decls.foldlM (init := #[]) fun decls decl => return decls ++ (← decl.lambdaLifting false (suffix := `_lam))
|
||||
decls.foldlM (init := #[]) fun decls decl =>
|
||||
return decls ++ (← decl.lambdaLifting false true (suffix := `_lam))
|
||||
|
||||
/--
|
||||
During eager lambda lifting, we inspect declarations that are not inlineable or instances (doing it
|
||||
|
|
@ -182,7 +217,7 @@ def eagerLambdaLifting : Pass where
|
|||
if decl.inlineable || (← Meta.isInstance decl.name) then
|
||||
return decls.push decl
|
||||
else
|
||||
return decls ++ (← decl.lambdaLifting (liftInstParamOnly := true) (suffix := `_elam))
|
||||
return decls ++ (← decl.lambdaLifting (liftInstParamOnly := true) (allowEtaContraction := false) (suffix := `_elam))
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Compiler.eagerLambdaLifting (inherited := true)
|
||||
|
|
|
|||
17
tests/lean/run/eta_lambda_lift.lean
Normal file
17
tests/lean/run/eta_lambda_lift.lean
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
import Lean.Util.FindExpr
|
||||
|
||||
/-!
|
||||
This test asserts that the compiler will eta contract trivial lambdas instead of lambda lifting
|
||||
them.
|
||||
-/
|
||||
|
||||
/--
|
||||
trace: [Compiler.lambdaLifting] size: 2
|
||||
def test e : Option Lean.Expr :=
|
||||
let _f.1 := Lean.Expr.hasMVar;
|
||||
let _x.2 := Lean.Expr.findImpl? _f.1 e;
|
||||
return _x.2
|
||||
-/
|
||||
#guard_msgs in
|
||||
set_option trace.Compiler.lambdaLifting true in
|
||||
def test (e : Lean.Expr) := e.find? (fun e => e.hasMVar)
|
||||
Loading…
Add table
Reference in a new issue