feat: support for LawfulEqCmp in grind (#9069)

This PR implements support for the type class `LawfulEqCmp`. Examples:
```lean
example (a b c : Vector (List Nat) n)
    : b = c → a.compareLex (List.compareLex compare) b = o → o = .eq → a = c := by
  grind

example [Ord α] [Std.LawfulEqCmp (compare : α → α → Ordering)] (a b c : Array (Vector (List α) n))
    : b = c → o = .eq → a.compareLex (Vector.compareLex (List.compareLex compare)) b = o → a = c := by
  grind
```
This commit is contained in:
Leonardo de Moura 2025-06-28 15:41:22 -07:00 committed by GitHub
parent f6bb524406
commit f2e06ead54
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 93 additions and 1 deletions

View file

@ -31,6 +31,7 @@ import Lean.Meta.Tactic.Grind.MatchDiscrOnly
import Lean.Meta.Tactic.Grind.Diseq
import Lean.Meta.Tactic.Grind.MBTC
import Lean.Meta.Tactic.Grind.Lookahead
import Lean.Meta.Tactic.Grind.LawfulEqCmp
namespace Lean

View file

@ -0,0 +1,53 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Meta.Tactic.Grind.Types
/-!
Support for type class `LawfulEqCmp`.
-/
/-
Note: we will have similar support for `Associative` and `Commutative`. In the future, we should have
a mechanism for letting users to install their own handlers.
-/
namespace Lean.Meta.Grind
/--
If `op` implements `LawfulEqCmp`, then returns the proof term for
`∀ a b, op a b = .eq → a = b`
-/
def getLawfulEqCmpThm? (op : Expr) : GrindM (Option Expr) := do
if let some thm? := (← get).lawfulEqCmpMap.find? { expr := op } then
return thm?
let thm? ← go?
modify fun s => { s with lawfulEqCmpMap := s.lawfulEqCmpMap.insert { expr := op } thm? }
return thm?
where
go? : MetaM (Option Expr) := do
unless (← getEnv).contains ``Std.LawfulEqCmp do return none
let opType ← whnf (← inferType op)
let .forallE _ α b _ := opType | return none
if b.hasLooseBVars then return none
let .forallE _ α' b _ ← whnf b | return none
unless b.isConstOf ``Ordering do return none
unless (← isDefEq α α') do return none
let u ← getLevel α
let some u ← decLevel? u | return none
let lawfulEqCmp := mkApp2 (mkConst ``Std.LawfulEqCmp [u]) α op
let .some lawfulEqCmpInst ← trySynthInstance lawfulEqCmp | return none
return some <| mkApp3 (mkConst ``Std.LawfulEqCmp.eq_of_compare [u]) α op lawfulEqCmpInst
def propagateLawfulEqCmp (e : Expr) : GoalM Unit := do
let some op := getBinOp e | return ()
let some thm ← getLawfulEqCmpThm? op | return ()
let oeq ← getOrderingEqExpr
unless (← isEqv e oeq) do return ()
let a := e.appFn!.appArg!
let b := e.appArg!
pushEq a b <| mkApp3 thm a b (← mkEqProof e oeq)
end Lean.Meta.Grind

View file

@ -20,6 +20,7 @@ import Lean.Meta.Tactic.Grind.Split
import Lean.Meta.Tactic.Grind.Solve
import Lean.Meta.Tactic.Grind.SimpUtil
import Lean.Meta.Tactic.Grind.Cases
import Lean.Meta.Tactic.Grind.LawfulEqCmp
namespace Lean.Meta.Grind
@ -49,6 +50,7 @@ def mkMethods (fallback : Fallback) : CoreM Methods := do
prop e
propagateDown := fun e => do
propagateForallPropDown e
propagateLawfulEqCmp e
let .const declName _ := e.getAppFn | return ()
if let some prop := builtinPropagators.down[declName]? then
prop e
@ -72,11 +74,12 @@ def GrindM.run (x : GrindM α) (params : Params) (fallback : Fallback) : MetaM
let (bfalseExpr, scState) := shareCommonAlpha (mkConst ``Bool.false) scState
let (btrueExpr, scState) := shareCommonAlpha (mkConst ``Bool.true) scState
let (natZExpr, scState) := shareCommonAlpha (mkNatLit 0) scState
let (ordEqExpr, scState) := shareCommonAlpha (mkConst ``Ordering.eq) scState
let simprocs := params.normProcs
let simpMethods := Simp.mkMethods simprocs discharge? (wellBehavedDischarge := true)
let simp := params.norm
let config := params.config
x (← mkMethods fallback).toMethodsRef { config, simpMethods, simp, trueExpr, falseExpr, natZExpr, btrueExpr, bfalseExpr }
x (← mkMethods fallback).toMethodsRef { config, simpMethods, simp, trueExpr, falseExpr, natZExpr, btrueExpr, bfalseExpr, ordEqExpr }
|>.run' { scState }
private def mkCleanState (mvarId : MVarId) (params : Params) : MetaM Clean.State := mvarId.withContext do
@ -93,6 +96,7 @@ private def mkGoal (mvarId : MVarId) (params : Params) : GrindM Goal := do
let btrueExpr ← getBoolTrueExpr
let bfalseExpr ← getBoolFalseExpr
let natZeroExpr ← getNatZeroExpr
let ordEqExpr ← getOrderingEqExpr
let thmMap := params.ematch
let casesTypes := params.casesTypes
let clean ← mkCleanState mvarId params
@ -102,6 +106,7 @@ private def mkGoal (mvarId : MVarId) (params : Params) : GrindM Goal := do
mkENodeCore btrueExpr (interpreted := false) (ctor := true) (generation := 0)
mkENodeCore bfalseExpr (interpreted := false) (ctor := true) (generation := 0)
mkENodeCore natZeroExpr (interpreted := true) (ctor := false) (generation := 0)
mkENodeCore ordEqExpr (interpreted := false) (ctor := true) (generation := 0)
for thm in params.extra do
activateTheorem thm 0

View file

@ -111,6 +111,7 @@ structure Context where
natZExpr : Expr
btrueExpr : Expr
bfalseExpr : Expr
ordEqExpr : Expr -- `Ordering.eq`
/-- Key for the congruence theorem cache. -/
structure CongrTheoremCacheKey where
@ -188,6 +189,11 @@ structure State where
counters : Counters := {}
/-- Split diagnostic information. This information is only collected when `set_option diagnostics true` -/
splitDiags : PArray SplitDiagInfo := {}
/--
Mapping from binary functions `f` to a theorem `thm : ∀ a b, f a b = .eq → a = b`
if it implements the `LawfulEqCmp` type class.
-/
lawfulEqCmpMap : PHashMap ExprPtr (Option Expr) := {}
private opaque MethodsRefPointed : NonemptyType.{0}
private def MethodsRef : Type := MethodsRefPointed.type
@ -236,6 +242,10 @@ def getBoolFalseExpr : GrindM Expr := do
def getNatZeroExpr : GrindM Expr := do
return (← readThe Context).natZExpr
/-- Returns the internalized `Ordering.eq`. -/
def getOrderingEqExpr : GrindM Expr := do
return (← readThe Context).ordEqExpr
def cheapCasesOnly : GrindM Bool :=
return (← readThe Context).cheapCases

View file

@ -222,4 +222,10 @@ def isIte (e : Expr) :=
def isDIte (e : Expr) :=
e.isAppOf ``dite && e.getAppNumArgs >= 5
def getBinOp (e : Expr) : Option Expr :=
if !e.isApp then none else
let f := e.appFn!
if !f.isApp then none else
some f.appFn!
end Lean.Meta.Grind

View file

@ -0,0 +1,17 @@
import Std
example (f : αα → Ordering) [Std.LawfulEqCmp f] (a b c : α) : b = c → f a b = o → o = .eq → a = c := by
grind
example (a b c : Vector (List Nat) 10) : b = c → a.compareLex (List.compareLex compare) b = o → o = .eq → a = c := by
grind
example (a b c : Vector (List Nat) 10) : b = c → o = .eq → a.compareLex (List.compareLex compare) b = o → a = c := by
grind
example (a b c : Array (Vector (List Nat) n)) : b = c → o = .eq → a.compareLex (Vector.compareLex (List.compareLex compare)) b = o → a = c := by
grind
example [Ord α] [Std.LawfulEqCmp (compare : αα → Ordering)] (a b c : Array (Vector (List α) n))
: b = c → o = .eq → a.compareLex (Vector.compareLex (List.compareLex compare)) b = o → a = c := by
grind