feat(library/init/lean/compiler/ir): develop expandresetreuse

This commit is contained in:
Leonardo de Moura 2019-05-23 12:42:31 -07:00
parent f9f4e6c14b
commit c6c46df285
3 changed files with 193 additions and 14 deletions

View file

@ -560,5 +560,11 @@ namespace VarIdSet
instance : Inhabited VarIdSet := ⟨{}⟩
end VarIdSet
def mkIf (x : VarId) (t e : FnBody) : FnBody :=
FnBody.case `Bool x [
Alt.ctor {name := `Bool.false, cidx := 0, size := 0, usize := 0, ssize := 0} e,
Alt.ctor {name := `Bool.true, cidx := 1, size := 0, usize := 0, ssize := 0} t
].toArray
end IR
end Lean

View file

@ -78,7 +78,6 @@ abbrev VarTypeMap := HashMap VarId IRType
abbrev JPParamsMap := HashMap JoinPointId (Array Param)
namespace CollectMaps
/- Auxiliary monad for collecting Decl information -/
abbrev Collector := (VarTypeMap × JPParamsMap) → (VarTypeMap × JPParamsMap)
@[inline] def collectVar (x : VarId) (t : IRType) : Collector
| (vs, js) := (vs.insert x t, js)

View file

@ -7,11 +7,42 @@ prelude
import init.control.state
import init.control.reader
import init.lean.compiler.ir.compilerm
import init.lean.compiler.ir.normids
import init.lean.compiler.ir.freevars
namespace Lean
namespace IR
namespace ExpandResetReuse
/- Mapping from variable to projections -/
abbrev ProjMap := HashMap VarId Expr
namespace CollectProjMap
abbrev Collector := ProjMap → ProjMap
@[inline] def collectVDecl (x : VarId) (v : Expr) : Collector :=
λ m, match v with
| Expr.proj _ _ := m.insert x v
| Expr.sproj _ _ _ := m.insert x v
| Expr.uproj _ _ := m.insert x v
| _ := m
local infix ` >> `:50 := Function.comp
partial def collectFnBody : FnBody → Collector
| (FnBody.vdecl x _ v b) := collectVDecl x v >> collectFnBody b
| (FnBody.jdecl _ _ v b) := collectFnBody v >> collectFnBody b
| (FnBody.case _ _ alts) := λ s, alts.foldl (λ s alt, collectFnBody alt.body s) s
| e := if e.isTerminal then id else collectFnBody e.body
end CollectProjMap
/- Create a mapping from variables to projections.
This function assumes variable ids have been normalized -/
def mkProjMap (d : Decl) : ProjMap :=
match d with
| Decl.fdecl _ _ _ b := CollectProjMap.collectFnBody b {}
| _ := {}
structure Context :=
(projMap : ProjMap)
/- Return true iff `x` is consumed in all branches of the current block.
Here consumption means the block contains a `dec x` or `reuse x ...`. -/
partial def consumed (x : VarId) : FnBody → Bool
@ -23,29 +54,172 @@ partial def consumed (x : VarId) : FnBody → Bool
| (FnBody.case _ _ alts) := alts.any $ λ alt, consumed alt.body
| e := !e.isTerminal && consumed e.body
partial def expand (bs : Array FnBody) (x : VarId) (n : Nat) (y : VarId) (b : FnBody) : FnBody :=
-- dbgTrace ("FOUND " ++ toString x) $ λ _,
reshape bs (FnBody.vdecl x IRType.object (Expr.reset n y) b)
abbrev Mask := Array (Option VarId)
partial def searchAndExpand : FnBody → Array FnBody → FnBody
/- Auxiliary function for eraseProjIncFor -/
partial def eraseProjIncForAux (y : VarId) : Array FnBody → Mask → Array FnBody → Array FnBody × Mask
| bs mask keep :=
let done (_ : Unit) := (bs ++ keep.reverse, mask) in
let keepInstr (b : FnBody) := eraseProjIncForAux bs.pop mask (keep.push b) in
if bs.size < 2 then done ()
else
let b := bs.back in
match b with
| (FnBody.vdecl _ _ (Expr.sproj _ _ _) _) := keepInstr b
| (FnBody.vdecl _ _ (Expr.uproj _ _) _) := keepInstr b
| (FnBody.inc z n c _) :=
if n == 0 then done () else
let b' := bs.get (bs.size - 2) in
match b' with
| (FnBody.vdecl w _ (Expr.proj i x) _) :=
if w == z && y == x then
/- Found
```
let z := proj[i] y;
inc z n c
```
We keep `proj`, and `inc` when `n > 1`
-/
let bs := bs.pop.pop in
let mask := mask.set i (some z) in
let keep := keep.push b' in
let keep := if n == 1 then keep else keep.push (FnBody.inc z (n-1) c FnBody.nil) in
eraseProjIncForAux bs mask keep
else done ()
| other := done ()
| other := done ()
/- Try to erase `inc` instructions on projections of `y` occurring in the tail of `bs`.
Return the updated `bs` and a bit mask specifying which `inc`s have been removed. -/
def eraseProjIncFor (n : Nat) (y : VarId) (bs : Array FnBody) : Array FnBody × Mask :=
eraseProjIncForAux y bs (mkArray n none) Array.empty
/- Replace `reuse x ctor ...` with `ctor ...`, and remoce `dec x` -/
partial def reuseToCtor (x : VarId) : FnBody → FnBody
| (FnBody.dec y n c b) :=
if x == y then b -- n must be 1 since `x := reset ...`
else FnBody.dec y n c (reuseToCtor b)
| (FnBody.vdecl z t v b) :=
match v with
| Expr.reuse y c u xs :=
if x == y then FnBody.vdecl z t (Expr.ctor c xs) b
else FnBody.vdecl z t v (reuseToCtor b)
| _ :=
FnBody.vdecl z t v (reuseToCtor b)
| (FnBody.case tid y alts) :=
let alts := alts.hmap $ λ alt, alt.modifyBody reuseToCtor in
FnBody.case tid y alts
| e :=
if e.isTerminal then e
else
let (instr, b) := e.split in
let b := reuseToCtor b in
instr <;> b
/-
replace
```
x := reset y; b
```
with
```
inc z_1; ...; inc z_i; dec y; b'
```
where `z_i`'s are the variables in `mask`,
and `b'` is `b` where we removed `dec x` and replaced `reuse x ctor_i ...` with `ctor_i ...`.
-/
def mkSlowPath (x y : VarId) (mask : Mask) (b : FnBody) : FnBody :=
let b := reuseToCtor x b in
let b := FnBody.dec y 1 true b in
mask.foldl
(λ b m, match m with
| some z := FnBody.inc z 1 true b
| none := b)
b
abbrev M := ReaderT Context (State Nat)
def mkFresh : M VarId :=
do idx ← get, modify (+1), pure { idx := idx }
def releaseUnreadFields (y : VarId) (mask : Mask) (b : FnBody) : M FnBody :=
mask.size.mfold
(λ i b,
match mask.get i with
| some _ := pure b -- code took ownership of this field
| none := do
fld ← mkFresh,
pure (FnBody.vdecl fld IRType.object (Expr.proj i y) (FnBody.dec fld 1 true b)))
b
/-
replace
```
x := reset y; b
```
with
```
let f_i_1 := proj[i_1] y;
...
let f_i_k := proj[i_k] y;
b'
```
where `i_j`s are the field indexes
that the code did not touch immediately before the reset.
That is `mask[j] == none`.
`b'` is `b` where `y` `dec x` is replaced with `del y`,
and `z := reuse x ctor_i ws; F` is replaced with
`set x i ws[i]` operations, and we replace `z` with `x` in `F`
-/
def mkFastPath (x y : VarId) (mask : Mask) (b : FnBody) : M FnBody :=
do
let b := FnBody.vdecl x IRType.object (Expr.reset mask.size y) b, -- todo
releaseUnreadFields y mask b
-- Expand `bs; x := reset[n] y; b`
partial def expand (bs : Array FnBody) (x : VarId) (n : Nat) (y : VarId) (b : FnBody) : M (Array FnBody × FnBody) :=
do
let bOld := FnBody.vdecl x IRType.object (Expr.reset n y) b,
let (bs, mask) := eraseProjIncFor n y bs,
let bSlow := mkSlowPath x y mask b,
bFast ← mkFastPath x y mask b,
c ← mkFresh,
let b := FnBody.vdecl c IRType.uint8 (Expr.isShared y) (mkIf c bSlow bFast) in
-- dbgTrace ("expand\n" ++ toString (reshape bs b)) $ λ _,
pure (bs, b)
partial def searchAndExpand : FnBody → Array FnBody → M FnBody
| d@(FnBody.vdecl x t (Expr.reset n y) b) bs :=
if consumed x b then
expand bs x n y b
if consumed x b then do
(bs, b) ← expand bs x n y b,
pure $ reshape bs b -- TODO
else
searchAndExpand b (push bs d)
| (FnBody.case tid x alts) bs :=
let alts := alts.hmap $ λ alt, alt.modifyBody $ λ b, searchAndExpand b Array.empty in
reshape bs (FnBody.case tid x alts)
| (FnBody.jdecl j xs v b) bs := do
v ← searchAndExpand v Array.empty,
searchAndExpand b (push bs (FnBody.jdecl j xs v FnBody.nil))
| (FnBody.case tid x alts) bs := do
alts ← alts.hmmap $ λ alt, alt.mmodifyBody $ λ b, searchAndExpand b Array.empty,
pure $ reshape bs (FnBody.case tid x alts)
| b bs :=
if b.isTerminal then reshape bs b
if b.isTerminal then pure $ reshape bs b
else searchAndExpand b.body (push bs b)
def main (d : Decl) : Decl :=
let d := d.normalizeIds in
match d with
| (Decl.fdecl f xs t b) :=
let m := mkProjMap d in
let nextIdx := d.maxIndex + 1 in
let b := (searchAndExpand b Array.empty { projMap := m }).run' nextIdx in
Decl.fdecl f xs t b
| d := d
end ExpandResetReuse
/-- (Try to) expand `reset` and `reuse` instructions. -/
def Decl.expandResetReuse : Decl → Decl
| (Decl.fdecl f xs t b) := Decl.fdecl f xs t (ExpandResetReuse.searchAndExpand b Array.empty)
| other := other
def Decl.expandResetReuse (d : Decl) : Decl :=
d
-- ExpandResetReuse.main d
end IR
end Lean