From d685c545b46b5d026475789d77be031bfb64f88d Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 22 Nov 2021 17:26:30 -0800 Subject: [PATCH] feat: eta for structures --- src/Lean/Meta/ExprDefEq.lean | 33 ++++++++++++++++++++++++++++++++- src/kernel/type_checker.cpp | 24 ++++++++++++++++++++++++ src/kernel/type_checker.h | 4 ++++ tests/lean/run/etaStruct.lean | 35 +++++++++++++++++++++++++++++++++++ 4 files changed, 95 insertions(+), 1 deletion(-) create mode 100644 tests/lean/run/etaStruct.lean diff --git a/src/Lean/Meta/ExprDefEq.lean b/src/Lean/Meta/ExprDefEq.lean index a95106244d..aaa6f36b1d 100644 --- a/src/Lean/Meta/ExprDefEq.lean +++ b/src/Lean/Meta/ExprDefEq.lean @@ -16,6 +16,35 @@ import Lean.Meta.UnificationHint namespace Lean.Meta +/-- + Return true `b` is of the form `mk a.1 ... a.n`. +-/ +private def isDefEqEtaStruct (a b : Expr) : MetaM Bool := + matchConstCtor b.getAppFn (fun _ => return false) fun ctorVal _ => do + if ctorVal.numParams + ctorVal.numFields != b.getAppNumArgs then + trace[Meta.isDefEq.eta.struct] "failed, insufficient number of arguments at{indentExpr b}" + return false + else + let inductVal ← getConstInfoInduct ctorVal.induct + if inductVal.nctors != 1 then + trace[Meta.isDefEq.eta.struct] "failed, type is not a structure{indentExpr b}" + return false + else checkpointDefEq do + let args := b.getAppArgs + for i in [ctorVal.numParams : args.size] do + match (← whnf args[i]) with + | Expr.proj _ j e _ => + unless ctorVal.numParams + j == i do + trace[Meta.isDefEq.eta.struct] "failed, unexpect arg #{i}, unexpected projection #{j}, at{indentExpr b}" + return false + unless (← isDefEq e a) do + trace[Meta.isDefEq.eta.struct] "failed, unexpect arg #{i}, argument{e}\nis not defeq to{indentExpr a}" + return false + | e => + trace[Meta.isDefEq.eta.struct] "failed, projection expected{indentExpr e}" + return false + return true + /-- Try to solve `a := (fun x => t) =?= b` by eta-expanding `b`. @@ -35,7 +64,7 @@ private def isDefEqEta (a b : Expr) : MetaM Bool := do checkpointDefEq <| Meta.isExprDefEqAux a b' | _ => pure false else - pure false + return false /-- Support for `Lean.reduceBool` and `Lean.reduceNat` -/ def isDefEqNative (s t : Expr) : MetaM LBool := do @@ -1436,6 +1465,8 @@ private def isDefEqApp (t s : Expr) : MetaM Bool := do private def isExprDefEqExpensive (t : Expr) (s : Expr) : MetaM Bool := do if (← (isDefEqEta t s <||> isDefEqEta s t)) then pure true else + -- TODO: investigate whether this is the place for putting this check + if (← (isDefEqEtaStruct t s <||> isDefEqEtaStruct s t)) then pure true else if (← isDefEqProj t s) then pure true else whenUndefDo (isDefEqNative t s) do whenUndefDo (isDefEqNat t s) do diff --git a/src/kernel/type_checker.cpp b/src/kernel/type_checker.cpp index 2b693885f1..d2b887c894 100644 --- a/src/kernel/type_checker.cpp +++ b/src/kernel/type_checker.cpp @@ -749,6 +749,27 @@ bool type_checker::try_eta_expansion_core(expr const & t, expr const & s) { } } +/** \brief check whether \c s is of the form mk t.1 ... t.n */ +bool type_checker::try_eta_struct_core(expr const & t, expr const & s) { + expr f = get_app_fn(s); + if (!is_constant(f)) return false; + constant_info f_info = env().get(const_name(f)); + if (!f_info.is_constructor()) return false; + constructor_val f_val = f_info.to_constructor_val(); + if (get_app_num_args(s) != f_val.get_nparams() + f_val.get_nfields()) return false; + inductive_val I_val = env().get(f_val.get_induct()).to_inductive_val(); + if (I_val.get_ncnstrs() != 1) return 1; + buffer s_args; + get_app_args(s, s_args); + for (unsigned i = f_val.get_nparams(); i < s_args.size(); i++) { + expr s_arg = whnf(s_args[i]); + if (!is_proj(s_arg)) return false; + if (proj_idx(s_arg) + nat(f_val.get_nparams()) != nat(i)) return false; + if (!is_def_eq(t, proj_expr(s_arg))) return false; + } + return true; +} + /** \brief Return true if \c t and \c s are definitionally equal because they are applications of the form (f a_1 ... a_n) (g b_1 ... b_n), and \c f and \c g are definitionally equal, and \c a_i and \c b_i are also definitionally equal for every 1 <= i <= n. @@ -961,6 +982,9 @@ bool type_checker::is_def_eq_core(expr const & t, expr const & s) { if (try_eta_expansion(t_n, s_n)) return true; + if (try_eta_struct(t_n, s_n)) + return true; + r = try_string_lit_expansion(t_n, s_n); if (r != l_undef) return r == l_true; diff --git a/src/kernel/type_checker.h b/src/kernel/type_checker.h index a5fcd4612e..45b510f782 100644 --- a/src/kernel/type_checker.h +++ b/src/kernel/type_checker.h @@ -78,6 +78,10 @@ private: bool try_eta_expansion(expr const & t, expr const & s) { return try_eta_expansion_core(t, s) || try_eta_expansion_core(s, t); } + bool try_eta_struct_core(expr const & t, expr const & s); + bool try_eta_struct(expr const & t, expr const & s) { + return try_eta_struct_core(t, s) || try_eta_struct_core(s, t); + } lbool try_string_lit_expansion_core(expr const & t, expr const & s); lbool try_string_lit_expansion(expr const & t, expr const & s); bool is_def_eq_app(expr const & t, expr const & s); diff --git a/tests/lean/run/etaStruct.lean b/tests/lean/run/etaStruct.lean new file mode 100644 index 0000000000..689cb6f75d --- /dev/null +++ b/tests/lean/run/etaStruct.lean @@ -0,0 +1,35 @@ +example (x : α × β) : x = (x.1, x.2) := + rfl -- Should work with eta for structures + +example (x : Unit) : x = ⟨⟩ := + rfl -- Should work with eta for structures + +structure Equiv (α : Sort u) (β : Sort v) where + toFun : α → β + invFun : β → α + left_inv : ∀ x, invFun (toFun x) = x + right_inv : ∀ x, toFun (invFun x) = x + +infix:50 "≃" => Equiv + +def Equiv.symm (e : α ≃ β) : β ≃ α := + { toFun := e.invFun + invFun := e.toFun + left_inv := e.right_inv + right_inv := e.left_inv } + +theorem Equiv.symm.symm (e : α ≃ β) : e.symm.symm = e := + rfl -- Should work with eta for structures + +structure Bla where + x : Nat + +def Bla.toNat (b : Bla) : Nat := b.x +def Nat.toBla (x : Nat) : Bla := { x } + +example (b : Bla) : b.toNat.toBla = b := + rfl -- Should work with eta for structures + +example (b : Bla) : b.toNat.toBla = b := by + cases b + rfl