feat: Nat div/mod in cutsat (#7502)

This PR implements support for `Nat` div and mod in the cutsat
procedure.
This commit is contained in:
Leonardo de Moura 2025-03-15 17:29:43 -07:00 committed by GitHub
parent b7354aacaa
commit ae81567fbe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 45 additions and 43 deletions

View file

@ -2245,10 +2245,10 @@ private def intMulFn : Expr :=
mkApp4 (mkConst ``HMul.hMul [0, 0, 0]) Int.mkType Int.mkType Int.mkType Int.mkInstHMul
private def intDivFn : Expr :=
mkApp4 (mkConst ``HDiv.hDiv [0, 0, 0]) Int.mkType Int.mkType Int.mkType Int.mkInstHMul
mkApp4 (mkConst ``HDiv.hDiv [0, 0, 0]) Int.mkType Int.mkType Int.mkType Int.mkInstHDiv
private def intModFn : Expr :=
mkApp4 (mkConst ``HMod.hMod [0, 0, 0]) Int.mkType Int.mkType Int.mkType Int.mkInstHMul
mkApp4 (mkConst ``HMod.hMod [0, 0, 0]) Int.mkType Int.mkType Int.mkType Int.mkInstHMod
private def intNatCastFn : Expr :=
mkApp2 (mkConst ``NatCast.natCast [0]) Int.mkType Int.mkInstNatCast

View file

@ -16,7 +16,6 @@ import Lean.Meta.Tactic.Grind.Arith.Cutsat.Var
import Lean.Meta.Tactic.Grind.Arith.Cutsat.EqCnstr
import Lean.Meta.Tactic.Grind.Arith.Cutsat.SearchM
import Lean.Meta.Tactic.Grind.Arith.Cutsat.Model
import Lean.Meta.Tactic.Grind.Arith.Cutsat.DivMod
namespace Lean

View file

@ -1,36 +0,0 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Meta.Tactic.Grind.PropagatorAttr
import Lean.Meta.Tactic.Grind.Arith.Cutsat.Util
import Lean.Meta.Tactic.Grind.Canon
namespace Lean.Meta.Grind.Arith.Cutsat
private def expandDivMod (a : Expr) (b : Int) : GoalM Unit := do
if b == 0 then return ()
if (← get').divMod.contains (a, b) then return ()
modify' fun s => { s with divMod := s.divMod.insert (a, b) }
let n : Int := 1 - b.natAbs
let b := mkIntLit b
pushNewProof <| mkApp2 (mkConst ``Int.Linear.ediv_emod) a b
pushNewProof <| mkApp3 (mkConst ``Int.Linear.emod_nonneg) a b reflBoolTrue
pushNewProof <| mkApp4 (mkConst ``Int.Linear.emod_le) a b (toExpr n) reflBoolTrue
builtin_grind_propagator propagateDiv ↑HDiv.hDiv := fun e => do
let_expr HDiv.hDiv _ _ _ inst a b ← e | return ()
if (← isInstHDivInt inst) then
let some b ← getIntValue? b | return ()
-- Remark: we currently do not consider the case where `b` is in the equivalence class of a numeral.
expandDivMod a b
builtin_grind_propagator propagateMod ↑HMod.hMod := fun e => do
let_expr HMod.hMod _ _ _ inst a b ← e | return ()
if (← isInstHModInt inst) then
let some b ← getIntValue? b | return ()
expandDivMod a b
end Lean.Meta.Grind.Arith.Cutsat

View file

@ -247,7 +247,7 @@ private def processNewNatEq (a b : Expr) : GoalM Unit := do
let rhs' ← toLinearExpr (rhs.denoteAsIntExpr ctx) gen
let p := lhs'.sub rhs' |>.norm
let c := { p, h := .coreNat a b ctx lhs rhs lhs' rhs' : EqCnstr }
trace[grind.cutsat.assert.eq] "{← c.pp}"
trace[grind.debug.cutsat.nat] "{← c.pp}"
c.assert
@[export lean_process_cutsat_eq]
@ -293,7 +293,7 @@ private def processNewNatDiseq (a b : Expr) : GoalM Unit := do
let rhs' ← toLinearExpr (rhs.denoteAsIntExpr ctx) gen
let p := lhs'.sub rhs' |>.norm
let c := { p, h := .coreNat a b ctx lhs rhs lhs' rhs' : DiseqCnstr }
trace[grind.cutsat.assert.eq] "{← c.pp}"
trace[grind.debug.cutsat.nat] "{← c.pp}"
c.assert
@[export lean_process_cutsat_diseq]
@ -306,12 +306,14 @@ def processNewDiseqImpl (a b : Expr) : GoalM Unit := do
/-- Different kinds of terms internalized by this module. -/
private inductive SupportedTermKind where
| add | mul | num
| add | mul | num | div | mod
private def getKindAndType? (e : Expr) : Option (SupportedTermKind × Expr) :=
match_expr e with
| HAdd.hAdd α _ _ _ _ _ => some (.add, α)
| HMul.hMul α _ _ _ _ _ => some (.mul, α)
| HDiv.hDiv α _ _ _ _ _ => some (.div, α)
| HMod.hMod α _ _ _ _ _ => some (.mod, α)
| OfNat.ofNat α _ _ => some (.num, α)
| Neg.neg α _ a =>
let_expr OfNat.ofNat _ _ _ := a | none
@ -319,6 +321,7 @@ private def getKindAndType? (e : Expr) : Option (SupportedTermKind × Expr) :=
| _ => none
private def isForbiddenParent (parent? : Option Expr) (k : SupportedTermKind) : Bool := Id.run do
if k matches .div | .mod then return false
let some parent := parent? | return false
let .const declName _ := parent.getAppFn | return false
if declName == ``HAdd.hAdd || declName == ``LE.le || declName == ``Dvd.dvd then return true
@ -326,6 +329,7 @@ private def isForbiddenParent (parent? : Option Expr) (k : SupportedTermKind) :
| .add => return false
| .mul => return declName == ``HMul.hMul
| .num => return declName == ``HMul.hMul || declName == ``Eq
| _ => unreachable!
private def internalizeInt (e : Expr) : GoalM Unit := do
if (← get').terms.contains { expr := e } then return ()
@ -334,6 +338,29 @@ private def internalizeInt (e : Expr) : GoalM Unit := do
trace[grind.cutsat.internalize] "{aquote e}:= {← p.pp}"
modify' fun s => { s with terms := s.terms.insert { expr := e } p }
private def expandDivMod (a : Expr) (b : Int) : GoalM Unit := do
if b == 0 then return ()
if (← get').divMod.contains (a, b) then return ()
modify' fun s => { s with divMod := s.divMod.insert (a, b) }
let n : Int := 1 - b.natAbs
let b := mkIntLit b
pushNewProof <| mkApp2 (mkConst ``Int.Linear.ediv_emod) a b
pushNewProof <| mkApp3 (mkConst ``Int.Linear.emod_nonneg) a b reflBoolTrue
pushNewProof <| mkApp4 (mkConst ``Int.Linear.emod_le) a b (toExpr n) reflBoolTrue
private def propagateDiv (e : Expr) : GoalM Unit := do
let_expr HDiv.hDiv _ _ _ inst a b ← e | return ()
if (← isInstHDivInt inst) then
let some b ← getIntValue? b | return ()
-- Remark: we currently do not consider the case where `b` is in the equivalence class of a numeral.
expandDivMod a b
private def propagateMod (e : Expr) : GoalM Unit := do
let_expr HMod.hMod _ _ _ inst a b ← e | return ()
if (← isInstHModInt inst) then
let some b ← getIntValue? b | return ()
expandDivMod a b
/--
Internalizes an integer (and `Nat`) expression. Here are the different cases that are handled.
@ -348,7 +375,10 @@ def internalize (e : Expr) (parent? : Option Expr) : GoalM Unit := do
if isForbiddenParent parent? k then return ()
trace[grind.debug.cutsat.internalize] "{e} : {type}"
if type.isConstOf ``Int then
internalizeInt e
match k with
| .div => propagateDiv e
| .mod => propagateMod e
| _ => internalizeInt e
else if type.isConstOf ``Nat then
markForeignTerm e .nat

View file

@ -14,3 +14,12 @@ example (a b c : Nat) : a + 2*b = 0 → b + c + b = 0 → a = c := by
example (a : Nat) : a ≤ 2 → a ≠ 0 → a ≠ 1 → a ≠ 2 → False := by
grind
example (x y : Nat) : x / 2 + y = 3 → x = 5 → y = 1 := by
grind
example (x y : Nat) : x % 2 + y = 3 → x = 5 → y = 2 := by
grind
example (x y : Nat) : x = y / 2 → y % 2 = 0 → y = 2*x := by
grind