From bf2c0bf5b73d6808d5ce8c12e0678276aaf23335 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 18 Aug 2022 00:12:24 -0700 Subject: [PATCH] feat: avoid generation of auxliary join points when inlining functions that have only one exit point --- src/Lean/Compiler/Simp.lean | 160 ++++++++++++++++++++---------------- src/Lean/Compiler/Util.lean | 14 ++++ 2 files changed, 104 insertions(+), 70 deletions(-) diff --git a/src/Lean/Compiler/Simp.lean b/src/Lean/Compiler/Simp.lean index 027c139f50..1d19e2ad32 100644 --- a/src/Lean/Compiler/Simp.lean +++ b/src/Lean/Compiler/Simp.lean @@ -181,23 +181,41 @@ def simpAppApp? (e : Expr) : OptionT SimpM Expr := do def shouldInline (localDecl : LocalDecl) : SimpM Bool := return (← read).localInline && (← read).stats.shouldInline localDecl.userName -def inlineCandidate? (e : Expr) : SimpM (Option Nat) := do +structure InlineCandidateInfo where + isLocal : Bool + arity : Nat + /-- Value (lambda expression) of the function to be inlined. -/ + value : Expr + +def inlineCandidate? (e : Expr) : SimpM (Option InlineCandidateInfo) := do let f := e.getAppFn - let arity ← match f with - | .const declName _ => - unless hasInlineAttribute (← getEnv) declName do return none - -- TODO: check whether function is recursive or not. - -- We can skip the test and store function inline so far. - let some decl ← getStage1Decl? declName | return none - pure decl.getArity - | _ => - match (← findLambda? f) with - | none => return none - | some localDecl => - unless (← shouldInline localDecl) do return none - pure (getLambdaArity localDecl.value) - if e.getAppNumArgs < arity then return none - return e.getAppNumArgs - arity + match f with + | .const declName us => + unless hasInlineAttribute (← getEnv) declName do return none + -- TODO: check whether function is recursive or not. + -- We can skip the test and store function inline so far. + let some decl ← getStage1Decl? declName | return none + let numArgs := e.getAppNumArgs + let arity := decl.getArity + if numArgs < arity then return none + return some { + arity + isLocal := false + value := decl.value.instantiateLevelParams decl.levelParams us + } + | _ => + match (← findLambda? f) with + | none => return none + | some localDecl => + unless (← shouldInline localDecl) do return none + let numArgs := e.getAppNumArgs + let arity := getLambdaArity localDecl.value + if numArgs < arity then return none + return some { + arity + isLocal := true + value := localDecl.value + } /-- If `e` if a free variable that expands to a valid LCNF terminal `let`-block expression `e'`, @@ -255,23 +273,6 @@ partial def inlineProjInst? (e : Expr) : OptionT SimpM Expr := do saved.restore return none -partial def inlineApp (e : Expr) (jp? : Option Expr := none) : SimpM Expr := do - let f := e.getAppFn - trace[Compiler.simp.inline] "inlining {e}" - let value ← match f with - | .const declName us => - let some decl ← getStage1Decl? declName | unreachable! - pure <| decl.value.instantiateLevelParams decl.levelParams us - | _ => - let some localDecl ← findLambda? f | unreachable! - pure localDecl.value - let args := e.getAppArgs - let value := value.beta args - let value ← attachOptJp value jp? - assert! !value.isLambda - markSimplified - withLocalInline (!f.isConst) do visitLet value - /-- If `e` is an application that can be inlined, inline it. @@ -280,45 +281,62 @@ that need to instantiated with `xs`. That is, if `k? = some k`, then `k.instanti is an expression without loose bound variables. -/ partial def inlineApp? (e : Expr) (xs : Array Expr) (k? : Option Expr) : SimpM (Option Expr) := do - let some numExtraArgs ← inlineCandidate? e | return none + let some info ← inlineCandidate? e | return none let args := e.getAppArgs - if k?.isNone && numExtraArgs == 0 then - -- Easy case, there is not continuation and `e` is not over applied - inlineApp e + let numArgs := args.size + trace[Compiler.simp.inline] "inlining {e}" + markSimplified + withLocalInline info.isLocal do + if !(← manyExitPoints info.value) then + -- If `info.value` has only one exit point, we don't need to create a new join point + let value := info.value.beta args[:info.arity] + let value ← visitLet value #[] + trace[Meta.debug] "value: {value}" + match numArgs == info.arity, k? with + | true, none => return value + | false, none => return mkAppN (← mkAuxLetDecl value) args[info.arity:] + | true, some k => let x ← mkAuxLetDecl value; visitLet k (xs.push x) + | false, some k => + let x ← mkAuxLetDecl value + let x ← mkAuxLetDecl (mkAppN x args[info.arity:]) + visitLet k (xs.push x) else - /- - There is a continuation `k` or `e` is over applied. - If `e` is over applied, the extra arguments act as continuation. - -/ - let toInline := mkAppN e.getAppFn args[:args.size - numExtraArgs] - /- - `toInline` is the application that is going to be inline - We create a new join point - ``` - let jp := fun y => - let x := y -- if `e` is over applied - k - ``` - Recall that `visitLet` incorporates the current continuation - to the new join point `jp`. - -/ - let jpDomain ← inferType toInline - let binderName ← mkFreshUserName `_y - let jp ← withNewScope do - let y ← mkLocalDecl binderName jpDomain - let body ← if numExtraArgs == 0 then - visitLet k?.get! (xs.push y) - else - let x ← mkAuxLetDecl (mkAppN y args[args.size - numExtraArgs:]) - if let some k := k? then - visitLet k (xs.push x) + let args := e.getAppArgs + if k?.isNone && numArgs == info.arity then + -- Easy case, there is no continuation and `e` is not overapplied + return info.value.beta args + else + /- + There is a continuation `k` or `e` is over applied. + If `e` is over applied, the extra arguments act as a continuation. + + We create a new join point + ``` + let jp := fun y => + let x := y -- if `e` is over applied + k + ``` + Recall that `visitLet` incorporates the current continuation + to the new join point `jp`. + -/ + let jpDomain ← inferType (mkAppN e.getAppFn args[:info.arity]) + let binderName ← mkFreshUserName `_y + let jp ← withNewScope do + let y ← mkLocalDecl binderName jpDomain + let body ← if numArgs == info.arity then + visitLet k?.get! (xs.push y) else - visitLet x (xs.push x) - let body ← mkLetUsingScope body - mkLambda #[y] body - let jp ← mkJpDeclIfNotSimple jp - /- Inline `toInline` and "go-to" `jp` with the result. -/ - inlineApp toInline jp + let x ← mkAuxLetDecl (mkAppN y args[info.arity:]) + if let some k := k? then + visitLet k (xs.push x) + else + visitLet x (xs.push x) + let body ← mkLetUsingScope body + mkLambda #[y] body + let jp ← mkJpDeclIfNotSimple jp + let value := info.value.beta args[:info.arity] + let value ← attachJp value jp + visitLet value /-- Try to apply simple simplifications. -/ partial def simpValue? (e : Expr) : SimpM (Option Expr) := @@ -365,6 +383,7 @@ def Decl.simp? (decl : Decl) : CoreM (Option Decl) := do trace[Compiler.simp.inline.stats] "{decl.name}:{Format.nest 2 (format stats)}" let (value, s) ← Simp.visitLambda decl.value |>.run { stats } |>.run { simplified := false } |>.run' {} trace[Compiler.simp.step] "{decl.name} :=\n{decl.value}" + trace[Compiler.simp.stat] "{decl.name}: {← getLCNFSize decl.value}" if s.simplified then return some { decl with value } else @@ -379,7 +398,8 @@ partial def Decl.simp (decl : Decl) : CoreM Decl := do builtin_initialize registerTraceClass `Compiler.simp.inline + registerTraceClass `Compiler.simp.stat registerTraceClass `Compiler.simp.step registerTraceClass `Compiler.simp.inline.stats -end Lean.Compiler \ No newline at end of file +end Lean.Compiler diff --git a/src/Lean/Compiler/Util.lean b/src/Lean/Compiler/Util.lean index 71fa94e107..6861a13710 100644 --- a/src/Lean/Compiler/Util.lean +++ b/src/Lean/Compiler/Util.lean @@ -175,4 +175,18 @@ def isJump? [Monad m] [MonadLCtx m] (e : Expr) : m (Option FVarId) := do else return none +/-- +Return if the LCNF expression has many exit points. +It assumes `cases` expressions only occur at the end of `let`-blocks. +That is, `terminalCases` has already been applied. +It also assumes that if contains a join point, then it has multiple +exit points. This is a reasonable assumption because the simplifier +inlines any join point that was used only once. +-/ +def manyExitPoints (e : Expr) : CoreM Bool := do + match e with + | .lam _ _ b _ => manyExitPoints b + | .letE n _ _ b _ => pure (isJpBinderName n) <||> manyExitPoints b + | e => return (← isCasesApp? e).isSome + end Lean.Compiler