lean4-htt/src/Lean/Compiler/IR/LiveVars.lean
2020-10-25 09:54:07 -07:00

169 lines
6.8 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.Compiler.IR.Basic
import Lean.Compiler.IR.FreeVars
namespace Lean.IR
/- Remark: in the paper "Counting Immutable Beans" the concepts of
free and live variables coincide because the paper does *not* consider
join points. For example, consider the function body `B`
```
let x := ctor_0;
jmp block_1 x
```
in a context where we have the join point `block_1` defined as
```
block_1 (x : obj) : obj :=
let z := ctor_0 x y;
ret z
``
The variable `y` is live in the function body `B` since it occurs in
`block_1` which is "invoked" by `B`.
-/
namespace IsLive
/-
We use `State Context` instead of `ReaderT Context Id` because we remove
non local joint points from `Context` whenever we visit them instead of
maintaining a set of visited non local join points.
Remark: we don't need to track local join points because we assume there is
no variable or join point shadowing in our IR.
-/
abbrev M := StateM LocalContext
abbrev visitVar (w : Index) (x : VarId) : M Bool := pure (HasIndex.visitVar w x)
abbrev visitJP (w : Index) (x : JoinPointId) : M Bool := pure (HasIndex.visitJP w x)
abbrev visitArg (w : Index) (a : Arg) : M Bool := pure (HasIndex.visitArg w a)
abbrev visitArgs (w : Index) (as : Array Arg) : M Bool := pure (HasIndex.visitArgs w as)
abbrev visitExpr (w : Index) (e : Expr) : M Bool := pure (HasIndex.visitExpr w e)
partial def visitFnBody (w : Index) : FnBody → M Bool
| FnBody.vdecl x _ v b => visitExpr w v <||> visitFnBody w b
| FnBody.jdecl j ys v b => visitFnBody w v <||> visitFnBody w b
| FnBody.set x _ y b => visitVar w x <||> visitArg w y <||> visitFnBody w b
| FnBody.uset x _ y b => visitVar w x <||> visitVar w y <||> visitFnBody w b
| FnBody.sset x _ _ y _ b => visitVar w x <||> visitVar w y <||> visitFnBody w b
| FnBody.setTag x _ b => visitVar w x <||> visitFnBody w b
| FnBody.inc x _ _ _ b => visitVar w x <||> visitFnBody w b
| FnBody.dec x _ _ _ b => visitVar w x <||> visitFnBody w b
| FnBody.del x b => visitVar w x <||> visitFnBody w b
| FnBody.mdata _ b => visitFnBody w b
| FnBody.jmp j ys => visitArgs w ys <||> do
let ctx ← get
match ctx.getJPBody j with
| some b =>
-- `j` is not a local join point since we assume we cannot shadow join point declarations.
-- Instead of marking the join points that we have already been visited, we permanently remove `j` from the context.
set (ctx.eraseJoinPointDecl j) *> visitFnBody w b
| none =>
-- `j` must be a local join point. So do nothing since we have already visite its body.
pure false
| FnBody.ret x => visitArg w x
| FnBody.case _ x _ alts => visitVar w x <||> alts.anyM (fun alt => visitFnBody w alt.body)
| FnBody.unreachable => pure false
end IsLive
/- Return true if `x` is live in the function body `b` in the context `ctx`.
Remark: the context only needs to contain all (free) join point declarations.
Recall that we say that a join point `j` is free in `b` if `b` contains
`FnBody.jmp j ys` and `j` is not local. -/
def FnBody.hasLiveVar (b : FnBody) (ctx : LocalContext) (x : VarId) : Bool :=
(IsLive.visitFnBody x.idx b).run' ctx
abbrev LiveVarSet := VarIdSet
abbrev JPLiveVarMap := Std.RBMap JoinPointId LiveVarSet (fun j₁ j₂ => j₁.idx < j₂.idx)
instance : Inhabited LiveVarSet := ⟨{}⟩
def mkLiveVarSet (x : VarId) : LiveVarSet :=
Std.RBTree.empty.insert x
namespace LiveVars
abbrev Collector := LiveVarSet → LiveVarSet
@[inline] private def skip : Collector := fun s => s
@[inline] private def collectVar (x : VarId) : Collector := fun s => s.insert x
private def collectArg : Arg → Collector
| Arg.var x => collectVar x
| irrelevant => skip
@[specialize] private def collectArray {α : Type} (as : Array α) (f : α → Collector) : Collector := fun s =>
as.foldl (fun s a => f a s) s
private def collectArgs (as : Array Arg) : Collector :=
collectArray as collectArg
private def accumulate (s' : LiveVarSet) : Collector :=
fun s => s'.fold (fun s x => s.insert x) s
private def collectJP (m : JPLiveVarMap) (j : JoinPointId) : Collector :=
match m.find? j with
| some xs => accumulate xs
| none => skip -- unreachable for well-formed code
private def bindVar (x : VarId) : Collector := fun s =>
s.erase x
private def bindParams (ps : Array Param) : Collector := fun s =>
ps.foldl (fun s p => s.erase p.x) s
def collectExpr : Expr → Collector
| Expr.ctor _ ys => collectArgs ys
| Expr.reset _ x => collectVar x
| Expr.reuse x _ _ ys => collectVar x ∘ collectArgs ys
| Expr.proj _ x => collectVar x
| Expr.uproj _ x => collectVar x
| Expr.sproj _ _ x => collectVar x
| Expr.fap _ ys => collectArgs ys
| Expr.pap _ ys => collectArgs ys
| Expr.ap x ys => collectVar x ∘ collectArgs ys
| Expr.box _ x => collectVar x
| Expr.unbox x => collectVar x
| Expr.lit v => skip
| Expr.isShared x => collectVar x
| Expr.isTaggedPtr x => collectVar x
partial def collectFnBody : FnBody → JPLiveVarMap → Collector
| FnBody.vdecl x _ v b, m => collectExpr v ∘ bindVar x ∘ collectFnBody b m
| FnBody.jdecl j ys v b, m =>
let jLiveVars := (bindParams ys ∘ collectFnBody v m) {};
let m := m.insert j jLiveVars;
collectFnBody b m
| FnBody.set x _ y b, m => collectVar x ∘ collectArg y ∘ collectFnBody b m
| FnBody.setTag x _ b, m => collectVar x ∘ collectFnBody b m
| FnBody.uset x _ y b, m => collectVar x ∘ collectVar y ∘ collectFnBody b m
| FnBody.sset x _ _ y _ b, m => collectVar x ∘ collectVar y ∘ collectFnBody b m
| FnBody.inc x _ _ _ b, m => collectVar x ∘ collectFnBody b m
| FnBody.dec x _ _ _ b, m => collectVar x ∘ collectFnBody b m
| FnBody.del x b, m => collectVar x ∘ collectFnBody b m
| FnBody.mdata _ b, m => collectFnBody b m
| FnBody.ret x, m => collectArg x
| FnBody.case _ x _ alts, m => collectVar x ∘ collectArray alts (fun alt => collectFnBody alt.body m)
| FnBody.unreachable, m => skip
| FnBody.jmp j xs, m => collectJP m j ∘ collectArgs xs
def updateJPLiveVarMap (j : JoinPointId) (ys : Array Param) (v : FnBody) (m : JPLiveVarMap) : JPLiveVarMap :=
let jLiveVars := (bindParams ys ∘ collectFnBody v m) {};
m.insert j jLiveVars
end LiveVars
def updateLiveVars (e : Expr) (v : LiveVarSet) : LiveVarSet :=
LiveVars.collectExpr e v
def collectLiveVars (b : FnBody) (m : JPLiveVarMap) (v : LiveVarSet := {}) : LiveVarSet :=
LiveVars.collectFnBody b m v
export LiveVars (updateJPLiveVarMap)
end Lean.IR