diff --git a/library/init/category/reader.lean b/library/init/category/reader.lean index d40a16c115..b13b12910c 100644 --- a/library/init/category/reader.lean +++ b/library/init/category/reader.lean @@ -117,6 +117,19 @@ instance [monad m] : monad_reader_adapter ρ ρ' (reader_t ρ m) (reader_t ρ' m ⟨λ α, reader_t.adapt⟩ end - instance (ρ : Type u) (m out) [monad_run out m] : monad_run (λ α, ρ → out α) (reader_t ρ m) := ⟨λ α x, run ∘ x.run⟩ + +class monad_reader_runner (ρ : Type u) (m m' : Type u → Type u) := +(run_reader {} {α : Type u} : m α → ρ → m' α) +export monad_reader_runner (run_reader) + +section +variables {ρ ρ' : Type u} {m m' : Type u → Type u} + +instance monad_reader_runner_trans {n n' : Type u → Type u} [monad_functor m m' n n'] [monad_reader_runner ρ m m'] : monad_reader_runner ρ n n' := +⟨λ α x r, monad_map (λ α (y : m α), (run_reader y r : m' α)) x⟩ + +instance reader_t.monad_state_runner [monad m] : monad_reader_runner ρ (reader_t ρ m) m := +⟨λ α x r, x.run r⟩ +end diff --git a/library/init/category/state.lean b/library/init/category/state.lean index 916de397ff..ce8d244904 100644 --- a/library/init/category/state.lean +++ b/library/init/category/state.lean @@ -170,6 +170,19 @@ instance [monad m] : monad_state_adapter σ σ' (state_t σ m) (state_t σ' m) : ⟨λ σ'' α, state_t.adapt⟩ end - instance (σ m out) [monad_run out m] : monad_run (λ α, σ → out (α × σ)) (state_t σ m) := ⟨λ α x, run ∘ (λ σ, x.run σ)⟩ + +class monad_state_runner (σ : Type u) (m m' : Type u → Type u) := +(run_state {} {α : Type u} : m α → σ → m' α) +export monad_state_runner (run_state) + +section +variables {σ σ' : Type u} {m m' : Type u → Type u} + +instance monad_state_runner_trans {n n' : Type u → Type u} [monad_functor m m' n n'] [monad_state_runner σ m m'] : monad_state_runner σ n n' := +⟨λ α x s, monad_map (λ α (y : m α), (run_state y s : m' α)) x⟩ + +instance state_t.monad_state_runner [monad m] : monad_state_runner σ (state_t σ m) m := +⟨λ α x s, prod.fst <$> x.run s⟩ +end diff --git a/library/init/compiler/ir.lean b/library/init/compiler/ir.lean index 9abbd943c4..250313d92e 100644 --- a/library/init/compiler/ir.lean +++ b/library/init/compiler/ir.lean @@ -47,10 +47,10 @@ show has_lt string, from infer_instance def var_set := rbtree var (<) def blockid_set := rbtree blockid (<) def context := rbmap var type (<) -def mk_var_set := mk_rbtree var (<) -def mk_blockid_set := mk_rbtree blockid (<) def var2blockid := rbmap var blockid (<) -def mk_var2blockid := mk_rbmap var blockid (<) +def mk_var_set : var_set := mk_rbtree var (<) +def mk_blockid_set : blockid_set := mk_rbtree blockid (<) +def mk_var2blockid : var2blockid := mk_rbmap var blockid (<) inductive instr | lit (x : var) (ty : type) (lit : literal) -- x : ty := lit @@ -107,12 +107,19 @@ SSA validator @[reducible] def ssa_check : Type := except_t string (state (var2blockid × var_set)) unit -def var.declare_at (b : blockid) (x : var) : ssa_check := -do (m, s) ← get, - if m.contains x then throw ("variable has already been defined '" ++ x ++ "'") - else put (m.insert x b, s) +inductive ssa_error +| already_defined (v : var) +| undefined (v : var) +| no_block -def instr.declare_vars_at (b : blockid) : instr → ssa_check +@[reducible] def ssa_decl_m := except_t ssa_error (state_t var2blockid id) + +def var.declare_at (b : blockid) (x : var) : ssa_decl_m unit := +do m ← get, + if m.contains x then throw $ ssa_error.already_defined x + else put (m.insert x b) + +def instr.declare_vars_at (b : blockid) : instr → ssa_decl_m unit | (instr.lit x _ _) := x.declare_at b | (instr.cast x _ _) := x.declare_at b | (instr.unop x _ _ _) := x.declare_at b @@ -131,57 +138,59 @@ def instr.declare_vars_at (b : blockid) : instr → ssa_check | (instr.sread x _ _ _) := x.declare_at b | _ := return () -def phi.declare_at (b : blockid) : phi → ssa_check +def phi.declare_at (b : blockid) : phi → ssa_decl_m unit | {x := x, ..} := x.declare_at b -def block.declare_vars : block → ssa_check +def block.declare_vars : block → ssa_decl_m unit | {id := b, phis := ps, instrs := is, ..} := ps.mmap' (phi.declare_at b) >> is.mmap' (instr.declare_vars_at b) -def arg.declare_at (b : blockid) : arg → ssa_check +def arg.declare_at (b : blockid) : arg → ssa_decl_m unit | {n := x, ..} := x.declare_at b /- Collect where each variable is declared, and check whether each variable was declared at most once. -/ -def decl.declare_vars : decl → ssa_check +def decl.declare_vars : decl → ssa_decl_m unit | {as := as, bs := b::bs, ..} := /- We assume that arguments are declared in the first basic block. TODO: check whether this assumption matches LLVM or not -/ as.mmap' (arg.declare_at b.id) >> b.declare_vars >> bs.mmap' block.declare_vars -| _ := throw "declaration must have at least one block" +| _ := throw ssa_error.no_block /- Generate the mapping from variable to blockid for the given declaration. This function assumes `d` is in SSA. -/ -def decl.var2blockid (d : decl) : var2blockid := -let (_, (m, _)) := d.declare_vars.run.run (mk_var2blockid, mk_var_set) in -m +def decl.var2blockid (d : decl) : except_t ssa_error id var2blockid := +run_state (d.declare_vars >> get) mk_var2blockid + +@[reducible] def ssa_valid_m := except_t ssa_error (reader_t var2blockid (state_t var_set id)) /- Given, x := phi ys, check whether every ys is declared at the var2blockid mapping, and update the set of already defined variables in the basic block with `x`. TODO: check whether the SSA validation rules here match the ones used in LLVM. -/ -def phi.valid_ssa : phi → ssa_check +def phi.valid_ssa : phi → ssa_valid_m unit | {x := x, ys := ys, ..} := do - (m, s) ← get, + m ← read, ys.mmap' (λ y, if m.contains y then return () - else throw ("undefined '" ++ y ++ "'")), - put (m, s.insert x) + else throw $ ssa_error.undefined y), + s ← get, + put (s.insert x) /- Check whether `x` has been already defined in the current basic block or not. -/ -def var.defined (x : var) : ssa_check := -do (_, s) ← get, +def var.defined (x : var) : ssa_valid_m unit := +do s ← get, if s.contains x then return () - else throw ("undefined variable '" ++ x ++ "'") + else throw $ ssa_error.undefined x /- Mark `x` as a variable defined in the current basic block. -/ -def var.define (x : var) : ssa_check := -do (m, s) ← get, put (m, s.insert x) +def var.define (x : var) : ssa_valid_m unit := +do s ← get, put (s.insert x) -def instr.valid_ssa : instr → ssa_check +def instr.valid_ssa : instr → ssa_valid_m unit | (instr.lit x _ _) := x.define | (instr.cast x _ y) := x.define >> y.defined | (instr.unop x _ _ y) := x.define >> y.defined @@ -205,63 +214,65 @@ def instr.valid_ssa : instr → ssa_check | (instr.dealloc x) := x.defined | (instr.dec x) := x.defined -def terminator.valid_ssa : terminator → ssa_check +def terminator.valid_ssa : terminator → ssa_valid_m unit | (terminator.ret ys) := ys.mmap' var.defined | (terminator.case x _) := x.defined | (terminator.jmp _) := return () -def block.valid_ssa : block → ssa_check +def block.valid_ssa_core : block → ssa_valid_m unit | {phis := ps, instrs := is, term := r, ..} := - do (m, s) ← get, - put (m, mk_var_set), -- reset set of variables defined in the current block - ps.mmap' phi.valid_ssa, + do ps.mmap' phi.valid_ssa, is.mmap' instr.valid_ssa, r.valid_ssa +def block.valid_ssa (b : block) : except_t ssa_error (reader_t var2blockid id) unit := +run_state b.valid_ssa_core mk_var_set + /- We first check whether every variable `x` was declared only once and store the blockid where `x` is defined (action: `decl.declare_vars`). Then, we check whether every used variable in basic block has been defined before being used. -/ -def decl.valid_ssa : decl → ssa_check +def decl.valid_ssa : decl → except_t ssa_error id var2blockid | d@{as := as, bs := bs, ..} := - d.declare_vars >> bs.mmap' block.valid_ssa + do m ← d.var2blockid, + bs.mmap' (λ b : block, run_reader b.valid_ssa m), + return m def valid_ssa (d : decl) : bool := -let (e, _) := d.valid_ssa.run.run (mk_var2blockid, mk_var_set) -in e.to_bool +d.valid_ssa.run.to_bool /- Check blockids -/ -@[reducible] def blockid_check : Type := -except_t string (state blockid_set) unit +@[reducible] def blockid_check_m := +except_t string (state blockid_set) -def block.declare : block → blockid_check +def block.declare : block → blockid_check_m unit | {id := id, ..} := do s ← get, if s.contains id then throw ("blockid '" ++ id ++ "' has already been used") else put (s.insert id) -def blockid.defined (bid : blockid) : blockid_check := +def blockid.defined (bid : blockid) : blockid_check_m unit := do s ← get, if s.contains bid then return () else throw ("unknown blockid '" ++ bid ++ "'") -def terminator.check_blockids : terminator → blockid_check +def terminator.check_blockids : terminator → blockid_check_m unit | (terminator.ret ys) := return () | (terminator.case _ bids) := bids.mmap' blockid.defined | (terminator.jmp bid) := bid.defined -def block.check_blockids : block → blockid_check +def block.check_blockids : block → blockid_check_m unit | {term := r, ..} := r.check_blockids -def decl.check_blockids : decl → blockid_check +def decl.check_blockids : decl → blockid_check_m unit | {bs := bs, ..} := bs.mmap' block.declare >> bs.mmap' block.check_blockids def check_blockids (d : decl) : bool := -let (e, _) := d.check_blockids.run.run mk_blockid_set -in e.to_bool +let r : except_t string id unit := run_state d.check_blockids mk_blockid_set in +r.run.to_bool /- TODO: type inference