From fd1ae3118cd6b718cfff94164e28ac770be54646 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 25 Sep 2022 20:54:41 -0700 Subject: [PATCH] feat: replace `isCasesOnCases?` with `simpJpCases?` It addresses the code explosion issue with the old optimization. For example, the resulting size for `Lean.Json.Parser.escapedChar` went from 31593 to 361. --- src/Lean/Compiler/LCNF/Basic.lean | 9 + src/Lean/Compiler/LCNF/Simp.lean | 313 +++++++++++++----- .../linterUnusedVariables.lean.expected.out | 2 +- 3 files changed, 236 insertions(+), 88 deletions(-) diff --git a/src/Lean/Compiler/LCNF/Basic.lean b/src/Lean/Compiler/LCNF/Basic.lean index 0486591588..8d77d6e6a7 100644 --- a/src/Lean/Compiler/LCNF/Basic.lean +++ b/src/Lean/Compiler/LCNF/Basic.lean @@ -70,6 +70,15 @@ abbrev Alt := AltCore Code abbrev FunDecl := FunDeclCore Code abbrev Cases := CasesCore Code +/-- +Return the constructor names that have an explicit (non-default) alternative. +-/ +def CasesCore.getCtorNames (c : Cases) : NameSet := + c.alts.foldl (init := {}) fun ctorNames alt => + match alt with + | .default _ => ctorNames + | .alt ctorName .. => ctorNames.insert ctorName + inductive CodeDecl where | let (decl : LetDecl) | fun (decl : FunDecl) diff --git a/src/Lean/Compiler/LCNF/Simp.lean b/src/Lean/Compiler/LCNF/Simp.lean index 88276e8449..fa2ceb838f 100644 --- a/src/Lean/Compiler/LCNF/Simp.lean +++ b/src/Lean/Compiler/LCNF/Simp.lean @@ -14,6 +14,7 @@ import Lean.Compiler.LCNF.Bind import Lean.Compiler.LCNF.PrettyPrinter import Lean.Compiler.LCNF.PassManager import Lean.Compiler.LCNF.AlphaEqv +import Lean.Compiler.LCNF.DependsOn namespace Lean.Compiler.LCNF @@ -687,85 +688,6 @@ where | .return fvarId => visit (.fvar fvarId) projs | _ => failure -/-- -Given the function declaration `decl`, return `true` if it is of the form -``` -f y := - ... /- This part is not bigger than smallThreshold. -/ - cases y - | ... => ... - ... -``` --/ -def isJpCases (decl : FunDecl) (smallThreshold : Nat) : CompilerM Bool := do - if decl.params.size != 1 then - return false - else - let param := decl.params[0]! - let rec go (code : Code) (prefixSize : Nat) : Bool := - prefixSize <= smallThreshold && - match code with - | .let _ k => go k (prefixSize + 1) /- TODO: we should have uniform heuristics for estimating the size. -/ - | .cases c => c.discr == param.fvarId - | _ => false - return go decl.value 0 - -/-- -Return `some _jp.k` if the given `cases` is of the form -``` -cases _x.i - (... let _x.j₁ := ctorⱼ₁ ...; _jp.k _x.j₁) - ... - (... let _x.jₙ := ctorⱼₙ ...; _jp.k _x.jₙ) -``` -where `_jp.k` is a join point of the form -``` -let _jp.k y := - cases y ... -``` -The goal is to mark `_jp.k` as must inline in this scenario. -The function also allows `_jp.k` to have a small prefix before -`cases y`. The small prefix is set using the configuration option -`config.smallThreshold`. It is currently the same threshold used to -decide when to inline a function that has multiple occurrences. - -Example: consider the following declarations -``` -@[inline] def pred? (x : Nat) : Option Nat := - match x with - | 0 => none - | x+1 => some x - -def isZero (x : Nat) := - match pred? x with - | some _ => false - | none => true -``` -After inlining `pred?` in `isZero`, this simplification is applicable, producing - -Remark: this method does not assume `cases` has already been normalized, -but returns a normalized `FVarId` in case of success. --/ -def isCasesOnCases? (cases : Cases) : OptionT SimpM FVarId := do - let jpFirst ← isCtorJmp? cases.alts[0]!.getCode - let funDecl ← getFunDecl jpFirst - let smallThreshold := (← read).config.smallThreshold - guard <| (← isJpCases funDecl smallThreshold) - for alt in cases.alts[1:] do - let jp ← isCtorJmp? alt.getCode - guard <| jpFirst == jp - return jpFirst -where - isCtorJmp? (code : Code) : OptionT SimpM FVarId := do - match code with - | .let _ k | .jp _ k | .fun _ k => isCtorJmp? k - | .return .. | .unreach .. | .cases .. => failure - | .jmp jpFVarId args => - let #[arg] := args | failure - let arg ← findExpr (← normExpr arg) - guard <| arg.isConstructorApp (← getEnv) - normFVar jpFVarId - /-- Return the alternative in `alts` whose body appears in most arms, and the number of occurrences. @@ -1119,23 +1041,49 @@ partial def simp (code : Code) : SimpM Code := withIncRecDepth do return alts[0]!.getCode else return code.updateCases! resultType discr alts - if let some jpFVarId ← isCasesOnCases? c then - withAddMustInline jpFVarId simpCasesDefault - else - simpCasesDefault + simpCasesDefault end +/-- +Given the function declaration `decl`, return `true` if it is of the form +``` +f y := + ... /- This part is not bigger than smallThreshold. -/ + cases y + | ... => ... + ... +``` +-/ +def isJpCases (decl : FunDecl) (smallThreshold : Nat) : CompilerM Bool := do + if decl.params.size != 1 then + return false + else + let param := decl.params[0]! + let rec go (code : Code) (prefixSize : Nat) : Bool := + prefixSize <= smallThreshold && + match code with + | .let _ k => go k (prefixSize + 1) /- TODO: we should have uniform heuristics for estimating the size. -/ + | .cases c => c.discr == param.fvarId + | _ => false + return go decl.value 0 + +abbrev JpCasesInfo := FVarIdMap NameSet + +/-- Return `true` if the collected information suggests opportunities for the `JpCases` optimization. -/ +def JpCasesInfo.isCandidate (info : JpCasesInfo) : Bool := + info.any fun _ s => !s.isEmpty + /-- Return a map containing entries `jpFVarId ↦ ctorNames` where `jpFVarId` is the id of join point in code that satisfies `isJpCases`, and `ctorNames` is a set of constructor names such that there is a jump `.jmp jpFVarId #[x]` in `code` and `x` is a constructor application. -/ -partial def collectJpCasesInfo (code : Code) (smallThreshold : Nat): CompilerM (FVarIdMap NameSet) := do +partial def collectJpCasesInfo (code : Code) (smallThreshold : Nat): CompilerM JpCasesInfo := do let (_, s) ← go code |>.run {} return s where - go (code : Code) : StateRefT (FVarIdMap NameSet) CompilerM Unit := do + go (code : Code) : StateRefT JpCasesInfo CompilerM Unit := do match code with | .let _ k => go k | .fun decl k => go decl.value; go k @@ -1152,6 +1100,192 @@ where let some (cval, _) := arg.constructorApp? (← getEnv) | return () modify fun s => s.insert fvarId <| ctorNames.insert cval.name +/-- +Extract the let-declarations and `cases` for a join point body that satisfies `isJpCases`. +-/ +private def extractJpCases (code : Code) : Array CodeDecl × Cases := + go code #[] +where + go (code : Code) (decls : Array CodeDecl) := + match code with + | .let decl k => go k <| decls.push (.let decl) + | .cases c => (decls, c) + | _ => unreachable! -- `code` is not the body of a join point that satisfies `isJpCases` + +structure JpCasesAlt where + decl : FunDecl + default : Bool + dependsOnDiscr : Bool + +abbrev Ctor2JpCasesAlt := FVarIdMap (NameMap JpCasesAlt) + +open Internalize in +private def mkJpAlt (decls : Array CodeDecl) (discr : Param) (fields : Array Param) (k : Code) (default : Bool) : CompilerM JpCasesAlt := do + go |>.run' {} +where + go : InternalizeM JpCasesAlt := do + let s : FVarIdSet := {} + let mut paramsNew := #[] + let dependsOnDiscr := k.dependsOn (s.insert discr.fvarId) + if dependsOnDiscr then + paramsNew := paramsNew.push (← internalizeParam discr) + paramsNew := paramsNew ++ (← fields.mapM internalizeParam) + let decls ← decls.mapM internalizeCodeDecl + let k ← internalizeCode k + let value := LCNF.attachCodeDecls decls k + return { decl := (← mkAuxJpDecl paramsNew value), default, dependsOnDiscr } + +/-- +Try to optimize `jpCases` join points. +We say a join point is a `jpCases` when it satifies the predicate `isJpCases`. +If we have a jump to `jpCases` with a constructor, then we can optimize the code by creating an new join point for +the constructor. +Example: suppose we have +```lean +jp _jp.1 y := + let x.1 := true + cases y + | nil => let x.2 := g x.1; return x.2 + | cons h t => let x.3 := h x.1; return x.3 +... +cases x.4 +| ctor1 => + let x.5 := cons z.1 z.2 + jmp _jp.1 x.5 +| ctor2 => + let x.6 := f x.4 + jmp _jp.1 x.6 +``` +This `simpJpCases?` converts it to +```lean +jp _jp.2 h t := + let x.1 := true + let x.3 := h x.1 + return x.3 +jp _jp.1 y := + let x.1 := true + cases y + | nil => let x.2 := g x.1; return x.2 + | cons h t => jmp _jp.2 h t +... +cases x.4 +| ctor1 => + -- The constructor has been eliminated here + jmp _jp.2 z.1 z.2 +| ctor2 => + let x.6 := f x.4 + jmp _jp.1 x.6 +``` +Note that if all jumps to the join point are with constructors, +then the join point is eliminated as dead code. +-/ +partial def simpJpCases? (code : Code) (smallThreshold : Nat) : CompilerM (Option Code) := do + let info ← collectJpCasesInfo code smallThreshold + unless info.isCandidate do return none + traceM `Compiler.simp.jpCases do + let mut msg : MessageData := "candidates" + for (fvarId, ctorName) in info.toList do + msg := msg ++ indentD m!"{mkFVar fvarId} ↦ {ctorName.toList}" + return msg + visit code info |>.run' {} +where + visit (code : Code) : ReaderT JpCasesInfo (StateRefT Ctor2JpCasesAlt CompilerM) Code := do + match code with + | .let decl k => + return code.updateLet! decl (← visit k) + | .fun decl k => + let value ← visit decl.value + let decl ← decl.updateValue value + return code.updateFun! decl (← visit k) + | .jp decl k => + if let some code ← visitJp? decl k then + return code + else + let value ← visit decl.value + let decl ← decl.updateValue value + return code.updateFun! decl (← visit k) + | .cases c => + let alts ← c.alts.mapMonoM fun alt => return alt.updateCode (← visit alt.getCode) + return code.updateAlts! alts + | .return _ | .unreach _ => return code + | .jmp fvarId args => + let some code ← visitJmp? fvarId args | return code + return code + + visitJp? (decl : FunDecl) (k : Code) : ReaderT JpCasesInfo (StateRefT Ctor2JpCasesAlt CompilerM) (Option Code) := do + let some s := (← read).find? decl.fvarId | return none + if s.isEmpty then return none + -- This join point satisfies `isJp` and there jumps with constructors in `s` to it. + let p := decl.params[0]! + let (decls, cases) := extractJpCases decl.value + let mut jpAltMap := {} + let mut jpAltDecls := #[] + let mut altsNew := #[] + for alt in cases.alts do + match alt with + | .default k => + let k ← visit k + let explicitCtorNames := cases.getCtorNames + if s.any fun ctorNameInJump => !explicitCtorNames.contains ctorNameInJump then + let jpAlt ← mkJpAlt decls p #[] k (default := true) + jpAltDecls := jpAltDecls.push (.jp jpAlt.decl) + eraseCode k + for ctorNameInJmp in s do + unless explicitCtorNames.contains ctorNameInJmp do + jpAltMap := jpAltMap.insert ctorNameInJmp jpAlt + let args := if jpAlt.dependsOnDiscr then #[.fvar p.fvarId] else #[] + altsNew := altsNew.push (alt.updateCode (.jmp jpAlt.decl.fvarId args)) + else + altsNew := altsNew.push (alt.updateCode k) + | .alt ctorName fields k => + let k ← visit k + if s.contains ctorName then + let jpAlt ← mkJpAlt decls p fields k (default := false) + jpAltDecls := jpAltDecls.push (.jp jpAlt.decl) + jpAltMap := jpAltMap.insert ctorName jpAlt + let mut args := fields.map (mkFVar ·.fvarId) + if jpAlt.dependsOnDiscr then + args := #[mkFVar p.fvarId] ++ args + eraseCode k + altsNew := altsNew.push (alt.updateCode (.jmp jpAlt.decl.fvarId args)) + else + altsNew := altsNew.push (alt.updateCode k) + modify fun s => s.insert decl.fvarId jpAltMap + let value := LCNF.attachCodeDecls decls (.cases { cases with alts := altsNew }) + let decl ← decl.updateValue value + let code := .jp decl (← visit k) + return LCNF.attachCodeDecls jpAltDecls code + + visitJmp? (fvarId : FVarId) (args : Array Expr) : ReaderT JpCasesInfo (StateRefT Ctor2JpCasesAlt CompilerM) (Option Code) := do + let some ctorJpAltMap := (← get).find? fvarId | return none + assert! args.size == 1 + let arg ← findExpr args[0]! + let some (ctorVal, ctorArgs) := arg.constructorApp? (← getEnv) (useRaw := true) | return none + let some jpAlt := ctorJpAltMap.find? ctorVal.name | return none + if jpAlt.default then + if jpAlt.dependsOnDiscr then + return some <| .jmp jpAlt.decl.fvarId args + else + return some <| .jmp jpAlt.decl.fvarId #[] + else + let fields := ctorArgs[ctorVal.numParams:] + -- Recall that if `arg` is a `Nat` literal, then `ctorArgs` is a literal too. + -- We use a for-loop because we may have other special cases in the future. + let mut auxDecls := #[] + let mut fieldsNew := #[] + for field in fields do + if field.isFVar then + fieldsNew := fieldsNew.push field + else + let letDecl ← mkAuxLetDecl field + auxDecls := auxDecls.push (CodeDecl.let letDecl) + fieldsNew := fieldsNew.push (.fvar letDecl.fvarId) + let code ← if jpAlt.dependsOnDiscr then + pure <| .jmp jpAlt.decl.fvarId (args ++ fieldsNew) + else + pure <| .jmp jpAlt.decl.fvarId fieldsNew + return some <| LCNF.attachCodeDecls auxDecls code + end Simp open Simp @@ -1160,11 +1294,15 @@ def Decl.simp? (decl : Decl) : SimpM (Option Decl) := do updateFunDeclInfo decl.value trace[Compiler.simp.inline.info] "{decl.name}:{Format.nest 2 (← (← get).funDeclInfoMap.format)}" traceM `Compiler.simp.step do ppDecl decl - let value ← simp decl.value + let mut value ← simp decl.value traceM `Compiler.simp.step.new do return m!"{decl.name} :=\n{← ppCode value}" let s ← get trace[Compiler.simp.stat] "{decl.name}, size: {value.size}, # visited: {s.visited}, # inline: {s.inline}, # inline local: {s.inlineLocal}" - if (← get).simplified then + let mut progress := (← get).simplified + if let some valueNew ← simpJpCases? value (← read).config.smallThreshold then + progress := true + value := valueNew + if progress then return some { decl with value } else return none @@ -1195,6 +1333,7 @@ builtin_initialize registerTraceClass `Compiler.simp (inherited := true) registerTraceClass `Compiler.simp.inline registerTraceClass `Compiler.simp.stat + registerTraceClass `Compiler.simp.jpCases registerTraceClass `Compiler.simp.step registerTraceClass `Compiler.simp.step.new diff --git a/tests/lean/linterUnusedVariables.lean.expected.out b/tests/lean/linterUnusedVariables.lean.expected.out index d2e8b55a11..2872589237 100644 --- a/tests/lean/linterUnusedVariables.lean.expected.out +++ b/tests/lean/linterUnusedVariables.lean.expected.out @@ -12,8 +12,8 @@ linterUnusedVariables.lean:51:11-51:12: warning: unused variable `z` [linter.unu linterUnusedVariables.lean:56:14-56:15: warning: unused variable `y` [linter.unusedVariables] linterUnusedVariables.lean:62:20-62:21: warning: unused variable `y` [linter.unusedVariables] linterUnusedVariables.lean:67:34-67:38: warning: unused variable `inst` [linter.unusedVariables] -linterUnusedVariables.lean:108:25-108:26: warning: unused variable `x` [linter.unusedVariables] linterUnusedVariables.lean:109:6-109:7: warning: unused variable `y` [linter.unusedVariables] +linterUnusedVariables.lean:108:25-108:26: warning: unused variable `x` [linter.unusedVariables] linterUnusedVariables.lean:115:6-115:7: warning: unused variable `a` [linter.unusedVariables] linterUnusedVariables.lean:125:26-125:27: warning: unused variable `z` [linter.unusedVariables] linterUnusedVariables.lean:133:9-133:10: warning: unused variable `h` [linter.unusedVariables]