diff --git a/src/frontends/lean/lean_parser.cpp b/src/frontends/lean/lean_parser.cpp
index bebb45edfa..e0259d568a 100644
--- a/src/frontends/lean/lean_parser.cpp
+++ b/src/frontends/lean/lean_parser.cpp
@@ -377,9 +377,8 @@ class parser::imp {
/**
\brief Return the function associated with the given operator.
- If the operator has been overloaded, it returns an expression
- of the form (overload f_k ... (overload f_2 f_1) ...)
- where f_i's are different options.
+ If the operator has been overloaded, it returns a choice expression
+ of the form (choice f_1 f_2 ... f_k) where f_i's are different options.
After we finish parsing, the procedure #elaborate will
resolve/decide which f_i should be used.
*/
@@ -389,9 +388,15 @@ class parser::imp {
auto it = fs.begin();
expr r = *it;
++it;
- for (; it != fs.end(); ++it)
- r = mk_app(mk_overload_marker(), *it, r);
- return r;
+ if (it == fs.end()) {
+ return r;
+ } else {
+ buffer alternatives;
+ alternatives.push_back(r);
+ for (; it != fs.end(); ++it)
+ alternatives.push_back(*it);
+ return mk_choice(alternatives.size(), alternatives.data());
+ }
}
/**
diff --git a/src/library/elaborator.cpp b/src/library/elaborator.cpp
index 0398c4b3b6..37481be270 100644
--- a/src/library/elaborator.cpp
+++ b/src/library/elaborator.cpp
@@ -17,17 +17,28 @@ Author: Leonardo de Moura
#include "elaborator_exception.h"
namespace lean {
-static name g_overload_name(name(name(name(0u), "library"), "overload"));
-static expr g_overload = mk_constant(g_overload_name);
+static name g_choice_name(name(name(name(0u), "library"), "choice"));
+static expr g_choice = mk_constant(g_choice_name);
static format g_assignment_fmt = format(":=");
static format g_unification_fmt = format("\u2248");
-bool is_overload_marker(expr const & e) {
- return e == g_overload;
+expr mk_choice(unsigned num_fs, expr const * fs) {
+ lean_assert(num_fs >= 2);
+ return mk_eq(g_choice, mk_app(num_fs, fs));
}
-expr mk_overload_marker() {
- return g_overload;
+bool is_choice(expr const & e) {
+ return is_eq(e) && eq_lhs(e) == g_choice;
+}
+
+unsigned get_num_choices(expr const & e) {
+ lean_assert(is_choice(e));
+ return num_args(eq_rhs(e));
+}
+
+expr const & get_choice(expr const & e, unsigned i) {
+ lean_assert(is_choice(e));
+ return arg(eq_rhs(e), i);
}
class elaborator::imp {
@@ -82,13 +93,22 @@ class elaborator::imp {
volatile bool m_interrupted;
+ expr mk_metavar(context const & ctx) {
+ unsigned midx = m_metavars.size();
+ expr r = ::lean::mk_metavar(midx);
+ m_metavars.push_back(metavar_info());
+ m_metavars[midx].m_mvar = r;
+ m_metavars[midx].m_ctx = ctx;
+ return r;
+ }
+
expr metavar_type(expr const & m) {
lean_assert(is_metavar(m));
unsigned midx = metavar_idx(m);
if (m_metavars[midx].m_type) {
return m_metavars[midx].m_type;
} else {
- expr t = mk_metavar();
+ expr t = mk_metavar(m_metavars[midx].m_ctx);
m_metavars[midx].m_type = t;
return t;
}
@@ -163,67 +183,139 @@ class elaborator::imp {
}
}
- expr infer(expr const & e, context const & ctx) {
+ typedef std::pair expr_pair;
+
+ /**
+ \brief Traverse the expression \c e, and compute
+
+ 1- A new expression that does not contain choice expressions,
+ coercions have been added when appropriate, and placeholders
+ have been replaced with metavariables.
+
+ 2- The type of \c e.
+
+ It also populates m_constraints with a set of constraints that
+ need to be solved to infer the value of the metavariables.
+ */
+ expr_pair process(expr const & e, context const & ctx) {
check_interrupted(m_interrupted);
switch (e.kind()) {
case expr_kind::Constant:
- if (is_metavar(e)) {
- unsigned midx = metavar_idx(e);
- if (!(m_metavars[midx].m_ctx)) {
- lean_assert(!(m_metavars[midx].m_mvar));
- m_metavars[midx].m_mvar = e;
- m_metavars[midx].m_ctx = ctx;
- }
- return metavar_type(e);
+ if (is_placeholder(e)) {
+ expr m = mk_metavar(ctx);
+ m_trace[m] = e;
+ return expr_pair(m, metavar_type(m));
+ } else if (is_metavar(e)) {
+ return expr_pair(e, metavar_type(e));
} else {
- return m_env.get_object(const_name(e)).get_type();
+ return expr_pair(e, m_env.get_object(const_name(e)).get_type());
}
case expr_kind::Var:
- return lookup(ctx, var_idx(e));
+ return expr_pair(e, lookup(ctx, var_idx(e)));
case expr_kind::Type:
- return mk_type(ty_level(e) + 1);
+ return expr_pair(e, mk_type(ty_level(e) + 1));
case expr_kind::Value:
- return to_value(e).get_type();
+ return expr_pair(e, to_value(e).get_type());
case expr_kind::App: {
+ buffer args;
buffer types;
+ buffer f_choices;
+ buffer f_choice_types;
unsigned num = num_args(e);
- for (unsigned i = 0; i < num; i++) {
- types.push_back(infer(arg(e,i), ctx));
- }
- // TODO: handle overload in args[0]
- expr f_t = types[0];
- if (!f_t) {
+ unsigned i = 0;
+ bool modified = false;
+ expr const & f = arg(e, 0);
+ if (is_metavar(f)) {
throw invalid_placeholder_exception(*m_owner, ctx, e);
+ } else if (is_choice(f)) {
+ unsigned num_alts = get_num_choices(f);
+ for (unsigned j = 0; j < num_alts; j++) {
+ auto p = process(get_choice(f, j), ctx);
+ f_choices.push_back(p.first);
+ f_choice_types.push_back(p.second);
+ }
+ args.push_back(expr()); // placeholder
+ types.push_back(expr()); // placeholder
+ modified = true;
+ i++;
}
+ for (; i < num; i++) {
+ expr const & a_i = arg(e, i);
+ auto p = process(a_i, ctx);
+ if (!is_eqp(p.first, a_i))
+ modified = true;
+ args.push_back(p.first);
+ types.push_back(p.second);
+ }
+ // TODO: choose an f from f_choices
+ expr f_t = types[0];
for (unsigned i = 1; i < num; i++) {
f_t = check_pi(f_t, ctx, e, ctx);
if (m_add_constraints)
add_constraint(abst_domain(f_t), types[i], ctx, e, i);
- f_t = instantiate_free_var_mmv(abst_body(f_t), 0, arg(e, i));
+ f_t = instantiate_free_var_mmv(abst_body(f_t), 0, args[i]);
+ }
+ if (modified) {
+ expr new_e = mk_app(args.size(), args.data());
+ m_trace[new_e] = e;
+ return expr_pair(new_e, f_t);
+ } else {
+ return expr_pair(e, f_t);
}
- return f_t;
}
case expr_kind::Eq: {
- infer(eq_lhs(e), ctx);
- infer(eq_rhs(e), ctx);
- return mk_bool_type();
+ auto lhs_p = process(eq_lhs(e), ctx);
+ auto rhs_p = process(eq_rhs(e), ctx);
+ if (is_eqp(lhs_p.first, eq_lhs(e)) && is_eqp(rhs_p.first, eq_rhs(e))) {
+ return expr_pair(e, mk_bool_type());
+ } else {
+ expr new_e = mk_eq(lhs_p.first, rhs_p.first);
+ m_trace[new_e] = e;
+ return expr_pair(new_e, mk_bool_type());
+ }
}
case expr_kind::Pi: {
- expr dt = infer(abst_domain(e), ctx);
- expr bt = infer(abst_body(e), extend(ctx, abst_name(e), abst_domain(e)));
- return mk_type(max(check_universe(dt, ctx, e, ctx), check_universe(bt, ctx, e, ctx)));
+ auto d_p = process(abst_domain(e), ctx);
+ auto b_p = process(abst_body(e), extend(ctx, abst_name(e), d_p.first));
+ expr t = mk_type(max(check_universe(d_p.second, ctx, e, ctx), check_universe(b_p.second, ctx, e, ctx)));
+ if (is_eqp(d_p.first, abst_domain(e)) && is_eqp(b_p.first, abst_body(e))) {
+ return expr_pair(e, t);
+ } else {
+ expr new_e = mk_pi(abst_name(e), d_p.first, b_p.first);
+ m_trace[new_e] = e;
+ return expr_pair(new_e, t);
+ }
}
case expr_kind::Lambda: {
- expr dt = infer(abst_domain(e), ctx);
- expr bt = infer(abst_body(e), extend(ctx, abst_name(e), abst_domain(e)));
- return mk_pi(abst_name(e), abst_domain(e), bt);
+ auto d_p = process(abst_domain(e), ctx);
+ auto b_p = process(abst_body(e), extend(ctx, abst_name(e), d_p.first));
+ expr t = mk_pi(abst_name(e), d_p.first, b_p.second);
+ if (is_eqp(d_p.first, abst_domain(e)) && is_eqp(b_p.first, abst_body(e))) {
+ return expr_pair(e, t);
+ } else {
+ expr new_e = mk_lambda(abst_name(e), d_p.first, b_p.first);
+ m_trace[new_e] = e;
+ return expr_pair(new_e, t);
+ }
}
case expr_kind::Let: {
- expr lt = infer(let_value(e), ctx);
- return lower_free_vars_mmv(infer(let_body(e), extend(ctx, let_name(e), lt, let_value(e))), 1, 1);
+ auto v_p = process(let_value(e), ctx);
+ auto b_p = process(let_body(e), extend(ctx, let_name(e), v_p.second, v_p.first));
+ expr t = lower_free_vars_mmv(b_p.second, 1, 1);
+ if (is_eqp(v_p.first, let_value(e)) && is_eqp(b_p.first, let_body(e))) {
+ return expr_pair(e, t);
+ } else {
+ expr new_e = mk_let(let_name(e), v_p.first, b_p.first);
+ m_trace[new_e] = e;
+ return expr_pair(new_e, t);
+ }
}}
lean_unreachable();
- return expr();
+ return expr_pair(expr(), expr());
+ }
+
+ expr infer(expr const & e, context const & ctx) {
+ return process(e, ctx).second;
}
bool is_simple_ho_match(expr const & e1, expr const & e2, context const & ctx) {
@@ -454,7 +546,8 @@ class elaborator::imp {
return replacer(e);
}
- void solve(unsigned num_meta) {
+ void solve() {
+ unsigned num_meta = m_metavars.size();
m_add_constraints = false;
while (true) {
solve_core();
@@ -493,24 +586,6 @@ class elaborator::imp {
}
}
- expr mk_metavars(expr const & e) {
- // replace placeholders with fresh metavars
- auto proc = [&](expr const & n, unsigned offset) -> expr {
- if (is_placeholder(n)) {
- return mk_metavar();
- } else {
- return n;
- }
- };
- auto tracer = [&](expr const & old_e, expr const & new_e) {
- if (!is_eqp(new_e, old_e)) {
- m_trace[new_e] = old_e;
- }
- };
- replace_fn replacer(proc, tracer);
- return replacer(e);
- }
-
public:
imp(environment const & env, name_set const * defs):
m_env(env),
@@ -519,13 +594,6 @@ public:
m_owner = nullptr;
}
- expr mk_metavar() {
- unsigned midx = m_metavars.size();
- expr r = ::lean::mk_metavar(midx);
- m_metavars.push_back(metavar_info());
- return r;
- }
-
void clear() {
m_trace.clear();
}
@@ -560,12 +628,10 @@ public:
if (has_placeholder(e)) {
m_constraints.clear();
m_metavars.clear();
- m_root = mk_metavars(e);
m_owner = &elb;
- unsigned num_meta = m_metavars.size();
m_add_constraints = true;
- infer(m_root, context());
- solve(num_meta);
+ m_root = process(e, context()).first;
+ solve();
return instantiate(m_root);
} else {
return e;
@@ -607,7 +673,6 @@ public:
};
elaborator::elaborator(environment const & env):m_ptr(new imp(env, nullptr)) {}
elaborator::~elaborator() {}
-expr elaborator::mk_metavar() { return m_ptr->mk_metavar(); }
expr elaborator::operator()(expr const & e) { return (*m_ptr)(e, *this); }
expr const & elaborator::get_original(expr const & e) const { return m_ptr->get_original(e); }
void elaborator::set_interrupt(bool flag) { m_ptr->set_interrupt(flag); }
diff --git a/src/library/elaborator.h b/src/library/elaborator.h
index b8fc4be316..980d11e286 100644
--- a/src/library/elaborator.h
+++ b/src/library/elaborator.h
@@ -23,8 +23,6 @@ public:
explicit elaborator(environment const & env);
~elaborator();
- expr mk_metavar();
-
expr operator()(expr const & e);
/**
@@ -45,8 +43,30 @@ public:
void display(std::ostream & out) const;
format pp(formatter & f, options const & o) const;
};
-/** \brief Return true iff \c e is a special constant used to mark application of overloads. */
-bool is_overload_marker(expr const & e);
-/** \brief Return the overload marker */
-expr mk_overload_marker();
+/**
+ \brief Create a choice expression for the given functions.
+ It is used to mark which functions can be used in a particular application.
+ The elaborator decides which one should be used based on the type of the arguments.
+
+ \pre num_fs >= 2
+*/
+expr mk_choice(unsigned num_fs, expr const * fs);
+/**
+ \brief Return true iff \c e is an expression created using \c mk_choice.
+*/
+bool is_choice(expr const & e);
+/**
+ \brief Return the number of alternatives in a choice expression.
+ We have that get_num_choices(mk_choice(n, fs)) == n.
+
+ \pre is_choice(e)
+*/
+unsigned get_num_choices(expr const & e);
+/**
+ \brief Return the (i+1)-th alternative of a choice expression.
+
+ \pre is_choice(e)
+ \pre i < get_num_choices(e)
+*/
+expr const & get_choice(expr const & e, unsigned i);
}
diff --git a/src/library/metavar.cpp b/src/library/metavar.cpp
index 5955609cdd..e2a53879f1 100644
--- a/src/library/metavar.cpp
+++ b/src/library/metavar.cpp
@@ -107,6 +107,7 @@ bool is_subst(expr const & e) {
}
expr mk_lift_fn(unsigned s, unsigned n) {
+ lean_assert(n > 0);
return mk_constant(name(name(g_lift_prefix, s), n));
}
diff --git a/tests/lean/tst7.lean.expected.out b/tests/lean/tst7.lean.expected.out
index 882ccb0757..4e384d1e2f 100644
--- a/tests/lean/tst7.lean.expected.out
+++ b/tests/lean/tst7.lean.expected.out
@@ -5,7 +5,7 @@ Error (line: 4, pos: 40) application type mismatch during term elaboration at te
Elaborator state
?M0 := [31m[unassigned][0m
?M1 := [31m[unassigned][0m
- #0 ≈ [33mlift[0m:0:0 ?M0
+ #0 ≈ [33mlift[0m:0:2 ?M0
Assumed: myeq
myeq ([33mΠ[0m (A : [36mType[0m) (a : A), A) ([33mλ[0m (A : [36mType[0m) (a : A), a) ([33mλ[0m (B : [36mType[0m) (b : B), b)
Bool