From f2e06ead5436e01e585a15dcc795ccb96bf3d182 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 28 Jun 2025 15:41:22 -0700 Subject: [PATCH] feat: support for `LawfulEqCmp` in `grind` (#9069) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR implements support for the type class `LawfulEqCmp`. Examples: ```lean example (a b c : Vector (List Nat) n) : b = c → a.compareLex (List.compareLex compare) b = o → o = .eq → a = c := by grind example [Ord α] [Std.LawfulEqCmp (compare : α → α → Ordering)] (a b c : Array (Vector (List α) n)) : b = c → o = .eq → a.compareLex (Vector.compareLex (List.compareLex compare)) b = o → a = c := by grind ``` --- src/Lean/Meta/Tactic/Grind.lean | 1 + src/Lean/Meta/Tactic/Grind/LawfulEqCmp.lean | 53 +++++++++++++++++++++ src/Lean/Meta/Tactic/Grind/Main.lean | 7 ++- src/Lean/Meta/Tactic/Grind/Types.lean | 10 ++++ src/Lean/Meta/Tactic/Grind/Util.lean | 6 +++ tests/lean/run/grind_lawful_eq_cmp.lean | 17 +++++++ 6 files changed, 93 insertions(+), 1 deletion(-) create mode 100644 src/Lean/Meta/Tactic/Grind/LawfulEqCmp.lean create mode 100644 tests/lean/run/grind_lawful_eq_cmp.lean diff --git a/src/Lean/Meta/Tactic/Grind.lean b/src/Lean/Meta/Tactic/Grind.lean index 7e33907547..cee3370e91 100644 --- a/src/Lean/Meta/Tactic/Grind.lean +++ b/src/Lean/Meta/Tactic/Grind.lean @@ -31,6 +31,7 @@ import Lean.Meta.Tactic.Grind.MatchDiscrOnly import Lean.Meta.Tactic.Grind.Diseq import Lean.Meta.Tactic.Grind.MBTC import Lean.Meta.Tactic.Grind.Lookahead +import Lean.Meta.Tactic.Grind.LawfulEqCmp namespace Lean diff --git a/src/Lean/Meta/Tactic/Grind/LawfulEqCmp.lean b/src/Lean/Meta/Tactic/Grind/LawfulEqCmp.lean new file mode 100644 index 0000000000..6bafda34bd --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/LawfulEqCmp.lean @@ -0,0 +1,53 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import Lean.Meta.Tactic.Grind.Types + +/-! +Support for type class `LawfulEqCmp`. +-/ +/- +Note: we will have similar support for `Associative` and `Commutative`. In the future, we should have +a mechanism for letting users to install their own handlers. +-/ + +namespace Lean.Meta.Grind + +/-- +If `op` implements `LawfulEqCmp`, then returns the proof term for +`∀ a b, op a b = .eq → a = b` +-/ +def getLawfulEqCmpThm? (op : Expr) : GrindM (Option Expr) := do + if let some thm? := (← get).lawfulEqCmpMap.find? { expr := op } then + return thm? + let thm? ← go? + modify fun s => { s with lawfulEqCmpMap := s.lawfulEqCmpMap.insert { expr := op } thm? } + return thm? +where + go? : MetaM (Option Expr) := do + unless (← getEnv).contains ``Std.LawfulEqCmp do return none + let opType ← whnf (← inferType op) + let .forallE _ α b _ := opType | return none + if b.hasLooseBVars then return none + let .forallE _ α' b _ ← whnf b | return none + unless b.isConstOf ``Ordering do return none + unless (← isDefEq α α') do return none + let u ← getLevel α + let some u ← decLevel? u | return none + let lawfulEqCmp := mkApp2 (mkConst ``Std.LawfulEqCmp [u]) α op + let .some lawfulEqCmpInst ← trySynthInstance lawfulEqCmp | return none + return some <| mkApp3 (mkConst ``Std.LawfulEqCmp.eq_of_compare [u]) α op lawfulEqCmpInst + +def propagateLawfulEqCmp (e : Expr) : GoalM Unit := do + let some op := getBinOp e | return () + let some thm ← getLawfulEqCmpThm? op | return () + let oeq ← getOrderingEqExpr + unless (← isEqv e oeq) do return () + let a := e.appFn!.appArg! + let b := e.appArg! + pushEq a b <| mkApp3 thm a b (← mkEqProof e oeq) + +end Lean.Meta.Grind diff --git a/src/Lean/Meta/Tactic/Grind/Main.lean b/src/Lean/Meta/Tactic/Grind/Main.lean index 24434de5b7..b40b2a73f5 100644 --- a/src/Lean/Meta/Tactic/Grind/Main.lean +++ b/src/Lean/Meta/Tactic/Grind/Main.lean @@ -20,6 +20,7 @@ import Lean.Meta.Tactic.Grind.Split import Lean.Meta.Tactic.Grind.Solve import Lean.Meta.Tactic.Grind.SimpUtil import Lean.Meta.Tactic.Grind.Cases +import Lean.Meta.Tactic.Grind.LawfulEqCmp namespace Lean.Meta.Grind @@ -49,6 +50,7 @@ def mkMethods (fallback : Fallback) : CoreM Methods := do prop e propagateDown := fun e => do propagateForallPropDown e + propagateLawfulEqCmp e let .const declName _ := e.getAppFn | return () if let some prop := builtinPropagators.down[declName]? then prop e @@ -72,11 +74,12 @@ def GrindM.run (x : GrindM α) (params : Params) (fallback : Fallback) : MetaM let (bfalseExpr, scState) := shareCommonAlpha (mkConst ``Bool.false) scState let (btrueExpr, scState) := shareCommonAlpha (mkConst ``Bool.true) scState let (natZExpr, scState) := shareCommonAlpha (mkNatLit 0) scState + let (ordEqExpr, scState) := shareCommonAlpha (mkConst ``Ordering.eq) scState let simprocs := params.normProcs let simpMethods := Simp.mkMethods simprocs discharge? (wellBehavedDischarge := true) let simp := params.norm let config := params.config - x (← mkMethods fallback).toMethodsRef { config, simpMethods, simp, trueExpr, falseExpr, natZExpr, btrueExpr, bfalseExpr } + x (← mkMethods fallback).toMethodsRef { config, simpMethods, simp, trueExpr, falseExpr, natZExpr, btrueExpr, bfalseExpr, ordEqExpr } |>.run' { scState } private def mkCleanState (mvarId : MVarId) (params : Params) : MetaM Clean.State := mvarId.withContext do @@ -93,6 +96,7 @@ private def mkGoal (mvarId : MVarId) (params : Params) : GrindM Goal := do let btrueExpr ← getBoolTrueExpr let bfalseExpr ← getBoolFalseExpr let natZeroExpr ← getNatZeroExpr + let ordEqExpr ← getOrderingEqExpr let thmMap := params.ematch let casesTypes := params.casesTypes let clean ← mkCleanState mvarId params @@ -102,6 +106,7 @@ private def mkGoal (mvarId : MVarId) (params : Params) : GrindM Goal := do mkENodeCore btrueExpr (interpreted := false) (ctor := true) (generation := 0) mkENodeCore bfalseExpr (interpreted := false) (ctor := true) (generation := 0) mkENodeCore natZeroExpr (interpreted := true) (ctor := false) (generation := 0) + mkENodeCore ordEqExpr (interpreted := false) (ctor := true) (generation := 0) for thm in params.extra do activateTheorem thm 0 diff --git a/src/Lean/Meta/Tactic/Grind/Types.lean b/src/Lean/Meta/Tactic/Grind/Types.lean index 83bb66281a..f7ab618d41 100644 --- a/src/Lean/Meta/Tactic/Grind/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/Types.lean @@ -111,6 +111,7 @@ structure Context where natZExpr : Expr btrueExpr : Expr bfalseExpr : Expr + ordEqExpr : Expr -- `Ordering.eq` /-- Key for the congruence theorem cache. -/ structure CongrTheoremCacheKey where @@ -188,6 +189,11 @@ structure State where counters : Counters := {} /-- Split diagnostic information. This information is only collected when `set_option diagnostics true` -/ splitDiags : PArray SplitDiagInfo := {} + /-- + Mapping from binary functions `f` to a theorem `thm : ∀ a b, f a b = .eq → a = b` + if it implements the `LawfulEqCmp` type class. + -/ + lawfulEqCmpMap : PHashMap ExprPtr (Option Expr) := {} private opaque MethodsRefPointed : NonemptyType.{0} private def MethodsRef : Type := MethodsRefPointed.type @@ -236,6 +242,10 @@ def getBoolFalseExpr : GrindM Expr := do def getNatZeroExpr : GrindM Expr := do return (← readThe Context).natZExpr +/-- Returns the internalized `Ordering.eq`. -/ +def getOrderingEqExpr : GrindM Expr := do + return (← readThe Context).ordEqExpr + def cheapCasesOnly : GrindM Bool := return (← readThe Context).cheapCases diff --git a/src/Lean/Meta/Tactic/Grind/Util.lean b/src/Lean/Meta/Tactic/Grind/Util.lean index bafcbe5fc5..933f36afe1 100644 --- a/src/Lean/Meta/Tactic/Grind/Util.lean +++ b/src/Lean/Meta/Tactic/Grind/Util.lean @@ -222,4 +222,10 @@ def isIte (e : Expr) := def isDIte (e : Expr) := e.isAppOf ``dite && e.getAppNumArgs >= 5 +def getBinOp (e : Expr) : Option Expr := + if !e.isApp then none else + let f := e.appFn! + if !f.isApp then none else + some f.appFn! + end Lean.Meta.Grind diff --git a/tests/lean/run/grind_lawful_eq_cmp.lean b/tests/lean/run/grind_lawful_eq_cmp.lean new file mode 100644 index 0000000000..9b313ee3d0 --- /dev/null +++ b/tests/lean/run/grind_lawful_eq_cmp.lean @@ -0,0 +1,17 @@ +import Std + +example (f : α → α → Ordering) [Std.LawfulEqCmp f] (a b c : α) : b = c → f a b = o → o = .eq → a = c := by + grind + +example (a b c : Vector (List Nat) 10) : b = c → a.compareLex (List.compareLex compare) b = o → o = .eq → a = c := by + grind + +example (a b c : Vector (List Nat) 10) : b = c → o = .eq → a.compareLex (List.compareLex compare) b = o → a = c := by + grind + +example (a b c : Array (Vector (List Nat) n)) : b = c → o = .eq → a.compareLex (Vector.compareLex (List.compareLex compare)) b = o → a = c := by + grind + +example [Ord α] [Std.LawfulEqCmp (compare : α → α → Ordering)] (a b c : Array (Vector (List α) n)) + : b = c → o = .eq → a.compareLex (Vector.compareLex (List.compareLex compare)) b = o → a = c := by + grind