feat: eta for structures

This commit is contained in:
Leonardo de Moura 2021-11-22 17:26:30 -08:00
parent dd146d50cf
commit d685c545b4
4 changed files with 95 additions and 1 deletions

View file

@ -16,6 +16,35 @@ import Lean.Meta.UnificationHint
namespace Lean.Meta
/--
Return true `b` is of the form `mk a.1 ... a.n`.
-/
private def isDefEqEtaStruct (a b : Expr) : MetaM Bool :=
matchConstCtor b.getAppFn (fun _ => return false) fun ctorVal _ => do
if ctorVal.numParams + ctorVal.numFields != b.getAppNumArgs then
trace[Meta.isDefEq.eta.struct] "failed, insufficient number of arguments at{indentExpr b}"
return false
else
let inductVal ← getConstInfoInduct ctorVal.induct
if inductVal.nctors != 1 then
trace[Meta.isDefEq.eta.struct] "failed, type is not a structure{indentExpr b}"
return false
else checkpointDefEq do
let args := b.getAppArgs
for i in [ctorVal.numParams : args.size] do
match (← whnf args[i]) with
| Expr.proj _ j e _ =>
unless ctorVal.numParams + j == i do
trace[Meta.isDefEq.eta.struct] "failed, unexpect arg #{i}, unexpected projection #{j}, at{indentExpr b}"
return false
unless (← isDefEq e a) do
trace[Meta.isDefEq.eta.struct] "failed, unexpect arg #{i}, argument{e}\nis not defeq to{indentExpr a}"
return false
| e =>
trace[Meta.isDefEq.eta.struct] "failed, projection expected{indentExpr e}"
return false
return true
/--
Try to solve `a := (fun x => t) =?= b` by eta-expanding `b`.
@ -35,7 +64,7 @@ private def isDefEqEta (a b : Expr) : MetaM Bool := do
checkpointDefEq <| Meta.isExprDefEqAux a b'
| _ => pure false
else
pure false
return false
/-- Support for `Lean.reduceBool` and `Lean.reduceNat` -/
def isDefEqNative (s t : Expr) : MetaM LBool := do
@ -1436,6 +1465,8 @@ private def isDefEqApp (t s : Expr) : MetaM Bool := do
private def isExprDefEqExpensive (t : Expr) (s : Expr) : MetaM Bool := do
if (← (isDefEqEta t s <||> isDefEqEta s t)) then pure true else
-- TODO: investigate whether this is the place for putting this check
if (← (isDefEqEtaStruct t s <||> isDefEqEtaStruct s t)) then pure true else
if (← isDefEqProj t s) then pure true else
whenUndefDo (isDefEqNative t s) do
whenUndefDo (isDefEqNat t s) do

View file

@ -749,6 +749,27 @@ bool type_checker::try_eta_expansion_core(expr const & t, expr const & s) {
}
}
/** \brief check whether \c s is of the form <tt>mk t.1 ... t.n</tt> */
bool type_checker::try_eta_struct_core(expr const & t, expr const & s) {
expr f = get_app_fn(s);
if (!is_constant(f)) return false;
constant_info f_info = env().get(const_name(f));
if (!f_info.is_constructor()) return false;
constructor_val f_val = f_info.to_constructor_val();
if (get_app_num_args(s) != f_val.get_nparams() + f_val.get_nfields()) return false;
inductive_val I_val = env().get(f_val.get_induct()).to_inductive_val();
if (I_val.get_ncnstrs() != 1) return 1;
buffer<expr> s_args;
get_app_args(s, s_args);
for (unsigned i = f_val.get_nparams(); i < s_args.size(); i++) {
expr s_arg = whnf(s_args[i]);
if (!is_proj(s_arg)) return false;
if (proj_idx(s_arg) + nat(f_val.get_nparams()) != nat(i)) return false;
if (!is_def_eq(t, proj_expr(s_arg))) return false;
}
return true;
}
/** \brief Return true if \c t and \c s are definitionally equal because they are applications of the form
<tt>(f a_1 ... a_n)</tt> <tt>(g b_1 ... b_n)</tt>, and \c f and \c g are definitionally equal, and
\c a_i and \c b_i are also definitionally equal for every 1 <= i <= n.
@ -961,6 +982,9 @@ bool type_checker::is_def_eq_core(expr const & t, expr const & s) {
if (try_eta_expansion(t_n, s_n))
return true;
if (try_eta_struct(t_n, s_n))
return true;
r = try_string_lit_expansion(t_n, s_n);
if (r != l_undef) return r == l_true;

View file

@ -78,6 +78,10 @@ private:
bool try_eta_expansion(expr const & t, expr const & s) {
return try_eta_expansion_core(t, s) || try_eta_expansion_core(s, t);
}
bool try_eta_struct_core(expr const & t, expr const & s);
bool try_eta_struct(expr const & t, expr const & s) {
return try_eta_struct_core(t, s) || try_eta_struct_core(s, t);
}
lbool try_string_lit_expansion_core(expr const & t, expr const & s);
lbool try_string_lit_expansion(expr const & t, expr const & s);
bool is_def_eq_app(expr const & t, expr const & s);

View file

@ -0,0 +1,35 @@
example (x : α × β) : x = (x.1, x.2) :=
rfl -- Should work with eta for structures
example (x : Unit) : x = ⟨⟩ :=
rfl -- Should work with eta for structures
structure Equiv (α : Sort u) (β : Sort v) where
toFun : α → β
invFun : β → α
left_inv : ∀ x, invFun (toFun x) = x
right_inv : ∀ x, toFun (invFun x) = x
infix:50 "≃" => Equiv
def Equiv.symm (e : α ≃ β) : β ≃ α :=
{ toFun := e.invFun
invFun := e.toFun
left_inv := e.right_inv
right_inv := e.left_inv }
theorem Equiv.symm.symm (e : α ≃ β) : e.symm.symm = e :=
rfl -- Should work with eta for structures
structure Bla where
x : Nat
def Bla.toNat (b : Bla) : Nat := b.x
def Nat.toBla (x : Nat) : Bla := { x }
example (b : Bla) : b.toNat.toBla = b :=
rfl -- Should work with eta for structures
example (b : Bla) : b.toNat.toBla = b := by
cases b
rfl