feat: eta expand local function declarations that are not being inlined
This commit is contained in:
parent
fa7769260a
commit
778f9aa08f
3 changed files with 63 additions and 5 deletions
|
|
@ -152,6 +152,20 @@ def visitLambda (e : Expr) : CompilerM (Array Expr × Expr) := do
|
|||
let (fvars, e) ← visitLambdaCore e
|
||||
return (fvars, e.instantiateRev fvars)
|
||||
|
||||
/--
|
||||
Similar to `visitLambda` but for arrow-types.
|
||||
-/
|
||||
def visitArrow (type : Expr) : CompilerM (Array Expr × Expr) := do
|
||||
go type #[]
|
||||
where
|
||||
go (type : Expr) (fvars : Array Expr) := do
|
||||
if let .forallE binderName type body binderInfo := type then
|
||||
let type := type.instantiateRev fvars
|
||||
let fvar ← mkLocalDecl binderName type binderInfo
|
||||
go body (fvars.push fvar)
|
||||
else
|
||||
return (fvars, type.instantiateRev fvars)
|
||||
|
||||
/--
|
||||
Given an expression representing a `match` return a tuple consisting of:
|
||||
1. The motive
|
||||
|
|
|
|||
|
|
@ -264,11 +264,14 @@ def isOnceOrMustInline (binderName : Name) : SimpM Bool := do
|
|||
| some .once | some .mustInline => return true
|
||||
| _ => return false
|
||||
|
||||
def isSmallValue (value : Expr) : SimpM Bool := do
|
||||
lcnfSizeLe value (← read).config.smallThreshold
|
||||
|
||||
def shouldInlineLocal (localDecl : LocalDecl) : SimpM Bool := do
|
||||
if (← isOnceOrMustInline localDecl.userName) then
|
||||
return true
|
||||
else
|
||||
lcnfSizeLe localDecl.value (← read).config.smallThreshold
|
||||
isSmallValue localDecl.value
|
||||
|
||||
structure InlineCandidateInfo where
|
||||
isLocal : Bool
|
||||
|
|
@ -365,6 +368,23 @@ private def simpUsingEtaReduction (e : Expr) : Expr :=
|
|||
| .letE n t v b d => .letE n t v (simpUsingEtaReduction b) d
|
||||
| _ => e
|
||||
|
||||
private def etaExpand (type : Expr) (value : Expr) : CompilerM Expr := do
|
||||
let typeArity := getArrowArity type
|
||||
let valueArity := getLambdaArity value
|
||||
if typeArity <= valueArity then
|
||||
-- It can be < because of the "any" type
|
||||
return value
|
||||
else
|
||||
withNewScope do
|
||||
let (xs, _) ← visitArrow type
|
||||
let value := getLambdaBody value
|
||||
let value := value.instantiateRev xs[:valueArity]
|
||||
let valueType ← inferType value
|
||||
let f ← mkLocalDecl (← mkFreshUserName `_f) valueType
|
||||
let k ← mkLambda #[f] (mkAppN f xs[valueArity:])
|
||||
let value ← attachJp value k
|
||||
mkLambda xs value
|
||||
|
||||
/--
|
||||
Auxiliary function for projecting "type class dictionary access".
|
||||
That is, we are trying to extract one of the type class instance elements.
|
||||
|
|
@ -434,9 +454,6 @@ where
|
|||
|
||||
def betaReduce (e : Expr) (args : Array Expr) : SimpM Expr := do
|
||||
-- TODO: add necessary casts to `args`
|
||||
let rec getLambdaBody : Expr → Expr
|
||||
| .lam _ _ b _ => getLambdaBody b
|
||||
| b => b
|
||||
let result ← instantiateRevInternalize (getLambdaBody e) args
|
||||
trace[Meta.debug] "inline:\n{result}"
|
||||
return result
|
||||
|
|
@ -632,6 +649,7 @@ partial def visitLet (e : Expr) (xs : Array Expr := #[]): SimpM Expr := do
|
|||
modify fun s => { s with counter := s.counter + 1 }
|
||||
match e with
|
||||
| .letE binderName type value body nonDep =>
|
||||
let type := type.instantiateRev xs
|
||||
let mut value := value.instantiateRev xs
|
||||
if value.isLambda then
|
||||
unless (← isOnceOrMustInline binderName) do
|
||||
|
|
@ -640,6 +658,25 @@ partial def visitLet (e : Expr) (xs : Array Expr := #[]): SimpM Expr := do
|
|||
we do it after its is inlined and we have information about the actual arguments.
|
||||
-/
|
||||
value ← visitLambda value
|
||||
unless isJpBinderName binderName || (← isSmallValue value) do
|
||||
/-
|
||||
This lambda is not going to be inlined. So, we eta-expand it IF it is not a join point.
|
||||
Recall that local function declarations that are not join points will be lambda lifted
|
||||
anyway. Eta-expanding here also creates new simplification opportunities for
|
||||
monadic local functions before we perform the lambda-lifting.
|
||||
For example, consider the local function
|
||||
```
|
||||
let _x.23 := fun xs body =>
|
||||
...
|
||||
let _x.29 := StateRefT'.lift _x.24
|
||||
let _x.30 := _x.25 _x.29
|
||||
let _x.31 := fun a => ...
|
||||
ReaderT.bind _x.30 _x.31
|
||||
```
|
||||
The function applications `StateRefT'.lift` and `ReaderT.bind` are not inlined because
|
||||
they are partially applied. After, we eta-expand this code, it will be reduced at this stage.
|
||||
-/
|
||||
value ← etaExpand type value
|
||||
else if let some value' ← simpValue? value then
|
||||
if value'.isLet then
|
||||
let e := mkFlatLet binderName type value' body nonDep
|
||||
|
|
@ -653,7 +690,6 @@ partial def visitLet (e : Expr) (xs : Array Expr := #[]): SimpM Expr := do
|
|||
else if let some e ← inlineApp? value xs body then
|
||||
return e
|
||||
else
|
||||
let type := type.instantiateRev xs
|
||||
let x ← mkLetDecl binderName type value nonDep
|
||||
visitLet body (xs.push x)
|
||||
| _ =>
|
||||
|
|
|
|||
|
|
@ -170,11 +170,19 @@ where
|
|||
return false
|
||||
return true
|
||||
|
||||
def getArrowArity (e : Expr) :=
|
||||
match e with
|
||||
| .forallE _ _ b _ => getArrowArity b + 1
|
||||
| _ => 0
|
||||
|
||||
def getLambdaArity (e : Expr) :=
|
||||
match e with
|
||||
| .lam _ _ b _ => getLambdaArity b + 1
|
||||
| _ => 0
|
||||
|
||||
def getLambdaBody : Expr → Expr
|
||||
| .lam _ _ b _ => getLambdaBody b
|
||||
| b => b
|
||||
|
||||
/--
|
||||
Whether a given local declaration is a join point.
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue