refactor: adopt do notation (#9882)

This commit is contained in:
Cameron Zwarich 2025-08-12 15:12:59 -07:00 committed by GitHub
parent 6f7ca5e5d3
commit 639baaaa03
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -70,15 +70,16 @@ def initBorrow (ps : Array Param) : Array Param :=
def initBorrowIfNotExported (exported : Bool) (ps : Array Param) : Array Param :=
if exported then ps else initBorrow ps
partial def visitFnBody (fnid : FunId) : FnBody → StateM ParamMap Unit
| .jdecl j xs v b => do
partial def visitFnBody (fnid : FunId) (b : FnBody) : StateM ParamMap Unit := do
match b with
| .jdecl j xs v b =>
modify fun m => m.insert (.jp fnid j) (initBorrow xs)
visitFnBody fnid v
visitFnBody fnid b
| .case _ _ _ alts => alts.forM fun alt => visitFnBody fnid alt.body
| e => do
unless e.isTerminal do
visitFnBody fnid e.body
| _ => do
unless b.isTerminal do
visitFnBody fnid b.body
def visitDecls (env : Environment) (decls : Array Decl) : StateM ParamMap Unit :=
decls.forM fun decl => match decl with
@ -231,19 +232,33 @@ def ownArgsIfParam (xs : Array Arg) : M Unit := do
| .var x => if ctx.paramSet.contains x.idx then ownVar x
| .erased => pure ()
def collectExpr (z : VarId) : Expr → M Unit
| .reset _ x => ownVar z *> ownVar x
| .reuse x _ _ ys => ownVar z *> ownVar x *> ownArgsIfParam ys
| .ctor _ xs => ownVar z *> ownArgsIfParam xs
| .proj _ x => do
def collectExpr (z : VarId) (e : Expr) : M Unit := do
match e with
| .reset _ x =>
ownVar z
ownVar x
| .reuse x _ _ ys =>
ownVar z
ownVar x
ownArgsIfParam ys
| .ctor _ xs =>
ownVar z
ownArgsIfParam xs
| .proj _ x =>
if (← isOwned x) then ownVar z
if (← isOwned z) then ownVar x
| .fap g xs => do
| .fap g xs =>
let ps ← getParamInfo (.decl g)
ownVar z *> ownArgsUsingParams xs ps
| .ap x ys => ownVar z *> ownVar x *> ownArgs ys
| .pap _ xs => ownVar z *> ownArgs xs
| _ => pure ()
ownVar z
ownArgsUsingParams xs ps
| .ap x ys =>
ownVar z
ownVar x
ownArgs ys
| .pap _ xs =>
ownVar z
ownArgs xs
| _ => pure ()
def preserveTailCall (x : VarId) (v : Expr) (b : FnBody) : M Unit := do
let ctx ← read
@ -258,20 +273,24 @@ def preserveTailCall (x : VarId) (v : Expr) (b : FnBody) : M Unit := do
def updateParamSet (ctx : BorrowInfCtx) (ps : Array Param) : BorrowInfCtx :=
{ ctx with paramSet := ps.foldl (fun s p => s.insert p.x.idx) ctx.paramSet }
partial def collectFnBody : FnBody → M Unit
| .jdecl j ys v b => do
partial def collectFnBody (b : FnBody) : M Unit := do
match b with
| .jdecl j ys v b =>
withReader (fun ctx => updateParamSet ctx ys) (collectFnBody v)
let ctx ← read
updateParamMap (.jp ctx.currFn j)
collectFnBody b
| .vdecl x _ v b => collectFnBody b *> collectExpr x v *> preserveTailCall x v b
| .jmp j ys => do
| .vdecl x _ v b =>
collectFnBody b
collectExpr x v
preserveTailCall x v b
| .jmp j ys =>
let ctx ← read
let ps ← getParamInfo (.jp ctx.currFn j)
ownArgsUsingParams ys ps -- for making sure the join point can reuse
ownParamsUsingArgs ys ps -- for making sure the tail call is preserved
| .case _ _ _ alts => alts.forM fun alt => collectFnBody alt.body
| e => do unless e.isTerminal do collectFnBody e.body
| _ => do unless b.isTerminal do collectFnBody b.body
partial def collectDecl : Decl → M Unit
| .fdecl (f := f) (xs := ys) (body := b) .. =>