feat(library/compiler/erase_irrelevant): add support for IO monad.bind
This commit is contained in:
parent
e40c54013a
commit
63ed0c0056
3 changed files with 38 additions and 7 deletions
|
|
@ -282,9 +282,28 @@ class erase_irrelevant_fn : public compiler_step_visitor {
|
|||
return add_args(r, 3, args);
|
||||
}
|
||||
|
||||
expr visit_monad_bind(expr const & e, buffer<expr> const & args) {
|
||||
if (args.size() == 6 && is_constant(args[1], get_monadIO_name())) {
|
||||
/* IO bind */
|
||||
expr v = visit(args[4]);
|
||||
expr b = visit(args[5]);
|
||||
/* We just convert it into a let-expression */
|
||||
if (is_lambda(b)) {
|
||||
return mk_let(binding_name(b), mk_neutral_expr(), v, binding_body(b));
|
||||
} else {
|
||||
lean_assert(closed(b));
|
||||
return mk_let(mk_fresh_name(), mk_neutral_expr(), v, mk_app(b, mk_var(0)));
|
||||
}
|
||||
} else {
|
||||
return compiler_step_visitor::visit_app(e);
|
||||
}
|
||||
}
|
||||
|
||||
virtual expr visit_app(expr const & e) override {
|
||||
if (is_comp_irrelevant(ctx(), e))
|
||||
return *g_neutral_expr;
|
||||
if (auto n = to_nat_value(ctx(), e))
|
||||
return *n;
|
||||
buffer<expr> args;
|
||||
expr const & fn = get_app_args(e, args);
|
||||
if (is_lambda(fn)) {
|
||||
|
|
@ -301,6 +320,8 @@ class erase_irrelevant_fn : public compiler_step_visitor {
|
|||
return visit_quot_mk(args);
|
||||
} else if (n == get_subtype_rec_name()) {
|
||||
return visit_subtype_rec(args);
|
||||
} else if (n == get_monad_bind_name()) {
|
||||
return visit_monad_bind(e, args);
|
||||
} else if (is_cases_on_recursor(env(), n)) {
|
||||
return visit_cases_on(fn, args);
|
||||
} else if (inductive::is_elim_rule(env(), n)) {
|
||||
|
|
|
|||
|
|
@ -81,14 +81,20 @@ mpz const & get_nat_value_value(expr const & e) {
|
|||
return static_cast<nat_value_macro const *>(macro_def(e).raw())->get_value();
|
||||
}
|
||||
|
||||
optional<expr> to_nat_value(type_context & ctx, expr const & e) {
|
||||
if (optional<mpz> v = to_num(e)) {
|
||||
expr type = ctx.whnf(ctx.infer(e));
|
||||
if (is_constant(type, get_nat_name())) {
|
||||
return some_expr(mk_nat_value(*v));
|
||||
}
|
||||
}
|
||||
return none_expr();
|
||||
}
|
||||
|
||||
class find_nat_values_fn : public compiler_step_visitor {
|
||||
expr visit_app(expr const & e) override {
|
||||
if (optional<mpz> v = to_num(e)) {
|
||||
expr type = ctx().whnf(ctx().infer(e));
|
||||
if (is_constant(type, get_nat_name())) {
|
||||
return mk_nat_value(*v);
|
||||
}
|
||||
}
|
||||
if (auto v = to_nat_value(ctx(), e))
|
||||
return *v;
|
||||
return compiler_step_visitor::visit_app(e);
|
||||
}
|
||||
public:
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
|||
Author: Leonardo de Moura
|
||||
*/
|
||||
#pragma once
|
||||
#include "kernel/environment.h"
|
||||
#include "library/type_context.h"
|
||||
|
||||
namespace lean {
|
||||
/** \brief Replace nat numerals encoded using bit0, bit1, one with an auxiliary nat_value macro.
|
||||
|
|
@ -18,6 +18,10 @@ bool is_nat_value(expr const & e);
|
|||
/** \brief Return the mpz stored in the nat_value macro.
|
||||
\pre is_nat_value(e) */
|
||||
mpz const & get_nat_value_value(expr const & e);
|
||||
|
||||
/** \brief If \c e encodes a nat numeral, then convert it into a nat_value macro */
|
||||
optional<expr> to_nat_value(type_context & ctx, expr const & e);
|
||||
|
||||
void initialize_nat_value();
|
||||
void finalize_nat_value();
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue