feat: eager lambda lifting

This commit is contained in:
Leonardo de Moura 2022-10-08 18:57:37 -07:00
parent 878e72b2f9
commit b7d4fd03a3
2 changed files with 16 additions and 4 deletions

View file

@ -3,6 +3,7 @@ Copyright (c) 2022 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.Meta.Instances
import Lean.Compiler.LCNF.Closure
import Lean.Compiler.LCNF.Types
import Lean.Compiler.LCNF.MonadScope
@ -14,6 +15,7 @@ namespace LambdaLifting
structure Context where
liftInstParamOnly : Bool := false
suffix : Name
mainDecl : Decl
structure State where
@ -34,7 +36,7 @@ open Internalize in
def mkAuxDecl (paramsNew : Array Param) (decl : FunDecl) : LiftM LetDecl := do
let mainDecl := (← read).mainDecl
let nextIdx := (← get).decls.size
let nameNew := mainDecl.name ++ (`_lambda).appendIndexAfter nextIdx
let nameNew := mainDecl.name ++ (← read).suffix.appendIndexAfter nextIdx
let auxDecl ← go nameNew mainDecl.safe |>.run' {}
auxDecl.save
modify fun { decls, .. } => { decls := decls.push auxDecl }
@ -94,17 +96,26 @@ def main (decl : Decl) : LiftM Decl := do
end LambdaLifting
partial def Decl.lambdaLifting (decl : Decl) (liftInstParamOnly : Bool) : CompilerM (Array Decl) := do
let (decl, s) ← LambdaLifting.main decl |>.run { mainDecl := decl, liftInstParamOnly } |>.run {} |>.run {}
partial def Decl.lambdaLifting (decl : Decl) (liftInstParamOnly : Bool) (suffix : Name) : CompilerM (Array Decl) := do
let (decl, s) ← LambdaLifting.main decl |>.run { mainDecl := decl, liftInstParamOnly, suffix } |>.run {} |>.run {}
return s.decls.push decl
def lambdaLifting : Pass where
phase := .mono
name := `lambdaLifting
run := fun decls => do
decls.foldlM (init := #[]) fun decls decl => return decls ++ (← decl.lambdaLifting false)
decls.foldlM (init := #[]) fun decls decl => return decls ++ (← decl.lambdaLifting false (suffix := `_lambda))
def eagerLambdaLifting : Pass where
phase := .base
name := `eagerLambdaLifting
run := fun decls => do
decls.foldlM (init := #[]) fun decls decl => do
let liftInstParamOnly := !(← Meta.isInstance decl.name)
return decls ++ (← decl.lambdaLifting liftInstParamOnly (suffix := `_elambda))
builtin_initialize
registerTraceClass `Compiler.eagerLambdaLifting (inherited := true)
registerTraceClass `Compiler.lambdaLifting (inherited := true)
end Lean.Compiler.LCNF

View file

@ -50,6 +50,7 @@ def builtinPassManager : PassManager := {
pullFunDecls,
reduceJpArity,
simp { etaPoly := true, inlinePartial := true, implementedBy := true } (occurrence := 1),
eagerLambdaLifting,
specialize,
simp (occurrence := 2),
cse,