feat: JpCases for join points with multiple parameters
This commit is contained in:
parent
12deab6516
commit
f0be5439e6
4 changed files with 223 additions and 68 deletions
|
|
@ -50,7 +50,7 @@ def checkpoint (stepName : Name) (decls : Array Decl) : CompilerM Unit := do
|
|||
withOptions (fun opts => opts.setBool `pp.motives.pi false) do
|
||||
let clsName := `Compiler ++ stepName
|
||||
if (← Lean.isTracingEnabledFor clsName) then
|
||||
Lean.addTrace clsName m!"size: {decl.size}\n{← ppDecl decl}"
|
||||
Lean.addTrace clsName m!"size: {decl.size}\n{← ppDecl' decl}"
|
||||
if compiler.check.get (← getOptions) then
|
||||
decl.check
|
||||
if compiler.check.get (← getOptions) then
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ namespace Lean.Compiler.LCNF
|
|||
namespace Simp
|
||||
|
||||
/--
|
||||
Given the function declaration `decl`, return `true` if it is of the form
|
||||
Given the function declaration `decl`, return `some idx` if it is of the form
|
||||
```
|
||||
f y :=
|
||||
... /- This part is not bigger than smallThreshold. -/
|
||||
|
|
@ -20,55 +20,68 @@ f y :=
|
|||
| ... => ...
|
||||
...
|
||||
```
|
||||
`idx` is the index of the parameter used in the `cases` statement.
|
||||
-/
|
||||
def isJpCases (decl : FunDecl) : CompilerM Bool := do
|
||||
if decl.params.size != 1 then
|
||||
return false
|
||||
def isJpCases? (decl : FunDecl) : CompilerM (Option Nat) := do
|
||||
if decl.params.size == 0 then
|
||||
return none
|
||||
else
|
||||
let param := decl.params[0]!
|
||||
let small := (← getConfig).smallThreshold
|
||||
let rec go (code : Code) (prefixSize : Nat) : Bool :=
|
||||
prefixSize <= small &&
|
||||
let rec go (code : Code) (prefixSize : Nat) : Option Nat :=
|
||||
if prefixSize > small then none else
|
||||
match code with
|
||||
| .let _ k => go k (prefixSize + 1) /- TODO: we should have uniform heuristics for estimating the size. -/
|
||||
| .cases c => c.discr == param.fvarId
|
||||
| _ => false
|
||||
| .cases c => decl.params.findIdx? fun param => c.discr == param.fvarId
|
||||
| _ => none
|
||||
return go decl.value 0
|
||||
|
||||
abbrev JpCasesInfo := FVarIdMap NameSet
|
||||
/--
|
||||
Information for join points that satisfy `isJpCases?`
|
||||
-/
|
||||
structure JpCasesInfo where
|
||||
/-- Parameter index returned by `isJpCases?`. This parameter is the one the join point is performing the case-split. -/
|
||||
paramIdx : Nat
|
||||
/--
|
||||
Set of constructor names s.t. `ctorName` is in the set if there is a jump to the join point where the parameter
|
||||
`paramIdx` is a constructor application.
|
||||
-/
|
||||
ctorNames : NameSet := {}
|
||||
deriving Inhabited
|
||||
|
||||
abbrev JpCasesInfoMap := FVarIdMap JpCasesInfo
|
||||
|
||||
/-- Return `true` if the collected information suggests opportunities for the `JpCases` optimization. -/
|
||||
def JpCasesInfo.isCandidate (info : JpCasesInfo) : Bool :=
|
||||
info.any fun _ s => !s.isEmpty
|
||||
def JpCasesInfoMap.isCandidate (info : JpCasesInfoMap) : Bool :=
|
||||
info.any fun _ s => !s.ctorNames.isEmpty
|
||||
|
||||
/--
|
||||
Return a map containing entries `jpFVarId ↦ ctorNames` where `jpFVarId` is the id of join point
|
||||
Return a map containing entries `jpFVarId ↦ { paramIdx, ctorNames }` where `jpFVarId` is the id of join point
|
||||
in code that satisfies `isJpCases`, and `ctorNames` is a set of constructor names such that
|
||||
there is a jump `.jmp jpFVarId #[x]` in `code` and `x` is a constructor application.
|
||||
there is a jump `.jmp jpFVarId #[..., x, ...]` in `code` and `x` is a constructor application.
|
||||
`paramIdx` is the index of the parameter
|
||||
-/
|
||||
partial def collectJpCasesInfo (code : Code) : CompilerM JpCasesInfo := do
|
||||
partial def collectJpCasesInfo (code : Code) : CompilerM JpCasesInfoMap := do
|
||||
let (_, s) ← go code |>.run {}
|
||||
return s
|
||||
where
|
||||
go (code : Code) : StateRefT JpCasesInfo CompilerM Unit := do
|
||||
go (code : Code) : StateRefT JpCasesInfoMap CompilerM Unit := do
|
||||
match code with
|
||||
| .let _ k => go k
|
||||
| .fun decl k => go decl.value; go k
|
||||
| .jp decl k =>
|
||||
if (← isJpCases decl) then
|
||||
modify fun s => s.insert decl.fvarId {}
|
||||
if let some paramIdx ← isJpCases? decl then
|
||||
modify fun s => s.insert decl.fvarId { paramIdx }
|
||||
go decl.value; go k
|
||||
| .cases c => c.alts.forM fun alt => go alt.getCode
|
||||
| .return .. | .unreach .. => return ()
|
||||
| .jmp fvarId args =>
|
||||
if args.size == 1 then
|
||||
if let some ctorNames := (← get).find? fvarId then
|
||||
let arg ← findExpr args[0]!
|
||||
if let some info := (← get).find? fvarId then
|
||||
let arg ← findExpr args[info.paramIdx]!
|
||||
let some (cval, _) := arg.constructorApp? (← getEnv) | return ()
|
||||
modify fun s => s.insert fvarId <| ctorNames.insert cval.name
|
||||
modify fun map => map.insert fvarId <| { info with ctorNames := info.ctorNames.insert cval.name }
|
||||
|
||||
/--
|
||||
Extract the let-declarations and `cases` for a join point body that satisfies `isJpCases`.
|
||||
Extract the let-declarations and `cases` for a join point body that satisfies `isJpCases?`.
|
||||
-/
|
||||
private def extractJpCases (code : Code) : Array CodeDecl × Cases :=
|
||||
go code #[]
|
||||
|
|
@ -87,21 +100,49 @@ structure JpCasesAlt where
|
|||
abbrev Ctor2JpCasesAlt := FVarIdMap (NameMap JpCasesAlt)
|
||||
|
||||
open Internalize in
|
||||
private def mkJpAlt (decls : Array CodeDecl) (discr : Param) (fields : Array Param) (k : Code) (default : Bool) : CompilerM JpCasesAlt := do
|
||||
/--
|
||||
Construct an auxiliary join point for a particular alternative in a join-point that satifies `isJpCases?`.
|
||||
- `decls` is the prefix (before the `cases`). See `isJpCases?`.
|
||||
- `params` are the parameters of the main join point that satisfies `isJpCases?`.
|
||||
- `targetParamIdx` is the index of the parameter that we are expanding to `fields`
|
||||
- `fields` are the fields/parameter of the alternative.
|
||||
- `k` is the body of the alternative.
|
||||
- `default` is true if it is a default alternative.
|
||||
-/
|
||||
private def mkJpAlt (decls : Array CodeDecl) (params : Array Param) (targetParamIdx : Nat) (fields : Array Param) (k : Code) (default : Bool) : CompilerM JpCasesAlt := do
|
||||
go |>.run' {}
|
||||
where
|
||||
go : InternalizeM JpCasesAlt := do
|
||||
let s : FVarIdSet := {}
|
||||
let mut paramsNew := #[]
|
||||
let dependsOnDiscr := k.dependsOn (s.insert discr.fvarId)
|
||||
if dependsOnDiscr then
|
||||
paramsNew := paramsNew.push (← internalizeParam discr)
|
||||
paramsNew := paramsNew ++ (← fields.mapM internalizeParam)
|
||||
let dependsOnDiscr := k.dependsOn (s.insert params[targetParamIdx]!.fvarId)
|
||||
for i in [:params.size] do
|
||||
let param := params[i]!
|
||||
if targetParamIdx == i then
|
||||
if dependsOnDiscr then
|
||||
paramsNew := paramsNew.push (← internalizeParam param)
|
||||
paramsNew := paramsNew ++ (← fields.mapM internalizeParam)
|
||||
else
|
||||
paramsNew := paramsNew.push (← internalizeParam param)
|
||||
let decls ← decls.mapM internalizeCodeDecl
|
||||
let k ← internalizeCode k
|
||||
let value := LCNF.attachCodeDecls decls k
|
||||
return { decl := (← mkAuxJpDecl paramsNew value), default, dependsOnDiscr }
|
||||
|
||||
/-- Create the arguments for a jump to an auxiliary join point created using `mkJpAlt`. -/
|
||||
private def mkJmpNewArgs (args : Array Expr) (targetParamIdx : Nat) (fields : Array Expr) (dependsOnTarget : Bool) : Array Expr :=
|
||||
if dependsOnTarget then
|
||||
args[:targetParamIdx+1] ++ fields ++ args[targetParamIdx+1:]
|
||||
else
|
||||
args[:targetParamIdx] ++ fields ++ args[targetParamIdx+1:]
|
||||
|
||||
/--
|
||||
Create the arguments for a jump to an auxiliary join point created using `mkJpAlt`.
|
||||
This function is used to create jumps from the join point satisfying `isJpCases?` to the new auxiliary join points created using `mkJpAlt`.
|
||||
-/
|
||||
private def mkJmpArgsAtJp (params : Array Param) (targetParamIdx : Nat) (fields : Array Param) (dependsOnTarget : Bool) : Array Expr := Id.run do
|
||||
mkJmpNewArgs (params.map (mkFVar ·.fvarId)) targetParamIdx (fields.map (mkFVar ·.fvarId)) dependsOnTarget
|
||||
|
||||
/--
|
||||
Try to optimize `jpCases` join points.
|
||||
We say a join point is a `jpCases` when it satifies the predicate `isJpCases`.
|
||||
|
|
@ -147,16 +188,16 @@ Note that if all jumps to the join point are with constructors,
|
|||
then the join point is eliminated as dead code.
|
||||
-/
|
||||
partial def simpJpCases? (code : Code) : CompilerM (Option Code) := do
|
||||
let info ← collectJpCasesInfo code
|
||||
unless info.isCandidate do return none
|
||||
let map ← collectJpCasesInfo code
|
||||
unless map.isCandidate do return none
|
||||
traceM `Compiler.simp.jpCases do
|
||||
let mut msg : MessageData := "candidates"
|
||||
for (fvarId, ctorName) in info.toList do
|
||||
msg := msg ++ indentD m!"{mkFVar fvarId} ↦ {ctorName.toList}"
|
||||
for (fvarId, info) in map.toList do
|
||||
msg := msg ++ indentD m!"{mkFVar fvarId} ↦ {info.ctorNames.toList}"
|
||||
return msg
|
||||
visit code info |>.run' {}
|
||||
visit code map |>.run' {}
|
||||
where
|
||||
visit (code : Code) : ReaderT JpCasesInfo (StateRefT Ctor2JpCasesAlt CompilerM) Code := do
|
||||
visit (code : Code) : ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt CompilerM) Code := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
return code.updateLet! decl (← visit k)
|
||||
|
|
@ -179,11 +220,10 @@ where
|
|||
let some code ← visitJmp? fvarId args | return code
|
||||
return code
|
||||
|
||||
visitJp? (decl : FunDecl) (k : Code) : ReaderT JpCasesInfo (StateRefT Ctor2JpCasesAlt CompilerM) (Option Code) := do
|
||||
let some s := (← read).find? decl.fvarId | return none
|
||||
if s.isEmpty then return none
|
||||
-- This join point satisfies `isJp` and there jumps with constructors in `s` to it.
|
||||
let p := decl.params[0]!
|
||||
visitJp? (decl : FunDecl) (k : Code) : ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt CompilerM) (Option Code) := do
|
||||
let some info := (← read).find? decl.fvarId | return none
|
||||
if info.ctorNames.isEmpty then return none
|
||||
-- This join point satisfies `isJpCases?` and there are jumps with constructors in `info` to it.
|
||||
let (decls, cases) := extractJpCases decl.value
|
||||
let mut jpAltMap := {}
|
||||
let mut jpAltDecls := #[]
|
||||
|
|
@ -193,26 +233,24 @@ where
|
|||
| .default k =>
|
||||
let k ← visit k
|
||||
let explicitCtorNames := cases.getCtorNames
|
||||
if s.any fun ctorNameInJump => !explicitCtorNames.contains ctorNameInJump then
|
||||
let jpAlt ← mkJpAlt decls p #[] k (default := true)
|
||||
if info.ctorNames.any fun ctorNameInJump => !explicitCtorNames.contains ctorNameInJump then
|
||||
let jpAlt ← mkJpAlt decls decl.params info.paramIdx #[] k (default := true)
|
||||
jpAltDecls := jpAltDecls.push (.jp jpAlt.decl)
|
||||
eraseCode k
|
||||
for ctorNameInJmp in s do
|
||||
for ctorNameInJmp in info.ctorNames do
|
||||
unless explicitCtorNames.contains ctorNameInJmp do
|
||||
jpAltMap := jpAltMap.insert ctorNameInJmp jpAlt
|
||||
let args := if jpAlt.dependsOnDiscr then #[.fvar p.fvarId] else #[]
|
||||
let args := mkJmpArgsAtJp decl.params info.paramIdx #[] jpAlt.dependsOnDiscr
|
||||
altsNew := altsNew.push (alt.updateCode (.jmp jpAlt.decl.fvarId args))
|
||||
else
|
||||
altsNew := altsNew.push (alt.updateCode k)
|
||||
| .alt ctorName fields k =>
|
||||
let k ← visit k
|
||||
if s.contains ctorName then
|
||||
let jpAlt ← mkJpAlt decls p fields k (default := false)
|
||||
if info.ctorNames.contains ctorName then
|
||||
let jpAlt ← mkJpAlt decls decl.params info.paramIdx fields k (default := false)
|
||||
jpAltDecls := jpAltDecls.push (.jp jpAlt.decl)
|
||||
jpAltMap := jpAltMap.insert ctorName jpAlt
|
||||
let mut args := fields.map (mkFVar ·.fvarId)
|
||||
if jpAlt.dependsOnDiscr then
|
||||
args := #[mkFVar p.fvarId] ++ args
|
||||
let args := mkJmpArgsAtJp decl.params info.paramIdx fields jpAlt.dependsOnDiscr
|
||||
eraseCode k
|
||||
altsNew := altsNew.push (alt.updateCode (.jmp jpAlt.decl.fvarId args))
|
||||
else
|
||||
|
|
@ -223,23 +261,18 @@ where
|
|||
let code := .jp decl (← visit k)
|
||||
return LCNF.attachCodeDecls jpAltDecls code
|
||||
|
||||
visitJmp? (fvarId : FVarId) (args : Array Expr) : ReaderT JpCasesInfo (StateRefT Ctor2JpCasesAlt CompilerM) (Option Code) := do
|
||||
visitJmp? (fvarId : FVarId) (args : Array Expr) : ReaderT JpCasesInfoMap (StateRefT Ctor2JpCasesAlt CompilerM) (Option Code) := do
|
||||
let some ctorJpAltMap := (← get).find? fvarId | return none
|
||||
assert! args.size == 1
|
||||
let arg ← findExpr args[0]!
|
||||
let some info := (← read).find? fvarId | return none
|
||||
let arg ← findExpr args[info.paramIdx]!
|
||||
let some (ctorVal, ctorArgs) := arg.constructorApp? (← getEnv) (useRaw := true) | return none
|
||||
let some jpAlt := ctorJpAltMap.find? ctorVal.name | return none
|
||||
if jpAlt.default then
|
||||
if jpAlt.dependsOnDiscr then
|
||||
return some <| .jmp jpAlt.decl.fvarId args
|
||||
else
|
||||
return some <| .jmp jpAlt.decl.fvarId #[]
|
||||
else
|
||||
let fields := ctorArgs[ctorVal.numParams:]
|
||||
-- Recall that if `arg` is a `Nat` literal, then `ctorArgs` is a literal too.
|
||||
-- We use a for-loop because we may have other special cases in the future.
|
||||
let mut auxDecls := #[]
|
||||
let mut fieldsNew := #[]
|
||||
let fields := ctorArgs[ctorVal.numParams:]
|
||||
-- Recall that if `arg` is a `Nat` literal, then `ctorArgs` is a literal too.
|
||||
-- We use a for-loop because we may have other special cases in the future.
|
||||
let mut auxDecls := #[]
|
||||
let mut fieldsNew := #[]
|
||||
unless jpAlt.default do
|
||||
for field in fields do
|
||||
if field.isFVar then
|
||||
fieldsNew := fieldsNew.push field
|
||||
|
|
@ -247,11 +280,8 @@ where
|
|||
let letDecl ← mkAuxLetDecl field
|
||||
auxDecls := auxDecls.push (CodeDecl.let letDecl)
|
||||
fieldsNew := fieldsNew.push (.fvar letDecl.fvarId)
|
||||
let code ← if jpAlt.dependsOnDiscr then
|
||||
pure <| .jmp jpAlt.decl.fvarId (args ++ fieldsNew)
|
||||
else
|
||||
pure <| .jmp jpAlt.decl.fvarId fieldsNew
|
||||
return some <| LCNF.attachCodeDecls auxDecls code
|
||||
let argsNew := mkJmpNewArgs args info.paramIdx fieldsNew jpAlt.dependsOnDiscr
|
||||
return some <| LCNF.attachCodeDecls auxDecls (.jmp jpAlt.decl.fvarId argsNew)
|
||||
|
||||
end Simp
|
||||
|
||||
|
|
|
|||
47
tests/lean/jpCasesNary.lean
Normal file
47
tests/lean/jpCasesNary.lean
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
set_option trace.Compiler.saveBase true in
|
||||
def f1 (c : Bool) (a b : Nat) :=
|
||||
let k d y z :=
|
||||
match d with
|
||||
| true => y + z + z*y
|
||||
| false => z + y
|
||||
match c with
|
||||
| true => k false a b
|
||||
| false => k true b a
|
||||
|
||||
set_option trace.Compiler.saveBase true in
|
||||
def f2 (c : Bool) (a b : Nat) :=
|
||||
let k y d z :=
|
||||
match d with
|
||||
| true => y + z + z*y
|
||||
| false => z + y
|
||||
match c with
|
||||
| true => k a false b
|
||||
| false => k b true a
|
||||
|
||||
inductive C where
|
||||
| c1 | c2 | c3 | c4
|
||||
|
||||
set_option trace.Compiler.saveBase true in
|
||||
def f3 (c c' : C) (a b : Nat) :=
|
||||
let k y (d : C) z :=
|
||||
match d with
|
||||
| C.c1 => y + z + z*y
|
||||
| _ => z + y + y
|
||||
match c with
|
||||
| .c1 => k a .c2 b
|
||||
| .c2 => k b .c1 a
|
||||
| .c3 => k b c' a
|
||||
| .c4 => k a c' a
|
||||
|
||||
set_option trace.Compiler.saveBase true in
|
||||
def f4 (c c' : C) (a b : Nat) :=
|
||||
let k y z (d : C) :=
|
||||
match d with
|
||||
| C.c1 => y + z + z*y
|
||||
| C.c3 => y*y+a
|
||||
| _ => z + y + y
|
||||
match c with
|
||||
| .c1 => k a b .c2
|
||||
| .c2 => k b b .c1
|
||||
| .c3 => k b a c'
|
||||
| .c4 => k a a c'
|
||||
78
tests/lean/jpCasesNary.lean.expected.out
Normal file
78
tests/lean/jpCasesNary.lean.expected.out
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
[Compiler.saveBase] size: 7
|
||||
def f1 c a b :=
|
||||
cases c
|
||||
| Bool.false =>
|
||||
let _x.1 := Nat.add b a
|
||||
let _x.2 := Nat.mul a b
|
||||
let _x.3 := Nat.add _x.1 _x.2
|
||||
_x.3
|
||||
| Bool.true =>
|
||||
let _x.4 := Nat.add b a
|
||||
_x.4
|
||||
[Compiler.saveBase] size: 7
|
||||
def f2 c a b :=
|
||||
cases c
|
||||
| Bool.false =>
|
||||
let _x.1 := Nat.add b a
|
||||
let _x.2 := Nat.mul a b
|
||||
let _x.3 := Nat.add _x.1 _x.2
|
||||
_x.3
|
||||
| Bool.true =>
|
||||
let _x.4 := Nat.add b a
|
||||
_x.4
|
||||
[Compiler.saveBase] size: 19
|
||||
def f3 c c' a b :=
|
||||
jp _jp.1 y z :=
|
||||
let _x.2 := Nat.add y z
|
||||
let _x.3 := Nat.mul z y
|
||||
let _x.4 := Nat.add _x.2 _x.3
|
||||
_x.4
|
||||
jp _jp.5 y z :=
|
||||
let _x.6 := Nat.add z y
|
||||
let _x.7 := Nat.add _x.6 y
|
||||
_x.7
|
||||
jp _jp.8 y d z :=
|
||||
cases d
|
||||
| C.c1 =>
|
||||
goto _jp.1 y z
|
||||
| _ =>
|
||||
goto _jp.5 y z
|
||||
cases c
|
||||
| C.c1 =>
|
||||
goto _jp.5 a b
|
||||
| C.c2 =>
|
||||
goto _jp.1 b a
|
||||
| C.c3 =>
|
||||
goto _jp.8 b c' a
|
||||
| C.c4 =>
|
||||
goto _jp.8 a c' a
|
||||
[Compiler.saveBase] size: 22
|
||||
def f4 c c' a b :=
|
||||
jp _jp.1 y z :=
|
||||
let _x.2 := Nat.add y z
|
||||
let _x.3 := Nat.mul z y
|
||||
let _x.4 := Nat.add _x.2 _x.3
|
||||
_x.4
|
||||
jp _jp.5 y z :=
|
||||
let _x.6 := Nat.add z y
|
||||
let _x.7 := Nat.add _x.6 y
|
||||
_x.7
|
||||
jp _jp.8 y z d :=
|
||||
cases d
|
||||
| C.c1 =>
|
||||
goto _jp.1 y z
|
||||
| C.c3 =>
|
||||
let _x.9 := Nat.mul y y
|
||||
let _x.10 := Nat.add _x.9 a
|
||||
_x.10
|
||||
| _ =>
|
||||
goto _jp.5 y z
|
||||
cases c
|
||||
| C.c1 =>
|
||||
goto _jp.5 a b
|
||||
| C.c2 =>
|
||||
goto _jp.1 b b
|
||||
| C.c3 =>
|
||||
goto _jp.8 b a c'
|
||||
| C.c4 =>
|
||||
goto _jp.8 a a c'
|
||||
Loading…
Add table
Reference in a new issue