diff --git a/library/init/compiler/ir.lean b/library/init/compiler/ir.lean index d92fcc5172..57028929db 100644 --- a/library/init/compiler/ir.lean +++ b/library/init/compiler/ir.lean @@ -97,10 +97,8 @@ structure decl := (n : fid) (as : list arg) (rs : list result) (bs : list block) /- -Check whether every variable is defined only once. +SSA validator -/ -section ssa_validation - @[reducible] def ssa_check : Type := except_t unit (state var_set) unit @@ -153,7 +151,51 @@ def valid_ssa (d : decl) : bool := let (e, _) := d.valid_ssa.run.run mk_var_set in e.to_bool -end ssa_validation +/- Collect used variables -/ +@[reducible] def collector : Type := +state var_set unit + +def var.collect (x : var) : collector := +do s ← get, put (s.insert x) + +def instr.collect_vars : instr → collector +| (instr.lit x _ _) := x.collect +| (instr.cast x _ y) := x.collect >> y.collect +| (instr.unop x _ _ y) := x.collect >> y.collect +| (instr.binop x _ _ y z) := x.collect >> y.collect >> z.collect +| (instr.call xs _ ys) := xs.mmap' var.collect >> ys.mmap' var.collect +| (instr.phi x ps) := x.collect >> ps.mmap' (λ ⟨y, _⟩, y.collect) +| (instr.cnstr o _ _ _) := o.collect +| (instr.set o _ x) := o.collect >> x.collect +| (instr.get x y _) := x.collect >> y.collect +| (instr.sets o _ x) := o.collect >> x.collect +| (instr.gets x _ y _) := x.collect >> y.collect +| (instr.closure x _ ys) := x.collect >> ys.mmap' var.collect +| (instr.apply x ys) := x.collect >> ys.mmap' var.collect +| (instr.array a sz c) := a.collect >> sz.collect >> c.collect +| (instr.write a i v) := a.collect >> i.collect >> v.collect +| (instr.read x a i) := x.collect >> a.collect >> i.collect +| (instr.sarray x _ sz c) := x.collect >> sz.collect >> c.collect +| (instr.swrite a i v) := a.collect >> i.collect >> v.collect +| (instr.sread x _ a i) := x.collect >> a.collect >> i.collect +| (instr.inc x) := x.collect +| (instr.decs x) := x.collect +| (instr.dealloc x) := x.collect +| (instr.dec x) := x.collect + +def arg.collect : arg → collector +| {n := x, ..} := x.collect + +def block.collect_vars : block → collector +| {instrs := is, ..} := is.mmap' instr.collect_vars + +def decl.collect_vars : decl → collector +| {as := as, bs := bs, ..} := + as.mmap' arg.collect >> bs.mmap' block.collect_vars + +def collect_vars (d : decl) : var_set := +let (_, r) := d.collect_vars.run mk_var_set +in r /- TODO: