chore(library/init/lean/compiler/ir): clarify

This commit is contained in:
Leonardo de Moura 2019-05-06 10:52:33 -07:00
parent 51fe185989
commit 67d9f4cd1e
5 changed files with 104 additions and 89 deletions

View file

@ -19,7 +19,7 @@ namespace IR
def test (d : Decl) : IO Unit :=
do
IO.println d,
IO.println $ ("Max variable " ++ toString d.maxVar),
IO.println $ ("Max index " ++ toString d.maxIndex),
let d := d.pushProj,
IO.println "=== After push projections ===",
IO.println d,

View file

@ -17,7 +17,7 @@ partial def reshapeWithoutDeadAux : Array FnBody → FnBody → IndexSet → FnB
let curr := bs.back in
let bs := bs.pop in
let keep (_ : Unit) :=
let used := curr.collectFreeVars used in
let used := curr.collectFreeIndices used in
let b := curr <;> b in
reshapeWithoutDeadAux bs b used in
let keepIfUsed (vidx : Index) :=
@ -29,7 +29,7 @@ partial def reshapeWithoutDeadAux : Array FnBody → FnBody → IndexSet → FnB
| _ := keep ()
def reshapeWithoutDead (bs : Array FnBody) (term : FnBody) : FnBody :=
reshapeWithoutDeadAux bs term term.freeVars
reshapeWithoutDeadAux bs term term.freeIndices
partial def FnBody.elimDead : FnBody → FnBody
| b :=

View file

@ -9,10 +9,89 @@ import init.lean.compiler.ir.basic
namespace Lean
namespace IR
namespace MaxIndex
/- Compute the maximum index `M` used in a declaration.
We `M` to initialize the fresh index generator used to create fresh
variable and join point names.
Recall that we variable and join points share the same namespace in
our implementation.
-/
abbrev Collector := Index → Index
@[inline] private def skip : Collector := id
@[inline] private def collect (x : Index) : Collector := λ y, if x > y then x else y
@[inline] private def collectVar (x : VarId) : Collector := collect x.idx
@[inline] private def collectJP (j : JoinPointId) : Collector := collect j.idx
@[inline] private def seq (k₁ k₂ : Collector) : Collector := k₂ ∘ k₁
instance : HasAndthen Collector := ⟨seq⟩
private def collectArg : Arg → Collector
| (Arg.var x) := collectVar x
| irrelevant := skip
@[specialize] private def collectArray {α : Type} (as : Array α) (f : α → Collector) : Collector :=
λ m, as.foldl (λ m a, f a m) m
private def collectArgs (as : Array Arg) : Collector := collectArray as collectArg
private def collectParam (p : Param) : Collector := collectVar p.x
private def collectParams (ps : Array Param) : Collector := collectArray ps collectParam
private def collectExpr : Expr → Collector
| (Expr.ctor _ ys) := collectArgs ys
| (Expr.reset x) := collectVar x
| (Expr.reuse x _ _ ys) := collectVar x; collectArgs ys
| (Expr.proj _ x) := collectVar x
| (Expr.uproj _ x) := collectVar x
| (Expr.sproj _ _ x) := collectVar x
| (Expr.fap _ ys) := collectArgs ys
| (Expr.pap _ ys) := collectArgs ys
| (Expr.ap x ys) := collectVar x; collectArgs ys
| (Expr.box _ x) := collectVar x
| (Expr.unbox x) := collectVar x
| (Expr.lit v) := skip
| (Expr.isShared x) := collectVar x
| (Expr.isTaggedPtr x) := collectVar x
private def collectAlts (f : FnBody → Collector) (alts : Array Alt) : Collector :=
collectArray alts $ λ alt, f alt.body
partial def collectFnBody : FnBody → Collector
| (FnBody.vdecl x _ v b) := collectExpr v; collectFnBody b
| (FnBody.jdecl j ys _ v b) := collectFnBody v; collectParams ys; collectFnBody b
| (FnBody.set x _ y b) := collectVar x; collectVar y; collectFnBody b
| (FnBody.uset x _ y b) := collectVar x; collectVar y; collectFnBody b
| (FnBody.sset x _ _ y _ b) := collectVar x; collectVar y; collectFnBody b
| (FnBody.release x _ b) := collectVar x; collectFnBody b
| (FnBody.inc x _ _ b) := collectVar x; collectFnBody b
| (FnBody.dec x _ _ b) := collectVar x; collectFnBody b
| (FnBody.mdata _ b) := collectFnBody b
| (FnBody.case _ x alts) := collectVar x; collectAlts collectFnBody alts
| (FnBody.jmp j ys) := collectJP j; collectArgs ys
| (FnBody.ret x) := collectArg x
| FnBody.unreachable := skip
partial def collectDecl : Decl → Collector
| (Decl.fdecl _ xs _ b) := collectParams xs; collectFnBody b
| (Decl.extern _ xs _) := collectParams xs
end MaxIndex
def FnBody.maxIndex (b : FnBody) : Index :=
MaxIndex.collectFnBody b 0
def Decl.maxIndex (d : Decl) : Index :=
MaxIndex.collectDecl d 0
/-- Set of variable and join point names -/
abbrev IndexSet := RBTree Index (λ a b, a < b)
instance vsetInh : Inhabited IndexSet := ⟨{}⟩
namespace FreeVariables
namespace FreeIndices
/- We say a variable (join point) index (aka name) is free in a function body
if there isn't a `FnBody.vdecl` (`Fnbody.jdecl`) binding it. -/
abbrev Collector := IndexSet → IndexSet → IndexSet
@[inline] private def skip : Collector :=
@ -91,83 +170,19 @@ partial def collectFnBody : FnBody → Collector
| (FnBody.ret x) := collectArg x
| FnBody.unreachable := skip
end FreeVariables
end FreeIndices
def FnBody.collectFreeVars (b : FnBody) (vs : IndexSet) : IndexSet :=
FreeVariables.collectFnBody b {} vs
def FnBody.collectFreeIndices (b : FnBody) (vs : IndexSet) : IndexSet :=
FreeIndices.collectFnBody b {} vs
def FnBody.freeVars (b : FnBody) : IndexSet :=
b.collectFreeVars {}
namespace MaxVar
abbrev Collector := Index → Index
@[inline] private def skip : Collector := id
@[inline] private def collect (x : Index) : Collector := λ y, if x > y then x else y
@[inline] private def collectVar (x : VarId) : Collector := collect x.idx
@[inline] private def collectJP (j : JoinPointId) : Collector := collect j.idx
@[inline] private def seq (k₁ k₂ : Collector) : Collector := k₂ ∘ k₁
instance : HasAndthen Collector := ⟨seq⟩
private def collectArg : Arg → Collector
| (Arg.var x) := collectVar x
| irrelevant := skip
@[specialize] private def collectArray {α : Type} (as : Array α) (f : α → Collector) : Collector :=
λ m, as.foldl (λ m a, f a m) m
private def collectArgs (as : Array Arg) : Collector := collectArray as collectArg
private def collectParam (p : Param) : Collector := collectVar p.x
private def collectParams (ps : Array Param) : Collector := collectArray ps collectParam
private def collectExpr : Expr → Collector
| (Expr.ctor _ ys) := collectArgs ys
| (Expr.reset x) := collectVar x
| (Expr.reuse x _ _ ys) := collectVar x; collectArgs ys
| (Expr.proj _ x) := collectVar x
| (Expr.uproj _ x) := collectVar x
| (Expr.sproj _ _ x) := collectVar x
| (Expr.fap _ ys) := collectArgs ys
| (Expr.pap _ ys) := collectArgs ys
| (Expr.ap x ys) := collectVar x; collectArgs ys
| (Expr.box _ x) := collectVar x
| (Expr.unbox x) := collectVar x
| (Expr.lit v) := skip
| (Expr.isShared x) := collectVar x
| (Expr.isTaggedPtr x) := collectVar x
private def collectAlts (f : FnBody → Collector) (alts : Array Alt) : Collector :=
collectArray alts $ λ alt, f alt.body
partial def collectFnBody : FnBody → Collector
| (FnBody.vdecl x _ v b) := collectExpr v; collectFnBody b
| (FnBody.jdecl j ys _ v b) := collectFnBody v; collectParams ys; collectFnBody b
| (FnBody.set x _ y b) := collectVar x; collectVar y; collectFnBody b
| (FnBody.uset x _ y b) := collectVar x; collectVar y; collectFnBody b
| (FnBody.sset x _ _ y _ b) := collectVar x; collectVar y; collectFnBody b
| (FnBody.release x _ b) := collectVar x; collectFnBody b
| (FnBody.inc x _ _ b) := collectVar x; collectFnBody b
| (FnBody.dec x _ _ b) := collectVar x; collectFnBody b
| (FnBody.mdata _ b) := collectFnBody b
| (FnBody.case _ x alts) := collectVar x; collectAlts collectFnBody alts
| (FnBody.jmp j ys) := collectJP j; collectArgs ys
| (FnBody.ret x) := collectArg x
| FnBody.unreachable := skip
partial def collectDecl : Decl → Collector
| (Decl.fdecl _ xs _ b) := collectParams xs; collectFnBody b
| (Decl.extern _ xs _) := collectParams xs
end MaxVar
def FnBody.maxVar (b : FnBody) : Index :=
MaxVar.collectFnBody b 0
def Decl.maxVar (d : Decl) : Index :=
MaxVar.collectDecl d 0
def FnBody.freeIndices (b : FnBody) : IndexSet :=
b.collectFreeIndices {}
namespace HasIndex
/- In principle, we can check whether a function body `b` contains an index `i` using
`b.freeIndices.contains i`, but it is more efficient to avoid the construction
of the set of freeIndices and just search whether `i` occurs in `b` or not.
-/
def visitVar (w : Index) (x : VarId) : Bool := w == x.idx
def visitJP (w : Index) (x : JoinPointId) : Bool := w == x.idx

View file

@ -11,23 +11,23 @@ namespace Lean
namespace IR
partial def pushProjs : Array FnBody → Array Alt → Array IndexSet → Array FnBody → IndexSet → Array FnBody × Array Alt
| bs alts afvs ps vs :=
if bs.isEmpty then (ps.reverse, alts)
| bs alts altsF ctx ctxF :=
if bs.isEmpty then (ctx.reverse, alts)
else
let b := bs.back in
let bs := bs.pop in
let done (_ : Unit) := (bs.push b ++ ps.reverse, alts) in
let skip (_ : Unit) := pushProjs bs alts afvs (ps.push b) (b.collectFreeVars vs) in
let done (_ : Unit) := (bs.push b ++ ctx.reverse, alts) in
let skip (_ : Unit) := pushProjs bs alts altsF (ctx.push b) (b.collectFreeIndices ctxF) in
match b with
| FnBody.vdecl x t v _ :=
match v with
| Expr.proj _ _ :=
if !vs.contains x.idx && !afvs.all (λ s, s.contains x.idx) then
if !ctxF.contains x.idx && !altsF.all (λ s, s.contains x.idx) then
let alts := alts.hmapIdx $ λ i alt, alt.modifyBody $ λ b',
if (afvs.get i).contains x.idx then b <;> b'
if (altsF.get i).contains x.idx then b <;> b'
else b' in
let fvs := afvs.hmap $ λ s, if s.contains x.idx then b.collectFreeVars s else s in
pushProjs bs alts fvs ps vs
let altsF := altsF.hmap $ λ s, if s.contains x.idx then b.collectFreeIndices s else s in
pushProjs bs alts altsF ctx ctxF
else
skip ()
| Expr.uproj _ _ := skip ()
@ -41,8 +41,8 @@ partial def FnBody.pushProj : FnBody → FnBody
let bs := modifyJPs bs FnBody.pushProj in
match term with
| FnBody.case tid x alts :=
let afvs := alts.map $ λ alt, alt.body.freeVars in
let (bs, alts) := pushProjs bs alts afvs Array.empty {x.idx} in
let altsF := alts.map $ λ alt, alt.body.freeIndices in
let (bs, alts) := pushProjs bs alts altsF Array.empty {x.idx} in
let alts := alts.hmap $ λ alt, alt.modifyBody FnBody.pushProj in
let term := FnBody.case tid x alts in
reshape bs term

View file

@ -116,8 +116,8 @@ private partial def R : FnBody → M FnBody
def Decl.insertResetReuse : Decl → Decl
| d@(Decl.fdecl f xs t b) :=
let nextVar := d.maxVar + 1 in
let b := (R b).run' nextVar in
let nextIndex := d.maxIndex + 1 in
let b := (R b).run' nextIndex in
Decl.fdecl f xs t b
| other := other