diff --git a/src/library/compiler/erase_irrelevant.cpp b/src/library/compiler/erase_irrelevant.cpp index 0b57111cec..62bda6e510 100644 --- a/src/library/compiler/erase_irrelevant.cpp +++ b/src/library/compiler/erase_irrelevant.cpp @@ -282,9 +282,28 @@ class erase_irrelevant_fn : public compiler_step_visitor { return add_args(r, 3, args); } + expr visit_monad_bind(expr const & e, buffer const & args) { + if (args.size() == 6 && is_constant(args[1], get_monadIO_name())) { + /* IO bind */ + expr v = visit(args[4]); + expr b = visit(args[5]); + /* We just convert it into a let-expression */ + if (is_lambda(b)) { + return mk_let(binding_name(b), mk_neutral_expr(), v, binding_body(b)); + } else { + lean_assert(closed(b)); + return mk_let(mk_fresh_name(), mk_neutral_expr(), v, mk_app(b, mk_var(0))); + } + } else { + return compiler_step_visitor::visit_app(e); + } + } + virtual expr visit_app(expr const & e) override { if (is_comp_irrelevant(ctx(), e)) return *g_neutral_expr; + if (auto n = to_nat_value(ctx(), e)) + return *n; buffer args; expr const & fn = get_app_args(e, args); if (is_lambda(fn)) { @@ -301,6 +320,8 @@ class erase_irrelevant_fn : public compiler_step_visitor { return visit_quot_mk(args); } else if (n == get_subtype_rec_name()) { return visit_subtype_rec(args); + } else if (n == get_monad_bind_name()) { + return visit_monad_bind(e, args); } else if (is_cases_on_recursor(env(), n)) { return visit_cases_on(fn, args); } else if (inductive::is_elim_rule(env(), n)) { diff --git a/src/library/compiler/nat_value.cpp b/src/library/compiler/nat_value.cpp index d27d9950e5..5e101a32a6 100644 --- a/src/library/compiler/nat_value.cpp +++ b/src/library/compiler/nat_value.cpp @@ -81,14 +81,20 @@ mpz const & get_nat_value_value(expr const & e) { return static_cast(macro_def(e).raw())->get_value(); } +optional to_nat_value(type_context & ctx, expr const & e) { + if (optional v = to_num(e)) { + expr type = ctx.whnf(ctx.infer(e)); + if (is_constant(type, get_nat_name())) { + return some_expr(mk_nat_value(*v)); + } + } + return none_expr(); +} + class find_nat_values_fn : public compiler_step_visitor { expr visit_app(expr const & e) override { - if (optional v = to_num(e)) { - expr type = ctx().whnf(ctx().infer(e)); - if (is_constant(type, get_nat_name())) { - return mk_nat_value(*v); - } - } + if (auto v = to_nat_value(ctx(), e)) + return *v; return compiler_step_visitor::visit_app(e); } public: diff --git a/src/library/compiler/nat_value.h b/src/library/compiler/nat_value.h index 7423f3084e..5133fa4a2e 100644 --- a/src/library/compiler/nat_value.h +++ b/src/library/compiler/nat_value.h @@ -5,7 +5,7 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ #pragma once -#include "kernel/environment.h" +#include "library/type_context.h" namespace lean { /** \brief Replace nat numerals encoded using bit0, bit1, one with an auxiliary nat_value macro. @@ -18,6 +18,10 @@ bool is_nat_value(expr const & e); /** \brief Return the mpz stored in the nat_value macro. \pre is_nat_value(e) */ mpz const & get_nat_value_value(expr const & e); + +/** \brief If \c e encodes a nat numeral, then convert it into a nat_value macro */ +optional to_nat_value(type_context & ctx, expr const & e); + void initialize_nat_value(); void finalize_nat_value(); }