From 2efb1dbdf1adff30b232ea8763b72379a4434dd0 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 8 Oct 2022 19:08:58 -0700 Subject: [PATCH] doc: lambda lifting --- src/Lean/Compiler/LCNF/LambdaLifting.lean | 41 +++++++++++++++++++++-- 1 file changed, 38 insertions(+), 3 deletions(-) diff --git a/src/Lean/Compiler/LCNF/LambdaLifting.lean b/src/Lean/Compiler/LCNF/LambdaLifting.lean index e1850a7070..32400d83c2 100644 --- a/src/Lean/Compiler/LCNF/LambdaLifting.lean +++ b/src/Lean/Compiler/LCNF/LambdaLifting.lean @@ -13,19 +13,41 @@ import Lean.Compiler.LCNF.Level namespace Lean.Compiler.LCNF namespace LambdaLifting +/-- Context for the `LiftM` monad. -/ structure Context where + /-- + If `liftInstParamOnly` is `true`, then only local functions that take + local instances as parameters are lambda lifted. + -/ liftInstParamOnly : Bool := false + /-- Suffix for the new auxiliary declarations being created. -/ suffix : Name + /-- + Declaration where lambda lifting is being applied. + We use it to provide the "base name" for auxiliary declarations and the flag `safe`. + -/ mainDecl : Decl +/-- State for the `LiftM` monad. -/ structure State where + /-- + New auxiliary declarations + -/ decls : Array Decl := #[] +/-- Monad for applying lambda lifting. -/ abbrev LiftM := ReaderT Context (StateRefT State (ScopeT CompilerM)) +/-- +Return `true` if the given declaration takes a local instance as a parameter. +We lambda lift this kind of local function declaration before specialization. +-/ def hasInstParam (decl : FunDecl) : CompilerM Bool := decl.params.anyM fun param => return (← isArrowClass? param.type).isSome +/-- +Return `true` if the given declaration should be lambda lifted. +-/ def shouldLift (decl : FunDecl) : LiftM Bool := do if (← read).liftInstParamOnly then hasInstParam decl @@ -33,7 +55,11 @@ def shouldLift (decl : FunDecl) : LiftM Bool := do return true open Internalize in -def mkAuxDecl (paramsNew : Array Param) (decl : FunDecl) : LiftM LetDecl := do +/-- +Create a new auxiliary declaration. The array `closure` contains all free variables +occurring in `decl`. +-/ +def mkAuxDecl (closure : Array Param) (decl : FunDecl) : LiftM LetDecl := do let mainDecl := (← read).mainDecl let nextIdx := (← get).decls.size let nameNew := mainDecl.name ++ (← read).suffix.appendIndexAfter nextIdx @@ -41,7 +67,7 @@ def mkAuxDecl (paramsNew : Array Param) (decl : FunDecl) : LiftM LetDecl := do auxDecl.save modify fun { decls, .. } => { decls := decls.push auxDecl } let us := auxDecl.levelParams.map mkLevelParam - let value := mkAppN (.const auxDecl.name us) (paramsNew.map (mkFVar ·.fvarId)) + let value := mkAppN (.const auxDecl.name us) (closure.map (mkFVar ·.fvarId)) /- We reuse `decl`s `fvarId` to avoid substitution -/ let declNew := { fvarId := decl.fvarId, binderName := decl.binderName, type := decl.type, value } modifyLCtx fun lctx => lctx.addLetDecl declNew @@ -49,7 +75,7 @@ def mkAuxDecl (paramsNew : Array Param) (decl : FunDecl) : LiftM LetDecl := do return declNew where go (nameNew : Name) (safe : Bool) : InternalizeM Decl := do - let params := (← paramsNew.mapM internalizeParam) ++ (← decl.params.mapM internalizeParam) + let params := (← closure.mapM internalizeParam) ++ (← decl.params.mapM internalizeParam) let value ← internalizeCode decl.value let type ← value.inferType let type ← mkForallParams params type @@ -100,18 +126,27 @@ partial def Decl.lambdaLifting (decl : Decl) (liftInstParamOnly : Bool) (suffix let (decl, s) ← LambdaLifting.main decl |>.run { mainDecl := decl, liftInstParamOnly, suffix } |>.run {} |>.run {} return s.decls.push decl +/-- +Eliminate all local function declarations. +-/ def lambdaLifting : Pass where phase := .mono name := `lambdaLifting run := fun decls => do decls.foldlM (init := #[]) fun decls decl => return decls ++ (← decl.lambdaLifting false (suffix := `_lambda)) +/-- +During eager lambda lifting, we lift +- All local function declarations from instances (motivation: make sure it is cheap to inline them later) +- Local function declarations that take local instances as parameters (motivation: ensure they are specialized) +-/ def eagerLambdaLifting : Pass where phase := .base name := `eagerLambdaLifting run := fun decls => do decls.foldlM (init := #[]) fun decls decl => do let liftInstParamOnly := !(← Meta.isInstance decl.name) + -- TODO: when performing eager lambda lifting in instances, we must check whether they are tagged with `[inline]` and propagate annotation to new functions return decls ++ (← decl.lambdaLifting liftInstParamOnly (suffix := `_elambda)) builtin_initialize