refactor: adopt do notation (#9882)
This commit is contained in:
parent
6f7ca5e5d3
commit
639baaaa03
1 changed files with 39 additions and 20 deletions
|
|
@ -70,15 +70,16 @@ def initBorrow (ps : Array Param) : Array Param :=
|
|||
def initBorrowIfNotExported (exported : Bool) (ps : Array Param) : Array Param :=
|
||||
if exported then ps else initBorrow ps
|
||||
|
||||
partial def visitFnBody (fnid : FunId) : FnBody → StateM ParamMap Unit
|
||||
| .jdecl j xs v b => do
|
||||
partial def visitFnBody (fnid : FunId) (b : FnBody) : StateM ParamMap Unit := do
|
||||
match b with
|
||||
| .jdecl j xs v b =>
|
||||
modify fun m => m.insert (.jp fnid j) (initBorrow xs)
|
||||
visitFnBody fnid v
|
||||
visitFnBody fnid b
|
||||
| .case _ _ _ alts => alts.forM fun alt => visitFnBody fnid alt.body
|
||||
| e => do
|
||||
unless e.isTerminal do
|
||||
visitFnBody fnid e.body
|
||||
| _ => do
|
||||
unless b.isTerminal do
|
||||
visitFnBody fnid b.body
|
||||
|
||||
def visitDecls (env : Environment) (decls : Array Decl) : StateM ParamMap Unit :=
|
||||
decls.forM fun decl => match decl with
|
||||
|
|
@ -231,19 +232,33 @@ def ownArgsIfParam (xs : Array Arg) : M Unit := do
|
|||
| .var x => if ctx.paramSet.contains x.idx then ownVar x
|
||||
| .erased => pure ()
|
||||
|
||||
def collectExpr (z : VarId) : Expr → M Unit
|
||||
| .reset _ x => ownVar z *> ownVar x
|
||||
| .reuse x _ _ ys => ownVar z *> ownVar x *> ownArgsIfParam ys
|
||||
| .ctor _ xs => ownVar z *> ownArgsIfParam xs
|
||||
| .proj _ x => do
|
||||
def collectExpr (z : VarId) (e : Expr) : M Unit := do
|
||||
match e with
|
||||
| .reset _ x =>
|
||||
ownVar z
|
||||
ownVar x
|
||||
| .reuse x _ _ ys =>
|
||||
ownVar z
|
||||
ownVar x
|
||||
ownArgsIfParam ys
|
||||
| .ctor _ xs =>
|
||||
ownVar z
|
||||
ownArgsIfParam xs
|
||||
| .proj _ x =>
|
||||
if (← isOwned x) then ownVar z
|
||||
if (← isOwned z) then ownVar x
|
||||
| .fap g xs => do
|
||||
| .fap g xs =>
|
||||
let ps ← getParamInfo (.decl g)
|
||||
ownVar z *> ownArgsUsingParams xs ps
|
||||
| .ap x ys => ownVar z *> ownVar x *> ownArgs ys
|
||||
| .pap _ xs => ownVar z *> ownArgs xs
|
||||
| _ => pure ()
|
||||
ownVar z
|
||||
ownArgsUsingParams xs ps
|
||||
| .ap x ys =>
|
||||
ownVar z
|
||||
ownVar x
|
||||
ownArgs ys
|
||||
| .pap _ xs =>
|
||||
ownVar z
|
||||
ownArgs xs
|
||||
| _ => pure ()
|
||||
|
||||
def preserveTailCall (x : VarId) (v : Expr) (b : FnBody) : M Unit := do
|
||||
let ctx ← read
|
||||
|
|
@ -258,20 +273,24 @@ def preserveTailCall (x : VarId) (v : Expr) (b : FnBody) : M Unit := do
|
|||
def updateParamSet (ctx : BorrowInfCtx) (ps : Array Param) : BorrowInfCtx :=
|
||||
{ ctx with paramSet := ps.foldl (fun s p => s.insert p.x.idx) ctx.paramSet }
|
||||
|
||||
partial def collectFnBody : FnBody → M Unit
|
||||
| .jdecl j ys v b => do
|
||||
partial def collectFnBody (b : FnBody) : M Unit := do
|
||||
match b with
|
||||
| .jdecl j ys v b =>
|
||||
withReader (fun ctx => updateParamSet ctx ys) (collectFnBody v)
|
||||
let ctx ← read
|
||||
updateParamMap (.jp ctx.currFn j)
|
||||
collectFnBody b
|
||||
| .vdecl x _ v b => collectFnBody b *> collectExpr x v *> preserveTailCall x v b
|
||||
| .jmp j ys => do
|
||||
| .vdecl x _ v b =>
|
||||
collectFnBody b
|
||||
collectExpr x v
|
||||
preserveTailCall x v b
|
||||
| .jmp j ys =>
|
||||
let ctx ← read
|
||||
let ps ← getParamInfo (.jp ctx.currFn j)
|
||||
ownArgsUsingParams ys ps -- for making sure the join point can reuse
|
||||
ownParamsUsingArgs ys ps -- for making sure the tail call is preserved
|
||||
| .case _ _ _ alts => alts.forM fun alt => collectFnBody alt.body
|
||||
| e => do unless e.isTerminal do collectFnBody e.body
|
||||
| _ => do unless b.isTerminal do collectFnBody b.body
|
||||
|
||||
partial def collectDecl : Decl → M Unit
|
||||
| .fdecl (f := f) (xs := ys) (body := b) .. =>
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue