292 lines
12 KiB
Text
292 lines
12 KiB
Text
/-
|
||
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
|
||
Released under Apache 2.0 license as described in the file LICENSE.
|
||
Authors: Leonardo de Moura
|
||
-/
|
||
prelude
|
||
import Init.Lean.Runtime
|
||
import Init.Lean.Compiler.IR.CompilerM
|
||
import Init.Lean.Compiler.IR.LiveVars
|
||
|
||
namespace Lean
|
||
namespace IR
|
||
namespace ExplicitRC
|
||
/- Insert explicit RC instructions. So, it assumes the input code does not contain `inc` nor `dec` instructions.
|
||
This transformation is applied before lower level optimizations
|
||
that introduce the instructions `release` and `set`
|
||
-/
|
||
|
||
structure VarInfo :=
|
||
(ref : Bool := true) -- true if the variable may be a reference (aka pointer) at runtime
|
||
(persistent : Bool := false) -- true if the variable is statically known to be marked a Persistent at runtime
|
||
(consume : Bool := false) -- true if the variable RC must be "consumed"
|
||
|
||
abbrev VarMap := RBMap VarId VarInfo (fun x y => x.idx < y.idx)
|
||
|
||
structure Context :=
|
||
(env : Environment)
|
||
(decls : Array Decl)
|
||
(varMap : VarMap := {})
|
||
(jpLiveVarMap : JPLiveVarMap := {}) -- map: join point => live variables
|
||
(localCtx : LocalContext := {}) -- we use it to store the join point declarations
|
||
|
||
def getDecl (ctx : Context) (fid : FunId) : Decl :=
|
||
match findEnvDecl' ctx.env fid ctx.decls with
|
||
| some decl => decl
|
||
| none => arbitrary _ -- unreachable if well-formed
|
||
|
||
def getVarInfo (ctx : Context) (x : VarId) : VarInfo :=
|
||
match ctx.varMap.find x with
|
||
| some info => info
|
||
| none => {} -- unreachable in well-formed code
|
||
|
||
def getJPParams (ctx : Context) (j : JoinPointId) : Array Param :=
|
||
match ctx.localCtx.getJPParams j with
|
||
| some ps => ps
|
||
| none => #[] -- unreachable in well-formed code
|
||
|
||
def getJPLiveVars (ctx : Context) (j : JoinPointId) : LiveVarSet :=
|
||
match ctx.jpLiveVarMap.find j with
|
||
| some s => s
|
||
| none => {}
|
||
|
||
def mustConsume (ctx : Context) (x : VarId) : Bool :=
|
||
let info := getVarInfo ctx x;
|
||
info.ref && info.consume
|
||
|
||
@[inline] def addInc (ctx : Context) (x : VarId) (b : FnBody) (n := 1) : FnBody :=
|
||
let info := getVarInfo ctx x;
|
||
if n == 0 then b else FnBody.inc x n true info.persistent b
|
||
|
||
@[inline] def addDec (ctx : Context) (x : VarId) (b : FnBody) : FnBody :=
|
||
let info := getVarInfo ctx x;
|
||
FnBody.dec x 1 true info.persistent b
|
||
|
||
private def updateRefUsingCtorInfo (ctx : Context) (x : VarId) (c : CtorInfo) : Context :=
|
||
if c.isRef then ctx
|
||
else let m := ctx.varMap;
|
||
{ varMap := match m.find x with
|
||
| some info => m.insert x { ref := false, .. info } -- I really want a Lenses library + notation
|
||
| none => m,
|
||
.. ctx }
|
||
|
||
private def addDecForAlt (ctx : Context) (caseLiveVars altLiveVars : LiveVarSet) (b : FnBody) : FnBody :=
|
||
caseLiveVars.fold
|
||
(fun b x => if !altLiveVars.contains x && mustConsume ctx x then addDec ctx x b else b)
|
||
b
|
||
|
||
/- `isFirstOcc xs x i = true` if `xs[i]` is the first occurrence of `xs[i]` in `xs` -/
|
||
private def isFirstOcc (xs : Array Arg) (i : Nat) : Bool :=
|
||
let x := xs.get! i;
|
||
i.all $ fun j => xs.get! j != x
|
||
|
||
/- Return true if `x` also occurs in `ys` in a position that is not consumed.
|
||
That is, it is also passed as a borrow reference. -/
|
||
@[specialize]
|
||
private def isBorrowParamAux (x : VarId) (ys : Array Arg) (consumeParamPred : Nat → Bool) : Bool :=
|
||
ys.size.any $ fun i =>
|
||
let y := ys.get! i;
|
||
match y with
|
||
| Arg.irrelevant => false
|
||
| Arg.var y => x == y && !consumeParamPred i
|
||
|
||
private def isBorrowParam (x : VarId) (ys : Array Arg) (ps : Array Param) : Bool :=
|
||
isBorrowParamAux x ys (fun i => not (ps.get! i).borrow)
|
||
|
||
/-
|
||
Return `n`, the number of times `x` is consumed.
|
||
- `ys` is a sequence of instruction parameters where we search for `x`.
|
||
- `consumeParamPred i = true` if parameter `i` is consumed.
|
||
-/
|
||
@[specialize]
|
||
private def getNumConsumptions (x : VarId) (ys : Array Arg) (consumeParamPred : Nat → Bool) : Nat :=
|
||
ys.size.fold
|
||
(fun i n =>
|
||
let y := ys.get! i;
|
||
match y with
|
||
| Arg.irrelevant => n
|
||
| Arg.var y => if x == y && consumeParamPred i then n+1 else n)
|
||
0
|
||
|
||
@[specialize]
|
||
private def addIncBeforeAux (ctx : Context) (xs : Array Arg) (consumeParamPred : Nat → Bool) (b : FnBody) (liveVarsAfter : LiveVarSet) : FnBody :=
|
||
xs.size.fold
|
||
(fun i b =>
|
||
let x := xs.get! i;
|
||
match x with
|
||
| Arg.irrelevant => b
|
||
| Arg.var x =>
|
||
let info := getVarInfo ctx x;
|
||
if !info.ref || !isFirstOcc xs i then b
|
||
else
|
||
let numConsuptions := getNumConsumptions x xs consumeParamPred; -- number of times the argument is
|
||
let numIncs :=
|
||
if !info.consume || -- `x` is not a variable that must be consumed by the current procedure
|
||
liveVarsAfter.contains x || -- `x` is live after executing instruction
|
||
isBorrowParamAux x xs consumeParamPred -- `x` is used in a position that is passed as a borrow reference
|
||
then numConsuptions
|
||
else numConsuptions - 1;
|
||
-- dbgTrace ("addInc " ++ toString x ++ " nconsumptions: " ++ toString numConsuptions ++ " incs: " ++ toString numIncs
|
||
-- ++ " consume: " ++ toString info.consume ++ " live: " ++ toString (liveVarsAfter.contains x)
|
||
-- ++ " borrowParam : " ++ toString (isBorrowParamAux x xs consumeParamPred)) $ fun _ =>
|
||
addInc ctx x b numIncs)
|
||
b
|
||
|
||
private def addIncBefore (ctx : Context) (xs : Array Arg) (ps : Array Param) (b : FnBody) (liveVarsAfter : LiveVarSet) : FnBody :=
|
||
addIncBeforeAux ctx xs (fun i => not (ps.get! i).borrow) b liveVarsAfter
|
||
|
||
/- See `addIncBeforeAux`/`addIncBefore` for the procedure that inserts `inc` operations before an application. -/
|
||
private def addDecAfterFullApp (ctx : Context) (xs : Array Arg) (ps : Array Param) (b : FnBody) (bLiveVars : LiveVarSet) : FnBody :=
|
||
xs.size.fold
|
||
(fun i b =>
|
||
match xs.get! i with
|
||
| Arg.irrelevant => b
|
||
| Arg.var x =>
|
||
/- We must add a `dec` if `x` must be consumed, it is alive after the application,
|
||
and it has been borrowed by the application.
|
||
Remark: `x` may occur multiple times in the application (e.g., `f x y x`).
|
||
This is why we check whether it is the first occurrence. -/
|
||
if mustConsume ctx x && isFirstOcc xs i && isBorrowParam x xs ps && !bLiveVars.contains x then
|
||
addDec ctx x b
|
||
else b)
|
||
b
|
||
|
||
private def addIncBeforeConsumeAll (ctx : Context) (xs : Array Arg) (b : FnBody) (liveVarsAfter : LiveVarSet) : FnBody :=
|
||
addIncBeforeAux ctx xs (fun i => true) b liveVarsAfter
|
||
|
||
/- Add `dec` instructions for parameters that are references, are not alive in `b`, and are not borrow.
|
||
That is, we must make sure these parameters are consumed. -/
|
||
private def addDecForDeadParams (ctx : Context) (ps : Array Param) (b : FnBody) (bLiveVars : LiveVarSet) : FnBody :=
|
||
ps.foldl
|
||
(fun b p => if !p.borrow && p.ty.isObj && !bLiveVars.contains p.x then addDec ctx p.x b else b)
|
||
b
|
||
|
||
private def isPersistent : Expr → Bool
|
||
| Expr.fap c xs => xs.isEmpty -- all global constants are persistent objects
|
||
| _ => false
|
||
|
||
/- We do not need to consume the projection of a variable that is not consumed -/
|
||
private def consumeExpr (m : VarMap) : Expr → Bool
|
||
| Expr.proj i x => match m.find x with
|
||
| some info => info.consume
|
||
| none => true
|
||
| other => true
|
||
|
||
/- Return true iff `v` at runtime is a scalar value stored in a tagged pointer.
|
||
We do not need RC operations for this kind of value. -/
|
||
private def isScalarBoxedInTaggedPtr (v : Expr) : Bool :=
|
||
match v with
|
||
| Expr.ctor c ys => c.size == 0 && c.ssize == 0 && c.usize == 0
|
||
| Expr.lit (LitVal.num n) => n ≤ maxSmallNat
|
||
| _ => false
|
||
|
||
private def updateVarInfo (ctx : Context) (x : VarId) (t : IRType) (v : Expr) : Context :=
|
||
{ varMap := ctx.varMap.insert x {
|
||
ref := t.isObj && !isScalarBoxedInTaggedPtr v,
|
||
persistent := isPersistent v,
|
||
consume := consumeExpr ctx.varMap v },
|
||
.. ctx }
|
||
|
||
private def addDecIfNeeded (ctx : Context) (x : VarId) (b : FnBody) (bLiveVars : LiveVarSet) : FnBody :=
|
||
if mustConsume ctx x && !bLiveVars.contains x then addDec ctx x b else b
|
||
|
||
private def processVDecl (ctx : Context) (z : VarId) (t : IRType) (v : Expr) (b : FnBody) (bLiveVars : LiveVarSet) : FnBody × LiveVarSet :=
|
||
-- dbgTrace ("processVDecl " ++ toString z ++ " " ++ toString (format v)) $ fun _ =>
|
||
let b := match v with
|
||
| (Expr.ctor _ ys) => addIncBeforeConsumeAll ctx ys (FnBody.vdecl z t v b) bLiveVars
|
||
| (Expr.reuse _ _ _ ys) => addIncBeforeConsumeAll ctx ys (FnBody.vdecl z t v b) bLiveVars
|
||
| (Expr.proj _ x) =>
|
||
let b := addDecIfNeeded ctx x b bLiveVars;
|
||
let b := if (getVarInfo ctx x).consume then addInc ctx z b else b;
|
||
(FnBody.vdecl z t v b)
|
||
| (Expr.uproj _ x) => FnBody.vdecl z t v (addDecIfNeeded ctx x b bLiveVars)
|
||
| (Expr.sproj _ _ x) => FnBody.vdecl z t v (addDecIfNeeded ctx x b bLiveVars)
|
||
| (Expr.fap f ys) =>
|
||
-- dbgTrace ("processVDecl " ++ toString v) $ fun _ =>
|
||
let ps := (getDecl ctx f).params;
|
||
let b := addDecAfterFullApp ctx ys ps b bLiveVars;
|
||
let b := FnBody.vdecl z t v b;
|
||
addIncBefore ctx ys ps b bLiveVars
|
||
| (Expr.pap _ ys) => addIncBeforeConsumeAll ctx ys (FnBody.vdecl z t v b) bLiveVars
|
||
| (Expr.ap x ys) =>
|
||
let ysx := ys.push (Arg.var x); -- TODO: avoid temporary array allocation
|
||
addIncBeforeConsumeAll ctx ysx (FnBody.vdecl z t v b) bLiveVars
|
||
| (Expr.unbox x) => FnBody.vdecl z t v (addDecIfNeeded ctx x b bLiveVars)
|
||
| other => FnBody.vdecl z t v b; -- Expr.reset, Expr.box, Expr.lit are handled here
|
||
let liveVars := updateLiveVars v bLiveVars;
|
||
let liveVars := liveVars.erase z;
|
||
(b, liveVars)
|
||
|
||
def updateVarInfoWithParams (ctx : Context) (ps : Array Param) : Context :=
|
||
let m := ps.foldl (fun (m : VarMap) p => m.insert p.x { ref := p.ty.isObj, consume := !p.borrow }) ctx.varMap;
|
||
{ varMap := m, .. ctx }
|
||
|
||
partial def visitFnBody : FnBody → Context → (FnBody × LiveVarSet)
|
||
| FnBody.vdecl x t v b, ctx =>
|
||
let ctx := updateVarInfo ctx x t v;
|
||
let (b, bLiveVars) := visitFnBody b ctx;
|
||
processVDecl ctx x t v b bLiveVars
|
||
| FnBody.jdecl j xs v b, ctx =>
|
||
let (v, vLiveVars) := visitFnBody v (updateVarInfoWithParams ctx xs);
|
||
let v := addDecForDeadParams ctx xs v vLiveVars;
|
||
let ctx := { jpLiveVarMap := updateJPLiveVarMap j xs v ctx.jpLiveVarMap, .. ctx };
|
||
let (b, bLiveVars) := visitFnBody b ctx;
|
||
(FnBody.jdecl j xs v b, bLiveVars)
|
||
| FnBody.uset x i y b, ctx =>
|
||
let (b, s) := visitFnBody b ctx;
|
||
-- We don't need to insert `y` since we only need to track live variables that are references at runtime
|
||
let s := s.insert x;
|
||
(FnBody.uset x i y b, s)
|
||
| FnBody.sset x i o y t b, ctx =>
|
||
let (b, s) := visitFnBody b ctx;
|
||
-- We don't need to insert `y` since we only need to track live variables that are references at runtime
|
||
let s := s.insert x;
|
||
(FnBody.sset x i o y t b, s)
|
||
| FnBody.mdata m b, ctx =>
|
||
let (b, s) := visitFnBody b ctx;
|
||
(FnBody.mdata m b, s)
|
||
| b@(FnBody.case tid x xType alts), ctx =>
|
||
let caseLiveVars := collectLiveVars b ctx.jpLiveVarMap;
|
||
let alts := alts.map $ fun alt => match alt with
|
||
| Alt.ctor c b =>
|
||
let ctx := updateRefUsingCtorInfo ctx x c;
|
||
let (b, altLiveVars) := visitFnBody b ctx;
|
||
let b := addDecForAlt ctx caseLiveVars altLiveVars b;
|
||
Alt.ctor c b
|
||
| Alt.default b =>
|
||
let (b, altLiveVars) := visitFnBody b ctx;
|
||
let b := addDecForAlt ctx caseLiveVars altLiveVars b;
|
||
Alt.default b;
|
||
(FnBody.case tid x xType alts, caseLiveVars)
|
||
| b@(FnBody.ret x), ctx =>
|
||
match x with
|
||
| Arg.var x =>
|
||
let info := getVarInfo ctx x;
|
||
if info.ref && !info.consume then (addInc ctx x b, mkLiveVarSet x) else (b, mkLiveVarSet x)
|
||
| _ => (b, {})
|
||
| b@(FnBody.jmp j xs), ctx =>
|
||
let jLiveVars := getJPLiveVars ctx j;
|
||
let ps := getJPParams ctx j;
|
||
let b := addIncBefore ctx xs ps b jLiveVars;
|
||
let bLiveVars := collectLiveVars b ctx.jpLiveVarMap;
|
||
(b, bLiveVars)
|
||
| FnBody.unreachable, _ => (FnBody.unreachable, {})
|
||
| other, ctx => (other, {}) -- unreachable if well-formed
|
||
|
||
partial def visitDecl (env : Environment) (decls : Array Decl) : Decl → Decl
|
||
| Decl.fdecl f xs t b =>
|
||
let ctx : Context := { env := env, decls := decls };
|
||
let ctx := updateVarInfoWithParams ctx xs;
|
||
let (b, bLiveVars) := visitFnBody b ctx;
|
||
let b := addDecForDeadParams ctx xs b bLiveVars;
|
||
Decl.fdecl f xs t b
|
||
| other => other
|
||
|
||
end ExplicitRC
|
||
|
||
def explicitRC (decls : Array Decl) : CompilerM (Array Decl) :=
|
||
do env ← getEnv;
|
||
pure $ decls.map (ExplicitRC.visitDecl env decls)
|
||
|
||
end IR
|
||
end Lean
|