120 lines
4 KiB
C++
120 lines
4 KiB
C++
/*
|
|
Copyright (c) 2016 Microsoft Corporation. All rights reserved.
|
|
Released under Apache 2.0 license as described in the file LICENSE.
|
|
|
|
Author: Leonardo de Moura
|
|
*/
|
|
#include "util/numerics/mpz.h"
|
|
#include "kernel/expr.h"
|
|
#include "library/constants.h"
|
|
#include "library/num.h"
|
|
#include "library/kernel_serializer.h"
|
|
#include "library/compiler/compiler_step_visitor.h"
|
|
|
|
namespace lean {
|
|
static expr * g_nat = nullptr;
|
|
static name * g_nat_macro = nullptr;
|
|
static std::string * g_nat_opcode = nullptr;
|
|
|
|
/** \brief Auxiliary macro used during compilation */
|
|
class nat_value_macro : public macro_definition_cell {
|
|
mpz m_value;
|
|
public:
|
|
nat_value_macro(mpz const & v):m_value(v) {}
|
|
virtual bool lt(macro_definition_cell const & d) const {
|
|
return m_value < static_cast<nat_value_macro const &>(d).m_value;
|
|
}
|
|
virtual name get_name() const { return *g_nat_macro; }
|
|
|
|
virtual expr check_type(expr const &, abstract_type_context &, bool) const {
|
|
return *g_nat;
|
|
}
|
|
|
|
static expr convert(mpz const & n, expr const & one, expr const & add) {
|
|
if (n == 0) {
|
|
return mk_constant(get_nat_zero_name());
|
|
} else if (n == 1) {
|
|
return one;
|
|
} else {
|
|
expr r = convert(n/2, one, add);
|
|
r = mk_app(add, r, r);
|
|
if (n % mpz(2) == 1)
|
|
return mk_app(add, r, one);
|
|
else
|
|
return r;
|
|
}
|
|
}
|
|
|
|
virtual optional<expr> expand(expr const &, abstract_type_context &) const {
|
|
expr one = mk_app(mk_constant(get_nat_succ_name()), mk_constant(get_nat_zero_name()));
|
|
expr add = mk_constant(get_nat_add_name());
|
|
expr r = convert(m_value, one, add);
|
|
return optional<expr>(r);
|
|
}
|
|
|
|
virtual unsigned trust_level() const { return 0; }
|
|
virtual bool operator==(macro_definition_cell const & other) const {
|
|
nat_value_macro const * other_ptr = dynamic_cast<nat_value_macro const *>(&other);
|
|
return other_ptr && m_value == other_ptr->m_value;
|
|
}
|
|
virtual void display(std::ostream & out) const {
|
|
out << m_value;
|
|
}
|
|
virtual bool is_atomic_pp(bool, bool) const { return true; }
|
|
virtual unsigned hash() const { return m_value.hash(); }
|
|
virtual void write(serializer & s) const {
|
|
s << *g_nat_opcode << m_value;
|
|
}
|
|
mpz const & get_value() const { return m_value; }
|
|
};
|
|
|
|
expr mk_nat_value(mpz const & v) {
|
|
return mk_macro(macro_definition(new nat_value_macro(v)));
|
|
}
|
|
|
|
bool is_nat_value(expr const & e) {
|
|
return is_macro(e) && dynamic_cast<nat_value_macro const *>(macro_def(e).raw()) != nullptr;
|
|
}
|
|
|
|
mpz const & get_nat_value_value(expr const & e) {
|
|
lean_assert(is_nat_value(e));
|
|
return static_cast<nat_value_macro const *>(macro_def(e).raw())->get_value();
|
|
}
|
|
|
|
class find_nat_values_fn : public compiler_step_visitor {
|
|
expr visit_app(expr const & e) override {
|
|
if (optional<mpz> v = to_num(e)) {
|
|
expr type = ctx().whnf(ctx().infer(e));
|
|
if (is_constant(type, get_nat_name())) {
|
|
return mk_nat_value(*v);
|
|
}
|
|
}
|
|
return compiler_step_visitor::visit_app(e);
|
|
}
|
|
public:
|
|
find_nat_values_fn(environment const & env):compiler_step_visitor(env) {}
|
|
};
|
|
|
|
expr find_nat_values(environment const & env, expr const & e) {
|
|
return find_nat_values_fn(env)(e);
|
|
}
|
|
|
|
void initialize_nat_value() {
|
|
g_nat_macro = new name("nat_value_macro");
|
|
g_nat = new expr(Const(get_nat_name()));
|
|
g_nat_opcode = new std::string("CNatM");
|
|
register_macro_deserializer(*g_nat_opcode,
|
|
[](deserializer & d, unsigned num, expr const *) {
|
|
if (num != 0)
|
|
throw corrupted_stream_exception();
|
|
mpz v = read_mpz(d);
|
|
return mk_nat_value(v);
|
|
});
|
|
}
|
|
|
|
void finalize_nat_value() {
|
|
delete g_nat_opcode;
|
|
delete g_nat_macro;
|
|
delete g_nat;
|
|
}
|
|
}
|