feat: cleanups to ACI and Identity classes (#3195)

This makes changes to the definitions of Associativity, Commutativity,
Idempotence and Identity classes to be more aligned with Mathlib's
versions.

The changes are:
*  Move classes are moved from `Lean` to root namespace.
* Drop `Is` prefix from names.
* Rename `IsNeutral` to `LawfulIdentity` and add Left and Right
subclasses.
* Change neutral/identity element to outParam.
* Introduce `HasIdentity` for operations not intended for proofs to
implement

The identity changes are to make this compatible with
[Mathlib](718042db9d/Mathlib/Init/Algebra/Classes.lean)
and to enable nicer fold operations in Std that can use type classes to
infer the identity/initial element on binary operations.

---------

Co-authored-by: Kyle Miller <kmill31415@gmail.com>
This commit is contained in:
Joe Hendrix 2024-01-24 13:46:58 -08:00 committed by GitHub
parent 2beb948a3b
commit 8293fd4e09
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 150 additions and 74 deletions

View file

@ -1680,40 +1680,92 @@ So, you are mainly losing the capability of type checking your development using
-/
axiom ofReduceNat (a b : Nat) (h : reduceNat a = b) : a = b
end Lean
namespace Std
variable {α : Sort u}
/--
`IsAssociative op` says that `op` is an associative operation,
i.e. `(a ∘ b) ∘ c = a ∘ (b ∘ c)`. It is used by the `ac_rfl` tactic.
`Associative op` indicates `op` is an associative operation,
i.e. `(a ∘ b) ∘ c = a ∘ (b ∘ c)`.
-/
class IsAssociative {α : Sort u} (op : ααα) where
class Associative (op : ααα) : Prop where
/-- An associative operation satisfies `(a ∘ b) ∘ c = a ∘ (b ∘ c)`. -/
assoc : (a b c : α) → op (op a b) c = op a (op b c)
/--
`IsCommutative op` says that `op` is a commutative operation,
i.e. `a ∘ b = b ∘ a`. It is used by the `ac_rfl` tactic.
`Commutative op` says that `op` is a commutative operation,
i.e. `a ∘ b = b ∘ a`.
-/
class IsCommutative {α : Sort u} (op : ααα) where
class Commutative (op : ααα) : Prop where
/-- A commutative operation satisfies `a ∘ b = b ∘ a`. -/
comm : (a b : α) → op a b = op b a
/--
`IsIdempotent op` says that `op` is an idempotent operation,
i.e. `a ∘ a = a`. It is used by the `ac_rfl` tactic
(which also simplifies up to idempotence when available).
`IdempotentOp op` indicates `op` is an idempotent binary operation.
i.e. `a ∘ a = a`.
-/
class IsIdempotent {α : Sort u} (op : ααα) where
class IdempotentOp (op : ααα) : Prop where
/-- An idempotent operation satisfies `a ∘ a = a`. -/
idempotent : (x : α) → op x x = x
/--
`IsNeutral op e` says that `e` is a neutral operation for `op`,
i.e. `a ∘ e = a = e ∘ a`. It is used by the `ac_rfl` tactic
(which also simplifies neutral elements when available).
-/
class IsNeutral {α : Sort u} (op : ααα) (neutral : α) where
/-- A neutral element can be cancelled on the left: `e ∘ a = a`. -/
left_neutral : (a : α) → op neutral a = a
/-- A neutral element can be cancelled on the right: `a ∘ e = a`. -/
right_neutral : (a : α) → op a neutral = a
`LeftIdentify op o` indicates `o` is a left identity of `op`.
end Lean
This class does not require a proof that `o` is an identity, and
is used primarily for infering the identity using class resoluton.
-/
class LeftIdentity (op : α → β → β) (o : outParam α) : Prop
/--
`LawfulLeftIdentify op o` indicates `o` is a verified left identity of
`op`.
-/
class LawfulLeftIdentity (op : α → β → β) (o : outParam α) extends LeftIdentity op o : Prop where
/-- Left identity `o` is an identity. -/
left_id : ∀ a, op o a = a
/--
`RightIdentify op o` indicates `o` is a right identity `o` of `op`.
This class does not require a proof that `o` is an identity, and is used
primarily for infering the identity using class resoluton.
-/
class RightIdentity (op : α → β → α) (o : outParam β) : Prop
/--
`LawfulRightIdentify op o` indicates `o` is a verified right identity of
`op`.
-/
class LawfulRightIdentity (op : α → β → α) (o : outParam β) extends RightIdentity op o : Prop where
/-- Right identity `o` is an identity. -/
right_id : ∀ a, op a o = a
/--
`Identity op o` indicates `o` is a left and right identity of `op`.
This class does not require a proof that `o` is an identity, and is used
primarily for infering the identity using class resoluton.
-/
class Identity (op : ααα) (o : outParam α) extends LeftIdentity op o, RightIdentity op o : Prop
/--
`LawfulIdentity op o` indicates `o` is a verified left and right
identity of `op`.
-/
class LawfulIdentity (op : ααα) (o : outParam α) extends Identity op o, LawfulLeftIdentity op o, LawfulRightIdentity op o : Prop
/--
`LawfulCommIdentity` can simplify defining instances of `LawfulIdentity`
on commutative functions by requiring only a left or right identity
proof.
This class is intended for simplifying defining instances of
`LawfulIdentity` and functions needed commutative operations with
identity should just add a `LawfulIdentity` constraint.
-/
class LawfulCommIdentity (op : ααα) (o : outParam α) [hc : Commutative op] extends LawfulIdentity op o : Prop where
left_id a := Eq.trans (hc.comm o a) (right_id a)
right_id a := Eq.trans (hc.comm a o) (left_id a)
end Std

View file

@ -14,15 +14,17 @@ inductive Expr
| op (lhs rhs : Expr)
deriving Inhabited, Repr, BEq
open Std
structure Variable {α : Sort u} (op : ααα) : Type u where
value : α
neutral : Option $ IsNeutral op value
neutral : Option $ PLift (LawfulIdentity op value)
structure Context (α : Sort u) where
op : ααα
assoc : IsAssociative op
comm : Option $ IsCommutative op
idem : Option $ IsIdempotent op
assoc : Associative op
comm : Option $ PLift $ Commutative op
idem : Option $ PLift $ IdempotentOp op
vars : List (Variable op)
arbitrary : α
@ -128,7 +130,14 @@ theorem Context.mergeIdem_head2 (h : x ≠ y) : mergeIdem (x :: y :: ys) = x ::
simp [mergeIdem, mergeIdem.loop, h]
theorem Context.evalList_mergeIdem (ctx : Context α) (h : ContextInformation.isIdem ctx) (e : List Nat) : evalList α ctx (mergeIdem e) = evalList α ctx e := by
have h : IsIdempotent ctx.op := by simp [ContextInformation.isIdem, Option.isSome] at h; cases h₂ : ctx.idem <;> simp [h₂] at h; assumption
have h : IdempotentOp ctx.op := by
simp [ContextInformation.isIdem, Option.isSome] at h;
match h₂ : ctx.idem with
| none =>
simp [h₂] at h
| some val =>
simp [h₂] at h
exact val.down
induction e using List.two_step_induction with
| empty => rfl
| single => rfl
@ -169,7 +178,7 @@ theorem Context.sort_loop_nonEmpty (xs : List Nat) (h : xs ≠ []) : sort.loop x
theorem Context.evalList_insert
(ctx : Context α)
(h : IsCommutative ctx.op)
(h : Commutative ctx.op)
(x : Nat)
(xs : List Nat)
: evalList α ctx (insert x xs) = evalList α ctx (x::xs) := by
@ -190,7 +199,7 @@ theorem Context.evalList_insert
theorem Context.evalList_sort_congr
(ctx : Context α)
(h : IsCommutative ctx.op)
(h : Commutative ctx.op)
(h₂ : evalList α ctx a = evalList α ctx b)
(h₃ : a ≠ [])
(h₄ : b ≠ [])
@ -209,7 +218,7 @@ theorem Context.evalList_sort_congr
theorem Context.evalList_sort_loop_swap
(ctx : Context α)
(h : IsCommutative ctx.op)
(h : Commutative ctx.op)
(xs ys : List Nat)
: evalList α ctx (sort.loop xs (y::ys)) = evalList α ctx (sort.loop (y::xs) ys) := by
induction ys generalizing y xs with
@ -224,7 +233,7 @@ theorem Context.evalList_sort_loop_swap
theorem Context.evalList_sort_cons
(ctx : Context α)
(h : IsCommutative ctx.op)
(h : Commutative ctx.op)
(x : Nat)
(xs : List Nat)
: evalList α ctx (sort (x :: xs)) = evalList α ctx (x :: sort xs) := by
@ -247,7 +256,14 @@ theorem Context.evalList_sort_cons
all_goals simp [insert_nonEmpty]
theorem Context.evalList_sort (ctx : Context α) (h : ContextInformation.isComm ctx) (e : List Nat) : evalList α ctx (sort e) = evalList α ctx e := by
have h : IsCommutative ctx.op := by simp [ContextInformation.isComm, Option.isSome] at h; cases h₂ : ctx.comm <;> simp [h₂] at h; assumption
have h : Commutative ctx.op := by
simp [ContextInformation.isComm, Option.isSome] at h
match h₂ : ctx.comm with
| none =>
simp only [h₂] at h
| some val =>
simp [h₂] at h
exact val.down
induction e using List.two_step_induction with
| empty => rfl
| single => rfl
@ -269,10 +285,12 @@ theorem Context.toList_nonEmpty (e : Expr) : e.toList ≠ [] := by
theorem Context.unwrap_isNeutral
{ctx : Context α}
{x : Nat}
: ContextInformation.isNeutral ctx x = true → IsNeutral (EvalInformation.evalOp ctx) (EvalInformation.evalVar (β := α) ctx x) := by
: ContextInformation.isNeutral ctx x = true → LawfulIdentity (EvalInformation.evalOp ctx) (EvalInformation.evalVar (β := α) ctx x) := by
simp [ContextInformation.isNeutral, Option.isSome, EvalInformation.evalOp, EvalInformation.evalVar]
match (var ctx x).neutral with
| some hn => intro; assumption
| some hn =>
intro
exact hn.down
| none => intro; contradiction
theorem Context.evalList_removeNeutrals (ctx : Context α) (e : List Nat) : evalList α ctx (removeNeutrals ctx e) = evalList α ctx e := by
@ -283,10 +301,12 @@ theorem Context.evalList_removeNeutrals (ctx : Context α) (e : List Nat) : eval
case h_1 => rfl
case h_2 h => split at h <;> simp_all
| step x y ys ih =>
cases h₁ : ContextInformation.isNeutral ctx x <;> cases h₂ : ContextInformation.isNeutral ctx y <;> cases h₃ : removeNeutrals.loop ctx ys
cases h₁ : ContextInformation.isNeutral ctx x <;>
cases h₂ : ContextInformation.isNeutral ctx y <;>
cases h₃ : removeNeutrals.loop ctx ys
<;> simp [removeNeutrals, removeNeutrals.loop, h₁, h₂, h₃, evalList, ←ih]
<;> (try simp [unwrap_isNeutral h₂ |>.2])
<;> (try simp [unwrap_isNeutral h₁ |>.1])
<;> (try simp [unwrap_isNeutral h₂ |>.right_id])
<;> (try simp [unwrap_isNeutral h₁ |>.left_id])
theorem Context.evalList_append
(ctx : Context α)

View file

@ -268,8 +268,8 @@ macro "rfl'" : tactic => `(tactic| set_option smartUnfolding false in with_unfol
/--
`ac_rfl` proves equalities up to application of an associative and commutative operator.
```
instance : IsAssociative (α := Nat) (.+.) := ⟨Nat.add_assoc⟩
instance : IsCommutative (α := Nat) (.+.) := ⟨Nat.add_comm⟩
instance : Associative (α := Nat) (.+.) := ⟨Nat.add_assoc⟩
instance : Commutative (α := Nat) (.+.) := ⟨Nat.add_comm⟩
example (a b c d : Nat) : a + b + c + d = d + (b + c) + a := by ac_rfl
```

View file

@ -11,6 +11,7 @@ import Lean.Elab.Tactic.Rewrite
namespace Lean.Meta.AC
open Lean.Data.AC
open Lean.Elab.Tactic
open Std
abbrev ACExpr := Lean.Data.AC.Expr
@ -43,13 +44,13 @@ def getInstance (cls : Name) (exprs : Array Expr) : MetaM (Option Expr) := do
| _ => return none
def preContext (expr : Expr) : MetaM (Option PreContext) := do
if let some assoc := ←getInstance ``IsAssociative #[expr] then
if let some assoc := ←getInstance ``Associative #[expr] then
return some
{ assoc,
op := expr
id := 0
comm := ←getInstance ``IsCommutative #[expr]
idem := ←getInstance ``IsIdempotent #[expr] }
comm := ←getInstance ``Commutative #[expr]
idem := ←getInstance ``IdempotentOp #[expr] }
return none
@ -99,13 +100,14 @@ where
mkContext (α : Expr) (u : Level) (vars : Array Expr) : MetaM (Array Bool × Expr) := do
let arbitrary := vars[0]!
let zero := mkLevelZeroEx ()
let noneE := mkApp (mkConst ``Option.none [zero])
let someE := mkApp2 (mkConst ``Option.some [zero])
let plift := mkApp (mkConst ``PLift [zero])
let pliftUp := mkApp2 (mkConst ``PLift.up [zero])
let noneE tp := mkApp (mkConst ``Option.none [zero]) (plift tp)
let someE tp v := mkApp2 (mkConst ``Option.some [zero]) (plift tp) (pliftUp tp v)
let vars ← vars.mapM fun x => do
let isNeutral :=
let isNeutralClass := mkApp3 (mkConst ``IsNeutral [u]) α preContext.op x
match ←getInstance ``IsNeutral #[preContext.op, x] with
let isNeutralClass := mkApp3 (mkConst ``LawfulIdentity [u]) α preContext.op x
match ←getInstance ``LawfulIdentity #[preContext.op, x] with
| none => (false, noneE isNeutralClass)
| some isNeutral => (true, someE isNeutralClass isNeutral)
@ -116,13 +118,13 @@ where
let vars ← mkListLit (mkApp2 (mkConst ``Variable [u]) α preContext.op) vars
let comm :=
let commClass := mkApp2 (mkConst ``IsCommutative [u]) α preContext.op
let commClass := mkApp2 (mkConst ``Commutative [u]) α preContext.op
match preContext.comm with
| none => noneE commClass
| some comm => someE commClass comm
let idem :=
let idemClass := mkApp2 (mkConst ``IsIdempotent [u]) α preContext.op
let idemClass := mkApp2 (mkConst ``IdempotentOp [u]) α preContext.op
match preContext.idem with
| none => noneE idemClass
| some idem => someE idemClass idem
@ -130,12 +132,12 @@ where
return (isNeutrals, mkApp7 (mkConst ``Lean.Data.AC.Context.mk [u]) α preContext.op preContext.assoc comm idem vars arbitrary)
convert : ACExpr → Expr
| Data.AC.Expr.op l r => mkApp2 (mkConst ``Data.AC.Expr.op) (convert l) (convert r)
| Data.AC.Expr.var x => mkApp (mkConst ``Data.AC.Expr.var) $ mkNatLit x
| .op l r => mkApp2 (mkConst ``Data.AC.Expr.op) (convert l) (convert r)
| .var x => mkApp (mkConst ``Data.AC.Expr.var) $ mkNatLit x
convertTarget (vars : Array Expr) : ACExpr → Expr
| Data.AC.Expr.op l r => mkApp2 preContext.op (convertTarget vars l) (convertTarget vars r)
| Data.AC.Expr.var x => vars[x]!
| .op l r => mkApp2 preContext.op (convertTarget vars l) (convertTarget vars r)
| .var x => vars[x]!
def rewriteUnnormalized (mvarId : MVarId) : MetaM Unit := do
let simpCtx :=

View file

@ -1,8 +1,9 @@
opaque f : Bool → Bool → Bool
axiom f_comm (a b : Bool) : f a b = f b a
axiom f_assoc (a b c : Bool) : f (f a b) c = f a (f b c)
instance : Lean.IsCommutative f := ⟨f_comm⟩
instance : Lean.IsAssociative f := ⟨f_assoc⟩
open Std
instance : Commutative f := ⟨f_comm⟩
instance : Associative f := ⟨f_assoc⟩
example (a b c : Bool) : f (f a b) c = f (f a c) b :=
by ac_rfl -- good

View file

@ -1,12 +1,12 @@
open Lean
open Std
instance : IsAssociative (α := Nat) HAdd.hAdd := ⟨Nat.add_assoc⟩
instance : IsCommutative (α := Nat) HAdd.hAdd := ⟨Nat.add_comm⟩
instance : IsNeutral HAdd.hAdd 0 := ⟨Nat.zero_add, Nat.add_zero⟩
instance : Associative (α := Nat) HAdd.hAdd := ⟨Nat.add_assoc⟩
instance : Commutative (α := Nat) HAdd.hAdd := ⟨Nat.add_comm⟩
instance : LawfulCommIdentity HAdd.hAdd 0 where right_id := Nat.add_zero
instance : IsAssociative (α := Nat) HMul.hMul := ⟨Nat.mul_assoc⟩
instance : IsCommutative (α := Nat) HMul.hMul := ⟨Nat.mul_comm⟩
instance : IsNeutral HMul.hMul 1 := ⟨Nat.one_mul, Nat.mul_one⟩
instance : Associative (α := Nat) HMul.hMul := ⟨Nat.mul_assoc⟩
instance : Commutative (α := Nat) HMul.hMul := ⟨Nat.mul_comm⟩
instance : LawfulCommIdentity HMul.hMul 1 where right_id := Nat.mul_one
@[simp] theorem succ_le_succ_iff {x y : Nat} : x.succ ≤ y.succ ↔ x ≤ y :=
⟨Nat.le_of_succ_le_succ, Nat.succ_le_succ⟩
@ -52,22 +52,23 @@ theorem Nat.zero_max (n : Nat) : max 0 n = n := by simp [Nat.max_def]
theorem Nat.max_zero (n : Nat) : max n 0 = n := by rw [max_comm, Nat.zero_max]
instance : IsAssociative (α := Nat) max := ⟨max_assoc⟩
instance : IsCommutative (α := Nat) max := ⟨max_comm⟩
instance : IsIdempotent (α := Nat) max := ⟨max_idem⟩
instance : IsNeutral max 0 := ⟨Nat.zero_max, Nat.max_zero⟩
instance : Associative (α := Nat) max := ⟨max_assoc⟩
instance : Commutative (α := Nat) max := ⟨max_comm⟩
instance : IdempotentOp (α := Nat) max := ⟨max_idem⟩
instance : LawfulCommIdentity max 0 where
right_id := Nat.max_zero
instance : IsAssociative And := ⟨λ p q r => propext ⟨λ ⟨⟨hp, hq⟩, hr⟩ => ⟨hp, hq, hr⟩, λ ⟨hp, hq, hr⟩ => ⟨⟨hp, hq⟩, hr⟩⟩⟩
instance : IsCommutative And := ⟨λ p q => propext ⟨λ ⟨hp, hq⟩ => ⟨hq, hp⟩, λ ⟨hq, hp⟩ => ⟨hp, hq⟩⟩⟩
instance : IsIdempotent And := ⟨λ p => propext ⟨λ ⟨hp, _⟩ => hp, λ hp => ⟨hp, hp⟩⟩⟩
instance : IsNeutral And True :=
⟨λ p => propext ⟨λ ⟨_, hp⟩ => hp, λ hp => ⟨True.intro, hp⟩⟩, λ p => propext ⟨λ ⟨hp, _⟩ => hp, λ hp => ⟨hp, True.intro⟩⟩
instance : Associative And := ⟨λ _p _q _r => propext ⟨λ ⟨⟨hp, hq⟩, hr⟩ => ⟨hp, hq, hr⟩, λ ⟨hp, hq, hr⟩ => ⟨⟨hp, hq⟩, hr⟩⟩⟩
instance : Commutative And := ⟨λ _p _q => propext ⟨λ ⟨hp, hq⟩ => ⟨hq, hp⟩, λ ⟨hq, hp⟩ => ⟨hp, hq⟩⟩⟩
instance : IdempotentOp And := ⟨λ _p => propext ⟨λ ⟨hp, _⟩ => hp, λ hp => ⟨hp, hp⟩⟩⟩
instance : LawfulCommIdentity And True where
right_id _p := propext ⟨λ ⟨hp, _⟩ => hp, λ hp => ⟨hp, True.intro⟩⟩
instance : IsAssociative Or := ⟨by simp [or_assoc]⟩
instance : IsCommutative Or := ⟨λ p q => propext ⟨λ hpq => hpq.elim Or.inr Or.inl, λ hqp => hqp.elim Or.inr Or.inl⟩⟩
instance : IsIdempotent Or := ⟨λ p => propext ⟨λ hp => hp.elim id id, Or.inl⟩⟩
instance : IsNeutral Or False :=
⟨λ p => propext ⟨λ hfp => hfp.elim False.elim id, Or.inr⟩, λ p => propext ⟨λ hpf => hpf.elim id False.elim, Or.inl⟩
instance : Associative Or := ⟨by simp [or_assoc]⟩
instance : Commutative Or := ⟨λ_p _q => propext ⟨λ hpq => hpq.elim Or.inr Or.inl, λ hqp => hqp.elim Or.inr Or.inl⟩⟩
instance : IdempotentOp Or := ⟨λ_p => propext ⟨λ hp => hp.elim id id, Or.inl⟩⟩
instance : LawfulCommIdentity Or False where
right_id _p := propext ⟨λ hpf => hpf.elim id False.elim, Or.inl⟩
example (x y z : Nat) : x + y + 0 + z = z + (x + y) := by ac_rfl