fix: Borrow.lean

The following bugs have been fixed
- The implementation for the `Expr.proj` rule did not match our paper nor `RC.lean`.
The correct rule is:
```
| Expr.proj _ x       => whenM (isOwned x) $ ownVar z
```

- We take the OwnsetSet (`O` in our paper) into account when computing
the fixpoint.

- `applyParamMap` and `mkInitParamMap` were not visiting the alternatives of `case` statements.
This commit is contained in:
Leonardo de Moura 2019-12-22 18:21:12 -08:00
parent f5741af39d
commit 7148fc1078

View file

@ -12,39 +12,54 @@ import Init.Lean.Compiler.IR.NormIds
namespace Lean
namespace IR
namespace Borrow
namespace OwnedSet
abbrev Key := FunId × Index
def beq : Key → Key → Bool
| (f₁, x₁), (f₂, x₂) => f₁ == f₂ && x₁ == x₂
instance : HasBeq Key := ⟨beq⟩
def getHash : Key → USize
| (f, x) => mixHash (hash f) (hash x)
instance : Hashable Key := ⟨getHash⟩
end OwnedSet
abbrev OwnedSet := HashMap OwnedSet.Key Unit
def OwnedSet.insert (s : OwnedSet) (k : OwnedSet.Key) : OwnedSet := s.insert k ()
def OwnedSet.contains (s : OwnedSet) (k : OwnedSet.Key) : Bool := s.contains k
/- We perform borrow inference in a block of mutually recursive functions.
Join points are viewed as local functions, and are identified using
their local id + the name of the surrounding function.
We keep a mapping from function and joint points to parameters (`Array Param`).
Recall that `Param` contains the field `borrow`.
The type `Key` is the the key of this map. -/
Recall that `Param` contains the field `borrow`. -/
namespace ParamMap
inductive Key
| decl (name : FunId)
| jp (name : FunId) (jpid : JoinPointId)
namespace Key
def beq : Key → Key → Bool
| decl n₁, decl n₂ => n₁ == n₂
| jp n₁ id₁, jp n₂ id₂ => n₁ == n₂ && id₁ == id₂
| _, _ => false
| Key.decl n₁, Key.decl n₂ => n₁ == n₂
| Key.jp n₁ id₁, Key.jp n₂ id₂ => n₁ == n₂ && id₁ == id₂
| _, _ => false
instance : HasBeq Key := ⟨beq⟩
def getHash : Key → USize
| decl n => hash n
| jp n id => mixHash (hash n) (hash id)
| Key.decl n => hash n
| Key.jp n id => mixHash (hash n) (hash id)
instance : Hashable Key := ⟨getHash⟩
end Key
end ParamMap
abbrev ParamMap := HashMap Key (Array Param)
abbrev ParamMap := HashMap ParamMap.Key (Array Param)
def ParamMap.fmt (map : ParamMap) : Format :=
let fmts := map.fold (fun fmt k ps =>
let k := match k with
| Key.decl n => format n
| Key.jp n id => format n ++ ":" ++ format id;
| ParamMap.Key.decl n => format n
| ParamMap.Key.jp n id => format n ++ ":" ++ format id;
fmt ++ Format.line ++ k ++ " -> " ++ formatParams ps)
Format.nil;
"{" ++ (Format.nest 1 fmts) ++ "}"
@ -53,7 +68,6 @@ instance : HasFormat ParamMap := ⟨ParamMap.fmt⟩
instance : HasToString ParamMap := ⟨fun m => Format.pretty (format m)⟩
namespace InitParamMap
/- Mark parameters that take a reference as borrow -/
def initBorrow (ps : Array Param) : Array Param :=
ps.map $ fun p => { borrow := p.ty.isObj, .. p }
@ -63,17 +77,17 @@ ps.map $ fun p => { borrow := p.ty.isObj, .. p }
These wrappers use smart pointers such as `object_ref`.
When writing a new wrapper we need to know whether an argument is a borrow
inference or not.
We can revise this decision when we implement code for generating
the wrappers automatically. -/
def initBorrowIfNotExported (exported : Bool) (ps : Array Param) : Array Param :=
if exported then ps else initBorrow ps
partial def visitFnBody (fnid : FunId) : FnBody → StateM ParamMap Unit
| FnBody.jdecl j xs v b => do
modify $ fun m => m.insert (Key.jp fnid j) (initBorrow xs);
| FnBody.jdecl j xs v b => do
modify $ fun m => m.insert (ParamMap.Key.jp fnid j) (initBorrow xs);
visitFnBody v;
visitFnBody b
| FnBody.case _ _ _ alts => alts.forM $ fun alt => visitFnBody alt.body
| e =>
unless (e.isTerminal) $ do
let (instr, b) := e.split;
@ -83,7 +97,7 @@ def visitDecls (env : Environment) (decls : Array Decl) : StateM ParamMap Unit :
decls.forM $ fun decl => match decl with
| Decl.fdecl f xs _ b => do
let exported := isExport env f;
modify $ fun m => m.insert (Key.decl f) (initBorrowIfNotExported exported xs);
modify $ fun m => m.insert (ParamMap.Key.decl f) (initBorrowIfNotExported exported xs);
visitFnBody f b
| _ => pure ()
end InitParamMap
@ -95,27 +109,29 @@ def mkInitParamMap (env : Environment) (decls : Array Decl) : ParamMap :=
recursive functions. -/
namespace ApplyParamMap
partial def visitFnBody : FnBody → FunId → ParamMap → FnBody
| FnBody.jdecl j xs v b, fnid, map =>
let v := visitFnBody v fnid map;
let b := visitFnBody b fnid map;
match map.find? (Key.jp fnid j) with
partial def visitFnBody (fn : FunId) (paramMap : ParamMap) : FnBody → FnBody
| FnBody.jdecl j xs v b =>
let v := visitFnBody v;
let b := visitFnBody b;
match paramMap.find? (ParamMap.Key.jp fn j) with
| some ys => FnBody.jdecl j ys v b
| none => FnBody.jdecl j xs v b
| e, fnid, map =>
| none => unreachable!
| FnBody.case tid x xType alts =>
FnBody.case tid x xType $ alts.map $ fun alt => alt.modifyBody visitFnBody
| e =>
if e.isTerminal then e
else
let (instr, b) := e.split;
let b := visitFnBody b fnid map;
let b := visitFnBody b;
instr.setBody b
def visitDecls (decls : Array Decl) (map : ParamMap) : Array Decl :=
def visitDecls (decls : Array Decl) (paramMap : ParamMap) : Array Decl :=
decls.map $ fun decl => match decl with
| Decl.fdecl f xs ty b =>
let b := visitFnBody b f map;
match map.find? (Key.decl f) with
let b := visitFnBody f paramMap b;
match paramMap.find? (ParamMap.Key.decl f) with
| some xs => Decl.fdecl f xs ty b
| none => Decl.fdecl f xs ty b
| none => unreachable!
| other => other
end ApplyParamMap
@ -125,66 +141,71 @@ def applyParamMap (decls : Array Decl) (map : ParamMap) : Array Decl :=
ApplyParamMap.visitDecls decls map
structure BorrowInfCtx :=
(env : Environment)
(currFn : FunId := arbitrary _) -- Function being analyzed.
(env : Environment)
(currFn : FunId := arbitrary _) -- Function being analyzed.
(paramSet : IndexSet := {}) -- Set of all function parameters in scope. This is used to implement the heuristic at `ownArgsUsingParams`
structure BorrowInfState :=
/- `map` is a mapping storing the inferred borrow annotations for all functions (and joint points) in a mutually recursive declaration. -/
(map : ParamMap)
/- Set of variables that must be `owned`. -/
(owned : IndexSet := {})
(modifiedOwned : Bool := false)
(modifiedParamMap : Bool := false)
(owned : OwnedSet := {})
(modified : Bool := false)
(paramMap : ParamMap)
abbrev M := ReaderT BorrowInfCtx (StateM BorrowInfState)
def markModifiedParamMap : M Unit :=
modify $ fun s => { modifiedParamMap := true, .. s }
def getCurrFn : M FunId := do
ctx ← read;
pure ctx.currFn
def ownVar (x : VarId) : M Unit :=
def markModified : M Unit :=
modify $ fun s => { modified := true, .. s }
def ownVar (x : VarId) : M Unit := do
-- dbgTrace ("ownVar " ++ toString x) $ fun _ =>
currFn ← getCurrFn;
modify $ fun s =>
if s.owned.contains x.idx then s
else { owned := s.owned.insert x.idx, modifiedOwned := true, .. s }
if s.owned.contains (currFn, x.idx) then s
else { owned := s.owned.insert (currFn, x.idx), modified := true, .. s }
def ownArg (x : Arg) : M Unit :=
match x with
| (Arg.var x) => ownVar x
| _ => pure ()
| Arg.var x => ownVar x
| _ => pure ()
def ownArgs (xs : Array Arg) : M Unit :=
xs.forM ownArg
def isOwned (x : VarId) : M Bool := do
s ← get;
pure $ s.owned.contains x.idx
currFn ← getCurrFn;
s ← get;
pure $ s.owned.contains (currFn, x.idx)
/- Updates `map[k]` using the current set of `owned` variables. -/
def updateParamMap (k : Key) : M Unit := do
def updateParamMap (k : ParamMap.Key) : M Unit := do
currFn ← getCurrFn;
s ← get;
match s.map.find? k with
match s.paramMap.find? k with
| some ps => do
ps ← ps.mapM $ fun (p : Param) =>
if p.borrow && s.owned.contains p.x.idx then do
markModifiedParamMap; pure { borrow := false, .. p }
else
pure p;
modify $ fun s => { map := s.map.insert k ps, .. s }
if !p.borrow then pure p
else condM (isOwned p.x)
(do markModified; pure { borrow := false, .. p })
(pure p);
modify $ fun s => { paramMap := s.paramMap.insert k ps, .. s }
| none => pure ()
def getParamInfo (k : Key) : M (Array Param) := do
def getParamInfo (k : ParamMap.Key) : M (Array Param) := do
s ← get;
match s.map.find? k with
match s.paramMap.find? k with
| some ps => pure ps
| none =>
match k with
| (Key.decl fn) => do
| ParamMap.Key.decl fn => do
ctx ← read;
match findEnvDecl ctx.env fn with
| some decl => pure decl.params
| none => pure #[] -- unreachable if well-formed input
| _ => pure #[] -- unreachable if well-formed input
| none => unreachable!
| _ => unreachable!
/- For each ps[i], if ps[i] is owned, then mark xs[i] as owned. -/
def ownArgsUsingParams (xs : Array Arg) (ps : Array Param) : M Unit :=
@ -214,8 +235,7 @@ xs.size.forM $ fun i => do
def f (x y : obj) :=
let z := ctor_1 x y;
ret z
```
-/
``` -/
def ownArgsIfParam (xs : Array Arg) : M Unit := do
ctx ← read;
xs.forM $ fun x =>
@ -227,8 +247,10 @@ def collectExpr (z : VarId) : Expr → M Unit
| Expr.reset _ x => ownVar z *> ownVar x
| Expr.reuse x _ _ ys => ownVar z *> ownVar x *> ownArgsIfParam ys
| Expr.ctor _ xs => ownVar z *> ownArgsIfParam xs
| Expr.proj _ x => whenM (isOwned z) $ ownVar x
| Expr.fap g xs => do ps ← getParamInfo (Key.decl g);
| Expr.proj _ x => do
whenM (isOwned x) $ ownVar z;
whenM (isOwned z) $ ownVar x
| Expr.fap g xs => do ps ← getParamInfo (ParamMap.Key.decl g);
-- dbgTrace ("collectExpr: " ++ toString g ++ " " ++ toString (formatParams ps)) $ fun _ =>
ownVar z *> ownArgsUsingParams xs ps
| Expr.ap x ys => ownVar z *> ownVar x *> ownArgs ys
@ -241,7 +263,7 @@ match v, b with
| (Expr.fap g ys), (FnBody.ret (Arg.var z)) =>
when (ctx.currFn == g && x == z) $ do
-- dbgTrace ("preserveTailCall " ++ toString b) $ fun _ => do
ps ← getParamInfo (Key.decl g);
ps ← getParamInfo (ParamMap.Key.decl g);
ownParamsUsingArgs ys ps
| _, _ => pure ()
@ -252,58 +274,45 @@ partial def collectFnBody : FnBody → M Unit
| FnBody.jdecl j ys v b => do
adaptReader (fun ctx => updateParamSet ctx ys) (collectFnBody v);
ctx ← read;
updateParamMap (Key.jp ctx.currFn j);
updateParamMap (ParamMap.Key.jp ctx.currFn j);
collectFnBody b
| FnBody.vdecl x _ v b => collectFnBody b *> collectExpr x v *> preserveTailCall x v b
| FnBody.jmp j ys => do
ctx ← read;
ps ← getParamInfo (Key.jp ctx.currFn j);
ps ← getParamInfo (ParamMap.Key.jp ctx.currFn j);
ownArgsUsingParams ys ps; -- for making sure the join point can reuse
ownParamsUsingArgs ys ps -- for making sure the tail call is preserved
| FnBody.case _ _ _ alts => alts.forM $ fun alt => collectFnBody alt.body
| e => unless (e.isTerminal) $ collectFnBody e.body
@[specialize] partial def whileModifingOwnedAux (x : M Unit) : Unit → M Unit
| _ => do
modify $ fun s => { modifiedOwned := false, .. s };
x;
s ← get;
if s.modifiedOwned then whileModifingOwnedAux ()
else pure ()
/- Keep executing `x` while it modifies ownedSet -/
@[inline] def whileModifingOwned (x : M Unit) : M Unit :=
whileModifingOwnedAux x ()
partial def collectDecl : Decl → M Unit
| Decl.fdecl f ys _ b =>
adaptReader (fun ctx => let ctx := updateParamSet ctx ys; { currFn := f, .. ctx }) $ do
modify $ fun (s : BorrowInfState) => { owned := {}, .. s };
whileModifingOwned (collectFnBody b);
updateParamMap (Key.decl f)
collectFnBody b;
updateParamMap (ParamMap.Key.decl f)
| _ => pure ()
@[specialize] partial def whileModifingParamMapAux (x : M Unit) : Unit → M Unit
@[specialize] partial def whileModifingAux (x : M Unit) : Unit → M Unit
| _ => do
modify $ fun s => { modifiedParamMap := false, .. s };
s ← get;
modify $ fun s => { modified := false, .. s };
-- s ← get;
-- dbgTrace (toString s.map) $ fun _ => do
x;
s ← get;
if s.modifiedParamMap then whileModifingParamMapAux ()
if s.modified then whileModifingAux ()
else pure ()
/- Keep executing `x` while it modifies paramMap -/
@[inline] def whileModifingParamMap (x : M Unit) : M Unit :=
whileModifingParamMapAux x ()
/- Keep executing `x` until it reaches a fixpoint -/
@[inline] def whileModifing (x : M Unit) : M Unit :=
whileModifingAux x ()
def collectDecls (decls : Array Decl) : M ParamMap := do
whileModifingParamMap (decls.forM collectDecl);
whileModifing (decls.forM collectDecl);
s ← get;
pure s.map
pure s.paramMap
def infer (env : Environment) (decls : Array Decl) : ParamMap :=
(collectDecls decls { env := env }).run' { map := mkInitParamMap env decls }
(collectDecls decls { env := env }).run' { paramMap := mkInitParamMap env decls }
end Borrow