fix(library/vm,library/compiler,frontends/lean): IO monad support

This commit is contained in:
Leonardo de Moura 2016-05-25 13:01:42 -07:00
parent 6d37c26b5d
commit bf2d2b9feb
9 changed files with 105 additions and 16 deletions

View file

@ -668,7 +668,8 @@ static environment vm_eval_cmd(parser & p) {
expr e; level_param_names ls;
std::tie(e, ls) = parse_local_expr(p);
type_checker tc(p.env());
expr type = tc.infer(e);
expr type = tc.whnf(tc.infer(e));
bool is_IO = is_constant(get_app_fn(type), get_IO_name());
environment new_env = p.env();
name main("_main");
auto cd = check(new_env, mk_definition(new_env, main, ls, type, e));
@ -677,7 +678,16 @@ static environment vm_eval_cmd(parser & p) {
vm_state s(new_env);
{
timeit timer(p.ios().get_diagnostic_stream(), "vm_eval time");
if (is_IO) s.push(mk_vm_simple(0)); // push the "RealWorld" state
s.invoke_fn(main);
if (is_IO) {
vm_decl d = *s.get_decl(main);
if (d.get_arity() == 0) {
/* main returned a closure, it did not process RealWorld yet.
So, we force the execution. */
s.apply();
}
}
}
vm_obj r = s.get(0);
display(p.ios().get_regular_stream(), r);

View file

@ -284,16 +284,35 @@ class erase_irrelevant_fn : public compiler_step_visitor {
expr visit_monad_bind(expr const & e, buffer<expr> const & args) {
if (args.size() == 6 && is_constant(args[1], get_monadIO_name())) {
/* IO monad bind */
expr v = visit(args[4]);
expr b = visit(args[5]);
/* We just convert it into a let-expression */
if (is_lambda(b)) {
return mk_let(binding_name(b), mk_neutral_expr(), v, binding_body(b));
} else {
lean_assert(closed(b));
return mk_let(mk_fresh_name(), mk_neutral_expr(), v, mk_app(b, mk_var(0)));
}
/* Remark: morally the IO monad is
IO a := State -> (a, State)
and the (monad.bind v b) is
fun S, let p := v S
in b (pr1 p) (pr2 p)
However, the State is a fiction. It is a unit at runtime.
The IO a is a really just a thunk.
IO a := unit -> a
So, in this version, we have a simpler (monad.bind v b)
bind v b :=
fun s, let a := v unit in
b a unit
We use a let-expression to make sure that `v unit` is not erased.
*/
expr v = visit(args[4]);
expr u = mk_neutral_expr();
expr vu = mk_app(v, u);
expr b = visit(args[5]);
expr bau = beta_reduce(mk_app(b, mk_var(0), u));
expr let = mk_let("a", u, vu, bau);
return mk_lambda("S", u, let);
} else {
return compiler_step_visitor::visit_app(e);
}
@ -301,8 +320,15 @@ class erase_irrelevant_fn : public compiler_step_visitor {
expr visit_monad_return(expr const & e, buffer<expr> const & args) {
if (args.size() == 4 && is_constant(args[1], get_monadIO_name())) {
/* IO monad return */
return visit(args[3]);
/* IO monad return
return v := fun s, v
Remark: we do not return the state explicility.
*/
expr u = mk_neutral_expr();
expr s = mk_var(0);
expr v = visit(args[3]);
return mk_lambda("S", u, v);
} else {
return compiler_step_visitor::visit_app(e);
}

View file

@ -90,6 +90,7 @@ name const * g_implies = nullptr;
name const * g_implies_of_if_neg = nullptr;
name const * g_implies_of_if_pos = nullptr;
name const * g_implies_resolve = nullptr;
name const * g_IO = nullptr;
name const * g_is_trunc_is_prop = nullptr;
name const * g_is_trunc_is_prop_elim = nullptr;
name const * g_is_trunc_is_set = nullptr;
@ -398,6 +399,7 @@ void initialize_constants() {
g_implies_of_if_neg = new name{"implies_of_if_neg"};
g_implies_of_if_pos = new name{"implies_of_if_pos"};
g_implies_resolve = new name{"implies", "resolve"};
g_IO = new name{"IO"};
g_is_trunc_is_prop = new name{"is_trunc", "is_prop"};
g_is_trunc_is_prop_elim = new name{"is_trunc", "is_prop", "elim"};
g_is_trunc_is_set = new name{"is_trunc", "is_set"};
@ -707,6 +709,7 @@ void finalize_constants() {
delete g_implies_of_if_neg;
delete g_implies_of_if_pos;
delete g_implies_resolve;
delete g_IO;
delete g_is_trunc_is_prop;
delete g_is_trunc_is_prop_elim;
delete g_is_trunc_is_set;
@ -1015,6 +1018,7 @@ name const & get_implies_name() { return *g_implies; }
name const & get_implies_of_if_neg_name() { return *g_implies_of_if_neg; }
name const & get_implies_of_if_pos_name() { return *g_implies_of_if_pos; }
name const & get_implies_resolve_name() { return *g_implies_resolve; }
name const & get_IO_name() { return *g_IO; }
name const & get_is_trunc_is_prop_name() { return *g_is_trunc_is_prop; }
name const & get_is_trunc_is_prop_elim_name() { return *g_is_trunc_is_prop_elim; }
name const & get_is_trunc_is_set_name() { return *g_is_trunc_is_set; }

View file

@ -92,6 +92,7 @@ name const & get_implies_name();
name const & get_implies_of_if_neg_name();
name const & get_implies_of_if_pos_name();
name const & get_implies_resolve_name();
name const & get_IO_name();
name const & get_is_trunc_is_prop_name();
name const & get_is_trunc_is_prop_elim_name();
name const & get_is_trunc_is_set_name();

View file

@ -85,6 +85,7 @@ implies
implies_of_if_neg
implies_of_if_pos
implies.resolve
IO
is_trunc.is_prop
is_trunc.is_prop.elim
is_trunc.is_set

View file

@ -1157,10 +1157,34 @@ void vm_state::invoke_fn(unsigned fn_idx) {
run();
}
void vm_state::execute(vm_instr const * code) {
m_call_stack.emplace_back(m_code, m_fn_idx, 0, 0, m_bp);
m_code = code;
m_fn_idx = -1;
m_pc = 0;
m_bp = m_stack.size();
run();
}
void vm_state::apply(unsigned n) {
buffer<vm_instr> code;
for (unsigned i = 0; i < n; i++)
code.push_back(mk_apply_instr());
code.push_back(mk_ret_instr());
execute(code.data());
}
void vm_state::display(std::ostream & out, vm_obj const & o) const {
::lean::display(out, o, [&](unsigned idx) { return optional<name>(m_decls[idx].get_name()); });
}
optional<vm_decl> vm_state::get_decl(name const & n) const {
if (auto idx = m_fn_name2idx.find(n))
return optional<vm_decl>(m_decls[*idx]);
else
return optional<vm_decl>();
}
void display_vm_code(std::ostream & out, environment const & env, unsigned code_sz, vm_instr const * code) {
vm_decls const & ext = get_extension(env);
auto idx2name = [&](unsigned idx) {

View file

@ -427,6 +427,7 @@ class vm_state {
void invoke_builtin(vm_decl const & d);
void invoke_global(vm_decl const & d);
void run();
void execute(vm_instr const * code);
public:
vm_state(environment const & env);
@ -443,9 +444,23 @@ public:
return m_stack[m_bp + idx];
}
vm_obj const & top() const { return m_stack.back(); }
optional<vm_decl> get_decl(name const & n) const;
void invoke_fn(name const & fn);
void invoke_fn(unsigned fn_idx);
/** Given a stack of the form
v_n
...
v_1
(closure ...)
execute n function applications. */
void apply(unsigned n = 1);
void display_stack(std::ostream & out) const;
void display(std::ostream & out, vm_obj const & o) const;
};

View file

@ -52,9 +52,9 @@ static void get_line(vm_state & s) {
}
void initialize_vm_io() {
declare_vm_builtin(get_put_str_name(), 1, put_str);
declare_vm_builtin(get_put_nat_name(), 1, put_nat);
declare_vm_builtin(get_get_line_name(), 0, get_line);
declare_vm_builtin(get_put_str_name(), 2, put_str);
declare_vm_builtin(get_put_nat_name(), 2, put_nat);
declare_vm_builtin(get_get_line_name(), 1, get_line);
}
void finalize_vm_io() {

View file

@ -0,0 +1,8 @@
import system.IO
open bool unit
definition when (b : bool) (a : IO unit) : IO unit :=
cond b a (return star)
vm_eval when tt (put_str "hello\n")
vm_eval when ff (put_str "error\n")