feat: inline applications of the form inline (f ...)

The `inline` identity function is a directive for the compiler.
This commit is contained in:
Leonardo de Moura 2022-09-10 13:26:22 -07:00
parent 1953f5953f
commit ca098d3769
2 changed files with 44 additions and 21 deletions

View file

@ -311,6 +311,8 @@ structure InlineCandidateInfo where
params : Array Param
/-- Value (lambda expression) of the function to be inlined. -/
value : Code
f : Expr
args : Array Expr
/-- The arity (aka number of parameters) of the function to be inlined. -/
def InlineCandidateInfo.arity : InlineCandidateInfo → Nat
@ -320,31 +322,40 @@ def InlineCandidateInfo.arity : InlineCandidateInfo → Nat
Return `some info` if `e` should be inlined.
-/
def inlineCandidate? (e : Expr) : SimpM (Option InlineCandidateInfo) := do
let mut e := e
let mut mustInline := false
if e.isAppOfArity ``inline 2 then
e ← findExpr e.appArg!
mustInline := true
let numArgs := e.getAppNumArgs
let f := e.getAppFn
if let .const declName us ← findExpr f then
unless hasInlineAttribute (← getEnv) declName do return none
unless mustInline || hasInlineAttribute (← getEnv) declName do return none
-- TODO: check whether function is recursive or not.
-- We can skip the test and store function inline so far.
let some decl ← getStage1Decl? declName | return none
let arity := decl.getArity
let inlinePartial := (← read).config.inlinePartial
if !inlinePartial && numArgs < arity then return none
if !mustInline && !inlinePartial && numArgs < arity then return none
let params := decl.instantiateParamsLevelParams us
let value := decl.instantiateValueLevelParams us
incInline
return some {
isLocal := false
f := e.getAppFn
args := e.getAppArgs
params, value
}
else if let some decl ← findFunDecl? f then
unless numArgs > 0 do return none -- It is not worth to inline a local function that does not take any arguments
unless (← shouldInlineLocal decl) do return none
unless mustInline || (← shouldInlineLocal decl) do return none
-- Remark: we inline local function declarations even if they are partial applied
incInlineLocal
modify fun s => { s with inlineLocal := s.inlineLocal + 1 }
return some {
isLocal := true
f := e.getAppFn
args := e.getAppArgs
params := decl.params
value := decl.value
}
@ -392,15 +403,15 @@ def betaReduce (params : Array Param) (code : Code) (args : Array Expr) (mustInl
return code
/--
Create a new local function declaration when `args.size < info.params.size`.
Create a new local function declaration when `info.args.size < info.params.size`.
We use this function to inline/specialize a partial application of a local function.
-/
def specializePartialApp (info : InlineCandidateInfo) (args : Array Expr) : SimpM FunDecl := do
def specializePartialApp (info : InlineCandidateInfo) : SimpM FunDecl := do
let mut subst := {}
for param in info.params, arg in args do
for param in info.params, arg in info.args do
subst := subst.insert param.fvarId arg
let mut paramsNew := #[]
for param in info.params[args.size:] do
for param in info.params[info.args.size:] do
let type ← replaceExprFVars param.type subst
let paramNew ← mkAuxParam type
paramsNew := paramsNew.push paramNew
@ -410,27 +421,23 @@ def specializePartialApp (info : InlineCandidateInfo) (args : Array Expr) : Simp
mkAuxFunDecl paramsNew code
/--
If `e` is an application that can be inlined, inline it.
If the value of the given let-declaration is an application that can be inlined, inline it.
`k?` is the optional "continuation" for `e`, and it may contain loose bound variables
that need to instantiated with `xs`. That is, if `k? = some k`, then `k.instantiateRev xs`
is an expression without loose bound variables.
`k` is the "continuation" for the let declaration.
-/
partial def inlineApp? (letDecl : LetDecl) (k : Code) : SimpM (Option Code) := do
if k matches .unreach .. then return some k
let e := letDecl.value
let some info ← inlineCandidate? e | return none
let some info ← inlineCandidate? letDecl.value | return none
markSimplified
let args := e.getAppArgs
let numArgs := args.size
trace[Compiler.simp.inline] "inlining {e}"
let numArgs := info.args.size
trace[Compiler.simp.inline] "inlining {letDecl.value}"
let fvarId := letDecl.fvarId
if numArgs < info.arity then
let funDecl ← specializePartialApp info args
let funDecl ← specializePartialApp info
addSubst letDecl.fvarId (.fvar funDecl.fvarId)
return some (.fun funDecl k)
else
let code ← betaReduce info.params info.value args[:info.arity]
let code ← betaReduce info.params info.value info.args[:info.arity]
if k.isReturnOf fvarId && numArgs == info.arity then
/- Easy case, the continuation `k` is just returning the result of the application. -/
return code
@ -442,7 +449,7 @@ partial def inlineApp? (letDecl : LetDecl) (k : Code) : SimpM (Option Code) := d
code.bind fun fvarId' => do
/- fvarId' is the result of the computation -/
if numArgs > info.arity then
let decl ← mkAuxLetDecl (mkAppN (.fvar fvarId') args[info.arity:])
let decl ← mkAuxLetDecl (mkAppN (.fvar fvarId') info.args[info.arity:])
let k ← replaceFVar k fvarId decl.fvarId
return .let decl k
else
@ -452,9 +459,9 @@ partial def inlineApp? (letDecl : LetDecl) (k : Code) : SimpM (Option Code) := d
`code` has multiple exit points, and the continuation is non-trivial
Thus, we create an auxiliary join point.
-/
let jpParam ← mkAuxParam (← inferType (mkAppN e.getAppFn args[:info.arity]))
let jpParam ← mkAuxParam (← inferType (mkAppN info.f info.args[:info.arity]))
let jpValue ← if numArgs > info.arity then
let decl ← mkAuxLetDecl (mkAppN (.fvar jpParam.fvarId) args[info.arity:])
let decl ← mkAuxLetDecl (mkAppN (.fvar jpParam.fvarId) info.args[info.arity:])
let k ← replaceFVar k fvarId decl.fvarId
pure <| .let decl k
else

View file

@ -0,0 +1,16 @@
import Lean
def f (x : Nat) :=
(x - 1) + x * 2 + x*x
def h (x : Nat) :=
inline <| f (x + x)
#eval Lean.Compiler.compile #[``h]
open Lean Compiler LCNF in
@[cpass] def simpInline : PassInstaller :=
Testing.assertDoesNotContainConstAfter `simp `simpInlinesInline ``inline "simp did not inline `inline`"
set_option trace.Compiler.result true
#eval Lean.Compiler.compile #[``h]