feat: inline applications of the form inline (f ...)
The `inline` identity function is a directive for the compiler.
This commit is contained in:
parent
1953f5953f
commit
ca098d3769
2 changed files with 44 additions and 21 deletions
|
|
@ -311,6 +311,8 @@ structure InlineCandidateInfo where
|
|||
params : Array Param
|
||||
/-- Value (lambda expression) of the function to be inlined. -/
|
||||
value : Code
|
||||
f : Expr
|
||||
args : Array Expr
|
||||
|
||||
/-- The arity (aka number of parameters) of the function to be inlined. -/
|
||||
def InlineCandidateInfo.arity : InlineCandidateInfo → Nat
|
||||
|
|
@ -320,31 +322,40 @@ def InlineCandidateInfo.arity : InlineCandidateInfo → Nat
|
|||
Return `some info` if `e` should be inlined.
|
||||
-/
|
||||
def inlineCandidate? (e : Expr) : SimpM (Option InlineCandidateInfo) := do
|
||||
let mut e := e
|
||||
let mut mustInline := false
|
||||
if e.isAppOfArity ``inline 2 then
|
||||
e ← findExpr e.appArg!
|
||||
mustInline := true
|
||||
let numArgs := e.getAppNumArgs
|
||||
let f := e.getAppFn
|
||||
if let .const declName us ← findExpr f then
|
||||
unless hasInlineAttribute (← getEnv) declName do return none
|
||||
unless mustInline || hasInlineAttribute (← getEnv) declName do return none
|
||||
-- TODO: check whether function is recursive or not.
|
||||
-- We can skip the test and store function inline so far.
|
||||
let some decl ← getStage1Decl? declName | return none
|
||||
let arity := decl.getArity
|
||||
let inlinePartial := (← read).config.inlinePartial
|
||||
if !inlinePartial && numArgs < arity then return none
|
||||
if !mustInline && !inlinePartial && numArgs < arity then return none
|
||||
let params := decl.instantiateParamsLevelParams us
|
||||
let value := decl.instantiateValueLevelParams us
|
||||
incInline
|
||||
return some {
|
||||
isLocal := false
|
||||
f := e.getAppFn
|
||||
args := e.getAppArgs
|
||||
params, value
|
||||
}
|
||||
else if let some decl ← findFunDecl? f then
|
||||
unless numArgs > 0 do return none -- It is not worth to inline a local function that does not take any arguments
|
||||
unless (← shouldInlineLocal decl) do return none
|
||||
unless mustInline || (← shouldInlineLocal decl) do return none
|
||||
-- Remark: we inline local function declarations even if they are partial applied
|
||||
incInlineLocal
|
||||
modify fun s => { s with inlineLocal := s.inlineLocal + 1 }
|
||||
return some {
|
||||
isLocal := true
|
||||
f := e.getAppFn
|
||||
args := e.getAppArgs
|
||||
params := decl.params
|
||||
value := decl.value
|
||||
}
|
||||
|
|
@ -392,15 +403,15 @@ def betaReduce (params : Array Param) (code : Code) (args : Array Expr) (mustInl
|
|||
return code
|
||||
|
||||
/--
|
||||
Create a new local function declaration when `args.size < info.params.size`.
|
||||
Create a new local function declaration when `info.args.size < info.params.size`.
|
||||
We use this function to inline/specialize a partial application of a local function.
|
||||
-/
|
||||
def specializePartialApp (info : InlineCandidateInfo) (args : Array Expr) : SimpM FunDecl := do
|
||||
def specializePartialApp (info : InlineCandidateInfo) : SimpM FunDecl := do
|
||||
let mut subst := {}
|
||||
for param in info.params, arg in args do
|
||||
for param in info.params, arg in info.args do
|
||||
subst := subst.insert param.fvarId arg
|
||||
let mut paramsNew := #[]
|
||||
for param in info.params[args.size:] do
|
||||
for param in info.params[info.args.size:] do
|
||||
let type ← replaceExprFVars param.type subst
|
||||
let paramNew ← mkAuxParam type
|
||||
paramsNew := paramsNew.push paramNew
|
||||
|
|
@ -410,27 +421,23 @@ def specializePartialApp (info : InlineCandidateInfo) (args : Array Expr) : Simp
|
|||
mkAuxFunDecl paramsNew code
|
||||
|
||||
/--
|
||||
If `e` is an application that can be inlined, inline it.
|
||||
If the value of the given let-declaration is an application that can be inlined, inline it.
|
||||
|
||||
`k?` is the optional "continuation" for `e`, and it may contain loose bound variables
|
||||
that need to instantiated with `xs`. That is, if `k? = some k`, then `k.instantiateRev xs`
|
||||
is an expression without loose bound variables.
|
||||
`k` is the "continuation" for the let declaration.
|
||||
-/
|
||||
partial def inlineApp? (letDecl : LetDecl) (k : Code) : SimpM (Option Code) := do
|
||||
if k matches .unreach .. then return some k
|
||||
let e := letDecl.value
|
||||
let some info ← inlineCandidate? e | return none
|
||||
let some info ← inlineCandidate? letDecl.value | return none
|
||||
markSimplified
|
||||
let args := e.getAppArgs
|
||||
let numArgs := args.size
|
||||
trace[Compiler.simp.inline] "inlining {e}"
|
||||
let numArgs := info.args.size
|
||||
trace[Compiler.simp.inline] "inlining {letDecl.value}"
|
||||
let fvarId := letDecl.fvarId
|
||||
if numArgs < info.arity then
|
||||
let funDecl ← specializePartialApp info args
|
||||
let funDecl ← specializePartialApp info
|
||||
addSubst letDecl.fvarId (.fvar funDecl.fvarId)
|
||||
return some (.fun funDecl k)
|
||||
else
|
||||
let code ← betaReduce info.params info.value args[:info.arity]
|
||||
let code ← betaReduce info.params info.value info.args[:info.arity]
|
||||
if k.isReturnOf fvarId && numArgs == info.arity then
|
||||
/- Easy case, the continuation `k` is just returning the result of the application. -/
|
||||
return code
|
||||
|
|
@ -442,7 +449,7 @@ partial def inlineApp? (letDecl : LetDecl) (k : Code) : SimpM (Option Code) := d
|
|||
code.bind fun fvarId' => do
|
||||
/- fvarId' is the result of the computation -/
|
||||
if numArgs > info.arity then
|
||||
let decl ← mkAuxLetDecl (mkAppN (.fvar fvarId') args[info.arity:])
|
||||
let decl ← mkAuxLetDecl (mkAppN (.fvar fvarId') info.args[info.arity:])
|
||||
let k ← replaceFVar k fvarId decl.fvarId
|
||||
return .let decl k
|
||||
else
|
||||
|
|
@ -452,9 +459,9 @@ partial def inlineApp? (letDecl : LetDecl) (k : Code) : SimpM (Option Code) := d
|
|||
`code` has multiple exit points, and the continuation is non-trivial
|
||||
Thus, we create an auxiliary join point.
|
||||
-/
|
||||
let jpParam ← mkAuxParam (← inferType (mkAppN e.getAppFn args[:info.arity]))
|
||||
let jpParam ← mkAuxParam (← inferType (mkAppN info.f info.args[:info.arity]))
|
||||
let jpValue ← if numArgs > info.arity then
|
||||
let decl ← mkAuxLetDecl (mkAppN (.fvar jpParam.fvarId) args[info.arity:])
|
||||
let decl ← mkAuxLetDecl (mkAppN (.fvar jpParam.fvarId) info.args[info.arity:])
|
||||
let k ← replaceFVar k fvarId decl.fvarId
|
||||
pure <| .let decl k
|
||||
else
|
||||
|
|
|
|||
16
tests/lean/run/inlineApp.lean
Normal file
16
tests/lean/run/inlineApp.lean
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
import Lean
|
||||
|
||||
def f (x : Nat) :=
|
||||
(x - 1) + x * 2 + x*x
|
||||
|
||||
def h (x : Nat) :=
|
||||
inline <| f (x + x)
|
||||
|
||||
#eval Lean.Compiler.compile #[``h]
|
||||
|
||||
open Lean Compiler LCNF in
|
||||
@[cpass] def simpInline : PassInstaller :=
|
||||
Testing.assertDoesNotContainConstAfter `simp `simpInlinesInline ``inline "simp did not inline `inline`"
|
||||
|
||||
set_option trace.Compiler.result true
|
||||
#eval Lean.Compiler.compile #[``h]
|
||||
Loading…
Add table
Reference in a new issue