feat: support for LawfulEqCmp in grind (#9069)
This PR implements support for the type class `LawfulEqCmp`. Examples:
```lean
example (a b c : Vector (List Nat) n)
: b = c → a.compareLex (List.compareLex compare) b = o → o = .eq → a = c := by
grind
example [Ord α] [Std.LawfulEqCmp (compare : α → α → Ordering)] (a b c : Array (Vector (List α) n))
: b = c → o = .eq → a.compareLex (Vector.compareLex (List.compareLex compare)) b = o → a = c := by
grind
```
This commit is contained in:
parent
f6bb524406
commit
f2e06ead54
6 changed files with 93 additions and 1 deletions
|
|
@ -31,6 +31,7 @@ import Lean.Meta.Tactic.Grind.MatchDiscrOnly
|
|||
import Lean.Meta.Tactic.Grind.Diseq
|
||||
import Lean.Meta.Tactic.Grind.MBTC
|
||||
import Lean.Meta.Tactic.Grind.Lookahead
|
||||
import Lean.Meta.Tactic.Grind.LawfulEqCmp
|
||||
|
||||
namespace Lean
|
||||
|
||||
|
|
|
|||
53
src/Lean/Meta/Tactic/Grind/LawfulEqCmp.lean
Normal file
53
src/Lean/Meta/Tactic/Grind/LawfulEqCmp.lean
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
/-
|
||||
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Lean.Meta.Tactic.Grind.Types
|
||||
|
||||
/-!
|
||||
Support for type class `LawfulEqCmp`.
|
||||
-/
|
||||
/-
|
||||
Note: we will have similar support for `Associative` and `Commutative`. In the future, we should have
|
||||
a mechanism for letting users to install their own handlers.
|
||||
-/
|
||||
|
||||
namespace Lean.Meta.Grind
|
||||
|
||||
/--
|
||||
If `op` implements `LawfulEqCmp`, then returns the proof term for
|
||||
`∀ a b, op a b = .eq → a = b`
|
||||
-/
|
||||
def getLawfulEqCmpThm? (op : Expr) : GrindM (Option Expr) := do
|
||||
if let some thm? := (← get).lawfulEqCmpMap.find? { expr := op } then
|
||||
return thm?
|
||||
let thm? ← go?
|
||||
modify fun s => { s with lawfulEqCmpMap := s.lawfulEqCmpMap.insert { expr := op } thm? }
|
||||
return thm?
|
||||
where
|
||||
go? : MetaM (Option Expr) := do
|
||||
unless (← getEnv).contains ``Std.LawfulEqCmp do return none
|
||||
let opType ← whnf (← inferType op)
|
||||
let .forallE _ α b _ := opType | return none
|
||||
if b.hasLooseBVars then return none
|
||||
let .forallE _ α' b _ ← whnf b | return none
|
||||
unless b.isConstOf ``Ordering do return none
|
||||
unless (← isDefEq α α') do return none
|
||||
let u ← getLevel α
|
||||
let some u ← decLevel? u | return none
|
||||
let lawfulEqCmp := mkApp2 (mkConst ``Std.LawfulEqCmp [u]) α op
|
||||
let .some lawfulEqCmpInst ← trySynthInstance lawfulEqCmp | return none
|
||||
return some <| mkApp3 (mkConst ``Std.LawfulEqCmp.eq_of_compare [u]) α op lawfulEqCmpInst
|
||||
|
||||
def propagateLawfulEqCmp (e : Expr) : GoalM Unit := do
|
||||
let some op := getBinOp e | return ()
|
||||
let some thm ← getLawfulEqCmpThm? op | return ()
|
||||
let oeq ← getOrderingEqExpr
|
||||
unless (← isEqv e oeq) do return ()
|
||||
let a := e.appFn!.appArg!
|
||||
let b := e.appArg!
|
||||
pushEq a b <| mkApp3 thm a b (← mkEqProof e oeq)
|
||||
|
||||
end Lean.Meta.Grind
|
||||
|
|
@ -20,6 +20,7 @@ import Lean.Meta.Tactic.Grind.Split
|
|||
import Lean.Meta.Tactic.Grind.Solve
|
||||
import Lean.Meta.Tactic.Grind.SimpUtil
|
||||
import Lean.Meta.Tactic.Grind.Cases
|
||||
import Lean.Meta.Tactic.Grind.LawfulEqCmp
|
||||
|
||||
namespace Lean.Meta.Grind
|
||||
|
||||
|
|
@ -49,6 +50,7 @@ def mkMethods (fallback : Fallback) : CoreM Methods := do
|
|||
prop e
|
||||
propagateDown := fun e => do
|
||||
propagateForallPropDown e
|
||||
propagateLawfulEqCmp e
|
||||
let .const declName _ := e.getAppFn | return ()
|
||||
if let some prop := builtinPropagators.down[declName]? then
|
||||
prop e
|
||||
|
|
@ -72,11 +74,12 @@ def GrindM.run (x : GrindM α) (params : Params) (fallback : Fallback) : MetaM
|
|||
let (bfalseExpr, scState) := shareCommonAlpha (mkConst ``Bool.false) scState
|
||||
let (btrueExpr, scState) := shareCommonAlpha (mkConst ``Bool.true) scState
|
||||
let (natZExpr, scState) := shareCommonAlpha (mkNatLit 0) scState
|
||||
let (ordEqExpr, scState) := shareCommonAlpha (mkConst ``Ordering.eq) scState
|
||||
let simprocs := params.normProcs
|
||||
let simpMethods := Simp.mkMethods simprocs discharge? (wellBehavedDischarge := true)
|
||||
let simp := params.norm
|
||||
let config := params.config
|
||||
x (← mkMethods fallback).toMethodsRef { config, simpMethods, simp, trueExpr, falseExpr, natZExpr, btrueExpr, bfalseExpr }
|
||||
x (← mkMethods fallback).toMethodsRef { config, simpMethods, simp, trueExpr, falseExpr, natZExpr, btrueExpr, bfalseExpr, ordEqExpr }
|
||||
|>.run' { scState }
|
||||
|
||||
private def mkCleanState (mvarId : MVarId) (params : Params) : MetaM Clean.State := mvarId.withContext do
|
||||
|
|
@ -93,6 +96,7 @@ private def mkGoal (mvarId : MVarId) (params : Params) : GrindM Goal := do
|
|||
let btrueExpr ← getBoolTrueExpr
|
||||
let bfalseExpr ← getBoolFalseExpr
|
||||
let natZeroExpr ← getNatZeroExpr
|
||||
let ordEqExpr ← getOrderingEqExpr
|
||||
let thmMap := params.ematch
|
||||
let casesTypes := params.casesTypes
|
||||
let clean ← mkCleanState mvarId params
|
||||
|
|
@ -102,6 +106,7 @@ private def mkGoal (mvarId : MVarId) (params : Params) : GrindM Goal := do
|
|||
mkENodeCore btrueExpr (interpreted := false) (ctor := true) (generation := 0)
|
||||
mkENodeCore bfalseExpr (interpreted := false) (ctor := true) (generation := 0)
|
||||
mkENodeCore natZeroExpr (interpreted := true) (ctor := false) (generation := 0)
|
||||
mkENodeCore ordEqExpr (interpreted := false) (ctor := true) (generation := 0)
|
||||
for thm in params.extra do
|
||||
activateTheorem thm 0
|
||||
|
||||
|
|
|
|||
|
|
@ -111,6 +111,7 @@ structure Context where
|
|||
natZExpr : Expr
|
||||
btrueExpr : Expr
|
||||
bfalseExpr : Expr
|
||||
ordEqExpr : Expr -- `Ordering.eq`
|
||||
|
||||
/-- Key for the congruence theorem cache. -/
|
||||
structure CongrTheoremCacheKey where
|
||||
|
|
@ -188,6 +189,11 @@ structure State where
|
|||
counters : Counters := {}
|
||||
/-- Split diagnostic information. This information is only collected when `set_option diagnostics true` -/
|
||||
splitDiags : PArray SplitDiagInfo := {}
|
||||
/--
|
||||
Mapping from binary functions `f` to a theorem `thm : ∀ a b, f a b = .eq → a = b`
|
||||
if it implements the `LawfulEqCmp` type class.
|
||||
-/
|
||||
lawfulEqCmpMap : PHashMap ExprPtr (Option Expr) := {}
|
||||
|
||||
private opaque MethodsRefPointed : NonemptyType.{0}
|
||||
private def MethodsRef : Type := MethodsRefPointed.type
|
||||
|
|
@ -236,6 +242,10 @@ def getBoolFalseExpr : GrindM Expr := do
|
|||
def getNatZeroExpr : GrindM Expr := do
|
||||
return (← readThe Context).natZExpr
|
||||
|
||||
/-- Returns the internalized `Ordering.eq`. -/
|
||||
def getOrderingEqExpr : GrindM Expr := do
|
||||
return (← readThe Context).ordEqExpr
|
||||
|
||||
def cheapCasesOnly : GrindM Bool :=
|
||||
return (← readThe Context).cheapCases
|
||||
|
||||
|
|
|
|||
|
|
@ -222,4 +222,10 @@ def isIte (e : Expr) :=
|
|||
def isDIte (e : Expr) :=
|
||||
e.isAppOf ``dite && e.getAppNumArgs >= 5
|
||||
|
||||
def getBinOp (e : Expr) : Option Expr :=
|
||||
if !e.isApp then none else
|
||||
let f := e.appFn!
|
||||
if !f.isApp then none else
|
||||
some f.appFn!
|
||||
|
||||
end Lean.Meta.Grind
|
||||
|
|
|
|||
17
tests/lean/run/grind_lawful_eq_cmp.lean
Normal file
17
tests/lean/run/grind_lawful_eq_cmp.lean
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
import Std
|
||||
|
||||
example (f : α → α → Ordering) [Std.LawfulEqCmp f] (a b c : α) : b = c → f a b = o → o = .eq → a = c := by
|
||||
grind
|
||||
|
||||
example (a b c : Vector (List Nat) 10) : b = c → a.compareLex (List.compareLex compare) b = o → o = .eq → a = c := by
|
||||
grind
|
||||
|
||||
example (a b c : Vector (List Nat) 10) : b = c → o = .eq → a.compareLex (List.compareLex compare) b = o → a = c := by
|
||||
grind
|
||||
|
||||
example (a b c : Array (Vector (List Nat) n)) : b = c → o = .eq → a.compareLex (Vector.compareLex (List.compareLex compare)) b = o → a = c := by
|
||||
grind
|
||||
|
||||
example [Ord α] [Std.LawfulEqCmp (compare : α → α → Ordering)] (a b c : Array (Vector (List α) n))
|
||||
: b = c → o = .eq → a.compareLex (Vector.compareLex (List.compareLex compare)) b = o → a = c := by
|
||||
grind
|
||||
Loading…
Add table
Reference in a new issue