diff --git a/src/Init/Lean/Compiler/IR/Borrow.lean b/src/Init/Lean/Compiler/IR/Borrow.lean index e3bbb698a2..c2e5ed5b14 100644 --- a/src/Init/Lean/Compiler/IR/Borrow.lean +++ b/src/Init/Lean/Compiler/IR/Borrow.lean @@ -12,39 +12,54 @@ import Init.Lean.Compiler.IR.NormIds namespace Lean namespace IR namespace Borrow + +namespace OwnedSet +abbrev Key := FunId × Index + +def beq : Key → Key → Bool +| (f₁, x₁), (f₂, x₂) => f₁ == f₂ && x₁ == x₂ +instance : HasBeq Key := ⟨beq⟩ + +def getHash : Key → USize +| (f, x) => mixHash (hash f) (hash x) +instance : Hashable Key := ⟨getHash⟩ +end OwnedSet + +abbrev OwnedSet := HashMap OwnedSet.Key Unit +def OwnedSet.insert (s : OwnedSet) (k : OwnedSet.Key) : OwnedSet := s.insert k () +def OwnedSet.contains (s : OwnedSet) (k : OwnedSet.Key) : Bool := s.contains k + /- We perform borrow inference in a block of mutually recursive functions. Join points are viewed as local functions, and are identified using their local id + the name of the surrounding function. - We keep a mapping from function and joint points to parameters (`Array Param`). - Recall that `Param` contains the field `borrow`. - The type `Key` is the the key of this map. -/ + Recall that `Param` contains the field `borrow`. -/ +namespace ParamMap inductive Key | decl (name : FunId) | jp (name : FunId) (jpid : JoinPointId) -namespace Key def beq : Key → Key → Bool -| decl n₁, decl n₂ => n₁ == n₂ -| jp n₁ id₁, jp n₂ id₂ => n₁ == n₂ && id₁ == id₂ -| _, _ => false +| Key.decl n₁, Key.decl n₂ => n₁ == n₂ +| Key.jp n₁ id₁, Key.jp n₂ id₂ => n₁ == n₂ && id₁ == id₂ +| _, _ => false instance : HasBeq Key := ⟨beq⟩ def getHash : Key → USize -| decl n => hash n -| jp n id => mixHash (hash n) (hash id) +| Key.decl n => hash n +| Key.jp n id => mixHash (hash n) (hash id) instance : Hashable Key := ⟨getHash⟩ -end Key +end ParamMap -abbrev ParamMap := HashMap Key (Array Param) +abbrev ParamMap := HashMap ParamMap.Key (Array Param) def ParamMap.fmt (map : ParamMap) : Format := let fmts := map.fold (fun fmt k ps => let k := match k with - | Key.decl n => format n - | Key.jp n id => format n ++ ":" ++ format id; + | ParamMap.Key.decl n => format n + | ParamMap.Key.jp n id => format n ++ ":" ++ format id; fmt ++ Format.line ++ k ++ " -> " ++ formatParams ps) Format.nil; "{" ++ (Format.nest 1 fmts) ++ "}" @@ -53,7 +68,6 @@ instance : HasFormat ParamMap := ⟨ParamMap.fmt⟩ instance : HasToString ParamMap := ⟨fun m => Format.pretty (format m)⟩ namespace InitParamMap - /- Mark parameters that take a reference as borrow -/ def initBorrow (ps : Array Param) : Array Param := ps.map $ fun p => { borrow := p.ty.isObj, .. p } @@ -63,17 +77,17 @@ ps.map $ fun p => { borrow := p.ty.isObj, .. p } These wrappers use smart pointers such as `object_ref`. When writing a new wrapper we need to know whether an argument is a borrow inference or not. - We can revise this decision when we implement code for generating the wrappers automatically. -/ def initBorrowIfNotExported (exported : Bool) (ps : Array Param) : Array Param := if exported then ps else initBorrow ps partial def visitFnBody (fnid : FunId) : FnBody → StateM ParamMap Unit -| FnBody.jdecl j xs v b => do - modify $ fun m => m.insert (Key.jp fnid j) (initBorrow xs); +| FnBody.jdecl j xs v b => do + modify $ fun m => m.insert (ParamMap.Key.jp fnid j) (initBorrow xs); visitFnBody v; visitFnBody b +| FnBody.case _ _ _ alts => alts.forM $ fun alt => visitFnBody alt.body | e => unless (e.isTerminal) $ do let (instr, b) := e.split; @@ -83,7 +97,7 @@ def visitDecls (env : Environment) (decls : Array Decl) : StateM ParamMap Unit : decls.forM $ fun decl => match decl with | Decl.fdecl f xs _ b => do let exported := isExport env f; - modify $ fun m => m.insert (Key.decl f) (initBorrowIfNotExported exported xs); + modify $ fun m => m.insert (ParamMap.Key.decl f) (initBorrowIfNotExported exported xs); visitFnBody f b | _ => pure () end InitParamMap @@ -95,27 +109,29 @@ def mkInitParamMap (env : Environment) (decls : Array Decl) : ParamMap := recursive functions. -/ namespace ApplyParamMap -partial def visitFnBody : FnBody → FunId → ParamMap → FnBody -| FnBody.jdecl j xs v b, fnid, map => - let v := visitFnBody v fnid map; - let b := visitFnBody b fnid map; - match map.find? (Key.jp fnid j) with +partial def visitFnBody (fn : FunId) (paramMap : ParamMap) : FnBody → FnBody +| FnBody.jdecl j xs v b => + let v := visitFnBody v; + let b := visitFnBody b; + match paramMap.find? (ParamMap.Key.jp fn j) with | some ys => FnBody.jdecl j ys v b - | none => FnBody.jdecl j xs v b -| e, fnid, map => + | none => unreachable! +| FnBody.case tid x xType alts => + FnBody.case tid x xType $ alts.map $ fun alt => alt.modifyBody visitFnBody +| e => if e.isTerminal then e else let (instr, b) := e.split; - let b := visitFnBody b fnid map; + let b := visitFnBody b; instr.setBody b -def visitDecls (decls : Array Decl) (map : ParamMap) : Array Decl := +def visitDecls (decls : Array Decl) (paramMap : ParamMap) : Array Decl := decls.map $ fun decl => match decl with | Decl.fdecl f xs ty b => - let b := visitFnBody b f map; - match map.find? (Key.decl f) with + let b := visitFnBody f paramMap b; + match paramMap.find? (ParamMap.Key.decl f) with | some xs => Decl.fdecl f xs ty b - | none => Decl.fdecl f xs ty b + | none => unreachable! | other => other end ApplyParamMap @@ -125,66 +141,71 @@ def applyParamMap (decls : Array Decl) (map : ParamMap) : Array Decl := ApplyParamMap.visitDecls decls map structure BorrowInfCtx := -(env : Environment) -(currFn : FunId := arbitrary _) -- Function being analyzed. +(env : Environment) +(currFn : FunId := arbitrary _) -- Function being analyzed. (paramSet : IndexSet := {}) -- Set of all function parameters in scope. This is used to implement the heuristic at `ownArgsUsingParams` structure BorrowInfState := -/- `map` is a mapping storing the inferred borrow annotations for all functions (and joint points) in a mutually recursive declaration. -/ -(map : ParamMap) /- Set of variables that must be `owned`. -/ -(owned : IndexSet := {}) -(modifiedOwned : Bool := false) -(modifiedParamMap : Bool := false) +(owned : OwnedSet := {}) +(modified : Bool := false) +(paramMap : ParamMap) abbrev M := ReaderT BorrowInfCtx (StateM BorrowInfState) -def markModifiedParamMap : M Unit := -modify $ fun s => { modifiedParamMap := true, .. s } +def getCurrFn : M FunId := do +ctx ← read; +pure ctx.currFn -def ownVar (x : VarId) : M Unit := +def markModified : M Unit := +modify $ fun s => { modified := true, .. s } + +def ownVar (x : VarId) : M Unit := do -- dbgTrace ("ownVar " ++ toString x) $ fun _ => +currFn ← getCurrFn; modify $ fun s => - if s.owned.contains x.idx then s - else { owned := s.owned.insert x.idx, modifiedOwned := true, .. s } + if s.owned.contains (currFn, x.idx) then s + else { owned := s.owned.insert (currFn, x.idx), modified := true, .. s } def ownArg (x : Arg) : M Unit := match x with -| (Arg.var x) => ownVar x -| _ => pure () +| Arg.var x => ownVar x +| _ => pure () def ownArgs (xs : Array Arg) : M Unit := xs.forM ownArg def isOwned (x : VarId) : M Bool := do -s ← get; -pure $ s.owned.contains x.idx +currFn ← getCurrFn; +s ← get; +pure $ s.owned.contains (currFn, x.idx) /- Updates `map[k]` using the current set of `owned` variables. -/ -def updateParamMap (k : Key) : M Unit := do +def updateParamMap (k : ParamMap.Key) : M Unit := do +currFn ← getCurrFn; s ← get; -match s.map.find? k with +match s.paramMap.find? k with | some ps => do ps ← ps.mapM $ fun (p : Param) => - if p.borrow && s.owned.contains p.x.idx then do - markModifiedParamMap; pure { borrow := false, .. p } - else - pure p; - modify $ fun s => { map := s.map.insert k ps, .. s } + if !p.borrow then pure p + else condM (isOwned p.x) + (do markModified; pure { borrow := false, .. p }) + (pure p); + modify $ fun s => { paramMap := s.paramMap.insert k ps, .. s } | none => pure () -def getParamInfo (k : Key) : M (Array Param) := do +def getParamInfo (k : ParamMap.Key) : M (Array Param) := do s ← get; -match s.map.find? k with +match s.paramMap.find? k with | some ps => pure ps | none => match k with - | (Key.decl fn) => do + | ParamMap.Key.decl fn => do ctx ← read; match findEnvDecl ctx.env fn with | some decl => pure decl.params - | none => pure #[] -- unreachable if well-formed input - | _ => pure #[] -- unreachable if well-formed input + | none => unreachable! + | _ => unreachable! /- For each ps[i], if ps[i] is owned, then mark xs[i] as owned. -/ def ownArgsUsingParams (xs : Array Arg) (ps : Array Param) : M Unit := @@ -214,8 +235,7 @@ xs.size.forM $ fun i => do def f (x y : obj) := let z := ctor_1 x y; ret z - ``` --/ + ``` -/ def ownArgsIfParam (xs : Array Arg) : M Unit := do ctx ← read; xs.forM $ fun x => @@ -227,8 +247,10 @@ def collectExpr (z : VarId) : Expr → M Unit | Expr.reset _ x => ownVar z *> ownVar x | Expr.reuse x _ _ ys => ownVar z *> ownVar x *> ownArgsIfParam ys | Expr.ctor _ xs => ownVar z *> ownArgsIfParam xs -| Expr.proj _ x => whenM (isOwned z) $ ownVar x -| Expr.fap g xs => do ps ← getParamInfo (Key.decl g); +| Expr.proj _ x => do + whenM (isOwned x) $ ownVar z; + whenM (isOwned z) $ ownVar x +| Expr.fap g xs => do ps ← getParamInfo (ParamMap.Key.decl g); -- dbgTrace ("collectExpr: " ++ toString g ++ " " ++ toString (formatParams ps)) $ fun _ => ownVar z *> ownArgsUsingParams xs ps | Expr.ap x ys => ownVar z *> ownVar x *> ownArgs ys @@ -241,7 +263,7 @@ match v, b with | (Expr.fap g ys), (FnBody.ret (Arg.var z)) => when (ctx.currFn == g && x == z) $ do -- dbgTrace ("preserveTailCall " ++ toString b) $ fun _ => do - ps ← getParamInfo (Key.decl g); + ps ← getParamInfo (ParamMap.Key.decl g); ownParamsUsingArgs ys ps | _, _ => pure () @@ -252,58 +274,45 @@ partial def collectFnBody : FnBody → M Unit | FnBody.jdecl j ys v b => do adaptReader (fun ctx => updateParamSet ctx ys) (collectFnBody v); ctx ← read; - updateParamMap (Key.jp ctx.currFn j); + updateParamMap (ParamMap.Key.jp ctx.currFn j); collectFnBody b | FnBody.vdecl x _ v b => collectFnBody b *> collectExpr x v *> preserveTailCall x v b | FnBody.jmp j ys => do ctx ← read; - ps ← getParamInfo (Key.jp ctx.currFn j); + ps ← getParamInfo (ParamMap.Key.jp ctx.currFn j); ownArgsUsingParams ys ps; -- for making sure the join point can reuse ownParamsUsingArgs ys ps -- for making sure the tail call is preserved | FnBody.case _ _ _ alts => alts.forM $ fun alt => collectFnBody alt.body | e => unless (e.isTerminal) $ collectFnBody e.body -@[specialize] partial def whileModifingOwnedAux (x : M Unit) : Unit → M Unit -| _ => do - modify $ fun s => { modifiedOwned := false, .. s }; - x; - s ← get; - if s.modifiedOwned then whileModifingOwnedAux () - else pure () - -/- Keep executing `x` while it modifies ownedSet -/ -@[inline] def whileModifingOwned (x : M Unit) : M Unit := -whileModifingOwnedAux x () - partial def collectDecl : Decl → M Unit | Decl.fdecl f ys _ b => adaptReader (fun ctx => let ctx := updateParamSet ctx ys; { currFn := f, .. ctx }) $ do - modify $ fun (s : BorrowInfState) => { owned := {}, .. s }; - whileModifingOwned (collectFnBody b); - updateParamMap (Key.decl f) + collectFnBody b; + updateParamMap (ParamMap.Key.decl f) | _ => pure () -@[specialize] partial def whileModifingParamMapAux (x : M Unit) : Unit → M Unit +@[specialize] partial def whileModifingAux (x : M Unit) : Unit → M Unit | _ => do - modify $ fun s => { modifiedParamMap := false, .. s }; - s ← get; + modify $ fun s => { modified := false, .. s }; + -- s ← get; -- dbgTrace (toString s.map) $ fun _ => do x; s ← get; - if s.modifiedParamMap then whileModifingParamMapAux () + if s.modified then whileModifingAux () else pure () -/- Keep executing `x` while it modifies paramMap -/ -@[inline] def whileModifingParamMap (x : M Unit) : M Unit := -whileModifingParamMapAux x () +/- Keep executing `x` until it reaches a fixpoint -/ +@[inline] def whileModifing (x : M Unit) : M Unit := +whileModifingAux x () def collectDecls (decls : Array Decl) : M ParamMap := do -whileModifingParamMap (decls.forM collectDecl); +whileModifing (decls.forM collectDecl); s ← get; -pure s.map +pure s.paramMap def infer (env : Environment) (decls : Array Decl) : ParamMap := -(collectDecls decls { env := env }).run' { map := mkInitParamMap env decls } +(collectDecls decls { env := env }).run' { paramMap := mkInitParamMap env decls } end Borrow