From 598daa40bcf5bfc472d13d5d9990df8d795828f1 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 1 Sep 2013 10:24:10 -0700 Subject: [PATCH] Refactor elaborator for supporting overloads Signed-off-by: Leonardo de Moura --- src/frontends/lean/lean_parser.cpp | 17 ++- src/library/elaborator.cpp | 207 +++++++++++++++++++---------- src/library/elaborator.h | 32 ++++- src/library/metavar.cpp | 1 + tests/lean/tst7.lean.expected.out | 2 +- 5 files changed, 175 insertions(+), 84 deletions(-) diff --git a/src/frontends/lean/lean_parser.cpp b/src/frontends/lean/lean_parser.cpp index bebb45edfa..e0259d568a 100644 --- a/src/frontends/lean/lean_parser.cpp +++ b/src/frontends/lean/lean_parser.cpp @@ -377,9 +377,8 @@ class parser::imp { /** \brief Return the function associated with the given operator. - If the operator has been overloaded, it returns an expression - of the form (overload f_k ... (overload f_2 f_1) ...) - where f_i's are different options. + If the operator has been overloaded, it returns a choice expression + of the form (choice f_1 f_2 ... f_k) where f_i's are different options. After we finish parsing, the procedure #elaborate will resolve/decide which f_i should be used. */ @@ -389,9 +388,15 @@ class parser::imp { auto it = fs.begin(); expr r = *it; ++it; - for (; it != fs.end(); ++it) - r = mk_app(mk_overload_marker(), *it, r); - return r; + if (it == fs.end()) { + return r; + } else { + buffer alternatives; + alternatives.push_back(r); + for (; it != fs.end(); ++it) + alternatives.push_back(*it); + return mk_choice(alternatives.size(), alternatives.data()); + } } /** diff --git a/src/library/elaborator.cpp b/src/library/elaborator.cpp index 0398c4b3b6..37481be270 100644 --- a/src/library/elaborator.cpp +++ b/src/library/elaborator.cpp @@ -17,17 +17,28 @@ Author: Leonardo de Moura #include "elaborator_exception.h" namespace lean { -static name g_overload_name(name(name(name(0u), "library"), "overload")); -static expr g_overload = mk_constant(g_overload_name); +static name g_choice_name(name(name(name(0u), "library"), "choice")); +static expr g_choice = mk_constant(g_choice_name); static format g_assignment_fmt = format(":="); static format g_unification_fmt = format("\u2248"); -bool is_overload_marker(expr const & e) { - return e == g_overload; +expr mk_choice(unsigned num_fs, expr const * fs) { + lean_assert(num_fs >= 2); + return mk_eq(g_choice, mk_app(num_fs, fs)); } -expr mk_overload_marker() { - return g_overload; +bool is_choice(expr const & e) { + return is_eq(e) && eq_lhs(e) == g_choice; +} + +unsigned get_num_choices(expr const & e) { + lean_assert(is_choice(e)); + return num_args(eq_rhs(e)); +} + +expr const & get_choice(expr const & e, unsigned i) { + lean_assert(is_choice(e)); + return arg(eq_rhs(e), i); } class elaborator::imp { @@ -82,13 +93,22 @@ class elaborator::imp { volatile bool m_interrupted; + expr mk_metavar(context const & ctx) { + unsigned midx = m_metavars.size(); + expr r = ::lean::mk_metavar(midx); + m_metavars.push_back(metavar_info()); + m_metavars[midx].m_mvar = r; + m_metavars[midx].m_ctx = ctx; + return r; + } + expr metavar_type(expr const & m) { lean_assert(is_metavar(m)); unsigned midx = metavar_idx(m); if (m_metavars[midx].m_type) { return m_metavars[midx].m_type; } else { - expr t = mk_metavar(); + expr t = mk_metavar(m_metavars[midx].m_ctx); m_metavars[midx].m_type = t; return t; } @@ -163,67 +183,139 @@ class elaborator::imp { } } - expr infer(expr const & e, context const & ctx) { + typedef std::pair expr_pair; + + /** + \brief Traverse the expression \c e, and compute + + 1- A new expression that does not contain choice expressions, + coercions have been added when appropriate, and placeholders + have been replaced with metavariables. + + 2- The type of \c e. + + It also populates m_constraints with a set of constraints that + need to be solved to infer the value of the metavariables. + */ + expr_pair process(expr const & e, context const & ctx) { check_interrupted(m_interrupted); switch (e.kind()) { case expr_kind::Constant: - if (is_metavar(e)) { - unsigned midx = metavar_idx(e); - if (!(m_metavars[midx].m_ctx)) { - lean_assert(!(m_metavars[midx].m_mvar)); - m_metavars[midx].m_mvar = e; - m_metavars[midx].m_ctx = ctx; - } - return metavar_type(e); + if (is_placeholder(e)) { + expr m = mk_metavar(ctx); + m_trace[m] = e; + return expr_pair(m, metavar_type(m)); + } else if (is_metavar(e)) { + return expr_pair(e, metavar_type(e)); } else { - return m_env.get_object(const_name(e)).get_type(); + return expr_pair(e, m_env.get_object(const_name(e)).get_type()); } case expr_kind::Var: - return lookup(ctx, var_idx(e)); + return expr_pair(e, lookup(ctx, var_idx(e))); case expr_kind::Type: - return mk_type(ty_level(e) + 1); + return expr_pair(e, mk_type(ty_level(e) + 1)); case expr_kind::Value: - return to_value(e).get_type(); + return expr_pair(e, to_value(e).get_type()); case expr_kind::App: { + buffer args; buffer types; + buffer f_choices; + buffer f_choice_types; unsigned num = num_args(e); - for (unsigned i = 0; i < num; i++) { - types.push_back(infer(arg(e,i), ctx)); - } - // TODO: handle overload in args[0] - expr f_t = types[0]; - if (!f_t) { + unsigned i = 0; + bool modified = false; + expr const & f = arg(e, 0); + if (is_metavar(f)) { throw invalid_placeholder_exception(*m_owner, ctx, e); + } else if (is_choice(f)) { + unsigned num_alts = get_num_choices(f); + for (unsigned j = 0; j < num_alts; j++) { + auto p = process(get_choice(f, j), ctx); + f_choices.push_back(p.first); + f_choice_types.push_back(p.second); + } + args.push_back(expr()); // placeholder + types.push_back(expr()); // placeholder + modified = true; + i++; } + for (; i < num; i++) { + expr const & a_i = arg(e, i); + auto p = process(a_i, ctx); + if (!is_eqp(p.first, a_i)) + modified = true; + args.push_back(p.first); + types.push_back(p.second); + } + // TODO: choose an f from f_choices + expr f_t = types[0]; for (unsigned i = 1; i < num; i++) { f_t = check_pi(f_t, ctx, e, ctx); if (m_add_constraints) add_constraint(abst_domain(f_t), types[i], ctx, e, i); - f_t = instantiate_free_var_mmv(abst_body(f_t), 0, arg(e, i)); + f_t = instantiate_free_var_mmv(abst_body(f_t), 0, args[i]); + } + if (modified) { + expr new_e = mk_app(args.size(), args.data()); + m_trace[new_e] = e; + return expr_pair(new_e, f_t); + } else { + return expr_pair(e, f_t); } - return f_t; } case expr_kind::Eq: { - infer(eq_lhs(e), ctx); - infer(eq_rhs(e), ctx); - return mk_bool_type(); + auto lhs_p = process(eq_lhs(e), ctx); + auto rhs_p = process(eq_rhs(e), ctx); + if (is_eqp(lhs_p.first, eq_lhs(e)) && is_eqp(rhs_p.first, eq_rhs(e))) { + return expr_pair(e, mk_bool_type()); + } else { + expr new_e = mk_eq(lhs_p.first, rhs_p.first); + m_trace[new_e] = e; + return expr_pair(new_e, mk_bool_type()); + } } case expr_kind::Pi: { - expr dt = infer(abst_domain(e), ctx); - expr bt = infer(abst_body(e), extend(ctx, abst_name(e), abst_domain(e))); - return mk_type(max(check_universe(dt, ctx, e, ctx), check_universe(bt, ctx, e, ctx))); + auto d_p = process(abst_domain(e), ctx); + auto b_p = process(abst_body(e), extend(ctx, abst_name(e), d_p.first)); + expr t = mk_type(max(check_universe(d_p.second, ctx, e, ctx), check_universe(b_p.second, ctx, e, ctx))); + if (is_eqp(d_p.first, abst_domain(e)) && is_eqp(b_p.first, abst_body(e))) { + return expr_pair(e, t); + } else { + expr new_e = mk_pi(abst_name(e), d_p.first, b_p.first); + m_trace[new_e] = e; + return expr_pair(new_e, t); + } } case expr_kind::Lambda: { - expr dt = infer(abst_domain(e), ctx); - expr bt = infer(abst_body(e), extend(ctx, abst_name(e), abst_domain(e))); - return mk_pi(abst_name(e), abst_domain(e), bt); + auto d_p = process(abst_domain(e), ctx); + auto b_p = process(abst_body(e), extend(ctx, abst_name(e), d_p.first)); + expr t = mk_pi(abst_name(e), d_p.first, b_p.second); + if (is_eqp(d_p.first, abst_domain(e)) && is_eqp(b_p.first, abst_body(e))) { + return expr_pair(e, t); + } else { + expr new_e = mk_lambda(abst_name(e), d_p.first, b_p.first); + m_trace[new_e] = e; + return expr_pair(new_e, t); + } } case expr_kind::Let: { - expr lt = infer(let_value(e), ctx); - return lower_free_vars_mmv(infer(let_body(e), extend(ctx, let_name(e), lt, let_value(e))), 1, 1); + auto v_p = process(let_value(e), ctx); + auto b_p = process(let_body(e), extend(ctx, let_name(e), v_p.second, v_p.first)); + expr t = lower_free_vars_mmv(b_p.second, 1, 1); + if (is_eqp(v_p.first, let_value(e)) && is_eqp(b_p.first, let_body(e))) { + return expr_pair(e, t); + } else { + expr new_e = mk_let(let_name(e), v_p.first, b_p.first); + m_trace[new_e] = e; + return expr_pair(new_e, t); + } }} lean_unreachable(); - return expr(); + return expr_pair(expr(), expr()); + } + + expr infer(expr const & e, context const & ctx) { + return process(e, ctx).second; } bool is_simple_ho_match(expr const & e1, expr const & e2, context const & ctx) { @@ -454,7 +546,8 @@ class elaborator::imp { return replacer(e); } - void solve(unsigned num_meta) { + void solve() { + unsigned num_meta = m_metavars.size(); m_add_constraints = false; while (true) { solve_core(); @@ -493,24 +586,6 @@ class elaborator::imp { } } - expr mk_metavars(expr const & e) { - // replace placeholders with fresh metavars - auto proc = [&](expr const & n, unsigned offset) -> expr { - if (is_placeholder(n)) { - return mk_metavar(); - } else { - return n; - } - }; - auto tracer = [&](expr const & old_e, expr const & new_e) { - if (!is_eqp(new_e, old_e)) { - m_trace[new_e] = old_e; - } - }; - replace_fn replacer(proc, tracer); - return replacer(e); - } - public: imp(environment const & env, name_set const * defs): m_env(env), @@ -519,13 +594,6 @@ public: m_owner = nullptr; } - expr mk_metavar() { - unsigned midx = m_metavars.size(); - expr r = ::lean::mk_metavar(midx); - m_metavars.push_back(metavar_info()); - return r; - } - void clear() { m_trace.clear(); } @@ -560,12 +628,10 @@ public: if (has_placeholder(e)) { m_constraints.clear(); m_metavars.clear(); - m_root = mk_metavars(e); m_owner = &elb; - unsigned num_meta = m_metavars.size(); m_add_constraints = true; - infer(m_root, context()); - solve(num_meta); + m_root = process(e, context()).first; + solve(); return instantiate(m_root); } else { return e; @@ -607,7 +673,6 @@ public: }; elaborator::elaborator(environment const & env):m_ptr(new imp(env, nullptr)) {} elaborator::~elaborator() {} -expr elaborator::mk_metavar() { return m_ptr->mk_metavar(); } expr elaborator::operator()(expr const & e) { return (*m_ptr)(e, *this); } expr const & elaborator::get_original(expr const & e) const { return m_ptr->get_original(e); } void elaborator::set_interrupt(bool flag) { m_ptr->set_interrupt(flag); } diff --git a/src/library/elaborator.h b/src/library/elaborator.h index b8fc4be316..980d11e286 100644 --- a/src/library/elaborator.h +++ b/src/library/elaborator.h @@ -23,8 +23,6 @@ public: explicit elaborator(environment const & env); ~elaborator(); - expr mk_metavar(); - expr operator()(expr const & e); /** @@ -45,8 +43,30 @@ public: void display(std::ostream & out) const; format pp(formatter & f, options const & o) const; }; -/** \brief Return true iff \c e is a special constant used to mark application of overloads. */ -bool is_overload_marker(expr const & e); -/** \brief Return the overload marker */ -expr mk_overload_marker(); +/** + \brief Create a choice expression for the given functions. + It is used to mark which functions can be used in a particular application. + The elaborator decides which one should be used based on the type of the arguments. + + \pre num_fs >= 2 +*/ +expr mk_choice(unsigned num_fs, expr const * fs); +/** + \brief Return true iff \c e is an expression created using \c mk_choice. +*/ +bool is_choice(expr const & e); +/** + \brief Return the number of alternatives in a choice expression. + We have that get_num_choices(mk_choice(n, fs)) == n. + + \pre is_choice(e) +*/ +unsigned get_num_choices(expr const & e); +/** + \brief Return the (i+1)-th alternative of a choice expression. + + \pre is_choice(e) + \pre i < get_num_choices(e) +*/ +expr const & get_choice(expr const & e, unsigned i); } diff --git a/src/library/metavar.cpp b/src/library/metavar.cpp index 5955609cdd..e2a53879f1 100644 --- a/src/library/metavar.cpp +++ b/src/library/metavar.cpp @@ -107,6 +107,7 @@ bool is_subst(expr const & e) { } expr mk_lift_fn(unsigned s, unsigned n) { + lean_assert(n > 0); return mk_constant(name(name(g_lift_prefix, s), n)); } diff --git a/tests/lean/tst7.lean.expected.out b/tests/lean/tst7.lean.expected.out index 882ccb0757..4e384d1e2f 100644 --- a/tests/lean/tst7.lean.expected.out +++ b/tests/lean/tst7.lean.expected.out @@ -5,7 +5,7 @@ Error (line: 4, pos: 40) application type mismatch during term elaboration at te Elaborator state ?M0 := [unassigned] ?M1 := [unassigned] - #0 ≈ lift:0:0 ?M0 + #0 ≈ lift:0:2 ?M0 Assumed: myeq myeq (Π (A : Type) (a : A), A) (λ (A : Type) (a : A), a) (λ (B : Type) (b : B), b) Bool