diff --git a/src/library/inductive_compiler/nested.cpp b/src/library/inductive_compiler/nested.cpp index 164547e4b9..7915fa7b69 100644 --- a/src/library/inductive_compiler/nested.cpp +++ b/src/library/inductive_compiler/nested.cpp @@ -1269,6 +1269,7 @@ class add_nested_inductive_decl_fn { defeq_can_state dcs; simplify_fn simplifier(tctx, dcs, all_lemmas, max_steps, contextual, lift_eq, canonize_instances, canonize_proofs, use_axioms); + simplifier.set_use_matcher(false); // hack auto thm_pr = simplifier.prove_by_simp(get_eq_name(), thm); if (!thm_pr) { formatter_factory const & fmtf = get_global_ios().get_formatter_factory(); diff --git a/src/library/tactic/simplify.cpp b/src/library/tactic/simplify.cpp index 0d15d126b6..a4b45ffcc0 100644 --- a/src/library/tactic/simplify.cpp +++ b/src/library/tactic/simplify.cpp @@ -497,14 +497,79 @@ simp_result simplify_core_fn::rewrite(expr const & e) { return simp_result(e); } +struct match_fn { + tmp_type_context & m_ctx; + name const & m_id; + buffer> m_postponed; + + match_fn(tmp_type_context & ctx, name const & id):m_ctx(ctx), m_id(id) {} + + bool match(expr const & p, expr const & t) { + if (m_ctx.ctx().is_mvar(p)) + if (auto v = m_ctx.ctx().get_assignment(p)) + return match(*v, t); + if (is_app(p) && is_app(t)) { + expr const & fn = get_app_fn(p); + if (m_ctx.is_def_eq(fn, get_app_fn(t))) { + buffer p_args; + buffer t_args; + get_app_args(p, p_args); + get_app_args(t, t_args); + fun_info finfo = get_fun_info(m_ctx.ctx(), fn); + if (p_args.size() != t_args.size()) + return false; + auto it = finfo.get_params_info(); + for (unsigned i = 0; i < p_args.size(); i++) { + if (it && head(it).is_inst_implicit()) { + m_postponed.emplace_back(p_args[i], t_args[i], true); + } else if (it && head(it).is_implicit()) { + m_postponed.emplace_back(p_args[i], t_args[i], false); + } else if (!match(p_args[i], t_args[i])) { + return false; + } + if (it) it = tail(it); + } + return true; + } + } + return m_ctx.is_def_eq(p, t); + } + + bool operator()(expr const & p, expr const & t) { + if (!match(p, t)) return false; + + for (auto const & e : m_postponed) { + expr p1, t1; bool implicit; + std::tie(p1, t1, implicit) = e; + p1 = m_ctx.instantiate_mvars(p1); + if (implicit) + p1 = m_ctx.ctx().complete_instance(p1); + if (!match(p1, t1)) { + lean_trace(name({"simplify", "implicit_failure"}), scope_trace_env scope(m_ctx.env(), m_ctx.ctx()); + tout() << "fail to match '" << m_id << "':\n"; + tout() << p << "\n=?=\n" << t << "\nbecause the following implicit match\n"; + tout() << p1 << "\n=?=\n" << t1 << "\n";); + return false; + } + } + return true; + } +}; + +bool simplify_core_fn::match(tmp_type_context & ctx, simp_lemma const & sl, expr const & t) { + if (m_use_matcher) + return match_fn(ctx, sl.get_id())(sl.get_lhs(), t); + else + return ctx.is_def_eq(t, sl.get_lhs()); +} + simp_result simplify_core_fn::rewrite_core(expr const & e, simp_lemma const & sl) { tmp_type_context tmp_ctx(m_ctx, sl.get_num_umeta(), sl.get_num_emeta()); - if (!tmp_ctx.is_def_eq(e, sl.get_lhs())) { + + if (!match(tmp_ctx, sl, e)) { lean_trace_d(name({"debug", "simplify", "try_rewrite"}), - tout() << "fail to unify '" << sl.get_id() - << "':\n------------------------------------------------\n" - << e << "\n=?=\n" << sl.get_lhs() - << "\n------------------------------------------------\n";); + tout() << "fail to unify '" << sl.get_id() << "':\n" + << e << "\n=?=\n" << sl.get_lhs() << "\n--------------\n";); return simp_result(e); } @@ -1158,6 +1223,7 @@ vm_obj tactic_ext_simplify_core(unsigned DEBUG_CODE(num), vm_obj const * args) { void initialize_simplify() { register_trace_class("simplify"); register_trace_class(name({"simplify", "failure"})); + register_trace_class(name({"simplify", "implicit_failure"})); register_trace_class(name({"simplify", "context"})); register_trace_class(name({"simplify", "canonize"})); register_trace_class(name({"simplify", "congruence"})); diff --git a/src/library/tactic/simplify.h b/src/library/tactic/simplify.h index 0502b67bd9..7e8fc0bfd1 100644 --- a/src/library/tactic/simplify.h +++ b/src/library/tactic/simplify.h @@ -51,6 +51,9 @@ protected: bool m_lift_eq; bool m_canonize_instances; bool m_canonize_proofs; + /* The following option should be removed as soon as we + refactor the inductive compiler. */ + bool m_use_matcher{true}; simp_result join(simp_result const & r1, simp_result const & r2); void inc_num_steps(); @@ -100,6 +103,8 @@ protected: simp_result simplify(expr const & e); + bool match(tmp_type_context & ctx, simp_lemma const & sl, expr const & t); + public: simplify_core_fn(type_context & ctx, defeq_canonizer::state & dcs, simp_lemmas const & slss, unsigned max_steps, bool contextual, bool lift_eq, @@ -108,6 +113,8 @@ public: environment const & env() const; simp_result operator()(name const & rel, expr const & e); + void set_use_matcher(bool flag) { m_use_matcher = flag; } + optional prove_by_simp(name const & rel, expr const & e); }; diff --git a/src/library/type_context.cpp b/src/library/type_context.cpp index 4713b6fcca..f459071efc 100644 --- a/src/library/type_context.cpp +++ b/src/library/type_context.cpp @@ -2220,73 +2220,18 @@ bool type_context::is_def_eq_args(expr const & e1, expr const & e2) { return false; fun_info finfo = get_fun_info(*this, fn, args1.size()); unsigned i = 0; - /* - Try to solve unification constraint - (f a_1 ... a_n) =?= (f b_1 ... b_n) - by solving - a_i =?= b_i - - We add i to postponed, if a_i or b_i is a numeral and - the i-th argument of f has forward dependencies. - - The goal is to be able to handle unification constraints - coming from fixed size vector problems. - - In this kind of problem, we have constraints such as - - @to_list ?α (?m - ?n) (@dropn ?α ?m ?n ?v) =?= @to_list bool 6 (@dropn bool 8 2 v) - - In the first pass, we have the constraint - - ?m - ?n =?= 6 - - which cannot be solved. However, after we unify the next argument, we have - - ?m := 8 and ?n := 2 - - and the constraint above becomes - - 8 - 2 =?= 6 - - which can be solved. - */ - buffer postponed; - bool progress = false; for (param_info const & pinfo : finfo.get_params_info()) { if (pinfo.is_inst_implicit()) { args1[i] = complete_instance(args1[i]); args2[i] = complete_instance(args2[i]); } - if (is_def_eq_core(args1[i], args2[i])) { - progress = true; - } else if (pinfo.has_fwd_deps() && (to_small_num(args1[i]) || to_small_num(args2[i]))) { - postponed.push_back(i); - } else { + if (!is_def_eq_core(args1[i], args2[i])) return false; - } i++; } for (; i < args1.size(); i++) { - if (is_def_eq_core(args1[i], args2[i])) { - progress = true; - } else { + if (!is_def_eq_core(args1[i], args2[i])) return false; - } - } - while (true) { - if (postponed.empty()) return true; - if (!progress) return false; - progress = false; - unsigned j = 0; - for (unsigned i = 0; i < postponed.size(); i++) { - if (is_def_eq_core(instantiate_mvars(args1[postponed[i]]), instantiate_mvars(args2[postponed[i]]))) { - progress = true; - } else { - postponed[j] = postponed[i]; - j++; - } - } - postponed.shrink(j); } return true; } @@ -3675,34 +3620,34 @@ expr type_context::eta_expand(expr const & e) { return locals.mk_lambda(r); } -tmp_type_context::tmp_type_context(type_context & tctx, unsigned num_umeta, unsigned num_emeta): m_tctx(tctx) { +tmp_type_context::tmp_type_context(type_context & ctx, unsigned num_umeta, unsigned num_emeta): m_ctx(ctx) { m_tmp_uassignment.resize(num_umeta, none_level()); m_tmp_eassignment.resize(num_emeta, none_expr()); } bool tmp_type_context::is_def_eq(expr const & e1, expr const & e2) { - type_context::tmp_mode_scope_with_buffers tmp_scope(m_tctx, m_tmp_uassignment, m_tmp_eassignment); - return m_tctx.is_def_eq(e1, e2); + type_context::tmp_mode_scope_with_buffers tmp_scope(m_ctx, m_tmp_uassignment, m_tmp_eassignment); + return m_ctx.is_def_eq(e1, e2); } expr tmp_type_context::infer(expr const & e) { - type_context::tmp_mode_scope_with_buffers tmp_scope(m_tctx, m_tmp_uassignment, m_tmp_eassignment); - return m_tctx.infer(e); + type_context::tmp_mode_scope_with_buffers tmp_scope(m_ctx, m_tmp_uassignment, m_tmp_eassignment); + return m_ctx.infer(e); } expr tmp_type_context::whnf(expr const & e) { - type_context::tmp_mode_scope_with_buffers tmp_scope(m_tctx, m_tmp_uassignment, m_tmp_eassignment); - return m_tctx.whnf(e); + type_context::tmp_mode_scope_with_buffers tmp_scope(m_ctx, m_tmp_uassignment, m_tmp_eassignment); + return m_ctx.whnf(e); } level tmp_type_context::mk_tmp_univ_mvar() { - type_context::tmp_mode_scope_with_buffers tmp_scope(m_tctx, m_tmp_uassignment, m_tmp_eassignment); - return m_tctx.mk_tmp_univ_mvar(); + type_context::tmp_mode_scope_with_buffers tmp_scope(m_ctx, m_tmp_uassignment, m_tmp_eassignment); + return m_ctx.mk_tmp_univ_mvar(); } expr tmp_type_context::mk_tmp_mvar(expr const & type) { - type_context::tmp_mode_scope_with_buffers tmp_scope(m_tctx, m_tmp_uassignment, m_tmp_eassignment); - return m_tctx.mk_tmp_mvar(type); + type_context::tmp_mode_scope_with_buffers tmp_scope(m_ctx, m_tmp_uassignment, m_tmp_eassignment); + return m_ctx.mk_tmp_mvar(type); } bool tmp_type_context::is_uassigned(unsigned i) const { @@ -3720,48 +3665,48 @@ void tmp_type_context::clear_eassignment() { } expr tmp_type_context::instantiate_mvars(expr const & e) { - type_context::tmp_mode_scope_with_buffers tmp_scope(m_tctx, m_tmp_uassignment, m_tmp_eassignment); - return m_tctx.instantiate_mvars(e); + type_context::tmp_mode_scope_with_buffers tmp_scope(m_ctx, m_tmp_uassignment, m_tmp_eassignment); + return m_ctx.instantiate_mvars(e); } void tmp_type_context::assign(expr const & m, expr const & v) { - type_context::tmp_mode_scope_with_buffers tmp_scope(m_tctx, m_tmp_uassignment, m_tmp_eassignment); - m_tctx.assign(m, v); + type_context::tmp_mode_scope_with_buffers tmp_scope(m_ctx, m_tmp_uassignment, m_tmp_eassignment); + m_ctx.assign(m, v); } expr tmp_type_context::mk_lambda(buffer const & locals, expr const & e) { - type_context::tmp_mode_scope_with_buffers tmp_scope(m_tctx, m_tmp_uassignment, m_tmp_eassignment); - return m_tctx.mk_lambda(locals, e); + type_context::tmp_mode_scope_with_buffers tmp_scope(m_ctx, m_tmp_uassignment, m_tmp_eassignment); + return m_ctx.mk_lambda(locals, e); } expr tmp_type_context::mk_pi(buffer const & locals, expr const & e) { - type_context::tmp_mode_scope_with_buffers tmp_scope(m_tctx, m_tmp_uassignment, m_tmp_eassignment); - return m_tctx.mk_pi(locals, e); + type_context::tmp_mode_scope_with_buffers tmp_scope(m_ctx, m_tmp_uassignment, m_tmp_eassignment); + return m_ctx.mk_pi(locals, e); } expr tmp_type_context::mk_lambda(expr const & local, expr const & e) { - type_context::tmp_mode_scope_with_buffers tmp_scope(m_tctx, m_tmp_uassignment, m_tmp_eassignment); - return m_tctx.mk_lambda(local, e); + type_context::tmp_mode_scope_with_buffers tmp_scope(m_ctx, m_tmp_uassignment, m_tmp_eassignment); + return m_ctx.mk_lambda(local, e); } expr tmp_type_context::mk_pi(expr const & local, expr const & e) { - type_context::tmp_mode_scope_with_buffers tmp_scope(m_tctx, m_tmp_uassignment, m_tmp_eassignment); - return m_tctx.mk_pi(local, e); + type_context::tmp_mode_scope_with_buffers tmp_scope(m_ctx, m_tmp_uassignment, m_tmp_eassignment); + return m_ctx.mk_pi(local, e); } expr tmp_type_context::mk_lambda(std::initializer_list const & locals, expr const & e) { - type_context::tmp_mode_scope_with_buffers tmp_scope(m_tctx, m_tmp_uassignment, m_tmp_eassignment); - return m_tctx.mk_lambda(locals, e); + type_context::tmp_mode_scope_with_buffers tmp_scope(m_ctx, m_tmp_uassignment, m_tmp_eassignment); + return m_ctx.mk_lambda(locals, e); } expr tmp_type_context::mk_pi(std::initializer_list const & locals, expr const & e) { - type_context::tmp_mode_scope_with_buffers tmp_scope(m_tctx, m_tmp_uassignment, m_tmp_eassignment); - return m_tctx.mk_pi(locals, e); + type_context::tmp_mode_scope_with_buffers tmp_scope(m_ctx, m_tmp_uassignment, m_tmp_eassignment); + return m_ctx.mk_pi(locals, e); } bool tmp_type_context::is_prop(expr const & e) { - type_context::tmp_mode_scope_with_buffers tmp_scope(m_tctx, m_tmp_uassignment, m_tmp_eassignment); - return m_tctx.is_prop(e); + type_context::tmp_mode_scope_with_buffers tmp_scope(m_ctx, m_tmp_uassignment, m_tmp_eassignment); + return m_ctx.is_prop(e); } /** \brief Helper class for pretty printing terms that contain local_decl_ref's and metavar_decl_ref's */ diff --git a/src/library/type_context.h b/src/library/type_context.h index 451199484c..1bb7ef9256 100644 --- a/src/library/type_context.h +++ b/src/library/type_context.h @@ -418,6 +418,9 @@ public: expr eta_expand(expr const & e); + /* Try to assign metavariables occuring in e using type class resolution */ + expr complete_instance(expr const & e); + struct transparency_scope : public flet { transparency_scope(type_context & ctx, transparency_mode m): flet(ctx.m_transparency_mode, m) { @@ -618,7 +621,6 @@ private: bool is_def_eq_core_core(expr const & t, expr const & s); bool is_def_eq_core(expr const & t, expr const & s); bool is_def_eq_binding(expr e1, expr e2); - expr complete_instance(expr const & e); expr try_to_unstuck_using_complete_instance(expr const & e); bool is_def_eq_args(expr const & e1, expr const & e2); bool is_def_eq_eta(expr const & e1, expr const & e2); @@ -703,15 +705,15 @@ public: }; class tmp_type_context : public abstract_type_context { - type_context & m_tctx; + type_context & m_ctx; buffer> m_tmp_uassignment; buffer> m_tmp_eassignment; public: - tmp_type_context(type_context & tctx, unsigned num_umeta = 0, unsigned num_emeta = 0); - type_context & tctx() const { return m_tctx; } + tmp_type_context(type_context & ctx, unsigned num_umeta = 0, unsigned num_emeta = 0); + type_context & ctx() const { return m_ctx; } - virtual environment const & env() const override { return m_tctx.env(); } + virtual environment const & env() const override { return m_ctx.env(); } virtual expr infer(expr const & e) override; virtual expr whnf(expr const & e) override; virtual bool is_def_eq(expr const & e1, expr const & e2) override; diff --git a/tests/lean/run/mk_byte.lean b/tests/lean/run/mk_byte.lean new file mode 100644 index 0000000000..7d8987a0ae --- /dev/null +++ b/tests/lean/run/mk_byte.lean @@ -0,0 +1,19 @@ +import data.bitvec + +open vector + +def byte_type := bitvec 8 + +-- A byte is formed from concatenating two bits and a 6-bit field. +def mk_byte (a b : bool) (l : bitvec 6) : byte_type := a :: b :: l + +-- Get the third component +def get_data (byte : byte_type) : bitvec 6 := vector.dropn 2 byte + +lemma get_data_mk_byte {a b : bool} {l : bitvec 6} : get_data (mk_byte a b l) = l := +begin + apply vector.eq, + unfold mk_byte, + unfold get_data, + simp [to_list_dropn, to_list_cons, list.dropn] +end