feat: improve code inliner

and fix bugs at the `onlyOneExitPoint` case.
This commit is contained in:
Leonardo de Moura 2022-08-19 11:33:16 -07:00
parent 6d11dc9b62
commit a7c96142ea

View file

@ -215,14 +215,45 @@ def inlineCandidate? (e : Expr) : SimpM (Option InlineCandidateInfo) := do
If `e` if a free variable that expands to a valid LCNF terminal `let`-block expression `e'`,
return `e'`.
-/
def expandTrivialExpr? (e : Expr) : SimpM (Option Expr) := do
def expandTrivialExpr (e : Expr) : SimpM Expr := do
if e.isFVar then
let e' ← findExpr e
unless e'.isLambda do
if e != e' then
markSimplified
return some e'
return none
return e'
return e
/--
Given `value` of the form `let x_1 := v_1; ...; let x_n := v_n; e`,
return `let x_1; ...; let x_n := v_n; let y : type := e; body`.
This methods assumes `type` and `value` do not have loose bound variables.
Remark: `body` may have many loose bound variables, and the loose bound variables > 0
must be lifted by `n`.
-/
def mkFlatLet (y : Name) (type : Expr) (value : Expr) (body : Expr) (nonDep : Bool := false) : Expr :=
go value 0
where
go (value : Expr) (i : Nat) : Expr :=
match value with
| .letE n t v b d => .letE n t v (go b (i+1)) d
| _ => .letE y type value (body.liftLooseBVars 1 i) nonDep
/--
Update inlining statistics (`stats` field) with the local function
declarations in `e`.
We use this method to make sure type class instance elements are
inlined in the current compiler simp pass.
-/
private def updateStatsUsing (e : Expr) : SimpM Unit := do
match e with
| .letE binderName _ v b _ =>
if v.isLambda then
modify fun s => { s with stats := s.stats.add binderName 1 }
updateStatsUsing b
| _ => return ()
/--
Auxiliary function for projecting "type class dictionary access".
@ -257,12 +288,8 @@ partial def inlineProjInst? (e : Expr) : OptionT SimpM Expr := do
-/
let value ← withNewScope do mkLetUsingScope (← visitProj e)
let value ← ensureUniqueLetVarNames value
/- We use `visitLet` again to put back on the current local context the relevant let-declarations. -/
visitLet (m := SimpM) value fun binderName value => do
if value.isLambda then
/- make sure instance element can be beta reduced in this simp step. -/
modify fun s => { s with stats := s.stats.add binderName 1 }
return value
updateStatsUsing value
return value
where
visitProj (e : Expr) : OptionT SimpM Expr := do
let .proj _ i s := e | unreachable!
@ -285,7 +312,12 @@ where
guard <| decl.getArity == e.getAppNumArgs
let value := decl.value.instantiateLevelParams decl.levelParams us
let value := value.beta e.getAppArgs
let value ← visitLet (m := SimpM) value fun _ value => return value
/-
Here, we just go inside of the let-declaration block without trying to simplify it.
Reason: a type class instannce may have many elements, and it does not make sense to simplify
all of them when we are extracting only one of them.
-/
let value ← Compiler.visitLet (m := SimpM) value fun _ value => return value
visit value
mutual
@ -327,55 +359,51 @@ partial def inlineApp? (e : Expr) (xs : Array Expr) (k? : Option Expr) : SimpM (
let numArgs := args.size
trace[Compiler.simp.inline] "inlining {e}"
markSimplified
if !(← manyExitPoints info.value) then
/- If `info.value` has only one exit point, we don't need to create a new join point -/
let value := info.value.beta args[:info.arity]
let value ← visitLet value #[]
match numArgs == info.arity, k? with
| true, none => return value
| false, none => return mkAppN (← mkAuxLetDecl value) args[info.arity:]
| true, some k => let x ← mkAuxLetDecl value; visitLet k (xs.push x)
| false, some k =>
let x ← mkAuxLetDecl value
let x ← mkAuxLetDecl (mkAppN x args[info.arity:])
visitLet k (xs.push x)
if k?.isNone && numArgs == info.arity then
/- Easy case, there is no continuation and `e` is not over applied -/
visitLet (info.value.beta args)
else if (← onlyOneExitPoint info.value) then
/- If `info.value` has only one exit point, we don't need to create a new auxiliary join point -/
let mut value := info.value.beta args[:info.arity]
if numArgs > info.arity then
let type ← inferType (mkAppN e.getAppFn args[:info.arity])
value := mkFlatLet (← mkAuxLetDeclName) type value (mkAppN (.bvar 0) args[info.arity:])
if let some k := k? then
let type ← inferType e
value := mkFlatLet (← mkAuxLetDeclName) type value k
visitLet value xs
else
let args := e.getAppArgs
if k?.isNone && numArgs == info.arity then
/- Easy case, there is no continuation and `e` is not overapplied -/
return info.value.beta args
else
/-
There is a continuation `k` or `e` is over applied.
If `e` is over applied, the extra arguments act as a continuation.
/-
There is a continuation `k` or `e` is over applied.
If `e` is over applied, the extra arguments act as a continuation.
We create a new join point
```
let jp := fun y =>
let x := y <extra-arguments> -- if `e` is over applied
k
```
Recall that `visitLet` incorporates the current continuation
to the new join point `jp`.
-/
let jpDomain ← inferType (mkAppN e.getAppFn args[:info.arity])
let binderName ← mkFreshUserName `_y
let jp ← withNewScope do
let y ← mkLocalDecl binderName jpDomain
let body ← if numArgs == info.arity then
visitLet k?.get! (xs.push y)
We create a new join point
```
let jp := fun y =>
let x := y <extra-arguments> -- if `e` is over applied
k
```
Recall that `visitLet` incorporates the current continuation
to the new join point `jp`.
-/
let jpDomain ← inferType (mkAppN e.getAppFn args[:info.arity])
let binderName ← mkFreshUserName `_y
let jp ← withNewScope do
let y ← mkLocalDecl binderName jpDomain
let body ← if numArgs == info.arity then
visitLet k?.get! (xs.push y)
else
let x ← mkAuxLetDecl (mkAppN y args[info.arity:])
if let some k := k? then
visitLet k (xs.push x)
else
let x ← mkAuxLetDecl (mkAppN y args[info.arity:])
if let some k := k? then
visitLet k (xs.push x)
else
visitLet x (xs.push x)
let body ← mkLetUsingScope body
mkLambda #[y] body
let jp ← mkJpDeclIfNotSimple jp
let value := info.value.beta args[:info.arity]
let value ← attachJp value jp
visitLet value
visitLet x (xs.push x)
let body ← mkLetUsingScope body
mkLambda #[y] body
let jp ← mkJpDeclIfNotSimple jp
let value := info.value.beta args[:info.arity]
let value ← attachJp value jp
visitLet value
/-- Try to apply simple simplifications. -/
partial def simpValue? (e : Expr) : SimpM (Option Expr) :=
@ -392,28 +420,31 @@ partial def visitLet (e : Expr) (xs : Array Expr := #[]): SimpM Expr := do
if value.isLambda then
value ← visitLambda value
else if let some value' ← simpValue? value then
if value'.isLet then
let e := mkFlatLet binderName type value' body nonDep
let e ← visitLet e xs
return e
value := value'
if value.isFVar then
/- Eliminate `let _x_i := _x_j;` -/
markSimplified
visitLet body (xs.push value)
else if let some e ← inlineApp? value xs body then
visitLet e
return e
else
let type := type.instantiateRev xs
let x ← mkLetDecl binderName type value nonDep
visitLet body (xs.push x)
| _ =>
let e := e.instantiateRev xs
let e := (← simpValue? e).getD e
if let some casesInfo ← isCasesApp? e then
if let some value ← simpValue? e then
visitLet value
else if let some casesInfo ← isCasesApp? e then
visitCases casesInfo e
else if let some e ← inlineApp? e #[] none then
visitLet e
else if let some e ← expandTrivialExpr? e then
visitLet e
else
return e
else
expandTrivialExpr e
end
end Simp