From c6c46df28540220104ea340b60d677eeac1971b4 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 23 May 2019 12:42:31 -0700 Subject: [PATCH] feat(library/init/lean/compiler/ir): develop expandresetreuse --- library/init/lean/compiler/ir/basic.lean | 6 + library/init/lean/compiler/ir/emitutil.lean | 1 - .../lean/compiler/ir/expandresetreuse.lean | 200 ++++++++++++++++-- 3 files changed, 193 insertions(+), 14 deletions(-) diff --git a/library/init/lean/compiler/ir/basic.lean b/library/init/lean/compiler/ir/basic.lean index 0279ddcdd6..b036a22f21 100644 --- a/library/init/lean/compiler/ir/basic.lean +++ b/library/init/lean/compiler/ir/basic.lean @@ -560,5 +560,11 @@ namespace VarIdSet instance : Inhabited VarIdSet := ⟨{}⟩ end VarIdSet +def mkIf (x : VarId) (t e : FnBody) : FnBody := +FnBody.case `Bool x [ + Alt.ctor {name := `Bool.false, cidx := 0, size := 0, usize := 0, ssize := 0} e, + Alt.ctor {name := `Bool.true, cidx := 1, size := 0, usize := 0, ssize := 0} t +].toArray + end IR end Lean diff --git a/library/init/lean/compiler/ir/emitutil.lean b/library/init/lean/compiler/ir/emitutil.lean index a01fb15029..dfb175aa48 100644 --- a/library/init/lean/compiler/ir/emitutil.lean +++ b/library/init/lean/compiler/ir/emitutil.lean @@ -78,7 +78,6 @@ abbrev VarTypeMap := HashMap VarId IRType abbrev JPParamsMap := HashMap JoinPointId (Array Param) namespace CollectMaps -/- Auxiliary monad for collecting Decl information -/ abbrev Collector := (VarTypeMap × JPParamsMap) → (VarTypeMap × JPParamsMap) @[inline] def collectVar (x : VarId) (t : IRType) : Collector | (vs, js) := (vs.insert x t, js) diff --git a/library/init/lean/compiler/ir/expandresetreuse.lean b/library/init/lean/compiler/ir/expandresetreuse.lean index 3d1e7acb0e..ccbfeaf308 100644 --- a/library/init/lean/compiler/ir/expandresetreuse.lean +++ b/library/init/lean/compiler/ir/expandresetreuse.lean @@ -7,11 +7,42 @@ prelude import init.control.state import init.control.reader import init.lean.compiler.ir.compilerm +import init.lean.compiler.ir.normids +import init.lean.compiler.ir.freevars namespace Lean namespace IR namespace ExpandResetReuse +/- Mapping from variable to projections -/ +abbrev ProjMap := HashMap VarId Expr +namespace CollectProjMap +abbrev Collector := ProjMap → ProjMap +@[inline] def collectVDecl (x : VarId) (v : Expr) : Collector := +λ m, match v with + | Expr.proj _ _ := m.insert x v + | Expr.sproj _ _ _ := m.insert x v + | Expr.uproj _ _ := m.insert x v + | _ := m +local infix ` >> `:50 := Function.comp + +partial def collectFnBody : FnBody → Collector +| (FnBody.vdecl x _ v b) := collectVDecl x v >> collectFnBody b +| (FnBody.jdecl _ _ v b) := collectFnBody v >> collectFnBody b +| (FnBody.case _ _ alts) := λ s, alts.foldl (λ s alt, collectFnBody alt.body s) s +| e := if e.isTerminal then id else collectFnBody e.body +end CollectProjMap + +/- Create a mapping from variables to projections. + This function assumes variable ids have been normalized -/ +def mkProjMap (d : Decl) : ProjMap := +match d with +| Decl.fdecl _ _ _ b := CollectProjMap.collectFnBody b {} +| _ := {} + +structure Context := +(projMap : ProjMap) + /- Return true iff `x` is consumed in all branches of the current block. Here consumption means the block contains a `dec x` or `reuse x ...`. -/ partial def consumed (x : VarId) : FnBody → Bool @@ -23,29 +54,172 @@ partial def consumed (x : VarId) : FnBody → Bool | (FnBody.case _ _ alts) := alts.any $ λ alt, consumed alt.body | e := !e.isTerminal && consumed e.body -partial def expand (bs : Array FnBody) (x : VarId) (n : Nat) (y : VarId) (b : FnBody) : FnBody := --- dbgTrace ("FOUND " ++ toString x) $ λ _, -reshape bs (FnBody.vdecl x IRType.object (Expr.reset n y) b) +abbrev Mask := Array (Option VarId) -partial def searchAndExpand : FnBody → Array FnBody → FnBody +/- Auxiliary function for eraseProjIncFor -/ +partial def eraseProjIncForAux (y : VarId) : Array FnBody → Mask → Array FnBody → Array FnBody × Mask +| bs mask keep := + let done (_ : Unit) := (bs ++ keep.reverse, mask) in + let keepInstr (b : FnBody) := eraseProjIncForAux bs.pop mask (keep.push b) in + if bs.size < 2 then done () + else + let b := bs.back in + match b with + | (FnBody.vdecl _ _ (Expr.sproj _ _ _) _) := keepInstr b + | (FnBody.vdecl _ _ (Expr.uproj _ _) _) := keepInstr b + | (FnBody.inc z n c _) := + if n == 0 then done () else + let b' := bs.get (bs.size - 2) in + match b' with + | (FnBody.vdecl w _ (Expr.proj i x) _) := + if w == z && y == x then + /- Found + ``` + let z := proj[i] y; + inc z n c + ``` + We keep `proj`, and `inc` when `n > 1` + -/ + let bs := bs.pop.pop in + let mask := mask.set i (some z) in + let keep := keep.push b' in + let keep := if n == 1 then keep else keep.push (FnBody.inc z (n-1) c FnBody.nil) in + eraseProjIncForAux bs mask keep + else done () + | other := done () + | other := done () + +/- Try to erase `inc` instructions on projections of `y` occurring in the tail of `bs`. + Return the updated `bs` and a bit mask specifying which `inc`s have been removed. -/ +def eraseProjIncFor (n : Nat) (y : VarId) (bs : Array FnBody) : Array FnBody × Mask := +eraseProjIncForAux y bs (mkArray n none) Array.empty + +/- Replace `reuse x ctor ...` with `ctor ...`, and remoce `dec x` -/ +partial def reuseToCtor (x : VarId) : FnBody → FnBody +| (FnBody.dec y n c b) := + if x == y then b -- n must be 1 since `x := reset ...` + else FnBody.dec y n c (reuseToCtor b) +| (FnBody.vdecl z t v b) := + match v with + | Expr.reuse y c u xs := + if x == y then FnBody.vdecl z t (Expr.ctor c xs) b + else FnBody.vdecl z t v (reuseToCtor b) + | _ := + FnBody.vdecl z t v (reuseToCtor b) +| (FnBody.case tid y alts) := + let alts := alts.hmap $ λ alt, alt.modifyBody reuseToCtor in + FnBody.case tid y alts +| e := + if e.isTerminal then e + else + let (instr, b) := e.split in + let b := reuseToCtor b in + instr <;> b + +/- +replace +``` +x := reset y; b +``` +with +``` +inc z_1; ...; inc z_i; dec y; b' +``` +where `z_i`'s are the variables in `mask`, +and `b'` is `b` where we removed `dec x` and replaced `reuse x ctor_i ...` with `ctor_i ...`. +-/ +def mkSlowPath (x y : VarId) (mask : Mask) (b : FnBody) : FnBody := +let b := reuseToCtor x b in +let b := FnBody.dec y 1 true b in +mask.foldl + (λ b m, match m with + | some z := FnBody.inc z 1 true b + | none := b) + b + +abbrev M := ReaderT Context (State Nat) +def mkFresh : M VarId := +do idx ← get, modify (+1), pure { idx := idx } + +def releaseUnreadFields (y : VarId) (mask : Mask) (b : FnBody) : M FnBody := +mask.size.mfold + (λ i b, + match mask.get i with + | some _ := pure b -- code took ownership of this field + | none := do + fld ← mkFresh, + pure (FnBody.vdecl fld IRType.object (Expr.proj i y) (FnBody.dec fld 1 true b))) + b + +/- +replace +``` +x := reset y; b +``` +with +``` +let f_i_1 := proj[i_1] y; +... +let f_i_k := proj[i_k] y; +b' +``` +where `i_j`s are the field indexes +that the code did not touch immediately before the reset. +That is `mask[j] == none`. +`b'` is `b` where `y` `dec x` is replaced with `del y`, +and `z := reuse x ctor_i ws; F` is replaced with +`set x i ws[i]` operations, and we replace `z` with `x` in `F` +-/ +def mkFastPath (x y : VarId) (mask : Mask) (b : FnBody) : M FnBody := +do +let b := FnBody.vdecl x IRType.object (Expr.reset mask.size y) b, -- todo +releaseUnreadFields y mask b + +-- Expand `bs; x := reset[n] y; b` +partial def expand (bs : Array FnBody) (x : VarId) (n : Nat) (y : VarId) (b : FnBody) : M (Array FnBody × FnBody) := +do +let bOld := FnBody.vdecl x IRType.object (Expr.reset n y) b, +let (bs, mask) := eraseProjIncFor n y bs, +let bSlow := mkSlowPath x y mask b, +bFast ← mkFastPath x y mask b, +c ← mkFresh, +let b := FnBody.vdecl c IRType.uint8 (Expr.isShared y) (mkIf c bSlow bFast) in +-- dbgTrace ("expand\n" ++ toString (reshape bs b)) $ λ _, +pure (bs, b) + +partial def searchAndExpand : FnBody → Array FnBody → M FnBody | d@(FnBody.vdecl x t (Expr.reset n y) b) bs := - if consumed x b then - expand bs x n y b + if consumed x b then do + (bs, b) ← expand bs x n y b, + pure $ reshape bs b -- TODO else searchAndExpand b (push bs d) -| (FnBody.case tid x alts) bs := - let alts := alts.hmap $ λ alt, alt.modifyBody $ λ b, searchAndExpand b Array.empty in - reshape bs (FnBody.case tid x alts) +| (FnBody.jdecl j xs v b) bs := do + v ← searchAndExpand v Array.empty, + searchAndExpand b (push bs (FnBody.jdecl j xs v FnBody.nil)) +| (FnBody.case tid x alts) bs := do + alts ← alts.hmmap $ λ alt, alt.mmodifyBody $ λ b, searchAndExpand b Array.empty, + pure $ reshape bs (FnBody.case tid x alts) | b bs := - if b.isTerminal then reshape bs b + if b.isTerminal then pure $ reshape bs b else searchAndExpand b.body (push bs b) +def main (d : Decl) : Decl := +let d := d.normalizeIds in +match d with +| (Decl.fdecl f xs t b) := + let m := mkProjMap d in + let nextIdx := d.maxIndex + 1 in + let b := (searchAndExpand b Array.empty { projMap := m }).run' nextIdx in + Decl.fdecl f xs t b +| d := d + end ExpandResetReuse /-- (Try to) expand `reset` and `reuse` instructions. -/ -def Decl.expandResetReuse : Decl → Decl -| (Decl.fdecl f xs t b) := Decl.fdecl f xs t (ExpandResetReuse.searchAndExpand b Array.empty) -| other := other +def Decl.expandResetReuse (d : Decl) : Decl := +d +-- ExpandResetReuse.main d end IR end Lean