feat(library/compiler/erase_irrelevant): add support for IO monad.bind

This commit is contained in:
Leonardo de Moura 2016-05-24 18:14:39 -07:00
parent e40c54013a
commit 63ed0c0056
3 changed files with 38 additions and 7 deletions

View file

@ -282,9 +282,28 @@ class erase_irrelevant_fn : public compiler_step_visitor {
return add_args(r, 3, args);
}
expr visit_monad_bind(expr const & e, buffer<expr> const & args) {
if (args.size() == 6 && is_constant(args[1], get_monadIO_name())) {
/* IO bind */
expr v = visit(args[4]);
expr b = visit(args[5]);
/* We just convert it into a let-expression */
if (is_lambda(b)) {
return mk_let(binding_name(b), mk_neutral_expr(), v, binding_body(b));
} else {
lean_assert(closed(b));
return mk_let(mk_fresh_name(), mk_neutral_expr(), v, mk_app(b, mk_var(0)));
}
} else {
return compiler_step_visitor::visit_app(e);
}
}
virtual expr visit_app(expr const & e) override {
if (is_comp_irrelevant(ctx(), e))
return *g_neutral_expr;
if (auto n = to_nat_value(ctx(), e))
return *n;
buffer<expr> args;
expr const & fn = get_app_args(e, args);
if (is_lambda(fn)) {
@ -301,6 +320,8 @@ class erase_irrelevant_fn : public compiler_step_visitor {
return visit_quot_mk(args);
} else if (n == get_subtype_rec_name()) {
return visit_subtype_rec(args);
} else if (n == get_monad_bind_name()) {
return visit_monad_bind(e, args);
} else if (is_cases_on_recursor(env(), n)) {
return visit_cases_on(fn, args);
} else if (inductive::is_elim_rule(env(), n)) {

View file

@ -81,14 +81,20 @@ mpz const & get_nat_value_value(expr const & e) {
return static_cast<nat_value_macro const *>(macro_def(e).raw())->get_value();
}
optional<expr> to_nat_value(type_context & ctx, expr const & e) {
if (optional<mpz> v = to_num(e)) {
expr type = ctx.whnf(ctx.infer(e));
if (is_constant(type, get_nat_name())) {
return some_expr(mk_nat_value(*v));
}
}
return none_expr();
}
class find_nat_values_fn : public compiler_step_visitor {
expr visit_app(expr const & e) override {
if (optional<mpz> v = to_num(e)) {
expr type = ctx().whnf(ctx().infer(e));
if (is_constant(type, get_nat_name())) {
return mk_nat_value(*v);
}
}
if (auto v = to_nat_value(ctx(), e))
return *v;
return compiler_step_visitor::visit_app(e);
}
public:

View file

@ -5,7 +5,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#pragma once
#include "kernel/environment.h"
#include "library/type_context.h"
namespace lean {
/** \brief Replace nat numerals encoded using bit0, bit1, one with an auxiliary nat_value macro.
@ -18,6 +18,10 @@ bool is_nat_value(expr const & e);
/** \brief Return the mpz stored in the nat_value macro.
\pre is_nat_value(e) */
mpz const & get_nat_value_value(expr const & e);
/** \brief If \c e encodes a nat numeral, then convert it into a nat_value macro */
optional<expr> to_nat_value(type_context & ctx, expr const & e);
void initialize_nat_value();
void finalize_nat_value();
}