feat: add divisibility constraint solver to grind (#7122)

This PR implements the divisibility constraint solver for the cutsat
procedure in the `grind` tactic.
This commit is contained in:
Leonardo de Moura 2025-02-17 18:43:35 -08:00 committed by GitHub
parent ca253ae4cf
commit 97fb0b82bb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 214 additions and 42 deletions

View file

@ -5,13 +5,25 @@ Authors: Leonardo de Moura
-/
prelude
import Lean.Util.Trace
import Lean.Meta.Tactic.Grind.Arith.Cutsat.DvdCnstr
import Lean.Meta.Tactic.Grind.Arith.Cutsat.Inv
import Lean.Meta.Tactic.Grind.Arith.Cutsat.Proof
import Lean.Meta.Tactic.Grind.Arith.Cutsat.Types
import Lean.Meta.Tactic.Grind.Arith.Cutsat.Util
import Lean.Meta.Tactic.Grind.Arith.Cutsat.Var
namespace Lean
builtin_initialize registerTraceClass `grind.cutsat
builtin_initialize registerTraceClass `grind.cutsat.assert
builtin_initialize registerTraceClass `grind.cutsat.assert.dvd
builtin_initialize registerTraceClass `grind.cutsat.dvd
builtin_initialize registerTraceClass `grind.cutsat.dvd.update (inherited := true)
builtin_initialize registerTraceClass `grind.cutsat.dvd.unsat (inherited := true)
builtin_initialize registerTraceClass `grind.cutsat.dvd.trivial (inherited := true)
builtin_initialize registerTraceClass `grind.cutsat.dvd.solve (inherited := true)
builtin_initialize registerTraceClass `grind.cutsat.dvd.solve.combine (inherited := true)
builtin_initialize registerTraceClass `grind.cutsat.dvd.solve.elim (inherited := true)
builtin_initialize registerTraceClass `grind.cutsat.internalize
builtin_initialize registerTraceClass `grind.cutsat.internalize.term (inherited := true)

View file

@ -4,19 +4,102 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Meta.Tactic.Simp.Arith.Int
import Lean.Meta.Tactic.Grind.PropagatorAttr
import Lean.Meta.Tactic.Grind.Arith.Cutsat.Var
import Lean.Meta.Tactic.Grind.Arith.Cutsat.Proof
namespace Lean.Meta.Grind.Arith.Cutsat
/--
`gcdExt a b` returns the triple `(g, α, β)` such that
- `g = gcd a b` (with `g ≥ 0`), and
- `g = α * a + β * β`.
-/
partial def gcdExt (a b : Int) : Int × Int × Int :=
if b = 0 then
(a.natAbs, if a = 0 then 0 else a / a.natAbs, 0)
else
let (g, α, β) := gcdExt b (a % b)
(g, β, α - (a / b) * β)
def assertDvdCnstr (e : Expr) : GoalM Unit := do
abbrev DvdCnstrWithProof.isUnsat (cₚ : DvdCnstrWithProof) : Bool :=
cₚ.c.isUnsat
abbrev DvdCnstrWithProof.isTrivial (cₚ : DvdCnstrWithProof) : Bool :=
cₚ.c.isTrivial
def DvdCnstrWithProof.norm (cₚ : DvdCnstrWithProof) : DvdCnstrWithProof :=
let cₚ := if cₚ.c.isSorted then cₚ else { cₚ with c.p := cₚ.c.p.norm, h := .norm cₚ }
let g := cₚ.c.p.gcdCoeffs cₚ.c.k
if cₚ.c.p.getConst % g == 0 then
{ cₚ with c := cₚ.c.div g, h := .divCoeffs cₚ }
else
cₚ
/-- Asserts divisibility constraint. -/
partial def assertDvdCnstr (cₚ : DvdCnstrWithProof) : GoalM Unit := withIncRecDepth do
if (← isInconsistent) then return ()
let cₚ := cₚ.norm
if cₚ.isUnsat then
trace[grind.cutsat.dvd.unsat] "{← cₚ.denoteExpr}"
withProofContext do
let h ← cₚ.toExprProof
let heq := mkApp3 (mkConst ``Int.Linear.DvdCnstr.eq_false_of_isUnsat) (← getContext) (toExpr cₚ.c) reflBoolTrue
let c ← cₚ.denoteExpr
let heq ← mkExpectedTypeHint heq (← mkEq c (← getFalseExpr))
closeGoal (← mkEqMP heq h)
else if cₚ.isTrivial then
trace[grind.cutsat.dvd.trivial] "{← cₚ.denoteExpr}"
return ()
else
let d₁ := cₚ.c.k
let .add a₁ x p₁ := cₚ.c.p
| throwError "internal `grind` error, unexpected divisibility constraint {indentExpr (← cₚ.denoteExpr)}"
if let some cₚ' := (← get').dvdCnstrs[x]! then
trace[grind.cutsat.dvd.solve] "{← cₚ.denoteExpr}, {← cₚ'.denoteExpr}"
let d₂ := cₚ'.c.k
let .add a₂ _ p₂ := cₚ'.c.p
| throwError "internal `grind` error, unexpected divisibility constraint {indentExpr (← cₚ'.denoteExpr)}"
let (d, α, β) := gcdExt (a₁*d₂) (a₂*d₁)
/-
We have that
`d = α*a₁*d₂ + β*a₂*d₁`
`d = gcd (a₁*d₂) (a₂*d₁)`
and two implied divisibility constraints:
- `d₁*d₂ d*x + α*d₂*p₁ + β*d₁*p₂`
- `d a₂*p₁ - a₁*p₂`
-/
let α_d₂_p₁ := p₁.mul (α*d₂)
let β_d₁_p₂ := p₂.mul (β*d₁)
let combine := { c.k := d₁*d₂, c.p := .add d x (α_d₂_p₁.combine β_d₁_p₂), h := .solveCombine cₚ cₚ' }
trace[grind.cutsat.dvd.solve.combine] "{← combine.denoteExpr}"
modify' fun s => { s with dvdCnstrs := s.dvdCnstrs.set x none}
assertDvdCnstr combine
let a₂_p₁ := p₁.mul a₂
let a₁_p₂ := p₂.mul (-a₁)
let elim := { c.k := d, c.p := a₂_p₁.combine a₁_p₂, h := .solveElim cₚ cₚ' }
trace[grind.cutsat.dvd.solve.elim] "{← elim.denoteExpr}"
assertDvdCnstr elim
else
trace[grind.cutsat.dvd.update] "{← cₚ.denoteExpr}"
modify' fun s => { s with dvdCnstrs := s.dvdCnstrs.set x (some cₚ) }
builtin_grind_propagator propagateDvd ↓Dvd.dvd := fun e => do
let_expr Dvd.dvd _ inst a b ← e | return ()
unless (← isInstDvdInt inst) do return ()
let some k ← getIntValue? a
| reportIssue! "non-linear divisibility constraint found{indentExpr e}"
let p ← toPoly b
let c : DvdCnstr := { k, p }
trace[grind.cutsat.assert.dvd] "{e}, {repr c}"
-- TODO
return ()
return ()
if (← isEqTrue e) then
let p ← toPoly b
let cₚ := { c.k := k, c.p := p, h := .expr (← mkOfEqTrue (← mkEqTrueProof e)) }
trace[grind.cutsat.assert.dvd] "{← cₚ.denoteExpr}"
assertDvdCnstr cₚ
else if (← isEqFalse e) then
/-
TODO: we have `¬ a b`, we should assert
`∃ x z, b = a*x + z ∧ 1 ≤ z < a`
-/
return ()
end Lean.Meta.Grind.Arith.Cutsat

View file

@ -5,20 +5,9 @@ Authors: Leonardo de Moura
-/
prelude
import Lean.Meta.Tactic.Grind.Arith.Cutsat.Util
import Lean.Meta.Tactic.Grind.Arith.Cutsat.Var
namespace Int.Linear
/--
Returns `true` if the variables in the given polynomial are sorted
in decreasing order.
-/
def Poly.isSorted (p : Poly) : Bool :=
go none p
where
go : Option Var → Poly → Bool
| _, .num _ => true
| none, .add _ y p => go (some y) p
| some x, .add _ y p => x > y && go (some y) p
/-- Returns `true` if all coefficients are not `0`. -/
def Poly.checkCoeffs : Poly → Bool
| .num _ => true

View file

@ -0,0 +1,15 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Meta.Tactic.Grind.Arith.Cutsat.Util
namespace Lean.Meta.Grind.Arith.Cutsat
def DvdCnstrWithProof.toExprProof (cₚ : DvdCnstrWithProof) : ProofM Expr := do
-- TODO
mkSorry (← cₚ.denoteExpr) false
end Lean.Meta.Grind.Arith.Cutsat

View file

@ -13,15 +13,18 @@ namespace Lean.Meta.Grind.Arith.Cutsat
export Int.Linear (Var Poly RelCnstr DvdCnstr)
-- TODO: include RelCnstrWithProof and RelCnstrProof
mutual
/-- A divisibility constraint and its justification/proof. -/
structure DvdCnstrWithProof where
c : DvdCnstr
p : DvdCnstrProof
h : DvdCnstrProof
inductive DvdCnstrProof where
| expr (h : Expr)
| solveCombine (c₁ c₂ : DvdCnstrWithProof) (α β : Int)
| norm (c : DvdCnstrWithProof)
| divCoeffs (c : DvdCnstrWithProof)
| solveCombine (c₁ c₂ : DvdCnstrWithProof)
| solveElim (c₁ c₂ : DvdCnstrWithProof)
end

View file

@ -6,6 +6,56 @@ Authors: Leonardo de Moura
prelude
import Lean.Meta.Tactic.Grind.Types
namespace Int.Linear
def Poly.isZero : Poly → Bool
| .num 0 => true
| _ => false
/--
Returns `true` if the variables in the given polynomial are sorted
in decreasing order.
-/
def Poly.isSorted (p : Poly) : Bool :=
go none p
where
go : Option Var → Poly → Bool
| _, .num _ => true
| none, .add _ y p => go (some y) p
| some x, .add _ y p => x > y && go (some y) p
/--
If both `p₁.isSorted` and `p₂.isSorted`, returns a new
polynomial that is also sorted and `(p₁.combine p₂).denote ctx = p₁.denote ctx + p₂.denote ctx`.
-/
def Poly.combine (p₁ p₂ : Poly) : Poly :=
match _:p₁, _:p₂ with
| .num k₁, .num k₂ => .num (k₁+k₂)
| .num _, .add a x p => .add a x (combine p₁ p)
| .add a x p, .num _ => .add a x (combine p p₂)
| .add a₁ x₁ p₁', .add a₂ x₂ p₂' =>
if x₁ == x₂ then
let a := a₁ + a₂
if a == 0 then
combine p₁' p₂'
else
.add a x₁ (combine p₁' p₂')
else if x₁ > x₂ then
.add a₁ x₁ (combine p₁' p₂)
else
.add a₂ x₂ (combine p₁ p₂')
termination_by sizeOf p₁ + sizeOf p₂
def DvdCnstr.isSorted (c : DvdCnstr) : Bool :=
c.p.isSorted
def DvdCnstr.isTrivial (c : DvdCnstr) : Bool :=
match c.p with
| .num k' => k' % c.k == 0
| _ => c.k == 1
end Int.Linear
namespace Lean.Meta.Grind.Arith.Cutsat
def get' : GoalM State := do
@ -14,4 +64,28 @@ def get' : GoalM State := do
@[inline] def modify' (f : State → State) : GoalM Unit := do
modify fun s => { s with arith.cutsat := f s.arith.cutsat }
def getVars : GoalM (PArray Expr) :=
return (← get').vars
def DvdCnstrWithProof.denoteExpr (cₚ : DvdCnstrWithProof) : GoalM Expr := do
let vars ← getVars
cₚ.c.denoteExpr (vars[·]!)
def toContextExpr : GoalM Expr := do
let vars ← getVars
if h : 0 < vars.size then
return RArray.toExpr (mkConst ``Int) id (RArray.ofFn (vars[·]) h)
else
return RArray.toExpr (mkConst ``Int) id (RArray.leaf (mkIntLit 0))
/-- Auxiliary monad for constructing cutsat proofs. -/
abbrev ProofM := ReaderT Expr GoalM
/-- Returns a Lean expression representing the variable context used to construct cutsat proofs. -/
abbrev getContext : ProofM Expr := do
read
abbrev withProofContext (x : ProofM α) : GoalM α := do
x (← toContextExpr)
end Lean.Meta.Grind.Arith.Cutsat

View file

@ -7,10 +7,6 @@ prelude
import Lean.Meta.IntInstTesters
import Lean.Meta.Tactic.Grind.Arith.Cutsat.Util
def Int.Linear.Poly.isZero : Poly → Bool
| .num 0 => true
| _ => false
namespace Lean.Meta.Grind.Arith.Cutsat
/-- Creates a new variable in the cutsat module. -/

View file

@ -125,43 +125,43 @@ instance : ToExpr Int.Linear.RawDvdCnstr where
toExpr a := ofRawDvdCnstr a
toTypeExpr := mkConst ``Int.Linear.RawDvdCnstr
def _root_.Int.Linear.Expr.denoteExpr (ctx : Array Expr) (e : Int.Linear.Expr) : MetaM Expr := do
def _root_.Int.Linear.Expr.denoteExpr (ctx : Nat → Expr) (e : Int.Linear.Expr) : MetaM Expr := do
match e with
| .num v => return Lean.toExpr v
| .var i => return ctx[i]!
| .var x => return ctx x
| .neg a => return mkIntNeg (← denoteExpr ctx a)
| .add a b => return mkIntAdd (← denoteExpr ctx a) (← denoteExpr ctx b)
| .sub a b => return mkIntSub (← denoteExpr ctx a) (← denoteExpr ctx b)
| .mulL k a => return mkIntMul (toExpr k) (← denoteExpr ctx a)
| .mulR a k => return mkIntMul (← denoteExpr ctx a) (toExpr k)
def _root_.Int.Linear.RawRelCnstr.denoteExpr (ctx : Array Expr) (c : Int.Linear.RawRelCnstr) : MetaM Expr := do
def _root_.Int.Linear.RawRelCnstr.denoteExpr (ctx : Nat → Expr) (c : Int.Linear.RawRelCnstr) : MetaM Expr := do
match c with
| .eq e₁ e₂ => return mkIntEq (← e₁.denoteExpr ctx) (← e₂.denoteExpr ctx)
| .le e₁ e₂ => return mkIntLE (← e₁.denoteExpr ctx) (← e₂.denoteExpr ctx)
def _root_.Int.Linear.RawDvdCnstr.denoteExpr (ctx : Array Expr) (c : Int.Linear.RawDvdCnstr) : MetaM Expr := do
def _root_.Int.Linear.RawDvdCnstr.denoteExpr (ctx : Nat → Expr) (c : Int.Linear.RawDvdCnstr) : MetaM Expr := do
return mkIntDvd (mkIntLit c.k) (← c.e.denoteExpr ctx)
def _root_.Int.Linear.Poly.denoteExpr (ctx : Array Expr) (p : Int.Linear.Poly) : MetaM Expr := do
def _root_.Int.Linear.Poly.denoteExpr (ctx : Nat → Expr) (p : Int.Linear.Poly) : MetaM Expr := do
match p with
| .num k => return toExpr k
| .add 1 x p => go ctx[x]! p
| .add k x p => go (mkIntMul (toExpr k) ctx[x]!) p
| .add 1 x p => go (ctx x) p
| .add k x p => go (mkIntMul (toExpr k) (ctx x)) p
where
go (r : Expr) (p : Int.Linear.Poly) : MetaM Expr :=
match p with
| .num 0 => return r
| .num k => return mkIntAdd r (toExpr k)
| .add 1 x p => go (mkIntAdd r ctx[x]!) p
| .add k x p => go (mkIntAdd r (mkIntMul (toExpr k) ctx[x]!)) p
| .add 1 x p => go (mkIntAdd r (ctx x)) p
| .add k x p => go (mkIntAdd r (mkIntMul (toExpr k) (ctx x))) p
def _root_.Int.Linear.RelCnstr.denoteExpr (ctx : Array Expr) (c : Int.Linear.RelCnstr) : MetaM Expr := do
def _root_.Int.Linear.RelCnstr.denoteExpr (ctx : Nat → Expr) (c : Int.Linear.RelCnstr) : MetaM Expr := do
match c with
| .eq p => return mkIntEq (← p.denoteExpr ctx) (mkIntLit 0)
| .le p => return mkIntLE (← p.denoteExpr ctx) (mkIntLit 0)
def _root_.Int.Linear.DvdCnstr.denoteExpr (ctx : Array Expr) (c : Int.Linear.DvdCnstr) : MetaM Expr := do
def _root_.Int.Linear.DvdCnstr.denoteExpr (ctx : Nat → Expr) (c : Int.Linear.DvdCnstr) : MetaM Expr := do
return mkIntDvd (mkIntLit c.k) (← c.p.denoteExpr ctx)
namespace ToLinear

View file

@ -46,7 +46,7 @@ namespace Lean.Meta.Simp.Arith.Int
def simpRelCnstrPos? (e : Expr) : MetaM (Option (Expr × Expr)) := do
let some (c, atoms) ← toRawRelCnstr? e | return none
withAbstractAtoms atoms ``Int fun atoms => do
let lhs ← c.denoteExpr atoms
let lhs ← c.denoteExpr (atoms[·]!)
let c' := c.norm
if c'.isUnsat then
let r := mkConst ``False
@ -70,12 +70,12 @@ def simpRelCnstrPos? (e : Expr) : MetaM (Option (Expr × Expr)) := do
| _ =>
let k := c'.gcdCoeffs
if k == 1 then
let r ← c'.denoteExpr atoms
let r ← c'.denoteExpr (atoms[·]!)
let h := mkApp4 (mkConst ``Int.Linear.RawRelCnstr.eq_of_norm_eq) (toContextExpr atoms) (toExpr c) (toExpr c') reflBoolTrue
return some (r, ← mkExpectedTypeHint h (← mkEq lhs r))
else if c'.getConst % k == 0 then
let c' := c'.div k
let r ← c'.denoteExpr atoms
let r ← c'.denoteExpr (atoms[·]!)
let h := mkApp5 (mkConst ``Int.Linear.RawRelCnstr.eq_of_divBy) (toContextExpr atoms) (toExpr c) (toExpr c') (toExpr (Int.ofNat k)) reflBoolTrue
return some (r, ← mkExpectedTypeHint h (← mkEq lhs r))
else if c'.isEq then
@ -85,7 +85,7 @@ def simpRelCnstrPos? (e : Expr) : MetaM (Option (Expr × Expr)) := do
else
-- `p.isLe`: tighten the bound
let c' := c'.div k
let r ← c'.denoteExpr atoms
let r ← c'.denoteExpr (atoms[·]!)
let h := mkApp5 (mkConst ``Int.Linear.RawRelCnstr.eq_of_divByLe) (toContextExpr atoms) (toExpr c) (toExpr c') (toExpr (Int.ofNat k)) reflBoolTrue
return some (r, ← mkExpectedTypeHint h (← mkEq lhs r))
else
@ -129,14 +129,14 @@ def simpDvdCnstr? (e : Expr) : MetaM (Option (Expr × Expr)) := do
let some (c, atoms) ← toRawDvdCnstr? e | return none
if c.k == 0 then return none
withAbstractAtoms atoms ``Int fun atoms => do
let lhs ← c.denoteExpr atoms
let lhs ← c.denoteExpr (atoms[·]!)
let c' := c.norm
let k := c'.p.gcdCoeffs c'.k
if c'.p.getConst % k == 0 then
let c' := c'.div k
if c == c'.toRaw then
return none
let r ← c'.denoteExpr atoms
let r ← c'.denoteExpr (atoms[·]!)
let h := mkApp5 (mkConst ``Int.Linear.RawDvdCnstr.eq_of_isEqv) (toContextExpr atoms) (toExpr c) (toExpr c') (toExpr k) reflBoolTrue
return some (r, ← mkExpectedTypeHint h (← mkEq lhs r))
else
@ -151,7 +151,7 @@ def simpExpr? (input : Expr) : MetaM (Option (Expr × Expr)) := do
if e != e' then
-- We only return some if monomials were fused
let p := mkApp4 (mkConst ``Int.Linear.Expr.eq_of_toPoly_eq) (toContextExpr atoms) (toExpr e) (toExpr e') reflBoolTrue
let r ← e'.denoteExpr atoms
let r ← e'.denoteExpr (atoms[·]!)
return some (r, ← mkExpectedTypeHint p (← mkEq input r))
else
return none