From bf2d2b9feb49d0ae2b2d167a17476543b7cb499d Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 25 May 2016 13:01:42 -0700 Subject: [PATCH] fix(library/vm,library/compiler,frontends/lean): IO monad support --- src/frontends/lean/builtin_cmds.cpp | 12 +++++- src/library/compiler/erase_irrelevant.cpp | 50 +++++++++++++++++------ src/library/constants.cpp | 4 ++ src/library/constants.h | 1 + src/library/constants.txt | 1 + src/library/vm/vm.cpp | 24 +++++++++++ src/library/vm/vm.h | 15 +++++++ src/library/vm/vm_io.cpp | 6 +-- tests/lean/run/whenIO.lean | 8 ++++ 9 files changed, 105 insertions(+), 16 deletions(-) create mode 100644 tests/lean/run/whenIO.lean diff --git a/src/frontends/lean/builtin_cmds.cpp b/src/frontends/lean/builtin_cmds.cpp index be039a002f..d41867b17a 100644 --- a/src/frontends/lean/builtin_cmds.cpp +++ b/src/frontends/lean/builtin_cmds.cpp @@ -668,7 +668,8 @@ static environment vm_eval_cmd(parser & p) { expr e; level_param_names ls; std::tie(e, ls) = parse_local_expr(p); type_checker tc(p.env()); - expr type = tc.infer(e); + expr type = tc.whnf(tc.infer(e)); + bool is_IO = is_constant(get_app_fn(type), get_IO_name()); environment new_env = p.env(); name main("_main"); auto cd = check(new_env, mk_definition(new_env, main, ls, type, e)); @@ -677,7 +678,16 @@ static environment vm_eval_cmd(parser & p) { vm_state s(new_env); { timeit timer(p.ios().get_diagnostic_stream(), "vm_eval time"); + if (is_IO) s.push(mk_vm_simple(0)); // push the "RealWorld" state s.invoke_fn(main); + if (is_IO) { + vm_decl d = *s.get_decl(main); + if (d.get_arity() == 0) { + /* main returned a closure, it did not process RealWorld yet. + So, we force the execution. */ + s.apply(); + } + } } vm_obj r = s.get(0); display(p.ios().get_regular_stream(), r); diff --git a/src/library/compiler/erase_irrelevant.cpp b/src/library/compiler/erase_irrelevant.cpp index 1be1c4912b..b227146834 100644 --- a/src/library/compiler/erase_irrelevant.cpp +++ b/src/library/compiler/erase_irrelevant.cpp @@ -284,16 +284,35 @@ class erase_irrelevant_fn : public compiler_step_visitor { expr visit_monad_bind(expr const & e, buffer const & args) { if (args.size() == 6 && is_constant(args[1], get_monadIO_name())) { - /* IO monad bind */ - expr v = visit(args[4]); - expr b = visit(args[5]); - /* We just convert it into a let-expression */ - if (is_lambda(b)) { - return mk_let(binding_name(b), mk_neutral_expr(), v, binding_body(b)); - } else { - lean_assert(closed(b)); - return mk_let(mk_fresh_name(), mk_neutral_expr(), v, mk_app(b, mk_var(0))); - } + /* Remark: morally the IO monad is + + IO a := State -> (a, State) + + and the (monad.bind v b) is + + fun S, let p := v S + in b (pr1 p) (pr2 p) + + However, the State is a fiction. It is a unit at runtime. + The IO a is a really just a thunk. + + IO a := unit -> a + + So, in this version, we have a simpler (monad.bind v b) + + bind v b := + fun s, let a := v unit in + b a unit + + We use a let-expression to make sure that `v unit` is not erased. + */ + expr v = visit(args[4]); + expr u = mk_neutral_expr(); + expr vu = mk_app(v, u); + expr b = visit(args[5]); + expr bau = beta_reduce(mk_app(b, mk_var(0), u)); + expr let = mk_let("a", u, vu, bau); + return mk_lambda("S", u, let); } else { return compiler_step_visitor::visit_app(e); } @@ -301,8 +320,15 @@ class erase_irrelevant_fn : public compiler_step_visitor { expr visit_monad_return(expr const & e, buffer const & args) { if (args.size() == 4 && is_constant(args[1], get_monadIO_name())) { - /* IO monad return */ - return visit(args[3]); + /* IO monad return + return v := fun s, v + + Remark: we do not return the state explicility. + */ + expr u = mk_neutral_expr(); + expr s = mk_var(0); + expr v = visit(args[3]); + return mk_lambda("S", u, v); } else { return compiler_step_visitor::visit_app(e); } diff --git a/src/library/constants.cpp b/src/library/constants.cpp index eee5d8dd46..5e8d60977a 100644 --- a/src/library/constants.cpp +++ b/src/library/constants.cpp @@ -90,6 +90,7 @@ name const * g_implies = nullptr; name const * g_implies_of_if_neg = nullptr; name const * g_implies_of_if_pos = nullptr; name const * g_implies_resolve = nullptr; +name const * g_IO = nullptr; name const * g_is_trunc_is_prop = nullptr; name const * g_is_trunc_is_prop_elim = nullptr; name const * g_is_trunc_is_set = nullptr; @@ -398,6 +399,7 @@ void initialize_constants() { g_implies_of_if_neg = new name{"implies_of_if_neg"}; g_implies_of_if_pos = new name{"implies_of_if_pos"}; g_implies_resolve = new name{"implies", "resolve"}; + g_IO = new name{"IO"}; g_is_trunc_is_prop = new name{"is_trunc", "is_prop"}; g_is_trunc_is_prop_elim = new name{"is_trunc", "is_prop", "elim"}; g_is_trunc_is_set = new name{"is_trunc", "is_set"}; @@ -707,6 +709,7 @@ void finalize_constants() { delete g_implies_of_if_neg; delete g_implies_of_if_pos; delete g_implies_resolve; + delete g_IO; delete g_is_trunc_is_prop; delete g_is_trunc_is_prop_elim; delete g_is_trunc_is_set; @@ -1015,6 +1018,7 @@ name const & get_implies_name() { return *g_implies; } name const & get_implies_of_if_neg_name() { return *g_implies_of_if_neg; } name const & get_implies_of_if_pos_name() { return *g_implies_of_if_pos; } name const & get_implies_resolve_name() { return *g_implies_resolve; } +name const & get_IO_name() { return *g_IO; } name const & get_is_trunc_is_prop_name() { return *g_is_trunc_is_prop; } name const & get_is_trunc_is_prop_elim_name() { return *g_is_trunc_is_prop_elim; } name const & get_is_trunc_is_set_name() { return *g_is_trunc_is_set; } diff --git a/src/library/constants.h b/src/library/constants.h index b8beee04a1..fc1bcbe0e0 100644 --- a/src/library/constants.h +++ b/src/library/constants.h @@ -92,6 +92,7 @@ name const & get_implies_name(); name const & get_implies_of_if_neg_name(); name const & get_implies_of_if_pos_name(); name const & get_implies_resolve_name(); +name const & get_IO_name(); name const & get_is_trunc_is_prop_name(); name const & get_is_trunc_is_prop_elim_name(); name const & get_is_trunc_is_set_name(); diff --git a/src/library/constants.txt b/src/library/constants.txt index 9dd71a8d7b..85c3f2722e 100644 --- a/src/library/constants.txt +++ b/src/library/constants.txt @@ -85,6 +85,7 @@ implies implies_of_if_neg implies_of_if_pos implies.resolve +IO is_trunc.is_prop is_trunc.is_prop.elim is_trunc.is_set diff --git a/src/library/vm/vm.cpp b/src/library/vm/vm.cpp index df9dca7b6a..db9650f346 100644 --- a/src/library/vm/vm.cpp +++ b/src/library/vm/vm.cpp @@ -1157,10 +1157,34 @@ void vm_state::invoke_fn(unsigned fn_idx) { run(); } +void vm_state::execute(vm_instr const * code) { + m_call_stack.emplace_back(m_code, m_fn_idx, 0, 0, m_bp); + m_code = code; + m_fn_idx = -1; + m_pc = 0; + m_bp = m_stack.size(); + run(); +} + +void vm_state::apply(unsigned n) { + buffer code; + for (unsigned i = 0; i < n; i++) + code.push_back(mk_apply_instr()); + code.push_back(mk_ret_instr()); + execute(code.data()); +} + void vm_state::display(std::ostream & out, vm_obj const & o) const { ::lean::display(out, o, [&](unsigned idx) { return optional(m_decls[idx].get_name()); }); } +optional vm_state::get_decl(name const & n) const { + if (auto idx = m_fn_name2idx.find(n)) + return optional(m_decls[*idx]); + else + return optional(); +} + void display_vm_code(std::ostream & out, environment const & env, unsigned code_sz, vm_instr const * code) { vm_decls const & ext = get_extension(env); auto idx2name = [&](unsigned idx) { diff --git a/src/library/vm/vm.h b/src/library/vm/vm.h index f63087eba0..c88d685d8d 100644 --- a/src/library/vm/vm.h +++ b/src/library/vm/vm.h @@ -427,6 +427,7 @@ class vm_state { void invoke_builtin(vm_decl const & d); void invoke_global(vm_decl const & d); void run(); + void execute(vm_instr const * code); public: vm_state(environment const & env); @@ -443,9 +444,23 @@ public: return m_stack[m_bp + idx]; } + vm_obj const & top() const { return m_stack.back(); } + + optional get_decl(name const & n) const; + void invoke_fn(name const & fn); void invoke_fn(unsigned fn_idx); + /** Given a stack of the form + + v_n + ... + v_1 + (closure ...) + + execute n function applications. */ + void apply(unsigned n = 1); + void display_stack(std::ostream & out) const; void display(std::ostream & out, vm_obj const & o) const; }; diff --git a/src/library/vm/vm_io.cpp b/src/library/vm/vm_io.cpp index 77f7794400..5c6c612175 100644 --- a/src/library/vm/vm_io.cpp +++ b/src/library/vm/vm_io.cpp @@ -52,9 +52,9 @@ static void get_line(vm_state & s) { } void initialize_vm_io() { - declare_vm_builtin(get_put_str_name(), 1, put_str); - declare_vm_builtin(get_put_nat_name(), 1, put_nat); - declare_vm_builtin(get_get_line_name(), 0, get_line); + declare_vm_builtin(get_put_str_name(), 2, put_str); + declare_vm_builtin(get_put_nat_name(), 2, put_nat); + declare_vm_builtin(get_get_line_name(), 1, get_line); } void finalize_vm_io() { diff --git a/tests/lean/run/whenIO.lean b/tests/lean/run/whenIO.lean new file mode 100644 index 0000000000..f99c5b2379 --- /dev/null +++ b/tests/lean/run/whenIO.lean @@ -0,0 +1,8 @@ +import system.IO +open bool unit + +definition when (b : bool) (a : IO unit) : IO unit := +cond b a (return star) + +vm_eval when tt (put_str "hello\n") +vm_eval when ff (put_str "error\n")