From b7d4fd03a3827fca5229e8b17516cf5af7c1ef57 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 8 Oct 2022 18:57:37 -0700 Subject: [PATCH] feat: eager lambda lifting --- src/Lean/Compiler/LCNF/LambdaLifting.lean | 19 +++++++++++++++---- src/Lean/Compiler/LCNF/Passes.lean | 1 + 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/Lean/Compiler/LCNF/LambdaLifting.lean b/src/Lean/Compiler/LCNF/LambdaLifting.lean index 198c31989a..e1850a7070 100644 --- a/src/Lean/Compiler/LCNF/LambdaLifting.lean +++ b/src/Lean/Compiler/LCNF/LambdaLifting.lean @@ -3,6 +3,7 @@ Copyright (c) 2022 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ +import Lean.Meta.Instances import Lean.Compiler.LCNF.Closure import Lean.Compiler.LCNF.Types import Lean.Compiler.LCNF.MonadScope @@ -14,6 +15,7 @@ namespace LambdaLifting structure Context where liftInstParamOnly : Bool := false + suffix : Name mainDecl : Decl structure State where @@ -34,7 +36,7 @@ open Internalize in def mkAuxDecl (paramsNew : Array Param) (decl : FunDecl) : LiftM LetDecl := do let mainDecl := (← read).mainDecl let nextIdx := (← get).decls.size - let nameNew := mainDecl.name ++ (`_lambda).appendIndexAfter nextIdx + let nameNew := mainDecl.name ++ (← read).suffix.appendIndexAfter nextIdx let auxDecl ← go nameNew mainDecl.safe |>.run' {} auxDecl.save modify fun { decls, .. } => { decls := decls.push auxDecl } @@ -94,17 +96,26 @@ def main (decl : Decl) : LiftM Decl := do end LambdaLifting -partial def Decl.lambdaLifting (decl : Decl) (liftInstParamOnly : Bool) : CompilerM (Array Decl) := do - let (decl, s) ← LambdaLifting.main decl |>.run { mainDecl := decl, liftInstParamOnly } |>.run {} |>.run {} +partial def Decl.lambdaLifting (decl : Decl) (liftInstParamOnly : Bool) (suffix : Name) : CompilerM (Array Decl) := do + let (decl, s) ← LambdaLifting.main decl |>.run { mainDecl := decl, liftInstParamOnly, suffix } |>.run {} |>.run {} return s.decls.push decl def lambdaLifting : Pass where phase := .mono name := `lambdaLifting run := fun decls => do - decls.foldlM (init := #[]) fun decls decl => return decls ++ (← decl.lambdaLifting false) + decls.foldlM (init := #[]) fun decls decl => return decls ++ (← decl.lambdaLifting false (suffix := `_lambda)) + +def eagerLambdaLifting : Pass where + phase := .base + name := `eagerLambdaLifting + run := fun decls => do + decls.foldlM (init := #[]) fun decls decl => do + let liftInstParamOnly := !(← Meta.isInstance decl.name) + return decls ++ (← decl.lambdaLifting liftInstParamOnly (suffix := `_elambda)) builtin_initialize + registerTraceClass `Compiler.eagerLambdaLifting (inherited := true) registerTraceClass `Compiler.lambdaLifting (inherited := true) end Lean.Compiler.LCNF diff --git a/src/Lean/Compiler/LCNF/Passes.lean b/src/Lean/Compiler/LCNF/Passes.lean index 9a694e8380..b687124cf7 100644 --- a/src/Lean/Compiler/LCNF/Passes.lean +++ b/src/Lean/Compiler/LCNF/Passes.lean @@ -50,6 +50,7 @@ def builtinPassManager : PassManager := { pullFunDecls, reduceJpArity, simp { etaPoly := true, inlinePartial := true, implementedBy := true } (occurrence := 1), + eagerLambdaLifting, specialize, simp (occurrence := 2), cse,