lean4-htt/library/init/lean/ir/ssa_check.lean
Leonardo de Moura af1a5fe874 feat(library/init/lean/ir): add x : ty := y instruction
It is useful when we are not producing IR in SSA.
2018-05-17 15:44:13 -07:00

189 lines
7.1 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2018 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import init.lean.ir.instances init.lean.ir.format
namespace lean
namespace ir
@[reducible] def ssa_pre_m := except_t format (state_t var2blockid id)
def var.declare (x : var) : reader_t blockid ssa_pre_m unit :=
do m ← get,
if m.contains x then throw ("already defined " ++ to_fmt x)
else do b ← read, put (m.insert x b)
def instr.declare_vars : instr → reader_t blockid ssa_pre_m unit
| (instr.assign x _ _) := x.declare
| (instr.assign_lit x _ _) := x.declare
| (instr.assign_unop x _ _ _) := x.declare
| (instr.assign_binop x _ _ _ _) := x.declare
| (instr.call xs _ _) := xs.mfor var.declare
| (instr.cnstr o _ _ _) := o.declare
| (instr.get x _ _) := x.declare
| (instr.sget x _ _ _) := x.declare
| (instr.closure x _ _) := x.declare
| (instr.apply x _) := x.declare
| (instr.array a _ _) := a.declare
| (instr.sarray x _ _ _) := x.declare
| _ := return ()
def phi.declare (p : phi) : reader_t blockid ssa_pre_m unit :=
p.decorate_error p.x.declare
def block.declare_vars (b : block) : ssa_pre_m unit :=
b.decorate_error $ (b.phis.mfor phi.declare >> b.instrs.mfor instr.declare_vars).run b.id
def arg.declare (a : arg) : reader_t blockid ssa_pre_m unit :=
a.n.declare
/- Collect where each variable is declared, and
check whether each variable was declared at most once. -/
def decl.declare_vars : decl → ssa_pre_m unit
| (decl.defn h (b::bs)) :=
/- We assume that arguments are declared in the first basic block. -/
h.decorate_error $ (h.args.mfor arg.declare).run b.id >> b.declare_vars >> bs.mfor block.declare_vars
| (decl.defn _ []) := throw "declaration must have at least one basic block"
| _ := return ()
/- Generate the mapping from variable to blockid for the given declaration.
This function assumes `d` is in SSA. -/
def decl.var2blockid (d : decl) : except format var2blockid :=
run (d.declare_vars >> get) mk_var2blockid
@[reducible] def ssa_valid_m := except_t format (reader_t var2blockid (state_t var_set id))
def ssa_valid_m.run {α} (a : ssa_valid_m α) (m : var2blockid) : except format α :=
run a m mk_var_set
/- Mark `x` as a variable defined in the current basic block. -/
def var.define (x : var) : ssa_valid_m unit :=
modify $ λ s, s.insert x
def arg.define (a : arg) : ssa_valid_m unit :=
a.n.define
/- Check whether `x` has been already defined in the current basic block or not. -/
def var.defined (x : var) : ssa_valid_m unit :=
do s ← get,
if s.contains x then return ()
else throw ("undefined '" ++ to_fmt x ++ "'")
/- Given, x := phi ys,
check whether every ys is declared at the var2blockid mapping,
and update the set of already defined variables in the basic block with `x`. -/
def phi.valid_ssa (p : phi) : ssa_valid_m unit :=
p.decorate_error $
do m ← read,
p.ys.mfor $ λ y, unless (m.contains y) $ throw ("undefined '" ++ to_fmt y ++ "'"),
p.x.define
def instr.valid_ssa (ins : instr) : ssa_valid_m unit :=
ins.decorate_error $
match ins with
| (instr.assign x _ y) := x.define >> y.defined
| (instr.assign_lit x _ _) := x.define
| (instr.assign_unop x _ _ y) := x.define >> y.defined
| (instr.assign_binop x _ _ y z) := x.define >> y.defined >> z.defined
| (instr.unop _ x) := x.defined
| (instr.call xs _ ys) := xs.mfor var.define >> ys.mfor var.defined
| (instr.cnstr o _ _ _) := o.define
| (instr.set o _ x) := o.defined >> x.defined
| (instr.get x y _) := x.define >> y.defined
| (instr.sset o _ x) := o.defined >> x.defined
| (instr.sget x _ y _) := x.define >> y.defined
| (instr.closure x _ ys) := x.define >> ys.mfor var.defined
| (instr.apply x ys) := x.define >> ys.mfor var.defined
| (instr.array a sz c) := a.define >> sz.defined >> c.defined
| (instr.sarray x _ sz c) := x.define >> sz.defined >> c.defined
| (instr.array_write a i v) := a.defined >> i.defined >> v.defined
def terminator.valid_ssa (term : terminator) : ssa_valid_m unit :=
term.decorate_error $
match term with
| (terminator.ret ys) := ys.mfor var.defined
| (terminator.case x _) := x.defined
| (terminator.jmp _) := return ()
def phi.predecessors (p : phi) : ssa_valid_m blockid_set :=
p.ys.mfoldl (λ s y,
do m ← read,
match m.find y with
| some bid := if s.contains bid
then throw ("multiple predecessors at '" ++ to_fmt p ++ "'")
else return $ (s.insert bid)
| none := throw ("undefined '" ++ to_fmt y ++ "' at '" ++ to_fmt p ++ "'"))
mk_blockid_set
def phis.check_predecessors (ps : list phi) : ssa_valid_m unit :=
do ps.mfoldl (λ (os : option blockid_set) (p : phi),
p.decorate_error $
do s' ← p.predecessors,
match os with
| (some s) := if s.seteq s' then return os
else throw ("missing predecessor '" ++ to_fmt p.x ++ "' at '" ++ to_fmt p ++ "'")
| none := return (some s'))
none,
return ()
def block.valid_ssa_core (b : block) : ssa_valid_m unit :=
b.decorate_error $
do phis.check_predecessors b.phis,
b.phis.mfor phi.valid_ssa,
b.instrs.mfor instr.valid_ssa,
b.term.valid_ssa
/-
We first check whether every variable `x` was declared only once
and store the blockid where `x` is defined (action: `decl.declare_vars`).
Then, we check whether every used variable in basic block has been
defined before being used.
-/
def decl.valid_ssa (d : decl) : except format var2blockid :=
d.decorate_error $
do m ← d.var2blockid,
match d with
| decl.defn {args:=args, ..} (b::bs) :=
(args.mfor arg.define >> block.valid_ssa_core b).run m
>> (bs.mfor block.valid_ssa_core).run m
>> return m
| _ := return m
/- Check blockids -/
@[reducible] def blockid_check_m :=
except_t format (state blockid_set)
def blockid_check_m.run {α} (a : blockid_check_m α) : except format α :=
run a mk_blockid_set
def block.declare (b : block) : blockid_check_m unit :=
do s ← get,
if s.contains b.id then throw $ "block label '" ++ to_fmt b.id ++ "' has been used more than once"
else put (s.insert b.id)
def blockid.defined (bid : blockid) : blockid_check_m unit :=
do s ← get,
if s.contains bid then return ()
else throw $ "unknown basic block '" ++ to_fmt bid ++ "'"
def terminator.check_blockids (term : terminator) : blockid_check_m unit :=
term.decorate_error $
match term with
| (terminator.ret ys) := return ()
| (terminator.case _ bids) := bids.mfor blockid.defined
| (terminator.jmp bid) := bid.defined
def block.check_blockids (b : block) : blockid_check_m unit :=
b.term.check_blockids
def decl.check_blockids : decl → blockid_check_m unit
| (decl.defn h bs) := h.decorate_error $ bs.mfor block.declare >> bs.mfor block.check_blockids
| _ := return ()
def check_blockids (d : decl) : except format blockid_set :=
(d.check_blockids >> get).run
end ir
end lean