feat: eta expand partial applications of functions that take local instances as arguments
This commit is contained in:
parent
bf44e9fb2f
commit
1812e86c7f
2 changed files with 71 additions and 13 deletions
|
|
@ -12,8 +12,7 @@ import Lean.Compiler.LCNF.ReduceJpArity
|
|||
|
||||
namespace Lean.Compiler.LCNF
|
||||
|
||||
@[cpass]
|
||||
def builtin : PassInstaller :=
|
||||
.append #[pullInstances, cse, simp, pullFunDecls, reduceJpArity, simp]
|
||||
@[cpass] def builtin : PassInstaller :=
|
||||
.append #[pullInstances, cse, simp, pullFunDecls, reduceJpArity, simp { etaPoly := true }]
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
|||
Authors: Leonardo de Moura
|
||||
-/
|
||||
import Lean.Util.Recognizers
|
||||
import Lean.Meta.Instances
|
||||
import Lean.Compiler.InlineAttrs
|
||||
import Lean.Compiler.LCNF.CompilerM
|
||||
import Lean.Compiler.LCNF.ElimDead
|
||||
|
|
@ -101,7 +102,20 @@ partial def findExpr (e : Expr) (skipMData := true) : CompilerM Expr := do
|
|||
| _ => return e
|
||||
|
||||
structure Config where
|
||||
/--
|
||||
Any local function declaration or join point with size `≤ smallThresold` is inlined
|
||||
even if there are multiple occurrences.
|
||||
We currently don't do the same for global declarations because we are not saving
|
||||
the stage1 compilation result in .olean files yet.
|
||||
-/
|
||||
smallThreshold : Nat := 1
|
||||
/--
|
||||
If `etaPoly` is true, we eta expand any global function application when
|
||||
the function takes local instances. The idea is that we do not generate code
|
||||
for this kind of application, and we want all of them to specialized or inlined.
|
||||
-/
|
||||
etaPoly : Bool := false
|
||||
deriving Inhabited
|
||||
|
||||
structure Context where
|
||||
config : Config := {}
|
||||
|
|
@ -715,6 +729,37 @@ it is a type, type former, or `lcErased`.
|
|||
def addSubst (fvarId : FVarId) (val : Expr) : SimpM Unit :=
|
||||
modify fun s => { s with subst := s.subst.insert fvarId val }
|
||||
|
||||
/--
|
||||
Return `true` if the arrow type contains an instance implicit argument.
|
||||
-/
|
||||
def hasLocalInst (type : Expr) : Bool :=
|
||||
match type with
|
||||
| .forallE _ _ b bi => bi.isInstImplicit || hasLocalInst b
|
||||
| _ => false
|
||||
|
||||
/--
|
||||
When the configuration flag `etaPoly = true`, we eta-expand
|
||||
partial applications of functions that take local instances as arguments.
|
||||
This kind of function is inlined or specialized, and we create new
|
||||
simplification opportunities by eta-expanding them.
|
||||
-/
|
||||
def etaPolyApp? (letDecl : LetDecl) : OptionT SimpM FunDecl := do
|
||||
guard <| (← read).config.etaPoly
|
||||
let .const declName _ := letDecl.value.getAppFn | failure
|
||||
let info ← getConstInfo declName
|
||||
guard <| hasLocalInst info.type
|
||||
guard <| !(← Meta.isInstance declName)
|
||||
let some decl ← getStage1Decl? declName | failure
|
||||
let numArgs := letDecl.value.getAppNumArgs
|
||||
guard <| decl.getArity > numArgs
|
||||
let params ← mkNewParams letDecl.type
|
||||
let value := mkAppN letDecl.value (params.map (.fvar ·.fvarId))
|
||||
let auxDecl ← mkAuxLetDecl value
|
||||
let funDecl ← mkAuxFunDecl params (.let auxDecl (.return auxDecl.fvarId))
|
||||
addSubst letDecl.fvarId (.fvar funDecl.fvarId)
|
||||
eraseLocalDecl letDecl.fvarId
|
||||
return funDecl
|
||||
|
||||
mutual
|
||||
/--
|
||||
Simplify the given local function declaration.
|
||||
|
|
@ -748,14 +793,16 @@ partial def simpCasesOnCtor? (cases : Cases) : SimpM (Option Code) := do
|
|||
/--
|
||||
Simplify `code`
|
||||
-/
|
||||
partial def simp (code : Code) : SimpM Code := do
|
||||
partial def simp (code : Code) : SimpM Code := withIncRecDepth do
|
||||
incVisited
|
||||
match code with
|
||||
| .let decl k =>
|
||||
let mut decl ← normLetDecl decl
|
||||
if let some value ← simpValue? decl.value then
|
||||
decl ← decl.updateValue value
|
||||
if decl.value.isFVar then
|
||||
if let some funDecl ← etaPolyApp? decl then
|
||||
simp (.fun funDecl k)
|
||||
else if decl.value.isFVar then
|
||||
/- Eliminate `let _x_i := _x_j;` -/
|
||||
addSubst decl.fvarId decl.value
|
||||
eraseLocalDecl decl.fvarId
|
||||
|
|
@ -865,15 +912,27 @@ def Decl.simp? (decl : Decl) : SimpM (Option Decl) := do
|
|||
else
|
||||
return none
|
||||
|
||||
partial def Decl.simp (decl : Decl) : CompilerM Decl := do
|
||||
if let some decl ← decl.simp? |>.run {} |>.run' {} then
|
||||
-- TODO: bound number of steps?
|
||||
decl.simp
|
||||
else
|
||||
return decl
|
||||
partial def Decl.simp (decl : Decl) (config : Config) : CompilerM Decl := do
|
||||
let mut config := config
|
||||
if (← pure (Simp.hasLocalInst decl.type) <||> Meta.isInstance decl.name) then
|
||||
/-
|
||||
We do not eta-expand partial applications in instances or when the declaration type
|
||||
has local instances. Recall we do not generate code for them.
|
||||
Remark: by eta-expanding partial applications in instaces, we also make the simplifier
|
||||
work harder when inlining instance projections.
|
||||
-/
|
||||
config := { config with etaPoly := false }
|
||||
go decl config
|
||||
where
|
||||
go (decl : Decl) (config : Config) : CompilerM Decl := do
|
||||
if let some decl ← decl.simp? |>.run { config } |>.run' {} then
|
||||
-- TODO: bound number of steps?
|
||||
go decl config
|
||||
else
|
||||
return decl
|
||||
|
||||
def simp : Pass :=
|
||||
.mkPerDeclaration `simp Decl.simp
|
||||
def simp (config : Config := {}) : Pass :=
|
||||
.mkPerDeclaration `simp (Decl.simp · config)
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Compiler.simp (inherited := true)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue