/* Copyright (c) 2015 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ #include #include #include #include "util/priority_queue.h" #include "util/sstream.h" #include "util/flet.h" #include "util/list_fn.h" #include "util/name_hash_map.h" #include "kernel/error_msgs.h" #include "kernel/find_fn.h" #include "kernel/instantiate.h" #include "library/trace.h" #include "library/reducible.h" #include "library/num.h" #include "library/cache_helper.h" #include "library/scoped_ext.h" #include "library/attribute_manager.h" #include "library/type_context.h" #include "library/vm/vm_expr.h" #include "library/vm/vm_list.h" #include "library/vm/vm_option.h" #include "library/vm/vm_name.h" #include "library/tactic/simp_result.h" #include "library/tactic/tactic_state.h" #include "library/tactic/simplifier/ceqv.h" #include "library/tactic/simplifier/simp_lemmas.h" namespace lean { /* Validation */ LEAN_THREAD_VALUE(bool, g_throw_ex, false); void validate_simp(type_context & tctx, name const & n); void validate_congr(type_context & tctx, name const & n); void on_add_simp_lemma(environment const & env, name const & c, bool) { type_context tctx(env); validate_simp(tctx, c); } void on_add_congr_lemma(environment const & env, name const & c, bool) { type_context tctx(env); validate_congr(tctx, c); } /* Getters/checkers */ static void report_failure(sstream const & strm) { if (g_throw_ex){ throw exception(strm); } else { lean_trace(name({"simplifier", "failure"}), tout() << strm.str() << "\n";); } } simp_lemmas add_core(tmp_type_context & tmp_tctx, simp_lemmas const & s, name const & id, levels const & univ_metas, expr const & e, expr const & h, unsigned priority) { list ceqvs = to_ceqvs(tmp_tctx, e, h); if (is_nil(ceqvs)) { report_failure(sstream() << "invalid [simp] lemma '" << id << "' : " << e); return s; } environment const & env = tmp_tctx.tctx().env(); simp_lemmas new_s = s; for (expr_pair const & p : ceqvs) { /* We only clear the eassignment since we want to reuse the temporary universe metavariables associated with the declaration. */ tmp_tctx.clear_eassignment(); expr rule = tmp_tctx.whnf(p.first); expr proof = tmp_tctx.whnf(p.second); bool is_perm = is_permutation_ceqv(env, rule); buffer emetas; buffer instances; while (is_pi(rule)) { expr mvar = tmp_tctx.mk_tmp_mvar(binding_domain(rule)); emetas.push_back(mvar); instances.push_back(binding_info(rule).is_inst_implicit()); rule = tmp_tctx.whnf(instantiate(binding_body(rule), mvar)); proof = mk_app(proof, mvar); } expr rel, lhs, rhs; if (is_simp_relation(env, rule, rel, lhs, rhs) && is_constant(rel)) { new_s.insert(const_name(rel), simp_lemma(id, univ_metas, reverse_to_list(emetas), reverse_to_list(instances), lhs, rhs, proof, is_perm, priority)); } } return new_s; } static simp_lemmas add_core(tmp_type_context & tmp_tctx, simp_lemmas const & s, name const & cname, unsigned priority) { declaration const & d = tmp_tctx.tctx().env().get(cname); buffer us; unsigned num_univs = d.get_num_univ_params(); for (unsigned i = 0; i < num_univs; i++) { us.push_back(tmp_tctx.mk_tmp_univ_mvar()); } levels ls = to_list(us); expr e = tmp_tctx.whnf(instantiate_type_univ_params(d, ls)); expr h = mk_constant(cname, ls); return add_core(tmp_tctx, s, cname, ls, e, h, priority); } // Return true iff lhs is of the form (B (x : ?m1), ?m2) or (B (x : ?m1), ?m2 x), // where B is lambda or Pi static bool is_valid_congr_rule_binding_lhs(expr const & lhs, name_set & found_mvars) { lean_assert(is_binding(lhs)); expr const & d = binding_domain(lhs); expr const & b = binding_body(lhs); if (!is_metavar(d)) return false; if (is_metavar(b) && b != d) { found_mvars.insert(mlocal_name(b)); found_mvars.insert(mlocal_name(d)); return true; } if (is_app(b) && is_metavar(app_fn(b)) && is_var(app_arg(b), 0) && app_fn(b) != d) { found_mvars.insert(mlocal_name(app_fn(b))); found_mvars.insert(mlocal_name(d)); return true; } return false; } // Return true iff all metavariables in e are in found_mvars static bool only_found_mvars(expr const & e, name_set const & found_mvars) { return !find(e, [&](expr const & m, unsigned) { return is_metavar(m) && !found_mvars.contains(mlocal_name(m)); }); } // Check whether rhs is of the form (mvar l_1 ... l_n) where mvar is a metavariable, // and l_i's are local constants, and mvar does not occur in found_mvars. // If it is return true and update found_mvars static bool is_valid_congr_hyp_rhs(expr const & rhs, name_set & found_mvars) { buffer rhs_args; expr const & rhs_fn = get_app_args(rhs, rhs_args); if (!is_metavar(rhs_fn) || found_mvars.contains(mlocal_name(rhs_fn))) return false; for (expr const & arg : rhs_args) if (!is_local(arg)) return false; found_mvars.insert(mlocal_name(rhs_fn)); return true; } simp_lemmas add_congr_core(tmp_type_context & tmp_tctx, simp_lemmas const & s, name const & n, unsigned prio) { declaration const & d = tmp_tctx.tctx().env().get(n); buffer us; unsigned num_univs = d.get_num_univ_params(); for (unsigned i = 0; i < num_univs; i++) { us.push_back(tmp_tctx.mk_tmp_univ_mvar()); } levels ls = to_list(us); expr rule = tmp_tctx.whnf(instantiate_type_univ_params(d, ls)); expr proof = mk_constant(n, ls); buffer emetas; buffer instances, explicits; while (is_pi(rule)) { expr mvar = tmp_tctx.mk_tmp_mvar(binding_domain(rule)); emetas.push_back(mvar); explicits.push_back(is_explicit(binding_info(rule))); instances.push_back(binding_info(rule).is_inst_implicit()); rule = tmp_tctx.whnf(instantiate(binding_body(rule), mvar)); proof = mk_app(proof, mvar); } expr rel, lhs, rhs; if (!is_simp_relation(tmp_tctx.tctx().env(), rule, rel, lhs, rhs) || !is_constant(rel)) { report_failure(sstream() << "invalid [congr] lemma, '" << n << "' resulting type is not of the form t ~ s, where '~' is a transitive and reflexive relation"); } name_set found_mvars; buffer lhs_args, rhs_args; expr const & lhs_fn = get_app_args(lhs, lhs_args); expr const & rhs_fn = get_app_args(rhs, rhs_args); if (is_constant(lhs_fn)) { if (!is_constant(rhs_fn) || const_name(lhs_fn) != const_name(rhs_fn) || lhs_args.size() != rhs_args.size()) { report_failure(sstream() << "invalid [congr] lemma, '" << n << "' resulting type is not of the form (" << const_name(lhs_fn) << " ...) " << "~ (" << const_name(lhs_fn) << " ...), where ~ is '" << const_name(rel) << "'"); } for (expr const & lhs_arg : lhs_args) { if (is_sort(lhs_arg)) continue; if (!is_metavar(lhs_arg) || found_mvars.contains(mlocal_name(lhs_arg))) { report_failure(sstream() << "invalid [congr] lemma, '" << n << "' the left-hand-side of the congruence resulting type must be of the form (" << const_name(lhs_fn) << " x_1 ... x_n), where each x_i is a distinct variable or a sort"); } found_mvars.insert(mlocal_name(lhs_arg)); } } else if (is_binding(lhs)) { if (lhs.kind() != rhs.kind()) { report_failure(sstream() << "invalid [congr] lemma, '" << n << "' kinds of the left-hand-side and right-hand-side of " << "the congruence resulting type do not match"); } if (!is_valid_congr_rule_binding_lhs(lhs, found_mvars)) { report_failure(sstream() << "invalid [congr] lemma, '" << n << "' left-hand-side of the congruence resulting type must " << "be of the form (fun/Pi (x : A), B x)"); } } else { report_failure(sstream() << "invalid [congr] lemma, '" << n << "' left-hand-side is not an application nor a binding"); } buffer congr_hyps; lean_assert(emetas.size() == explicits.size()); for (unsigned i = 0; i < emetas.size(); i++) { expr const & mvar = emetas[i]; if (explicits[i] && !found_mvars.contains(mlocal_name(mvar))) { buffer locals; expr type = mlocal_type(mvar); type_context::tmp_locals local_factory(tmp_tctx.tctx()); while (is_pi(type)) { expr local = local_factory.push_local_from_binding(type); locals.push_back(local); type = instantiate(binding_body(type), local); } expr h_rel, h_lhs, h_rhs; if (!is_simp_relation(tmp_tctx.tctx().env(), type, h_rel, h_lhs, h_rhs) || !is_constant(h_rel)) continue; unsigned j = 0; for (expr const & local : locals) { j++; if (!only_found_mvars(mlocal_type(local), found_mvars)) { report_failure(sstream() << "invalid [congr] lemma, '" << n << "' argument #" << j << " of parameter #" << (i+1) << " contains " << "unresolved parameters"); } } if (!only_found_mvars(h_lhs, found_mvars)) { report_failure(sstream() << "invalid [congr] lemma, '" << n << "' argument #" << (i+1) << " is not a valid hypothesis, the left-hand-side contains " << "unresolved parameters"); } if (!is_valid_congr_hyp_rhs(h_rhs, found_mvars)) { report_failure(sstream() << "invalid [congr] lemma, '" << n << "' argument #" << (i+1) << " is not a valid hypothesis, the right-hand-side must be " << "of the form (m l_1 ... l_n) where m is parameter that was not " << "'assigned/resolved' yet and l_i's are locals"); } found_mvars.insert(mlocal_name(mvar)); congr_hyps.push_back(mvar); } } simp_lemmas new_s = s; new_s.insert(const_name(rel), user_congr_lemma(n, ls, reverse_to_list(emetas), reverse_to_list(instances), lhs, rhs, proof, to_list(congr_hyps), prio)); return new_s; } simp_lemmas add_poly(type_context & tctx, simp_lemmas const & s, name const & id, unsigned priority) { tmp_type_context tmp_tctx(tctx); return add_core(tmp_tctx, s, id, priority); } simp_lemmas add(type_context & tctx, simp_lemmas const & s, name const & id, expr const & e, expr const & h, unsigned priority) { tmp_type_context tmp_tctx(tctx); return add_core(tmp_tctx, s, id, list(), e, h, priority); } simp_lemmas join(simp_lemmas const & s1, simp_lemmas const & s2) { if (s1.empty()) return s2; if (s2.empty()) return s1; simp_lemmas new_s1 = s1; buffer> slemmas; s2.for_each_simp([&](name const & eqv, simp_lemma const & r) { slemmas.push_back({eqv, r}); }); for (unsigned i = slemmas.size() - 1; i + 1 > 0; --i) new_s1.insert(slemmas[i].first, slemmas[i].second); buffer> clemmas; s2.for_each_congr([&](name const & eqv, user_congr_lemma const & r) { clemmas.push_back({eqv, r}); }); for (unsigned i = clemmas.size() - 1; i + 1 > 0; --i) new_s1.insert(clemmas[i].first, clemmas[i].second); return new_s1; } void validate_simp(type_context & tctx, name const & n) { simp_lemmas s; flet set_ex(g_throw_ex, true); tmp_type_context tmp_tctx(tctx); add_core(tmp_tctx, s, n, LEAN_DEFAULT_PRIORITY); } void validate_congr(type_context & tctx, name const & n) { simp_lemmas s; flet set_ex(g_throw_ex, true); tmp_type_context tmp_tctx(tctx); add_congr_core(tmp_tctx, s, n, LEAN_DEFAULT_PRIORITY); } simp_lemma_core::simp_lemma_core(name const & id, levels const & umetas, list const & emetas, list const & instances, expr const & lhs, expr const & rhs, expr const & proof, unsigned priority): m_id(id), m_umetas(umetas), m_emetas(emetas), m_instances(instances), m_lhs(lhs), m_rhs(rhs), m_proof(proof), m_priority(priority) {} simp_lemma::simp_lemma(name const & id, levels const & umetas, list const & emetas, list const & instances, expr const & lhs, expr const & rhs, expr const & proof, bool is_perm, unsigned priority): simp_lemma_core(id, umetas, emetas, instances, lhs, rhs, proof, priority), m_is_permutation(is_perm) {} bool operator==(simp_lemma const & r1, simp_lemma const & r2) { return r1.m_lhs == r2.m_lhs && r1.m_rhs == r2.m_rhs; } format simp_lemma::pp(formatter const & fmt) const { format r; r += format("[") + format(get_id()) + format("]") + space(); r += format("#") + format(get_num_emeta()); if (m_priority != LEAN_DEFAULT_PRIORITY) r += space() + paren(format(m_priority)); if (m_is_permutation) r += space() + format("perm"); format r1 = comma() + space() + fmt(get_lhs()); r1 += space() + format("↦") + pp_indent_expr(fmt, get_rhs()); r += group(r1); return r; } user_congr_lemma::user_congr_lemma(name const & id, levels const & umetas, list const & emetas, list const & instances, expr const & lhs, expr const & rhs, expr const & proof, list const & congr_hyps, unsigned priority): simp_lemma_core(id, umetas, emetas, instances, lhs, rhs, proof, priority), m_congr_hyps(congr_hyps) {} bool operator==(user_congr_lemma const & r1, user_congr_lemma const & r2) { return r1.m_lhs == r2.m_lhs && r1.m_rhs == r2.m_rhs && r1.m_congr_hyps == r2.m_congr_hyps; } format user_congr_lemma::pp(formatter const & fmt) const { format r; r += format("[") + format(get_id()) + format("]") + space(); r += format("#") + format(get_num_emeta()); if (m_priority != LEAN_DEFAULT_PRIORITY) r += space() + paren(format(m_priority)); format r1; for (expr const & h : m_congr_hyps) { r1 += space() + paren(fmt(mlocal_type(h))); } r += group(r1); r += space() + format(":") + space(); format r2 = paren(fmt(get_lhs())); r2 += space() + format("↦") + space() + paren(fmt(get_rhs())); r += group(r2); return r; } simp_lemmas_for::simp_lemmas_for(name const & eqv): m_eqv(eqv) {} void simp_lemmas_for::insert(simp_lemma const & r) { m_simp_set.insert(r.get_lhs(), r); } void simp_lemmas_for::erase(simp_lemma const & r) { m_simp_set.erase(r.get_lhs(), r); } void simp_lemmas_for::insert(user_congr_lemma const & r) { m_congr_set.insert(r.get_lhs(), r); } void simp_lemmas_for::erase(user_congr_lemma const & r) { m_congr_set.erase(r.get_lhs(), r); } list const * simp_lemmas_for::find_simp(head_index const & h) const { return m_simp_set.find(h); } void simp_lemmas_for::for_each_simp(std::function const & fn) const { m_simp_set.for_each_entry([&](head_index const &, simp_lemma const & r) { fn(r); }); } list const * simp_lemmas_for::find_congr(head_index const & h) const { return m_congr_set.find(h); } void simp_lemmas_for::for_each_congr(std::function const & fn) const { m_congr_set.for_each_entry([&](head_index const &, user_congr_lemma const & r) { fn(r); }); } void simp_lemmas_for::erase_simp(name_set const & ids) { // This method is not very smart and doesn't use any indexing or caching. // So, it may be a bottleneck in the future buffer to_delete; for_each_simp([&](simp_lemma const & r) { if (ids.contains(r.get_id())) { to_delete.push_back(r); } }); for (simp_lemma const & r : to_delete) { erase(r); } } void simp_lemmas_for::erase_simp(buffer const & ids) { erase_simp(to_name_set(ids)); } template void simp_lemmas::insert_core(name const & eqv, R const & r) { simp_lemmas_for s(eqv); if (auto const * curr = m_sets.find(eqv)) { s = *curr; } s.insert(r); m_sets.insert(eqv, s); } template void simp_lemmas::erase_core(name const & eqv, R const & r) { if (auto const * curr = m_sets.find(eqv)) { simp_lemmas_for s = *curr; s.erase(r); if (s.empty()) m_sets.erase(eqv); else m_sets.insert(eqv, s); } } void simp_lemmas::insert(name const & eqv, simp_lemma const & r) { return insert_core(eqv, r); } void simp_lemmas::erase(name const & eqv, simp_lemma const & r) { return erase_core(eqv, r); } void simp_lemmas::insert(name const & eqv, user_congr_lemma const & r) { return insert_core(eqv, r); } void simp_lemmas::erase(name const & eqv, user_congr_lemma const & r) { return erase_core(eqv, r); } void simp_lemmas::get_relations(buffer & rs) const { m_sets.for_each([&](name const & r, simp_lemmas_for const &) { rs.push_back(r); }); } void simp_lemmas::erase_simp(name_set const & ids) { name_map new_sets; m_sets.for_each([&](name const & n, simp_lemmas_for const & s) { simp_lemmas_for new_s = s; new_s.erase_simp(ids); new_sets.insert(n, new_s); }); m_sets = new_sets; } void simp_lemmas::erase_simp(buffer const & ids) { erase_simp(to_name_set(ids)); } simp_lemmas_for const * simp_lemmas::find(name const & eqv) const { return m_sets.find(eqv); } list const * simp_lemmas::find_simp(name const & eqv, head_index const & h) const { if (auto const * s = m_sets.find(eqv)) return s->find_simp(h); return nullptr; } list const * simp_lemmas::find_congr(name const & eqv, head_index const & h) const { if (auto const * s = m_sets.find(eqv)) return s->find_congr(h); return nullptr; } void simp_lemmas::for_each_simp(std::function const & fn) const { m_sets.for_each([&](name const & eqv, simp_lemmas_for const & s) { s.for_each_simp([&](simp_lemma const & r) { fn(eqv, r); }); }); } void simp_lemmas::for_each_congr(std::function const & fn) const { m_sets.for_each([&](name const & eqv, simp_lemmas_for const & s) { s.for_each_congr([&](user_congr_lemma const & r) { fn(eqv, r); }); }); } format simp_lemmas::pp(formatter const & fmt, format const & header, bool simp, bool congr) const { format r; if (simp) { name prev_eqv; for_each_simp([&](name const & eqv, simp_lemma const & rw) { if (prev_eqv != eqv) { r += format("simplification rules for ") + format(eqv); r += header; r += line(); prev_eqv = eqv; } r += rw.pp(fmt) + line(); }); } if (congr) { name prev_eqv; for_each_congr([&](name const & eqv, user_congr_lemma const & cr) { if (prev_eqv != eqv) { r += format("congruencec rules for ") + format(eqv) + line(); prev_eqv = eqv; } r += cr.pp(fmt) + line(); }); } return r; } format simp_lemmas::pp_simp(formatter const & fmt, format const & header) const { return pp(fmt, header, true, false); } format simp_lemmas::pp_simp(formatter const & fmt) const { return pp(fmt, format(), true, false); } format simp_lemmas::pp_congr(formatter const & fmt) const { return pp(fmt, format(), false, true); } format simp_lemmas::pp(formatter const & fmt) const { return pp(fmt, format(), true, true); } simp_lemmas get_simp_lemmas_from_attribute(type_context & ctx, name const & attr_name, simp_lemmas result) { auto const & attr = get_attribute(ctx.env(), attr_name); buffer simp_lemmas; attr.get_instances(ctx.env(), simp_lemmas); unsigned i = simp_lemmas.size(); while (i > 0) { i--; name const & id = simp_lemmas[i]; // TODO(Leo): fix the following hack tmp_type_context tmp_tctx(ctx); result = add_core(tmp_tctx, result, id, attr.get_prio(ctx.env(), id)); } return result; } simp_lemmas get_congr_lemmas_from_attribute(type_context & ctx, name const & attr_name, simp_lemmas result) { auto const & attr = get_attribute(ctx.env(), attr_name); buffer congr_lemmas; attr.get_instances(ctx.env(), congr_lemmas); unsigned i = congr_lemmas.size(); while (i > 0) { i--; name const & id = congr_lemmas[i]; tmp_type_context tmp_tctx(ctx); result = add_congr_core(tmp_tctx, result, id, attr.get_prio(ctx.env(), id)); } return result; } struct simp_lemmas_config { std::vector m_simp_attrs; std::vector m_congr_attrs; }; static std::vector * g_simp_lemmas_configs = nullptr; static name_map * g_name2simp_token = nullptr; static simp_lemmas_token g_default_token; simp_lemmas_token register_simp_attribute(name const & user_name, std::initializer_list const & simp_attrs, std::initializer_list const & congr_attrs) { simp_lemmas_config cfg; for (name const & attr_name : simp_attrs) { cfg.m_simp_attrs.push_back(attr_name); if (!is_system_attribute(attr_name)) { register_system_attribute(basic_attribute::with_check(attr_name, "simplification lemma", on_add_simp_lemma)); } } for (name const & attr_name : congr_attrs) { cfg.m_congr_attrs.push_back(attr_name); if (!is_system_attribute(attr_name)) { register_system_attribute(basic_attribute::with_check(attr_name, "congruence lemma", on_add_congr_lemma)); } } simp_lemmas_token tk = g_simp_lemmas_configs->size(); g_simp_lemmas_configs->push_back(cfg); g_name2simp_token->insert(user_name, tk); return tk; } simp_lemmas_config const & get_simp_lemmas_config(simp_lemmas_token tk) { lean_assert(tk < g_simp_lemmas_configs->size()); return (*g_simp_lemmas_configs)[tk]; } /* This is the cache for internally used simp_lemma collections */ class simp_lemmas_cache { struct entry { environment m_env; std::vector m_fingerprints; unsigned m_reducibility_fingerprint; optional m_lemmas; entry(environment const & env): m_env(env), m_reducibility_fingerprint(0) {} }; std::vector m_entries[4]; public: void expand(environment const & env, transparency_mode m, unsigned new_sz) { unsigned midx = static_cast(m); for (unsigned tk = m_entries[midx].size(); tk < new_sz; tk++) { auto & cfg = get_simp_lemmas_config(tk); m_entries[midx].emplace_back(env); auto & entry = m_entries[midx].back(); entry.m_fingerprints.resize(cfg.m_simp_attrs.size() + cfg.m_congr_attrs.size()); } } simp_lemmas mk_lemmas(environment const & env, transparency_mode m, entry & C, simp_lemmas_token tk) { lean_trace("simp_lemmas_cache", tout() << "make simp lemmas [" << tk << "]\n";); type_context ctx(env, m); C.m_env = env; auto & cfg = get_simp_lemmas_config(tk); simp_lemmas lemmas; unsigned i = 0; for (name const & attr_name : cfg.m_simp_attrs) { lemmas = get_simp_lemmas_from_attribute(ctx, attr_name, lemmas); C.m_fingerprints[i] = get_attribute_fingerprint(env, attr_name); i++; } for (name const & attr_name : cfg.m_congr_attrs) { lemmas = get_congr_lemmas_from_attribute(ctx, attr_name, lemmas); C.m_fingerprints[i] = get_attribute_fingerprint(env, attr_name); i++; } C.m_lemmas = lemmas; C.m_reducibility_fingerprint = get_reducibility_fingerprint(env); return lemmas; } simp_lemmas lemmas_of(entry const & C, simp_lemmas_token tk) { lean_trace("simp_lemmas_cache", tout() << "reusing cached simp lemmas [" << tk << "]\n";); return *C.m_lemmas; } bool is_compatible(entry const & C, environment const & env, simp_lemmas_token tk) { if (!env.is_descendant(C.m_env)) return false; if (get_reducibility_fingerprint(env) != C.m_reducibility_fingerprint) return false; auto & cfg = get_simp_lemmas_config(tk); unsigned i = 0; for (name const & attr_name : cfg.m_simp_attrs) { if (get_attribute_fingerprint(env, attr_name) != C.m_fingerprints[i]) return false; i++; } for (name const & attr_name : cfg.m_congr_attrs) { if (get_attribute_fingerprint(env, attr_name) != C.m_fingerprints[i]) return false; i++; } return true; } simp_lemmas get(environment const & env, transparency_mode m, simp_lemmas_token tk) { lean_assert(tk < g_simp_lemmas_configs->size()); unsigned midx = static_cast(m); if (tk >= m_entries[midx].size()) expand(env, m, tk+1); lean_assert(tk < m_entries[midx].size()); entry & C = m_entries[midx][tk]; if (!C.m_lemmas) return mk_lemmas(env, m, C, tk); if (is_eqp(env, C.m_env)) return lemmas_of(C, tk); if (!is_compatible(C, env, tk)) { lean_trace("simp_lemmas_cache", tout() << "creating new cache\n";); return mk_lemmas(env, m, C, tk); } return lemmas_of(C, tk); } }; MK_THREAD_LOCAL_GET_DEF(simp_lemmas_cache, get_cache); simp_lemmas get_simp_lemmas(environment const & env, transparency_mode m, simp_lemmas_token tk) { return get_cache().get(env, m, tk); } simp_lemmas get_default_simp_lemmas(environment const & env, transparency_mode m) { return get_simp_lemmas(env, m, g_default_token); } simp_lemmas get_simp_lemmas(environment const & env, transparency_mode m, name const & tk_name) { if (simp_lemmas_token const * tk = g_name2simp_token->find(tk_name)) return get_simp_lemmas(env, m, *tk); else throw exception(sstream() << "unknown simp_lemmas collection '" << tk_name << "'"); } struct vm_simp_lemmas : public vm_external { simp_lemmas m_val; vm_simp_lemmas(simp_lemmas const & v): m_val(v) {} virtual ~vm_simp_lemmas() {} virtual void dealloc() override { this->~vm_simp_lemmas(); get_vm_allocator().deallocate(sizeof(vm_simp_lemmas), this); } }; simp_lemmas const & to_simp_lemmas(vm_obj const & o) { lean_assert(is_external(o)); lean_assert(dynamic_cast(to_external(o))); return static_cast(to_external(o))->m_val; } vm_obj to_obj(simp_lemmas const & idx) { return mk_vm_external(new (get_vm_allocator().allocate(sizeof(vm_simp_lemmas))) vm_simp_lemmas(idx)); } vm_obj tactic_mk_default_simp_lemmas(vm_obj const & m, vm_obj const & s) { LEAN_TACTIC_TRY; environment const & env = to_tactic_state(s).env(); simp_lemmas r = get_default_simp_lemmas(env, to_transparency_mode(m)); return mk_tactic_success(to_obj(r), to_tactic_state(s)); LEAN_TACTIC_CATCH(to_tactic_state(s)); } vm_obj tactic_simp_lemmas_insert(vm_obj const & m, vm_obj const & lemmas, vm_obj const & lemma, vm_obj const & s) { LEAN_TACTIC_TRY; type_context tctx = mk_type_context_for(s, m); expr e = to_expr(lemma); name id; if (is_constant(e)) id = const_name(e); else if (is_local(e)) id = local_pp_name(e); expr e_type = tctx.infer(e); // TODO(dhs): accept priority as an argument // Reason for postponing: better plumbing of numerals through the vm simp_lemmas new_lemmas = add(tctx, to_simp_lemmas(lemmas), id, tctx.infer(e), e, LEAN_DEFAULT_PRIORITY); return mk_tactic_success(to_obj(new_lemmas), to_tactic_state(s)); LEAN_TACTIC_CATCH(to_tactic_state(s)); } vm_obj tactic_simp_lemmas_insert_constant(vm_obj const & m, vm_obj const & lemmas, vm_obj const & lemma_name, vm_obj const & s) { LEAN_TACTIC_TRY; type_context ctx = mk_type_context_for(s, m); simp_lemmas new_lemmas = add_poly(ctx, to_simp_lemmas(lemmas), to_name(lemma_name), LEAN_DEFAULT_PRIORITY); return mk_tactic_success(to_obj(new_lemmas), to_tactic_state(s)); LEAN_TACTIC_CATCH(to_tactic_state(s)); } vm_obj simp_lemmas_mk() { return to_obj(simp_lemmas()); } vm_obj simp_lemmas_join(vm_obj const & lemmas1, vm_obj const & lemmas2) { return to_obj(join(to_simp_lemmas(lemmas1), to_simp_lemmas(lemmas2))); } vm_obj simp_lemmas_erase(vm_obj const & lemmas, vm_obj const & lemma_list) { name_set S; for (name const & l : to_list_name(lemma_list)) S.insert(l); simp_lemmas new_lemmas = to_simp_lemmas(lemmas); new_lemmas.erase_simp(S); return to_obj(new_lemmas); } static optional prove(type_context & ctx, vm_obj const & prove_fn, expr const & e) { tactic_state s = mk_tactic_state_for(ctx.env(), ctx.get_options(), ctx.lctx(), e); vm_obj r_obj = invoke(prove_fn, to_obj(s)); optional s_new = is_tactic_success(r_obj); if (!s_new || s_new->goals()) return none_expr(); metavar_context mctx = s_new->mctx(); expr result = mctx.instantiate_mvars(s_new->main()); if (has_expr_metavar(result)) return none_expr(); ctx.set_mctx(mctx); return some_expr(result); } static bool instantiate_emetas(type_context & ctx, vm_obj const & prove_fn, unsigned num_emeta, list const & emetas, list const & instances) { environment const & env = ctx.env(); bool failed = false; unsigned i = num_emeta; for_each2(emetas, instances, [&](expr const & m, bool const & is_instance) { i--; if (failed) return; expr m_type = ctx.instantiate_mvars(ctx.infer(m)); if (has_metavar(m_type)) { failed = true; return; } if (ctx.get_tmp_mvar_assignment(i)) return; if (is_instance) { if (auto v = ctx.mk_class_instance(m_type)) { if (!ctx.is_def_eq(m, *v)) { lean_trace("simp_lemmas", scope_trace_env scope(env, ctx); tout() << "unable to assign instance for: " << m_type << "\n";); failed = true; return; } } else { lean_trace("simp_lemmas", scope_trace_env scope(env, ctx); tout() << "unable to synthesize instance for: " << m_type << "\n";); failed = true; return; } } if (ctx.get_tmp_mvar_assignment(i)) return; // Note: m_type has no metavars if (ctx.is_prop(m_type)) { if (auto pf = prove(ctx, prove_fn, m_type)) { lean_verify(ctx.is_def_eq(m, *pf)); return; } } lean_trace("simp_lemmas", scope_trace_env scope(env, ctx); tout() << "failed to assign: " << m << " : " << m_type << "\n";); failed = true; return; }); return !failed; } static simp_result simp_lemma_apply(type_context & ctx, simp_lemma const & sl, vm_obj const & prove_fn, expr const & e) { type_context::tmp_mode_scope scope(ctx, sl.get_num_umeta(), sl.get_num_emeta()); if (!ctx.is_def_eq(e, sl.get_lhs())) { lean_trace("simp_lemmas", tout() << "fail to unify: " << sl.get_id() << "\n";); return simp_result(e); } if (!instantiate_emetas(ctx, prove_fn, sl.get_num_emeta(), sl.get_emetas(), sl.get_instances())) { lean_trace("simp_lemmas", tout() << "fail to instantiate emetas: " << sl.get_id() << "\n";); return simp_result(e); } for (unsigned i = 0; i < sl.get_num_umeta(); i++) { if (!ctx.get_tmp_uvar_assignment(i)) return simp_result(e); } expr new_lhs = ctx.instantiate_mvars(sl.get_lhs()); expr new_rhs = ctx.instantiate_mvars(sl.get_rhs()); if (sl.is_perm()) { if (!is_lt(new_rhs, new_lhs, false)) { lean_trace("simp_lemmas", scope_trace_env scope(ctx.env(), ctx); tout() << "perm rejected: " << new_rhs << " !< " << new_lhs << "\n";); return simp_result(e); } } expr pf = ctx.instantiate_mvars(sl.get_proof()); return simp_result(new_rhs, pf); } vm_obj simp_lemmas_apply(transparency_mode const & m, simp_lemmas const & sls, vm_obj const & prove_fn, name const & R, expr const & e, tactic_state const & s) { LEAN_TACTIC_TRY; simp_lemmas_for const * sr = sls.find(R); if (!sr) return mk_tactic_exception("failed to apply simp_lemmas, no lemmas for the given relation", s); list const * srs = sr->find_simp(e); if (!srs) return mk_tactic_exception("failed to apply simp_lemmas, no simp lemma", s); type_context ctx = mk_type_context_for(s, m); for (simp_lemma const & lemma : *srs) { simp_result r = simp_lemma_apply(ctx, lemma, prove_fn, e); if (r.has_proof()) { lean_trace("simp_lemmas", scope_trace_env scope(ctx.env(), ctx); tout() << "[" << lemma.get_id() << "]: " << e << " ==> " << r.get_new() << "\n";); vm_obj new_e = to_obj(r.get_new()); vm_obj new_pr = to_obj(r.get_proof()); return mk_tactic_success(mk_vm_pair(new_e, new_pr), s); } } return mk_tactic_exception("failed to apply simp_lemmas, no simp lemma", s); LEAN_TACTIC_CATCH(s); } static vm_obj tactic_simp_lemmas_apply(vm_obj const & m, vm_obj const & sls, vm_obj const & prove_fn, vm_obj const & R, vm_obj const & e, vm_obj const & s) { return simp_lemmas_apply(to_transparency_mode(m), to_simp_lemmas(sls), prove_fn, to_name(R), to_expr(e), to_tactic_state(s)); } void initialize_simp_lemmas() { g_simp_lemmas_configs = new std::vector(); g_name2simp_token = new name_map(); g_default_token = register_simp_attribute("default", {"simp"}, {"congr"}); register_trace_class("simp_lemmas"); DECLARE_VM_BUILTIN(name({"simp_lemmas", "mk"}), simp_lemmas_mk); DECLARE_VM_BUILTIN(name({"simp_lemmas", "join"}), simp_lemmas_join); DECLARE_VM_BUILTIN(name({"simp_lemmas", "erase"}), simp_lemmas_erase); DECLARE_VM_BUILTIN(name({"tactic", "mk_default_simp_lemmas_core"}), tactic_mk_default_simp_lemmas); DECLARE_VM_BUILTIN(name({"tactic", "simp_lemmas_insert_core"}), tactic_simp_lemmas_insert); DECLARE_VM_BUILTIN(name({"tactic", "simp_lemmas_insert_constant_core"}), tactic_simp_lemmas_insert_constant); DECLARE_VM_BUILTIN(name({"tactic", "simp_lemmas_apply_core"}), tactic_simp_lemmas_apply); } void finalize_simp_lemmas() { delete g_simp_lemmas_configs; delete g_name2simp_token; } }