fix(library/vm,library/compiler,frontends/lean): IO monad support
This commit is contained in:
parent
6d37c26b5d
commit
bf2d2b9feb
9 changed files with 105 additions and 16 deletions
|
|
@ -668,7 +668,8 @@ static environment vm_eval_cmd(parser & p) {
|
|||
expr e; level_param_names ls;
|
||||
std::tie(e, ls) = parse_local_expr(p);
|
||||
type_checker tc(p.env());
|
||||
expr type = tc.infer(e);
|
||||
expr type = tc.whnf(tc.infer(e));
|
||||
bool is_IO = is_constant(get_app_fn(type), get_IO_name());
|
||||
environment new_env = p.env();
|
||||
name main("_main");
|
||||
auto cd = check(new_env, mk_definition(new_env, main, ls, type, e));
|
||||
|
|
@ -677,7 +678,16 @@ static environment vm_eval_cmd(parser & p) {
|
|||
vm_state s(new_env);
|
||||
{
|
||||
timeit timer(p.ios().get_diagnostic_stream(), "vm_eval time");
|
||||
if (is_IO) s.push(mk_vm_simple(0)); // push the "RealWorld" state
|
||||
s.invoke_fn(main);
|
||||
if (is_IO) {
|
||||
vm_decl d = *s.get_decl(main);
|
||||
if (d.get_arity() == 0) {
|
||||
/* main returned a closure, it did not process RealWorld yet.
|
||||
So, we force the execution. */
|
||||
s.apply();
|
||||
}
|
||||
}
|
||||
}
|
||||
vm_obj r = s.get(0);
|
||||
display(p.ios().get_regular_stream(), r);
|
||||
|
|
|
|||
|
|
@ -284,16 +284,35 @@ class erase_irrelevant_fn : public compiler_step_visitor {
|
|||
|
||||
expr visit_monad_bind(expr const & e, buffer<expr> const & args) {
|
||||
if (args.size() == 6 && is_constant(args[1], get_monadIO_name())) {
|
||||
/* IO monad bind */
|
||||
expr v = visit(args[4]);
|
||||
expr b = visit(args[5]);
|
||||
/* We just convert it into a let-expression */
|
||||
if (is_lambda(b)) {
|
||||
return mk_let(binding_name(b), mk_neutral_expr(), v, binding_body(b));
|
||||
} else {
|
||||
lean_assert(closed(b));
|
||||
return mk_let(mk_fresh_name(), mk_neutral_expr(), v, mk_app(b, mk_var(0)));
|
||||
}
|
||||
/* Remark: morally the IO monad is
|
||||
|
||||
IO a := State -> (a, State)
|
||||
|
||||
and the (monad.bind v b) is
|
||||
|
||||
fun S, let p := v S
|
||||
in b (pr1 p) (pr2 p)
|
||||
|
||||
However, the State is a fiction. It is a unit at runtime.
|
||||
The IO a is a really just a thunk.
|
||||
|
||||
IO a := unit -> a
|
||||
|
||||
So, in this version, we have a simpler (monad.bind v b)
|
||||
|
||||
bind v b :=
|
||||
fun s, let a := v unit in
|
||||
b a unit
|
||||
|
||||
We use a let-expression to make sure that `v unit` is not erased.
|
||||
*/
|
||||
expr v = visit(args[4]);
|
||||
expr u = mk_neutral_expr();
|
||||
expr vu = mk_app(v, u);
|
||||
expr b = visit(args[5]);
|
||||
expr bau = beta_reduce(mk_app(b, mk_var(0), u));
|
||||
expr let = mk_let("a", u, vu, bau);
|
||||
return mk_lambda("S", u, let);
|
||||
} else {
|
||||
return compiler_step_visitor::visit_app(e);
|
||||
}
|
||||
|
|
@ -301,8 +320,15 @@ class erase_irrelevant_fn : public compiler_step_visitor {
|
|||
|
||||
expr visit_monad_return(expr const & e, buffer<expr> const & args) {
|
||||
if (args.size() == 4 && is_constant(args[1], get_monadIO_name())) {
|
||||
/* IO monad return */
|
||||
return visit(args[3]);
|
||||
/* IO monad return
|
||||
return v := fun s, v
|
||||
|
||||
Remark: we do not return the state explicility.
|
||||
*/
|
||||
expr u = mk_neutral_expr();
|
||||
expr s = mk_var(0);
|
||||
expr v = visit(args[3]);
|
||||
return mk_lambda("S", u, v);
|
||||
} else {
|
||||
return compiler_step_visitor::visit_app(e);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -90,6 +90,7 @@ name const * g_implies = nullptr;
|
|||
name const * g_implies_of_if_neg = nullptr;
|
||||
name const * g_implies_of_if_pos = nullptr;
|
||||
name const * g_implies_resolve = nullptr;
|
||||
name const * g_IO = nullptr;
|
||||
name const * g_is_trunc_is_prop = nullptr;
|
||||
name const * g_is_trunc_is_prop_elim = nullptr;
|
||||
name const * g_is_trunc_is_set = nullptr;
|
||||
|
|
@ -398,6 +399,7 @@ void initialize_constants() {
|
|||
g_implies_of_if_neg = new name{"implies_of_if_neg"};
|
||||
g_implies_of_if_pos = new name{"implies_of_if_pos"};
|
||||
g_implies_resolve = new name{"implies", "resolve"};
|
||||
g_IO = new name{"IO"};
|
||||
g_is_trunc_is_prop = new name{"is_trunc", "is_prop"};
|
||||
g_is_trunc_is_prop_elim = new name{"is_trunc", "is_prop", "elim"};
|
||||
g_is_trunc_is_set = new name{"is_trunc", "is_set"};
|
||||
|
|
@ -707,6 +709,7 @@ void finalize_constants() {
|
|||
delete g_implies_of_if_neg;
|
||||
delete g_implies_of_if_pos;
|
||||
delete g_implies_resolve;
|
||||
delete g_IO;
|
||||
delete g_is_trunc_is_prop;
|
||||
delete g_is_trunc_is_prop_elim;
|
||||
delete g_is_trunc_is_set;
|
||||
|
|
@ -1015,6 +1018,7 @@ name const & get_implies_name() { return *g_implies; }
|
|||
name const & get_implies_of_if_neg_name() { return *g_implies_of_if_neg; }
|
||||
name const & get_implies_of_if_pos_name() { return *g_implies_of_if_pos; }
|
||||
name const & get_implies_resolve_name() { return *g_implies_resolve; }
|
||||
name const & get_IO_name() { return *g_IO; }
|
||||
name const & get_is_trunc_is_prop_name() { return *g_is_trunc_is_prop; }
|
||||
name const & get_is_trunc_is_prop_elim_name() { return *g_is_trunc_is_prop_elim; }
|
||||
name const & get_is_trunc_is_set_name() { return *g_is_trunc_is_set; }
|
||||
|
|
|
|||
|
|
@ -92,6 +92,7 @@ name const & get_implies_name();
|
|||
name const & get_implies_of_if_neg_name();
|
||||
name const & get_implies_of_if_pos_name();
|
||||
name const & get_implies_resolve_name();
|
||||
name const & get_IO_name();
|
||||
name const & get_is_trunc_is_prop_name();
|
||||
name const & get_is_trunc_is_prop_elim_name();
|
||||
name const & get_is_trunc_is_set_name();
|
||||
|
|
|
|||
|
|
@ -85,6 +85,7 @@ implies
|
|||
implies_of_if_neg
|
||||
implies_of_if_pos
|
||||
implies.resolve
|
||||
IO
|
||||
is_trunc.is_prop
|
||||
is_trunc.is_prop.elim
|
||||
is_trunc.is_set
|
||||
|
|
|
|||
|
|
@ -1157,10 +1157,34 @@ void vm_state::invoke_fn(unsigned fn_idx) {
|
|||
run();
|
||||
}
|
||||
|
||||
void vm_state::execute(vm_instr const * code) {
|
||||
m_call_stack.emplace_back(m_code, m_fn_idx, 0, 0, m_bp);
|
||||
m_code = code;
|
||||
m_fn_idx = -1;
|
||||
m_pc = 0;
|
||||
m_bp = m_stack.size();
|
||||
run();
|
||||
}
|
||||
|
||||
void vm_state::apply(unsigned n) {
|
||||
buffer<vm_instr> code;
|
||||
for (unsigned i = 0; i < n; i++)
|
||||
code.push_back(mk_apply_instr());
|
||||
code.push_back(mk_ret_instr());
|
||||
execute(code.data());
|
||||
}
|
||||
|
||||
void vm_state::display(std::ostream & out, vm_obj const & o) const {
|
||||
::lean::display(out, o, [&](unsigned idx) { return optional<name>(m_decls[idx].get_name()); });
|
||||
}
|
||||
|
||||
optional<vm_decl> vm_state::get_decl(name const & n) const {
|
||||
if (auto idx = m_fn_name2idx.find(n))
|
||||
return optional<vm_decl>(m_decls[*idx]);
|
||||
else
|
||||
return optional<vm_decl>();
|
||||
}
|
||||
|
||||
void display_vm_code(std::ostream & out, environment const & env, unsigned code_sz, vm_instr const * code) {
|
||||
vm_decls const & ext = get_extension(env);
|
||||
auto idx2name = [&](unsigned idx) {
|
||||
|
|
|
|||
|
|
@ -427,6 +427,7 @@ class vm_state {
|
|||
void invoke_builtin(vm_decl const & d);
|
||||
void invoke_global(vm_decl const & d);
|
||||
void run();
|
||||
void execute(vm_instr const * code);
|
||||
|
||||
public:
|
||||
vm_state(environment const & env);
|
||||
|
|
@ -443,9 +444,23 @@ public:
|
|||
return m_stack[m_bp + idx];
|
||||
}
|
||||
|
||||
vm_obj const & top() const { return m_stack.back(); }
|
||||
|
||||
optional<vm_decl> get_decl(name const & n) const;
|
||||
|
||||
void invoke_fn(name const & fn);
|
||||
void invoke_fn(unsigned fn_idx);
|
||||
|
||||
/** Given a stack of the form
|
||||
|
||||
v_n
|
||||
...
|
||||
v_1
|
||||
(closure ...)
|
||||
|
||||
execute n function applications. */
|
||||
void apply(unsigned n = 1);
|
||||
|
||||
void display_stack(std::ostream & out) const;
|
||||
void display(std::ostream & out, vm_obj const & o) const;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -52,9 +52,9 @@ static void get_line(vm_state & s) {
|
|||
}
|
||||
|
||||
void initialize_vm_io() {
|
||||
declare_vm_builtin(get_put_str_name(), 1, put_str);
|
||||
declare_vm_builtin(get_put_nat_name(), 1, put_nat);
|
||||
declare_vm_builtin(get_get_line_name(), 0, get_line);
|
||||
declare_vm_builtin(get_put_str_name(), 2, put_str);
|
||||
declare_vm_builtin(get_put_nat_name(), 2, put_nat);
|
||||
declare_vm_builtin(get_get_line_name(), 1, get_line);
|
||||
}
|
||||
|
||||
void finalize_vm_io() {
|
||||
|
|
|
|||
8
tests/lean/run/whenIO.lean
Normal file
8
tests/lean/run/whenIO.lean
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
import system.IO
|
||||
open bool unit
|
||||
|
||||
definition when (b : bool) (a : IO unit) : IO unit :=
|
||||
cond b a (return star)
|
||||
|
||||
vm_eval when tt (put_str "hello\n")
|
||||
vm_eval when ff (put_str "error\n")
|
||||
Loading…
Add table
Reference in a new issue