fix: Borrow.lean
The following bugs have been fixed - The implementation for the `Expr.proj` rule did not match our paper nor `RC.lean`. The correct rule is: ``` | Expr.proj _ x => whenM (isOwned x) $ ownVar z ``` - We take the OwnsetSet (`O` in our paper) into account when computing the fixpoint. - `applyParamMap` and `mkInitParamMap` were not visiting the alternatives of `case` statements.
This commit is contained in:
parent
f5741af39d
commit
7148fc1078
1 changed files with 99 additions and 90 deletions
|
|
@ -12,39 +12,54 @@ import Init.Lean.Compiler.IR.NormIds
|
|||
namespace Lean
|
||||
namespace IR
|
||||
namespace Borrow
|
||||
|
||||
namespace OwnedSet
|
||||
abbrev Key := FunId × Index
|
||||
|
||||
def beq : Key → Key → Bool
|
||||
| (f₁, x₁), (f₂, x₂) => f₁ == f₂ && x₁ == x₂
|
||||
instance : HasBeq Key := ⟨beq⟩
|
||||
|
||||
def getHash : Key → USize
|
||||
| (f, x) => mixHash (hash f) (hash x)
|
||||
instance : Hashable Key := ⟨getHash⟩
|
||||
end OwnedSet
|
||||
|
||||
abbrev OwnedSet := HashMap OwnedSet.Key Unit
|
||||
def OwnedSet.insert (s : OwnedSet) (k : OwnedSet.Key) : OwnedSet := s.insert k ()
|
||||
def OwnedSet.contains (s : OwnedSet) (k : OwnedSet.Key) : Bool := s.contains k
|
||||
|
||||
/- We perform borrow inference in a block of mutually recursive functions.
|
||||
Join points are viewed as local functions, and are identified using
|
||||
their local id + the name of the surrounding function.
|
||||
|
||||
We keep a mapping from function and joint points to parameters (`Array Param`).
|
||||
Recall that `Param` contains the field `borrow`.
|
||||
The type `Key` is the the key of this map. -/
|
||||
Recall that `Param` contains the field `borrow`. -/
|
||||
namespace ParamMap
|
||||
inductive Key
|
||||
| decl (name : FunId)
|
||||
| jp (name : FunId) (jpid : JoinPointId)
|
||||
|
||||
namespace Key
|
||||
def beq : Key → Key → Bool
|
||||
| decl n₁, decl n₂ => n₁ == n₂
|
||||
| jp n₁ id₁, jp n₂ id₂ => n₁ == n₂ && id₁ == id₂
|
||||
| _, _ => false
|
||||
| Key.decl n₁, Key.decl n₂ => n₁ == n₂
|
||||
| Key.jp n₁ id₁, Key.jp n₂ id₂ => n₁ == n₂ && id₁ == id₂
|
||||
| _, _ => false
|
||||
|
||||
instance : HasBeq Key := ⟨beq⟩
|
||||
|
||||
def getHash : Key → USize
|
||||
| decl n => hash n
|
||||
| jp n id => mixHash (hash n) (hash id)
|
||||
| Key.decl n => hash n
|
||||
| Key.jp n id => mixHash (hash n) (hash id)
|
||||
|
||||
instance : Hashable Key := ⟨getHash⟩
|
||||
end Key
|
||||
end ParamMap
|
||||
|
||||
abbrev ParamMap := HashMap Key (Array Param)
|
||||
abbrev ParamMap := HashMap ParamMap.Key (Array Param)
|
||||
|
||||
def ParamMap.fmt (map : ParamMap) : Format :=
|
||||
let fmts := map.fold (fun fmt k ps =>
|
||||
let k := match k with
|
||||
| Key.decl n => format n
|
||||
| Key.jp n id => format n ++ ":" ++ format id;
|
||||
| ParamMap.Key.decl n => format n
|
||||
| ParamMap.Key.jp n id => format n ++ ":" ++ format id;
|
||||
fmt ++ Format.line ++ k ++ " -> " ++ formatParams ps)
|
||||
Format.nil;
|
||||
"{" ++ (Format.nest 1 fmts) ++ "}"
|
||||
|
|
@ -53,7 +68,6 @@ instance : HasFormat ParamMap := ⟨ParamMap.fmt⟩
|
|||
instance : HasToString ParamMap := ⟨fun m => Format.pretty (format m)⟩
|
||||
|
||||
namespace InitParamMap
|
||||
|
||||
/- Mark parameters that take a reference as borrow -/
|
||||
def initBorrow (ps : Array Param) : Array Param :=
|
||||
ps.map $ fun p => { borrow := p.ty.isObj, .. p }
|
||||
|
|
@ -63,17 +77,17 @@ ps.map $ fun p => { borrow := p.ty.isObj, .. p }
|
|||
These wrappers use smart pointers such as `object_ref`.
|
||||
When writing a new wrapper we need to know whether an argument is a borrow
|
||||
inference or not.
|
||||
|
||||
We can revise this decision when we implement code for generating
|
||||
the wrappers automatically. -/
|
||||
def initBorrowIfNotExported (exported : Bool) (ps : Array Param) : Array Param :=
|
||||
if exported then ps else initBorrow ps
|
||||
|
||||
partial def visitFnBody (fnid : FunId) : FnBody → StateM ParamMap Unit
|
||||
| FnBody.jdecl j xs v b => do
|
||||
modify $ fun m => m.insert (Key.jp fnid j) (initBorrow xs);
|
||||
| FnBody.jdecl j xs v b => do
|
||||
modify $ fun m => m.insert (ParamMap.Key.jp fnid j) (initBorrow xs);
|
||||
visitFnBody v;
|
||||
visitFnBody b
|
||||
| FnBody.case _ _ _ alts => alts.forM $ fun alt => visitFnBody alt.body
|
||||
| e =>
|
||||
unless (e.isTerminal) $ do
|
||||
let (instr, b) := e.split;
|
||||
|
|
@ -83,7 +97,7 @@ def visitDecls (env : Environment) (decls : Array Decl) : StateM ParamMap Unit :
|
|||
decls.forM $ fun decl => match decl with
|
||||
| Decl.fdecl f xs _ b => do
|
||||
let exported := isExport env f;
|
||||
modify $ fun m => m.insert (Key.decl f) (initBorrowIfNotExported exported xs);
|
||||
modify $ fun m => m.insert (ParamMap.Key.decl f) (initBorrowIfNotExported exported xs);
|
||||
visitFnBody f b
|
||||
| _ => pure ()
|
||||
end InitParamMap
|
||||
|
|
@ -95,27 +109,29 @@ def mkInitParamMap (env : Environment) (decls : Array Decl) : ParamMap :=
|
|||
recursive functions. -/
|
||||
namespace ApplyParamMap
|
||||
|
||||
partial def visitFnBody : FnBody → FunId → ParamMap → FnBody
|
||||
| FnBody.jdecl j xs v b, fnid, map =>
|
||||
let v := visitFnBody v fnid map;
|
||||
let b := visitFnBody b fnid map;
|
||||
match map.find? (Key.jp fnid j) with
|
||||
partial def visitFnBody (fn : FunId) (paramMap : ParamMap) : FnBody → FnBody
|
||||
| FnBody.jdecl j xs v b =>
|
||||
let v := visitFnBody v;
|
||||
let b := visitFnBody b;
|
||||
match paramMap.find? (ParamMap.Key.jp fn j) with
|
||||
| some ys => FnBody.jdecl j ys v b
|
||||
| none => FnBody.jdecl j xs v b
|
||||
| e, fnid, map =>
|
||||
| none => unreachable!
|
||||
| FnBody.case tid x xType alts =>
|
||||
FnBody.case tid x xType $ alts.map $ fun alt => alt.modifyBody visitFnBody
|
||||
| e =>
|
||||
if e.isTerminal then e
|
||||
else
|
||||
let (instr, b) := e.split;
|
||||
let b := visitFnBody b fnid map;
|
||||
let b := visitFnBody b;
|
||||
instr.setBody b
|
||||
|
||||
def visitDecls (decls : Array Decl) (map : ParamMap) : Array Decl :=
|
||||
def visitDecls (decls : Array Decl) (paramMap : ParamMap) : Array Decl :=
|
||||
decls.map $ fun decl => match decl with
|
||||
| Decl.fdecl f xs ty b =>
|
||||
let b := visitFnBody b f map;
|
||||
match map.find? (Key.decl f) with
|
||||
let b := visitFnBody f paramMap b;
|
||||
match paramMap.find? (ParamMap.Key.decl f) with
|
||||
| some xs => Decl.fdecl f xs ty b
|
||||
| none => Decl.fdecl f xs ty b
|
||||
| none => unreachable!
|
||||
| other => other
|
||||
|
||||
end ApplyParamMap
|
||||
|
|
@ -125,66 +141,71 @@ def applyParamMap (decls : Array Decl) (map : ParamMap) : Array Decl :=
|
|||
ApplyParamMap.visitDecls decls map
|
||||
|
||||
structure BorrowInfCtx :=
|
||||
(env : Environment)
|
||||
(currFn : FunId := arbitrary _) -- Function being analyzed.
|
||||
(env : Environment)
|
||||
(currFn : FunId := arbitrary _) -- Function being analyzed.
|
||||
(paramSet : IndexSet := {}) -- Set of all function parameters in scope. This is used to implement the heuristic at `ownArgsUsingParams`
|
||||
|
||||
structure BorrowInfState :=
|
||||
/- `map` is a mapping storing the inferred borrow annotations for all functions (and joint points) in a mutually recursive declaration. -/
|
||||
(map : ParamMap)
|
||||
/- Set of variables that must be `owned`. -/
|
||||
(owned : IndexSet := {})
|
||||
(modifiedOwned : Bool := false)
|
||||
(modifiedParamMap : Bool := false)
|
||||
(owned : OwnedSet := {})
|
||||
(modified : Bool := false)
|
||||
(paramMap : ParamMap)
|
||||
|
||||
abbrev M := ReaderT BorrowInfCtx (StateM BorrowInfState)
|
||||
|
||||
def markModifiedParamMap : M Unit :=
|
||||
modify $ fun s => { modifiedParamMap := true, .. s }
|
||||
def getCurrFn : M FunId := do
|
||||
ctx ← read;
|
||||
pure ctx.currFn
|
||||
|
||||
def ownVar (x : VarId) : M Unit :=
|
||||
def markModified : M Unit :=
|
||||
modify $ fun s => { modified := true, .. s }
|
||||
|
||||
def ownVar (x : VarId) : M Unit := do
|
||||
-- dbgTrace ("ownVar " ++ toString x) $ fun _ =>
|
||||
currFn ← getCurrFn;
|
||||
modify $ fun s =>
|
||||
if s.owned.contains x.idx then s
|
||||
else { owned := s.owned.insert x.idx, modifiedOwned := true, .. s }
|
||||
if s.owned.contains (currFn, x.idx) then s
|
||||
else { owned := s.owned.insert (currFn, x.idx), modified := true, .. s }
|
||||
|
||||
def ownArg (x : Arg) : M Unit :=
|
||||
match x with
|
||||
| (Arg.var x) => ownVar x
|
||||
| _ => pure ()
|
||||
| Arg.var x => ownVar x
|
||||
| _ => pure ()
|
||||
|
||||
def ownArgs (xs : Array Arg) : M Unit :=
|
||||
xs.forM ownArg
|
||||
|
||||
def isOwned (x : VarId) : M Bool := do
|
||||
s ← get;
|
||||
pure $ s.owned.contains x.idx
|
||||
currFn ← getCurrFn;
|
||||
s ← get;
|
||||
pure $ s.owned.contains (currFn, x.idx)
|
||||
|
||||
/- Updates `map[k]` using the current set of `owned` variables. -/
|
||||
def updateParamMap (k : Key) : M Unit := do
|
||||
def updateParamMap (k : ParamMap.Key) : M Unit := do
|
||||
currFn ← getCurrFn;
|
||||
s ← get;
|
||||
match s.map.find? k with
|
||||
match s.paramMap.find? k with
|
||||
| some ps => do
|
||||
ps ← ps.mapM $ fun (p : Param) =>
|
||||
if p.borrow && s.owned.contains p.x.idx then do
|
||||
markModifiedParamMap; pure { borrow := false, .. p }
|
||||
else
|
||||
pure p;
|
||||
modify $ fun s => { map := s.map.insert k ps, .. s }
|
||||
if !p.borrow then pure p
|
||||
else condM (isOwned p.x)
|
||||
(do markModified; pure { borrow := false, .. p })
|
||||
(pure p);
|
||||
modify $ fun s => { paramMap := s.paramMap.insert k ps, .. s }
|
||||
| none => pure ()
|
||||
|
||||
def getParamInfo (k : Key) : M (Array Param) := do
|
||||
def getParamInfo (k : ParamMap.Key) : M (Array Param) := do
|
||||
s ← get;
|
||||
match s.map.find? k with
|
||||
match s.paramMap.find? k with
|
||||
| some ps => pure ps
|
||||
| none =>
|
||||
match k with
|
||||
| (Key.decl fn) => do
|
||||
| ParamMap.Key.decl fn => do
|
||||
ctx ← read;
|
||||
match findEnvDecl ctx.env fn with
|
||||
| some decl => pure decl.params
|
||||
| none => pure #[] -- unreachable if well-formed input
|
||||
| _ => pure #[] -- unreachable if well-formed input
|
||||
| none => unreachable!
|
||||
| _ => unreachable!
|
||||
|
||||
/- For each ps[i], if ps[i] is owned, then mark xs[i] as owned. -/
|
||||
def ownArgsUsingParams (xs : Array Arg) (ps : Array Param) : M Unit :=
|
||||
|
|
@ -214,8 +235,7 @@ xs.size.forM $ fun i => do
|
|||
def f (x y : obj) :=
|
||||
let z := ctor_1 x y;
|
||||
ret z
|
||||
```
|
||||
-/
|
||||
``` -/
|
||||
def ownArgsIfParam (xs : Array Arg) : M Unit := do
|
||||
ctx ← read;
|
||||
xs.forM $ fun x =>
|
||||
|
|
@ -227,8 +247,10 @@ def collectExpr (z : VarId) : Expr → M Unit
|
|||
| Expr.reset _ x => ownVar z *> ownVar x
|
||||
| Expr.reuse x _ _ ys => ownVar z *> ownVar x *> ownArgsIfParam ys
|
||||
| Expr.ctor _ xs => ownVar z *> ownArgsIfParam xs
|
||||
| Expr.proj _ x => whenM (isOwned z) $ ownVar x
|
||||
| Expr.fap g xs => do ps ← getParamInfo (Key.decl g);
|
||||
| Expr.proj _ x => do
|
||||
whenM (isOwned x) $ ownVar z;
|
||||
whenM (isOwned z) $ ownVar x
|
||||
| Expr.fap g xs => do ps ← getParamInfo (ParamMap.Key.decl g);
|
||||
-- dbgTrace ("collectExpr: " ++ toString g ++ " " ++ toString (formatParams ps)) $ fun _ =>
|
||||
ownVar z *> ownArgsUsingParams xs ps
|
||||
| Expr.ap x ys => ownVar z *> ownVar x *> ownArgs ys
|
||||
|
|
@ -241,7 +263,7 @@ match v, b with
|
|||
| (Expr.fap g ys), (FnBody.ret (Arg.var z)) =>
|
||||
when (ctx.currFn == g && x == z) $ do
|
||||
-- dbgTrace ("preserveTailCall " ++ toString b) $ fun _ => do
|
||||
ps ← getParamInfo (Key.decl g);
|
||||
ps ← getParamInfo (ParamMap.Key.decl g);
|
||||
ownParamsUsingArgs ys ps
|
||||
| _, _ => pure ()
|
||||
|
||||
|
|
@ -252,58 +274,45 @@ partial def collectFnBody : FnBody → M Unit
|
|||
| FnBody.jdecl j ys v b => do
|
||||
adaptReader (fun ctx => updateParamSet ctx ys) (collectFnBody v);
|
||||
ctx ← read;
|
||||
updateParamMap (Key.jp ctx.currFn j);
|
||||
updateParamMap (ParamMap.Key.jp ctx.currFn j);
|
||||
collectFnBody b
|
||||
| FnBody.vdecl x _ v b => collectFnBody b *> collectExpr x v *> preserveTailCall x v b
|
||||
| FnBody.jmp j ys => do
|
||||
ctx ← read;
|
||||
ps ← getParamInfo (Key.jp ctx.currFn j);
|
||||
ps ← getParamInfo (ParamMap.Key.jp ctx.currFn j);
|
||||
ownArgsUsingParams ys ps; -- for making sure the join point can reuse
|
||||
ownParamsUsingArgs ys ps -- for making sure the tail call is preserved
|
||||
| FnBody.case _ _ _ alts => alts.forM $ fun alt => collectFnBody alt.body
|
||||
| e => unless (e.isTerminal) $ collectFnBody e.body
|
||||
|
||||
@[specialize] partial def whileModifingOwnedAux (x : M Unit) : Unit → M Unit
|
||||
| _ => do
|
||||
modify $ fun s => { modifiedOwned := false, .. s };
|
||||
x;
|
||||
s ← get;
|
||||
if s.modifiedOwned then whileModifingOwnedAux ()
|
||||
else pure ()
|
||||
|
||||
/- Keep executing `x` while it modifies ownedSet -/
|
||||
@[inline] def whileModifingOwned (x : M Unit) : M Unit :=
|
||||
whileModifingOwnedAux x ()
|
||||
|
||||
partial def collectDecl : Decl → M Unit
|
||||
| Decl.fdecl f ys _ b =>
|
||||
adaptReader (fun ctx => let ctx := updateParamSet ctx ys; { currFn := f, .. ctx }) $ do
|
||||
modify $ fun (s : BorrowInfState) => { owned := {}, .. s };
|
||||
whileModifingOwned (collectFnBody b);
|
||||
updateParamMap (Key.decl f)
|
||||
collectFnBody b;
|
||||
updateParamMap (ParamMap.Key.decl f)
|
||||
| _ => pure ()
|
||||
|
||||
@[specialize] partial def whileModifingParamMapAux (x : M Unit) : Unit → M Unit
|
||||
@[specialize] partial def whileModifingAux (x : M Unit) : Unit → M Unit
|
||||
| _ => do
|
||||
modify $ fun s => { modifiedParamMap := false, .. s };
|
||||
s ← get;
|
||||
modify $ fun s => { modified := false, .. s };
|
||||
-- s ← get;
|
||||
-- dbgTrace (toString s.map) $ fun _ => do
|
||||
x;
|
||||
s ← get;
|
||||
if s.modifiedParamMap then whileModifingParamMapAux ()
|
||||
if s.modified then whileModifingAux ()
|
||||
else pure ()
|
||||
|
||||
/- Keep executing `x` while it modifies paramMap -/
|
||||
@[inline] def whileModifingParamMap (x : M Unit) : M Unit :=
|
||||
whileModifingParamMapAux x ()
|
||||
/- Keep executing `x` until it reaches a fixpoint -/
|
||||
@[inline] def whileModifing (x : M Unit) : M Unit :=
|
||||
whileModifingAux x ()
|
||||
|
||||
def collectDecls (decls : Array Decl) : M ParamMap := do
|
||||
whileModifingParamMap (decls.forM collectDecl);
|
||||
whileModifing (decls.forM collectDecl);
|
||||
s ← get;
|
||||
pure s.map
|
||||
pure s.paramMap
|
||||
|
||||
def infer (env : Environment) (decls : Array Decl) : ParamMap :=
|
||||
(collectDecls decls { env := env }).run' { map := mkInitParamMap env decls }
|
||||
(collectDecls decls { env := env }).run' { paramMap := mkInitParamMap env decls }
|
||||
|
||||
end Borrow
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue