diff --git a/src/Lean/Compiler/LCNF/Simp.lean b/src/Lean/Compiler/LCNF/Simp.lean index 98bb8867c9..af8353d7bf 100644 --- a/src/Lean/Compiler/LCNF/Simp.lean +++ b/src/Lean/Compiler/LCNF/Simp.lean @@ -311,6 +311,8 @@ structure InlineCandidateInfo where params : Array Param /-- Value (lambda expression) of the function to be inlined. -/ value : Code + f : Expr + args : Array Expr /-- The arity (aka number of parameters) of the function to be inlined. -/ def InlineCandidateInfo.arity : InlineCandidateInfo → Nat @@ -320,31 +322,40 @@ def InlineCandidateInfo.arity : InlineCandidateInfo → Nat Return `some info` if `e` should be inlined. -/ def inlineCandidate? (e : Expr) : SimpM (Option InlineCandidateInfo) := do + let mut e := e + let mut mustInline := false + if e.isAppOfArity ``inline 2 then + e ← findExpr e.appArg! + mustInline := true let numArgs := e.getAppNumArgs let f := e.getAppFn if let .const declName us ← findExpr f then - unless hasInlineAttribute (← getEnv) declName do return none + unless mustInline || hasInlineAttribute (← getEnv) declName do return none -- TODO: check whether function is recursive or not. -- We can skip the test and store function inline so far. let some decl ← getStage1Decl? declName | return none let arity := decl.getArity let inlinePartial := (← read).config.inlinePartial - if !inlinePartial && numArgs < arity then return none + if !mustInline && !inlinePartial && numArgs < arity then return none let params := decl.instantiateParamsLevelParams us let value := decl.instantiateValueLevelParams us incInline return some { isLocal := false + f := e.getAppFn + args := e.getAppArgs params, value } else if let some decl ← findFunDecl? f then unless numArgs > 0 do return none -- It is not worth to inline a local function that does not take any arguments - unless (← shouldInlineLocal decl) do return none + unless mustInline || (← shouldInlineLocal decl) do return none -- Remark: we inline local function declarations even if they are partial applied incInlineLocal modify fun s => { s with inlineLocal := s.inlineLocal + 1 } return some { isLocal := true + f := e.getAppFn + args := e.getAppArgs params := decl.params value := decl.value } @@ -392,15 +403,15 @@ def betaReduce (params : Array Param) (code : Code) (args : Array Expr) (mustInl return code /-- -Create a new local function declaration when `args.size < info.params.size`. +Create a new local function declaration when `info.args.size < info.params.size`. We use this function to inline/specialize a partial application of a local function. -/ -def specializePartialApp (info : InlineCandidateInfo) (args : Array Expr) : SimpM FunDecl := do +def specializePartialApp (info : InlineCandidateInfo) : SimpM FunDecl := do let mut subst := {} - for param in info.params, arg in args do + for param in info.params, arg in info.args do subst := subst.insert param.fvarId arg let mut paramsNew := #[] - for param in info.params[args.size:] do + for param in info.params[info.args.size:] do let type ← replaceExprFVars param.type subst let paramNew ← mkAuxParam type paramsNew := paramsNew.push paramNew @@ -410,27 +421,23 @@ def specializePartialApp (info : InlineCandidateInfo) (args : Array Expr) : Simp mkAuxFunDecl paramsNew code /-- -If `e` is an application that can be inlined, inline it. +If the value of the given let-declaration is an application that can be inlined, inline it. -`k?` is the optional "continuation" for `e`, and it may contain loose bound variables -that need to instantiated with `xs`. That is, if `k? = some k`, then `k.instantiateRev xs` -is an expression without loose bound variables. +`k` is the "continuation" for the let declaration. -/ partial def inlineApp? (letDecl : LetDecl) (k : Code) : SimpM (Option Code) := do if k matches .unreach .. then return some k - let e := letDecl.value - let some info ← inlineCandidate? e | return none + let some info ← inlineCandidate? letDecl.value | return none markSimplified - let args := e.getAppArgs - let numArgs := args.size - trace[Compiler.simp.inline] "inlining {e}" + let numArgs := info.args.size + trace[Compiler.simp.inline] "inlining {letDecl.value}" let fvarId := letDecl.fvarId if numArgs < info.arity then - let funDecl ← specializePartialApp info args + let funDecl ← specializePartialApp info addSubst letDecl.fvarId (.fvar funDecl.fvarId) return some (.fun funDecl k) else - let code ← betaReduce info.params info.value args[:info.arity] + let code ← betaReduce info.params info.value info.args[:info.arity] if k.isReturnOf fvarId && numArgs == info.arity then /- Easy case, the continuation `k` is just returning the result of the application. -/ return code @@ -442,7 +449,7 @@ partial def inlineApp? (letDecl : LetDecl) (k : Code) : SimpM (Option Code) := d code.bind fun fvarId' => do /- fvarId' is the result of the computation -/ if numArgs > info.arity then - let decl ← mkAuxLetDecl (mkAppN (.fvar fvarId') args[info.arity:]) + let decl ← mkAuxLetDecl (mkAppN (.fvar fvarId') info.args[info.arity:]) let k ← replaceFVar k fvarId decl.fvarId return .let decl k else @@ -452,9 +459,9 @@ partial def inlineApp? (letDecl : LetDecl) (k : Code) : SimpM (Option Code) := d `code` has multiple exit points, and the continuation is non-trivial Thus, we create an auxiliary join point. -/ - let jpParam ← mkAuxParam (← inferType (mkAppN e.getAppFn args[:info.arity])) + let jpParam ← mkAuxParam (← inferType (mkAppN info.f info.args[:info.arity])) let jpValue ← if numArgs > info.arity then - let decl ← mkAuxLetDecl (mkAppN (.fvar jpParam.fvarId) args[info.arity:]) + let decl ← mkAuxLetDecl (mkAppN (.fvar jpParam.fvarId) info.args[info.arity:]) let k ← replaceFVar k fvarId decl.fvarId pure <| .let decl k else diff --git a/tests/lean/run/inlineApp.lean b/tests/lean/run/inlineApp.lean new file mode 100644 index 0000000000..edc720fbdb --- /dev/null +++ b/tests/lean/run/inlineApp.lean @@ -0,0 +1,16 @@ +import Lean + +def f (x : Nat) := + (x - 1) + x * 2 + x*x + +def h (x : Nat) := + inline <| f (x + x) + +#eval Lean.Compiler.compile #[``h] + +open Lean Compiler LCNF in +@[cpass] def simpInline : PassInstaller := + Testing.assertDoesNotContainConstAfter `simp `simpInlinesInline ``inline "simp did not inline `inline`" + +set_option trace.Compiler.result true +#eval Lean.Compiler.compile #[``h]