feat: eta expand partial applications of functions that take local instances as arguments

This commit is contained in:
Leonardo de Moura 2022-09-05 19:33:22 -07:00
parent bf44e9fb2f
commit 1812e86c7f
2 changed files with 71 additions and 13 deletions

View file

@ -12,8 +12,7 @@ import Lean.Compiler.LCNF.ReduceJpArity
namespace Lean.Compiler.LCNF
@[cpass]
def builtin : PassInstaller :=
.append #[pullInstances, cse, simp, pullFunDecls, reduceJpArity, simp]
@[cpass] def builtin : PassInstaller :=
.append #[pullInstances, cse, simp, pullFunDecls, reduceJpArity, simp { etaPoly := true }]
end Lean.Compiler.LCNF

View file

@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.Util.Recognizers
import Lean.Meta.Instances
import Lean.Compiler.InlineAttrs
import Lean.Compiler.LCNF.CompilerM
import Lean.Compiler.LCNF.ElimDead
@ -101,7 +102,20 @@ partial def findExpr (e : Expr) (skipMData := true) : CompilerM Expr := do
| _ => return e
structure Config where
/--
Any local function declaration or join point with size `≤ smallThresold` is inlined
even if there are multiple occurrences.
We currently don't do the same for global declarations because we are not saving
the stage1 compilation result in .olean files yet.
-/
smallThreshold : Nat := 1
/--
If `etaPoly` is true, we eta expand any global function application when
the function takes local instances. The idea is that we do not generate code
for this kind of application, and we want all of them to specialized or inlined.
-/
etaPoly : Bool := false
deriving Inhabited
structure Context where
config : Config := {}
@ -715,6 +729,37 @@ it is a type, type former, or `lcErased`.
def addSubst (fvarId : FVarId) (val : Expr) : SimpM Unit :=
modify fun s => { s with subst := s.subst.insert fvarId val }
/--
Return `true` if the arrow type contains an instance implicit argument.
-/
def hasLocalInst (type : Expr) : Bool :=
match type with
| .forallE _ _ b bi => bi.isInstImplicit || hasLocalInst b
| _ => false
/--
When the configuration flag `etaPoly = true`, we eta-expand
partial applications of functions that take local instances as arguments.
This kind of function is inlined or specialized, and we create new
simplification opportunities by eta-expanding them.
-/
def etaPolyApp? (letDecl : LetDecl) : OptionT SimpM FunDecl := do
guard <| (← read).config.etaPoly
let .const declName _ := letDecl.value.getAppFn | failure
let info ← getConstInfo declName
guard <| hasLocalInst info.type
guard <| !(← Meta.isInstance declName)
let some decl ← getStage1Decl? declName | failure
let numArgs := letDecl.value.getAppNumArgs
guard <| decl.getArity > numArgs
let params ← mkNewParams letDecl.type
let value := mkAppN letDecl.value (params.map (.fvar ·.fvarId))
let auxDecl ← mkAuxLetDecl value
let funDecl ← mkAuxFunDecl params (.let auxDecl (.return auxDecl.fvarId))
addSubst letDecl.fvarId (.fvar funDecl.fvarId)
eraseLocalDecl letDecl.fvarId
return funDecl
mutual
/--
Simplify the given local function declaration.
@ -748,14 +793,16 @@ partial def simpCasesOnCtor? (cases : Cases) : SimpM (Option Code) := do
/--
Simplify `code`
-/
partial def simp (code : Code) : SimpM Code := do
partial def simp (code : Code) : SimpM Code := withIncRecDepth do
incVisited
match code with
| .let decl k =>
let mut decl ← normLetDecl decl
if let some value ← simpValue? decl.value then
decl ← decl.updateValue value
if decl.value.isFVar then
if let some funDecl ← etaPolyApp? decl then
simp (.fun funDecl k)
else if decl.value.isFVar then
/- Eliminate `let _x_i := _x_j;` -/
addSubst decl.fvarId decl.value
eraseLocalDecl decl.fvarId
@ -865,15 +912,27 @@ def Decl.simp? (decl : Decl) : SimpM (Option Decl) := do
else
return none
partial def Decl.simp (decl : Decl) : CompilerM Decl := do
if let some decl ← decl.simp? |>.run {} |>.run' {} then
-- TODO: bound number of steps?
decl.simp
else
return decl
partial def Decl.simp (decl : Decl) (config : Config) : CompilerM Decl := do
let mut config := config
if (← pure (Simp.hasLocalInst decl.type) <||> Meta.isInstance decl.name) then
/-
We do not eta-expand partial applications in instances or when the declaration type
has local instances. Recall we do not generate code for them.
Remark: by eta-expanding partial applications in instaces, we also make the simplifier
work harder when inlining instance projections.
-/
config := { config with etaPoly := false }
go decl config
where
go (decl : Decl) (config : Config) : CompilerM Decl := do
if let some decl ← decl.simp? |>.run { config } |>.run' {} then
-- TODO: bound number of steps?
go decl config
else
return decl
def simp : Pass :=
.mkPerDeclaration `simp Decl.simp
def simp (config : Config := {}) : Pass :=
.mkPerDeclaration `simp (Decl.simp · config)
builtin_initialize
registerTraceClass `Compiler.simp (inherited := true)