chore(library/init/lean/compiler/ir): clarify
This commit is contained in:
parent
51fe185989
commit
67d9f4cd1e
5 changed files with 104 additions and 89 deletions
|
|
@ -19,7 +19,7 @@ namespace IR
|
|||
def test (d : Decl) : IO Unit :=
|
||||
do
|
||||
IO.println d,
|
||||
IO.println $ ("Max variable " ++ toString d.maxVar),
|
||||
IO.println $ ("Max index " ++ toString d.maxIndex),
|
||||
let d := d.pushProj,
|
||||
IO.println "=== After push projections ===",
|
||||
IO.println d,
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ partial def reshapeWithoutDeadAux : Array FnBody → FnBody → IndexSet → FnB
|
|||
let curr := bs.back in
|
||||
let bs := bs.pop in
|
||||
let keep (_ : Unit) :=
|
||||
let used := curr.collectFreeVars used in
|
||||
let used := curr.collectFreeIndices used in
|
||||
let b := curr <;> b in
|
||||
reshapeWithoutDeadAux bs b used in
|
||||
let keepIfUsed (vidx : Index) :=
|
||||
|
|
@ -29,7 +29,7 @@ partial def reshapeWithoutDeadAux : Array FnBody → FnBody → IndexSet → FnB
|
|||
| _ := keep ()
|
||||
|
||||
def reshapeWithoutDead (bs : Array FnBody) (term : FnBody) : FnBody :=
|
||||
reshapeWithoutDeadAux bs term term.freeVars
|
||||
reshapeWithoutDeadAux bs term term.freeIndices
|
||||
|
||||
partial def FnBody.elimDead : FnBody → FnBody
|
||||
| b :=
|
||||
|
|
|
|||
|
|
@ -9,10 +9,89 @@ import init.lean.compiler.ir.basic
|
|||
namespace Lean
|
||||
namespace IR
|
||||
|
||||
namespace MaxIndex
|
||||
/- Compute the maximum index `M` used in a declaration.
|
||||
We `M` to initialize the fresh index generator used to create fresh
|
||||
variable and join point names.
|
||||
|
||||
Recall that we variable and join points share the same namespace in
|
||||
our implementation.
|
||||
-/
|
||||
|
||||
abbrev Collector := Index → Index
|
||||
|
||||
@[inline] private def skip : Collector := id
|
||||
@[inline] private def collect (x : Index) : Collector := λ y, if x > y then x else y
|
||||
@[inline] private def collectVar (x : VarId) : Collector := collect x.idx
|
||||
@[inline] private def collectJP (j : JoinPointId) : Collector := collect j.idx
|
||||
@[inline] private def seq (k₁ k₂ : Collector) : Collector := k₂ ∘ k₁
|
||||
instance : HasAndthen Collector := ⟨seq⟩
|
||||
|
||||
private def collectArg : Arg → Collector
|
||||
| (Arg.var x) := collectVar x
|
||||
| irrelevant := skip
|
||||
|
||||
@[specialize] private def collectArray {α : Type} (as : Array α) (f : α → Collector) : Collector :=
|
||||
λ m, as.foldl (λ m a, f a m) m
|
||||
|
||||
private def collectArgs (as : Array Arg) : Collector := collectArray as collectArg
|
||||
private def collectParam (p : Param) : Collector := collectVar p.x
|
||||
private def collectParams (ps : Array Param) : Collector := collectArray ps collectParam
|
||||
|
||||
private def collectExpr : Expr → Collector
|
||||
| (Expr.ctor _ ys) := collectArgs ys
|
||||
| (Expr.reset x) := collectVar x
|
||||
| (Expr.reuse x _ _ ys) := collectVar x; collectArgs ys
|
||||
| (Expr.proj _ x) := collectVar x
|
||||
| (Expr.uproj _ x) := collectVar x
|
||||
| (Expr.sproj _ _ x) := collectVar x
|
||||
| (Expr.fap _ ys) := collectArgs ys
|
||||
| (Expr.pap _ ys) := collectArgs ys
|
||||
| (Expr.ap x ys) := collectVar x; collectArgs ys
|
||||
| (Expr.box _ x) := collectVar x
|
||||
| (Expr.unbox x) := collectVar x
|
||||
| (Expr.lit v) := skip
|
||||
| (Expr.isShared x) := collectVar x
|
||||
| (Expr.isTaggedPtr x) := collectVar x
|
||||
|
||||
private def collectAlts (f : FnBody → Collector) (alts : Array Alt) : Collector :=
|
||||
collectArray alts $ λ alt, f alt.body
|
||||
|
||||
partial def collectFnBody : FnBody → Collector
|
||||
| (FnBody.vdecl x _ v b) := collectExpr v; collectFnBody b
|
||||
| (FnBody.jdecl j ys _ v b) := collectFnBody v; collectParams ys; collectFnBody b
|
||||
| (FnBody.set x _ y b) := collectVar x; collectVar y; collectFnBody b
|
||||
| (FnBody.uset x _ y b) := collectVar x; collectVar y; collectFnBody b
|
||||
| (FnBody.sset x _ _ y _ b) := collectVar x; collectVar y; collectFnBody b
|
||||
| (FnBody.release x _ b) := collectVar x; collectFnBody b
|
||||
| (FnBody.inc x _ _ b) := collectVar x; collectFnBody b
|
||||
| (FnBody.dec x _ _ b) := collectVar x; collectFnBody b
|
||||
| (FnBody.mdata _ b) := collectFnBody b
|
||||
| (FnBody.case _ x alts) := collectVar x; collectAlts collectFnBody alts
|
||||
| (FnBody.jmp j ys) := collectJP j; collectArgs ys
|
||||
| (FnBody.ret x) := collectArg x
|
||||
| FnBody.unreachable := skip
|
||||
|
||||
partial def collectDecl : Decl → Collector
|
||||
| (Decl.fdecl _ xs _ b) := collectParams xs; collectFnBody b
|
||||
| (Decl.extern _ xs _) := collectParams xs
|
||||
|
||||
end MaxIndex
|
||||
|
||||
def FnBody.maxIndex (b : FnBody) : Index :=
|
||||
MaxIndex.collectFnBody b 0
|
||||
|
||||
def Decl.maxIndex (d : Decl) : Index :=
|
||||
MaxIndex.collectDecl d 0
|
||||
|
||||
/-- Set of variable and join point names -/
|
||||
abbrev IndexSet := RBTree Index (λ a b, a < b)
|
||||
instance vsetInh : Inhabited IndexSet := ⟨{}⟩
|
||||
|
||||
namespace FreeVariables
|
||||
namespace FreeIndices
|
||||
/- We say a variable (join point) index (aka name) is free in a function body
|
||||
if there isn't a `FnBody.vdecl` (`Fnbody.jdecl`) binding it. -/
|
||||
|
||||
abbrev Collector := IndexSet → IndexSet → IndexSet
|
||||
|
||||
@[inline] private def skip : Collector :=
|
||||
|
|
@ -91,83 +170,19 @@ partial def collectFnBody : FnBody → Collector
|
|||
| (FnBody.ret x) := collectArg x
|
||||
| FnBody.unreachable := skip
|
||||
|
||||
end FreeVariables
|
||||
end FreeIndices
|
||||
|
||||
def FnBody.collectFreeVars (b : FnBody) (vs : IndexSet) : IndexSet :=
|
||||
FreeVariables.collectFnBody b {} vs
|
||||
def FnBody.collectFreeIndices (b : FnBody) (vs : IndexSet) : IndexSet :=
|
||||
FreeIndices.collectFnBody b {} vs
|
||||
|
||||
def FnBody.freeVars (b : FnBody) : IndexSet :=
|
||||
b.collectFreeVars {}
|
||||
|
||||
namespace MaxVar
|
||||
abbrev Collector := Index → Index
|
||||
|
||||
@[inline] private def skip : Collector := id
|
||||
@[inline] private def collect (x : Index) : Collector := λ y, if x > y then x else y
|
||||
@[inline] private def collectVar (x : VarId) : Collector := collect x.idx
|
||||
@[inline] private def collectJP (j : JoinPointId) : Collector := collect j.idx
|
||||
@[inline] private def seq (k₁ k₂ : Collector) : Collector := k₂ ∘ k₁
|
||||
instance : HasAndthen Collector := ⟨seq⟩
|
||||
|
||||
private def collectArg : Arg → Collector
|
||||
| (Arg.var x) := collectVar x
|
||||
| irrelevant := skip
|
||||
|
||||
@[specialize] private def collectArray {α : Type} (as : Array α) (f : α → Collector) : Collector :=
|
||||
λ m, as.foldl (λ m a, f a m) m
|
||||
|
||||
private def collectArgs (as : Array Arg) : Collector := collectArray as collectArg
|
||||
private def collectParam (p : Param) : Collector := collectVar p.x
|
||||
private def collectParams (ps : Array Param) : Collector := collectArray ps collectParam
|
||||
|
||||
private def collectExpr : Expr → Collector
|
||||
| (Expr.ctor _ ys) := collectArgs ys
|
||||
| (Expr.reset x) := collectVar x
|
||||
| (Expr.reuse x _ _ ys) := collectVar x; collectArgs ys
|
||||
| (Expr.proj _ x) := collectVar x
|
||||
| (Expr.uproj _ x) := collectVar x
|
||||
| (Expr.sproj _ _ x) := collectVar x
|
||||
| (Expr.fap _ ys) := collectArgs ys
|
||||
| (Expr.pap _ ys) := collectArgs ys
|
||||
| (Expr.ap x ys) := collectVar x; collectArgs ys
|
||||
| (Expr.box _ x) := collectVar x
|
||||
| (Expr.unbox x) := collectVar x
|
||||
| (Expr.lit v) := skip
|
||||
| (Expr.isShared x) := collectVar x
|
||||
| (Expr.isTaggedPtr x) := collectVar x
|
||||
|
||||
private def collectAlts (f : FnBody → Collector) (alts : Array Alt) : Collector :=
|
||||
collectArray alts $ λ alt, f alt.body
|
||||
|
||||
partial def collectFnBody : FnBody → Collector
|
||||
| (FnBody.vdecl x _ v b) := collectExpr v; collectFnBody b
|
||||
| (FnBody.jdecl j ys _ v b) := collectFnBody v; collectParams ys; collectFnBody b
|
||||
| (FnBody.set x _ y b) := collectVar x; collectVar y; collectFnBody b
|
||||
| (FnBody.uset x _ y b) := collectVar x; collectVar y; collectFnBody b
|
||||
| (FnBody.sset x _ _ y _ b) := collectVar x; collectVar y; collectFnBody b
|
||||
| (FnBody.release x _ b) := collectVar x; collectFnBody b
|
||||
| (FnBody.inc x _ _ b) := collectVar x; collectFnBody b
|
||||
| (FnBody.dec x _ _ b) := collectVar x; collectFnBody b
|
||||
| (FnBody.mdata _ b) := collectFnBody b
|
||||
| (FnBody.case _ x alts) := collectVar x; collectAlts collectFnBody alts
|
||||
| (FnBody.jmp j ys) := collectJP j; collectArgs ys
|
||||
| (FnBody.ret x) := collectArg x
|
||||
| FnBody.unreachable := skip
|
||||
|
||||
partial def collectDecl : Decl → Collector
|
||||
| (Decl.fdecl _ xs _ b) := collectParams xs; collectFnBody b
|
||||
| (Decl.extern _ xs _) := collectParams xs
|
||||
|
||||
end MaxVar
|
||||
|
||||
def FnBody.maxVar (b : FnBody) : Index :=
|
||||
MaxVar.collectFnBody b 0
|
||||
|
||||
def Decl.maxVar (d : Decl) : Index :=
|
||||
MaxVar.collectDecl d 0
|
||||
def FnBody.freeIndices (b : FnBody) : IndexSet :=
|
||||
b.collectFreeIndices {}
|
||||
|
||||
namespace HasIndex
|
||||
|
||||
/- In principle, we can check whether a function body `b` contains an index `i` using
|
||||
`b.freeIndices.contains i`, but it is more efficient to avoid the construction
|
||||
of the set of freeIndices and just search whether `i` occurs in `b` or not.
|
||||
-/
|
||||
def visitVar (w : Index) (x : VarId) : Bool := w == x.idx
|
||||
def visitJP (w : Index) (x : JoinPointId) : Bool := w == x.idx
|
||||
|
||||
|
|
|
|||
|
|
@ -11,23 +11,23 @@ namespace Lean
|
|||
namespace IR
|
||||
|
||||
partial def pushProjs : Array FnBody → Array Alt → Array IndexSet → Array FnBody → IndexSet → Array FnBody × Array Alt
|
||||
| bs alts afvs ps vs :=
|
||||
if bs.isEmpty then (ps.reverse, alts)
|
||||
| bs alts altsF ctx ctxF :=
|
||||
if bs.isEmpty then (ctx.reverse, alts)
|
||||
else
|
||||
let b := bs.back in
|
||||
let bs := bs.pop in
|
||||
let done (_ : Unit) := (bs.push b ++ ps.reverse, alts) in
|
||||
let skip (_ : Unit) := pushProjs bs alts afvs (ps.push b) (b.collectFreeVars vs) in
|
||||
let done (_ : Unit) := (bs.push b ++ ctx.reverse, alts) in
|
||||
let skip (_ : Unit) := pushProjs bs alts altsF (ctx.push b) (b.collectFreeIndices ctxF) in
|
||||
match b with
|
||||
| FnBody.vdecl x t v _ :=
|
||||
match v with
|
||||
| Expr.proj _ _ :=
|
||||
if !vs.contains x.idx && !afvs.all (λ s, s.contains x.idx) then
|
||||
if !ctxF.contains x.idx && !altsF.all (λ s, s.contains x.idx) then
|
||||
let alts := alts.hmapIdx $ λ i alt, alt.modifyBody $ λ b',
|
||||
if (afvs.get i).contains x.idx then b <;> b'
|
||||
if (altsF.get i).contains x.idx then b <;> b'
|
||||
else b' in
|
||||
let fvs := afvs.hmap $ λ s, if s.contains x.idx then b.collectFreeVars s else s in
|
||||
pushProjs bs alts fvs ps vs
|
||||
let altsF := altsF.hmap $ λ s, if s.contains x.idx then b.collectFreeIndices s else s in
|
||||
pushProjs bs alts altsF ctx ctxF
|
||||
else
|
||||
skip ()
|
||||
| Expr.uproj _ _ := skip ()
|
||||
|
|
@ -41,8 +41,8 @@ partial def FnBody.pushProj : FnBody → FnBody
|
|||
let bs := modifyJPs bs FnBody.pushProj in
|
||||
match term with
|
||||
| FnBody.case tid x alts :=
|
||||
let afvs := alts.map $ λ alt, alt.body.freeVars in
|
||||
let (bs, alts) := pushProjs bs alts afvs Array.empty {x.idx} in
|
||||
let altsF := alts.map $ λ alt, alt.body.freeIndices in
|
||||
let (bs, alts) := pushProjs bs alts altsF Array.empty {x.idx} in
|
||||
let alts := alts.hmap $ λ alt, alt.modifyBody FnBody.pushProj in
|
||||
let term := FnBody.case tid x alts in
|
||||
reshape bs term
|
||||
|
|
|
|||
|
|
@ -116,8 +116,8 @@ private partial def R : FnBody → M FnBody
|
|||
|
||||
def Decl.insertResetReuse : Decl → Decl
|
||||
| d@(Decl.fdecl f xs t b) :=
|
||||
let nextVar := d.maxVar + 1 in
|
||||
let b := (R b).run' nextVar in
|
||||
let nextIndex := d.maxIndex + 1 in
|
||||
let b := (R b).run' nextIndex in
|
||||
Decl.fdecl f xs t b
|
||||
| other := other
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue