diff --git a/src/Lean/Compiler/IR/Borrow.lean b/src/Lean/Compiler/IR/Borrow.lean index 93979c8964..492785bd6d 100644 --- a/src/Lean/Compiler/IR/Borrow.lean +++ b/src/Lean/Compiler/IR/Borrow.lean @@ -70,15 +70,16 @@ def initBorrow (ps : Array Param) : Array Param := def initBorrowIfNotExported (exported : Bool) (ps : Array Param) : Array Param := if exported then ps else initBorrow ps -partial def visitFnBody (fnid : FunId) : FnBody → StateM ParamMap Unit - | .jdecl j xs v b => do +partial def visitFnBody (fnid : FunId) (b : FnBody) : StateM ParamMap Unit := do + match b with + | .jdecl j xs v b => modify fun m => m.insert (.jp fnid j) (initBorrow xs) visitFnBody fnid v visitFnBody fnid b | .case _ _ _ alts => alts.forM fun alt => visitFnBody fnid alt.body - | e => do - unless e.isTerminal do - visitFnBody fnid e.body + | _ => do + unless b.isTerminal do + visitFnBody fnid b.body def visitDecls (env : Environment) (decls : Array Decl) : StateM ParamMap Unit := decls.forM fun decl => match decl with @@ -231,19 +232,33 @@ def ownArgsIfParam (xs : Array Arg) : M Unit := do | .var x => if ctx.paramSet.contains x.idx then ownVar x | .erased => pure () -def collectExpr (z : VarId) : Expr → M Unit - | .reset _ x => ownVar z *> ownVar x - | .reuse x _ _ ys => ownVar z *> ownVar x *> ownArgsIfParam ys - | .ctor _ xs => ownVar z *> ownArgsIfParam xs - | .proj _ x => do +def collectExpr (z : VarId) (e : Expr) : M Unit := do + match e with + | .reset _ x => + ownVar z + ownVar x + | .reuse x _ _ ys => + ownVar z + ownVar x + ownArgsIfParam ys + | .ctor _ xs => + ownVar z + ownArgsIfParam xs + | .proj _ x => if (← isOwned x) then ownVar z if (← isOwned z) then ownVar x - | .fap g xs => do + | .fap g xs => let ps ← getParamInfo (.decl g) - ownVar z *> ownArgsUsingParams xs ps - | .ap x ys => ownVar z *> ownVar x *> ownArgs ys - | .pap _ xs => ownVar z *> ownArgs xs - | _ => pure () + ownVar z + ownArgsUsingParams xs ps + | .ap x ys => + ownVar z + ownVar x + ownArgs ys + | .pap _ xs => + ownVar z + ownArgs xs + | _ => pure () def preserveTailCall (x : VarId) (v : Expr) (b : FnBody) : M Unit := do let ctx ← read @@ -258,20 +273,24 @@ def preserveTailCall (x : VarId) (v : Expr) (b : FnBody) : M Unit := do def updateParamSet (ctx : BorrowInfCtx) (ps : Array Param) : BorrowInfCtx := { ctx with paramSet := ps.foldl (fun s p => s.insert p.x.idx) ctx.paramSet } -partial def collectFnBody : FnBody → M Unit - | .jdecl j ys v b => do +partial def collectFnBody (b : FnBody) : M Unit := do + match b with + | .jdecl j ys v b => withReader (fun ctx => updateParamSet ctx ys) (collectFnBody v) let ctx ← read updateParamMap (.jp ctx.currFn j) collectFnBody b - | .vdecl x _ v b => collectFnBody b *> collectExpr x v *> preserveTailCall x v b - | .jmp j ys => do + | .vdecl x _ v b => + collectFnBody b + collectExpr x v + preserveTailCall x v b + | .jmp j ys => let ctx ← read let ps ← getParamInfo (.jp ctx.currFn j) ownArgsUsingParams ys ps -- for making sure the join point can reuse ownParamsUsingArgs ys ps -- for making sure the tail call is preserved | .case _ _ _ alts => alts.forM fun alt => collectFnBody alt.body - | e => do unless e.isTerminal do collectFnBody e.body + | _ => do unless b.isTerminal do collectFnBody b.body partial def collectDecl : Decl → M Unit | .fdecl (f := f) (xs := ys) (body := b) .. =>