diff --git a/library/init/logic.lean b/library/init/logic.lean index 34c97043f9..ca20bd9b35 100644 --- a/library/init/logic.lean +++ b/library/init/logic.lean @@ -54,15 +54,15 @@ lemma proof_irrel {a : Prop} (h₁ h₂ : a) : h₁ = h₂ := rfl @[simp] lemma id.def {α : Sort u} (a : α) : id a = a := rfl -@[inline] def eq.mp {α β : Sort u} : (α = β) → α → β := -eq.rec_on +@[inline] def eq.mp {α β : Sort u} (h₁ : α = β) (h₂ : α) : β := +eq.rec_on h₁ h₂ @[inline] def eq.mpr {α β : Sort u} : (α = β) → β → α := λ h₁ h₂, eq.rec_on (eq.symm h₁) h₂ @[elab_as_eliminator] -lemma eq.substr {α : Sort u} {p : α → Prop} {a b : α} (h₁ : b = a) : p a → p b := -eq.subst (eq.symm h₁) +lemma eq.substr {α : Sort u} {p : α → Prop} {a b : α} (h₁ : b = a) (h₂ : p a) : p b := +eq.subst (eq.symm h₁) h₂ lemma congr {α : Sort u} {β : Sort v} {f₁ f₂ : α → β} {a₁ a₂ : α} (h₁ : f₁ = f₂) (h₂ : a₁ = a₂) : f₁ a₁ = f₂ a₂ := eq.subst h₁ (eq.subst h₂ rfl) @@ -70,8 +70,8 @@ eq.subst h₁ (eq.subst h₂ rfl) lemma congr_fun {α : Sort u} {β : α → Sort v} {f g : Π x, β x} (h : f = g) (a : α) : f a = g a := eq.subst h (eq.refl (f a)) -lemma congr_arg {α : Sort u} {β : Sort v} {a₁ a₂ : α} (f : α → β) : a₁ = a₂ → f a₁ = f a₂ := -congr rfl +lemma congr_arg {α : Sort u} {β : Sort v} {a₁ a₂ : α} (f : α → β) (h : a₁ = a₂) : f a₁ = f a₂ := +congr rfl h lemma trans_rel_left {α : Sort u} {a b c : α} (r : α → α → Prop) (h₁ : r a b) (h₂ : b = c) : r a c := h₂ ▸ h₁ @@ -133,11 +133,11 @@ attribute [refl] heq.refl section variables {α β φ : Sort u} {a a' : α} {b b' : β} {c : φ} -lemma heq.elim {α : Sort u} {a : α} {p : α → Sort v} {b : α} (h₁ : a == b) -: p a → p b := eq.rec_on (eq_of_heq h₁) +lemma heq.elim {α : Sort u} {a : α} {p : α → Sort v} {b : α} (h₁ : a == b) (h₂ : p a) : p b := +eq.rec_on (eq_of_heq h₁) h₂ -lemma heq.subst {p : ∀ T : Sort u, T → Prop} : a == b → p α a → p β b := -heq.rec_on +lemma heq.subst {p : ∀ T : Sort u, T → Prop} (h₁ : a == b) (h₂ : p α a) : p β b := +heq.rec_on h₁ h₂ @[symm] lemma heq.symm (h : a == b) : b == a := heq.rec_on h (heq.refl a) diff --git a/src/frontends/lean/elaborator.cpp b/src/frontends/lean/elaborator.cpp index f20dfabfa4..cfb77418bd 100644 --- a/src/frontends/lean/elaborator.cpp +++ b/src/frontends/lean/elaborator.cpp @@ -703,6 +703,17 @@ optional elaborator::mk_coercion(expr const & e, expr e_type, expr type, e } bool elaborator::is_def_eq(expr const & e1, expr const & e2) { + type_context_old::fo_unif_approx_scope scope1(m_ctx); + type_context_old::ctx_unif_approx_scope scope2(m_ctx); + try { + return m_ctx.is_def_eq(e1, e2); + } catch (exception &) { + return false; + } +} + +/* Check `e1 =?= e2` using all unifier approximation: first-order, context-compression and quasi-patterns. */ +bool elaborator::is_def_eq_all_approx(expr const & e1, expr const & e2) { type_context_old::approximate_scope scope(m_ctx); try { return m_ctx.is_def_eq(e1, e2); @@ -1127,7 +1138,7 @@ expr elaborator::visit_elim_app(expr const & fn, elim_info const & info, buffer< trace_elab_debug(tout() << "motive:\n " << instantiate_mvars(motive) << "\n";); expr motive_arg = new_args[info.m_motive_idx]; - if (!is_def_eq(motive_arg, motive)) { + if (!is_def_eq_all_approx(motive_arg, motive)) { throw elaborator_exception(ref, "\"eliminator\" elaborator failed to compute the motive"); } @@ -1140,7 +1151,7 @@ expr elaborator::visit_elim_app(expr const & fn, elim_info const & info, buffer< expr new_arg = visit(*arg, some_expr(new_arg_type)); if (!is_def_eq(new_args[i], new_arg)) { throw elaborator_exception(ref, format("\"eliminator\" elaborator type mismatch, term") + - pp_type_mismatch(new_arg, infer_type(new_arg), new_arg_type)); + pp_type_mismatch(new_arg, infer_type(new_arg), new_arg_type)); } else { new_args[i] = new_arg; } diff --git a/src/frontends/lean/elaborator.h b/src/frontends/lean/elaborator.h index ca373ee711..290bff7c1c 100644 --- a/src/frontends/lean/elaborator.h +++ b/src/frontends/lean/elaborator.h @@ -117,6 +117,7 @@ private: expr whnf(expr const & e) { return m_ctx.whnf(e); } expr try_to_pi(expr const & e) { return m_ctx.try_to_pi(e); } bool is_def_eq(expr const & e1, expr const & e2); + bool is_def_eq_all_approx(expr const & e1, expr const & e2); bool try_is_def_eq(expr const & e1, expr const & e2); bool is_uvar_assigned(level const & l) const { return m_ctx.is_assigned(l); } bool is_mvar_assigned(expr const & e) const { return m_ctx.is_assigned(e); } diff --git a/src/library/type_context.cpp b/src/library/type_context.cpp index 0806098051..66e80be6e6 100644 --- a/src/library/type_context.cpp +++ b/src/library/type_context.cpp @@ -160,7 +160,7 @@ type_context_old::type_context_old(type_context_old && src): m_cache(src.m_cache == &src.m_dummy_cache ? &m_dummy_cache : src.m_cache), m_local_instances(src.m_local_instances), m_transparency_mode(src.m_transparency_mode), - m_approximate(src.m_approximate), + m_unifier_cfg(src.m_unifier_cfg), m_zeta(src.m_zeta), m_smart_unfolding(src.m_smart_unfolding) { lean_assert(!src.m_tmp_data); @@ -1648,10 +1648,6 @@ optional type_context_old::is_delta(expr const & e) { } } -bool type_context_old::approximate() { - return in_tmp_mode() || m_approximate; -} - /* If \c e is a let local-decl, then unfold it, otherwise return e. */ expr type_context_old::try_zeta(expr const & e) { if (!is_local_decl_ref(e)) @@ -1709,7 +1705,7 @@ Now, we consider some workarounds/approximations. (precise) solution: unfold `x` in `t`. A2) Suppose some `a_i` is in `C` (failed condition 2) - (approximated) solution (when approximate() predicate returns true) : + (approximated) solution (when fo_unif_approx() predicate returns true) : ignore condition and also use ?M := fun a_1 ... a_n, t @@ -1811,7 +1807,7 @@ Now, we consider some workarounds/approximations. If `?M'` is assigned, the workaround is precise, and we just unfold `?M'`. A5) If some `a_i` is not a local constant, - then we use first-order unification (if approximate() is true) + then we use first-order unification (if fo_unif_approx() is true) ?M a_1 ... a_i a_{i+1} ... a_{i+k} =?= f b_1 ... b_k @@ -1827,7 +1823,7 @@ Now, we consider some workarounds/approximations. ?M a_1 ... a_n =?= ?M b_1 ... b_k - then we use first-order unification (if approximate() is true) + then we use first-order unification (if fo_unif_approx() is true) */ bool type_context_old::process_assignment(expr const & m, expr const & v) { lean_trace(name({"type_context", "is_def_eq_detail"}), @@ -1864,7 +1860,7 @@ bool type_context_old::process_assignment(expr const & m, expr const & v) { args[i] = arg; if (!is_local_decl_ref(arg)) { /* m is of the form (?M ... t ...) where t is not a local constant. */ - if (approximate()) { + if (fo_unif_approx()) { /* workaround A5 */ use_fo = true; add_locals = false; @@ -1875,9 +1871,12 @@ bool type_context_old::process_assignment(expr const & m, expr const & v) { if (std::any_of(locals.begin(), locals.end(), [&](expr const & local) { return mlocal_name(local) == mlocal_name(arg); })) { /* m is of the form (?M ... l ... l ...) where l is a local constant. */ - if (approximate()) { + if (quasi_pattern_unif_approx()) { /* workaround A3 */ add_locals = false; + } else if (fo_unif_approx()) { + use_fo = true; + add_locals = false; } else { return false; } @@ -1888,18 +1887,28 @@ bool type_context_old::process_assignment(expr const & m, expr const & v) { if (in_tmp_mode()) { if (m_tmp_data->m_mvar_lctx.find_local_decl(arg)) { /* m is of the form (?M@C ... l ...) where l is a local constant in C */ - if (!approximate()) + if (quasi_pattern_unif_approx()) { + if (add_locals) + in_ctx_locals.push_back(arg); + } else if (fo_unif_approx()) { + use_fo = true; + add_locals = false; + } else { return false; - if (add_locals) - in_ctx_locals.push_back(arg); + } } } else { if (mvar_decl->get_context().find_local_decl(arg)) { /* m is of the form (?M@C ... l ...) where l is a local constant in C. */ - if (!approximate()) + if (quasi_pattern_unif_approx()) { + if (add_locals) + in_ctx_locals.push_back(arg); + } else if (fo_unif_approx()) { + use_fo = true; + add_locals = false; + } else { return false; - if (add_locals) - in_ctx_locals.push_back(arg); + } } } } @@ -1909,7 +1918,7 @@ bool type_context_old::process_assignment(expr const & m, expr const & v) { expr new_v = instantiate_mvars(v); /* enforce A4 */ - if (approximate() && !locals.empty() && get_app_fn(new_v) == mvar) { + if (fo_unif_approx() && !locals.empty() && get_app_fn(new_v) == mvar) { /* A6 */ use_fo = true; } @@ -1919,7 +1928,7 @@ bool type_context_old::process_assignment(expr const & m, expr const & v) { if (optional new_new_v = check_assignment(locals, in_ctx_locals, mvar, new_v)) new_v = *new_new_v; - else if (approximate() && !args.empty()) + else if (fo_unif_approx() && !args.empty()) return process_assignment_fo_approx(mvar, args, new_v); else return false; @@ -1953,7 +1962,6 @@ bool type_context_old::process_assignment(expr const & m, expr const & v) { } if (!in_ctx_locals.empty()) { - lean_assert(approximate()); try { /* We need to type check new_v because abstraction using `mk_lambda` may have produced a type incorrect term. See discussion at A2. @@ -2197,7 +2205,7 @@ struct check_assignment_fn : public replace_visitor { if (m_ctx.in_tmp_mode()) { if (!m_in_ctx_locals.empty()) { /* In tmp mode, we (usually) do not use approximate unification/matching. - Moreover, m_in_ctx_locals is empty if !approximate(). + Moreover, m_in_ctx_locals is empty if we are not approximating Remark: all temporary metavariables share the same local context. Then, if a local in `m_in_ctx_locals` is in the local context of `mvar`, @@ -2235,7 +2243,7 @@ struct check_assignment_fn : public replace_visitor { if (is_subset_of(e_lctx, mvar_lctx, delayed_locals)) return e; - if (m_ctx.approximate() && mvar_lctx.is_subset_of(e_lctx)) { + if (m_ctx.ctx_unif_approx() && mvar_lctx.is_subset_of(e_lctx)) { expr e_type = e_decl->get_type(); if (mvar_lctx.well_formed(e_type)) { /* Restrict context of the ?M' */ @@ -2267,7 +2275,7 @@ struct check_assignment_fn : public replace_visitor { scope_trace_env scope(m_ctx.env(), m_ctx); tout() << "failed to assign " << m_mvar << " :=\n" << m_value << "\n" << "value contains metavariable " << e; - if (!m_ctx.approximate()) { + if (!m_ctx.ctx_unif_approx()) { tout() << " that was declared in a local context that is not a " << "subset of the one in the metavariable being assigned, " << "and local context restriction is disabled\n"; @@ -4030,6 +4038,8 @@ static void instantiate_replacements(type_context_old & ctx, optional type_context_old::mk_class_instance(expr const & type_0) { expr type = instantiate_mvars(type_0); scope S(*this); + fo_unif_approx_scope as1(*this); + ctx_unif_approx_scope as2(*this); optional result; buffer u_replacements; buffer e_replacements; diff --git a/src/library/type_context.h b/src/library/type_context.h index f622196993..3809bf15c8 100644 --- a/src/library/type_context.h +++ b/src/library/type_context.h @@ -34,6 +34,16 @@ bool is_at_least_instances(transparency_mode m); transparency_mode ensure_semireducible_mode(transparency_mode m); transparency_mode ensure_instances_mode(transparency_mode m); +/* Approximation configuration object. */ +struct unifier_config { + bool m_fo_approx{false}; + bool m_ctx_approx{false}; + bool m_quasi_pattern_approx{false}; + unifier_config() {} + unifier_config(bool fo_approx, bool ctx_approx, bool qp_approx): + m_fo_approx(fo_approx), m_ctx_approx(ctx_approx), m_quasi_pattern_approx(qp_approx) {} +}; + class type_context_old : public abstract_type_context { typedef buffer> tmp_uassignment; typedef buffer> tmp_eassignment; @@ -394,14 +404,12 @@ private: /* Stack of backtracking point (aka scope) */ scopes m_scopes; tmp_data * m_tmp_data{nullptr}; - /* If m_approximate == true, then enable approximate higher-order unification - even if we are not in tmp_mode + /* Higher-order unification approximation options. - Users: + Modules that use approximations: - elaborator - - apply and rewrite tactics use it by default (it can be disabled). - */ - bool m_approximate{false}; + - apply and rewrite tactics use it by default (it can be disabled). */ + unifier_config m_unifier_cfg; /* If m_zeta, then use zeta-reduction (i.e., expand let-expressions at whnf) */ bool m_zeta{true}; @@ -743,9 +751,25 @@ public: } }; - struct approximate_scope : public flet { + /* Enable/disable all unifier approximations. */ + struct approximate_scope : public flet { approximate_scope(type_context_old & ctx, bool approx = true): - flet(ctx.m_approximate, approx) {} + flet(ctx.m_unifier_cfg, unifier_config(approx, approx, approx)) {} + }; + + struct fo_unif_approx_scope : public flet { + fo_unif_approx_scope(type_context_old & ctx, bool approx = true): + flet(ctx.m_unifier_cfg.m_fo_approx, approx) {} + }; + + struct ctx_unif_approx_scope : public flet { + ctx_unif_approx_scope(type_context_old & ctx, bool approx = true): + flet(ctx.m_unifier_cfg.m_ctx_approx, approx) {} + }; + + struct quasi_pattern_unif_approx_scope : public flet { + quasi_pattern_unif_approx_scope(type_context_old & ctx, bool approx = true): + flet(ctx.m_unifier_cfg.m_quasi_pattern_approx, approx) {} }; struct zeta_scope : public flet { @@ -781,7 +805,7 @@ public: -------------------------- */ public: struct tmp_mode_scope { - type_context_old & m_ctx; + type_context_old & m_ctx; buffer> m_tmp_uassignment; buffer> m_tmp_eassignment; tmp_data * m_old_data; @@ -930,7 +954,10 @@ private: void commit() { m_postponed_sz = m_owner.m_postponed.size(); m_owner.commit_scope(); m_keep = true; } }; bool process_postponed(scope const & s); - bool approximate(); + bool fo_unif_approx() const { return m_unifier_cfg.m_fo_approx; } + bool ctx_unif_approx() const { return m_unifier_cfg.m_ctx_approx; } + bool quasi_pattern_unif_approx() const { return m_unifier_cfg.m_quasi_pattern_approx; } + bool approximate() const { return fo_unif_approx() || ctx_unif_approx() || quasi_pattern_unif_approx(); } expr try_zeta(expr const & e); expr expand_let_decls(expr const & e); friend struct check_assignment_fn; diff --git a/tests/lean/run/quasi_pattern_unification_approx_issue.lean b/tests/lean/run/quasi_pattern_unification_approx_issue.lean new file mode 100644 index 0000000000..f0b4682c7e --- /dev/null +++ b/tests/lean/run/quasi_pattern_unification_approx_issue.lean @@ -0,0 +1,46 @@ +variables {δ σ : Type} + +def foo1 : state_t δ (state_t σ id) σ := +monad_lift (get : state_t σ id σ) +/- +In Lean3, we used to use the quasi-pattern approximation during elaboration. +The example above demonstrates why it produces counterintuitive behavior. +We have the `monad-lift` application: + +@monad_lift ?m ?n ?c ?α (get : state_t σ id σ) : ?n ?α + +It produces the following unification problem when we process the expected type: + +?n ?α =?= state_t δ (state_t σ id) σ +==> (approximate using first-order unification) +?n := state_t δ (state_t σ id) +?α := σ + +Then, we need to solve: + +?m ?α =?= state_t σ id σ +==> instantiate metavars +?m σ =?= state_t σ id σ +==> (approximate since it is a quasi-pattern unification constraint) +?m := λ σ, state_t σ id σ + +Remark: the constraint is not a Milner pattern because σ is in +the local context of `?m`. We are ignoring the other possible solutions: +?m := λ σ', state_t σ id σ +?m := λ σ', state_t σ' id σ +?m := λ σ', state_t σ id σ' + +We need the quasi-pattern approximation for elaborating recursors. +One option is to enable this kind of approximation only when +elaborating recursors and executing induction-like tactics. + +If we had use first-order unification, then we would have produced +the right answer: `?m := state_t σ id` + +Haskell would work on this example since it always uses +first-order unification. +-/ + +def foo2 : state_t δ (state_t σ id) σ := +do s : σ ← monad_lift (get : state_t σ id σ), + return s