From 0a67679afb8c7f30d2b140dc176df245d91f0443 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 2 Sep 2013 12:24:29 -0700 Subject: [PATCH] Add natural numbers. Fix how coercions and overloads interact (switch to approach used in C++). Add notation for natural and integer arithmetic. Rename m and u universe variables to M and U. Signed-off-by: Leonardo de Moura --- src/frontends/lean/lean_elaborator.cpp | 72 ++++---- src/frontends/lean/lean_notation.cpp | 23 +++ src/frontends/lean/lean_parser.cpp | 37 ++-- src/frontends/lean/lean_parser.h | 4 +- src/frontends/lean/lean_scanner.cpp | 4 +- src/frontends/lean/lean_scanner.h | 2 +- src/kernel/arith/arith.cpp | 213 +++++++++++++++++++--- src/kernel/arith/arith.h | 32 ++++ src/kernel/builtin.cpp | 11 +- src/kernel/builtin.h | 3 - src/tests/frontends/lean/lean_scanner.cpp | 10 +- tests/lean/arith1.lean | 17 ++ tests/lean/arith1.lean.expected.out | 17 ++ tests/lean/tst11.lean.expected.out | 2 +- tests/lean/tst15.lean | 30 +-- tests/lean/tst15.lean.expected.out | 32 ++-- tests/lean/tst4.lean.expected.out | 2 +- tests/lean/tst7.lean | 2 +- 18 files changed, 384 insertions(+), 129 deletions(-) create mode 100644 tests/lean/arith1.lean create mode 100644 tests/lean/arith1.lean.expected.out diff --git a/src/frontends/lean/lean_elaborator.cpp b/src/frontends/lean/lean_elaborator.cpp index 234001211e..cba746a304 100644 --- a/src/frontends/lean/lean_elaborator.cpp +++ b/src/frontends/lean/lean_elaborator.cpp @@ -201,42 +201,50 @@ class elaborator::imp { buffer good_choices; unsigned num_choices = f_choices.size(); unsigned num_args = args.size(); - for (unsigned j = 0; j < num_choices; j++) { - expr f_t = f_choice_types[j]; - try { - unsigned i = 1; - for (; i < num_args; i++) { - f_t = check_pi(f_t, ctx, src, ctx); - expr expected = abst_domain(f_t); - expr given = types[i]; - if (!has_metavar(expected) && !has_metavar(given)) { - if (!is_convertible(expected, given, ctx) && - !m_frontend.get_coercion(given, expected)) - break; // failed to use this overload + for (unsigned round = 0; round < 2; round++) { + // In the first round we only select perfect matches without considering + // overloads. This is the same approach used in C++. + // If a perfect match does not exist, then we try again using coercions. + for (unsigned j = 0; j < num_choices; j++) { + expr f_t = f_choice_types[j]; + try { + unsigned i = 1; + for (; i < num_args; i++) { + f_t = check_pi(f_t, ctx, src, ctx); + expr expected = abst_domain(f_t); + expr given = types[i]; + if (!has_metavar(expected) && !has_metavar(given)) { + if (!is_convertible(expected, given, ctx) && + // remark, we only consider coercions in the second round + (round == 0 || !m_frontend.get_coercion(given, expected))) + break; // failed to use this overload + } + f_t = instantiate_free_var_mmv(abst_body(f_t), 0, args[i]); } - f_t = instantiate_free_var_mmv(abst_body(f_t), 0, args[i]); - } - if (i == num_args) { - if (good_choices.empty()) { - // first good choice - args[0] = f_choices[j]; - types[0] = f_choice_types[j]; + if (i == num_args) { + if (good_choices.empty()) { + // first good choice + args[0] = f_choices[j]; + types[0] = f_choice_types[j]; + } + good_choices.push_back(j); } - good_choices.push_back(j); + } catch (exception & ex) { + // candidate failed + // do nothing } - } catch (exception & ex) { - // candidate failed - // do nothing } - } - if (good_choices.size() == 0) { - // TODO add information to the exception - throw exception("none of the overloads are good"); - } else if (good_choices.size() == 1) { - // found overload - } else { - // TODO add information to the exception - throw exception("ambiguous overload"); + if (good_choices.size() == 0) { + // TODO add information to the exception + if (round == 1) + throw exception("none of the overloads are good"); + } else if (good_choices.size() == 1) { + // found overload + return; + } else { + // TODO add information to the exception + throw exception("ambiguous overload"); + } } } diff --git a/src/frontends/lean/lean_notation.cpp b/src/frontends/lean/lean_notation.cpp index 68d7f20178..98921fd2d6 100644 --- a/src/frontends/lean/lean_notation.cpp +++ b/src/frontends/lean/lean_notation.cpp @@ -7,6 +7,7 @@ Author: Leonardo de Moura #include "builtin.h" #include "basic_thms.h" #include "lean_frontend.h" +#include "arith.h" namespace lean { /** @@ -29,6 +30,28 @@ void init_builtin_notation(frontend & f) { f.add_infixr("<=>", 25, mk_iff_fn()); // "<=>" f.add_infixr("\u21D4", 25, mk_iff_fn()); // "⇔" + f.add_infixl("+", 65, mk_nat_add_fn()); + f.add_infixl("*", 70, mk_nat_mul_fn()); + f.add_infix("<=", 50, mk_nat_le_fn()); + f.add_infix("\u2264", 50, mk_nat_le_fn()); // ≤ + f.add_infix(">=", 50, mk_nat_ge_fn()); + f.add_infix("\u2265", 50, mk_nat_ge_fn()); // ≥ + f.add_infix("<", 50, mk_nat_lt_fn()); + f.add_infix(">", 50, mk_nat_gt_fn()); + + f.add_infixl("+", 65, mk_int_add_fn()); + f.add_infixl("-", 65, mk_int_sub_fn()); + f.add_infixl("*", 70, mk_int_mul_fn()); + f.add_infixl("/", 70, mk_int_div_fn()); + f.add_infix("<=", 50, mk_int_le_fn()); + f.add_infix("\u2264", 50, mk_int_le_fn()); // ≤ + f.add_infix(">=", 50, mk_int_ge_fn()); + f.add_infix("\u2265", 50, mk_int_ge_fn()); // ≥ + f.add_infix("<", 50, mk_int_lt_fn()); + f.add_infix(">", 50, mk_int_gt_fn()); + + f.add_coercion(mk_nat_to_int_fn()); + // implicit arguments for builtin axioms f.mark_implicit_arguments(mk_mp_fn(), {true, true, false, false}); f.mark_implicit_arguments(mk_discharge_fn(), {true, true, false}); diff --git a/src/frontends/lean/lean_parser.cpp b/src/frontends/lean/lean_parser.cpp index 9d32cfe1be..9cdc501e91 100644 --- a/src/frontends/lean/lean_parser.cpp +++ b/src/frontends/lean/lean_parser.cpp @@ -199,8 +199,8 @@ class parser::imp { bool curr_is_identifier() const { return curr() == scanner::token::Id; } /** \brief Return true iff the current token is a '_" */ bool curr_is_placeholder() const { return curr() == scanner::token::Placeholder; } - /** \brief Return true iff the current token is an integer */ - bool curr_is_int() const { return curr() == scanner::token::IntVal; } + /** \brief Return true iff the current token is a natural number */ + bool curr_is_nat() const { return curr() == scanner::token::NatVal; } /** \brief Return true iff the current token is a '(' */ bool curr_is_lparen() const { return curr() == scanner::token::LeftParen; } /** \brief Return true iff the current token is a '{' */ @@ -254,10 +254,11 @@ class parser::imp { m_builtins["\u22A4"] = True; m_builtins["\u22A5"] = False; m_builtins["Int"] = Int; + m_builtins["Nat"] = Nat; } unsigned parse_unsigned(char const * msg) { - lean_assert(curr_is_int()); + lean_assert(curr_is_nat()); mpz pval = curr_num().get_numerator(); if (!pval.is_unsigned_int()) { throw parser_error(msg, pos()); @@ -285,7 +286,7 @@ class parser::imp { auto p = pos(); next(); buffer lvls; - while (curr_is_identifier() || curr_is_int()) { + while (curr_is_identifier() || curr_is_nat()) { lvls.push_back(parse_level()); } if (lvls.size() < 2) @@ -318,7 +319,7 @@ class parser::imp { level parse_level_nud() { switch (curr()) { case scanner::token::Id: return parse_level_nud_id(); - case scanner::token::IntVal: return parse_level_nud_int(); + case scanner::token::NatVal: return parse_level_nud_int(); default: throw parser_error("invalid level expression", pos()); } @@ -853,10 +854,10 @@ class parser::imp { } } - /** \brief Parse an integer value. */ - expr parse_int() { + /** \brief Parse a natural number value. */ + expr parse_nat() { auto p = pos(); - expr r = save(mk_int_value(m_scanner.get_num_val().get_numerator()), p); + expr r = save(mk_nat_value(m_scanner.get_num_val().get_numerator()), p); next(); return r; } @@ -873,7 +874,7 @@ class parser::imp { expr parse_type() { auto p = pos(); next(); - if (curr_is_identifier() || curr_is_int()) { + if (curr_is_identifier() || curr_is_nat()) { return save(mk_type(parse_level()), p); } else { return Type(); @@ -899,7 +900,7 @@ class parser::imp { case scanner::token::Forall: return parse_forall(); case scanner::token::Exists: return parse_exists(); case scanner::token::Let: return parse_let(); - case scanner::token::IntVal: return parse_int(); + case scanner::token::NatVal: return parse_nat(); case scanner::token::DecimalVal: return parse_decimal(); case scanner::token::StringVal: return parse_string(); case scanner::token::Placeholder: return parse_placeholder(); @@ -927,7 +928,7 @@ class parser::imp { case scanner::token::Eq: return parse_eq(left); case scanner::token::Arrow: return parse_arrow(left); case scanner::token::LeftParen: return mk_app_left(left, parse_lparen()); - case scanner::token::IntVal: return mk_app_left(left, parse_int()); + case scanner::token::NatVal: return mk_app_left(left, parse_nat()); case scanner::token::DecimalVal: return mk_app_left(left, parse_decimal()); case scanner::token::StringVal: return mk_app_left(left, parse_string()); case scanner::token::Placeholder: return mk_app_left(left, parse_placeholder()); @@ -954,7 +955,7 @@ class parser::imp { } case scanner::token::Eq : return g_eq_precedence; case scanner::token::Arrow : return g_arrow_precedence; - case scanner::token::LeftParen: case scanner::token::IntVal: case scanner::token::DecimalVal: + case scanner::token::LeftParen: case scanner::token::NatVal: case scanner::token::DecimalVal: case scanner::token::StringVal: case scanner::token::Type: case scanner::token::Placeholder: return 1; default: @@ -1119,7 +1120,7 @@ class parser::imp { name opt_id = curr_name(); next(); if (opt_id == g_env_kwd) { - if (curr_is_int()) { + if (curr_is_nat()) { unsigned i = parse_unsigned("invalid argument, value does not fit in a machine integer"); auto end = m_frontend.end_objects(); auto beg = m_frontend.begin_objects(); @@ -1155,7 +1156,7 @@ class parser::imp { /** \brief Return the (optional) precedence of a user-defined operator. */ unsigned parse_precedence() { - if (curr_is_int()) { + if (curr_is_nat()) { return parse_unsigned("invalid operator definition, precedence does not fit in a machine integer"); } else { return 0; @@ -1296,7 +1297,7 @@ class parser::imp { m_frontend.set_option(id, curr_string()); next(); break; - case scanner::token::IntVal: + case scanner::token::NatVal: if (k != IntOption && k != UnsignedOption) throw parser_error("invalid option value, given option is not an integer", pos()); m_frontend.set_option(id, parse_unsigned("invalid option value, value does not fit in a machine integer")); @@ -1450,7 +1451,7 @@ class parser::imp { } public: - imp(frontend & fe, std::istream & in, bool use_exceptions, bool interactive): + imp(frontend const & fe, std::istream & in, bool use_exceptions, bool interactive): m_frontend(fe), m_scanner(in), m_elaborator(fe), @@ -1543,7 +1544,7 @@ public: } }; -parser::parser(frontend fe, std::istream & in, bool use_exceptions, bool interactive) { +parser::parser(frontend const & fe, std::istream & in, bool use_exceptions, bool interactive) { parser::imp::show_prompt(interactive, fe); m_ptr.reset(new imp(fe, in, use_exceptions, interactive)); } @@ -1563,7 +1564,7 @@ expr parser::parse_expr() { return m_ptr->parse_expr_main(); } -shell::shell(frontend & fe):m_frontend(fe) { +shell::shell(frontend const & fe):m_frontend(fe) { } shell::~shell() { diff --git a/src/frontends/lean/lean_parser.h b/src/frontends/lean/lean_parser.h index 150e578a1b..e9b2f6d01c 100644 --- a/src/frontends/lean/lean_parser.h +++ b/src/frontends/lean/lean_parser.h @@ -15,7 +15,7 @@ class parser { class imp; std::unique_ptr m_ptr; public: - parser(frontend fe, std::istream & in, bool use_exceptions = true, bool interactive = false); + parser(frontend const & fe, std::istream & in, bool use_exceptions = true, bool interactive = false); ~parser(); /** \brief Parse a sequence of commands */ @@ -34,7 +34,7 @@ class shell { frontend m_frontend; interruptable_ptr m_parser; public: - shell(frontend & fe); + shell(frontend const & fe); ~shell(); bool operator()(); diff --git a/src/frontends/lean/lean_scanner.cpp b/src/frontends/lean/lean_scanner.cpp index 501978fae4..7de1348706 100644 --- a/src/frontends/lean/lean_scanner.cpp +++ b/src/frontends/lean/lean_scanner.cpp @@ -302,7 +302,7 @@ scanner::token scanner::read_number() { } if (is_decimal) m_num_val /= q; - return is_decimal ? token::DecimalVal : token::IntVal; + return is_decimal ? token::DecimalVal : token::NatVal; } scanner::token scanner::read_string() { @@ -389,7 +389,7 @@ std::ostream & operator<<(std::ostream & out, scanner::token const & t) { case scanner::token::In: out << "in"; break; case scanner::token::Id: out << "Id"; break; case scanner::token::CommandId: out << "CId"; break; - case scanner::token::IntVal: out << "Int"; break; + case scanner::token::NatVal: out << "Nat"; break; case scanner::token::DecimalVal: out << "Dec"; break; case scanner::token::StringVal: out << "String"; break; case scanner::token::Eq: out << "="; break; diff --git a/src/frontends/lean/lean_scanner.h b/src/frontends/lean/lean_scanner.h index 5e95d87049..8a02555c1d 100644 --- a/src/frontends/lean/lean_scanner.h +++ b/src/frontends/lean/lean_scanner.h @@ -19,7 +19,7 @@ class scanner { public: enum class token { LeftParen, RightParen, LeftCurlyBracket, RightCurlyBracket, Colon, Comma, Period, Lambda, Pi, Arrow, - Let, In, Forall, Exists, Id, CommandId, IntVal, DecimalVal, StringVal, Eq, Assign, Type, Placeholder, Eof + Let, In, Forall, Exists, Id, CommandId, NatVal, DecimalVal, StringVal, Eq, Assign, Type, Placeholder, Eof }; protected: int m_spos; // position in the current line of the stream diff --git a/src/kernel/arith/arith.cpp b/src/kernel/arith/arith.cpp index a0bad8b7e2..35a7f096a3 100644 --- a/src/kernel/arith/arith.cpp +++ b/src/kernel/arith/arith.cpp @@ -10,16 +10,130 @@ Author: Leonardo de Moura #include "environment.h" namespace lean { - -class int_type_value : public value { +/** \brief Base class for Nat, Int and Real types */ +class num_type_value : public value { + name m_name; public: - virtual ~int_type_value() {} + num_type_value(char const * name):m_name(name) {} + virtual ~num_type_value() {} virtual expr get_type() const { return Type(); } virtual bool normalize(unsigned num_args, expr const * args, expr & r) const { return false; } + virtual void display(std::ostream & out) const { out << m_name; } + virtual format pp() const { return format(m_name); } + virtual unsigned hash() const { return m_name.hash(); } +}; + +// ======================================= +// Natural numbers +class nat_type_value : public num_type_value { +public: + nat_type_value():num_type_value("Nat") {} + virtual bool operator==(value const & other) const { return dynamic_cast(&other) != nullptr; } +}; +expr const Nat = mk_value(*(new nat_type_value())); +expr mk_nat_type() { return Nat; } + +class nat_value_value : public value { + mpz m_val; +public: + nat_value_value(mpz const & v):m_val(v) { lean_assert(v >= 0); } + virtual ~nat_value_value() {} + virtual expr get_type() const { return Nat; } + virtual bool normalize(unsigned num_args, expr const * args, expr & r) const { return false; } + virtual bool operator==(value const & other) const { + nat_value_value const * _other = dynamic_cast(&other); + return _other && _other->m_val == m_val; + } + virtual void display(std::ostream & out) const { out << m_val; } + virtual format pp() const { return format(m_val); } + virtual unsigned hash() const { return m_val.hash(); } + mpz const & get_num() const { return m_val; } +}; + +expr mk_nat_value(mpz const & v) { + return mk_value(*(new nat_value_value(v))); +} + +bool is_nat_value(expr const & e) { + return is_value(e) && dynamic_cast(&to_value(e)) != nullptr; +} + +mpz const & nat_value_numeral(expr const & e) { + lean_assert(is_nat_value(e)); + return static_cast(to_value(e)).get_num(); +} + +template +class nat_bin_op : public value { + expr m_type; + name m_name; +public: + nat_bin_op() { + m_type = Nat >> (Nat >> Nat); + m_name = name("Nat", Name); + } + virtual ~nat_bin_op() {} + virtual expr get_type() const { return m_type; } + virtual bool operator==(value const & other) const { return dynamic_cast(&other) != nullptr; } + virtual bool normalize(unsigned num_args, expr const * args, expr & r) const { + if (num_args == 3 && is_nat_value(args[1]) && is_nat_value(args[2])) { + r = mk_nat_value(F()(nat_value_numeral(args[1]), nat_value_numeral(args[2]))); + return true; + } else { + return false; + } + } + virtual void display(std::ostream & out) const { out << m_name; } + virtual format pp() const { return format(m_name); } + virtual unsigned hash() const { return m_name.hash(); } +}; + +constexpr char nat_add_name[] = "add"; +struct nat_add_eval { mpz operator()(mpz const & v1, mpz const & v2) { return v1 + v2; }; }; +typedef nat_bin_op nat_add_value; +MK_BUILTIN(nat_add_fn, nat_add_value); + +constexpr char nat_mul_name[] = "mul"; +struct nat_mul_eval { mpz operator()(mpz const & v1, mpz const & v2) { return v1 * v2; }; }; +typedef nat_bin_op nat_mul_value; +MK_BUILTIN(nat_mul_fn, nat_mul_value); + +class nat_le_value : public value { + expr m_type; + name m_name; +public: + nat_le_value() { + m_type = Nat >> (Nat >> Bool); + m_name = name{"Nat", "le"}; + } + virtual ~nat_le_value() {} + virtual expr get_type() const { return m_type; } + virtual bool operator==(value const & other) const { return dynamic_cast(&other) != nullptr; } + virtual bool normalize(unsigned num_args, expr const * args, expr & r) const { + if (num_args == 3 && is_nat_value(args[1]) && is_nat_value(args[2])) { + r = mk_bool_value(nat_value_numeral(args[1]) <= nat_value_numeral(args[2])); + return true; + } else { + return false; + } + } + virtual void display(std::ostream & out) const { out << m_name; } + virtual format pp() const { return format(m_name); } + virtual unsigned hash() const { return m_name.hash(); } +}; +MK_BUILTIN(nat_le_fn, nat_le_value); + +MK_CONSTANT(nat_ge_fn, name(name("Nat"), "ge")); +MK_CONSTANT(nat_lt_fn, name(name("Nat"), "lt")); +MK_CONSTANT(nat_gt_fn, name(name("Nat"), "gt")); +// ======================================= + +// ======================================= +// Integers +class int_type_value : public num_type_value { +public: + int_type_value():num_type_value("Int") {} virtual bool operator==(value const & other) const { return dynamic_cast(&other) != nullptr; } - virtual void display(std::ostream & out) const { out << "Int"; } - virtual format pp() const { return format("Int"); } - virtual unsigned hash() const { return 41; } }; expr const Int = mk_value(*(new int_type_value())); expr mk_int_type() { return Int; } @@ -54,12 +168,14 @@ mpz const & int_value_numeral(expr const & e) { return static_cast(to_value(e)).get_num(); } -template +template class int_bin_op : public value { - expr m_type; + expr m_type; + name m_name; public: int_bin_op() { m_type = Int >> (Int >> Int); + m_name = name("Int", Name); } virtual ~int_bin_op() {} virtual expr get_type() const { return m_type; } @@ -72,36 +188,45 @@ public: return false; } } - virtual void display(std::ostream & out) const { out << Name; } - virtual format pp() const { return format(Name); } - virtual unsigned hash() const { return Hash; } + virtual void display(std::ostream & out) const { out << m_name; } + virtual format pp() const { return format(m_name); } + virtual unsigned hash() const { return m_name.hash(); } }; -constexpr char int_add_name[] = "+"; +constexpr char int_add_name[] = "add"; struct int_add_eval { mpz operator()(mpz const & v1, mpz const & v2) { return v1 + v2; }; }; -typedef int_bin_op int_add_value; +typedef int_bin_op int_add_value; MK_BUILTIN(int_add_fn, int_add_value); -constexpr char int_sub_name[] = "-"; +constexpr char int_sub_name[] = "sub"; struct int_sub_eval { mpz operator()(mpz const & v1, mpz const & v2) { return v1 - v2; }; }; -typedef int_bin_op int_sub_value; +typedef int_bin_op int_sub_value; MK_BUILTIN(int_sub_fn, int_sub_value); -constexpr char int_mul_name[] = "*"; +constexpr char int_mul_name[] = "mul"; struct int_mul_eval { mpz operator()(mpz const & v1, mpz const & v2) { return v1 * v2; }; }; -typedef int_bin_op int_mul_value; +typedef int_bin_op int_mul_value; MK_BUILTIN(int_mul_fn, int_mul_value); constexpr char int_div_name[] = "div"; -struct int_div_eval { mpz operator()(mpz const & v1, mpz const & v2) { return v1 / v2; }; }; -typedef int_bin_op int_div_value; +struct int_div_eval { + mpz operator()(mpz const & v1, mpz const & v2) { + if (v2.is_zero()) + return v2; + else + return v1 / v2; + }; +}; +typedef int_bin_op int_div_value; MK_BUILTIN(int_div_fn, int_div_value); class int_le_value : public value { expr m_type; + name m_name; public: int_le_value() { m_type = Int >> (Int >> Bool); + m_name = name{"Int", "le"}; } virtual ~int_le_value() {} virtual expr get_type() const { return m_type; } @@ -114,20 +239,54 @@ public: return false; } } - virtual void display(std::ostream & out) const { out << "Le"; } - virtual format pp() const { return format("Le"); } - virtual unsigned hash() const { return 67; } + virtual void display(std::ostream & out) const { out << m_name; } + virtual format pp() const { return format(m_name); } + virtual unsigned hash() const { return m_name.hash(); } }; MK_BUILTIN(int_le_fn, int_le_value); +MK_CONSTANT(int_ge_fn, name(name("Int"), "ge")); +MK_CONSTANT(int_lt_fn, name(name("Int"), "lt")); +MK_CONSTANT(int_gt_fn, name(name("Int"), "gt")); -MK_CONSTANT(int_ge_fn, name(name("int"), "Ge")); -MK_CONSTANT(int_lt_fn, name(name("int"), "Lt")); -MK_CONSTANT(int_gt_fn, name(name("int"), "Gt")); +class nat_to_int_value : public value { + expr m_type; + name m_name; +public: + nat_to_int_value() { + m_type = Nat >> Int; + m_name = "nat_to_int"; + } + virtual ~nat_to_int_value() {} + virtual expr get_type() const { return m_type; } + virtual bool operator==(value const & other) const { return dynamic_cast(&other) != nullptr; } + virtual bool normalize(unsigned num_args, expr const * args, expr & r) const { + if (num_args == 2 && is_nat_value(args[1])) { + r = mk_int_value(nat_value_numeral(args[1])); + return true; + } else { + return false; + } + } + virtual void display(std::ostream & out) const { out << m_name; } + virtual format pp() const { return format(m_name); } + virtual unsigned hash() const { return m_name.hash(); } +}; +MK_BUILTIN(nat_to_int_fn, nat_to_int_value); + +// ======================================= void add_int_theory(environment & env) { - expr p = Int >> (Int >> Bool); + expr p_ii = Int >> (Int >> Bool); + expr p_nn = Nat >> (Nat >> Bool); expr x = Const("x"); expr y = Const("y"); - env.add_definition(int_ge_fn_name, p, Fun({{x, Int}, {y, Int}}, iLe(y, x))); + + env.add_definition(nat_ge_fn_name, p_nn, Fun({{x, Nat}, {y, Nat}}, nLe(y, x))); + env.add_definition(nat_lt_fn_name, p_nn, Fun({{x, Nat}, {y, Nat}}, Not(nLe(y, x)))); + env.add_definition(nat_gt_fn_name, p_nn, Fun({{x, Nat}, {y, Nat}}, Not(nLe(x, y)))); + + env.add_definition(int_ge_fn_name, p_ii, Fun({{x, Int}, {y, Int}}, iLe(y, x))); + env.add_definition(int_lt_fn_name, p_ii, Fun({{x, Int}, {y, Int}}, Not(iLe(y, x)))); + env.add_definition(int_gt_fn_name, p_ii, Fun({{x, Int}, {y, Int}}, Not(iLe(x, y)))); } } diff --git a/src/kernel/arith/arith.h b/src/kernel/arith/arith.h index 1d7895bf33..ec32ff0cfa 100644 --- a/src/kernel/arith/arith.h +++ b/src/kernel/arith/arith.h @@ -11,6 +11,35 @@ Author: Leonardo de Moura #include "mpq.h" namespace lean { +expr mk_nat_type(); +extern expr const Nat; + +expr mk_nat_value(mpz const & v); +inline expr mk_nat_value(unsigned v) { return mk_nat_value(mpz(v)); } +inline expr nVal(unsigned v) { return mk_nat_value(v); } +bool is_nat_value(expr const & e); +mpz const & nat_value_numeral(expr const & e); + +expr mk_nat_add_fn(); +inline expr nAdd(expr const & e1, expr const & e2) { return mk_app(mk_nat_add_fn(), e1, e2); } + +expr mk_nat_mul_fn(); +inline expr nMul(expr const & e1, expr const & e2) { return mk_app(mk_nat_mul_fn(), e1, e2); } + +expr mk_nat_le_fn(); +inline expr nLe(expr const & e1, expr const & e2) { return mk_app(mk_nat_le_fn(), e1, e2); } + +expr mk_nat_ge_fn(); +inline expr nGe(expr const & e1, expr const & e2) { return mk_app(mk_nat_ge_fn(), e1, e2); } + +expr mk_nat_lt_fn(); +inline expr nLt(expr const & e1, expr const & e2) { return mk_app(mk_nat_lt_fn(), e1, e2); } + +expr mk_nat_gt_fn(); +inline expr nGt(expr const & e1, expr const & e2) { return mk_app(mk_nat_gt_fn(), e1, e2); } + +inline expr nIf(expr const & c, expr const & t, expr const & e) { return mk_if(Nat, c, t, e); } + expr mk_int_type(); extern expr const Int; @@ -46,6 +75,9 @@ inline expr iGt(expr const & e1, expr const & e2) { return mk_app(mk_int_gt_fn() inline expr iIf(expr const & c, expr const & t, expr const & e) { return mk_if(Int, c, t, e); } +expr mk_nat_to_int_fn(); +inline expr n2i(expr const & e) { return mk_app(mk_nat_to_int_fn(), e); } + class environment; void add_int_theory(environment & env); } diff --git a/src/kernel/builtin.cpp b/src/kernel/builtin.cpp index 308fb3b3fa..f0aa04cbe3 100644 --- a/src/kernel/builtin.cpp +++ b/src/kernel/builtin.cpp @@ -48,8 +48,8 @@ expr mk_bin_lop(expr const & op, expr const & unit, std::initializer_list // ======================================= // Bultin universe variables m and u -static level m_lvl(name("m")); -static level u_lvl(name("u")); +static level m_lvl(name("M")); +static level u_lvl(name("U")); expr const TypeM = Type(m_lvl); expr const TypeU = Type(u_lvl); // ======================================= @@ -57,6 +57,7 @@ expr const TypeU = Type(u_lvl); // ======================================= // Boolean Type static char const * g_Bool_str = "Bool"; +static name g_Bool_name(g_Bool_str); static format g_Bool_fmt(g_Bool_str); class bool_type_value : public value { public: @@ -64,9 +65,9 @@ public: virtual expr get_type() const { return Type(); } virtual bool normalize(unsigned num_args, expr const * args, expr & r) const { return false; } virtual bool operator==(value const & other) const { return dynamic_cast(&other) != nullptr; } - virtual void display(std::ostream & out) const { out << g_Bool_str; } + virtual void display(std::ostream & out) const { out << g_Bool_name; } virtual format pp() const { return g_Bool_fmt; } - virtual unsigned hash() const { return 17; } + virtual unsigned hash() const { return g_Bool_name.hash(); } }; expr const Bool = mk_value(*(new bool_type_value())); expr mk_bool_type() { return Bool; } @@ -155,7 +156,7 @@ public: virtual bool operator==(value const & other) const { return dynamic_cast(&other) != nullptr; } virtual void display(std::ostream & out) const { out << g_ite_name; } virtual format pp() const { return g_ite_fmt; } - virtual unsigned hash() const { return 27; } + virtual unsigned hash() const { return g_ite_name.hash(); } }; MK_BUILTIN(ite_fn, ite_fn_value); // ======================================= diff --git a/src/kernel/builtin.h b/src/kernel/builtin.h index 8b6721ff3b..0aaf3db3ff 100644 --- a/src/kernel/builtin.h +++ b/src/kernel/builtin.h @@ -160,9 +160,6 @@ expr mk_##Name() { \ static thread_local expr r = mk_value(*(new ClassName())); \ return r; \ } \ -bool is_##Name(expr const & e) { \ - return is_value(e) && dynamic_cast(&to_value(e)) != nullptr; \ -} /** \brief Helper macro for generating "defined" constants. diff --git a/src/tests/frontends/lean/lean_scanner.cpp b/src/tests/frontends/lean/lean_scanner.cpp index 65c3e3c8d2..b5cd99a18d 100644 --- a/src/tests/frontends/lean/lean_scanner.cpp +++ b/src/tests/frontends/lean/lean_scanner.cpp @@ -24,7 +24,7 @@ static void scan(char const * str, list const & cmds = list()) { std::cout << t; if (t == st::Id || t == st::CommandId) std::cout << "[" << s.get_name_val() << "]"; - else if (t == st::IntVal || t == st::DecimalVal) + else if (t == st::NatVal || t == st::DecimalVal) std::cout << "[" << s.get_num_val() << "]"; else if (t == st::StringVal) std::cout << "[\"" << escaped(s.get_str_val().c_str()) << "\"]"; @@ -78,11 +78,11 @@ static void tst2() { check("x+y", {st::Id, st::Id, st::Id}); check("(* testing *)", {}); check(" 2.31 ", {st::DecimalVal}); - check(" 333 22", {st::IntVal, st::IntVal}); + check(" 333 22", {st::NatVal, st::NatVal}); check("Int -> Int", {st::Id, st::Arrow, st::Id}); check("Int --> Int", {st::Id, st::Id, st::Id}); - check("x := 10", {st::Id, st::Assign, st::IntVal}); - check("(x+1):Int", {st::LeftParen, st::Id, st::Id, st::IntVal, st::RightParen, st::Colon, st::Id}); + check("x := 10", {st::Id, st::Assign, st::NatVal}); + check("(x+1):Int", {st::LeftParen, st::Id, st::Id, st::NatVal, st::RightParen, st::Colon, st::Id}); check("{x}", {st::LeftCurlyBracket, st::Id, st::RightCurlyBracket}); check("\u03BB \u03A0 \u2192", {st::Lambda, st::Pi, st::Arrow}); scan("++\u2295++x\u2296\u2296"); @@ -91,7 +91,7 @@ static void tst2() { check_name("x10", name("x10")); check_name("x::10", name(name("x"), 10)); check_name("x::10::bla::0", name(name(name(name("x"), 10), "bla"), 0u)); - check("0::1", {st::IntVal, st::Colon, st::Colon, st::IntVal}); + check("0::1", {st::NatVal, st::Colon, st::Colon, st::NatVal}); check_name("\u2296\u2296", name("\u2296\u2296")); try { scan("x::1000000000000000000"); diff --git a/tests/lean/arith1.lean b/tests/lean/arith1.lean new file mode 100644 index 0000000000..c6d69929cd --- /dev/null +++ b/tests/lean/arith1.lean @@ -0,0 +1,17 @@ +Check 10 + 20 +Check 10 +Check 10 - 20 +Eval 10 - 20 +Eval 15 + 10 - 20 +Check 15 + 10 - 20 +Variable x : Int +Variable n : Nat +Variable m : Nat +Show n + m +Show n + x + m +Set lean::pp::coercion true +Show n + x + m + 10 +Show x + n + m + 10 +Show n + m + 10 + x +Set lean::pp::notation false +Show n + m + 10 + x diff --git a/tests/lean/arith1.lean.expected.out b/tests/lean/arith1.lean.expected.out new file mode 100644 index 0000000000..8751c4bb4e --- /dev/null +++ b/tests/lean/arith1.lean.expected.out @@ -0,0 +1,17 @@ +Nat +Nat +Int +-10 +5 +Int + Assumed: x + Assumed: n + Assumed: m +n + m +n + x + m + Set option: lean::pp::coercion +(nat_to_int n) + x + (nat_to_int m) + (nat_to_int 10) +x + (nat_to_int n) + (nat_to_int m) + (nat_to_int 10) +(nat_to_int (n + m + 10)) + x + Set option: lean::pp::notation +Int::add (nat_to_int (Nat::add (Nat::add n m) 10)) x diff --git a/tests/lean/tst11.lean.expected.out b/tests/lean/tst11.lean.expected.out index 74b68aad25..81af0941a5 100644 --- a/tests/lean/tst11.lean.expected.out +++ b/tests/lean/tst11.lean.expected.out @@ -5,7 +5,7 @@ ⊤ Assumed: a a ⊕ a ⊕ a -Π (A : Type u) (a b : A) (P : A → Bool) (H1 : P a) (H2 : a = b), P b +Π (A : Type U) (a b : A) (P : A → Bool) (H1 : P a) (H2 : a = b), P b Proved: EM2 Π a : Bool, a ∨ ¬ a a ∨ ¬ a diff --git a/tests/lean/tst15.lean b/tests/lean/tst15.lean index 2c8eb7ba2c..78ff5ce3c0 100644 --- a/tests/lean/tst15.lean +++ b/tests/lean/tst15.lean @@ -1,21 +1,21 @@ Set pp::colors false -Variable x : Type max u+1+2 m+1 m+2 3 +Variable x : Type max U+1+2 M+1 M+2 3 Check x -Variable f : Type u+10 -> Type +Variable f : Type U+10 -> Type Check f Check f x Check Type 4 Check x -Check Type max u m -Show Type u+3 -Check Type u+3 -Check Type u ⊔ m -Check Type u ⊔ m ⊔ 3 -Show Type u+1 ⊔ m ⊔ 3 -Check Type u+1 ⊔ m ⊔ 3 -Show Type u -> Type 5 -Check Type u -> Type 5 -Check Type m ⊔ 3 -> Type u+5 -Show Type m ⊔ 3 -> Type u -> Type 5 -Check Type m ⊔ 3 -> Type u -> Type 5 -Show Type u +Check Type max U M +Show Type U+3 +Check Type U+3 +Check Type U ⊔ M +Check Type U ⊔ M ⊔ 3 +Show Type U+1 ⊔ M ⊔ 3 +Check Type U+1 ⊔ M ⊔ 3 +Show Type U -> Type 5 +Check Type U -> Type 5 +Check Type M ⊔ 3 -> Type U+5 +Show Type M ⊔ 3 -> Type U -> Type 5 +Check Type M ⊔ 3 -> Type U -> Type 5 +Show Type U diff --git a/tests/lean/tst15.lean.expected.out b/tests/lean/tst15.lean.expected.out index 4fe2e8eaf0..5ccd32ddae 100644 --- a/tests/lean/tst15.lean.expected.out +++ b/tests/lean/tst15.lean.expected.out @@ -1,21 +1,21 @@ Set option: pp::colors Assumed: x -Type u+3 ⊔ m+2 ⊔ 3 +Type U+3 ⊔ M+2 ⊔ 3 Assumed: f -Type u+10 → Type +Type U+10 → Type Type Type 5 -Type u+3 ⊔ m+2 ⊔ 3 -Type u+1 ⊔ m+1 -Type u+3 -Type u+4 -Type u+1 ⊔ m+1 -Type u+1 ⊔ m+1 ⊔ 4 -Type u+1 ⊔ m ⊔ 3 -Type u+2 ⊔ m+1 ⊔ 4 -Type u → Type 5 -Type u+1 ⊔ 6 -Type m+1 ⊔ 4 ⊔ u+6 -Type m ⊔ 3 → Type u → Type 5 -Type m+1 ⊔ 6 ⊔ u+1 -Type u +Type U+3 ⊔ M+2 ⊔ 3 +Type U+1 ⊔ M+1 +Type U+3 +Type U+4 +Type U+1 ⊔ M+1 +Type U+1 ⊔ M+1 ⊔ 4 +Type U+1 ⊔ M ⊔ 3 +Type U+2 ⊔ M+1 ⊔ 4 +Type U → Type 5 +Type U+1 ⊔ 6 +Type M+1 ⊔ 4 ⊔ U+6 +Type M ⊔ 3 → Type U → Type 5 +Type M+1 ⊔ 6 ⊔ U+1 +Type U diff --git a/tests/lean/tst4.lean.expected.out b/tests/lean/tst4.lean.expected.out index f6f3e57ebe..ac27ac3f91 100644 --- a/tests/lean/tst4.lean.expected.out +++ b/tests/lean/tst4.lean.expected.out @@ -9,7 +9,7 @@ f::explicit ((N → N) → N → N) (λ x : Set option: pp::colors EqNice::explicit N n1 n2 N -Π (A : Type u) (B : A → Type u) (f g : Π x : A, B x) (a b : A) (H1 : f = g) (H2 : a = b), (f a) = (g b) +Π (A : Type U) (B : A → Type U) (f g : Π x : A, B x) (a b : A) (H1 : f = g) (H2 : a = b), (f a) = (g b) f::explicit N n1 n2 Assumed: a Assumed: b diff --git a/tests/lean/tst7.lean b/tests/lean/tst7.lean index 5292774c45..f4fdde0234 100644 --- a/tests/lean/tst7.lean +++ b/tests/lean/tst7.lean @@ -4,7 +4,7 @@ Show fun (A B : Type) (a : _), f B a (* The following one should produce an error *) Show fun (A : Type) (a : _) (B : Type), f B a -Variable myeq : Pi (A : Type u), A -> A -> Bool +Variable myeq : Pi (A : Type U), A -> A -> Bool Show myeq _ (fun (A : Type) (a : _), a) (fun (B : Type) (b : B), b) Check myeq _ (fun (A : Type) (a : _), a) (fun (B : Type) (b : B), b)