feat: replace isCasesOnCases? with simpJpCases?
It addresses the code explosion issue with the old optimization. For example, the resulting size for `Lean.Json.Parser.escapedChar` went from 31593 to 361.
This commit is contained in:
parent
bbac49e925
commit
fd1ae3118c
3 changed files with 236 additions and 88 deletions
|
|
@ -70,6 +70,15 @@ abbrev Alt := AltCore Code
|
|||
abbrev FunDecl := FunDeclCore Code
|
||||
abbrev Cases := CasesCore Code
|
||||
|
||||
/--
|
||||
Return the constructor names that have an explicit (non-default) alternative.
|
||||
-/
|
||||
def CasesCore.getCtorNames (c : Cases) : NameSet :=
|
||||
c.alts.foldl (init := {}) fun ctorNames alt =>
|
||||
match alt with
|
||||
| .default _ => ctorNames
|
||||
| .alt ctorName .. => ctorNames.insert ctorName
|
||||
|
||||
inductive CodeDecl where
|
||||
| let (decl : LetDecl)
|
||||
| fun (decl : FunDecl)
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ import Lean.Compiler.LCNF.Bind
|
|||
import Lean.Compiler.LCNF.PrettyPrinter
|
||||
import Lean.Compiler.LCNF.PassManager
|
||||
import Lean.Compiler.LCNF.AlphaEqv
|
||||
import Lean.Compiler.LCNF.DependsOn
|
||||
|
||||
namespace Lean.Compiler.LCNF
|
||||
|
||||
|
|
@ -687,85 +688,6 @@ where
|
|||
| .return fvarId => visit (.fvar fvarId) projs
|
||||
| _ => failure
|
||||
|
||||
/--
|
||||
Given the function declaration `decl`, return `true` if it is of the form
|
||||
```
|
||||
f y :=
|
||||
... /- This part is not bigger than smallThreshold. -/
|
||||
cases y
|
||||
| ... => ...
|
||||
...
|
||||
```
|
||||
-/
|
||||
def isJpCases (decl : FunDecl) (smallThreshold : Nat) : CompilerM Bool := do
|
||||
if decl.params.size != 1 then
|
||||
return false
|
||||
else
|
||||
let param := decl.params[0]!
|
||||
let rec go (code : Code) (prefixSize : Nat) : Bool :=
|
||||
prefixSize <= smallThreshold &&
|
||||
match code with
|
||||
| .let _ k => go k (prefixSize + 1) /- TODO: we should have uniform heuristics for estimating the size. -/
|
||||
| .cases c => c.discr == param.fvarId
|
||||
| _ => false
|
||||
return go decl.value 0
|
||||
|
||||
/--
|
||||
Return `some _jp.k` if the given `cases` is of the form
|
||||
```
|
||||
cases _x.i
|
||||
(... let _x.j₁ := ctorⱼ₁ ...; _jp.k _x.j₁)
|
||||
...
|
||||
(... let _x.jₙ := ctorⱼₙ ...; _jp.k _x.jₙ)
|
||||
```
|
||||
where `_jp.k` is a join point of the form
|
||||
```
|
||||
let _jp.k y :=
|
||||
cases y ...
|
||||
```
|
||||
The goal is to mark `_jp.k` as must inline in this scenario.
|
||||
The function also allows `_jp.k` to have a small prefix before
|
||||
`cases y`. The small prefix is set using the configuration option
|
||||
`config.smallThreshold`. It is currently the same threshold used to
|
||||
decide when to inline a function that has multiple occurrences.
|
||||
|
||||
Example: consider the following declarations
|
||||
```
|
||||
@[inline] def pred? (x : Nat) : Option Nat :=
|
||||
match x with
|
||||
| 0 => none
|
||||
| x+1 => some x
|
||||
|
||||
def isZero (x : Nat) :=
|
||||
match pred? x with
|
||||
| some _ => false
|
||||
| none => true
|
||||
```
|
||||
After inlining `pred?` in `isZero`, this simplification is applicable, producing
|
||||
|
||||
Remark: this method does not assume `cases` has already been normalized,
|
||||
but returns a normalized `FVarId` in case of success.
|
||||
-/
|
||||
def isCasesOnCases? (cases : Cases) : OptionT SimpM FVarId := do
|
||||
let jpFirst ← isCtorJmp? cases.alts[0]!.getCode
|
||||
let funDecl ← getFunDecl jpFirst
|
||||
let smallThreshold := (← read).config.smallThreshold
|
||||
guard <| (← isJpCases funDecl smallThreshold)
|
||||
for alt in cases.alts[1:] do
|
||||
let jp ← isCtorJmp? alt.getCode
|
||||
guard <| jpFirst == jp
|
||||
return jpFirst
|
||||
where
|
||||
isCtorJmp? (code : Code) : OptionT SimpM FVarId := do
|
||||
match code with
|
||||
| .let _ k | .jp _ k | .fun _ k => isCtorJmp? k
|
||||
| .return .. | .unreach .. | .cases .. => failure
|
||||
| .jmp jpFVarId args =>
|
||||
let #[arg] := args | failure
|
||||
let arg ← findExpr (← normExpr arg)
|
||||
guard <| arg.isConstructorApp (← getEnv)
|
||||
normFVar jpFVarId
|
||||
|
||||
/--
|
||||
Return the alternative in `alts` whose body appears in most arms,
|
||||
and the number of occurrences.
|
||||
|
|
@ -1119,23 +1041,49 @@ partial def simp (code : Code) : SimpM Code := withIncRecDepth do
|
|||
return alts[0]!.getCode
|
||||
else
|
||||
return code.updateCases! resultType discr alts
|
||||
if let some jpFVarId ← isCasesOnCases? c then
|
||||
withAddMustInline jpFVarId simpCasesDefault
|
||||
else
|
||||
simpCasesDefault
|
||||
simpCasesDefault
|
||||
|
||||
end
|
||||
|
||||
/--
|
||||
Given the function declaration `decl`, return `true` if it is of the form
|
||||
```
|
||||
f y :=
|
||||
... /- This part is not bigger than smallThreshold. -/
|
||||
cases y
|
||||
| ... => ...
|
||||
...
|
||||
```
|
||||
-/
|
||||
def isJpCases (decl : FunDecl) (smallThreshold : Nat) : CompilerM Bool := do
|
||||
if decl.params.size != 1 then
|
||||
return false
|
||||
else
|
||||
let param := decl.params[0]!
|
||||
let rec go (code : Code) (prefixSize : Nat) : Bool :=
|
||||
prefixSize <= smallThreshold &&
|
||||
match code with
|
||||
| .let _ k => go k (prefixSize + 1) /- TODO: we should have uniform heuristics for estimating the size. -/
|
||||
| .cases c => c.discr == param.fvarId
|
||||
| _ => false
|
||||
return go decl.value 0
|
||||
|
||||
abbrev JpCasesInfo := FVarIdMap NameSet
|
||||
|
||||
/-- Return `true` if the collected information suggests opportunities for the `JpCases` optimization. -/
|
||||
def JpCasesInfo.isCandidate (info : JpCasesInfo) : Bool :=
|
||||
info.any fun _ s => !s.isEmpty
|
||||
|
||||
/--
|
||||
Return a map containing entries `jpFVarId ↦ ctorNames` where `jpFVarId` is the id of join point
|
||||
in code that satisfies `isJpCases`, and `ctorNames` is a set of constructor names such that
|
||||
there is a jump `.jmp jpFVarId #[x]` in `code` and `x` is a constructor application.
|
||||
-/
|
||||
partial def collectJpCasesInfo (code : Code) (smallThreshold : Nat): CompilerM (FVarIdMap NameSet) := do
|
||||
partial def collectJpCasesInfo (code : Code) (smallThreshold : Nat): CompilerM JpCasesInfo := do
|
||||
let (_, s) ← go code |>.run {}
|
||||
return s
|
||||
where
|
||||
go (code : Code) : StateRefT (FVarIdMap NameSet) CompilerM Unit := do
|
||||
go (code : Code) : StateRefT JpCasesInfo CompilerM Unit := do
|
||||
match code with
|
||||
| .let _ k => go k
|
||||
| .fun decl k => go decl.value; go k
|
||||
|
|
@ -1152,6 +1100,192 @@ where
|
|||
let some (cval, _) := arg.constructorApp? (← getEnv) | return ()
|
||||
modify fun s => s.insert fvarId <| ctorNames.insert cval.name
|
||||
|
||||
/--
|
||||
Extract the let-declarations and `cases` for a join point body that satisfies `isJpCases`.
|
||||
-/
|
||||
private def extractJpCases (code : Code) : Array CodeDecl × Cases :=
|
||||
go code #[]
|
||||
where
|
||||
go (code : Code) (decls : Array CodeDecl) :=
|
||||
match code with
|
||||
| .let decl k => go k <| decls.push (.let decl)
|
||||
| .cases c => (decls, c)
|
||||
| _ => unreachable! -- `code` is not the body of a join point that satisfies `isJpCases`
|
||||
|
||||
structure JpCasesAlt where
|
||||
decl : FunDecl
|
||||
default : Bool
|
||||
dependsOnDiscr : Bool
|
||||
|
||||
abbrev Ctor2JpCasesAlt := FVarIdMap (NameMap JpCasesAlt)
|
||||
|
||||
open Internalize in
|
||||
private def mkJpAlt (decls : Array CodeDecl) (discr : Param) (fields : Array Param) (k : Code) (default : Bool) : CompilerM JpCasesAlt := do
|
||||
go |>.run' {}
|
||||
where
|
||||
go : InternalizeM JpCasesAlt := do
|
||||
let s : FVarIdSet := {}
|
||||
let mut paramsNew := #[]
|
||||
let dependsOnDiscr := k.dependsOn (s.insert discr.fvarId)
|
||||
if dependsOnDiscr then
|
||||
paramsNew := paramsNew.push (← internalizeParam discr)
|
||||
paramsNew := paramsNew ++ (← fields.mapM internalizeParam)
|
||||
let decls ← decls.mapM internalizeCodeDecl
|
||||
let k ← internalizeCode k
|
||||
let value := LCNF.attachCodeDecls decls k
|
||||
return { decl := (← mkAuxJpDecl paramsNew value), default, dependsOnDiscr }
|
||||
|
||||
/--
|
||||
Try to optimize `jpCases` join points.
|
||||
We say a join point is a `jpCases` when it satifies the predicate `isJpCases`.
|
||||
If we have a jump to `jpCases` with a constructor, then we can optimize the code by creating an new join point for
|
||||
the constructor.
|
||||
Example: suppose we have
|
||||
```lean
|
||||
jp _jp.1 y :=
|
||||
let x.1 := true
|
||||
cases y
|
||||
| nil => let x.2 := g x.1; return x.2
|
||||
| cons h t => let x.3 := h x.1; return x.3
|
||||
...
|
||||
cases x.4
|
||||
| ctor1 =>
|
||||
let x.5 := cons z.1 z.2
|
||||
jmp _jp.1 x.5
|
||||
| ctor2 =>
|
||||
let x.6 := f x.4
|
||||
jmp _jp.1 x.6
|
||||
```
|
||||
This `simpJpCases?` converts it to
|
||||
```lean
|
||||
jp _jp.2 h t :=
|
||||
let x.1 := true
|
||||
let x.3 := h x.1
|
||||
return x.3
|
||||
jp _jp.1 y :=
|
||||
let x.1 := true
|
||||
cases y
|
||||
| nil => let x.2 := g x.1; return x.2
|
||||
| cons h t => jmp _jp.2 h t
|
||||
...
|
||||
cases x.4
|
||||
| ctor1 =>
|
||||
-- The constructor has been eliminated here
|
||||
jmp _jp.2 z.1 z.2
|
||||
| ctor2 =>
|
||||
let x.6 := f x.4
|
||||
jmp _jp.1 x.6
|
||||
```
|
||||
Note that if all jumps to the join point are with constructors,
|
||||
then the join point is eliminated as dead code.
|
||||
-/
|
||||
partial def simpJpCases? (code : Code) (smallThreshold : Nat) : CompilerM (Option Code) := do
|
||||
let info ← collectJpCasesInfo code smallThreshold
|
||||
unless info.isCandidate do return none
|
||||
traceM `Compiler.simp.jpCases do
|
||||
let mut msg : MessageData := "candidates"
|
||||
for (fvarId, ctorName) in info.toList do
|
||||
msg := msg ++ indentD m!"{mkFVar fvarId} ↦ {ctorName.toList}"
|
||||
return msg
|
||||
visit code info |>.run' {}
|
||||
where
|
||||
visit (code : Code) : ReaderT JpCasesInfo (StateRefT Ctor2JpCasesAlt CompilerM) Code := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
return code.updateLet! decl (← visit k)
|
||||
| .fun decl k =>
|
||||
let value ← visit decl.value
|
||||
let decl ← decl.updateValue value
|
||||
return code.updateFun! decl (← visit k)
|
||||
| .jp decl k =>
|
||||
if let some code ← visitJp? decl k then
|
||||
return code
|
||||
else
|
||||
let value ← visit decl.value
|
||||
let decl ← decl.updateValue value
|
||||
return code.updateFun! decl (← visit k)
|
||||
| .cases c =>
|
||||
let alts ← c.alts.mapMonoM fun alt => return alt.updateCode (← visit alt.getCode)
|
||||
return code.updateAlts! alts
|
||||
| .return _ | .unreach _ => return code
|
||||
| .jmp fvarId args =>
|
||||
let some code ← visitJmp? fvarId args | return code
|
||||
return code
|
||||
|
||||
visitJp? (decl : FunDecl) (k : Code) : ReaderT JpCasesInfo (StateRefT Ctor2JpCasesAlt CompilerM) (Option Code) := do
|
||||
let some s := (← read).find? decl.fvarId | return none
|
||||
if s.isEmpty then return none
|
||||
-- This join point satisfies `isJp` and there jumps with constructors in `s` to it.
|
||||
let p := decl.params[0]!
|
||||
let (decls, cases) := extractJpCases decl.value
|
||||
let mut jpAltMap := {}
|
||||
let mut jpAltDecls := #[]
|
||||
let mut altsNew := #[]
|
||||
for alt in cases.alts do
|
||||
match alt with
|
||||
| .default k =>
|
||||
let k ← visit k
|
||||
let explicitCtorNames := cases.getCtorNames
|
||||
if s.any fun ctorNameInJump => !explicitCtorNames.contains ctorNameInJump then
|
||||
let jpAlt ← mkJpAlt decls p #[] k (default := true)
|
||||
jpAltDecls := jpAltDecls.push (.jp jpAlt.decl)
|
||||
eraseCode k
|
||||
for ctorNameInJmp in s do
|
||||
unless explicitCtorNames.contains ctorNameInJmp do
|
||||
jpAltMap := jpAltMap.insert ctorNameInJmp jpAlt
|
||||
let args := if jpAlt.dependsOnDiscr then #[.fvar p.fvarId] else #[]
|
||||
altsNew := altsNew.push (alt.updateCode (.jmp jpAlt.decl.fvarId args))
|
||||
else
|
||||
altsNew := altsNew.push (alt.updateCode k)
|
||||
| .alt ctorName fields k =>
|
||||
let k ← visit k
|
||||
if s.contains ctorName then
|
||||
let jpAlt ← mkJpAlt decls p fields k (default := false)
|
||||
jpAltDecls := jpAltDecls.push (.jp jpAlt.decl)
|
||||
jpAltMap := jpAltMap.insert ctorName jpAlt
|
||||
let mut args := fields.map (mkFVar ·.fvarId)
|
||||
if jpAlt.dependsOnDiscr then
|
||||
args := #[mkFVar p.fvarId] ++ args
|
||||
eraseCode k
|
||||
altsNew := altsNew.push (alt.updateCode (.jmp jpAlt.decl.fvarId args))
|
||||
else
|
||||
altsNew := altsNew.push (alt.updateCode k)
|
||||
modify fun s => s.insert decl.fvarId jpAltMap
|
||||
let value := LCNF.attachCodeDecls decls (.cases { cases with alts := altsNew })
|
||||
let decl ← decl.updateValue value
|
||||
let code := .jp decl (← visit k)
|
||||
return LCNF.attachCodeDecls jpAltDecls code
|
||||
|
||||
visitJmp? (fvarId : FVarId) (args : Array Expr) : ReaderT JpCasesInfo (StateRefT Ctor2JpCasesAlt CompilerM) (Option Code) := do
|
||||
let some ctorJpAltMap := (← get).find? fvarId | return none
|
||||
assert! args.size == 1
|
||||
let arg ← findExpr args[0]!
|
||||
let some (ctorVal, ctorArgs) := arg.constructorApp? (← getEnv) (useRaw := true) | return none
|
||||
let some jpAlt := ctorJpAltMap.find? ctorVal.name | return none
|
||||
if jpAlt.default then
|
||||
if jpAlt.dependsOnDiscr then
|
||||
return some <| .jmp jpAlt.decl.fvarId args
|
||||
else
|
||||
return some <| .jmp jpAlt.decl.fvarId #[]
|
||||
else
|
||||
let fields := ctorArgs[ctorVal.numParams:]
|
||||
-- Recall that if `arg` is a `Nat` literal, then `ctorArgs` is a literal too.
|
||||
-- We use a for-loop because we may have other special cases in the future.
|
||||
let mut auxDecls := #[]
|
||||
let mut fieldsNew := #[]
|
||||
for field in fields do
|
||||
if field.isFVar then
|
||||
fieldsNew := fieldsNew.push field
|
||||
else
|
||||
let letDecl ← mkAuxLetDecl field
|
||||
auxDecls := auxDecls.push (CodeDecl.let letDecl)
|
||||
fieldsNew := fieldsNew.push (.fvar letDecl.fvarId)
|
||||
let code ← if jpAlt.dependsOnDiscr then
|
||||
pure <| .jmp jpAlt.decl.fvarId (args ++ fieldsNew)
|
||||
else
|
||||
pure <| .jmp jpAlt.decl.fvarId fieldsNew
|
||||
return some <| LCNF.attachCodeDecls auxDecls code
|
||||
|
||||
end Simp
|
||||
|
||||
open Simp
|
||||
|
|
@ -1160,11 +1294,15 @@ def Decl.simp? (decl : Decl) : SimpM (Option Decl) := do
|
|||
updateFunDeclInfo decl.value
|
||||
trace[Compiler.simp.inline.info] "{decl.name}:{Format.nest 2 (← (← get).funDeclInfoMap.format)}"
|
||||
traceM `Compiler.simp.step do ppDecl decl
|
||||
let value ← simp decl.value
|
||||
let mut value ← simp decl.value
|
||||
traceM `Compiler.simp.step.new do return m!"{decl.name} :=\n{← ppCode value}"
|
||||
let s ← get
|
||||
trace[Compiler.simp.stat] "{decl.name}, size: {value.size}, # visited: {s.visited}, # inline: {s.inline}, # inline local: {s.inlineLocal}"
|
||||
if (← get).simplified then
|
||||
let mut progress := (← get).simplified
|
||||
if let some valueNew ← simpJpCases? value (← read).config.smallThreshold then
|
||||
progress := true
|
||||
value := valueNew
|
||||
if progress then
|
||||
return some { decl with value }
|
||||
else
|
||||
return none
|
||||
|
|
@ -1195,6 +1333,7 @@ builtin_initialize
|
|||
registerTraceClass `Compiler.simp (inherited := true)
|
||||
registerTraceClass `Compiler.simp.inline
|
||||
registerTraceClass `Compiler.simp.stat
|
||||
registerTraceClass `Compiler.simp.jpCases
|
||||
registerTraceClass `Compiler.simp.step
|
||||
registerTraceClass `Compiler.simp.step.new
|
||||
|
||||
|
|
|
|||
|
|
@ -12,8 +12,8 @@ linterUnusedVariables.lean:51:11-51:12: warning: unused variable `z` [linter.unu
|
|||
linterUnusedVariables.lean:56:14-56:15: warning: unused variable `y` [linter.unusedVariables]
|
||||
linterUnusedVariables.lean:62:20-62:21: warning: unused variable `y` [linter.unusedVariables]
|
||||
linterUnusedVariables.lean:67:34-67:38: warning: unused variable `inst` [linter.unusedVariables]
|
||||
linterUnusedVariables.lean:108:25-108:26: warning: unused variable `x` [linter.unusedVariables]
|
||||
linterUnusedVariables.lean:109:6-109:7: warning: unused variable `y` [linter.unusedVariables]
|
||||
linterUnusedVariables.lean:108:25-108:26: warning: unused variable `x` [linter.unusedVariables]
|
||||
linterUnusedVariables.lean:115:6-115:7: warning: unused variable `a` [linter.unusedVariables]
|
||||
linterUnusedVariables.lean:125:26-125:27: warning: unused variable `z` [linter.unusedVariables]
|
||||
linterUnusedVariables.lean:133:9-133:10: warning: unused variable `h` [linter.unusedVariables]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue