lean4-htt/src/Lean/Compiler/IR/ExpandResetReuse.lean
2020-10-16 16:37:56 -07:00

289 lines
9.5 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#lang lean4
/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.Compiler.IR.CompilerM
import Lean.Compiler.IR.NormIds
import Lean.Compiler.IR.FreeVars
namespace Lean.IR.ExpandResetReuse
/- Mapping from variable to projections -/
abbrev ProjMap := Std.HashMap VarId Expr
namespace CollectProjMap
abbrev Collector := ProjMap → ProjMap
@[inline] def collectVDecl (x : VarId) (v : Expr) : Collector :=
fun m => match v with
| Expr.proj _ _ => m.insert x v
| Expr.sproj _ _ _ => m.insert x v
| Expr.uproj _ _ => m.insert x v
| _ => m
partial def collectFnBody : FnBody → Collector
| FnBody.vdecl x _ v b => collectVDecl x v ∘ collectFnBody b
| FnBody.jdecl _ _ v b => collectFnBody v ∘ collectFnBody b
| FnBody.case _ _ _ alts => fun s => alts.foldl (fun s alt => collectFnBody alt.body s) s
| e => if e.isTerminal then id else collectFnBody e.body
end CollectProjMap
/- Create a mapping from variables to projections.
This function assumes variable ids have been normalized -/
def mkProjMap (d : Decl) : ProjMap :=
match d with
| Decl.fdecl _ _ _ b => CollectProjMap.collectFnBody b {}
| _ => {}
structure Context :=
(projMap : ProjMap)
/- Return true iff `x` is consumed in all branches of the current block.
Here consumption means the block contains a `dec x` or `reuse x ...`. -/
partial def consumed (x : VarId) : FnBody → Bool
| FnBody.vdecl _ _ v b =>
match v with
| Expr.reuse y _ _ _ => x == y || consumed x b
| _ => consumed x b
| FnBody.dec y _ _ _ b => x == y || consumed x b
| FnBody.case _ _ _ alts => alts.all fun alt => consumed x alt.body
| e => !e.isTerminal && consumed x e.body
abbrev Mask := Array (Option VarId)
/- Auxiliary function for eraseProjIncFor -/
partial def eraseProjIncForAux (y : VarId) : Array FnBody → Mask → Array FnBody → Array FnBody × Mask
| bs, mask, keep =>
let done (_ : Unit) := (bs ++ keep.reverse, mask)
let keepInstr (b : FnBody) := eraseProjIncForAux y bs.pop mask (keep.push b)
if bs.size < 2 then done ()
else
let b := bs.back
match b with
| (FnBody.vdecl _ _ (Expr.sproj _ _ _) _) => keepInstr b
| (FnBody.vdecl _ _ (Expr.uproj _ _) _) => keepInstr b
| (FnBody.inc z n c p _) =>
if n == 0 then done () else
let b' := bs[bs.size - 2]
match b' with
| (FnBody.vdecl w _ (Expr.proj i x) _) =>
if w == z && y == x then
/- Found
```
let z := proj[i] y
inc z n c
```
We keep `proj`, and `inc` when `n > 1`
-/
let bs := bs.pop.pop
let mask := mask.set! i (some z)
let keep := keep.push b'
let keep := if n == 1 then keep else keep.push (FnBody.inc z (n-1) c p FnBody.nil)
eraseProjIncForAux y bs mask keep
else done ()
| other => done ()
| other => done ()
/- Try to erase `inc` instructions on projections of `y` occurring in the tail of `bs`.
Return the updated `bs` and a bit mask specifying which `inc`s have been removed. -/
def eraseProjIncFor (n : Nat) (y : VarId) (bs : Array FnBody) : Array FnBody × Mask :=
eraseProjIncForAux y bs (mkArray n none) #[]
/- Replace `reuse x ctor ...` with `ctor ...`, and remoce `dec x` -/
partial def reuseToCtor (x : VarId) : FnBody → FnBody
| FnBody.dec y n c p b =>
if x == y then b -- n must be 1 since `x := reset ...`
else FnBody.dec y n c p (reuseToCtor x b)
| FnBody.vdecl z t v b =>
match v with
| Expr.reuse y c u xs =>
if x == y then FnBody.vdecl z t (Expr.ctor c xs) b
else FnBody.vdecl z t v (reuseToCtor x b)
| _ =>
FnBody.vdecl z t v (reuseToCtor x b)
| FnBody.case tid y yType alts =>
let alts := alts.map fun alt => alt.modifyBody (reuseToCtor x)
FnBody.case tid y yType alts
| e =>
if e.isTerminal then
e
else
let (instr, b) := e.split
let b := reuseToCtor x b
instr.setBody b
/-
replace
```
x := reset y; b
```
with
```
inc z_1; ...; inc z_i; dec y; b'
```
where `z_i`'s are the variables in `mask`,
and `b'` is `b` where we removed `dec x` and replaced `reuse x ctor_i ...` with `ctor_i ...`.
-/
def mkSlowPath (x y : VarId) (mask : Mask) (b : FnBody) : FnBody :=
let b := reuseToCtor x b
let b := FnBody.dec y 1 true false b
mask.foldl
(fun b m => match m with
| some z => FnBody.inc z 1 true false b
| none => b)
b
abbrev M := ReaderT Context (StateM Nat)
def mkFresh : M VarId :=
modifyGet $ fun n => ({ idx := n }, n + 1)
def releaseUnreadFields (y : VarId) (mask : Mask) (b : FnBody) : M FnBody :=
mask.size.foldM
(fun i b =>
match mask.get! i with
| some _ => pure b -- code took ownership of this field
| none => do
let fld ← mkFresh
pure (FnBody.vdecl fld IRType.object (Expr.proj i y) (FnBody.dec fld 1 true false b)))
b
def setFields (y : VarId) (zs : Array Arg) (b : FnBody) : FnBody :=
zs.size.fold
(fun i b => FnBody.set y i (zs.get! i) b)
b
/- Given `set x[i] := y`, return true iff `y := proj[i] x` -/
def isSelfSet (ctx : Context) (x : VarId) (i : Nat) (y : Arg) : Bool :=
match y with
| Arg.var y =>
match ctx.projMap.find? y with
| some (Expr.proj j w) => j == i && w == x
| _ => false
| _ => false
/- Given `uset x[i] := y`, return true iff `y := uproj[i] x` -/
def isSelfUSet (ctx : Context) (x : VarId) (i : Nat) (y : VarId) : Bool :=
match ctx.projMap.find? y with
| some (Expr.uproj j w) => j == i && w == x
| _ => false
/- Given `sset x[n, i] := y`, return true iff `y := sproj[n, i] x` -/
def isSelfSSet (ctx : Context) (x : VarId) (n : Nat) (i : Nat) (y : VarId) : Bool :=
match ctx.projMap.find? y with
| some (Expr.sproj m j w) => n == m && j == i && w == x
| _ => false
/- Remove unnecessary `set/uset/sset` operations -/
partial def removeSelfSet (ctx : Context) : FnBody → FnBody
| FnBody.set x i y b =>
if isSelfSet ctx x i y then removeSelfSet ctx b
else FnBody.set x i y (removeSelfSet ctx b)
| FnBody.uset x i y b =>
if isSelfUSet ctx x i y then removeSelfSet ctx b
else FnBody.uset x i y (removeSelfSet ctx b)
| FnBody.sset x n i y t b =>
if isSelfSSet ctx x n i y then removeSelfSet ctx b
else FnBody.sset x n i y t (removeSelfSet ctx b)
| FnBody.case tid y yType alts =>
let alts := alts.map fun alt => alt.modifyBody (removeSelfSet ctx)
FnBody.case tid y yType alts
| e =>
if e.isTerminal then e
else
let (instr, b) := e.split
let b := removeSelfSet ctx b
instr.setBody b
partial def reuseToSet (ctx : Context) (x y : VarId) : FnBody → FnBody
| FnBody.dec z n c p b =>
if x == z then FnBody.del y b
else FnBody.dec z n c p (reuseToSet ctx x y b)
| FnBody.vdecl z t v b =>
match v with
| Expr.reuse w c u zs =>
if x == w then
let b := setFields y zs (b.replaceVar z y)
let b := if u then FnBody.setTag y c.cidx b else b
removeSelfSet ctx b
else FnBody.vdecl z t v (reuseToSet ctx x y b)
| _ => FnBody.vdecl z t v (reuseToSet ctx x y b)
| FnBody.case tid z zType alts =>
let alts := alts.map fun alt => alt.modifyBody (reuseToSet ctx x y)
FnBody.case tid z zType alts
| e =>
if e.isTerminal then e
else
let (instr, b) := e.split
let b := reuseToSet ctx x y b
instr.setBody b
/-
replace
```
x := reset y; b
```
with
```
let f_i_1 := proj[i_1] y;
...
let f_i_k := proj[i_k] y;
b'
```
where `i_j`s are the field indexes
that the code did not touch immediately before the reset.
That is `mask[j] == none`.
`b'` is `b` where `y` `dec x` is replaced with `del y`,
and `z := reuse x ctor_i ws; F` is replaced with
`set x i ws[i]` operations, and we replace `z` with `x` in `F`
-/
def mkFastPath (x y : VarId) (mask : Mask) (b : FnBody) : M FnBody := do
let ctx ← read
let b := reuseToSet ctx x y b
releaseUnreadFields y mask b
-- Expand `bs; x := reset[n] y; b`
partial def expand (mainFn : FnBody → Array FnBody → M FnBody)
(bs : Array FnBody) (x : VarId) (n : Nat) (y : VarId) (b : FnBody) : M FnBody := do
let bOld := FnBody.vdecl x IRType.object (Expr.reset n y) b
let (bs, mask) := eraseProjIncFor n y bs
/- Remark: we may be duplicting variable/JP indices. That is, `bSlow` and `bFast` may
have duplicate indices. We run `normalizeIds` to fix the ids after we have expand them. -/
let bSlow := mkSlowPath x y mask b
let bFast ← mkFastPath x y mask b
/- We only optimize recursively the fast. -/
let bFast ← mainFn bFast #[]
let c ← mkFresh
let b := FnBody.vdecl c IRType.uint8 (Expr.isShared y) (mkIf c bSlow bFast)
pure $ reshape bs b
partial def searchAndExpand : FnBody → Array FnBody → M FnBody
| d@(FnBody.vdecl x t (Expr.reset n y) b), bs =>
if consumed x b then do
expand searchAndExpand bs x n y b
else
searchAndExpand b (push bs d)
| FnBody.jdecl j xs v b, bs => do
let v ← searchAndExpand v #[]
searchAndExpand b (push bs (FnBody.jdecl j xs v FnBody.nil))
| FnBody.case tid x xType alts, bs => do
let alts ← alts.mapM $ fun alt => alt.mmodifyBody fun b => searchAndExpand b #[]
pure $ reshape bs (FnBody.case tid x xType alts)
| b, bs =>
if b.isTerminal then pure $ reshape bs b
else searchAndExpand b.body (push bs b)
def main (d : Decl) : Decl :=
match d with
| (Decl.fdecl f xs t b) =>
let m := mkProjMap d
let nextIdx := d.maxIndex + 1
let b := (searchAndExpand b #[] { projMap := m }).run' nextIdx
Decl.fdecl f xs t b
| d => d
end ExpandResetReuse
/-- (Try to) expand `reset` and `reuse` instructions. -/
def Decl.expandResetReuse (d : Decl) : Decl :=
(ExpandResetReuse.main d).normalizeIds
end Lean.IR