feat(library/init/lean/compiler/ir): develop expandresetreuse
This commit is contained in:
parent
f9f4e6c14b
commit
c6c46df285
3 changed files with 193 additions and 14 deletions
|
|
@ -560,5 +560,11 @@ namespace VarIdSet
|
|||
instance : Inhabited VarIdSet := ⟨{}⟩
|
||||
end VarIdSet
|
||||
|
||||
def mkIf (x : VarId) (t e : FnBody) : FnBody :=
|
||||
FnBody.case `Bool x [
|
||||
Alt.ctor {name := `Bool.false, cidx := 0, size := 0, usize := 0, ssize := 0} e,
|
||||
Alt.ctor {name := `Bool.true, cidx := 1, size := 0, usize := 0, ssize := 0} t
|
||||
].toArray
|
||||
|
||||
end IR
|
||||
end Lean
|
||||
|
|
|
|||
|
|
@ -78,7 +78,6 @@ abbrev VarTypeMap := HashMap VarId IRType
|
|||
abbrev JPParamsMap := HashMap JoinPointId (Array Param)
|
||||
|
||||
namespace CollectMaps
|
||||
/- Auxiliary monad for collecting Decl information -/
|
||||
abbrev Collector := (VarTypeMap × JPParamsMap) → (VarTypeMap × JPParamsMap)
|
||||
@[inline] def collectVar (x : VarId) (t : IRType) : Collector
|
||||
| (vs, js) := (vs.insert x t, js)
|
||||
|
|
|
|||
|
|
@ -7,11 +7,42 @@ prelude
|
|||
import init.control.state
|
||||
import init.control.reader
|
||||
import init.lean.compiler.ir.compilerm
|
||||
import init.lean.compiler.ir.normids
|
||||
import init.lean.compiler.ir.freevars
|
||||
|
||||
namespace Lean
|
||||
namespace IR
|
||||
namespace ExpandResetReuse
|
||||
|
||||
/- Mapping from variable to projections -/
|
||||
abbrev ProjMap := HashMap VarId Expr
|
||||
namespace CollectProjMap
|
||||
abbrev Collector := ProjMap → ProjMap
|
||||
@[inline] def collectVDecl (x : VarId) (v : Expr) : Collector :=
|
||||
λ m, match v with
|
||||
| Expr.proj _ _ := m.insert x v
|
||||
| Expr.sproj _ _ _ := m.insert x v
|
||||
| Expr.uproj _ _ := m.insert x v
|
||||
| _ := m
|
||||
local infix ` >> `:50 := Function.comp
|
||||
|
||||
partial def collectFnBody : FnBody → Collector
|
||||
| (FnBody.vdecl x _ v b) := collectVDecl x v >> collectFnBody b
|
||||
| (FnBody.jdecl _ _ v b) := collectFnBody v >> collectFnBody b
|
||||
| (FnBody.case _ _ alts) := λ s, alts.foldl (λ s alt, collectFnBody alt.body s) s
|
||||
| e := if e.isTerminal then id else collectFnBody e.body
|
||||
end CollectProjMap
|
||||
|
||||
/- Create a mapping from variables to projections.
|
||||
This function assumes variable ids have been normalized -/
|
||||
def mkProjMap (d : Decl) : ProjMap :=
|
||||
match d with
|
||||
| Decl.fdecl _ _ _ b := CollectProjMap.collectFnBody b {}
|
||||
| _ := {}
|
||||
|
||||
structure Context :=
|
||||
(projMap : ProjMap)
|
||||
|
||||
/- Return true iff `x` is consumed in all branches of the current block.
|
||||
Here consumption means the block contains a `dec x` or `reuse x ...`. -/
|
||||
partial def consumed (x : VarId) : FnBody → Bool
|
||||
|
|
@ -23,29 +54,172 @@ partial def consumed (x : VarId) : FnBody → Bool
|
|||
| (FnBody.case _ _ alts) := alts.any $ λ alt, consumed alt.body
|
||||
| e := !e.isTerminal && consumed e.body
|
||||
|
||||
partial def expand (bs : Array FnBody) (x : VarId) (n : Nat) (y : VarId) (b : FnBody) : FnBody :=
|
||||
-- dbgTrace ("FOUND " ++ toString x) $ λ _,
|
||||
reshape bs (FnBody.vdecl x IRType.object (Expr.reset n y) b)
|
||||
abbrev Mask := Array (Option VarId)
|
||||
|
||||
partial def searchAndExpand : FnBody → Array FnBody → FnBody
|
||||
/- Auxiliary function for eraseProjIncFor -/
|
||||
partial def eraseProjIncForAux (y : VarId) : Array FnBody → Mask → Array FnBody → Array FnBody × Mask
|
||||
| bs mask keep :=
|
||||
let done (_ : Unit) := (bs ++ keep.reverse, mask) in
|
||||
let keepInstr (b : FnBody) := eraseProjIncForAux bs.pop mask (keep.push b) in
|
||||
if bs.size < 2 then done ()
|
||||
else
|
||||
let b := bs.back in
|
||||
match b with
|
||||
| (FnBody.vdecl _ _ (Expr.sproj _ _ _) _) := keepInstr b
|
||||
| (FnBody.vdecl _ _ (Expr.uproj _ _) _) := keepInstr b
|
||||
| (FnBody.inc z n c _) :=
|
||||
if n == 0 then done () else
|
||||
let b' := bs.get (bs.size - 2) in
|
||||
match b' with
|
||||
| (FnBody.vdecl w _ (Expr.proj i x) _) :=
|
||||
if w == z && y == x then
|
||||
/- Found
|
||||
```
|
||||
let z := proj[i] y;
|
||||
inc z n c
|
||||
```
|
||||
We keep `proj`, and `inc` when `n > 1`
|
||||
-/
|
||||
let bs := bs.pop.pop in
|
||||
let mask := mask.set i (some z) in
|
||||
let keep := keep.push b' in
|
||||
let keep := if n == 1 then keep else keep.push (FnBody.inc z (n-1) c FnBody.nil) in
|
||||
eraseProjIncForAux bs mask keep
|
||||
else done ()
|
||||
| other := done ()
|
||||
| other := done ()
|
||||
|
||||
/- Try to erase `inc` instructions on projections of `y` occurring in the tail of `bs`.
|
||||
Return the updated `bs` and a bit mask specifying which `inc`s have been removed. -/
|
||||
def eraseProjIncFor (n : Nat) (y : VarId) (bs : Array FnBody) : Array FnBody × Mask :=
|
||||
eraseProjIncForAux y bs (mkArray n none) Array.empty
|
||||
|
||||
/- Replace `reuse x ctor ...` with `ctor ...`, and remoce `dec x` -/
|
||||
partial def reuseToCtor (x : VarId) : FnBody → FnBody
|
||||
| (FnBody.dec y n c b) :=
|
||||
if x == y then b -- n must be 1 since `x := reset ...`
|
||||
else FnBody.dec y n c (reuseToCtor b)
|
||||
| (FnBody.vdecl z t v b) :=
|
||||
match v with
|
||||
| Expr.reuse y c u xs :=
|
||||
if x == y then FnBody.vdecl z t (Expr.ctor c xs) b
|
||||
else FnBody.vdecl z t v (reuseToCtor b)
|
||||
| _ :=
|
||||
FnBody.vdecl z t v (reuseToCtor b)
|
||||
| (FnBody.case tid y alts) :=
|
||||
let alts := alts.hmap $ λ alt, alt.modifyBody reuseToCtor in
|
||||
FnBody.case tid y alts
|
||||
| e :=
|
||||
if e.isTerminal then e
|
||||
else
|
||||
let (instr, b) := e.split in
|
||||
let b := reuseToCtor b in
|
||||
instr <;> b
|
||||
|
||||
/-
|
||||
replace
|
||||
```
|
||||
x := reset y; b
|
||||
```
|
||||
with
|
||||
```
|
||||
inc z_1; ...; inc z_i; dec y; b'
|
||||
```
|
||||
where `z_i`'s are the variables in `mask`,
|
||||
and `b'` is `b` where we removed `dec x` and replaced `reuse x ctor_i ...` with `ctor_i ...`.
|
||||
-/
|
||||
def mkSlowPath (x y : VarId) (mask : Mask) (b : FnBody) : FnBody :=
|
||||
let b := reuseToCtor x b in
|
||||
let b := FnBody.dec y 1 true b in
|
||||
mask.foldl
|
||||
(λ b m, match m with
|
||||
| some z := FnBody.inc z 1 true b
|
||||
| none := b)
|
||||
b
|
||||
|
||||
abbrev M := ReaderT Context (State Nat)
|
||||
def mkFresh : M VarId :=
|
||||
do idx ← get, modify (+1), pure { idx := idx }
|
||||
|
||||
def releaseUnreadFields (y : VarId) (mask : Mask) (b : FnBody) : M FnBody :=
|
||||
mask.size.mfold
|
||||
(λ i b,
|
||||
match mask.get i with
|
||||
| some _ := pure b -- code took ownership of this field
|
||||
| none := do
|
||||
fld ← mkFresh,
|
||||
pure (FnBody.vdecl fld IRType.object (Expr.proj i y) (FnBody.dec fld 1 true b)))
|
||||
b
|
||||
|
||||
/-
|
||||
replace
|
||||
```
|
||||
x := reset y; b
|
||||
```
|
||||
with
|
||||
```
|
||||
let f_i_1 := proj[i_1] y;
|
||||
...
|
||||
let f_i_k := proj[i_k] y;
|
||||
b'
|
||||
```
|
||||
where `i_j`s are the field indexes
|
||||
that the code did not touch immediately before the reset.
|
||||
That is `mask[j] == none`.
|
||||
`b'` is `b` where `y` `dec x` is replaced with `del y`,
|
||||
and `z := reuse x ctor_i ws; F` is replaced with
|
||||
`set x i ws[i]` operations, and we replace `z` with `x` in `F`
|
||||
-/
|
||||
def mkFastPath (x y : VarId) (mask : Mask) (b : FnBody) : M FnBody :=
|
||||
do
|
||||
let b := FnBody.vdecl x IRType.object (Expr.reset mask.size y) b, -- todo
|
||||
releaseUnreadFields y mask b
|
||||
|
||||
-- Expand `bs; x := reset[n] y; b`
|
||||
partial def expand (bs : Array FnBody) (x : VarId) (n : Nat) (y : VarId) (b : FnBody) : M (Array FnBody × FnBody) :=
|
||||
do
|
||||
let bOld := FnBody.vdecl x IRType.object (Expr.reset n y) b,
|
||||
let (bs, mask) := eraseProjIncFor n y bs,
|
||||
let bSlow := mkSlowPath x y mask b,
|
||||
bFast ← mkFastPath x y mask b,
|
||||
c ← mkFresh,
|
||||
let b := FnBody.vdecl c IRType.uint8 (Expr.isShared y) (mkIf c bSlow bFast) in
|
||||
-- dbgTrace ("expand\n" ++ toString (reshape bs b)) $ λ _,
|
||||
pure (bs, b)
|
||||
|
||||
partial def searchAndExpand : FnBody → Array FnBody → M FnBody
|
||||
| d@(FnBody.vdecl x t (Expr.reset n y) b) bs :=
|
||||
if consumed x b then
|
||||
expand bs x n y b
|
||||
if consumed x b then do
|
||||
(bs, b) ← expand bs x n y b,
|
||||
pure $ reshape bs b -- TODO
|
||||
else
|
||||
searchAndExpand b (push bs d)
|
||||
| (FnBody.case tid x alts) bs :=
|
||||
let alts := alts.hmap $ λ alt, alt.modifyBody $ λ b, searchAndExpand b Array.empty in
|
||||
reshape bs (FnBody.case tid x alts)
|
||||
| (FnBody.jdecl j xs v b) bs := do
|
||||
v ← searchAndExpand v Array.empty,
|
||||
searchAndExpand b (push bs (FnBody.jdecl j xs v FnBody.nil))
|
||||
| (FnBody.case tid x alts) bs := do
|
||||
alts ← alts.hmmap $ λ alt, alt.mmodifyBody $ λ b, searchAndExpand b Array.empty,
|
||||
pure $ reshape bs (FnBody.case tid x alts)
|
||||
| b bs :=
|
||||
if b.isTerminal then reshape bs b
|
||||
if b.isTerminal then pure $ reshape bs b
|
||||
else searchAndExpand b.body (push bs b)
|
||||
|
||||
def main (d : Decl) : Decl :=
|
||||
let d := d.normalizeIds in
|
||||
match d with
|
||||
| (Decl.fdecl f xs t b) :=
|
||||
let m := mkProjMap d in
|
||||
let nextIdx := d.maxIndex + 1 in
|
||||
let b := (searchAndExpand b Array.empty { projMap := m }).run' nextIdx in
|
||||
Decl.fdecl f xs t b
|
||||
| d := d
|
||||
|
||||
end ExpandResetReuse
|
||||
|
||||
/-- (Try to) expand `reset` and `reuse` instructions. -/
|
||||
def Decl.expandResetReuse : Decl → Decl
|
||||
| (Decl.fdecl f xs t b) := Decl.fdecl f xs t (ExpandResetReuse.searchAndExpand b Array.empty)
|
||||
| other := other
|
||||
def Decl.expandResetReuse (d : Decl) : Decl :=
|
||||
d
|
||||
-- ExpandResetReuse.main d
|
||||
|
||||
end IR
|
||||
end Lean
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue