fix: grind linarith counterexample (#10960)
This PR fixes a bug in the `grind linarith` model/counterexample construction. Closes #10500
This commit is contained in:
parent
a9d6bc60d0
commit
cdaa827b2a
7 changed files with 168 additions and 2 deletions
|
|
@ -91,6 +91,7 @@ def internalize (e : Expr) (parent? : Option Expr) : GoalM Unit := do
|
|||
let some type := getType? e | return ()
|
||||
if isForbiddenParent parent? then return ()
|
||||
if let some structId ← getStructId? type then LinearM.run structId do
|
||||
trace[grind.linarith.internalize] "{e}"
|
||||
setTermStructId e
|
||||
linearExt.markTerm e
|
||||
markVars e
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ Authors: Leonardo de Moura
|
|||
module
|
||||
prelude
|
||||
public import Lean.Meta.Tactic.Grind.Arith.Linear.Types
|
||||
import Lean.Meta.Tactic.Grind.Arith.Linear.Reify
|
||||
import Lean.Meta.Tactic.Grind.Arith.ModelUtil
|
||||
import Init.Grind.Module.Envelope
|
||||
public section
|
||||
|
|
@ -28,6 +29,58 @@ private def toQ? (e : Expr) : Option Expr :=
|
|||
| Grind.IntModule.OfNatModule.toQ _ _ a => some a
|
||||
| _ => none
|
||||
|
||||
/--
|
||||
Helper function for evaluating terms that have been processed by `internalize`, but
|
||||
we did not added them to constraints. See comment at `assignTerms`.
|
||||
-/
|
||||
private partial def evalTermAt? (e : Expr) (s : Struct) (model : Std.HashMap Expr Rat) : MetaM (Option Rat) := do
|
||||
go e
|
||||
where
|
||||
go (e : Expr) : OptionT MetaM Rat := do
|
||||
if let some val := model.get? e then
|
||||
return val
|
||||
match_expr e with
|
||||
| Neg.neg _ i a => if isNegInst s i then return - (← go a) else failure
|
||||
| HAdd.hAdd _ _ _ i a b => if isAddInst s i then return (← go a) + (← go b) else failure
|
||||
| HSub.hSub _ _ _ i a b => if isSubInst s i then return (← go a) - (← go b) else failure
|
||||
| HMul.hMul _ _ _ i a b => if isHomoMulInst s i then return (← go a) * (← go b) else failure
|
||||
| HSMul.hSMul _ _ _ i a b =>
|
||||
if isSMulIntInst s i then
|
||||
let k ← getIntValue? a
|
||||
return k * (← go b)
|
||||
else if isSMulNatInst s i then
|
||||
let k ← getNatValue? a
|
||||
return k * (← go b)
|
||||
else
|
||||
failure
|
||||
| Zero.zero _ i => if isZeroInst s i then return 0 else failure
|
||||
| OfNat.ofNat _ n _ => let k ← getNatValue? n; return k
|
||||
| _ => failure
|
||||
|
||||
/--
|
||||
Assigns terms that do not have an assignment associated with them because they were used only as markers
|
||||
for communicating with the `grind` core.
|
||||
For example, suppose we have `a + b = c`. The `internalize` function marks `a + b`, `a`, and `b`
|
||||
with theory variables. Let's assume we also have `c = 2*b`. In this case, `internalize` marks `2*b`
|
||||
with a theory variable, and `b` is already marked. Then, when both equalities are asserted
|
||||
in the `grind` core, the `linarith` module is notified that `a + b = 2*b` is true, and it is then
|
||||
normalized as `a + -1*b = 0`. The terms `a` and `b` are assigned rational values by the search
|
||||
procedure, but `a + b` and `2*b` are not. This procedure assigns them using the model computed by the
|
||||
search procedure.
|
||||
|
||||
**Note**: We should reconsider whether to add the equalities `「a+b」 = a + b` and `「2*b」 = 2*b`
|
||||
to force the search procedure to assign interpretations to these terms.
|
||||
-/
|
||||
private def assignTerms (goal : Goal) (structId : Nat) (model : Std.HashMap Expr Rat) : MetaM (Std.HashMap Expr Rat) := do
|
||||
let mut model := model
|
||||
let s ← linearExt.getStateCore goal
|
||||
let struct := s.structs[structId]!
|
||||
for (e, structId') in s.exprToStructIdEntries do
|
||||
if structId == structId' && !model.contains e then
|
||||
if let some v ← evalTermAt? e struct model then
|
||||
model := assignEqc goal e v model
|
||||
return model
|
||||
|
||||
/--
|
||||
Construct a model that satisfies all constraints in the linarith model for the structure with id `structId`.
|
||||
It also assigns values to (integer) terms that have not been internalized by the linarith model.
|
||||
|
|
@ -42,6 +95,7 @@ def mkModel (goal : Goal) (structId : Nat) : MetaM (Array (Expr × Rat)) := do
|
|||
if (← hasType s.type node) then
|
||||
if let some v := getAssignment? s node.self then
|
||||
model := assignEqc goal node.self v model
|
||||
model ← assignTerms goal structId model
|
||||
-- Assign `toQ a` terms
|
||||
for e in goal.exprs do
|
||||
let node ← goal.getENode e
|
||||
|
|
|
|||
|
|
@ -18,6 +18,9 @@ def IneqCnstr.throwUnexpected (c : IneqCnstr) : LinearM α := do
|
|||
def DiseqCnstr.throwUnexpected (c : DiseqCnstr) : LinearM α := do
|
||||
throwError "`grind linarith` internal error, unexpected{indentD (← c.denoteExpr)}"
|
||||
|
||||
def EqCnstr.throwUnexpected (c :EqCnstr) : LinearM α := do
|
||||
throwError "`grind linarith` internal error, unexpected{indentD (← c.denoteExpr)}"
|
||||
|
||||
private def checkIsNextVar (x : Var) : LinearM Unit := do
|
||||
if x != (← getStruct).assignment.size then
|
||||
throwError "`grind linarith` internal error, assigning variable out of order"
|
||||
|
|
@ -246,6 +249,28 @@ private def resetDecisionStack : SearchM Unit := do
|
|||
let first := (← get).cases[0]!
|
||||
modifyStruct fun s => { first.saved with assignment := s.assignment }
|
||||
|
||||
/-- Assign eliminated variables using `elimEqs` field. -/
|
||||
private def assignElimVars : SearchM Unit := do
|
||||
if (← inconsistent) then return ()
|
||||
go (← getStruct).elimStack
|
||||
where
|
||||
go (xs : List Var) : SearchM Unit := do
|
||||
match xs with
|
||||
| [] => return ()
|
||||
| x :: xs =>
|
||||
let some c := (← getStruct).elimEqs[x]!
|
||||
| throwError "`grind` internal error, eliminated variable must have equation associated with it"
|
||||
-- `x` may not be the max variable
|
||||
let a := c.p.coeff x
|
||||
if a == 0 then c.throwUnexpected
|
||||
-- ensure `x` is 0 when evaluating `c.p`
|
||||
modifyStruct fun s => { s with assignment := s.assignment.set x 0 }
|
||||
let some v ← c.p.eval? | c.throwUnexpected
|
||||
let v := (-v) / a
|
||||
traceAssignment x v
|
||||
modifyStruct fun s => { s with assignment := s.assignment.set x v }
|
||||
go xs
|
||||
|
||||
/-- Search for an assignment/model for the linear constraints. -/
|
||||
private def searchAssignmentMain : SearchM Unit := do
|
||||
repeat
|
||||
|
|
@ -253,6 +278,7 @@ private def searchAssignmentMain : SearchM Unit := do
|
|||
checkSystem "linarith"
|
||||
if (← hasAssignment) then
|
||||
trace[grind.debug.linarith.search] "found assignment"
|
||||
assignElimVars
|
||||
return ()
|
||||
if (← isInconsistent) then
|
||||
-- `grind` state is inconsistent
|
||||
|
|
|
|||
|
|
@ -245,6 +245,8 @@ structure State where
|
|||
typeIdOf : PHashMap ExprPtr (Option Nat) := {}
|
||||
/- Mapping from expressions/terms to their structure ids. -/
|
||||
exprToStructId : PHashMap ExprPtr Nat := {}
|
||||
/-- `exprToStructId` content as an array for traversal. -/
|
||||
exprToStructIdEntries : PArray (Expr × Nat) := {}
|
||||
/--
|
||||
Some types are unordered rings, so we do not process them in `linarith`.
|
||||
When such types are detected in `getStructId?`, we add them to the set
|
||||
|
|
|
|||
|
|
@ -40,7 +40,10 @@ def setTermStructId (e : Expr) : LinearM Unit := do
|
|||
unless structId' == structId do
|
||||
reportIssue! "expression in two different structure in linarith module{indentExpr e}"
|
||||
return ()
|
||||
modify' fun s => { s with exprToStructId := s.exprToStructId.insert { expr := e } structId }
|
||||
modify' fun s => { s with
|
||||
exprToStructId := s.exprToStructId.insert { expr := e } structId
|
||||
exprToStructIdEntries := s.exprToStructIdEntries.push (e, structId)
|
||||
}
|
||||
|
||||
def getNoNatDivInst : LinearM Expr := do
|
||||
let some inst := (← getStruct).noNatDivInst?
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ Returns `true` if `e` should be treated as an interpreted value by the arithmeti
|
|||
def isInterpretedTerm (e : Expr) : Bool :=
|
||||
isNatNum e || isIntNum e || e.isAppOf ``HAdd.hAdd || e.isAppOf ``HMul.hMul || e.isAppOf ``HSub.hSub || e.isAppOf ``HSMul.hSMul
|
||||
|| e.isAppOf ``Neg.neg || e.isAppOf ``HDiv.hDiv || e.isAppOf ``HMod.hMod || e.isAppOf ``One.one || e.isAppOf ``Zero.zero
|
||||
|| e.isAppOf ``NatCast.natCast || e.isIte || e.isDIte || e.isAppOf ``OfNat.ofNat || e.isAppOf ``Grind.ToInt.toInt
|
||||
|| e.isAppOf ``Inv.inv || e.isAppOf ``NatCast.natCast || e.isIte || e.isDIte || e.isAppOf ``OfNat.ofNat || e.isAppOf ``Grind.ToInt.toInt
|
||||
|| e.isAppOf ``Grind.IntModule.OfNatModule.toQ || e matches .lit (.natVal _)
|
||||
|
||||
/--
|
||||
|
|
|
|||
80
tests/lean/run/grind_10500.lean
Normal file
80
tests/lean/run/grind_10500.lean
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
open Lean Grind Std
|
||||
|
||||
set_option warn.sorry false
|
||||
|
||||
namespace Ex1
|
||||
variable [Field Q] [LT Q] [LE Q] [LawfulOrderLT Q] [IsLinearOrder Q] [OrderedRing Q]
|
||||
|
||||
variable (a₁ a₂ a₃ a₄ a₅ : Q)
|
||||
variable (α L ν : Q)
|
||||
|
||||
/--
|
||||
trace: [grind.linarith.model] a₁ := 0
|
||||
[grind.linarith.model] a₂ := 0
|
||||
[grind.linarith.model] a₃ := 0
|
||||
[grind.linarith.model] a₄ := 0
|
||||
[grind.linarith.model] a₅ := 0
|
||||
[grind.linarith.model] α := 0
|
||||
[grind.linarith.model] L := 0
|
||||
[grind.linarith.model] ν := 2
|
||||
-/
|
||||
#guard_msgs in
|
||||
set_option trace.grind.linarith.model true in
|
||||
theorem upper_bound
|
||||
(hL : L = α) (hL1 : L ≤ 1)
|
||||
(ha₁ : 0 ≤ a₁) (ha₂ : 0 ≤ a₂) (ha₃ : 0 ≤ a₃) (ha₄ : 0 ≤ a₄) (ha₅ : 0 ≤ a₅)
|
||||
(hα : α = a₁ + a₂ + a₃ + a₄ + a₅) :
|
||||
ν ≤ 9/10 := by
|
||||
fail_if_success grind
|
||||
sorry
|
||||
|
||||
end Ex1
|
||||
|
||||
namespace Ex2
|
||||
|
||||
variable (a₁ a₂ a₃ a₄ a₅ : Rat)
|
||||
variable (α L ν : Rat)
|
||||
|
||||
/--
|
||||
trace: [grind.linarith.model] a₁ := 0
|
||||
[grind.linarith.model] a₂ := 0
|
||||
[grind.linarith.model] a₃ := 0
|
||||
[grind.linarith.model] a₄ := 0
|
||||
[grind.linarith.model] a₅ := 0
|
||||
[grind.linarith.model] α := 0
|
||||
[grind.linarith.model] L := 0
|
||||
[grind.linarith.model] ν := 2
|
||||
-/
|
||||
#guard_msgs in
|
||||
set_option trace.grind.linarith.model true in
|
||||
theorem upper_bound
|
||||
(hL : L = α) (hL1 : L ≤ 1)
|
||||
(ha₁ : 0 ≤ a₁) (ha₂ : 0 ≤ a₂) (ha₃ : 0 ≤ a₃) (ha₄ : 0 ≤ a₄) (ha₅ : 0 ≤ a₅)
|
||||
(hα : α = a₁ + a₂ + a₃ + a₄ + a₅) :
|
||||
ν ≤ 9/10 := by
|
||||
fail_if_success grind
|
||||
sorry
|
||||
|
||||
end Ex2
|
||||
|
||||
/--
|
||||
trace: [grind.linarith.model] a := 0
|
||||
[grind.linarith.model] b := 0
|
||||
[grind.linarith.model] c := 0
|
||||
-/
|
||||
#guard_msgs in
|
||||
set_option trace.grind.linarith.model true in
|
||||
example [Field α] [LE α] [LT α] [Std.IsPreorder α] [OrderedRing α] (a b c : α) (h : a = b + c) : False := by
|
||||
fail_if_success grind
|
||||
sorry
|
||||
|
||||
/--
|
||||
trace: [grind.linarith.model] a := 0
|
||||
[grind.linarith.model] b := 0
|
||||
[grind.linarith.model] c := 0
|
||||
-/
|
||||
#guard_msgs in
|
||||
set_option trace.grind.linarith.model true in
|
||||
example (a b c : Rat) (h : a = b + c) : False := by
|
||||
fail_if_success grind
|
||||
sorry
|
||||
Loading…
Add table
Reference in a new issue