From cd3eb9125cabcc9cfb063ee0075ca07c9612f3de Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 8 Feb 2025 20:32:54 -0800 Subject: [PATCH] feat: linear integer arith normalizer (#7002) This PR implements the normalizer for linear integer arithmetic expressions. It is not connect to `simp +arith` yet because of some spurious `[simp]` attributes. --- src/Init/Data/Int/Linear.lean | 60 +++++- src/Lean/Expr.lean | 94 ++++++++- src/Lean/Meta/IntInstTesters.lean | 57 +++++ src/Lean/Meta/Tactic/LinearArith/Basic.lean | 15 ++ src/Lean/Meta/Tactic/LinearArith/Int.lean | 8 + .../Meta/Tactic/LinearArith/Int/Basic.lean | 195 ++++++++++++++++++ .../Meta/Tactic/LinearArith/Int/Simp.lean | 80 +++++++ .../Meta/Tactic/LinearArith/Nat/Basic.lean | 5 +- .../Meta/Tactic/LinearArith/Nat/Simp.lean | 26 +-- tests/lean/run/liaByRefl.lean | 5 + 10 files changed, 508 insertions(+), 37 deletions(-) create mode 100644 src/Lean/Meta/IntInstTesters.lean create mode 100644 src/Lean/Meta/Tactic/LinearArith/Int.lean create mode 100644 src/Lean/Meta/Tactic/LinearArith/Int/Basic.lean create mode 100644 src/Lean/Meta/Tactic/LinearArith/Int/Simp.lean diff --git a/src/Init/Data/Int/Linear.lean b/src/Init/Data/Int/Linear.lean index 32a20e8600..1878572e36 100644 --- a/src/Init/Data/Int/Linear.lean +++ b/src/Init/Data/Int/Linear.lean @@ -25,13 +25,15 @@ inductive Expr where | var (i : Var) | add (a b : Expr) | sub (a b : Expr) + | neg (a : Expr) | mulL (k : Int) (a : Expr) | mulR (a : Expr) (k : Int) - deriving Inhabited + deriving Inhabited, BEq def Expr.denote (ctx : Context) : Expr → Int | .add a b => Int.add (denote ctx a) (denote ctx b) | .sub a b => Int.sub (denote ctx a) (denote ctx b) + | .neg a => Int.neg (denote ctx a) | .num k => k | .var v => v.denote ctx | .mulL k e => Int.mul k (denote ctx e) @@ -40,7 +42,7 @@ def Expr.denote (ctx : Context) : Expr → Int inductive Poly where | num (k : Int) | add (k : Int) (v : Var) (p : Poly) - deriving BEq, Repr + deriving BEq def Poly.denote (ctx : Context) (p : Poly) : Int := match p with @@ -81,6 +83,7 @@ where | .sub a b => go coeff a ∘ go (-coeff) b | .mulL k a | .mulR a k => bif k == 0 then id else go (Int.mul coeff k) a + | .neg a => go (-coeff) a def Expr.toPoly (e : Expr) : Poly := e.toPoly'.norm @@ -88,7 +91,7 @@ def Expr.toPoly (e : Expr) : Poly := inductive PolyCnstr where | eq (p : Poly) | le (p : Poly) - deriving BEq, Repr + deriving BEq def PolyCnstr.denote (ctx : Context) : PolyCnstr → Prop | .eq p => p.denote ctx = 0 @@ -101,7 +104,7 @@ def PolyCnstr.norm : PolyCnstr → PolyCnstr inductive ExprCnstr where | eq (p₁ p₂ : Expr) | le (p₁ p₂ : Expr) - deriving Inhabited + deriving Inhabited, BEq def ExprCnstr.denote (ctx : Context) : ExprCnstr → Prop | .eq e₁ e₂ => e₁.denote ctx = e₂.denote ctx @@ -137,8 +140,9 @@ theorem Poly.denote_norm (ctx : Context) (p : Poly) : p.norm.denote ctx = p.deno attribute [local simp] Poly.denote_norm private theorem sub_fold (a b : Int) : a.sub b = a - b := rfl +private theorem neg_fold (a : Int) : a.neg = -a := rfl -attribute [local simp] sub_fold +attribute [local simp] sub_fold neg_fold attribute [local simp] ExprCnstr.denote ExprCnstr.toPoly PolyCnstr.denote Expr.denote theorem Expr.denote_toPoly'_go (ctx : Context) (e : Expr) : @@ -163,6 +167,7 @@ theorem Expr.denote_toPoly'_go (ctx : Context) (e : Expr) : simp at ih rw [ih] rw [Int.mul_assoc, Int.mul_comm k'] + | case7 k a ih => simp [toPoly'.go, ih] theorem Expr.denote_toPoly (ctx : Context) (e : Expr) : e.toPoly.denote ctx = e.denote ctx := by simp [toPoly, toPoly', Expr.denote_toPoly'_go] @@ -209,4 +214,49 @@ theorem ExprCnstr.eq_of_toPoly_eq (ctx : Context) (c c' : ExprCnstr) (h : c.toPo rw [denote_toPoly, denote_toPoly] at h assumption +def PolyCnstr.isUnsat : PolyCnstr → Bool + | .eq (.num k) => k != 0 + | .eq _ => false + | .le (.num k) => k > 0 + | .le _ => false + +theorem PolyCnstr.eq_false_of_isUnsat (ctx : Context) (p : PolyCnstr) : p.isUnsat → p.denote ctx = False := by + unfold isUnsat <;> split <;> simp <;> try contradiction + apply Int.not_le_of_gt + +theorem ExprCnstr.eq_false_of_isUnsat (ctx : Context) (c : ExprCnstr) (h : c.toPoly.isUnsat) : c.denote ctx = False := by + have := PolyCnstr.eq_false_of_isUnsat ctx (c.toPoly) h + rw [ExprCnstr.denote_toPoly] at this + assumption + +def PolyCnstr.isValid : PolyCnstr → Bool + | .eq (.num k) => k == 0 + | .eq _ => false + | .le (.num k) => k ≤ 0 + | .le _ => false + +theorem PolyCnstr.eq_true_of_isValid (ctx : Context) (p : PolyCnstr) : p.isValid → p.denote ctx = True := by + unfold isValid <;> split <;> simp + +theorem ExprCnstr.eq_true_of_isValid (ctx : Context) (c : ExprCnstr) (h : c.toPoly.isValid) : c.denote ctx = True := by + have := PolyCnstr.eq_true_of_isValid ctx (c.toPoly) h + rw [ExprCnstr.denote_toPoly] at this + assumption + end Int.Linear + +theorem Int.not_le_eq (a b : Int) : (¬a ≤ b) = (b + 1 ≤ a) := by + apply propext; constructor + · intro h; have h := Int.add_one_le_of_lt (Int.lt_of_not_ge h); assumption + · intro h; apply Int.not_le_of_gt; exact h + +theorem Int.not_ge_eq (a b : Int) : (¬a ≥ b) = (a + 1 ≤ b) := by + apply Int.not_le_eq + +theorem Int.not_lt_eq (a b : Int) : (¬a < b) = (b ≤ a) := by + apply propext; constructor + · intro h; simp [Int.not_lt] at h; assumption + · intro h; apply Int.not_le_of_gt; simp [Int.lt_add_one_iff, *] + +theorem Int.not_gt_eq (a b : Int) : (¬a > b) = (a ≤ b) := by + apply Int.not_lt_eq diff --git a/src/Lean/Expr.lean b/src/Lean/Expr.lean index 53f43b58c4..20ae50e900 100644 --- a/src/Lean/Expr.lean +++ b/src/Lean/Expr.lean @@ -5,6 +5,7 @@ Authors: Leonardo de Moura -/ prelude import Init.Data.Hashable +import Init.Data.Int import Lean.Data.KVMap import Lean.Data.SMap import Lean.Level @@ -2140,16 +2141,13 @@ def mkInstLE : Expr := mkConst ``instLENat end Nat private def natAddFn : Expr := - let nat := mkConst ``Nat - mkApp4 (mkConst ``HAdd.hAdd [0, 0, 0]) nat nat nat Nat.mkInstHAdd + mkApp4 (mkConst ``HAdd.hAdd [0, 0, 0]) Nat.mkType Nat.mkType Nat.mkType Nat.mkInstHAdd private def natSubFn : Expr := - let nat := mkConst ``Nat - mkApp4 (mkConst ``HSub.hSub [0, 0, 0]) nat nat nat Nat.mkInstHSub + mkApp4 (mkConst ``HSub.hSub [0, 0, 0]) Nat.mkType Nat.mkType Nat.mkType Nat.mkInstHSub private def natMulFn : Expr := - let nat := mkConst ``Nat - mkApp4 (mkConst ``HMul.hMul [0, 0, 0]) nat nat nat Nat.mkInstHMul + mkApp4 (mkConst ``HMul.hMul [0, 0, 0]) Nat.mkType Nat.mkType Nat.mkType Nat.mkInstHMul /-- Given `a : Nat`, returns `Nat.succ a` -/ def mkNatSucc (a : Expr) : Expr := @@ -2168,17 +2166,97 @@ def mkNatMul (a b : Expr) : Expr := mkApp2 natMulFn a b private def natLEPred : Expr := - mkApp2 (mkConst ``LE.le [0]) (mkConst ``Nat) Nat.mkInstLE + mkApp2 (mkConst ``LE.le [0]) Nat.mkType Nat.mkInstLE /-- Given `a b : Nat`, return `a ≤ b` -/ def mkNatLE (a b : Expr) : Expr := mkApp2 natLEPred a b private def natEqPred : Expr := - mkApp (mkConst ``Eq [1]) (mkConst ``Nat) + mkApp (mkConst ``Eq [1]) Nat.mkType /-- Given `a b : Nat`, return `a = b` -/ def mkNatEq (a b : Expr) : Expr := mkApp2 natEqPred a b +/-! Constants for Int typeclasses. -/ +namespace Int + +protected def mkType : Expr := mkConst ``Int + +def mkInstNeg : Expr := mkConst ``Int.instNegInt + +def mkInstAdd : Expr := mkConst ``Int.instAdd +def mkInstHAdd : Expr := mkApp2 (mkConst ``instHAdd [levelZero]) Int.mkType mkInstAdd + +def mkInstSub : Expr := mkConst ``Int.instSub +def mkInstHSub : Expr := mkApp2 (mkConst ``instHSub [levelZero]) Int.mkType mkInstSub + +def mkInstMul : Expr := mkConst ``Int.instMul +def mkInstHMul : Expr := mkApp2 (mkConst ``instHMul [levelZero]) Int.mkType mkInstMul + +def mkInstDiv : Expr := mkConst ``Int.instDiv +def mkInstHDiv : Expr := mkApp2 (mkConst ``instHDiv [levelZero]) Int.mkType mkInstDiv + +def mkInstMod : Expr := mkConst ``Int.instMod +def mkInstHMod : Expr := mkApp2 (mkConst ``instHMod [levelZero]) Int.mkType mkInstMod + +def mkInstPow : Expr := mkConst ``Int.instNatPow +def mkInstPowNat : Expr := mkApp2 (mkConst ``instPowNat [levelZero]) Int.mkType mkInstPow +def mkInstHPow : Expr := mkApp3 (mkConst ``instHPow [levelZero, levelZero]) Int.mkType Nat.mkType mkInstPowNat + +def mkInstLT : Expr := mkConst ``Int.instLTInt +def mkInstLE : Expr := mkConst ``Int.instLEInt + +end Int + +private def intNegFn : Expr := + mkApp2 (mkConst ``Neg.neg [0]) Int.mkType Int.mkInstNeg + +private def intAddFn : Expr := + mkApp4 (mkConst ``HAdd.hAdd [0, 0, 0]) Int.mkType Int.mkType Int.mkType Int.mkInstHAdd + +private def intSubFn : Expr := + mkApp4 (mkConst ``HSub.hSub [0, 0, 0]) Int.mkType Int.mkType Int.mkType Int.mkInstHSub + +private def intMulFn : Expr := + mkApp4 (mkConst ``HMul.hMul [0, 0, 0]) Int.mkType Int.mkType Int.mkType Int.mkInstHMul + +/-- Given `a : Int`, returns `- a` -/ +def mkIntNeg (a : Expr) : Expr := + mkApp intNegFn a + +/-- Given `a b : Int`, returns `a + b` -/ +def mkIntAdd (a b : Expr) : Expr := + mkApp2 intAddFn a b + +/-- Given `a b : Int`, returns `a - b` -/ +def mkIntSub (a b : Expr) : Expr := + mkApp2 intSubFn a b + +/-- Given `a b : Int`, returns `a * b` -/ +def mkIntMul (a b : Expr) : Expr := + mkApp2 intMulFn a b + +private def intLEPred : Expr := + mkApp2 (mkConst ``LE.le [0]) Int.mkType Int.mkInstLE + +/-- Given `a b : Int`, return `a ≤ b` -/ +def mkIntLE (a b : Expr) : Expr := + mkApp2 intLEPred a b + +private def intEqPred : Expr := + mkApp (mkConst ``Eq [1]) Int.mkType + +/-- Given `a b : Int`, return `a = b` -/ +def mkIntEq (a b : Expr) : Expr := + mkApp2 intEqPred a b + +def mkIntLit (n : Nat) : Expr := + let r := mkRawNatLit n + mkApp3 (mkConst ``OfNat.ofNat [levelZero]) Int.mkType r (mkApp (mkConst ``instOfNat) r) + +def reflBoolTrue : Expr := + mkApp2 (mkConst ``Eq.refl [levelOne]) (mkConst ``Bool) (mkConst ``Bool.true) + end Lean diff --git a/src/Lean/Meta/IntInstTesters.lean b/src/Lean/Meta/IntInstTesters.lean new file mode 100644 index 0000000000..4dfe86c12d --- /dev/null +++ b/src/Lean/Meta/IntInstTesters.lean @@ -0,0 +1,57 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import Lean.Meta.Basic + +namespace Lean.Meta +/-! +Functions for testing whether expressions are canonical `Int` instances. +-/ + +def isInstOfNatInt (e : Expr) : MetaM Bool := do + let_expr instOfNat _ ← e | return false + return true +def isInstNegInt (e : Expr) : MetaM Bool := do + let_expr Int.instNegInt ← e | return false + return true +def isInstAddInt (e : Expr) : MetaM Bool := do + let_expr Int.instAdd ← e | return false + return true +def isInstSubInt (e : Expr) : MetaM Bool := do + let_expr Int.instSub ← e | return false + return true +def isInstMulInt (e : Expr) : MetaM Bool := do + let_expr Int.instMul ← e | return false + return true +def isInstDivInt (e : Expr) : MetaM Bool := do + let_expr Int.instDiv ← e | return false + return true +def isInstModInt (e : Expr) : MetaM Bool := do + let_expr Int.instMod ← e | return false + return true +def isInstHAddInt (e : Expr) : MetaM Bool := do + let_expr instHAdd _ i ← e | return false + isInstAddInt i +def isInstHSubInt (e : Expr) : MetaM Bool := do + let_expr instHSub _ i ← e | return false + isInstSubInt i +def isInstHMulInt (e : Expr) : MetaM Bool := do + let_expr instHMul _ i ← e | return false + isInstMulInt i +def isInstHDivInt (e : Expr) : MetaM Bool := do + let_expr instHDiv _ i ← e | return false + isInstDivInt i +def isInstHModInt (e : Expr) : MetaM Bool := do + let_expr instHMod _ i ← e | return false + isInstModInt i +def isInstLTInt (e : Expr) : MetaM Bool := do + let_expr Int.instLTInt ← e | return false + return true +def isInstLEInt (e : Expr) : MetaM Bool := do + let_expr Int.instLEInt ← e | return false + return true + +end Lean.Meta diff --git a/src/Lean/Meta/Tactic/LinearArith/Basic.lean b/src/Lean/Meta/Tactic/LinearArith/Basic.lean index 0811a0a71a..f5e1fe999d 100644 --- a/src/Lean/Meta/Tactic/LinearArith/Basic.lean +++ b/src/Lean/Meta/Tactic/LinearArith/Basic.lean @@ -4,9 +4,24 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ prelude +import Lean.Meta.Basic import Lean.Expr namespace Lean.Meta.Linear +/- +To prevent the kernel from accidentially reducing the atoms in the equation while typechecking, +we abstract over them. +-/ +def withAbstractAtoms (atoms : Array Expr) (type : Name) (k : Array Expr → MetaM (Option (Expr × Expr))) : + MetaM (Option (Expr × Expr)) := do + let atoms := atoms + let decls : Array (Name × (Array Expr → MetaM Expr)) ← atoms.mapM fun _ => do + return ((← mkFreshUserName `x), fun _ => pure (mkConst type)) + withLocalDeclsD decls fun ctxt => do + let some (r, p) ← k ctxt | return none + let r := (← mkLambdaFVars ctxt r).beta atoms + let p := mkAppN (← mkLambdaFVars ctxt p) atoms + return some (r, p) /-- Quick filter for linear terms. -/ def isLinearTerm (e : Expr) : Bool := diff --git a/src/Lean/Meta/Tactic/LinearArith/Int.lean b/src/Lean/Meta/Tactic/LinearArith/Int.lean new file mode 100644 index 0000000000..9b687fa8b1 --- /dev/null +++ b/src/Lean/Meta/Tactic/LinearArith/Int.lean @@ -0,0 +1,8 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import Lean.Meta.Tactic.LinearArith.Int.Basic +import Lean.Meta.Tactic.LinearArith.Int.Simp diff --git a/src/Lean/Meta/Tactic/LinearArith/Int/Basic.lean b/src/Lean/Meta/Tactic/LinearArith/Int/Basic.lean new file mode 100644 index 0000000000..a359f59a25 --- /dev/null +++ b/src/Lean/Meta/Tactic/LinearArith/Int/Basic.lean @@ -0,0 +1,195 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import Init.Data.Int.Linear +import Lean.Meta.Check +import Lean.Meta.Offset +import Lean.Meta.IntInstTesters +import Lean.Meta.AppBuilder +import Lean.Meta.KExprMap +import Lean.Data.RArray + +namespace Int.Linear + +/-- Converts the linear polynomial into the "simplified" expression -/ +def Poly.toExpr (p : Poly) : Expr := + go none p +where + go : Option Expr → Poly → Expr + | none, .num k => .num k + | some e, .num 0 => e + | some e, .num k => .add e (.num k) + | none, .add 1 x p => go (some (.var x)) p + | none, .add k x p => go (some (.mulL k (.var x))) p + | some e, .add 1 x p => go (some (.add e (.var x))) p + | some e, .add k x p => go (some (.add e (.mulL k (.var x)))) p + +def PolyCnstr.toExprCnstr : PolyCnstr → ExprCnstr + | .eq p => .eq p.toExpr (.num 0) + | .le p => .le p.toExpr (.num 0) + +end Int.Linear + +namespace Lean.Meta.Linear.Int + +deriving instance Repr for Int.Linear.Poly +deriving instance Repr for Int.Linear.Expr +deriving instance Repr for Int.Linear.ExprCnstr +deriving instance Repr for Int.Linear.PolyCnstr + +abbrev LinearExpr := Int.Linear.Expr +abbrev LinearCnstr := Int.Linear.ExprCnstr +abbrev PolyExpr := Int.Linear.Poly + +def LinearExpr.toExpr (e : LinearExpr) : Expr := + open Int.Linear.Expr in + match e with + | .num v => mkApp (mkConst ``num) (Lean.toExpr v) + | .var i => mkApp (mkConst ``var) (mkNatLit i) + | .neg a => mkApp (mkConst ``neg) (toExpr a) + | .add a b => mkApp2 (mkConst ``add) (toExpr a) (toExpr b) + | .sub a b => mkApp2 (mkConst ``sub) (toExpr a) (toExpr b) + | .mulL k a => mkApp2 (mkConst ``mulL) (Lean.toExpr k) (toExpr a) + | .mulR a k => mkApp2 (mkConst ``mulR) (toExpr a) (Lean.toExpr k) + +instance : ToExpr LinearExpr where + toExpr a := a.toExpr + toTypeExpr := mkConst ``Int.Linear.Expr + +protected def LinearCnstr.toExpr (c : LinearCnstr) : Expr := + open Int.Linear.ExprCnstr in + match c with + | .eq e₁ e₂ => mkApp2 (mkConst ``eq) (toExpr e₁) (toExpr e₂) + | .le e₁ e₂ => mkApp2 (mkConst ``le) (toExpr e₁) (toExpr e₂) + +instance : ToExpr LinearCnstr where + toExpr a := a.toExpr + toTypeExpr := mkConst ``Int.Linear.ExprCnstr + +open Int.Linear.Expr in +def LinearExpr.toArith (ctx : Array Expr) (e : LinearExpr) : MetaM Expr := do + match e with + | .num v => return Lean.toExpr v + | .var i => return ctx[i]! + | .neg a => return mkIntNeg (← toArith ctx a) + | .add a b => return mkIntAdd (← toArith ctx a) (← toArith ctx b) + | .sub a b => return mkIntSub (← toArith ctx a) (← toArith ctx b) + | .mulL k a => return mkIntMul (Lean.toExpr k) (← toArith ctx a) + | .mulR a k => return mkIntMul (← toArith ctx a) (Lean.toExpr k) + +def LinearCnstr.toArith (ctx : Array Expr) (c : LinearCnstr) : MetaM Expr := do + match c with + | .eq e₁ e₂ => return mkIntEq (← LinearExpr.toArith ctx e₁) (← LinearExpr.toArith ctx e₂) + | .le e₁ e₂ => return mkIntLE (← LinearExpr.toArith ctx e₁) (← LinearExpr.toArith ctx e₂) + +namespace ToLinear + +structure State where + varMap : KExprMap Nat := {} -- It should be fine to use `KExprMap` here because the mapping should be small and few HeadIndex collisions. + vars : Array Expr := #[] + +abbrev M := StateRefT State MetaM + +open Int.Linear.Expr + +def addAsVar (e : Expr) : M LinearExpr := do + if let some x ← (← get).varMap.find? e then + return var x + else + let x := (← get).vars.size + let s ← get + set { varMap := (← s.varMap.insert e x), vars := s.vars.push e : State } + return var x + +private def toInt? (e : Expr) : MetaM (Option Int) := do + let_expr OfNat.ofNat _ n i ← e | return none + unless (← isInstOfNatInt i) do return none + let some n ← evalNat n |>.run | return none + return some (Int.ofNat n) + +partial def toLinearExpr (e : Expr) : M LinearExpr := do + match e with + | .mdata _ e => toLinearExpr e + | .app .. => visit e + | .mvar .. => visit e + | _ => addAsVar e +where + visit (e : Expr) : M LinearExpr := do + let mul (a b : Expr) := do + match (← toInt? a) with + | some k => return .mulL k (← toLinearExpr b) + | none => match (← toInt? b) with + | some k => return .mulR (← toLinearExpr a) k + | none => addAsVar e + match_expr e with + | OfNat.ofNat _ n i => + if (← isInstOfNatInt i) then toLinearExpr n + else addAsVar e + | Int.neg a => return .neg (← toLinearExpr a) + | Neg.neg _ i a => + if (← isInstNegInt i) then return .neg (← toLinearExpr a) + else addAsVar e + | Int.add a b => return .add (← toLinearExpr a) (← toLinearExpr b) + | Add.add _ i a b => + if (← isInstAddInt i) then return .add (← toLinearExpr a) (← toLinearExpr b) + else addAsVar e + | HAdd.hAdd _ _ _ i a b => + if (← isInstHAddInt i) then return .add (← toLinearExpr a) (← toLinearExpr b) + else addAsVar e + | Int.sub a b => return .sub (← toLinearExpr a) (← toLinearExpr b) + | Sub.sub _ i a b => + if (← isInstSubInt i) then return .sub (← toLinearExpr a) (← toLinearExpr b) + else addAsVar e + | HSub.hSub _ _ _ i a b => + if (← isInstSubInt i) then return .sub (← toLinearExpr a) (← toLinearExpr b) + else addAsVar e + | Int.mul a b => mul a b + | Mul.mul _ i a b => + if (← isInstMulInt i) then mul a b + else addAsVar e + | HMul.hMul _ _ _ i a b => + if (← isInstHMulInt i) then mul a b + else addAsVar e + | _ => addAsVar e + +partial def toLinearCnstr? (e : Expr) : M (Option LinearCnstr) := OptionT.run do + match_expr e with + | Eq α a b => + let_expr Int ← α | failure + return .eq (← toLinearExpr a) (← toLinearExpr b) + | Int.le a b => + return .le (← toLinearExpr a) (← toLinearExpr b) + | Int.lt a b => + return .le (.add (← toLinearExpr a) (.num 1)) (← toLinearExpr b) + | LE.le _ i a b => + guard (← isInstLENat i) + return .le (← toLinearExpr a) (← toLinearExpr b) + | LT.lt _ i a b => + guard (← isInstLTInt i) + return .le (.add (← toLinearExpr a) (.num 1)) (← toLinearExpr b) + | GE.ge _ i a b => + guard (← isInstLEInt i) + return .le (← toLinearExpr b) (← toLinearExpr a) + | GT.gt _ i a b => + guard (← isInstLTInt i) + return .le (.add (← toLinearExpr b) (.num 1)) (← toLinearExpr a) + | _ => failure + +def run (x : M α) : MetaM (α × Array Expr) := do + let (a, s) ← x.run {} + return (a, s.vars) + +end ToLinear + +export ToLinear (toLinearCnstr? toLinearExpr) + +def toContextExpr (ctx : Array Expr) : Expr := + if h : 0 < ctx.size then + RArray.toExpr (mkConst ``Int) id (RArray.ofArray ctx h) + else + RArray.toExpr (mkConst ``Int) id (RArray.leaf (mkIntLit 0)) + +end Lean.Meta.Linear.Int diff --git a/src/Lean/Meta/Tactic/LinearArith/Int/Simp.lean b/src/Lean/Meta/Tactic/LinearArith/Int/Simp.lean new file mode 100644 index 0000000000..c2a385d828 --- /dev/null +++ b/src/Lean/Meta/Tactic/LinearArith/Int/Simp.lean @@ -0,0 +1,80 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import Lean.Meta.Tactic.LinearArith.Basic +import Lean.Meta.Tactic.LinearArith.Int.Basic + +namespace Lean.Meta.Linear.Int + +def simpCnstrPos? (e : Expr) : MetaM (Option (Expr × Expr)) := do + let (some c, atoms) ← ToLinear.run (ToLinear.toLinearCnstr? e) | return none + withAbstractAtoms atoms ``Int fun ctx => do + let lhs ← c.toArith ctx + let p := c.toPoly + if p.isUnsat then + let r := mkConst ``False + let p := mkApp3 (mkConst ``Int.Linear.ExprCnstr.eq_false_of_isUnsat) (toContextExpr ctx) (toExpr c) reflBoolTrue + return some (r, ← mkExpectedTypeHint p (← mkEq lhs r)) + else if p.isValid then + let r := mkConst ``True + let p := mkApp3 (mkConst ``Int.Linear.ExprCnstr.eq_true_of_isValid) (toContextExpr ctx) (toExpr c) reflBoolTrue + return some (r, ← mkExpectedTypeHint p (← mkEq lhs r)) + else + let c' : LinearCnstr := p.toExprCnstr + if c != c' then + let r ← c'.toArith ctx + let p := mkApp4 (mkConst ``Int.Linear.ExprCnstr.eq_of_toPoly_eq) (toContextExpr ctx) (toExpr c) (toExpr c') reflBoolTrue + return some (r, ← mkExpectedTypeHint p (← mkEq lhs r)) + else + return none + +def simpCnstr? (e : Expr) : MetaM (Option (Expr × Expr)) := do + if let some arg := e.not? then + let mut eNew? := none + let mut thmName := Name.anonymous + match_expr arg with + | LE.le α _ lhs rhs => + if α.isConstOf ``Int then + eNew? := some (mkIntLE (mkIntAdd rhs (mkIntLit 1)) lhs) + thmName := ``Int.not_le_eq + | GE.ge α _ lhs rhs => + if α.isConstOf ``Int then + eNew? := some (mkIntLE (mkIntAdd lhs (mkIntLit 1)) rhs) + thmName := ``Int.not_ge_eq + | LT.lt α _ lhs rhs => + if α.isConstOf ``Int then + eNew? := some (mkIntLE rhs lhs) + thmName := ``Int.not_lt_eq + | GT.gt α _ lhs rhs => + if α.isConstOf ``Int then + eNew? := some (mkIntLE lhs rhs) + thmName := ``Int.not_gt_eq + | _ => pure () + if let some eNew := eNew? then + let h₁ := mkApp2 (mkConst thmName) (arg.getArg! 2) (arg.getArg! 3) + if let some (eNew', h₂) ← simpCnstrPos? eNew then + let h := mkApp6 (mkConst ``Eq.trans [levelOne]) (mkSort levelZero) e eNew eNew' h₁ h₂ + return some (eNew', h) + else + return some (eNew, h₁) + else + return none + else + simpCnstrPos? e + +def simpExpr? (e : Expr) : MetaM (Option (Expr × Expr)) := do + let (e, ctx) ← ToLinear.run (ToLinear.toLinearExpr e) + let p := e.toPoly + let e' := p.toExpr + if e != e' then + -- We only return some if monomials were fused + let p := mkApp4 (mkConst ``Int.Linear.Expr.eq_of_toPoly_eq) (toContextExpr ctx) (toExpr e) (toExpr e') reflBoolTrue + let r ← LinearExpr.toArith ctx e' + return some (r, p) + else + return none + +end Lean.Meta.Linear.Int diff --git a/src/Lean/Meta/Tactic/LinearArith/Nat/Basic.lean b/src/Lean/Meta/Tactic/LinearArith/Nat/Basic.lean index 23fd3dfa6a..ae354569c4 100644 --- a/src/Lean/Meta/Tactic/LinearArith/Nat/Basic.lean +++ b/src/Lean/Meta/Tactic/LinearArith/Nat/Basic.lean @@ -148,7 +148,4 @@ def toContextExpr (ctx : Array Expr) : Expr := else RArray.toExpr (mkConst ``Nat) id (RArray.leaf (mkNatLit 0)) -def reflTrue : Expr := - mkApp2 (mkConst ``Eq.refl [levelOne]) (mkConst ``Bool) (mkConst ``Bool.true) - -namespace Lean.Meta.Linear.Nat +end Lean.Meta.Linear.Nat diff --git a/src/Lean/Meta/Tactic/LinearArith/Nat/Simp.lean b/src/Lean/Meta/Tactic/LinearArith/Nat/Simp.lean index 7e282174ec..f9cc4183a5 100644 --- a/src/Lean/Meta/Tactic/LinearArith/Nat/Simp.lean +++ b/src/Lean/Meta/Tactic/LinearArith/Nat/Simp.lean @@ -4,44 +4,30 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ prelude +import Lean.Meta.Tactic.LinearArith.Basic import Lean.Meta.Tactic.LinearArith.Nat.Basic namespace Lean.Meta.Linear.Nat -/- -To prevent the kernel from accidentially reducing the atoms in the equation while typechecking, -we abstract over them. --/ -def withAbstractAtoms (atoms : Array Expr) (k : Array Expr → MetaM (Option (Expr × Expr))) : - MetaM (Option (Expr × Expr)) := do - let atoms := atoms - let decls : Array (Name × (Array Expr → MetaM Expr)) ← atoms.mapM fun _ => do - return ((← mkFreshUserName `x), fun _ => pure (mkConst ``Nat)) - withLocalDeclsD decls fun ctxt => do - let some (r, p) ← k ctxt | return none - let r := (← mkLambdaFVars ctxt r).beta atoms - let p := mkAppN (← mkLambdaFVars ctxt p) atoms - return some (r, p) - def simpCnstrPos? (e : Expr) : MetaM (Option (Expr × Expr)) := do let (some c, atoms) ← ToLinear.run (ToLinear.toLinearCnstr? e) | return none - withAbstractAtoms atoms fun ctx => do + withAbstractAtoms atoms ``Nat fun ctx => do let lhs ← c.toArith ctx let c₁ := c.toPoly let c₂ := c₁.norm if c₂.isUnsat then let r := mkConst ``False - let p := mkApp3 (mkConst ``Nat.Linear.ExprCnstr.eq_false_of_isUnsat) (toContextExpr ctx) (toExpr c) reflTrue + let p := mkApp3 (mkConst ``Nat.Linear.ExprCnstr.eq_false_of_isUnsat) (toContextExpr ctx) (toExpr c) reflBoolTrue return some (r, ← mkExpectedTypeHint p (← mkEq lhs r)) else if c₂.isValid then let r := mkConst ``True - let p := mkApp3 (mkConst ``Nat.Linear.ExprCnstr.eq_true_of_isValid) (toContextExpr ctx) (toExpr c) reflTrue + let p := mkApp3 (mkConst ``Nat.Linear.ExprCnstr.eq_true_of_isValid) (toContextExpr ctx) (toExpr c) reflBoolTrue return some (r, ← mkExpectedTypeHint p (← mkEq lhs r)) else let c₂ : LinearCnstr := c₂.toExpr let r ← c₂.toArith ctx if r != lhs then - let p := mkApp4 (mkConst ``Nat.Linear.ExprCnstr.eq_of_toNormPoly_eq) (toContextExpr ctx) (toExpr c) (toExpr c₂) reflTrue + let p := mkApp4 (mkConst ``Nat.Linear.ExprCnstr.eq_of_toNormPoly_eq) (toContextExpr ctx) (toExpr c) (toExpr c₂) reflBoolTrue return some (r, ← mkExpectedTypeHint p (← mkEq lhs r)) else return none @@ -87,7 +73,7 @@ def simpExpr? (e : Expr) : MetaM (Option (Expr × Expr)) := do if p'.length < p.length then -- We only return some if monomials were fused let e' : LinearExpr := p'.toExpr - let p := mkApp4 (mkConst ``Nat.Linear.Expr.eq_of_toNormPoly_eq) (toContextExpr ctx) (toExpr e) (toExpr e') reflTrue + let p := mkApp4 (mkConst ``Nat.Linear.Expr.eq_of_toNormPoly_eq) (toContextExpr ctx) (toExpr e) (toExpr e') reflBoolTrue let r ← e'.toArith ctx return some (r, p) else diff --git a/tests/lean/run/liaByRefl.lean b/tests/lean/run/liaByRefl.lean index 4143ec02d4..6629f3dc4b 100644 --- a/tests/lean/run/liaByRefl.lean +++ b/tests/lean/run/liaByRefl.lean @@ -17,6 +17,11 @@ example (x₁ x₂ : Int) : x₁ + x₂ + 3 := rfl +example (x₁ x₂ : Int) : + Poly.denote #R[x₁, x₂] (.add 1 0 (.add 3 1 (.num 4))) + = + 1 * x₁ + ((3 * x₂) + 4) := + rfl example (x₁ x₂ : Int) : Expr.denote #R[x₁, x₂] (.sub (.add (.mulR (.var 0) 4) (.mulL 2 (.var 1))) (.num 3))