From d4b17b9fd2ceac558dad498d978eb17d7d496484 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 12 Jun 2025 20:21:35 -0400 Subject: [PATCH] feat: counterexamples for `grind linarith` module (#8756) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR implements counterexamples for grind linarith. Example: ```lean example [CommRing α] [LinearOrder α] [Ring.IsOrdered α] (a b c d : α) : b ≥ 0 → c > b → d > b → a ≠ b + c → a > b + c → a < b + d → False := by grind ``` produces the counterexample ``` a := 7/2 b := 1 c := 2 d := 3 ``` ```lean example [IntModule α] [LinearOrder α] [IntModule.IsOrdered α] (a b c d : α) : a ≤ b → a - c ≥ 0 + d → d ≤ 0 → b = c → a ≠ b → False := by grind ``` generates the counterexample ``` a := 0 b := 1 c := 1 d := -1 ``` --- .../Meta/Tactic/Grind/Arith/Cutsat/Model.lean | 66 +--------- src/Lean/Meta/Tactic/Grind/Arith/Linear.lean | 3 + .../Meta/Tactic/Grind/Arith/Linear/Model.lean | 42 ++++++ .../Meta/Tactic/Grind/Arith/Linear/PP.lean | 31 +++++ .../Meta/Tactic/Grind/Arith/ModelUtil.lean | 122 ++++++++++++++++++ src/Lean/Meta/Tactic/Grind/PP.lean | 7 + tests/lean/run/grind_linarith_2.lean | 24 ++++ 7 files changed, 232 insertions(+), 63 deletions(-) create mode 100644 src/Lean/Meta/Tactic/Grind/Arith/Linear/Model.lean create mode 100644 src/Lean/Meta/Tactic/Grind/Arith/Linear/PP.lean create mode 100644 src/Lean/Meta/Tactic/Grind/Arith/ModelUtil.lean diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Model.lean b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Model.lean index 763fc570f1..df79391106 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Model.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Cutsat/Model.lean @@ -5,6 +5,7 @@ Authors: Leonardo de Moura -/ prelude import Lean.Meta.Tactic.Grind.Types +import Lean.Meta.Tactic.Grind.Arith.ModelUtil namespace Lean.Meta.Grind.Arith.Cutsat @@ -24,51 +25,11 @@ private def getCutsatAssignment? (goal : Goal) (node : ENode) : Option Rat := Id else return none -private partial def satisfyDiseqs (goal : Goal) (a : Std.HashMap Expr Rat) (e : Expr) (v : Int) : Bool := Id.run do - let some parents := goal.parents.find? { expr := e } | return true - for parent in parents do - let_expr Eq _ lhs rhs := parent | continue - let some root := goal.getRoot? parent | continue - if root.isConstOf ``False then - let some lhsRoot := goal.getRoot? lhs | continue - let some rhsRoot := goal.getRoot? rhs | continue - if lhsRoot == e && !checkDiseq rhsRoot then return false - if rhsRoot == e && !checkDiseq lhsRoot then return false - return true -where - checkDiseq (other : Expr) : Bool := - if let some v' := a[other]? then - v' != v - else - true - -private partial def pickUnusedValue (goal : Goal) (a : Std.HashMap Expr Rat) (e : Expr) (next : Int) (alreadyUsed : Std.HashSet Int) : Int := - go next -where - go (next : Int) : Int := - if alreadyUsed.contains next then - go (next+1) - else if satisfyDiseqs goal a e next then - next - else - go (next + 1) - -def isInterpretedTerm (e : Expr) : Bool := - isNatNum e || isIntNum e || e.isAppOf ``HAdd.hAdd || e.isAppOf ``HMul.hMul || e.isAppOf ``HSub.hSub - || e.isAppOf ``Neg.neg || e.isAppOf ``HDiv.hDiv || e.isAppOf ``HMod.hMod - || e.isAppOf ``NatCast.natCast || e.isIte || e.isDIte - private def natCast? (e : Expr) : Option Expr := let_expr NatCast.natCast _ inst a := e | none let_expr instNatCastInt := inst | none some a -private def assignEqc (goal : Goal) (e : Expr) (v : Rat) (a : Std.HashMap Expr Rat) : Std.HashMap Expr Rat := Id.run do - let mut a := a - for e in goal.getEqc e do - a := a.insert e v - return a - def getAssignment? (goal : Goal) (e : Expr) : MetaM (Option Rat) := do let node ← goal.getENode (← goal.getRoot e) if let some v := getCutsatAssignment? goal node then @@ -89,8 +50,6 @@ Remark: it uses rational numbers because cutsat may have failed to build an integer model. -/ def mkModel (goal : Goal) : MetaM (Array (Expr × Rat)) := do - let mut used : Std.HashSet Int := {} - let mut nextVal : Int := 0 let mut model := {} -- Assign on expressions associated with cutsat terms or interpreted terms for e in goal.exprs do @@ -98,7 +57,6 @@ def mkModel (goal : Goal) : MetaM (Array (Expr × Rat)) := do if node.isRoot then if (← isIntNatENode node) then if let some v ← getAssignment? goal node.self then - if v.den == 1 then used := used.insert v.num model := assignEqc goal node.self v model -- Assign cast terms for e in goal.exprs do @@ -108,26 +66,8 @@ def mkModel (goal : Goal) : MetaM (Array (Expr × Rat)) := do if model[n]?.isNone then let some v := model[i]? | pure () model := assignEqc goal n v model - -- Assign the remaining ones with values not used by cutsat - for e in goal.exprs do - let node ← goal.getENode e - if node.isRoot then - if (← isIntNatENode node) then - if model[node.self]?.isNone then - let v := pickUnusedValue goal model node.self nextVal used - model := assignEqc goal node.self v model - used := used.insert v - let mut r := #[] - for (e, v) in model do - unless isInterpretedTerm e do - r := r.push (e, v) - r := r.qsort fun (e₁, _) (e₂, _) => - let g₁ := goal.getGeneration e₁ - let g₂ := goal.getGeneration e₂ - if g₁ != g₂ then g₁ < g₂ else e₁.lt e₂ - if (← isTracingEnabledFor `grind.cutsat.model) then - for (x, v) in r do - trace[grind.cutsat.model] "{quoteIfArithTerm x} := {v}" + let r ← finalizeModel goal isIntNatENode model + traceModel `grind.cutsat.model r return r end Lean.Meta.Grind.Arith.Cutsat diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Linear.lean b/src/Lean/Meta/Tactic/Grind/Arith/Linear.lean index cc1c15abf2..474823f514 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Linear.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Linear.lean @@ -17,12 +17,15 @@ import Lean.Meta.Tactic.Grind.Arith.Linear.SearchM import Lean.Meta.Tactic.Grind.Arith.Linear.Search import Lean.Meta.Tactic.Grind.Arith.Linear.PropagateEq import Lean.Meta.Tactic.Grind.Arith.Linear.Internalize +import Lean.Meta.Tactic.Grind.Arith.Linear.Model +import Lean.Meta.Tactic.Grind.Arith.Linear.PP namespace Lean builtin_initialize registerTraceClass `grind.linarith builtin_initialize registerTraceClass `grind.linarith.internalize builtin_initialize registerTraceClass `grind.linarith.assert +builtin_initialize registerTraceClass `grind.linarith.model builtin_initialize registerTraceClass `grind.linarith.assert.unsat (inherited := true) builtin_initialize registerTraceClass `grind.linarith.assert.trivial (inherited := true) builtin_initialize registerTraceClass `grind.linarith.assert.store (inherited := true) diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Linear/Model.lean b/src/Lean/Meta/Tactic/Grind/Arith/Linear/Model.lean new file mode 100644 index 0000000000..cea3935708 --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/Arith/Linear/Model.lean @@ -0,0 +1,42 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import Lean.Meta.Tactic.Grind.Types +import Lean.Meta.Tactic.Grind.Arith.ModelUtil + +namespace Lean.Meta.Grind.Arith.Linear + +def getAssignment? (s : Struct) (e : Expr) : Option Rat := Id.run do + let some x := s.varMap.find? { expr := e } | return none + if h : x < s.assignment.size then + return some s.assignment[x] + else + return none + +private def hasType (type : Expr) (n : ENode): MetaM Bool := + withDefault do + let type' ← inferType n.self + isDefEq type' type + +/-- +Construct a model that satisfies all constraints in the linarith model for the structure with id `structId`. +It also assigns values to (integer) terms that have not been internalized by the linarith model. +-/ +def mkModel (goal : Goal) (structId : Nat) : MetaM (Array (Expr × Rat)) := do + let mut model := {} + let s := goal.arith.linear.structs[structId]! + -- Assign on expressions associated with cutsat terms or interpreted terms + for e in goal.exprs do + let node ← goal.getENode e + if node.isRoot then + if (← hasType s.type node) then + if let some v := getAssignment? s node.self then + model := assignEqc goal node.self v model + let r ← finalizeModel goal (hasType s.type) model + traceModel `grind.linarith.model r + return r + +end Lean.Meta.Grind.Arith.Linear diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Linear/PP.lean b/src/Lean/Meta/Tactic/Grind/Arith/Linear/PP.lean new file mode 100644 index 0000000000..80579d3383 --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/Arith/Linear/PP.lean @@ -0,0 +1,31 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import Lean.Meta.Tactic.Grind.Arith.Linear.Model + +namespace Lean.Meta.Grind.Arith.Linear + +def ppStruct? (goal : Goal) (s : Struct) : MetaM (Option MessageData) := do + let model ← mkModel goal s.id + if model.isEmpty then return none + let mut ms := #[] + for (e, val) in model do + ms := ms.push <| .trace { cls := `assign } m!"{Arith.quoteIfArithTerm e} := {val}" #[] + return some (.trace { cls := `linarith } m!"Linarith assignment for `{s.type}`" ms) + +def pp? (goal : Goal) : MetaM (Option MessageData) := do + let mut msgs := #[] + for struct in goal.arith.linear.structs do + let some msg ← ppStruct? goal struct | pure () + msgs := msgs.push msg + if msgs.isEmpty then + return none + else if h : msgs.size = 1 then + return some msgs[0] + else + return some (.trace { cls := `linarith } "Linarith" msgs) + +end Lean.Meta.Grind.Arith.Linear diff --git a/src/Lean/Meta/Tactic/Grind/Arith/ModelUtil.lean b/src/Lean/Meta/Tactic/Grind/Arith/ModelUtil.lean new file mode 100644 index 0000000000..8872ad4f5a --- /dev/null +++ b/src/Lean/Meta/Tactic/Grind/Arith/ModelUtil.lean @@ -0,0 +1,122 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import Std.Internal.Rat +import Lean.Meta.Tactic.Grind.Types + +namespace Lean.Meta.Grind.Arith +open Std.Internal +/-! +Helper functions for constructing counterexamples in the `linarith` and `cutsat` modules +-/ + +/-- +Returns `true` if adding the assignment `e := v` to `a` will falsify any asserted disequality in core. +-/ +private partial def satisfyDiseqs (goal : Goal) (a : Std.HashMap Expr Rat) (e : Expr) (v : Int) : Bool := Id.run do + let some parents := goal.parents.find? { expr := e } | return true + for parent in parents do + let_expr Eq _ lhs rhs := parent | continue + let some root := goal.getRoot? parent | continue + if root.isConstOf ``False then + let some lhsRoot := goal.getRoot? lhs | continue + let some rhsRoot := goal.getRoot? rhs | continue + if lhsRoot == e && !checkDiseq rhsRoot then return false + if rhsRoot == e && !checkDiseq lhsRoot then return false + return true +where + checkDiseq (other : Expr) : Bool := + if let some v' := a[other]? then + v' != v + else + true + +/-- +Returns an integer value `i` for assigning to `e` s.t. adding `e := i` to `a` will not falsify any disequality +and `i` is not in `alreadyUsed`. +-/ +partial def pickUnusedValue (goal : Goal) (a : Std.HashMap Expr Rat) (e : Expr) (next : Int) (alreadyUsed : Std.HashSet Int) : Int := + go next +where + go (next : Int) : Int := + if alreadyUsed.contains next then + go (next+1) + else if satisfyDiseqs goal a e next then + next + else + go (next + 1) + +/-- +Returns `true` if `e` should be treated as an interpreted value by the arithmetic modules. +-/ +def isInterpretedTerm (e : Expr) : Bool := + isNatNum e || isIntNum e || e.isAppOf ``HAdd.hAdd || e.isAppOf ``HMul.hMul || e.isAppOf ``HSub.hSub + || e.isAppOf ``Neg.neg || e.isAppOf ``HDiv.hDiv || e.isAppOf ``HMod.hMod || e.isAppOf ``One.one || e.isAppOf ``Zero.zero + || e.isAppOf ``NatCast.natCast || e.isIte || e.isDIte || e.isAppOf ``OfNat.ofNat + +/-- +Adds the assignments `e' := v` to `a` for each `e'` in the equivalence class os `e`. +-/ +def assignEqc (goal : Goal) (e : Expr) (v : Rat) (a : Std.HashMap Expr Rat) : Std.HashMap Expr Rat := Id.run do + let mut a := a + for e in goal.getEqc e do + a := a.insert e v + return a + +/-- +Assigns terms in the goal that satisfy `isTarget`. +Recall that not all terms are communicated to `linarith` and `cutsat` modules if they do not appear in relevant constraints. +The idea is to assign unused integer values that have not been used in the model and do not falsify equalities and disequalities +in core. +-/ +private def assignUnassigned (goal : Goal) (isTarget : ENode → MetaM Bool) (model : Std.HashMap Expr Rat) : MetaM (Std.HashMap Expr Rat) := do + let mut nextVal : Int := 0 + -- Collect used values + let mut used : Std.HashSet Int := {} + for (_, v) in model do + if v.den == 1 then + used := used.insert v.num + let mut model := model + -- Assign the remaining ones with values not used by cutsat + for e in goal.exprs do + let node ← goal.getENode e + if node.isRoot then + if (← isTarget node) then + if model[node.self]?.isNone then + let v := pickUnusedValue goal model node.self nextVal used + model := assignEqc goal node.self v model + used := used.insert v + nextVal := v + 1 + return model + +/-- Sorts assignment first by expression generation and then `Expr.lt` -/ +private def sortModel (goal : Goal) (m : Array (Expr × Rat)) : Array (Expr × Rat) := + m.qsort fun (e₁, _) (e₂, _) => + let g₁ := goal.getGeneration e₁ + let g₂ := goal.getGeneration e₂ + if g₁ != g₂ then g₁ < g₂ else e₁.lt e₂ + +/-- +Converts the given model into a sorted array of pairs `(e, v)` representing assignments `e := v`. +`isTarget` is a predicate used to detect terms that must be in the model but have not been assigned a value (see: `assignUnassigned`) +The pairs are sorted using `e`s generation and then `Expr.lt`. +Only terms s.t. `isInterpretedTerm e = false` are included into the resulting array. +-/ +def finalizeModel (goal : Goal) (isTarget : ENode → MetaM Bool) (model : Std.HashMap Expr Rat) : MetaM (Array (Expr × Rat)) := do + let model ← assignUnassigned goal isTarget model + let mut r := #[] + for (e, v) in model do + unless isInterpretedTerm e do + r := r.push (e, v) + return sortModel goal r + +/-- If the given trace class is enabled, trace the model using the class. -/ +def traceModel (traceClass : Name) (model : Array (Expr × Rat)) : MetaM Unit := do + if (← isTracingEnabledFor traceClass) then + for (x, v) in model do + addTrace traceClass m!"{quoteIfArithTerm x} := {v}" + +end Lean.Meta.Grind.Arith diff --git a/src/Lean/Meta/Tactic/Grind/PP.lean b/src/Lean/Meta/Tactic/Grind/PP.lean index 0bfd5dca0c..084b65bb7f 100644 --- a/src/Lean/Meta/Tactic/Grind/PP.lean +++ b/src/Lean/Meta/Tactic/Grind/PP.lean @@ -9,6 +9,7 @@ import Init.Grind.PP import Lean.Meta.Tactic.Grind.Types import Lean.Meta.Tactic.Grind.Arith.Model import Lean.Meta.Tactic.Grind.Arith.CommRing.PP +import Lean.Meta.Tactic.Grind.Arith.Linear.PP namespace Lean.Meta.Grind @@ -147,6 +148,11 @@ private def ppCommRing : M Unit := do let some msg ← Arith.CommRing.pp? goal | return () pushMsg msg +private def ppLinarith : M Unit := do + let goal ← read + let some msg ← Arith.Linear.pp? goal | return () + pushMsg msg + private def ppThresholds (c : Grind.Config) : M Unit := do let goal ← read let maxGen := goal.exprs.foldl (init := 0) fun g e => @@ -194,6 +200,7 @@ where ppActiveTheoremPatterns ppOffset ppCutsat + ppLinarith ppCommRing ppThresholds config diff --git a/tests/lean/run/grind_linarith_2.lean b/tests/lean/run/grind_linarith_2.lean index c5f74a4d31..8c7fe1bf7b 100644 --- a/tests/lean/run/grind_linarith_2.lean +++ b/tests/lean/run/grind_linarith_2.lean @@ -64,3 +64,27 @@ example [CommRing α] [LinearOrder α] [Ring.IsOrdered α] (a b c : α) example [CommRing α] [LinearOrder α] [Ring.IsOrdered α] (a b c : α) : c = a → a + b ≤ 3 → 3 ≤ b + c → a + b = 3 := by grind + +/-- +trace: [grind.linarith.model] a := 7/2 +[grind.linarith.model] b := 1 +[grind.linarith.model] c := 2 +[grind.linarith.model] d := 3 +-/ +#guard_msgs (drop error, trace) in +set_option trace.grind.linarith.model true in +example [CommRing α] [LinearOrder α] [Ring.IsOrdered α] (a b c d : α) + : b ≥ 0 → c > b → d > b → a ≠ b + c → a > b + c → a < b + d → False := by + grind + +/-- +trace: [grind.linarith.model] a := 0 +[grind.linarith.model] b := 1 +[grind.linarith.model] c := 1 +[grind.linarith.model] d := -1 +-/ +#guard_msgs (drop error, trace) in +set_option trace.grind.linarith.model true in +example [IntModule α] [LinearOrder α] [IntModule.IsOrdered α] (a b c d : α) + : a ≤ b → a - c ≥ 0 + d → d ≤ 0 → b = c → a ≠ b → False := by + grind