From a7c96142ea0ec2fa02c808b5a126df818828dc8a Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Fri, 19 Aug 2022 11:33:16 -0700 Subject: [PATCH] feat: improve code inliner and fix bugs at the `onlyOneExitPoint` case. --- src/Lean/Compiler/Simp.lean | 157 +++++++++++++++++++++--------------- 1 file changed, 94 insertions(+), 63 deletions(-) diff --git a/src/Lean/Compiler/Simp.lean b/src/Lean/Compiler/Simp.lean index 66917f88e2..85d2b55d43 100644 --- a/src/Lean/Compiler/Simp.lean +++ b/src/Lean/Compiler/Simp.lean @@ -215,14 +215,45 @@ def inlineCandidate? (e : Expr) : SimpM (Option InlineCandidateInfo) := do If `e` if a free variable that expands to a valid LCNF terminal `let`-block expression `e'`, return `e'`. -/ -def expandTrivialExpr? (e : Expr) : SimpM (Option Expr) := do +def expandTrivialExpr (e : Expr) : SimpM Expr := do if e.isFVar then let e' ← findExpr e unless e'.isLambda do if e != e' then markSimplified - return some e' - return none + return e' + return e + +/-- +Given `value` of the form `let x_1 := v_1; ...; let x_n := v_n; e`, +return `let x_1; ...; let x_n := v_n; let y : type := e; body`. + +This methods assumes `type` and `value` do not have loose bound variables. + +Remark: `body` may have many loose bound variables, and the loose bound variables > 0 +must be lifted by `n`. +-/ +def mkFlatLet (y : Name) (type : Expr) (value : Expr) (body : Expr) (nonDep : Bool := false) : Expr := + go value 0 +where + go (value : Expr) (i : Nat) : Expr := + match value with + | .letE n t v b d => .letE n t v (go b (i+1)) d + | _ => .letE y type value (body.liftLooseBVars 1 i) nonDep + +/-- +Update inlining statistics (`stats` field) with the local function +declarations in `e`. +We use this method to make sure type class instance elements are +inlined in the current compiler simp pass. +-/ +private def updateStatsUsing (e : Expr) : SimpM Unit := do + match e with + | .letE binderName _ v b _ => + if v.isLambda then + modify fun s => { s with stats := s.stats.add binderName 1 } + updateStatsUsing b + | _ => return () /-- Auxiliary function for projecting "type class dictionary access". @@ -257,12 +288,8 @@ partial def inlineProjInst? (e : Expr) : OptionT SimpM Expr := do -/ let value ← withNewScope do mkLetUsingScope (← visitProj e) let value ← ensureUniqueLetVarNames value - /- We use `visitLet` again to put back on the current local context the relevant let-declarations. -/ - visitLet (m := SimpM) value fun binderName value => do - if value.isLambda then - /- make sure instance element can be beta reduced in this simp step. -/ - modify fun s => { s with stats := s.stats.add binderName 1 } - return value + updateStatsUsing value + return value where visitProj (e : Expr) : OptionT SimpM Expr := do let .proj _ i s := e | unreachable! @@ -285,7 +312,12 @@ where guard <| decl.getArity == e.getAppNumArgs let value := decl.value.instantiateLevelParams decl.levelParams us let value := value.beta e.getAppArgs - let value ← visitLet (m := SimpM) value fun _ value => return value + /- + Here, we just go inside of the let-declaration block without trying to simplify it. + Reason: a type class instannce may have many elements, and it does not make sense to simplify + all of them when we are extracting only one of them. + -/ + let value ← Compiler.visitLet (m := SimpM) value fun _ value => return value visit value mutual @@ -327,55 +359,51 @@ partial def inlineApp? (e : Expr) (xs : Array Expr) (k? : Option Expr) : SimpM ( let numArgs := args.size trace[Compiler.simp.inline] "inlining {e}" markSimplified - if !(← manyExitPoints info.value) then - /- If `info.value` has only one exit point, we don't need to create a new join point -/ - let value := info.value.beta args[:info.arity] - let value ← visitLet value #[] - match numArgs == info.arity, k? with - | true, none => return value - | false, none => return mkAppN (← mkAuxLetDecl value) args[info.arity:] - | true, some k => let x ← mkAuxLetDecl value; visitLet k (xs.push x) - | false, some k => - let x ← mkAuxLetDecl value - let x ← mkAuxLetDecl (mkAppN x args[info.arity:]) - visitLet k (xs.push x) + if k?.isNone && numArgs == info.arity then + /- Easy case, there is no continuation and `e` is not over applied -/ + visitLet (info.value.beta args) + else if (← onlyOneExitPoint info.value) then + /- If `info.value` has only one exit point, we don't need to create a new auxiliary join point -/ + let mut value := info.value.beta args[:info.arity] + if numArgs > info.arity then + let type ← inferType (mkAppN e.getAppFn args[:info.arity]) + value := mkFlatLet (← mkAuxLetDeclName) type value (mkAppN (.bvar 0) args[info.arity:]) + if let some k := k? then + let type ← inferType e + value := mkFlatLet (← mkAuxLetDeclName) type value k + visitLet value xs else - let args := e.getAppArgs - if k?.isNone && numArgs == info.arity then - /- Easy case, there is no continuation and `e` is not overapplied -/ - return info.value.beta args - else - /- - There is a continuation `k` or `e` is over applied. - If `e` is over applied, the extra arguments act as a continuation. + /- + There is a continuation `k` or `e` is over applied. + If `e` is over applied, the extra arguments act as a continuation. - We create a new join point - ``` - let jp := fun y => - let x := y -- if `e` is over applied - k - ``` - Recall that `visitLet` incorporates the current continuation - to the new join point `jp`. - -/ - let jpDomain ← inferType (mkAppN e.getAppFn args[:info.arity]) - let binderName ← mkFreshUserName `_y - let jp ← withNewScope do - let y ← mkLocalDecl binderName jpDomain - let body ← if numArgs == info.arity then - visitLet k?.get! (xs.push y) + We create a new join point + ``` + let jp := fun y => + let x := y -- if `e` is over applied + k + ``` + Recall that `visitLet` incorporates the current continuation + to the new join point `jp`. + -/ + let jpDomain ← inferType (mkAppN e.getAppFn args[:info.arity]) + let binderName ← mkFreshUserName `_y + let jp ← withNewScope do + let y ← mkLocalDecl binderName jpDomain + let body ← if numArgs == info.arity then + visitLet k?.get! (xs.push y) + else + let x ← mkAuxLetDecl (mkAppN y args[info.arity:]) + if let some k := k? then + visitLet k (xs.push x) else - let x ← mkAuxLetDecl (mkAppN y args[info.arity:]) - if let some k := k? then - visitLet k (xs.push x) - else - visitLet x (xs.push x) - let body ← mkLetUsingScope body - mkLambda #[y] body - let jp ← mkJpDeclIfNotSimple jp - let value := info.value.beta args[:info.arity] - let value ← attachJp value jp - visitLet value + visitLet x (xs.push x) + let body ← mkLetUsingScope body + mkLambda #[y] body + let jp ← mkJpDeclIfNotSimple jp + let value := info.value.beta args[:info.arity] + let value ← attachJp value jp + visitLet value /-- Try to apply simple simplifications. -/ partial def simpValue? (e : Expr) : SimpM (Option Expr) := @@ -392,28 +420,31 @@ partial def visitLet (e : Expr) (xs : Array Expr := #[]): SimpM Expr := do if value.isLambda then value ← visitLambda value else if let some value' ← simpValue? value then + if value'.isLet then + let e := mkFlatLet binderName type value' body nonDep + let e ← visitLet e xs + return e value := value' if value.isFVar then /- Eliminate `let _x_i := _x_j;` -/ markSimplified visitLet body (xs.push value) else if let some e ← inlineApp? value xs body then - visitLet e + return e else let type := type.instantiateRev xs let x ← mkLetDecl binderName type value nonDep visitLet body (xs.push x) | _ => let e := e.instantiateRev xs - let e := (← simpValue? e).getD e - if let some casesInfo ← isCasesApp? e then + if let some value ← simpValue? e then + visitLet value + else if let some casesInfo ← isCasesApp? e then visitCases casesInfo e else if let some e ← inlineApp? e #[] none then - visitLet e - else if let some e ← expandTrivialExpr? e then - visitLet e - else return e + else + expandTrivialExpr e end end Simp