feat: JpCases for join points with multiple parameters

This commit is contained in:
Leonardo de Moura 2022-10-03 18:35:16 -07:00
parent 12deab6516
commit f0be5439e6
4 changed files with 223 additions and 68 deletions

View file

@ -50,7 +50,7 @@ def checkpoint (stepName : Name) (decls : Array Decl) : CompilerM Unit := do
withOptions (fun opts => opts.setBool `pp.motives.pi false) do
let clsName := `Compiler ++ stepName
if (← Lean.isTracingEnabledFor clsName) then
Lean.addTrace clsName m!"size: {decl.size}\n{← ppDecl decl}"
Lean.addTrace clsName m!"size: {decl.size}\n{← ppDecl' decl}"
if compiler.check.get (← getOptions) then
decl.check
if compiler.check.get (← getOptions) then

View file

@ -12,7 +12,7 @@ namespace Lean.Compiler.LCNF
namespace Simp
/--
Given the function declaration `decl`, return `true` if it is of the form
Given the function declaration `decl`, return `some idx` if it is of the form
```
f y :=
... /- This part is not bigger than smallThreshold. -/
@ -20,55 +20,68 @@ f y :=
| ... => ...
...
```
`idx` is the index of the parameter used in the `cases` statement.
-/
def isJpCases (decl : FunDecl) : CompilerM Bool := do
if decl.params.size != 1 then
return false
def isJpCases? (decl : FunDecl) : CompilerM (Option Nat) := do
if decl.params.size == 0 then
return none
else
let param := decl.params[0]!
let small := (← getConfig).smallThreshold
let rec go (code : Code) (prefixSize : Nat) : Bool :=
prefixSize <= small &&
let rec go (code : Code) (prefixSize : Nat) : Option Nat :=
if prefixSize > small then none else
match code with
| .let _ k => go k (prefixSize + 1) /- TODO: we should have uniform heuristics for estimating the size. -/
| .cases c => c.discr == param.fvarId
| _ => false
| .cases c => decl.params.findIdx? fun param => c.discr == param.fvarId
| _ => none
return go decl.value 0
abbrev JpCasesInfo := FVarIdMap NameSet
/--
Information for join points that satisfy `isJpCases?`
-/
structure JpCasesInfo where
/-- Parameter index returned by `isJpCases?`. This parameter is the one the join point is performing the case-split. -/
paramIdx : Nat
/--
Set of constructor names s.t. `ctorName` is in the set if there is a jump to the join point where the parameter
`paramIdx` is a constructor application.
-/
ctorNames : NameSet := {}
deriving Inhabited
abbrev JpCasesInfoMap := FVarIdMap JpCasesInfo
/-- Return `true` if the collected information suggests opportunities for the `JpCases` optimization. -/
def JpCasesInfo.isCandidate (info : JpCasesInfo) : Bool :=
info.any fun _ s => !s.isEmpty
def JpCasesInfoMap.isCandidate (info : JpCasesInfoMap) : Bool :=
info.any fun _ s => !s.ctorNames.isEmpty
/--
Return a map containing entries `jpFVarId ↦ ctorNames` where `jpFVarId` is the id of join point
Return a map containing entries `jpFVarId ↦ { paramIdx, ctorNames }` where `jpFVarId` is the id of join point
in code that satisfies `isJpCases`, and `ctorNames` is a set of constructor names such that
there is a jump `.jmp jpFVarId #[x]` in `code` and `x` is a constructor application.
there is a jump `.jmp jpFVarId #[..., x, ...]` in `code` and `x` is a constructor application.
`paramIdx` is the index of the parameter
-/
partial def collectJpCasesInfo (code : Code) : CompilerM JpCasesInfo := do
partial def collectJpCasesInfo (code : Code) : CompilerM JpCasesInfoMap := do
let (_, s) ← go code |>.run {}
return s
where
go (code : Code) : StateRefT JpCasesInfo CompilerM Unit := do
go (code : Code) : StateRefT JpCasesInfoMap CompilerM Unit := do
match code with
| .let _ k => go k
| .fun decl k => go decl.value; go k
| .jp decl k =>
if (← isJpCases decl) then
modify fun s => s.insert decl.fvarId {}
if let some paramIdx ← isJpCases? decl then
modify fun s => s.insert decl.fvarId { paramIdx }
go decl.value; go k
| .cases c => c.alts.forM fun alt => go alt.getCode
| .return .. | .unreach .. => return ()
| .jmp fvarId args =>
if args.size == 1 then
if let some ctorNames := (← get).find? fvarId then
let arg ← findExpr args[0]!
if let some info := (← get).find? fvarId then
let arg ← findExpr args[info.paramIdx]!
let some (cval, _) := arg.constructorApp? (← getEnv) | return ()
modify fun s => s.insert fvarId <| ctorNames.insert cval.name
modify fun map => map.insert fvarId <| { info with ctorNames := info.ctorNames.insert cval.name }
/--
Extract the let-declarations and `cases` for a join point body that satisfies `isJpCases`.
Extract the let-declarations and `cases` for a join point body that satisfies `isJpCases?`.
-/
private def extractJpCases (code : Code) : Array CodeDecl × Cases :=
go code #[]
@ -87,21 +100,49 @@ structure JpCasesAlt where
abbrev Ctor2JpCasesAlt := FVarIdMap (NameMap JpCasesAlt)
open Internalize in
private def mkJpAlt (decls : Array CodeDecl) (discr : Param) (fields : Array Param) (k : Code) (default : Bool) : CompilerM JpCasesAlt := do
/--
Construct an auxiliary join point for a particular alternative in a join-point that satifies `isJpCases?`.
- `decls` is the prefix (before the `cases`). See `isJpCases?`.
- `params` are the parameters of the main join point that satisfies `isJpCases?`.
- `targetParamIdx` is the index of the parameter that we are expanding to `fields`
- `fields` are the fields/parameter of the alternative.
- `k` is the body of the alternative.
- `default` is true if it is a default alternative.
-/
private def mkJpAlt (decls : Array CodeDecl) (params : Array Param) (targetParamIdx : Nat) (fields : Array Param) (k : Code) (default : Bool) : CompilerM JpCasesAlt := do
go |>.run' {}
where
go : InternalizeM JpCasesAlt := do
let s : FVarIdSet := {}
let mut paramsNew := #[]
let dependsOnDiscr := k.dependsOn (s.insert discr.fvarId)
if dependsOnDiscr then
paramsNew := paramsNew.push (← internalizeParam discr)
paramsNew := paramsNew ++ (← fields.mapM internalizeParam)
let dependsOnDiscr := k.dependsOn (s.insert params[targetParamIdx]!.fvarId)
for i in [:params.size] do
let param := params[i]!
if targetParamIdx == i then
if dependsOnDiscr then
paramsNew := paramsNew.push (← internalizeParam param)
paramsNew := paramsNew ++ (← fields.mapM internalizeParam)
else
paramsNew := paramsNew.push (← internalizeParam param)
let decls ← decls.mapM internalizeCodeDecl
let k ← internalizeCode k
let value := LCNF.attachCodeDecls decls k
return { decl := (← mkAuxJpDecl paramsNew value), default, dependsOnDiscr }
/-- Create the arguments for a jump to an auxiliary join point created using `mkJpAlt`. -/
private def mkJmpNewArgs (args : Array Expr) (targetParamIdx : Nat) (fields : Array Expr) (dependsOnTarget : Bool) : Array Expr :=
if dependsOnTarget then
args[:targetParamIdx+1] ++ fields ++ args[targetParamIdx+1:]
else
args[:targetParamIdx] ++ fields ++ args[targetParamIdx+1:]
/--
Create the arguments for a jump to an auxiliary join point created using `mkJpAlt`.
This function is used to create jumps from the join point satisfying `isJpCases?` to the new auxiliary join points created using `mkJpAlt`.
-/
private def mkJmpArgsAtJp (params : Array Param) (targetParamIdx : Nat) (fields : Array Param) (dependsOnTarget : Bool) : Array Expr := Id.run do
mkJmpNewArgs (params.map (mkFVar ·.fvarId)) targetParamIdx (fields.map (mkFVar ·.fvarId)) dependsOnTarget
/--
Try to optimize `jpCases` join points.
We say a join point is a `jpCases` when it satifies the predicate `isJpCases`.
@ -147,16 +188,16 @@ Note that if all jumps to the join point are with constructors,
then the join point is eliminated as dead code.
-/
partial def simpJpCases? (code : Code) : CompilerM (Option Code) := do
let info ← collectJpCasesInfo code
unless info.isCandidate do return none
let map ← collectJpCasesInfo code
unless map.isCandidate do return none
traceM `Compiler.simp.jpCases do
let mut msg : MessageData := "candidates"
for (fvarId, ctorName) in info.toList do
msg := msg ++ indentD m!"{mkFVar fvarId} ↦ {ctorName.toList}"
for (fvarId, info) in map.toList do
msg := msg ++ indentD m!"{mkFVar fvarId} ↦ {info.ctorNames.toList}"
return msg
visit code info |>.run' {}
visit code map |>.run' {}
where
visit (code : Code) : ReaderT JpCasesInfo (StateRefT Ctor2JpCasesAlt CompilerM) Code := do
visit (code : Code) : ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt CompilerM) Code := do
match code with
| .let decl k =>
return code.updateLet! decl (← visit k)
@ -179,11 +220,10 @@ where
let some code ← visitJmp? fvarId args | return code
return code
visitJp? (decl : FunDecl) (k : Code) : ReaderT JpCasesInfo (StateRefT Ctor2JpCasesAlt CompilerM) (Option Code) := do
let some s := (← read).find? decl.fvarId | return none
if s.isEmpty then return none
-- This join point satisfies `isJp` and there jumps with constructors in `s` to it.
let p := decl.params[0]!
visitJp? (decl : FunDecl) (k : Code) : ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt CompilerM) (Option Code) := do
let some info := (← read).find? decl.fvarId | return none
if info.ctorNames.isEmpty then return none
-- This join point satisfies `isJpCases?` and there are jumps with constructors in `info` to it.
let (decls, cases) := extractJpCases decl.value
let mut jpAltMap := {}
let mut jpAltDecls := #[]
@ -193,26 +233,24 @@ where
| .default k =>
let k ← visit k
let explicitCtorNames := cases.getCtorNames
if s.any fun ctorNameInJump => !explicitCtorNames.contains ctorNameInJump then
let jpAlt ← mkJpAlt decls p #[] k (default := true)
if info.ctorNames.any fun ctorNameInJump => !explicitCtorNames.contains ctorNameInJump then
let jpAlt ← mkJpAlt decls decl.params info.paramIdx #[] k (default := true)
jpAltDecls := jpAltDecls.push (.jp jpAlt.decl)
eraseCode k
for ctorNameInJmp in s do
for ctorNameInJmp in info.ctorNames do
unless explicitCtorNames.contains ctorNameInJmp do
jpAltMap := jpAltMap.insert ctorNameInJmp jpAlt
let args := if jpAlt.dependsOnDiscr then #[.fvar p.fvarId] else #[]
let args := mkJmpArgsAtJp decl.params info.paramIdx #[] jpAlt.dependsOnDiscr
altsNew := altsNew.push (alt.updateCode (.jmp jpAlt.decl.fvarId args))
else
altsNew := altsNew.push (alt.updateCode k)
| .alt ctorName fields k =>
let k ← visit k
if s.contains ctorName then
let jpAlt ← mkJpAlt decls p fields k (default := false)
if info.ctorNames.contains ctorName then
let jpAlt ← mkJpAlt decls decl.params info.paramIdx fields k (default := false)
jpAltDecls := jpAltDecls.push (.jp jpAlt.decl)
jpAltMap := jpAltMap.insert ctorName jpAlt
let mut args := fields.map (mkFVar ·.fvarId)
if jpAlt.dependsOnDiscr then
args := #[mkFVar p.fvarId] ++ args
let args := mkJmpArgsAtJp decl.params info.paramIdx fields jpAlt.dependsOnDiscr
eraseCode k
altsNew := altsNew.push (alt.updateCode (.jmp jpAlt.decl.fvarId args))
else
@ -223,23 +261,18 @@ where
let code := .jp decl (← visit k)
return LCNF.attachCodeDecls jpAltDecls code
visitJmp? (fvarId : FVarId) (args : Array Expr) : ReaderT JpCasesInfo (StateRefT Ctor2JpCasesAlt CompilerM) (Option Code) := do
visitJmp? (fvarId : FVarId) (args : Array Expr) : ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt CompilerM) (Option Code) := do
let some ctorJpAltMap := (← get).find? fvarId | return none
assert! args.size == 1
let arg ← findExpr args[0]!
let some info := (← read).find? fvarId | return none
let arg ← findExpr args[info.paramIdx]!
let some (ctorVal, ctorArgs) := arg.constructorApp? (← getEnv) (useRaw := true) | return none
let some jpAlt := ctorJpAltMap.find? ctorVal.name | return none
if jpAlt.default then
if jpAlt.dependsOnDiscr then
return some <| .jmp jpAlt.decl.fvarId args
else
return some <| .jmp jpAlt.decl.fvarId #[]
else
let fields := ctorArgs[ctorVal.numParams:]
-- Recall that if `arg` is a `Nat` literal, then `ctorArgs` is a literal too.
-- We use a for-loop because we may have other special cases in the future.
let mut auxDecls := #[]
let mut fieldsNew := #[]
let fields := ctorArgs[ctorVal.numParams:]
-- Recall that if `arg` is a `Nat` literal, then `ctorArgs` is a literal too.
-- We use a for-loop because we may have other special cases in the future.
let mut auxDecls := #[]
let mut fieldsNew := #[]
unless jpAlt.default do
for field in fields do
if field.isFVar then
fieldsNew := fieldsNew.push field
@ -247,11 +280,8 @@ where
let letDecl ← mkAuxLetDecl field
auxDecls := auxDecls.push (CodeDecl.let letDecl)
fieldsNew := fieldsNew.push (.fvar letDecl.fvarId)
let code ← if jpAlt.dependsOnDiscr then
pure <| .jmp jpAlt.decl.fvarId (args ++ fieldsNew)
else
pure <| .jmp jpAlt.decl.fvarId fieldsNew
return some <| LCNF.attachCodeDecls auxDecls code
let argsNew := mkJmpNewArgs args info.paramIdx fieldsNew jpAlt.dependsOnDiscr
return some <| LCNF.attachCodeDecls auxDecls (.jmp jpAlt.decl.fvarId argsNew)
end Simp

View file

@ -0,0 +1,47 @@
set_option trace.Compiler.saveBase true in
def f1 (c : Bool) (a b : Nat) :=
let k d y z :=
match d with
| true => y + z + z*y
| false => z + y
match c with
| true => k false a b
| false => k true b a
set_option trace.Compiler.saveBase true in
def f2 (c : Bool) (a b : Nat) :=
let k y d z :=
match d with
| true => y + z + z*y
| false => z + y
match c with
| true => k a false b
| false => k b true a
inductive C where
| c1 | c2 | c3 | c4
set_option trace.Compiler.saveBase true in
def f3 (c c' : C) (a b : Nat) :=
let k y (d : C) z :=
match d with
| C.c1 => y + z + z*y
| _ => z + y + y
match c with
| .c1 => k a .c2 b
| .c2 => k b .c1 a
| .c3 => k b c' a
| .c4 => k a c' a
set_option trace.Compiler.saveBase true in
def f4 (c c' : C) (a b : Nat) :=
let k y z (d : C) :=
match d with
| C.c1 => y + z + z*y
| C.c3 => y*y+a
| _ => z + y + y
match c with
| .c1 => k a b .c2
| .c2 => k b b .c1
| .c3 => k b a c'
| .c4 => k a a c'

View file

@ -0,0 +1,78 @@
[Compiler.saveBase] size: 7
def f1 c a b :=
cases c
| Bool.false =>
let _x.1 := Nat.add b a
let _x.2 := Nat.mul a b
let _x.3 := Nat.add _x.1 _x.2
_x.3
| Bool.true =>
let _x.4 := Nat.add b a
_x.4
[Compiler.saveBase] size: 7
def f2 c a b :=
cases c
| Bool.false =>
let _x.1 := Nat.add b a
let _x.2 := Nat.mul a b
let _x.3 := Nat.add _x.1 _x.2
_x.3
| Bool.true =>
let _x.4 := Nat.add b a
_x.4
[Compiler.saveBase] size: 19
def f3 c c' a b :=
jp _jp.1 y z :=
let _x.2 := Nat.add y z
let _x.3 := Nat.mul z y
let _x.4 := Nat.add _x.2 _x.3
_x.4
jp _jp.5 y z :=
let _x.6 := Nat.add z y
let _x.7 := Nat.add _x.6 y
_x.7
jp _jp.8 y d z :=
cases d
| C.c1 =>
goto _jp.1 y z
| _ =>
goto _jp.5 y z
cases c
| C.c1 =>
goto _jp.5 a b
| C.c2 =>
goto _jp.1 b a
| C.c3 =>
goto _jp.8 b c' a
| C.c4 =>
goto _jp.8 a c' a
[Compiler.saveBase] size: 22
def f4 c c' a b :=
jp _jp.1 y z :=
let _x.2 := Nat.add y z
let _x.3 := Nat.mul z y
let _x.4 := Nat.add _x.2 _x.3
_x.4
jp _jp.5 y z :=
let _x.6 := Nat.add z y
let _x.7 := Nat.add _x.6 y
_x.7
jp _jp.8 y z d :=
cases d
| C.c1 =>
goto _jp.1 y z
| C.c3 =>
let _x.9 := Nat.mul y y
let _x.10 := Nat.add _x.9 a
_x.10
| _ =>
goto _jp.5 y z
cases c
| C.c1 =>
goto _jp.5 a b
| C.c2 =>
goto _jp.1 b b
| C.c3 =>
goto _jp.8 b a c'
| C.c4 =>
goto _jp.8 a a c'