refactor: use a state monad rather than combinators for computing free indices (#9711)
This commit is contained in:
parent
ae728d84f0
commit
78b941019b
1 changed files with 56 additions and 59 deletions
|
|
@ -108,87 +108,84 @@ namespace FreeIndices
|
|||
/-! We say a variable (join point) index (aka name) is free in a function body
|
||||
if there isn't a `FnBody.vdecl` (`Fnbody.jdecl`) binding it. -/
|
||||
|
||||
abbrev Collector := IndexSet → IndexSet → IndexSet
|
||||
structure State where
|
||||
freeIndices : IndexSet := {}
|
||||
|
||||
@[inline] private def skip : Collector :=
|
||||
fun _ fv => fv
|
||||
abbrev M := StateM State
|
||||
|
||||
@[inline] private def collectIndex (x : Index) : Collector :=
|
||||
fun bv fv => if bv.contains x then fv else fv.insert x
|
||||
private def visitIndex (x : Index) : M Unit := do
|
||||
modify fun s => { s with freeIndices := s.freeIndices.insert x }
|
||||
|
||||
@[inline] private def collectVar (x : VarId) : Collector :=
|
||||
collectIndex x.idx
|
||||
private def visitVar (x : VarId) : M Unit :=
|
||||
visitIndex x.idx
|
||||
|
||||
@[inline] private def collectJP (x : JoinPointId) : Collector :=
|
||||
collectIndex x.idx
|
||||
private def visitJP (j : JoinPointId) : M Unit :=
|
||||
visitIndex j.idx
|
||||
|
||||
@[inline] private def withIndex (x : Index) : Collector → Collector :=
|
||||
fun k bv fv => k (bv.insert x) fv
|
||||
private def visitArg (arg : Arg) : M Unit :=
|
||||
match arg with
|
||||
| .var x => visitVar x
|
||||
| .erased => pure ()
|
||||
|
||||
@[inline] private def withVar (x : VarId) : Collector → Collector :=
|
||||
withIndex x.idx
|
||||
private def visitParam (p : Param) : M Unit :=
|
||||
visitVar p.x
|
||||
|
||||
@[inline] private def withJP (x : JoinPointId) : Collector → Collector :=
|
||||
withIndex x.idx
|
||||
|
||||
def insertParams (s : IndexSet) (ys : Array Param) : IndexSet :=
|
||||
ys.foldl (init := s) fun s p => s.insert p.x.idx
|
||||
|
||||
@[inline] private def withParams (ys : Array Param) : Collector → Collector :=
|
||||
fun k bv fv => k (insertParams bv ys) fv
|
||||
|
||||
@[inline] private def seq : Collector → Collector → Collector :=
|
||||
fun k₁ k₂ bv fv => k₂ bv (k₁ bv fv)
|
||||
|
||||
instance : AndThen Collector where
|
||||
andThen a b := private seq a (b ())
|
||||
|
||||
private def collectArg : Arg → Collector
|
||||
| .var x => collectVar x
|
||||
| .erased => skip
|
||||
|
||||
private def collectArray {α : Type} (as : Array α) (f : α → Collector) : Collector :=
|
||||
fun bv fv => as.foldl (fun fv a => f a bv fv) fv
|
||||
|
||||
private def collectArgs (as : Array Arg) : Collector :=
|
||||
collectArray as collectArg
|
||||
|
||||
private def collectExpr : Expr → Collector
|
||||
private def visitExpr (e : Expr) : M Unit := do
|
||||
match e with
|
||||
| .proj _ x | .uproj _ x | .sproj _ _ x | .box _ x | .unbox x | .reset _ x | .isShared x =>
|
||||
collectVar x
|
||||
visitVar x
|
||||
| .ctor _ ys | .fap _ ys | .pap _ ys =>
|
||||
collectArgs ys
|
||||
ys.forM visitArg
|
||||
| .ap x ys | .reuse x _ _ ys =>
|
||||
collectVar x >> collectArgs ys
|
||||
| .lit _ => skip
|
||||
visitVar x
|
||||
ys.forM visitArg
|
||||
| .lit _ => pure ()
|
||||
|
||||
private def collectAlts (f : FnBody → Collector) (alts : Array Alt) : Collector :=
|
||||
collectArray alts fun alt => f alt.body
|
||||
|
||||
partial def collectFnBody : FnBody → Collector
|
||||
partial def visitFnBody (fnBody : FnBody) : M Unit := do
|
||||
match fnBody with
|
||||
| .vdecl x _ v b =>
|
||||
collectExpr v >> withVar x (collectFnBody b)
|
||||
visitVar x
|
||||
visitExpr v
|
||||
visitFnBody b
|
||||
| .jdecl j ys v b =>
|
||||
withParams ys (collectFnBody v) >> withJP j (collectFnBody b)
|
||||
visitJP j
|
||||
visitFnBody v
|
||||
ys.forM visitParam
|
||||
visitFnBody b
|
||||
| .set x _ y b =>
|
||||
collectVar x >> collectArg y >> collectFnBody b
|
||||
visitVar x
|
||||
visitArg y
|
||||
visitFnBody b
|
||||
| .uset x _ y b | .sset x _ _ y _ b =>
|
||||
collectVar x >> collectVar y >> collectFnBody b
|
||||
visitVar x
|
||||
visitVar y
|
||||
visitFnBody b
|
||||
| .setTag x _ b | .inc x _ _ _ b | .dec x _ _ _ b | .del x b =>
|
||||
collectVar x >> collectFnBody b
|
||||
visitVar x
|
||||
visitFnBody b
|
||||
| .case _ x _ alts =>
|
||||
collectVar x >>
|
||||
collectAlts collectFnBody alts
|
||||
visitVar x
|
||||
alts.forM (visitFnBody ·.body)
|
||||
| .jmp j ys =>
|
||||
collectJP j >> collectArgs ys
|
||||
visitJP j
|
||||
ys.forM visitArg
|
||||
| .ret x =>
|
||||
collectArg x
|
||||
| .unreachable => skip
|
||||
visitArg x
|
||||
| .unreachable => pure ()
|
||||
|
||||
private def visitDecl (decl : Decl) : M Unit := do
|
||||
match decl with
|
||||
| .fdecl (xs := xs) (body := b) .. =>
|
||||
xs.forM visitParam
|
||||
visitFnBody b
|
||||
| .extern (xs := xs) .. =>
|
||||
xs.forM visitParam
|
||||
|
||||
end FreeIndices
|
||||
|
||||
def FnBody.collectFreeIndices (b : FnBody) (vs : IndexSet) : IndexSet :=
|
||||
FreeIndices.collectFnBody b {} vs
|
||||
def FnBody.collectFreeIndices (b : FnBody) (init : IndexSet) : IndexSet := Id.run do
|
||||
let ⟨_, { freeIndices }⟩ := FreeIndices.visitFnBody b |>.run { freeIndices := init }
|
||||
return freeIndices
|
||||
|
||||
def FnBody.freeIndices (b : FnBody) : IndexSet :=
|
||||
b.collectFreeIndices {}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue