diff --git a/src/Lean/Compiler/LCNF/LambdaLifting.lean b/src/Lean/Compiler/LCNF/LambdaLifting.lean index fe1648391c..06f607fbef 100644 --- a/src/Lean/Compiler/LCNF/LambdaLifting.lean +++ b/src/Lean/Compiler/LCNF/LambdaLifting.lean @@ -40,6 +40,10 @@ structure Context where We use this feature to implement `@[inline] instance ...` and `@[always_inline] instance ...` -/ minSize : Nat := 0 + /-- + Allow for eta contraction instead of lifting to a lambda if possible. + -/ + allowEtaContraction : Bool := true /-- State for the `LiftM` monad. -/ @@ -81,6 +85,13 @@ partial def mkAuxDeclName : LiftM Name := do if (← getDecl? nameNew).isNone then return nameNew mkAuxDeclName +def replaceFunDecl (decl : FunDecl) (value : LetValue) : LiftM LetDecl := do + /- We reuse `decl`s `fvarId` to avoid substitution -/ + let declNew := { fvarId := decl.fvarId, binderName := decl.binderName, type := decl.type, value } + modifyLCtx fun lctx => lctx.addLetDecl declNew + eraseFunDecl decl + return declNew + open Internalize in /-- Create a new auxiliary declaration. The array `closure` contains all free variables @@ -100,11 +111,7 @@ def mkAuxDecl (closure : Array Param) (decl : FunDecl) : LiftM LetDecl := do auxDecl.erase pure declName let value := .const auxDeclName us (closure.map (.fvar ·.fvarId)) - /- We reuse `decl`s `fvarId` to avoid substitution -/ - let declNew := { fvarId := decl.fvarId, binderName := decl.binderName, type := decl.type, value } - modifyLCtx fun lctx => lctx.addLetDecl declNew - eraseFunDecl decl - return declNew + replaceFunDecl decl value where go (nameNew : Name) (safe : Bool) (inlineAttr? : Option InlineAttributeKind) : InternalizeM Decl := do let params := (← closure.mapM internalizeParam) ++ (← decl.params.mapM internalizeParam) @@ -115,6 +122,20 @@ where let decl := { name := nameNew, levelParams := [], params, type, value, safe, inlineAttr?, recursive := false : Decl } return decl.setLevelParams +def etaContractibleDecl? (decl : FunDecl) : LiftM (Option LetDecl) := do + if !(← read).allowEtaContraction then return none + let .let { fvarId := letVar, value := .const declName us args, .. } (.return retVar) := decl.value + | return none + if letVar != retVar then return none + if args.size != decl.params.size then return none + if (← getDecl? declName).isNone then return none + for arg in args, param in decl.params do + let .fvar argVar := arg | return none + if argVar != param.fvarId then return none + + let value := .const declName us #[] + replaceFunDecl decl value + mutual partial def visitFunDecl (funDecl : FunDecl) : LiftM FunDecl := do let value ← withParams funDecl.params <| visitCode funDecl.value @@ -128,9 +149,13 @@ mutual | .fun decl k => let decl ← visitFunDecl decl if (← shouldLift decl) then - let scope ← getScope - let (_, params, _) ← Closure.run (inScope := scope.contains) <| Closure.collectFunDecl decl - let declNew ← mkAuxDecl params decl + let declNew ← do + if let some letDecl ← etaContractibleDecl? decl then + pure letDecl + else + let scope ← getScope + let (_, params, _) ← Closure.run (inScope := scope.contains) <| Closure.collectFunDecl decl + mkAuxDecl params decl let k ← withFVar declNew.fvarId <| visitCode k return .let declNew k else @@ -155,8 +180,17 @@ def main (decl : Decl) : LiftM Decl := do end LambdaLifting -partial def Decl.lambdaLifting (decl : Decl) (liftInstParamOnly : Bool) (suffix : Name) (inheritInlineAttrs := false) (minSize := 0) : CompilerM (Array Decl) := do - let (decl, s) ← LambdaLifting.main decl |>.run { mainDecl := decl, liftInstParamOnly, suffix, inheritInlineAttrs, minSize } |>.run {} |>.run {} +partial def Decl.lambdaLifting (decl : Decl) (liftInstParamOnly : Bool) (allowEtaContraction : Bool) + (suffix : Name) (inheritInlineAttrs := false) (minSize := 0) : CompilerM (Array Decl) := do + let ctx := { + mainDecl := decl, + liftInstParamOnly, + suffix, + inheritInlineAttrs, + minSize, + allowEtaContraction + } + let (decl, s) ← LambdaLifting.main decl |>.run ctx |>.run {} |>.run {} return s.decls.push decl /-- @@ -166,7 +200,8 @@ def lambdaLifting : Pass where phase := .mono name := `lambdaLifting run := fun decls => do - decls.foldlM (init := #[]) fun decls decl => return decls ++ (← decl.lambdaLifting false (suffix := `_lam)) + decls.foldlM (init := #[]) fun decls decl => + return decls ++ (← decl.lambdaLifting false true (suffix := `_lam)) /-- During eager lambda lifting, we inspect declarations that are not inlineable or instances (doing it @@ -182,7 +217,7 @@ def eagerLambdaLifting : Pass where if decl.inlineable || (← Meta.isInstance decl.name) then return decls.push decl else - return decls ++ (← decl.lambdaLifting (liftInstParamOnly := true) (suffix := `_elam)) + return decls ++ (← decl.lambdaLifting (liftInstParamOnly := true) (allowEtaContraction := false) (suffix := `_elam)) builtin_initialize registerTraceClass `Compiler.eagerLambdaLifting (inherited := true) diff --git a/tests/lean/run/eta_lambda_lift.lean b/tests/lean/run/eta_lambda_lift.lean new file mode 100644 index 0000000000..15f2a47048 --- /dev/null +++ b/tests/lean/run/eta_lambda_lift.lean @@ -0,0 +1,17 @@ +import Lean.Util.FindExpr + +/-! +This test asserts that the compiler will eta contract trivial lambdas instead of lambda lifting +them. +-/ + +/-- +trace: [Compiler.lambdaLifting] size: 2 + def test e : Option Lean.Expr := + let _f.1 := Lean.Expr.hasMVar; + let _x.2 := Lean.Expr.findImpl? _f.1 e; + return _x.2 +-/ +#guard_msgs in +set_option trace.Compiler.lambdaLifting true in +def test (e : Lean.Expr) := e.find? (fun e => e.hasMVar)