lean4-htt/library/init/lean/compiler/ir/boxing.lean
2019-05-30 07:30:07 -07:00

276 lines
10 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import init.control.estate
import init.control.reader
import init.lean.extern
import init.lean.compiler.ir.basic
import init.lean.compiler.ir.compilerm
import init.lean.compiler.ir.freevars
namespace Lean
namespace IR
namespace ExplicitBoxing
/-
Add explicit boxing and unboxing instructions.
Recall that the Lean to λ_pure compiler produces code without these instructions.
Assumptions:
- This transformation is applied before explicit RC instructions (`inc`, `dec`) are inserted.
- This transformation is applied before `FnBody.case` has been simplified and `Alt.default` is used.
Reason: if there is no `Alt.default` branch, then we can decide whether `x` at `FnBody.case x alts` is an
enumeration type by simply inspecting the `CtorInfo` values at `alts`.
- This transformation is applied before lower level optimizations are applied which use
`Expr.isShared`, `Expr.isTaggedPtr`, and `FnBody.set`.
- This transformation is applied after `reset` and `reuse` instructions have been added.
Reason: `resetreuse.lean` ignores `box` and `unbox` instructions.
-/
def mkBoxedName (n : Name) : Name :=
Name.mkString n "_boxed"
abbrev N := State Nat
private def mkFresh : N VarId :=
do idx ← get,
modify (+1),
pure {idx := idx }
def requiresBoxedVersion (env : Environment) (decl : Decl) : Bool :=
let ps := decl.params in
ps.size > 0 && (decl.resultType.isScalar || ps.any (λ p, p.ty.isScalar || p.borrow) || isExtern env decl.name)
def mkBoxedVersionAux (decl : Decl) : N Decl :=
do
let ps := decl.params,
qs ← ps.mmap (λ _, do x ← mkFresh, pure { Param . x := x, ty := IRType.object, borrow := false }),
(newVDecls, xs) ← qs.size.mfold
(λ i (r : Array FnBody × Array Arg),
let (newVDecls, xs) := r in
let p := ps.get i in
let q := qs.get i in
if !p.ty.isScalar then pure (newVDecls, xs.push (Arg.var q.x))
else do
x ← mkFresh,
pure (newVDecls.push (FnBody.vdecl x p.ty (Expr.unbox q.x) (default _)), xs.push (Arg.var x)))
(Array.empty, Array.empty),
r ← mkFresh,
let newVDecls := newVDecls.push (FnBody.vdecl r decl.resultType (Expr.fap decl.name xs) (default _)),
body ←
if !decl.resultType.isScalar then do {
pure $ reshape newVDecls (FnBody.ret (Arg.var r))
} else do {
newR ← mkFresh,
let newVDecls := newVDecls.push (FnBody.vdecl newR IRType.object (Expr.box decl.resultType r) (default _)),
pure $ reshape newVDecls (FnBody.ret (Arg.var newR))
},
pure $ Decl.fdecl (mkBoxedName decl.name) qs IRType.object body
def mkBoxedVersion (decl : Decl) : Decl :=
(mkBoxedVersionAux decl).run' 1
def addBoxedVersions (env : Environment) (decls : Array Decl) : Array Decl :=
let boxedDecls := decls.foldl
(λ (newDecls : Array Decl) decl, if requiresBoxedVersion env decl then newDecls.push (mkBoxedVersion decl) else newDecls)
Array.empty in
decls ++ boxedDecls
@[export lean.ir.add_boxed_version_core]
def addBoxedVersion (env : Environment) (decl : Decl) : Environment :=
if requiresBoxedVersion env decl then
addDeclAux env (mkBoxedVersion decl)
else
env
/- Infer scrutinee type using `case` alternatives.
This can be done whenever `alts` does not contain an `Alt.default _` value. -/
def getScrutineeType (alts : Array Alt) : IRType :=
let isScalar :=
alts.size > 1 && -- Recall that we encode Unit and PUnit using `object`.
alts.all (λ alt, match alt with
| Alt.ctor c _ := c.isScalar
| Alt.default _ := false) in
match isScalar with
| false := IRType.object
| true :=
let n := alts.size in
if n < 256 then IRType.uint8
else if n < 65536 then IRType.uint16
else if n < 4294967296 then IRType.uint32
else IRType.object -- in practice this should be unreachable
def eqvTypes (t₁ t₂ : IRType) : Bool :=
(t₁.isScalar == t₂.isScalar) && (!t₁.isScalar || t₁ == t₂)
structure BoxingContext :=
(localCtx : LocalContext := {}) (resultType : IRType := IRType.irrelevant) (decls : Array Decl) (env : Environment)
abbrev M := ReaderT BoxingContext (StateT Index Id)
def mkFresh : M VarId :=
do idx ← getModify (+1),
pure { idx := idx }
def getEnv : M Environment := BoxingContext.env <$> read
def getLocalContext : M LocalContext := BoxingContext.localCtx <$> read
def getResultType : M IRType := BoxingContext.resultType <$> read
def getVarType (x : VarId) : M IRType :=
do localCtx ← getLocalContext,
match localCtx.getType x with
| some t := pure t
| none := pure IRType.object -- unreachable, we assume the code is well formed
def getJPParams (j : JoinPointId) : M (Array Param) :=
do localCtx ← getLocalContext,
match localCtx.getJPParams j with
| some ys := pure ys
| none := pure Array.empty -- unreachable, we assume the code is well formed
def getDecl (fid : FunId) : M Decl :=
do ctx ← read,
match findEnvDecl' ctx.env fid ctx.decls with
| some decl := pure decl
| none := pure (default _) -- unreachable if well-formed
@[inline] def withParams {α : Type} (xs : Array Param) (k : M α) : M α :=
adaptReader (λ ctx : BoxingContext, { localCtx := ctx.localCtx.addParams xs, .. ctx }) k
@[inline] def withVDecl {α : Type} (x : VarId) (ty : IRType) (v : Expr) (k : M α) : M α :=
adaptReader (λ ctx : BoxingContext, { localCtx := ctx.localCtx.addLocal x ty v, .. ctx }) k
@[inline] def withJDecl {α : Type} (j : JoinPointId) (xs : Array Param) (v : FnBody) (k : M α) : M α :=
adaptReader (λ ctx : BoxingContext, { localCtx := ctx.localCtx.addJP j xs v, .. ctx }) k
/- Auxiliary function used by castVarIfNeeded.
It is used when the expected type does not match `xType`.
If `xType` is scalar, then we need to "box" it. Otherwise, we need to "unbox" it. -/
def mkCast (x : VarId) (xType : IRType) : Expr :=
if xType.isScalar then Expr.box xType x else Expr.unbox x
@[inline] def castVarIfNeeded (x : VarId) (expected : IRType) (k : VarId → M FnBody) : M FnBody :=
do xType ← getVarType x,
if eqvTypes xType expected then k x
else do
y ← mkFresh,
let v := mkCast x xType,
FnBody.vdecl y expected v <$> k y
@[inline] def castArgIfNeeded (x : Arg) (expected : IRType) (k : Arg → M FnBody) : M FnBody :=
match x with
| Arg.var x := castVarIfNeeded x expected (λ x, k (Arg.var x))
| _ := k x
@[specialize] def castArgsIfNeededAux (xs : Array Arg) (typeFromIdx : Nat → IRType) : M (Array Arg × Array FnBody) :=
xs.miterate (Array.empty, Array.empty) $ λ i (x : Arg) (r : Array Arg × Array FnBody),
let expected := typeFromIdx i.val in
let (xs, bs) := r in
match x with
| Arg.irrelevant := pure (xs.push x, bs)
| Arg.var x := do
xType ← getVarType x,
if eqvTypes xType expected then pure (xs.push (Arg.var x), bs)
else do
y ← mkFresh,
let v := mkCast x xType,
let b := FnBody.vdecl y expected v FnBody.nil,
pure (xs.push (Arg.var y), bs.push b)
@[inline] def castArgsIfNeeded (xs : Array Arg) (ps : Array Param) (k : Array Arg → M FnBody) : M FnBody :=
do (ys, bs) ← castArgsIfNeededAux xs (λ i, (ps.get i).ty),
b ← k ys,
pure (reshape bs b)
@[inline] def boxArgsIfNeeded (xs : Array Arg) (k : Array Arg → M FnBody) : M FnBody :=
do (ys, bs) ← castArgsIfNeededAux xs (λ _, IRType.object),
b ← k ys,
pure (reshape bs b)
def unboxResultIfNeeded (x : VarId) (ty : IRType) (e : Expr) (b : FnBody) : M FnBody :=
if ty.isScalar then do
y ← mkFresh,
pure $ FnBody.vdecl y IRType.object e (FnBody.vdecl x ty (Expr.unbox y) b)
else
pure $ FnBody.vdecl x ty e b
def castResultIfNeeded (x : VarId) (ty : IRType) (e : Expr) (eType : IRType) (b : FnBody) : M FnBody :=
if eqvTypes ty eType then pure $ FnBody.vdecl x ty e b
else do
y ← mkFresh,
pure $ FnBody.vdecl y eType e (FnBody.vdecl x ty (mkCast y eType) b)
def visitVDeclExpr (x : VarId) (ty : IRType) (e : Expr) (b : FnBody) : M FnBody :=
match e with
| Expr.ctor c ys :=
if c.isScalar && ty.isScalar then
pure $ FnBody.vdecl x ty (Expr.lit (LitVal.num c.cidx)) b
else
boxArgsIfNeeded ys $ λ ys, pure $ FnBody.vdecl x ty (Expr.ctor c ys) b
| Expr.reuse w c u ys :=
boxArgsIfNeeded ys $ λ ys, pure $ FnBody.vdecl x ty (Expr.reuse w c u ys) b
| Expr.fap f ys := do
decl ← getDecl f,
castArgsIfNeeded ys decl.params $ λ ys,
castResultIfNeeded x ty (Expr.fap f ys) decl.resultType b
| Expr.pap f ys := do
env ← getEnv,
decl ← getDecl f,
let f := if requiresBoxedVersion env decl then mkBoxedName f else f,
boxArgsIfNeeded ys $ λ ys, pure $ FnBody.vdecl x ty (Expr.pap f ys) b
| Expr.ap f ys :=
boxArgsIfNeeded ys $ λ ys,
unboxResultIfNeeded x ty (Expr.ap f ys) b
| other :=
pure $ FnBody.vdecl x ty e b
partial def visitFnBody : FnBody → M FnBody
| (FnBody.vdecl x t v b) := do
b ← withVDecl x t v (visitFnBody b),
visitVDeclExpr x t v b
| (FnBody.jdecl j xs v b) := do
v ← withParams xs (visitFnBody v),
b ← withJDecl j xs v (visitFnBody b),
pure $ FnBody.jdecl j xs v b
| (FnBody.uset x i y b) := do
b ← visitFnBody b,
castVarIfNeeded y IRType.usize $ λ y,
pure $ FnBody.uset x i y b
| (FnBody.sset x i o y ty b) := do
b ← visitFnBody b,
castVarIfNeeded y ty $ λ y,
pure $ FnBody.sset x i o y ty b
| (FnBody.mdata d b) :=
FnBody.mdata d <$> visitFnBody b
| (FnBody.case tid x alts) := do
let expected := getScrutineeType alts,
alts ← alts.mmap $ λ alt, alt.mmodifyBody visitFnBody,
castVarIfNeeded x expected $ λ x,
pure $ FnBody.case tid x alts
| (FnBody.ret x) := do
expected ← getResultType,
castArgIfNeeded x expected (λ x, pure $ FnBody.ret x)
| (FnBody.jmp j ys) := do
ps ← getJPParams j,
castArgsIfNeeded ys ps (λ ys, pure $ FnBody.jmp j ys)
| other :=
pure other
def run (env : Environment) (decls : Array Decl) : Array Decl :=
let ctx : BoxingContext := { decls := decls, env := env } in
let decls := decls.map (λ decl, match decl with
| Decl.fdecl f xs t b :=
let nextIdx := decl.maxIndex + 1 in
let b := (withParams xs (visitFnBody b) { resultType := t, .. ctx }).run' nextIdx in
Decl.fdecl f xs t b
| d := d) in
addBoxedVersions env decls
end ExplicitBoxing
def explicitBoxing (decls : Array Decl) : CompilerM (Array Decl) :=
do env ← getEnv,
pure $ ExplicitBoxing.run env decls
end IR
end Lean