diff --git a/src/Lean/Meta/ExprDefEq.lean b/src/Lean/Meta/ExprDefEq.lean
index a95106244d..aaa6f36b1d 100644
--- a/src/Lean/Meta/ExprDefEq.lean
+++ b/src/Lean/Meta/ExprDefEq.lean
@@ -16,6 +16,35 @@ import Lean.Meta.UnificationHint
namespace Lean.Meta
+/--
+ Return true `b` is of the form `mk a.1 ... a.n`.
+-/
+private def isDefEqEtaStruct (a b : Expr) : MetaM Bool :=
+ matchConstCtor b.getAppFn (fun _ => return false) fun ctorVal _ => do
+ if ctorVal.numParams + ctorVal.numFields != b.getAppNumArgs then
+ trace[Meta.isDefEq.eta.struct] "failed, insufficient number of arguments at{indentExpr b}"
+ return false
+ else
+ let inductVal ← getConstInfoInduct ctorVal.induct
+ if inductVal.nctors != 1 then
+ trace[Meta.isDefEq.eta.struct] "failed, type is not a structure{indentExpr b}"
+ return false
+ else checkpointDefEq do
+ let args := b.getAppArgs
+ for i in [ctorVal.numParams : args.size] do
+ match (← whnf args[i]) with
+ | Expr.proj _ j e _ =>
+ unless ctorVal.numParams + j == i do
+ trace[Meta.isDefEq.eta.struct] "failed, unexpect arg #{i}, unexpected projection #{j}, at{indentExpr b}"
+ return false
+ unless (← isDefEq e a) do
+ trace[Meta.isDefEq.eta.struct] "failed, unexpect arg #{i}, argument{e}\nis not defeq to{indentExpr a}"
+ return false
+ | e =>
+ trace[Meta.isDefEq.eta.struct] "failed, projection expected{indentExpr e}"
+ return false
+ return true
+
/--
Try to solve `a := (fun x => t) =?= b` by eta-expanding `b`.
@@ -35,7 +64,7 @@ private def isDefEqEta (a b : Expr) : MetaM Bool := do
checkpointDefEq <| Meta.isExprDefEqAux a b'
| _ => pure false
else
- pure false
+ return false
/-- Support for `Lean.reduceBool` and `Lean.reduceNat` -/
def isDefEqNative (s t : Expr) : MetaM LBool := do
@@ -1436,6 +1465,8 @@ private def isDefEqApp (t s : Expr) : MetaM Bool := do
private def isExprDefEqExpensive (t : Expr) (s : Expr) : MetaM Bool := do
if (← (isDefEqEta t s <||> isDefEqEta s t)) then pure true else
+ -- TODO: investigate whether this is the place for putting this check
+ if (← (isDefEqEtaStruct t s <||> isDefEqEtaStruct s t)) then pure true else
if (← isDefEqProj t s) then pure true else
whenUndefDo (isDefEqNative t s) do
whenUndefDo (isDefEqNat t s) do
diff --git a/src/kernel/type_checker.cpp b/src/kernel/type_checker.cpp
index 2b693885f1..d2b887c894 100644
--- a/src/kernel/type_checker.cpp
+++ b/src/kernel/type_checker.cpp
@@ -749,6 +749,27 @@ bool type_checker::try_eta_expansion_core(expr const & t, expr const & s) {
}
}
+/** \brief check whether \c s is of the form mk t.1 ... t.n */
+bool type_checker::try_eta_struct_core(expr const & t, expr const & s) {
+ expr f = get_app_fn(s);
+ if (!is_constant(f)) return false;
+ constant_info f_info = env().get(const_name(f));
+ if (!f_info.is_constructor()) return false;
+ constructor_val f_val = f_info.to_constructor_val();
+ if (get_app_num_args(s) != f_val.get_nparams() + f_val.get_nfields()) return false;
+ inductive_val I_val = env().get(f_val.get_induct()).to_inductive_val();
+ if (I_val.get_ncnstrs() != 1) return 1;
+ buffer s_args;
+ get_app_args(s, s_args);
+ for (unsigned i = f_val.get_nparams(); i < s_args.size(); i++) {
+ expr s_arg = whnf(s_args[i]);
+ if (!is_proj(s_arg)) return false;
+ if (proj_idx(s_arg) + nat(f_val.get_nparams()) != nat(i)) return false;
+ if (!is_def_eq(t, proj_expr(s_arg))) return false;
+ }
+ return true;
+}
+
/** \brief Return true if \c t and \c s are definitionally equal because they are applications of the form
(f a_1 ... a_n) (g b_1 ... b_n), and \c f and \c g are definitionally equal, and
\c a_i and \c b_i are also definitionally equal for every 1 <= i <= n.
@@ -961,6 +982,9 @@ bool type_checker::is_def_eq_core(expr const & t, expr const & s) {
if (try_eta_expansion(t_n, s_n))
return true;
+ if (try_eta_struct(t_n, s_n))
+ return true;
+
r = try_string_lit_expansion(t_n, s_n);
if (r != l_undef) return r == l_true;
diff --git a/src/kernel/type_checker.h b/src/kernel/type_checker.h
index a5fcd4612e..45b510f782 100644
--- a/src/kernel/type_checker.h
+++ b/src/kernel/type_checker.h
@@ -78,6 +78,10 @@ private:
bool try_eta_expansion(expr const & t, expr const & s) {
return try_eta_expansion_core(t, s) || try_eta_expansion_core(s, t);
}
+ bool try_eta_struct_core(expr const & t, expr const & s);
+ bool try_eta_struct(expr const & t, expr const & s) {
+ return try_eta_struct_core(t, s) || try_eta_struct_core(s, t);
+ }
lbool try_string_lit_expansion_core(expr const & t, expr const & s);
lbool try_string_lit_expansion(expr const & t, expr const & s);
bool is_def_eq_app(expr const & t, expr const & s);
diff --git a/tests/lean/run/etaStruct.lean b/tests/lean/run/etaStruct.lean
new file mode 100644
index 0000000000..689cb6f75d
--- /dev/null
+++ b/tests/lean/run/etaStruct.lean
@@ -0,0 +1,35 @@
+example (x : α × β) : x = (x.1, x.2) :=
+ rfl -- Should work with eta for structures
+
+example (x : Unit) : x = ⟨⟩ :=
+ rfl -- Should work with eta for structures
+
+structure Equiv (α : Sort u) (β : Sort v) where
+ toFun : α → β
+ invFun : β → α
+ left_inv : ∀ x, invFun (toFun x) = x
+ right_inv : ∀ x, toFun (invFun x) = x
+
+infix:50 "≃" => Equiv
+
+def Equiv.symm (e : α ≃ β) : β ≃ α :=
+ { toFun := e.invFun
+ invFun := e.toFun
+ left_inv := e.right_inv
+ right_inv := e.left_inv }
+
+theorem Equiv.symm.symm (e : α ≃ β) : e.symm.symm = e :=
+ rfl -- Should work with eta for structures
+
+structure Bla where
+ x : Nat
+
+def Bla.toNat (b : Bla) : Nat := b.x
+def Nat.toBla (x : Nat) : Bla := { x }
+
+example (b : Bla) : b.toNat.toBla = b :=
+ rfl -- Should work with eta for structures
+
+example (b : Bla) : b.toNat.toBla = b := by
+ cases b
+ rfl