From 1812e86c7ff0bbc348a18fb505f76a97823ed244 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 5 Sep 2022 19:33:22 -0700 Subject: [PATCH] feat: eta expand partial applications of functions that take local instances as arguments --- src/Lean/Compiler/LCNF/Passes.lean | 5 +- src/Lean/Compiler/LCNF/Simp.lean | 79 ++++++++++++++++++++++++++---- 2 files changed, 71 insertions(+), 13 deletions(-) diff --git a/src/Lean/Compiler/LCNF/Passes.lean b/src/Lean/Compiler/LCNF/Passes.lean index 97b2786b6e..e2e94111a7 100644 --- a/src/Lean/Compiler/LCNF/Passes.lean +++ b/src/Lean/Compiler/LCNF/Passes.lean @@ -12,8 +12,7 @@ import Lean.Compiler.LCNF.ReduceJpArity namespace Lean.Compiler.LCNF -@[cpass] -def builtin : PassInstaller := - .append #[pullInstances, cse, simp, pullFunDecls, reduceJpArity, simp] +@[cpass] def builtin : PassInstaller := + .append #[pullInstances, cse, simp, pullFunDecls, reduceJpArity, simp { etaPoly := true }] end Lean.Compiler.LCNF diff --git a/src/Lean/Compiler/LCNF/Simp.lean b/src/Lean/Compiler/LCNF/Simp.lean index 332e84d770..b62d3b274d 100644 --- a/src/Lean/Compiler/LCNF/Simp.lean +++ b/src/Lean/Compiler/LCNF/Simp.lean @@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ import Lean.Util.Recognizers +import Lean.Meta.Instances import Lean.Compiler.InlineAttrs import Lean.Compiler.LCNF.CompilerM import Lean.Compiler.LCNF.ElimDead @@ -101,7 +102,20 @@ partial def findExpr (e : Expr) (skipMData := true) : CompilerM Expr := do | _ => return e structure Config where + /-- + Any local function declaration or join point with size `≤ smallThresold` is inlined + even if there are multiple occurrences. + We currently don't do the same for global declarations because we are not saving + the stage1 compilation result in .olean files yet. + -/ smallThreshold : Nat := 1 + /-- + If `etaPoly` is true, we eta expand any global function application when + the function takes local instances. The idea is that we do not generate code + for this kind of application, and we want all of them to specialized or inlined. + -/ + etaPoly : Bool := false + deriving Inhabited structure Context where config : Config := {} @@ -715,6 +729,37 @@ it is a type, type former, or `lcErased`. def addSubst (fvarId : FVarId) (val : Expr) : SimpM Unit := modify fun s => { s with subst := s.subst.insert fvarId val } +/-- +Return `true` if the arrow type contains an instance implicit argument. +-/ +def hasLocalInst (type : Expr) : Bool := + match type with + | .forallE _ _ b bi => bi.isInstImplicit || hasLocalInst b + | _ => false + +/-- +When the configuration flag `etaPoly = true`, we eta-expand +partial applications of functions that take local instances as arguments. +This kind of function is inlined or specialized, and we create new +simplification opportunities by eta-expanding them. +-/ +def etaPolyApp? (letDecl : LetDecl) : OptionT SimpM FunDecl := do + guard <| (← read).config.etaPoly + let .const declName _ := letDecl.value.getAppFn | failure + let info ← getConstInfo declName + guard <| hasLocalInst info.type + guard <| !(← Meta.isInstance declName) + let some decl ← getStage1Decl? declName | failure + let numArgs := letDecl.value.getAppNumArgs + guard <| decl.getArity > numArgs + let params ← mkNewParams letDecl.type + let value := mkAppN letDecl.value (params.map (.fvar ·.fvarId)) + let auxDecl ← mkAuxLetDecl value + let funDecl ← mkAuxFunDecl params (.let auxDecl (.return auxDecl.fvarId)) + addSubst letDecl.fvarId (.fvar funDecl.fvarId) + eraseLocalDecl letDecl.fvarId + return funDecl + mutual /-- Simplify the given local function declaration. @@ -748,14 +793,16 @@ partial def simpCasesOnCtor? (cases : Cases) : SimpM (Option Code) := do /-- Simplify `code` -/ -partial def simp (code : Code) : SimpM Code := do +partial def simp (code : Code) : SimpM Code := withIncRecDepth do incVisited match code with | .let decl k => let mut decl ← normLetDecl decl if let some value ← simpValue? decl.value then decl ← decl.updateValue value - if decl.value.isFVar then + if let some funDecl ← etaPolyApp? decl then + simp (.fun funDecl k) + else if decl.value.isFVar then /- Eliminate `let _x_i := _x_j;` -/ addSubst decl.fvarId decl.value eraseLocalDecl decl.fvarId @@ -865,15 +912,27 @@ def Decl.simp? (decl : Decl) : SimpM (Option Decl) := do else return none -partial def Decl.simp (decl : Decl) : CompilerM Decl := do - if let some decl ← decl.simp? |>.run {} |>.run' {} then - -- TODO: bound number of steps? - decl.simp - else - return decl +partial def Decl.simp (decl : Decl) (config : Config) : CompilerM Decl := do + let mut config := config + if (← pure (Simp.hasLocalInst decl.type) <||> Meta.isInstance decl.name) then + /- + We do not eta-expand partial applications in instances or when the declaration type + has local instances. Recall we do not generate code for them. + Remark: by eta-expanding partial applications in instaces, we also make the simplifier + work harder when inlining instance projections. + -/ + config := { config with etaPoly := false } + go decl config +where + go (decl : Decl) (config : Config) : CompilerM Decl := do + if let some decl ← decl.simp? |>.run { config } |>.run' {} then + -- TODO: bound number of steps? + go decl config + else + return decl -def simp : Pass := - .mkPerDeclaration `simp Decl.simp +def simp (config : Config := {}) : Pass := + .mkPerDeclaration `simp (Decl.simp · config) builtin_initialize registerTraceClass `Compiler.simp (inherited := true)