feat: eliminate equations in grind linarith (#8810)
This PR implements equality elimination in `grind linarith`. The current implementation supports only `IntModule` and `IntModule` + `NoNatZeroDivisors`
This commit is contained in:
parent
7b67727067
commit
4e96a4ff45
14 changed files with 504 additions and 119 deletions
|
|
@ -7,6 +7,7 @@ module
|
|||
prelude
|
||||
import Init.Grind.Ordered.Module
|
||||
import Init.Grind.Ordered.Ring
|
||||
import Init.Grind.CommRing.Field
|
||||
import all Init.Data.Ord
|
||||
import all Init.Data.AC
|
||||
import Init.Data.RArray
|
||||
|
|
@ -83,6 +84,11 @@ theorem Poly.denote'_go_eq_denote {α} [IntModule α] (ctx : Context α) (p : Po
|
|||
theorem Poly.denote'_eq_denote {α} [IntModule α] (ctx : Context α) (p : Poly) : p.denote' ctx = p.denote ctx := by
|
||||
unfold denote' <;> split <;> simp [denote, denote'_go_eq_denote] <;> ac_rfl
|
||||
|
||||
def Poly.coeff (p : Poly) (x : Var) : Int :=
|
||||
match p with
|
||||
| .add a y p => bif x == y then a else coeff p x
|
||||
| .nil => 0
|
||||
|
||||
def Poly.insert (k : Int) (v : Var) (p : Poly) : Poly :=
|
||||
match p with
|
||||
| .nil => .add k v .nil
|
||||
|
|
@ -283,12 +289,6 @@ theorem le_lt_combine {α} [IntModule α] [Preorder α] [IntModule.IsOrdered α]
|
|||
replace h₂ := hmul_neg (↑p₁.leadCoeff.natAbs) h₂ |>.mp hp
|
||||
exact le_add_lt h₁ h₂
|
||||
|
||||
theorem le_eq_combine {α} [IntModule α] [Preorder α] [IntModule.IsOrdered α] (ctx : Context α) (p₁ p₂ p₃ : Poly)
|
||||
: le_le_combine_cert p₁ p₂ p₃ → p₁.denote' ctx ≤ 0 → p₂.denote' ctx = 0 → p₃.denote' ctx ≤ 0 := by
|
||||
simp [le_le_combine_cert]; intro _ h₁ h₂; subst p₃; simp [h₂]
|
||||
replace h₁ := hmul_nonpos (coe_natAbs_nonneg p₂.leadCoeff) h₁
|
||||
assumption
|
||||
|
||||
def lt_lt_combine_cert (p₁ p₂ p₃ : Poly) : Bool :=
|
||||
let a₁ := p₁.leadCoeff.natAbs
|
||||
let a₂ := p₂.leadCoeff.natAbs
|
||||
|
|
@ -301,21 +301,6 @@ theorem lt_lt_combine {α} [IntModule α] [Preorder α] [IntModule.IsOrdered α]
|
|||
replace h₂ := hmul_neg (↑p₁.leadCoeff.natAbs) h₂ |>.mp hp₂
|
||||
exact lt_add_lt h₁ h₂
|
||||
|
||||
def lt_eq_combine_cert (p₁ p₂ p₃ : Poly) : Bool :=
|
||||
let a₁ := p₁.leadCoeff.natAbs
|
||||
let a₂ := p₂.leadCoeff.natAbs
|
||||
a₂ > (0 : Int) && p₃ == (p₁.mul a₂ |>.combine (p₂.mul a₁))
|
||||
|
||||
theorem lt_eq_combine {α} [IntModule α] [Preorder α] [IntModule.IsOrdered α] (ctx : Context α) (p₁ p₂ p₃ : Poly)
|
||||
: lt_eq_combine_cert p₁ p₂ p₃ → p₁.denote' ctx < 0 → p₂.denote' ctx = 0 → p₃.denote' ctx < 0 := by
|
||||
simp [-Int.natAbs_pos, -Int.ofNat_pos, lt_eq_combine_cert]; intro hp₁ _ h₁ h₂; subst p₃; simp [h₂]
|
||||
replace h₁ := hmul_neg (↑p₂.leadCoeff.natAbs) h₁ |>.mp hp₁
|
||||
assumption
|
||||
|
||||
theorem eq_eq_combine {α} [IntModule α] (ctx : Context α) (p₁ p₂ p₃ : Poly)
|
||||
: le_le_combine_cert p₁ p₂ p₃ → p₁.denote' ctx = 0 → p₂.denote' ctx = 0 → p₃.denote' ctx = 0 := by
|
||||
simp [le_le_combine_cert]; intro _ h₁ h₂; subst p₃; simp [h₁, h₂]
|
||||
|
||||
def diseq_split_cert (p₁ p₂ : Poly) : Bool :=
|
||||
p₂ == p₁.mul (-1)
|
||||
|
||||
|
|
@ -335,30 +320,6 @@ theorem diseq_split_resolve {α} [IntModule α] [LinearOrder α] [IntModule.IsOr
|
|||
intro h₁ h₂ h₃
|
||||
exact (diseq_split ctx p₁ p₂ h₁ h₂).resolve_left h₃
|
||||
|
||||
def eq_diseq_combine_cert (p₁ p₂ p₃ : Poly) : Bool :=
|
||||
let a₁ := p₁.leadCoeff.natAbs
|
||||
let a₂ := p₂.leadCoeff.natAbs
|
||||
a₁ ≠ 0 && p₃ == (p₁.mul a₂ |>.combine (p₂.mul a₁))
|
||||
|
||||
theorem eq_diseq_combine {α} [IntModule α] [NoNatZeroDivisors α] (ctx : Context α) (p₁ p₂ p₃ : Poly)
|
||||
: eq_diseq_combine_cert p₁ p₂ p₃ → p₁.denote' ctx = 0 → p₂.denote' ctx ≠ 0 → p₃.denote' ctx ≠ 0 := by
|
||||
simp [- Int.natAbs_eq_zero, -Int.natCast_eq_zero, eq_diseq_combine_cert]; intro hne _ h₁ h₂; subst p₃
|
||||
simp [h₁, h₂]; intro h
|
||||
have := no_nat_zero_divisors (p₁.leadCoeff.natAbs) (p₂.denote ctx) hne h
|
||||
contradiction
|
||||
|
||||
def eq_diseq_combine_cert' (p₁ p₂ p₃ : Poly) (k : Int) : Bool :=
|
||||
p₃ == (p₁.mul k |>.combine p₂)
|
||||
|
||||
/-
|
||||
Special case of `eq_diseq_combine` where leading coefficient `c₁` of `p₁` is `-k*c₂`, where
|
||||
`c₂` is the leading coefficient of `p₂`.
|
||||
-/
|
||||
theorem eq_diseq_combine' {α} [IntModule α] (ctx : Context α) (p₁ p₂ p₃ : Poly) (k : Int)
|
||||
: eq_diseq_combine_cert' p₁ p₂ p₃ k → p₁.denote' ctx = 0 → p₂.denote' ctx ≠ 0 → p₃.denote' ctx ≠ 0 := by
|
||||
simp [eq_diseq_combine_cert']; intro _ h₁ h₂; subst p₃
|
||||
simp [h₁, h₂]
|
||||
|
||||
/-!
|
||||
Helper theorems for internalizing facts into the linear arithmetic procedure
|
||||
-/
|
||||
|
|
@ -464,17 +425,57 @@ theorem zero_lt_one {α} [Ring α] [Preorder α] [Ring.IsOrdered α] (ctx : Cont
|
|||
simp [zero_lt_one_cert]; intro _ h; subst p; simp [Poly.denote, h, One.one, neg_hmul]
|
||||
rw [neg_lt_iff, neg_zero]; apply Ring.IsOrdered.zero_lt_one
|
||||
|
||||
def zero_ne_one_cert (p : Poly) : Bool :=
|
||||
p == .add 1 0 .nil
|
||||
|
||||
theorem zero_ne_one_of_ord_ring {α} [Ring α] [Preorder α] [Ring.IsOrdered α] (ctx : Context α) (p : Poly)
|
||||
: zero_ne_one_cert p → (0 : Var).denote ctx = One.one → p.denote' ctx ≠ 0 := by
|
||||
simp [zero_ne_one_cert]; intro _ h; subst p; simp [Poly.denote, h, One.one]
|
||||
intro h; have := Ring.IsOrdered.zero_lt_one (R := α); simp [h, Preorder.lt_irrefl] at this
|
||||
|
||||
theorem zero_ne_one_of_field {α} [Field α] (ctx : Context α) (p : Poly)
|
||||
: zero_ne_one_cert p → (0 : Var).denote ctx = One.one → p.denote' ctx ≠ 0 := by
|
||||
simp [zero_ne_one_cert]; intro _ h; subst p; simp [Poly.denote, h, One.one]
|
||||
intro h; have := Field.zero_ne_one (α := α); simp [h] at this
|
||||
|
||||
theorem zero_ne_one_of_char0 {α} [Ring α] [IsCharP α 0] (ctx : Context α) (p : Poly)
|
||||
: zero_ne_one_cert p → (0 : Var).denote ctx = One.one → p.denote' ctx ≠ 0 := by
|
||||
simp [zero_ne_one_cert]; intro _ h; subst p; simp [Poly.denote, h, One.one]
|
||||
intro h; have := IsCharP.intCast_eq_zero_iff (α := α) 0 1; simp [Ring.intCast_one] at this
|
||||
contradiction
|
||||
|
||||
def zero_ne_one_of_charC_cert (c : Nat) (p : Poly) : Bool :=
|
||||
(c:Int) > 1 && p == .add 1 0 .nil
|
||||
|
||||
theorem zero_ne_one_of_charC {α c} [Ring α] [IsCharP α c] (ctx : Context α) (p : Poly)
|
||||
: zero_ne_one_of_charC_cert c p → (0 : Var).denote ctx = One.one → p.denote' ctx ≠ 0 := by
|
||||
simp [zero_ne_one_of_charC_cert]; intro hc _ h; subst p; simp [Poly.denote, h, One.one]
|
||||
intro h; have h' := IsCharP.intCast_eq_zero_iff (α := α) c 1; simp [Ring.intCast_one] at h'
|
||||
replace h' := h'.mp h
|
||||
have := Int.emod_eq_of_lt (by decide) hc
|
||||
simp [this] at h'
|
||||
|
||||
/-!
|
||||
Coefficient normalization
|
||||
-/
|
||||
|
||||
def coeff_cert (p₁ p₂ : Poly) (k : Nat) :=
|
||||
k > 0 && p₁ == p₂.mul k
|
||||
def eq_neg_cert (p₁ p₂ : Poly) :=
|
||||
p₂ == p₁.mul (-1)
|
||||
|
||||
theorem eq_neg {α} [IntModule α] (ctx : Context α) (p₁ p₂ : Poly)
|
||||
: eq_neg_cert p₁ p₂ → p₁.denote' ctx = 0 → p₂.denote' ctx = 0 := by
|
||||
simp [eq_neg_cert]; intros; simp [*]
|
||||
|
||||
def eq_coeff_cert (p₁ p₂ : Poly) (k : Nat) :=
|
||||
k != 0 && p₁ == p₂.mul k
|
||||
|
||||
theorem eq_coeff {α} [IntModule α] [NoNatZeroDivisors α] (ctx : Context α) (p₁ p₂ : Poly) (k : Nat)
|
||||
: coeff_cert p₁ p₂ k → p₁.denote' ctx = 0 → p₂.denote' ctx = 0 := by
|
||||
simp [coeff_cert]; intro h _; subst p₁; simp
|
||||
exact no_nat_zero_divisors k (p₂.denote ctx) (Nat.ne_zero_of_lt h)
|
||||
: eq_coeff_cert p₁ p₂ k → p₁.denote' ctx = 0 → p₂.denote' ctx = 0 := by
|
||||
simp [eq_coeff_cert]; intro h _; subst p₁; simp [*]
|
||||
exact no_nat_zero_divisors k (p₂.denote ctx) h
|
||||
|
||||
def coeff_cert (p₁ p₂ : Poly) (k : Nat) :=
|
||||
k > 0 && p₁ == p₂.mul k
|
||||
|
||||
theorem le_coeff {α} [IntModule α] [LinearOrder α] [IntModule.IsOrdered α] (ctx : Context α) (p₁ p₂ : Poly) (k : Nat)
|
||||
: coeff_cert p₁ p₂ k → p₁.denote' ctx ≤ 0 → p₂.denote' ctx ≤ 0 := by
|
||||
|
|
@ -499,4 +500,65 @@ theorem diseq_neg {α} [IntModule α] (ctx : Context α) (p p' : Poly) : p' == p
|
|||
intro h; replace h := congrArg (- ·) h; simp [neg_neg, neg_zero] at h
|
||||
contradiction
|
||||
|
||||
/-!
|
||||
Substitution
|
||||
-/
|
||||
|
||||
def eq_diseq_subst_cert (k₁ k₂ : Int) (p₁ p₂ p₃ : Poly) : Bool :=
|
||||
k₁.natAbs ≠ 0 && p₃ == (p₁.mul k₂ |>.combine (p₂.mul k₁))
|
||||
|
||||
theorem eq_diseq_subst {α} [IntModule α] [NoNatZeroDivisors α] (ctx : Context α) (k₁ k₂ : Int) (p₁ p₂ p₃ : Poly)
|
||||
: eq_diseq_subst_cert k₁ k₂ p₁ p₂ p₃ → p₁.denote' ctx = 0 → p₂.denote' ctx ≠ 0 → p₃.denote' ctx ≠ 0 := by
|
||||
simp [eq_diseq_subst_cert, - Int.natAbs_eq_zero, -Int.natCast_eq_zero]; intro hne _ h₁ h₂; subst p₃
|
||||
simp [h₁, h₂]; intro h₃
|
||||
have : k₁.natAbs * Poly.denote ctx p₂ = 0 := by
|
||||
have : (k₁.natAbs : Int) * Poly.denote ctx p₂ = 0 := by
|
||||
cases Int.natAbs_eq_iff.mp (Eq.refl k₁.natAbs)
|
||||
next h => rw [← h]; assumption
|
||||
next h => replace h := congrArg (- ·) h; simp at h; rw [← h, IntModule.neg_hmul, h₃, IntModule.neg_zero]
|
||||
exact this
|
||||
have := no_nat_zero_divisors (k₁.natAbs) (p₂.denote ctx) hne this
|
||||
contradiction
|
||||
|
||||
def eq_diseq_subst1_cert (k : Int) (p₁ p₂ p₃ : Poly) : Bool :=
|
||||
p₃ == (p₁.mul k |>.combine p₂)
|
||||
|
||||
/-
|
||||
Special case of `diseq_eq_subst` where leading coefficient `c₁` of `p₁` is `-k*c₂`, where
|
||||
`c₂` is the leading coefficient of `p₂`.
|
||||
-/
|
||||
theorem eq_diseq_subst1 {α} [IntModule α] (ctx : Context α) (k : Int) (p₁ p₂ p₃ : Poly)
|
||||
: eq_diseq_subst1_cert k p₁ p₂ p₃ → p₁.denote' ctx = 0 → p₂.denote' ctx ≠ 0 → p₃.denote' ctx ≠ 0 := by
|
||||
simp [eq_diseq_subst1_cert]; intro _ h₁ h₂; subst p₃
|
||||
simp [h₁, h₂]
|
||||
|
||||
def eq_le_subst_cert (x : Var) (p₁ p₂ p₃ : Poly) :=
|
||||
let a := p₁.coeff x
|
||||
let b := p₂.coeff x
|
||||
a ≥ 0 && p₃ == (p₂.mul a |>.combine (p₁.mul (-b)))
|
||||
|
||||
theorem eq_le_subst {α} [IntModule α] [Preorder α] [IntModule.IsOrdered α] (ctx : Context α) (x : Var) (p₁ p₂ p₃ : Poly)
|
||||
: eq_le_subst_cert x p₁ p₂ p₃ → p₁.denote' ctx = 0 → p₂.denote' ctx ≤ 0 → p₃.denote' ctx ≤ 0 := by
|
||||
simp [eq_le_subst_cert]; intro h _ h₁ h₂; subst p₃; simp [h₁]
|
||||
exact hmul_nonpos h h₂
|
||||
|
||||
def eq_lt_subst_cert (x : Var) (p₁ p₂ p₃ : Poly) :=
|
||||
let a := p₁.coeff x
|
||||
let b := p₂.coeff x
|
||||
a > 0 && p₃ == (p₂.mul a |>.combine (p₁.mul (-b)))
|
||||
|
||||
theorem eq_lt_subst {α} [IntModule α] [Preorder α] [IntModule.IsOrdered α] (ctx : Context α) (x : Var) (p₁ p₂ p₃ : Poly)
|
||||
: eq_lt_subst_cert x p₁ p₂ p₃ → p₁.denote' ctx = 0 → p₂.denote' ctx < 0 → p₃.denote' ctx < 0 := by
|
||||
simp [eq_lt_subst_cert]; intro h _ h₁ h₂; subst p₃; simp [h₁]
|
||||
exact IsOrdered.hmul_neg (p₁.coeff x) h₂ |>.mp h
|
||||
|
||||
def eq_eq_subst_cert (x : Var) (p₁ p₂ p₃ : Poly) :=
|
||||
let a := p₁.coeff x
|
||||
let b := p₂.coeff x
|
||||
p₃ == (p₂.mul a |>.combine (p₁.mul (-b)))
|
||||
|
||||
theorem eq_eq_subst {α} [IntModule α] (ctx : Context α) (x : Var) (p₁ p₂ p₃ : Poly)
|
||||
: eq_eq_subst_cert x p₁ p₂ p₃ → p₁.denote' ctx = 0 → p₂.denote' ctx = 0 → p₃.denote' ctx = 0 := by
|
||||
simp [eq_eq_subst_cert]; intro _ h₁ h₂; subst p₃; simp [h₁, h₂]
|
||||
|
||||
end Lean.Grind.Linarith
|
||||
|
|
|
|||
|
|
@ -129,22 +129,8 @@ where
|
|||
let commRing := mkApp (mkConst ``Grind.CommRing [u]) type
|
||||
let .some commRingInst ← trySynthInstance commRing | return none
|
||||
trace_goal[grind.ring] "new ring: {type}"
|
||||
let charInst? ← withNewMCtxDepth do
|
||||
let n ← mkFreshExprMVar (mkConst ``Nat)
|
||||
let charType := mkApp3 (mkConst ``Grind.IsCharP [u]) type ringInst n
|
||||
let .some charInst ← trySynthInstance charType | pure none
|
||||
let n ← instantiateMVars n
|
||||
let some n ← evalNat n |>.run
|
||||
| trace_goal[grind.ring] "found instance for{indentExpr charType}\nbut characteristic is not a natural number"; pure none
|
||||
trace_goal[grind.ring] "characteristic: {n}"
|
||||
pure <| some (charInst, n)
|
||||
let noZeroDivInst? ← withNewMCtxDepth do
|
||||
let zeroType := mkApp (mkConst ``Zero [u]) type
|
||||
let .some zeroInst ← trySynthInstance zeroType | return none
|
||||
let hmulType := mkApp3 (mkConst ``HMul [0, u, u]) (mkConst ``Nat []) type type
|
||||
let .some hmulInst ← trySynthInstance hmulType | return none
|
||||
let noZeroDivType := mkApp3 (mkConst ``Grind.NoNatZeroDivisors [u]) type zeroInst hmulInst
|
||||
LOption.toOption <$> trySynthInstance noZeroDivType
|
||||
let charInst? ← getIsCharInst? u type ringInst
|
||||
let noZeroDivInst? ← getNoZeroDivInst? u type
|
||||
trace_goal[grind.ring] "NoNatZeroDivisors available: {noZeroDivInst?.isSome}"
|
||||
let field := mkApp (mkConst ``Grind.Field [u]) type
|
||||
let fieldInst? : Option Expr ← LOption.toOption <$> trySynthInstance field
|
||||
|
|
|
|||
|
|
@ -120,16 +120,8 @@ private def updateDvdCnstr (a : Int) (x : Var) (c : EqCnstr) (y : Var) : GoalM U
|
|||
let c' ← c'.applyEq a x c b
|
||||
c'.assert
|
||||
|
||||
private def split (x : Var) (cs : PArray LeCnstr) : GoalM (PArray LeCnstr × Array (Int × LeCnstr)) := do
|
||||
let mut cs' := {}
|
||||
let mut todo := #[]
|
||||
for c in cs do
|
||||
let b := c.p.coeff x
|
||||
if b == 0 then
|
||||
cs' := cs'.push c
|
||||
else
|
||||
todo := todo.push (b, c)
|
||||
return (cs', todo)
|
||||
private def splitLeCnstrs (x : Var) (cs : PArray LeCnstr) : PArray LeCnstr × Array (Int × LeCnstr) :=
|
||||
split cs fun c => c.p.coeff x
|
||||
|
||||
/--
|
||||
Given an equation `c₁` containing `a*x`, eliminate `x` from the inequalities in `todo`.
|
||||
|
|
@ -146,7 +138,7 @@ Given an equation `c₁` containing `a*x`, eliminate `x` from lower bound inequa
|
|||
-/
|
||||
private def updateLowers (a : Int) (x : Var) (c : EqCnstr) (y : Var) : GoalM Unit := do
|
||||
if (← inconsistent) then return ()
|
||||
let (lowers', todo) ← split x (← get').lowers[y]!
|
||||
let (lowers', todo) := splitLeCnstrs x (← get').lowers[y]!
|
||||
modify' fun s => { s with lowers := s.lowers.set y lowers' }
|
||||
updateLeCnstrs a x c todo
|
||||
|
||||
|
|
@ -155,24 +147,16 @@ Given an equation `c₁` containing `a*x`, eliminate `x` from upper bound inequa
|
|||
-/
|
||||
private def updateUppers (a : Int) (x : Var) (c : EqCnstr) (y : Var) : GoalM Unit := do
|
||||
if (← inconsistent) then return ()
|
||||
let (uppers', todo) ← split x (← get').uppers[y]!
|
||||
let (uppers', todo) := splitLeCnstrs x (← get').uppers[y]!
|
||||
modify' fun s => { s with uppers := s.uppers.set y uppers' }
|
||||
updateLeCnstrs a x c todo
|
||||
|
||||
private def splitDiseqs (x : Var) (cs : PArray DiseqCnstr) : GoalM (PArray DiseqCnstr × Array (Int × DiseqCnstr)) := do
|
||||
let mut cs' := {}
|
||||
let mut todo := #[]
|
||||
for c in cs do
|
||||
let b := c.p.coeff x
|
||||
if b == 0 then
|
||||
cs' := cs'.push c
|
||||
else
|
||||
todo := todo.push (b, c)
|
||||
return (cs', todo)
|
||||
private def splitDiseqs (x : Var) (cs : PArray DiseqCnstr) : PArray DiseqCnstr × Array (Int × DiseqCnstr) :=
|
||||
split cs fun c => c.p.coeff x
|
||||
|
||||
private def updateDiseqs (a : Int) (x : Var) (c : EqCnstr) (y : Var) : GoalM Unit := do
|
||||
if (← inconsistent) then return ()
|
||||
let (diseqs', todo) ← splitDiseqs x (← get').diseqs[y]!
|
||||
let (diseqs', todo) := splitDiseqs x (← get').diseqs[y]!
|
||||
modify' fun s => { s with diseqs := s.diseqs.set y diseqs' }
|
||||
for (b, c₂) in todo do
|
||||
let c₂ ← c₂.applyEq a x c b
|
||||
|
|
|
|||
|
|
@ -30,11 +30,13 @@ builtin_initialize registerTraceClass `grind.linarith.model
|
|||
builtin_initialize registerTraceClass `grind.linarith.assert.unsat (inherited := true)
|
||||
builtin_initialize registerTraceClass `grind.linarith.assert.trivial (inherited := true)
|
||||
builtin_initialize registerTraceClass `grind.linarith.assert.store (inherited := true)
|
||||
builtin_initialize registerTraceClass `grind.linarith.assert.ignored (inherited := true)
|
||||
|
||||
builtin_initialize registerTraceClass `grind.debug.linarith.search
|
||||
builtin_initialize registerTraceClass `grind.debug.linarith.search.conflict (inherited := true)
|
||||
builtin_initialize registerTraceClass `grind.debug.linarith.search.assign (inherited := true)
|
||||
builtin_initialize registerTraceClass `grind.debug.linarith.search.split (inherited := true)
|
||||
builtin_initialize registerTraceClass `grind.debug.linarith.search.backtrack (inherited := true)
|
||||
builtin_initialize registerTraceClass `grind.debug.linarith.subst
|
||||
|
||||
end Lean
|
||||
|
|
|
|||
|
|
@ -59,6 +59,9 @@ private def denoteIneq (p : Poly) (strict : Bool) : M Expr := do
|
|||
def IneqCnstr.denoteExpr (c : IneqCnstr) : M Expr := do
|
||||
denoteIneq c.p c.strict
|
||||
|
||||
def EqCnstr.denoteExpr (c : EqCnstr) : M Expr := do
|
||||
mkEq (← c.p.denoteExpr) (← getStruct).ofNatZero
|
||||
|
||||
private def denoteNum (k : Int) : LinearM Expr := do
|
||||
return mkApp2 (← getStruct).hmulFn (mkIntLit k) (← getOne)
|
||||
|
||||
|
|
|
|||
|
|
@ -125,6 +125,14 @@ private def mkIntModThmPrefix (declName : Name) : ProofM Expr := do
|
|||
let s ← getStruct
|
||||
return mkApp3 (mkConst declName [s.u]) s.type s.intModuleInst (← getContext)
|
||||
|
||||
/--
|
||||
Returns the prefix of a theorem with name `declName` where the first three arguments are
|
||||
`{α} [IntModule α] [NoNatZeroDivisors α] (ctx : Context α)`
|
||||
-/
|
||||
private def mkIntModNoNatDivThmPrefix (declName : Name) : ProofM Expr := do
|
||||
let s ← getStruct
|
||||
return mkApp4 (mkConst declName [s.u]) s.type s.intModuleInst (← getNoNatDivInst) (← getContext)
|
||||
|
||||
/--
|
||||
Returns the prefix of a theorem with name `declName` where the first four arguments are
|
||||
`{α} [IntModule α] [Preorder α] (ctx : Context α)`
|
||||
|
|
@ -237,6 +245,32 @@ partial def DiseqCnstr.toExprProof (c' : DiseqCnstr) : ProofM Expr := caching c'
|
|||
| .neg c =>
|
||||
let h ← mkIntModThmPrefix ``Grind.Linarith.diseq_neg
|
||||
return mkApp4 h (← mkPolyDecl c.p) (← mkPolyDecl c'.p) reflBoolTrue (← c.toExprProof)
|
||||
| .subst k₁ k₂ c₁ c₂ =>
|
||||
let h ← mkIntModNoNatDivThmPrefix ``Grind.Linarith.eq_diseq_subst
|
||||
return mkApp8 h (toExpr k₁) (toExpr k₂) (← mkPolyDecl c₁.p) (← mkPolyDecl c₂.p) (← mkPolyDecl c'.p)
|
||||
reflBoolTrue (← c₁.toExprProof) (← c₂.toExprProof)
|
||||
| .subst1 k c₁ c₂ =>
|
||||
let h ← mkIntModThmPrefix ``Grind.Linarith.eq_diseq_subst1
|
||||
return mkApp7 h (toExpr k) (← mkPolyDecl c₁.p) (← mkPolyDecl c₂.p) (← mkPolyDecl c'.p) reflBoolTrue
|
||||
(← c₁.toExprProof) (← c₂.toExprProof)
|
||||
| .oneNeZero => throwError "NIY"
|
||||
|
||||
partial def EqCnstr.toExprProof (c' : EqCnstr) : ProofM Expr := caching c' do
|
||||
match c'.h with
|
||||
| .core a b lhs rhs =>
|
||||
let h ← mkIntModThmPrefix ``Grind.Linarith.eq_norm
|
||||
return mkApp5 h (← mkExprDecl lhs) (← mkExprDecl rhs) (← mkPolyDecl c'.p) reflBoolTrue (← mkEqProof a b)
|
||||
| .coreCommRing a b ra rb p lhs' => throwError "NIY"
|
||||
| .neg c =>
|
||||
let h ← mkIntModThmPrefix ``Grind.Linarith.eq_neg
|
||||
return mkApp4 h (← mkPolyDecl c.p) (← mkPolyDecl c'.p) reflBoolTrue (← c.toExprProof)
|
||||
| .coeff k c =>
|
||||
let h ← mkIntModNoNatDivThmPrefix ``Grind.Linarith.eq_coeff
|
||||
return mkApp5 h (← mkPolyDecl c.p) (← mkPolyDecl c'.p) (toExpr k) reflBoolTrue (← c.toExprProof)
|
||||
| .subst x c₁ c₂ =>
|
||||
let h ← mkIntModThmPrefix ``Grind.Linarith.eq_eq_subst
|
||||
return mkApp7 h (toExpr x) (← mkPolyDecl c₁.p) (← mkPolyDecl c₂.p) (← mkPolyDecl c'.p) reflBoolTrue
|
||||
(← c₁.toExprProof) (← c₂.toExprProof)
|
||||
|
||||
partial def UnsatProof.toExprProofCore (h : UnsatProof) : ProofM Expr := do
|
||||
match h with
|
||||
|
|
@ -271,13 +305,21 @@ partial def IneqCnstr.collectDecVars (c' : IneqCnstr) : CollectDecVarsM Unit :=
|
|||
| .norm c₁ _ => c₁.collectDecVars
|
||||
| .dec h => markAsFound h
|
||||
| .ofDiseqSplit (decVars := decVars) .. => decVars.forM markAsFound
|
||||
| .subst _ c₁ c₂ => c₁.collectDecVars; c₂.collectDecVars
|
||||
|
||||
-- `DiseqCnstr` is currently mutually recursive with `IneqCnstr`, but it will be in the future.
|
||||
-- Actually, it cannot even contain decision variables in the current implementation.
|
||||
partial def DiseqCnstr.collectDecVars (c' : DiseqCnstr) : CollectDecVarsM Unit := do unless (← alreadyVisited c') do
|
||||
match c'.h with
|
||||
| .core .. | .coreCommRing .. => return ()
|
||||
| .core .. | .coreCommRing .. | .oneNeZero => return ()
|
||||
| .neg c => c.collectDecVars
|
||||
| .subst _ _ c₁ c₂ | .subst1 _ c₁ c₂ => c₁.collectDecVars; c₂.collectDecVars
|
||||
|
||||
partial def EqCnstr.collectDecVars (c' : EqCnstr) : CollectDecVarsM Unit := do unless (← alreadyVisited c') do
|
||||
match c'.h with
|
||||
| .subst _ c₁ c₂ => c₁.collectDecVars; c₂.collectDecVars
|
||||
| .core .. | .coreCommRing .. => return ()
|
||||
| .neg c | .coeff _ c => c.collectDecVars
|
||||
|
||||
end
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,32 @@ import Lean.Meta.Tactic.Grind.Arith.Linear.DenoteExpr
|
|||
import Lean.Meta.Tactic.Grind.Arith.Linear.Proof
|
||||
|
||||
namespace Lean.Meta.Grind.Arith.Linear
|
||||
|
||||
private def _root_.Lean.Grind.Linarith.Poly.substVar (p : Poly) : LinearM (Option (Var × EqCnstr × Poly)) := do
|
||||
let some (a, x, c) ← p.findVarToSubst | return none
|
||||
let b := c.p.coeff x
|
||||
let p' := p.mul (-b) |>.combine (c.p.mul a)
|
||||
trace[grind.debug.linarith.subst] "{← p.denoteExpr}, {a}, {← getVar x}, {← c.denoteExpr}, {b}, {← p'.denoteExpr}"
|
||||
return some (x, c, p')
|
||||
|
||||
/--
|
||||
Given an equation `c₁` containing the monomial `a*x`, and a disequality constraint `c₂`
|
||||
containing the monomial `b*x`, eliminate `x` by applying substitution.
|
||||
-/
|
||||
def DiseqCnstr.applyEq? (a : Int) (x : Var) (c₁ : EqCnstr) (b : Int) (c₂ : DiseqCnstr) : LinearM (Option DiseqCnstr) := do
|
||||
trace[grind.linarith.subst] "{← getVar x}, {← c₁.denoteExpr}, {← c₂.denoteExpr}"
|
||||
let p := c₁.p
|
||||
let q := c₂.p
|
||||
if b % a == 0 then
|
||||
let k := - b / a
|
||||
let p := p.mul k |>.combine q
|
||||
return some { p, h := .subst1 k c₁ c₂ }
|
||||
else if (← hasNoNatZeroDivisors) then
|
||||
let p := p.mul b |>.combine (q.mul (-a))
|
||||
return some { p, h := .subst (-a) b c₁ c₂ }
|
||||
else
|
||||
return none
|
||||
|
||||
/-- Returns `some structId` if `a` and `b` are elements of the same structure. -/
|
||||
def inSameStruct? (a b : Expr) : GoalM (Option Nat) := do
|
||||
let some structId ← getTermStructId? a | return none
|
||||
|
|
@ -22,7 +48,7 @@ def inSameStruct? (a b : Expr) : GoalM (Option Nat) := do
|
|||
unless structId == structId' do return none -- This can happen when we have heterogeneous equalities
|
||||
return structId
|
||||
|
||||
private def processNewCommRingEq (a b : Expr) : LinearM Unit := do
|
||||
private def processNewCommRingEq' (a b : Expr) : LinearM Unit := do
|
||||
let some lhs ← withRingM <| CommRing.reify? a (skipVar := false) | return ()
|
||||
let some rhs ← withRingM <| CommRing.reify? b (skipVar := false) | return ()
|
||||
let gen := max (← getGeneration a) (← getGeneration b)
|
||||
|
|
@ -40,7 +66,7 @@ private def processNewCommRingEq (a b : Expr) : LinearM Unit := do
|
|||
let c₂ : IneqCnstr := { p, strict := false, h := .ofCommRingEq b a rhs lhs p' lhs' }
|
||||
c₂.assert
|
||||
|
||||
private def processNewIntModuleEq (a b : Expr) : LinearM Unit := do
|
||||
private def processNewIntModuleEq' (a b : Expr) : LinearM Unit := do
|
||||
let some lhs ← reify? a (skipVar := false) | return ()
|
||||
let some rhs ← reify? b (skipVar := false) | return ()
|
||||
let p := (lhs.sub rhs).norm
|
||||
|
|
@ -51,21 +77,83 @@ private def processNewIntModuleEq (a b : Expr) : LinearM Unit := do
|
|||
let c₂ : IneqCnstr := { p, strict := false, h := .ofEq b a rhs lhs }
|
||||
c₂.assert
|
||||
|
||||
@[export lean_process_linarith_eq]
|
||||
def processNewEqImpl (a b : Expr) : GoalM Unit := do
|
||||
if isSameExpr a b then return () -- TODO: check why this is needed
|
||||
let some structId ← inSameStruct? a b | return ()
|
||||
LinearM.run structId do
|
||||
-- TODO: support non ordered case
|
||||
unless (← isOrdered) do return ()
|
||||
trace_goal[grind.linarith.assert] "{← mkEq a b}"
|
||||
if (← isCommRing) then
|
||||
processNewCommRingEq a b
|
||||
else
|
||||
processNewIntModuleEq a b
|
||||
def EqCnstr.norm (c : EqCnstr) : LinearM (Nat × Var × EqCnstr) := do
|
||||
let mut c := c
|
||||
if (← hasNoNatZeroDivisors) then
|
||||
let k := c.p.gcdCoeffs
|
||||
if k != 1 then
|
||||
c := { p := c.p.div k, h := .coeff k c }
|
||||
let some (k, x) := c.p.pickVarToElim? | unreachable!
|
||||
if k < 0 then
|
||||
c := { p := c.p.mul (-1), h := .neg c }
|
||||
return (k.natAbs, x, c)
|
||||
|
||||
partial def EqCnstr.applySubsts (c : EqCnstr) : LinearM EqCnstr := withIncRecDepth do
|
||||
let some (x, c₁, p) ← c.p.substVar | return c
|
||||
trace[grind.debug.linarith.subst] "{← getVar x}, {← c.denoteExpr}, {← c₁.denoteExpr}"
|
||||
applySubsts { p, h := .subst x c₁ c : EqCnstr }
|
||||
|
||||
/--
|
||||
Given an equation `c₁` containing the monomial `a*x`, and an inequality constraint `c₂`
|
||||
containing the monomial `b*x`, eliminate `x` by applying substitution.
|
||||
-/
|
||||
def IneqCnstr.applyEq (a : Nat) (x : Var) (c₁ : EqCnstr) (b : Int) (c₂ : IneqCnstr) : LinearM IneqCnstr := do
|
||||
let p := c₁.p
|
||||
let q := c₂.p
|
||||
let p := q.mul a |>.combine (p.mul (-b))
|
||||
trace[grind.linarith.subst] "{← getVar x}, {← c₁.denoteExpr}, {← c₂.denoteExpr}"
|
||||
return { p, h := .subst x c₁ c₂, strict := c₂.strict }
|
||||
|
||||
/--
|
||||
Given an equation `c₁` containing `a*x`, eliminate `x` from the inequalities in `todo`.
|
||||
`todo` contains pairs of the form `(b, c₂)` where `b` is the coefficient of `x` in `c₂`.
|
||||
-/
|
||||
private def updateLeCnstrs (a : Nat) (x : Var) (c₁ : EqCnstr) (todo : Array (Int × IneqCnstr)) : LinearM Unit := do
|
||||
for (b, c₂) in todo do
|
||||
let c₂ ← c₂.applyEq a x c₁ b
|
||||
c₂.assert
|
||||
if (← inconsistent) then return ()
|
||||
|
||||
private def splitIneqCnstrs (x : Var) (cs : PArray IneqCnstr) : PArray IneqCnstr × Array (Int × IneqCnstr) :=
|
||||
split cs fun c => c.p.coeff x
|
||||
|
||||
/--
|
||||
Given an equation `c₁` containing `a*x`, eliminate `x` from lower bound inequalities of `y`.
|
||||
-/
|
||||
private def updateLowers (a : Nat) (x : Var) (c : EqCnstr) (y : Var) : LinearM Unit := do
|
||||
if (← inconsistent) then return ()
|
||||
let (lowers', todo) := splitIneqCnstrs x (← getStruct).lowers[y]!
|
||||
modifyStruct fun s => { s with lowers := s.lowers.set y lowers' }
|
||||
updateLeCnstrs a x c todo
|
||||
|
||||
/--
|
||||
Given an equation `c₁` containing `a*x`, eliminate `x` from upper bound inequalities of `y`.
|
||||
-/
|
||||
private def updateUppers (a : Nat) (x : Var) (c : EqCnstr) (y : Var) : LinearM Unit := do
|
||||
if (← inconsistent) then return ()
|
||||
let (uppers', todo) := splitIneqCnstrs x (← getStruct).uppers[y]!
|
||||
modifyStruct fun s => { s with uppers := s.uppers.set y uppers' }
|
||||
updateLeCnstrs a x c todo
|
||||
|
||||
def DiseqCnstr.ignore (c : DiseqCnstr) : LinearM Unit := do
|
||||
-- Remark: we filter duplicates before displaying diagnostics to users
|
||||
trace[grind.linarith.assert.ignored] "{← c.denoteExpr}"
|
||||
let diseq ← c.denoteExpr
|
||||
modifyStruct fun s => { s with ignored := s.ignored.push diseq }
|
||||
|
||||
partial def DiseqCnstr.applySubsts? (c₂ : DiseqCnstr) : LinearM (Option DiseqCnstr) := withIncRecDepth do
|
||||
let some (b, x, c₁) ← c₂.p.findVarToSubst | return some c₂
|
||||
let a := c₁.p.coeff x
|
||||
if let some c₂ ← c₂.applyEq? a x c₁ b then
|
||||
c₂.applySubsts?
|
||||
else
|
||||
-- Failed to eliminate
|
||||
c₂.ignore
|
||||
return none
|
||||
|
||||
def DiseqCnstr.assert (c : DiseqCnstr) : LinearM Unit := do
|
||||
trace[grind.linarith.assert] "{← c.denoteExpr}"
|
||||
let some c ← c.applySubsts? | return ()
|
||||
match c.p with
|
||||
| .nil =>
|
||||
trace[grind.linarith.unsat] "{← c.denoteExpr}"
|
||||
|
|
@ -77,6 +165,77 @@ def DiseqCnstr.assert (c : DiseqCnstr) : LinearM Unit := do
|
|||
if (← c.satisfied) == .false then
|
||||
resetAssignmentFrom x
|
||||
|
||||
private def splitDiseqs (x : Var) (cs : PArray DiseqCnstr) : PArray DiseqCnstr × Array (Int × DiseqCnstr) :=
|
||||
split cs fun c => c.p.coeff x
|
||||
|
||||
private def updateDiseqs (a : Int) (x : Var) (c : EqCnstr) (y : Var) : LinearM Unit := do
|
||||
if (← inconsistent) then return ()
|
||||
let (diseqs', todo) := splitDiseqs x (← getStruct).diseqs[y]!
|
||||
modifyStruct fun s => { s with diseqs := s.diseqs.set y diseqs' }
|
||||
for (b, c₂) in todo do
|
||||
if let some c₂ ← c₂.applyEq? a x c b then
|
||||
c₂.assert
|
||||
if (← inconsistent) then return ()
|
||||
else
|
||||
-- Failed to eliminate
|
||||
c₂.ignore
|
||||
|
||||
private def updateOccsAt (a : Nat) (x : Var) (c : EqCnstr) (y : Var) : LinearM Unit := do
|
||||
updateLowers a x c y
|
||||
updateUppers a x c y
|
||||
updateDiseqs a x c y
|
||||
|
||||
private def updateOccs (a : Nat) (x : Var) (c : EqCnstr) : LinearM Unit := do
|
||||
let ys := (← getStruct).occurs[x]!
|
||||
modifyStruct fun s => { s with occurs := s.occurs.set x {} }
|
||||
updateOccsAt a x c x
|
||||
for y in ys do
|
||||
updateOccsAt a x c y
|
||||
|
||||
def EqCnstr.assert (c : EqCnstr) : LinearM Unit := do
|
||||
trace[grind.linarith.assert] "{← c.denoteExpr}"
|
||||
let c ← c.applySubsts
|
||||
if c.p == .nil then
|
||||
trace[grind.linarith.trivial] "{← c.denoteExpr}"
|
||||
return ()
|
||||
let (a, x, c) ← c.norm
|
||||
trace[grind.debug.linarith.subst] ">> {← getVar x}, {← c.denoteExpr}"
|
||||
trace[grind.linarith.assert.store] "{← c.denoteExpr}"
|
||||
modifyStruct fun s => { s with
|
||||
elimEqs := s.elimEqs.set x (some c)
|
||||
elimStack := x :: s.elimStack
|
||||
}
|
||||
updateOccs a x c
|
||||
|
||||
private def processNewCommRingEq (a b : Expr) : LinearM Unit := do
|
||||
trace[Meta.debug] "{a}, {b}"
|
||||
-- TODO
|
||||
|
||||
private def processNewIntModuleEq (a b : Expr) : LinearM Unit := do
|
||||
let some lhs ← reify? a (skipVar := false) | return ()
|
||||
let some rhs ← reify? b (skipVar := false) | return ()
|
||||
let p := (lhs.sub rhs).norm
|
||||
if p == .nil then return ()
|
||||
let c : EqCnstr := { p, h := .core a b lhs rhs }
|
||||
c.assert
|
||||
|
||||
@[export lean_process_linarith_eq]
|
||||
def processNewEqImpl (a b : Expr) : GoalM Unit := do
|
||||
if isSameExpr a b then return () -- TODO: check why this is needed
|
||||
let some structId ← inSameStruct? a b | return ()
|
||||
LinearM.run structId do
|
||||
if (← isOrdered) then
|
||||
trace_goal[grind.linarith.assert] "{← mkEq a b}"
|
||||
if (← isCommRing) then
|
||||
processNewCommRingEq' a b
|
||||
else
|
||||
processNewIntModuleEq' a b
|
||||
else
|
||||
if (← isCommRing) then
|
||||
processNewCommRingEq a b
|
||||
else
|
||||
processNewIntModuleEq a b
|
||||
|
||||
private def processNewCommRingDiseq (a b : Expr) : LinearM Unit := do
|
||||
let some lhs ← withRingM <| CommRing.reify? a (skipVar := false) | return ()
|
||||
let some rhs ← withRingM <| CommRing.reify? b (skipVar := false) | return ()
|
||||
|
|
|
|||
|
|
@ -38,6 +38,23 @@ private def ensureDefEq (a b : Expr) : MetaM Unit := do
|
|||
unless (← withDefault <| isDefEq a b) do
|
||||
throwError (← mkExpectedDefEqMsg a b)
|
||||
|
||||
private def addZeroLtOne (one : Var) : LinearM Unit := do
|
||||
let p := Poly.add (-1) one .nil
|
||||
modifyStruct fun s => { s with
|
||||
lowers := s.lowers.modify one fun cs => cs.push { p, h := .oneGtZero, strict := true }
|
||||
}
|
||||
|
||||
private def addZeroNeOne (one : Var) : LinearM Unit := do
|
||||
let p := Poly.add 1 one .nil
|
||||
modifyStruct fun s => { s with
|
||||
diseqs := s.diseqs.modify one fun cs => cs.push { p, h := .oneNeZero }
|
||||
}
|
||||
|
||||
private def isNonTrivialIsCharInst (isCharInst? : Option (Expr × Nat)) : Bool :=
|
||||
match isCharInst? with
|
||||
| some (_, c) => c != 1
|
||||
| none => false
|
||||
|
||||
def getStructId? (type : Expr) : GoalM (Option Nat) := do
|
||||
unless (← getConfig).linarith do return none
|
||||
if (← getConfig).cutsat && Cutsat.isSupportedType type then
|
||||
|
|
@ -144,6 +161,7 @@ where
|
|||
let hsmulNatFn? ← getHSMulNatFn?
|
||||
let ringId? ← CommRing.getRingId? type
|
||||
let ringInst? ← getInst? ``Grind.Ring
|
||||
let fieldInst? ← getInst? ``Grind.Field
|
||||
let getOne? : GoalM (Option Expr) := do
|
||||
let some oneInst ← getInst? ``One | return none
|
||||
let one ← internalizeConst <| mkApp2 (mkConst ``One.one [u]) type oneInst
|
||||
|
|
@ -161,6 +179,7 @@ where
|
|||
return none
|
||||
return some inst
|
||||
let ringIsOrdInst? ← getRingIsOrdInst?
|
||||
let charInst? ← if let some ringInst := ringInst? then getIsCharInst? u type ringInst else pure none
|
||||
let getNoNatZeroDivInst? : GoalM (Option Expr) := do
|
||||
let hmulNat := mkApp3 (mkConst ``HMul [0, u, u]) Nat.mkType type type
|
||||
let .some hmulInst ← trySynthInstance hmulNat | return none
|
||||
|
|
@ -171,18 +190,21 @@ where
|
|||
let struct : Struct := {
|
||||
id, type, u, intModuleInst, preorderInst?, isOrdInst?, partialInst?, linearInst?, noNatDivInst?
|
||||
leFn?, ltFn?, addFn, subFn, negFn, hmulFn, hmulNatFn, hsmulFn?, hsmulNatFn?, zero, one?
|
||||
ringInst?, commRingInst?, ringIsOrdInst?, ringId?, ofNatZero
|
||||
ringInst?, commRingInst?, ringIsOrdInst?, charInst?, ringId?, fieldInst?, ofNatZero
|
||||
}
|
||||
modify' fun s => { s with structs := s.structs.push struct }
|
||||
if let some one := one? then
|
||||
if ringInst?.isSome then LinearM.run id do
|
||||
-- Create `1` variable, and assert strict lower bound `0 < 1`
|
||||
let x ← mkVar one (mark := false)
|
||||
let p := Poly.add (-1) x .nil
|
||||
p.updateOccs
|
||||
modifyStruct fun s => { s with
|
||||
lowers := s.lowers.modify x fun cs => cs.push { p, h := .oneGtZero, strict := true }
|
||||
}
|
||||
if ringIsOrdInst?.isSome then
|
||||
-- Create `1` variable, and assert strict lower bound `0 < 1` and `0 ≠ 1`
|
||||
let x ← mkVar one (mark := false)
|
||||
addZeroLtOne x
|
||||
addZeroNeOne x
|
||||
else if fieldInst?.isSome || isNonTrivialIsCharInst charInst? then
|
||||
-- Create `1` variable, and assert `0 ≠ 1`
|
||||
let x ← mkVar one (mark := false)
|
||||
addZeroNeOne x
|
||||
|
||||
return some id
|
||||
|
||||
end Lean.Meta.Grind.Arith.Linear
|
||||
|
|
|
|||
|
|
@ -25,7 +25,11 @@ structure EqCnstr where
|
|||
h : EqCnstrProof
|
||||
|
||||
inductive EqCnstrProof where
|
||||
| rfl -- TODO
|
||||
| core (a b : Expr) (lhs rhs : LinExpr)
|
||||
| coreCommRing (a b : Expr) (ra rb : Grind.CommRing.Expr) (p : Grind.CommRing.Poly) (lhs' : LinExpr)
|
||||
| neg (c : EqCnstr)
|
||||
| coeff (k : Nat) (c : EqCnstr)
|
||||
| subst (x : Var) (c₁ : EqCnstr) (c₂ : EqCnstr)
|
||||
|
||||
/-- An inequality constraint and its justification/proof. -/
|
||||
structure IneqCnstr where
|
||||
|
|
@ -47,6 +51,7 @@ inductive IneqCnstrProof where
|
|||
ofEq (a b : Expr) (la lb : LinExpr)
|
||||
| /-- `a ≤ b` from an equality `a = b` coming from the core. -/
|
||||
ofCommRingEq (a b : Expr) (ra rb : Grind.CommRing.Expr) (p : Grind.CommRing.Poly) (lhs' : LinExpr)
|
||||
| subst (x : Var) (c₁ : EqCnstr) (c₂ : IneqCnstr)
|
||||
|
||||
structure DiseqCnstr where
|
||||
p : Poly
|
||||
|
|
@ -56,6 +61,9 @@ inductive DiseqCnstrProof where
|
|||
| core (a b : Expr) (lhs rhs : LinExpr)
|
||||
| coreCommRing (a b : Expr) (ra rb : Grind.CommRing.Expr) (p : Grind.CommRing.Poly) (lhs' : LinExpr)
|
||||
| neg (c : DiseqCnstr)
|
||||
| subst (k₁ k₂ : Int) (c₁ : EqCnstr) (c₂ : DiseqCnstr)
|
||||
| subst1 (k : Int) (c₁ : EqCnstr) (c₂ : DiseqCnstr)
|
||||
| oneNeZero
|
||||
|
||||
inductive UnsatProof where
|
||||
| diseq (c : DiseqCnstr)
|
||||
|
|
@ -66,6 +74,9 @@ end
|
|||
instance : Inhabited DiseqCnstr where
|
||||
default := { p := .nil, h := .core default default .zero .zero }
|
||||
|
||||
instance : Inhabited EqCnstr where
|
||||
default := { p := .nil, h := .core default default .zero .zero }
|
||||
|
||||
abbrev VarSet := RBTree Var compare
|
||||
|
||||
/--
|
||||
|
|
@ -98,6 +109,10 @@ structure Struct where
|
|||
commRingInst? : Option Expr
|
||||
/-- `Ring.IsOrdered` instance with `Preorder` -/
|
||||
ringIsOrdInst? : Option Expr
|
||||
/-- `Field` instance -/
|
||||
fieldInst? : Option Expr
|
||||
/-- `IsCharP` instance for `type` if available. -/
|
||||
charInst? : Option (Expr × Nat)
|
||||
zero : Expr
|
||||
ofNatZero : Expr
|
||||
one? : Option Expr
|
||||
|
|
|
|||
|
|
@ -110,6 +110,11 @@ def setTermStructId (e : Expr) : LinearM Unit := do
|
|||
return ()
|
||||
modify' fun s => { s with exprToStructId := s.exprToStructId.insert { expr := e } structId }
|
||||
|
||||
def getNoNatDivInst : LinearM Expr := do
|
||||
let some inst := (← getStruct).noNatDivInst?
|
||||
| throwError "`grind linarith` internal error, structure does not implement `NoNatZeroDivisors`"
|
||||
return inst
|
||||
|
||||
def getPreorderInst : LinearM Expr := do
|
||||
let some inst := (← getStruct).preorderInst?
|
||||
| throwError "`grind linarith` internal error, structure is not a preorder"
|
||||
|
|
@ -221,4 +226,46 @@ partial def _root_.Lean.Grind.Linarith.Poly.updateOccs (p : Poly) : LinearM Unit
|
|||
addOcc x y; go p
|
||||
go p
|
||||
|
||||
/--
|
||||
Given a polynomial `p`, returns `some (x, k, c)` if `p` contains the monomial `k*x`,
|
||||
and `x` has been eliminated using the equality `c`.
|
||||
-/
|
||||
def _root_.Lean.Grind.Linarith.Poly.findVarToSubst (p : Poly) : LinearM (Option (Int × Var × EqCnstr)) := do
|
||||
match p with
|
||||
| .nil => return none
|
||||
| .add k x p =>
|
||||
if let some c := (← getStruct).elimEqs[x]! then
|
||||
return some (k, x, c)
|
||||
else
|
||||
findVarToSubst p
|
||||
|
||||
def _root_.Lean.Grind.Linarith.Poly.gcdCoeffsAux : Poly → Nat → Nat
|
||||
| .nil, k => k
|
||||
| .add k' _ p, k => gcdCoeffsAux p (Int.gcd k' k)
|
||||
|
||||
def _root_.Lean.Grind.Linarith.Poly.gcdCoeffs (p : Poly) : Nat :=
|
||||
match p with
|
||||
| .add k _ p => p.gcdCoeffsAux k.natAbs
|
||||
| .nil => 1
|
||||
|
||||
def _root_.Lean.Grind.Linarith.Poly.div (p : Poly) (k : Int) : Poly :=
|
||||
match p with
|
||||
| .add a x p => .add (a / k) x (p.div k)
|
||||
| .nil => .nil
|
||||
|
||||
/--
|
||||
Selects the variable in the given linear polynomial whose coefficient has the smallest absolute value.
|
||||
-/
|
||||
def _root_.Lean.Grind.Linarith.Poly.pickVarToElim? (p : Poly) : Option (Int × Var) :=
|
||||
match p with
|
||||
| .nil => none
|
||||
| .add k x p => go k x p
|
||||
where
|
||||
go (k : Int) (x : Var) (p : Poly) : Int × Var :=
|
||||
if k == 1 || k == -1 then
|
||||
(k, x)
|
||||
else match p with
|
||||
| .nil => (k, x)
|
||||
| .add k' x' p => if k'.natAbs < k.natAbs then go k' x' p else go k x p
|
||||
|
||||
end Lean.Meta.Grind.Arith.Linear
|
||||
|
|
|
|||
|
|
@ -4,8 +4,9 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
|||
Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Lean.Expr
|
||||
import Lean.Message
|
||||
import Init.Grind.CommRing.Basic
|
||||
import Lean.Meta.SynthInstance
|
||||
import Lean.Meta.Basic
|
||||
import Std.Internal.Rat
|
||||
|
||||
namespace Lean.Meta.Grind.Arith
|
||||
|
|
@ -151,4 +152,32 @@ def isIntModuleVirtualParent (parent? : Option Expr) : Bool :=
|
|||
| none => false
|
||||
| some parent => parent == getIntModuleVirtualParent
|
||||
|
||||
def getIsCharInst? (u : Level) (type : Expr) (ringInst : Expr) : MetaM (Option (Expr × Nat)) := do withNewMCtxDepth do
|
||||
let n ← mkFreshExprMVar (mkConst ``Nat)
|
||||
let charType := mkApp3 (mkConst ``Grind.IsCharP [u]) type ringInst n
|
||||
let .some charInst ← trySynthInstance charType | pure none
|
||||
let n ← instantiateMVars n
|
||||
let some n ← evalNat n |>.run
|
||||
| pure none
|
||||
pure <| some (charInst, n)
|
||||
|
||||
def getNoZeroDivInst? (u : Level) (type : Expr) : MetaM (Option Expr) := do
|
||||
let zeroType := mkApp (mkConst ``Zero [u]) type
|
||||
let .some zeroInst ← trySynthInstance zeroType | return none
|
||||
let hmulType := mkApp3 (mkConst ``HMul [0, u, u]) (mkConst ``Nat []) type type
|
||||
let .some hmulInst ← trySynthInstance hmulType | return none
|
||||
let noZeroDivType := mkApp3 (mkConst ``Grind.NoNatZeroDivisors [u]) type zeroInst hmulInst
|
||||
LOption.toOption <$> trySynthInstance noZeroDivType
|
||||
|
||||
@[specialize] def split (cs : PArray α) (getCoeff : α → Int) : PArray α × Array (Int × α) := Id.run do
|
||||
let mut cs' := {}
|
||||
let mut todo := #[]
|
||||
for c in cs do
|
||||
let b := getCoeff c
|
||||
if b == 0 then
|
||||
cs' := cs'.push c
|
||||
else
|
||||
todo := todo.push (b, c)
|
||||
return (cs', todo)
|
||||
|
||||
end Lean.Meta.Grind.Arith
|
||||
|
|
|
|||
37
tests/lean/run/grind_module_eqs.lean
Normal file
37
tests/lean/run/grind_module_eqs.lean
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
open Lean Grind
|
||||
|
||||
example [IntModule α] (x y : α) : x - y ≠ 0 - 2*y → x + y = 0 → False := by
|
||||
grind
|
||||
|
||||
example [IntModule α] (x y : α) : 2*x + 2*y ≠ 0 → x + y = 0 → False := by
|
||||
grind
|
||||
|
||||
example [IntModule α] (x y : α) : 2*x + 2*y ≠ 0 → 2*x + 2*y = 0 → False := by
|
||||
grind
|
||||
|
||||
example [IntModule α] [NoNatZeroDivisors α] (x y : α) : x + y ≠ 0 → 2*x + 2*y = 0 → False := by
|
||||
grind
|
||||
|
||||
example [IntModule α] [NoNatZeroDivisors α] (x y z : α) : x + y + z ≠ 0 → 2*x + 3*y = 0 → y = 2*z → False := by
|
||||
grind
|
||||
|
||||
example [IntModule α] [NoNatZeroDivisors α] (x y z : α) : x + y + z ≠ 0 → -3*y = 2*x → y = 2*z → False := by
|
||||
grind
|
||||
|
||||
example [IntModule α] (x y : α) : x + y = 0 → x - y = 0 - 2*y := by
|
||||
grind
|
||||
|
||||
example [IntModule α] (x y : α) : x + y = 0 → 2*x + 2*y = 0 := by
|
||||
grind
|
||||
|
||||
example [IntModule α] (x y : α) : 2*x + 2*y = 0 → 2*x = 0 - 2*y := by
|
||||
grind
|
||||
|
||||
example [IntModule α] [NoNatZeroDivisors α] (x y : α) : 2*x + 2*y = 0 → x = -y := by
|
||||
grind
|
||||
|
||||
example [IntModule α] [NoNatZeroDivisors α] (x y z : α) : 2*x + 3*y = 0 → y = 2*z → x + y + z = 0 := by
|
||||
grind
|
||||
|
||||
example [IntModule α] [NoNatZeroDivisors α] (x y z : α) : -3*y = 2*x → y = 2*z → x + y + z = 0 := by
|
||||
grind
|
||||
|
|
@ -22,6 +22,6 @@ example (a b c : R) (h : 2 * a + 2 * b = 4 * c) : 3 * a + c = 5 * c - b + (-b) +
|
|||
|
||||
-- In a `RatModule` we can clear common divisors.
|
||||
example (a : R) (h : a + a = 0) : a = 0 := by grind
|
||||
example (a b c : R) (h : 2 * a + 2 * b = 4 * c) : 3 * a + c = 5 * c - 3 * b := by grind
|
||||
example (a b c : R) (h : 2 * a + 2 * b = 4 * c) : 3 * a + c = 7 * c - 3 * b := by grind
|
||||
|
||||
end RatModule
|
||||
|
|
@ -16,7 +16,6 @@ example (x : UInt8) : (x + 16)*(x - 16) = x^2 := by
|
|||
|
||||
/--
|
||||
trace: [grind.ring] new ring: Int
|
||||
[grind.ring] characteristic: 0
|
||||
[grind.ring] NoNatZeroDivisors available: true
|
||||
-/
|
||||
#guard_msgs (trace) in
|
||||
|
|
@ -29,10 +28,8 @@ example (x : BitVec 8) : (x + 16)*(x - 16) = x^2 := by
|
|||
|
||||
/--
|
||||
trace: [grind.ring] new ring: Int
|
||||
[grind.ring] characteristic: 0
|
||||
[grind.ring] NoNatZeroDivisors available: true
|
||||
[grind.ring] new ring: BitVec 8
|
||||
[grind.ring] characteristic: 256
|
||||
[grind.ring] NoNatZeroDivisors available: false
|
||||
-/
|
||||
#guard_msgs (trace) in
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue