feat: eta expand local function declarations that are not being inlined

This commit is contained in:
Leonardo de Moura 2022-08-21 06:38:42 -07:00
parent fa7769260a
commit 778f9aa08f
3 changed files with 63 additions and 5 deletions

View file

@ -152,6 +152,20 @@ def visitLambda (e : Expr) : CompilerM (Array Expr × Expr) := do
let (fvars, e) ← visitLambdaCore e
return (fvars, e.instantiateRev fvars)
/--
Similar to `visitLambda` but for arrow-types.
-/
def visitArrow (type : Expr) : CompilerM (Array Expr × Expr) := do
go type #[]
where
go (type : Expr) (fvars : Array Expr) := do
if let .forallE binderName type body binderInfo := type then
let type := type.instantiateRev fvars
let fvar ← mkLocalDecl binderName type binderInfo
go body (fvars.push fvar)
else
return (fvars, type.instantiateRev fvars)
/--
Given an expression representing a `match` return a tuple consisting of:
1. The motive

View file

@ -264,11 +264,14 @@ def isOnceOrMustInline (binderName : Name) : SimpM Bool := do
| some .once | some .mustInline => return true
| _ => return false
def isSmallValue (value : Expr) : SimpM Bool := do
lcnfSizeLe value (← read).config.smallThreshold
def shouldInlineLocal (localDecl : LocalDecl) : SimpM Bool := do
if (← isOnceOrMustInline localDecl.userName) then
return true
else
lcnfSizeLe localDecl.value (← read).config.smallThreshold
isSmallValue localDecl.value
structure InlineCandidateInfo where
isLocal : Bool
@ -365,6 +368,23 @@ private def simpUsingEtaReduction (e : Expr) : Expr :=
| .letE n t v b d => .letE n t v (simpUsingEtaReduction b) d
| _ => e
private def etaExpand (type : Expr) (value : Expr) : CompilerM Expr := do
let typeArity := getArrowArity type
let valueArity := getLambdaArity value
if typeArity <= valueArity then
-- It can be < because of the "any" type
return value
else
withNewScope do
let (xs, _) ← visitArrow type
let value := getLambdaBody value
let value := value.instantiateRev xs[:valueArity]
let valueType ← inferType value
let f ← mkLocalDecl (← mkFreshUserName `_f) valueType
let k ← mkLambda #[f] (mkAppN f xs[valueArity:])
let value ← attachJp value k
mkLambda xs value
/--
Auxiliary function for projecting "type class dictionary access".
That is, we are trying to extract one of the type class instance elements.
@ -434,9 +454,6 @@ where
def betaReduce (e : Expr) (args : Array Expr) : SimpM Expr := do
-- TODO: add necessary casts to `args`
let rec getLambdaBody : Expr → Expr
| .lam _ _ b _ => getLambdaBody b
| b => b
let result ← instantiateRevInternalize (getLambdaBody e) args
trace[Meta.debug] "inline:\n{result}"
return result
@ -632,6 +649,7 @@ partial def visitLet (e : Expr) (xs : Array Expr := #[]): SimpM Expr := do
modify fun s => { s with counter := s.counter + 1 }
match e with
| .letE binderName type value body nonDep =>
let type := type.instantiateRev xs
let mut value := value.instantiateRev xs
if value.isLambda then
unless (← isOnceOrMustInline binderName) do
@ -640,6 +658,25 @@ partial def visitLet (e : Expr) (xs : Array Expr := #[]): SimpM Expr := do
we do it after its is inlined and we have information about the actual arguments.
-/
value ← visitLambda value
unless isJpBinderName binderName || (← isSmallValue value) do
/-
This lambda is not going to be inlined. So, we eta-expand it IF it is not a join point.
Recall that local function declarations that are not join points will be lambda lifted
anyway. Eta-expanding here also creates new simplification opportunities for
monadic local functions before we perform the lambda-lifting.
For example, consider the local function
```
let _x.23 := fun xs body =>
...
let _x.29 := StateRefT'.lift _x.24
let _x.30 := _x.25 _x.29
let _x.31 := fun a => ...
ReaderT.bind _x.30 _x.31
```
The function applications `StateRefT'.lift` and `ReaderT.bind` are not inlined because
they are partially applied. After, we eta-expand this code, it will be reduced at this stage.
-/
value ← etaExpand type value
else if let some value' ← simpValue? value then
if value'.isLet then
let e := mkFlatLet binderName type value' body nonDep
@ -653,7 +690,6 @@ partial def visitLet (e : Expr) (xs : Array Expr := #[]): SimpM Expr := do
else if let some e ← inlineApp? value xs body then
return e
else
let type := type.instantiateRev xs
let x ← mkLetDecl binderName type value nonDep
visitLet body (xs.push x)
| _ =>

View file

@ -170,11 +170,19 @@ where
return false
return true
def getArrowArity (e : Expr) :=
match e with
| .forallE _ _ b _ => getArrowArity b + 1
| _ => 0
def getLambdaArity (e : Expr) :=
match e with
| .lam _ _ b _ => getLambdaArity b + 1
| _ => 0
def getLambdaBody : Expr → Expr
| .lam _ _ b _ => getLambdaBody b
| b => b
/--
Whether a given local declaration is a join point.