feat(library/init/compiler/ir): use run_state and run_reader
This commit is contained in:
parent
9389176380
commit
1fd399d06f
3 changed files with 83 additions and 46 deletions
|
|
@ -117,6 +117,19 @@ instance [monad m] : monad_reader_adapter ρ ρ' (reader_t ρ m) (reader_t ρ' m
|
|||
⟨λ α, reader_t.adapt⟩
|
||||
end
|
||||
|
||||
|
||||
instance (ρ : Type u) (m out) [monad_run out m] : monad_run (λ α, ρ → out α) (reader_t ρ m) :=
|
||||
⟨λ α x, run ∘ x.run⟩
|
||||
|
||||
class monad_reader_runner (ρ : Type u) (m m' : Type u → Type u) :=
|
||||
(run_reader {} {α : Type u} : m α → ρ → m' α)
|
||||
export monad_reader_runner (run_reader)
|
||||
|
||||
section
|
||||
variables {ρ ρ' : Type u} {m m' : Type u → Type u}
|
||||
|
||||
instance monad_reader_runner_trans {n n' : Type u → Type u} [monad_functor m m' n n'] [monad_reader_runner ρ m m'] : monad_reader_runner ρ n n' :=
|
||||
⟨λ α x r, monad_map (λ α (y : m α), (run_reader y r : m' α)) x⟩
|
||||
|
||||
instance reader_t.monad_state_runner [monad m] : monad_reader_runner ρ (reader_t ρ m) m :=
|
||||
⟨λ α x r, x.run r⟩
|
||||
end
|
||||
|
|
|
|||
|
|
@ -170,6 +170,19 @@ instance [monad m] : monad_state_adapter σ σ' (state_t σ m) (state_t σ' m) :
|
|||
⟨λ σ'' α, state_t.adapt⟩
|
||||
end
|
||||
|
||||
|
||||
instance (σ m out) [monad_run out m] : monad_run (λ α, σ → out (α × σ)) (state_t σ m) :=
|
||||
⟨λ α x, run ∘ (λ σ, x.run σ)⟩
|
||||
|
||||
class monad_state_runner (σ : Type u) (m m' : Type u → Type u) :=
|
||||
(run_state {} {α : Type u} : m α → σ → m' α)
|
||||
export monad_state_runner (run_state)
|
||||
|
||||
section
|
||||
variables {σ σ' : Type u} {m m' : Type u → Type u}
|
||||
|
||||
instance monad_state_runner_trans {n n' : Type u → Type u} [monad_functor m m' n n'] [monad_state_runner σ m m'] : monad_state_runner σ n n' :=
|
||||
⟨λ α x s, monad_map (λ α (y : m α), (run_state y s : m' α)) x⟩
|
||||
|
||||
instance state_t.monad_state_runner [monad m] : monad_state_runner σ (state_t σ m) m :=
|
||||
⟨λ α x s, prod.fst <$> x.run s⟩
|
||||
end
|
||||
|
|
|
|||
|
|
@ -47,10 +47,10 @@ show has_lt string, from infer_instance
|
|||
def var_set := rbtree var (<)
|
||||
def blockid_set := rbtree blockid (<)
|
||||
def context := rbmap var type (<)
|
||||
def mk_var_set := mk_rbtree var (<)
|
||||
def mk_blockid_set := mk_rbtree blockid (<)
|
||||
def var2blockid := rbmap var blockid (<)
|
||||
def mk_var2blockid := mk_rbmap var blockid (<)
|
||||
def mk_var_set : var_set := mk_rbtree var (<)
|
||||
def mk_blockid_set : blockid_set := mk_rbtree blockid (<)
|
||||
def mk_var2blockid : var2blockid := mk_rbmap var blockid (<)
|
||||
|
||||
inductive instr
|
||||
| lit (x : var) (ty : type) (lit : literal) -- x : ty := lit
|
||||
|
|
@ -107,12 +107,19 @@ SSA validator
|
|||
@[reducible] def ssa_check : Type :=
|
||||
except_t string (state (var2blockid × var_set)) unit
|
||||
|
||||
def var.declare_at (b : blockid) (x : var) : ssa_check :=
|
||||
do (m, s) ← get,
|
||||
if m.contains x then throw ("variable has already been defined '" ++ x ++ "'")
|
||||
else put (m.insert x b, s)
|
||||
inductive ssa_error
|
||||
| already_defined (v : var)
|
||||
| undefined (v : var)
|
||||
| no_block
|
||||
|
||||
def instr.declare_vars_at (b : blockid) : instr → ssa_check
|
||||
@[reducible] def ssa_decl_m := except_t ssa_error (state_t var2blockid id)
|
||||
|
||||
def var.declare_at (b : blockid) (x : var) : ssa_decl_m unit :=
|
||||
do m ← get,
|
||||
if m.contains x then throw $ ssa_error.already_defined x
|
||||
else put (m.insert x b)
|
||||
|
||||
def instr.declare_vars_at (b : blockid) : instr → ssa_decl_m unit
|
||||
| (instr.lit x _ _) := x.declare_at b
|
||||
| (instr.cast x _ _) := x.declare_at b
|
||||
| (instr.unop x _ _ _) := x.declare_at b
|
||||
|
|
@ -131,57 +138,59 @@ def instr.declare_vars_at (b : blockid) : instr → ssa_check
|
|||
| (instr.sread x _ _ _) := x.declare_at b
|
||||
| _ := return ()
|
||||
|
||||
def phi.declare_at (b : blockid) : phi → ssa_check
|
||||
def phi.declare_at (b : blockid) : phi → ssa_decl_m unit
|
||||
| {x := x, ..} := x.declare_at b
|
||||
|
||||
def block.declare_vars : block → ssa_check
|
||||
def block.declare_vars : block → ssa_decl_m unit
|
||||
| {id := b, phis := ps, instrs := is, ..} :=
|
||||
ps.mmap' (phi.declare_at b) >>
|
||||
is.mmap' (instr.declare_vars_at b)
|
||||
|
||||
def arg.declare_at (b : blockid) : arg → ssa_check
|
||||
def arg.declare_at (b : blockid) : arg → ssa_decl_m unit
|
||||
| {n := x, ..} := x.declare_at b
|
||||
|
||||
/- Collect where each variable is declared, and
|
||||
check whether each variable was declared at most once. -/
|
||||
def decl.declare_vars : decl → ssa_check
|
||||
def decl.declare_vars : decl → ssa_decl_m unit
|
||||
| {as := as, bs := b::bs, ..} :=
|
||||
/- We assume that arguments are declared in the first basic block.
|
||||
TODO: check whether this assumption matches LLVM or not -/
|
||||
as.mmap' (arg.declare_at b.id) >>
|
||||
b.declare_vars >>
|
||||
bs.mmap' block.declare_vars
|
||||
| _ := throw "declaration must have at least one block"
|
||||
| _ := throw ssa_error.no_block
|
||||
|
||||
/- Generate the mapping from variable to blockid for the given declaration.
|
||||
This function assumes `d` is in SSA. -/
|
||||
def decl.var2blockid (d : decl) : var2blockid :=
|
||||
let (_, (m, _)) := d.declare_vars.run.run (mk_var2blockid, mk_var_set) in
|
||||
m
|
||||
def decl.var2blockid (d : decl) : except_t ssa_error id var2blockid :=
|
||||
run_state (d.declare_vars >> get) mk_var2blockid
|
||||
|
||||
@[reducible] def ssa_valid_m := except_t ssa_error (reader_t var2blockid (state_t var_set id))
|
||||
|
||||
/- Given, x := phi ys,
|
||||
check whether every ys is declared at the var2blockid mapping,
|
||||
and update the set of already defined variables in the basic block with `x`.
|
||||
|
||||
TODO: check whether the SSA validation rules here match the ones used in LLVM. -/
|
||||
def phi.valid_ssa : phi → ssa_check
|
||||
def phi.valid_ssa : phi → ssa_valid_m unit
|
||||
| {x := x, ys := ys, ..} := do
|
||||
(m, s) ← get,
|
||||
m ← read,
|
||||
ys.mmap' (λ y, if m.contains y then return ()
|
||||
else throw ("undefined '" ++ y ++ "'")),
|
||||
put (m, s.insert x)
|
||||
else throw $ ssa_error.undefined y),
|
||||
s ← get,
|
||||
put (s.insert x)
|
||||
|
||||
/- Check whether `x` has been already defined in the current basic block or not. -/
|
||||
def var.defined (x : var) : ssa_check :=
|
||||
do (_, s) ← get,
|
||||
def var.defined (x : var) : ssa_valid_m unit :=
|
||||
do s ← get,
|
||||
if s.contains x then return ()
|
||||
else throw ("undefined variable '" ++ x ++ "'")
|
||||
else throw $ ssa_error.undefined x
|
||||
|
||||
/- Mark `x` as a variable defined in the current basic block. -/
|
||||
def var.define (x : var) : ssa_check :=
|
||||
do (m, s) ← get, put (m, s.insert x)
|
||||
def var.define (x : var) : ssa_valid_m unit :=
|
||||
do s ← get, put (s.insert x)
|
||||
|
||||
def instr.valid_ssa : instr → ssa_check
|
||||
def instr.valid_ssa : instr → ssa_valid_m unit
|
||||
| (instr.lit x _ _) := x.define
|
||||
| (instr.cast x _ y) := x.define >> y.defined
|
||||
| (instr.unop x _ _ y) := x.define >> y.defined
|
||||
|
|
@ -205,63 +214,65 @@ def instr.valid_ssa : instr → ssa_check
|
|||
| (instr.dealloc x) := x.defined
|
||||
| (instr.dec x) := x.defined
|
||||
|
||||
def terminator.valid_ssa : terminator → ssa_check
|
||||
def terminator.valid_ssa : terminator → ssa_valid_m unit
|
||||
| (terminator.ret ys) := ys.mmap' var.defined
|
||||
| (terminator.case x _) := x.defined
|
||||
| (terminator.jmp _) := return ()
|
||||
|
||||
def block.valid_ssa : block → ssa_check
|
||||
def block.valid_ssa_core : block → ssa_valid_m unit
|
||||
| {phis := ps, instrs := is, term := r, ..} :=
|
||||
do (m, s) ← get,
|
||||
put (m, mk_var_set), -- reset set of variables defined in the current block
|
||||
ps.mmap' phi.valid_ssa,
|
||||
do ps.mmap' phi.valid_ssa,
|
||||
is.mmap' instr.valid_ssa,
|
||||
r.valid_ssa
|
||||
|
||||
def block.valid_ssa (b : block) : except_t ssa_error (reader_t var2blockid id) unit :=
|
||||
run_state b.valid_ssa_core mk_var_set
|
||||
|
||||
/-
|
||||
We first check whether every variable `x` was declared only once
|
||||
and store the blockid where `x` is defined (action: `decl.declare_vars`).
|
||||
Then, we check whether every used variable in basic block has been
|
||||
defined before being used.
|
||||
-/
|
||||
def decl.valid_ssa : decl → ssa_check
|
||||
def decl.valid_ssa : decl → except_t ssa_error id var2blockid
|
||||
| d@{as := as, bs := bs, ..} :=
|
||||
d.declare_vars >> bs.mmap' block.valid_ssa
|
||||
do m ← d.var2blockid,
|
||||
bs.mmap' (λ b : block, run_reader b.valid_ssa m),
|
||||
return m
|
||||
|
||||
def valid_ssa (d : decl) : bool :=
|
||||
let (e, _) := d.valid_ssa.run.run (mk_var2blockid, mk_var_set)
|
||||
in e.to_bool
|
||||
d.valid_ssa.run.to_bool
|
||||
|
||||
/- Check blockids -/
|
||||
@[reducible] def blockid_check : Type :=
|
||||
except_t string (state blockid_set) unit
|
||||
@[reducible] def blockid_check_m :=
|
||||
except_t string (state blockid_set)
|
||||
|
||||
def block.declare : block → blockid_check
|
||||
def block.declare : block → blockid_check_m unit
|
||||
| {id := id, ..} :=
|
||||
do s ← get,
|
||||
if s.contains id then throw ("blockid '" ++ id ++ "' has already been used")
|
||||
else put (s.insert id)
|
||||
|
||||
def blockid.defined (bid : blockid) : blockid_check :=
|
||||
def blockid.defined (bid : blockid) : blockid_check_m unit :=
|
||||
do s ← get,
|
||||
if s.contains bid then return ()
|
||||
else throw ("unknown blockid '" ++ bid ++ "'")
|
||||
|
||||
def terminator.check_blockids : terminator → blockid_check
|
||||
def terminator.check_blockids : terminator → blockid_check_m unit
|
||||
| (terminator.ret ys) := return ()
|
||||
| (terminator.case _ bids) := bids.mmap' blockid.defined
|
||||
| (terminator.jmp bid) := bid.defined
|
||||
|
||||
def block.check_blockids : block → blockid_check
|
||||
def block.check_blockids : block → blockid_check_m unit
|
||||
| {term := r, ..} := r.check_blockids
|
||||
|
||||
def decl.check_blockids : decl → blockid_check
|
||||
def decl.check_blockids : decl → blockid_check_m unit
|
||||
| {bs := bs, ..} :=
|
||||
bs.mmap' block.declare >> bs.mmap' block.check_blockids
|
||||
|
||||
def check_blockids (d : decl) : bool :=
|
||||
let (e, _) := d.check_blockids.run.run mk_blockid_set
|
||||
in e.to_bool
|
||||
let r : except_t string id unit := run_state d.check_blockids mk_blockid_set in
|
||||
r.run.to_bool
|
||||
|
||||
/-
|
||||
TODO: type inference
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue