feat: replace isCasesOnCases? with simpJpCases?

It addresses the code explosion issue with the old optimization.
For example, the resulting size for `Lean.Json.Parser.escapedChar`
went from 31593 to 361.
This commit is contained in:
Leonardo de Moura 2022-09-25 20:54:41 -07:00
parent bbac49e925
commit fd1ae3118c
3 changed files with 236 additions and 88 deletions

View file

@ -70,6 +70,15 @@ abbrev Alt := AltCore Code
abbrev FunDecl := FunDeclCore Code
abbrev Cases := CasesCore Code
/--
Return the constructor names that have an explicit (non-default) alternative.
-/
def CasesCore.getCtorNames (c : Cases) : NameSet :=
c.alts.foldl (init := {}) fun ctorNames alt =>
match alt with
| .default _ => ctorNames
| .alt ctorName .. => ctorNames.insert ctorName
inductive CodeDecl where
| let (decl : LetDecl)
| fun (decl : FunDecl)

View file

@ -14,6 +14,7 @@ import Lean.Compiler.LCNF.Bind
import Lean.Compiler.LCNF.PrettyPrinter
import Lean.Compiler.LCNF.PassManager
import Lean.Compiler.LCNF.AlphaEqv
import Lean.Compiler.LCNF.DependsOn
namespace Lean.Compiler.LCNF
@ -687,85 +688,6 @@ where
| .return fvarId => visit (.fvar fvarId) projs
| _ => failure
/--
Given the function declaration `decl`, return `true` if it is of the form
```
f y :=
... /- This part is not bigger than smallThreshold. -/
cases y
| ... => ...
...
```
-/
def isJpCases (decl : FunDecl) (smallThreshold : Nat) : CompilerM Bool := do
if decl.params.size != 1 then
return false
else
let param := decl.params[0]!
let rec go (code : Code) (prefixSize : Nat) : Bool :=
prefixSize <= smallThreshold &&
match code with
| .let _ k => go k (prefixSize + 1) /- TODO: we should have uniform heuristics for estimating the size. -/
| .cases c => c.discr == param.fvarId
| _ => false
return go decl.value 0
/--
Return `some _jp.k` if the given `cases` is of the form
```
cases _x.i
(... let _x.j₁ := ctorⱼ₁ ...; _jp.k _x.j₁)
...
(... let _x.jₙ := ctorⱼₙ ...; _jp.k _x.jₙ)
```
where `_jp.k` is a join point of the form
```
let _jp.k y :=
cases y ...
```
The goal is to mark `_jp.k` as must inline in this scenario.
The function also allows `_jp.k` to have a small prefix before
`cases y`. The small prefix is set using the configuration option
`config.smallThreshold`. It is currently the same threshold used to
decide when to inline a function that has multiple occurrences.
Example: consider the following declarations
```
@[inline] def pred? (x : Nat) : Option Nat :=
match x with
| 0 => none
| x+1 => some x
def isZero (x : Nat) :=
match pred? x with
| some _ => false
| none => true
```
After inlining `pred?` in `isZero`, this simplification is applicable, producing
Remark: this method does not assume `cases` has already been normalized,
but returns a normalized `FVarId` in case of success.
-/
def isCasesOnCases? (cases : Cases) : OptionT SimpM FVarId := do
let jpFirst ← isCtorJmp? cases.alts[0]!.getCode
let funDecl ← getFunDecl jpFirst
let smallThreshold := (← read).config.smallThreshold
guard <| (← isJpCases funDecl smallThreshold)
for alt in cases.alts[1:] do
let jp ← isCtorJmp? alt.getCode
guard <| jpFirst == jp
return jpFirst
where
isCtorJmp? (code : Code) : OptionT SimpM FVarId := do
match code with
| .let _ k | .jp _ k | .fun _ k => isCtorJmp? k
| .return .. | .unreach .. | .cases .. => failure
| .jmp jpFVarId args =>
let #[arg] := args | failure
let arg ← findExpr (← normExpr arg)
guard <| arg.isConstructorApp (← getEnv)
normFVar jpFVarId
/--
Return the alternative in `alts` whose body appears in most arms,
and the number of occurrences.
@ -1119,23 +1041,49 @@ partial def simp (code : Code) : SimpM Code := withIncRecDepth do
return alts[0]!.getCode
else
return code.updateCases! resultType discr alts
if let some jpFVarId ← isCasesOnCases? c then
withAddMustInline jpFVarId simpCasesDefault
else
simpCasesDefault
simpCasesDefault
end
/--
Given the function declaration `decl`, return `true` if it is of the form
```
f y :=
... /- This part is not bigger than smallThreshold. -/
cases y
| ... => ...
...
```
-/
def isJpCases (decl : FunDecl) (smallThreshold : Nat) : CompilerM Bool := do
if decl.params.size != 1 then
return false
else
let param := decl.params[0]!
let rec go (code : Code) (prefixSize : Nat) : Bool :=
prefixSize <= smallThreshold &&
match code with
| .let _ k => go k (prefixSize + 1) /- TODO: we should have uniform heuristics for estimating the size. -/
| .cases c => c.discr == param.fvarId
| _ => false
return go decl.value 0
abbrev JpCasesInfo := FVarIdMap NameSet
/-- Return `true` if the collected information suggests opportunities for the `JpCases` optimization. -/
def JpCasesInfo.isCandidate (info : JpCasesInfo) : Bool :=
info.any fun _ s => !s.isEmpty
/--
Return a map containing entries `jpFVarId ↦ ctorNames` where `jpFVarId` is the id of join point
in code that satisfies `isJpCases`, and `ctorNames` is a set of constructor names such that
there is a jump `.jmp jpFVarId #[x]` in `code` and `x` is a constructor application.
-/
partial def collectJpCasesInfo (code : Code) (smallThreshold : Nat): CompilerM (FVarIdMap NameSet) := do
partial def collectJpCasesInfo (code : Code) (smallThreshold : Nat): CompilerM JpCasesInfo := do
let (_, s) ← go code |>.run {}
return s
where
go (code : Code) : StateRefT (FVarIdMap NameSet) CompilerM Unit := do
go (code : Code) : StateRefT JpCasesInfo CompilerM Unit := do
match code with
| .let _ k => go k
| .fun decl k => go decl.value; go k
@ -1152,6 +1100,192 @@ where
let some (cval, _) := arg.constructorApp? (← getEnv) | return ()
modify fun s => s.insert fvarId <| ctorNames.insert cval.name
/--
Extract the let-declarations and `cases` for a join point body that satisfies `isJpCases`.
-/
private def extractJpCases (code : Code) : Array CodeDecl × Cases :=
go code #[]
where
go (code : Code) (decls : Array CodeDecl) :=
match code with
| .let decl k => go k <| decls.push (.let decl)
| .cases c => (decls, c)
| _ => unreachable! -- `code` is not the body of a join point that satisfies `isJpCases`
structure JpCasesAlt where
decl : FunDecl
default : Bool
dependsOnDiscr : Bool
abbrev Ctor2JpCasesAlt := FVarIdMap (NameMap JpCasesAlt)
open Internalize in
private def mkJpAlt (decls : Array CodeDecl) (discr : Param) (fields : Array Param) (k : Code) (default : Bool) : CompilerM JpCasesAlt := do
go |>.run' {}
where
go : InternalizeM JpCasesAlt := do
let s : FVarIdSet := {}
let mut paramsNew := #[]
let dependsOnDiscr := k.dependsOn (s.insert discr.fvarId)
if dependsOnDiscr then
paramsNew := paramsNew.push (← internalizeParam discr)
paramsNew := paramsNew ++ (← fields.mapM internalizeParam)
let decls ← decls.mapM internalizeCodeDecl
let k ← internalizeCode k
let value := LCNF.attachCodeDecls decls k
return { decl := (← mkAuxJpDecl paramsNew value), default, dependsOnDiscr }
/--
Try to optimize `jpCases` join points.
We say a join point is a `jpCases` when it satifies the predicate `isJpCases`.
If we have a jump to `jpCases` with a constructor, then we can optimize the code by creating an new join point for
the constructor.
Example: suppose we have
```lean
jp _jp.1 y :=
let x.1 := true
cases y
| nil => let x.2 := g x.1; return x.2
| cons h t => let x.3 := h x.1; return x.3
...
cases x.4
| ctor1 =>
let x.5 := cons z.1 z.2
jmp _jp.1 x.5
| ctor2 =>
let x.6 := f x.4
jmp _jp.1 x.6
```
This `simpJpCases?` converts it to
```lean
jp _jp.2 h t :=
let x.1 := true
let x.3 := h x.1
return x.3
jp _jp.1 y :=
let x.1 := true
cases y
| nil => let x.2 := g x.1; return x.2
| cons h t => jmp _jp.2 h t
...
cases x.4
| ctor1 =>
-- The constructor has been eliminated here
jmp _jp.2 z.1 z.2
| ctor2 =>
let x.6 := f x.4
jmp _jp.1 x.6
```
Note that if all jumps to the join point are with constructors,
then the join point is eliminated as dead code.
-/
partial def simpJpCases? (code : Code) (smallThreshold : Nat) : CompilerM (Option Code) := do
let info ← collectJpCasesInfo code smallThreshold
unless info.isCandidate do return none
traceM `Compiler.simp.jpCases do
let mut msg : MessageData := "candidates"
for (fvarId, ctorName) in info.toList do
msg := msg ++ indentD m!"{mkFVar fvarId} ↦ {ctorName.toList}"
return msg
visit code info |>.run' {}
where
visit (code : Code) : ReaderT JpCasesInfo (StateRefT Ctor2JpCasesAlt CompilerM) Code := do
match code with
| .let decl k =>
return code.updateLet! decl (← visit k)
| .fun decl k =>
let value ← visit decl.value
let decl ← decl.updateValue value
return code.updateFun! decl (← visit k)
| .jp decl k =>
if let some code ← visitJp? decl k then
return code
else
let value ← visit decl.value
let decl ← decl.updateValue value
return code.updateFun! decl (← visit k)
| .cases c =>
let alts ← c.alts.mapMonoM fun alt => return alt.updateCode (← visit alt.getCode)
return code.updateAlts! alts
| .return _ | .unreach _ => return code
| .jmp fvarId args =>
let some code ← visitJmp? fvarId args | return code
return code
visitJp? (decl : FunDecl) (k : Code) : ReaderT JpCasesInfo (StateRefT Ctor2JpCasesAlt CompilerM) (Option Code) := do
let some s := (← read).find? decl.fvarId | return none
if s.isEmpty then return none
-- This join point satisfies `isJp` and there jumps with constructors in `s` to it.
let p := decl.params[0]!
let (decls, cases) := extractJpCases decl.value
let mut jpAltMap := {}
let mut jpAltDecls := #[]
let mut altsNew := #[]
for alt in cases.alts do
match alt with
| .default k =>
let k ← visit k
let explicitCtorNames := cases.getCtorNames
if s.any fun ctorNameInJump => !explicitCtorNames.contains ctorNameInJump then
let jpAlt ← mkJpAlt decls p #[] k (default := true)
jpAltDecls := jpAltDecls.push (.jp jpAlt.decl)
eraseCode k
for ctorNameInJmp in s do
unless explicitCtorNames.contains ctorNameInJmp do
jpAltMap := jpAltMap.insert ctorNameInJmp jpAlt
let args := if jpAlt.dependsOnDiscr then #[.fvar p.fvarId] else #[]
altsNew := altsNew.push (alt.updateCode (.jmp jpAlt.decl.fvarId args))
else
altsNew := altsNew.push (alt.updateCode k)
| .alt ctorName fields k =>
let k ← visit k
if s.contains ctorName then
let jpAlt ← mkJpAlt decls p fields k (default := false)
jpAltDecls := jpAltDecls.push (.jp jpAlt.decl)
jpAltMap := jpAltMap.insert ctorName jpAlt
let mut args := fields.map (mkFVar ·.fvarId)
if jpAlt.dependsOnDiscr then
args := #[mkFVar p.fvarId] ++ args
eraseCode k
altsNew := altsNew.push (alt.updateCode (.jmp jpAlt.decl.fvarId args))
else
altsNew := altsNew.push (alt.updateCode k)
modify fun s => s.insert decl.fvarId jpAltMap
let value := LCNF.attachCodeDecls decls (.cases { cases with alts := altsNew })
let decl ← decl.updateValue value
let code := .jp decl (← visit k)
return LCNF.attachCodeDecls jpAltDecls code
visitJmp? (fvarId : FVarId) (args : Array Expr) : ReaderT JpCasesInfo (StateRefT Ctor2JpCasesAlt CompilerM) (Option Code) := do
let some ctorJpAltMap := (← get).find? fvarId | return none
assert! args.size == 1
let arg ← findExpr args[0]!
let some (ctorVal, ctorArgs) := arg.constructorApp? (← getEnv) (useRaw := true) | return none
let some jpAlt := ctorJpAltMap.find? ctorVal.name | return none
if jpAlt.default then
if jpAlt.dependsOnDiscr then
return some <| .jmp jpAlt.decl.fvarId args
else
return some <| .jmp jpAlt.decl.fvarId #[]
else
let fields := ctorArgs[ctorVal.numParams:]
-- Recall that if `arg` is a `Nat` literal, then `ctorArgs` is a literal too.
-- We use a for-loop because we may have other special cases in the future.
let mut auxDecls := #[]
let mut fieldsNew := #[]
for field in fields do
if field.isFVar then
fieldsNew := fieldsNew.push field
else
let letDecl ← mkAuxLetDecl field
auxDecls := auxDecls.push (CodeDecl.let letDecl)
fieldsNew := fieldsNew.push (.fvar letDecl.fvarId)
let code ← if jpAlt.dependsOnDiscr then
pure <| .jmp jpAlt.decl.fvarId (args ++ fieldsNew)
else
pure <| .jmp jpAlt.decl.fvarId fieldsNew
return some <| LCNF.attachCodeDecls auxDecls code
end Simp
open Simp
@ -1160,11 +1294,15 @@ def Decl.simp? (decl : Decl) : SimpM (Option Decl) := do
updateFunDeclInfo decl.value
trace[Compiler.simp.inline.info] "{decl.name}:{Format.nest 2 (← (← get).funDeclInfoMap.format)}"
traceM `Compiler.simp.step do ppDecl decl
let value ← simp decl.value
let mut value ← simp decl.value
traceM `Compiler.simp.step.new do return m!"{decl.name} :=\n{← ppCode value}"
let s ← get
trace[Compiler.simp.stat] "{decl.name}, size: {value.size}, # visited: {s.visited}, # inline: {s.inline}, # inline local: {s.inlineLocal}"
if (← get).simplified then
let mut progress := (← get).simplified
if let some valueNew ← simpJpCases? value (← read).config.smallThreshold then
progress := true
value := valueNew
if progress then
return some { decl with value }
else
return none
@ -1195,6 +1333,7 @@ builtin_initialize
registerTraceClass `Compiler.simp (inherited := true)
registerTraceClass `Compiler.simp.inline
registerTraceClass `Compiler.simp.stat
registerTraceClass `Compiler.simp.jpCases
registerTraceClass `Compiler.simp.step
registerTraceClass `Compiler.simp.step.new

View file

@ -12,8 +12,8 @@ linterUnusedVariables.lean:51:11-51:12: warning: unused variable `z` [linter.unu
linterUnusedVariables.lean:56:14-56:15: warning: unused variable `y` [linter.unusedVariables]
linterUnusedVariables.lean:62:20-62:21: warning: unused variable `y` [linter.unusedVariables]
linterUnusedVariables.lean:67:34-67:38: warning: unused variable `inst` [linter.unusedVariables]
linterUnusedVariables.lean:108:25-108:26: warning: unused variable `x` [linter.unusedVariables]
linterUnusedVariables.lean:109:6-109:7: warning: unused variable `y` [linter.unusedVariables]
linterUnusedVariables.lean:108:25-108:26: warning: unused variable `x` [linter.unusedVariables]
linterUnusedVariables.lean:115:6-115:7: warning: unused variable `a` [linter.unusedVariables]
linterUnusedVariables.lean:125:26-125:27: warning: unused variable `z` [linter.unusedVariables]
linterUnusedVariables.lean:133:9-133:10: warning: unused variable `h` [linter.unusedVariables]