From b87c01b1c02e42d1a1657405cbab057aef7fafa6 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 11 Feb 2025 15:37:30 -0800 Subject: [PATCH] feat: `simp +arith` sorts linear atoms (#7040) This PR ensures that terms such as `f (2*x + y)` and `f (y + x + x)` have the same normal form when using `simp +arith` --- src/Init/Data/Nat/Linear.lean | 32 +++++++-------- .../Meta/Tactic/LinearArith/Int/Basic.lean | 38 ++++++++++++++++- .../Meta/Tactic/LinearArith/Int/Simp.lean | 8 ++-- .../Meta/Tactic/LinearArith/Nat/Basic.lean | 41 ++++++++++++++++++- .../Meta/Tactic/LinearArith/Nat/Simp.lean | 5 ++- src/Lean/Util.lean | 1 + src/Lean/Util/SortExprs.lean | 23 +++++++++++ tests/lean/run/simp_int_arith.lean | 25 +++++++---- tests/lean/run/simp_nat_arith.lean | 8 ++++ 9 files changed, 148 insertions(+), 33 deletions(-) create mode 100644 src/Lean/Util/SortExprs.lean create mode 100644 tests/lean/run/simp_nat_arith.lean diff --git a/src/Init/Data/Nat/Linear.lean b/src/Init/Data/Nat/Linear.lean index 7ab2fae41f..3100e065e7 100644 --- a/src/Init/Data/Nat/Linear.lean +++ b/src/Init/Data/Nat/Linear.lean @@ -35,11 +35,11 @@ inductive Expr where deriving Inhabited def Expr.denote (ctx : Context) : Expr → Nat - | Expr.add a b => Nat.add (denote ctx a) (denote ctx b) - | Expr.num k => k - | Expr.var v => v.denote ctx - | Expr.mulL k e => Nat.mul k (denote ctx e) - | Expr.mulR e k => Nat.mul (denote ctx e) k + | .add a b => Nat.add (denote ctx a) (denote ctx b) + | .num k => k + | .var v => v.denote ctx + | .mulL k e => Nat.mul k (denote ctx e) + | .mulR e k => Nat.mul (denote ctx e) k abbrev Poly := List (Nat × Var) @@ -146,17 +146,17 @@ where -- Implementation note: This assembles the result using difference lists -- to avoid `++` on lists. go (coeff : Nat) : Expr → (Poly → Poly) - | Expr.num k => bif k == 0 then id else ((coeff * k, fixedVar) :: ·) - | Expr.var i => ((coeff, i) :: ·) - | Expr.add a b => go coeff a ∘ go coeff b - | Expr.mulL k a - | Expr.mulR a k => bif k == 0 then id else go (coeff * k) a + | .num k => bif k == 0 then id else ((coeff * k, fixedVar) :: ·) + | .var i => ((coeff, i) :: ·) + | .add a b => go coeff a ∘ go coeff b + | .mulL k a + | .mulR a k => bif k == 0 then id else go (coeff * k) a def Expr.toNormPoly (e : Expr) : Poly := e.toPoly.norm def Expr.inc (e : Expr) : Expr := - Expr.add e (Expr.num 1) + .add e (.num 1) structure PolyCnstr where eq : Bool @@ -244,21 +244,21 @@ def Certificate.denote (ctx : Context) (c : Certificate) : Prop := def monomialToExpr (k : Nat) (v : Var) : Expr := bif v == fixedVar then - Expr.num k + .num k else bif k == 1 then - Expr.var v + .var v else - Expr.mulL k (Expr.var v) + .mulL k (.var v) def Poly.toExpr (p : Poly) : Expr := match p with - | [] => Expr.num 0 + | [] => .num 0 | (k, v) :: p => go (monomialToExpr k v) p where go (e : Expr) (p : Poly) : Expr := match p with | [] => e - | (k, v) :: p => go (Expr.add e (monomialToExpr k v)) p + | (k, v) :: p => go (.add e (monomialToExpr k v)) p def PolyCnstr.toExpr (c : PolyCnstr) : ExprCnstr := { c with lhs := c.lhs.toExpr, rhs := c.rhs.toExpr } diff --git a/src/Lean/Meta/Tactic/LinearArith/Int/Basic.lean b/src/Lean/Meta/Tactic/LinearArith/Int/Basic.lean index 11b5850eea..c7a7585075 100644 --- a/src/Lean/Meta/Tactic/LinearArith/Int/Basic.lean +++ b/src/Lean/Meta/Tactic/LinearArith/Int/Basic.lean @@ -5,6 +5,7 @@ Authors: Leonardo de Moura -/ prelude import Init.Data.Int.Linear +import Lean.Util.SortExprs import Lean.Meta.Check import Lean.Meta.Offset import Lean.Meta.IntInstTesters @@ -31,6 +32,24 @@ def PolyCnstr.toExprCnstr : PolyCnstr → ExprCnstr | .eq p => .eq p.toExpr (.num 0) | .le p => .le p.toExpr (.num 0) +/-- Applies the given variable permutation to `e` -/ +def Expr.applyPerm (perm : Lean.Perm) (e : Expr) : Expr := + go e +where + go : Expr → Expr + | .num v => .num v + | .var i => .var (perm[(i : Nat)]?.getD i) + | .neg a => .neg (go a) + | .add a b => .add (go a) (go b) + | .sub a b => .sub (go a) (go b) + | .mulL k a => .mulL k (go a) + | .mulR a k => .mulR (go a) k + +/-- Applies the given variable permutation to the given expression constraint. -/ +def ExprCnstr.applyPerm (perm : Lean.Perm) : ExprCnstr → ExprCnstr + | .eq a b => .eq (a.applyPerm perm) (b.applyPerm perm) + | .le a b => .le (a.applyPerm perm) (b.applyPerm perm) + end Int.Linear namespace Lean.Meta.Linear.Int @@ -187,7 +206,24 @@ def run (x : M α) : MetaM (α × Array Expr) := do end ToLinear -export ToLinear (toLinearCnstr? toLinearExpr) +def toLinearExpr (e : Expr) : MetaM (LinearExpr × Array Expr) := do + let (e, atoms) ← ToLinear.run (ToLinear.toLinearExpr e) + if atoms.size == 1 then + return (e, atoms) + else + let (atoms, perm) := sortExprs atoms + let e := e.applyPerm perm + return (e, atoms) + +def toLinearCnstr? (e : Expr) : MetaM (Option (LinearCnstr × Array Expr)) := do + let (some c, atoms) ← ToLinear.run (ToLinear.toLinearCnstr? e) + | return none + if atoms.size <= 1 then + return some (c, atoms) + else + let (atoms, perm) := sortExprs atoms + let c := c.applyPerm perm + return some (c, atoms) def toContextExpr (ctx : Array Expr) : Expr := if h : 0 < ctx.size then diff --git a/src/Lean/Meta/Tactic/LinearArith/Int/Simp.lean b/src/Lean/Meta/Tactic/LinearArith/Int/Simp.lean index 4ee6a3ce33..72b29abdbf 100644 --- a/src/Lean/Meta/Tactic/LinearArith/Int/Simp.lean +++ b/src/Lean/Meta/Tactic/LinearArith/Int/Simp.lean @@ -44,7 +44,7 @@ def Int.Linear.PolyCnstr.getConst : PolyCnstr → Int namespace Lean.Meta.Linear.Int def simpCnstrPos? (e : Expr) : MetaM (Option (Expr × Expr)) := do - let (some c, atoms) ← ToLinear.run (ToLinear.toLinearCnstr? e) | return none + let some (c, atoms) ← toLinearCnstr? e | return none withAbstractAtoms atoms ``Int fun atoms => do let lhs ← c.toArith atoms let p := c.toPoly @@ -127,13 +127,13 @@ def simpCnstr? (e : Expr) : MetaM (Option (Expr × Expr)) := do simpCnstrPos? e def simpExpr? (e : Expr) : MetaM (Option (Expr × Expr)) := do - let (e, ctx) ← ToLinear.run (ToLinear.toLinearExpr e) + let (e, atoms) ← toLinearExpr e let p := e.toPoly let e' := p.toExpr if e != e' then -- We only return some if monomials were fused - let p := mkApp4 (mkConst ``Int.Linear.Expr.eq_of_toPoly_eq) (toContextExpr ctx) (toExpr e) (toExpr e') reflBoolTrue - let r ← LinearExpr.toArith ctx e' + let p := mkApp4 (mkConst ``Int.Linear.Expr.eq_of_toPoly_eq) (toContextExpr atoms) (toExpr e) (toExpr e') reflBoolTrue + let r ← LinearExpr.toArith atoms e' return some (r, p) else return none diff --git a/src/Lean/Meta/Tactic/LinearArith/Nat/Basic.lean b/src/Lean/Meta/Tactic/LinearArith/Nat/Basic.lean index ae354569c4..73ac8c148b 100644 --- a/src/Lean/Meta/Tactic/LinearArith/Nat/Basic.lean +++ b/src/Lean/Meta/Tactic/LinearArith/Nat/Basic.lean @@ -4,12 +4,32 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ prelude +import Lean.Util.SortExprs import Lean.Meta.Check import Lean.Meta.Offset import Lean.Meta.AppBuilder import Lean.Meta.KExprMap import Lean.Data.RArray +namespace Nat.Linear + +/-- Applies the given variable permutation to `e` -/ +def Expr.applyPerm (perm : Lean.Perm) (e : Expr) : Expr := + go e +where + go : Expr → Expr + | .num v => .num v + | .var i => .var (perm[(i : Nat)]?.getD i) + | .add a b => .add (go a) (go b) + | .mulL k a => .mulL k (go a) + | .mulR a k => .mulR (go a) k + +/-- Applies the given variable permutation to the given expression constraint. -/ +def ExprCnstr.applyPerm (perm : Lean.Perm) : ExprCnstr → ExprCnstr + | { eq, lhs, rhs } => { eq, lhs := lhs.applyPerm perm, rhs := rhs.applyPerm perm } + +end Nat.Linear + namespace Lean.Meta.Linear.Nat deriving instance Repr for Nat.Linear.Expr @@ -140,7 +160,24 @@ def run (x : M α) : MetaM (α × Array Expr) := do end ToLinear -export ToLinear (toLinearCnstr? toLinearExpr) +def toLinearExpr (e : Expr) : MetaM (LinearExpr × Array Expr) := do + let (e, atoms) ← ToLinear.run (ToLinear.toLinearExpr e) + if atoms.size == 1 then + return (e, atoms) + else + let (atoms, perm) := sortExprs atoms + let e := e.applyPerm perm + return (e, atoms) + +def toLinearCnstr? (e : Expr) : MetaM (Option (LinearCnstr × Array Expr)) := do + let (some c, atoms) ← ToLinear.run (ToLinear.toLinearCnstr? e) + | return none + if atoms.size <= 1 then + return some (c, atoms) + else + let (atoms, perm) := sortExprs atoms + let c := c.applyPerm perm + return some (c, atoms) def toContextExpr (ctx : Array Expr) : Expr := if h : 0 < ctx.size then @@ -148,4 +185,4 @@ def toContextExpr (ctx : Array Expr) : Expr := else RArray.toExpr (mkConst ``Nat) id (RArray.leaf (mkNatLit 0)) -end Lean.Meta.Linear.Nat +namespace Lean.Meta.Linear.Nat diff --git a/src/Lean/Meta/Tactic/LinearArith/Nat/Simp.lean b/src/Lean/Meta/Tactic/LinearArith/Nat/Simp.lean index ffcb5127c7..2e8dea1ba9 100644 --- a/src/Lean/Meta/Tactic/LinearArith/Nat/Simp.lean +++ b/src/Lean/Meta/Tactic/LinearArith/Nat/Simp.lean @@ -10,7 +10,8 @@ import Lean.Meta.Tactic.LinearArith.Nat.Basic namespace Lean.Meta.Linear.Nat def simpCnstrPos? (e : Expr) : MetaM (Option (Expr × Expr)) := do - let (some c, atoms) ← ToLinear.run (ToLinear.toLinearCnstr? e) | return none + let some (c, atoms) ← toLinearCnstr? e + | return none withAbstractAtoms atoms ``Nat fun atoms => do let lhs ← c.toArith atoms let c₁ := c.toPoly @@ -67,7 +68,7 @@ def simpCnstr? (e : Expr) : MetaM (Option (Expr × Expr)) := do simpCnstrPos? e def simpExpr? (e : Expr) : MetaM (Option (Expr × Expr)) := do - let (e, ctx) ← ToLinear.run (ToLinear.toLinearExpr e) + let (e, ctx) ← toLinearExpr e let p := e.toPoly let p' := p.norm if p'.length < p.length then diff --git a/src/Lean/Util.lean b/src/Lean/Util.lean index 8c84487851..360a784f5b 100644 --- a/src/Lean/Util.lean +++ b/src/Lean/Util.lean @@ -35,3 +35,4 @@ import Lean.Util.SafeExponentiation import Lean.Util.NumObjs import Lean.Util.NumApps import Lean.Util.FVarSubset +import Lean.Util.SortExprs diff --git a/src/Lean/Util/SortExprs.lean b/src/Lean/Util/SortExprs.lean new file mode 100644 index 0000000000..1a91d643f9 --- /dev/null +++ b/src/Lean/Util/SortExprs.lean @@ -0,0 +1,23 @@ +/- +Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import Lean.Expr + +namespace Lean + +abbrev Perm := Std.HashMap Nat Nat + +/-- +Sorts the given expressions using `Expr.lt`, and creates a "permutation map" storing the new position of each expression. +-/ +def sortExprs (es : Array Expr) : Array Expr × Perm := + let es := es.mapIdx fun i e => (e, i) + let es := es.qsort fun (e₁, _) (e₂, _) => e₁.lt e₂ + let (_, perm) := es.foldl (init := (0, Std.HashMap.empty)) fun (i, perm) (_, j) => (i+1, perm.insert j i) + let es := es.map (·.1) + (es, perm) + +end Lean diff --git a/tests/lean/run/simp_int_arith.lean b/tests/lean/run/simp_int_arith.lean index bd9c0a657b..8409f382ea 100644 --- a/tests/lean/run/simp_int_arith.lean +++ b/tests/lean/run/simp_int_arith.lean @@ -171,18 +171,18 @@ fun x y z f => id (Int.Linear.ExprCnstr.eq_true_of_isValid (Lean.RArray.branch 1 (Lean.RArray.leaf x) - (Lean.RArray.branch 2 (Lean.RArray.leaf x_1) (Lean.RArray.leaf z))) + (Lean.RArray.branch 2 (Lean.RArray.leaf z) (Lean.RArray.leaf x_1))) (Int.Linear.ExprCnstr.le - ((((((Int.Linear.Expr.var 0).add (Int.Linear.Expr.var 1)).add (Int.Linear.Expr.num 2)).add - (Int.Linear.Expr.var 1)).add - (Int.Linear.Expr.var 2)).add - (Int.Linear.Expr.var 2)) - (((((((Int.Linear.Expr.var 1).add (Int.Linear.Expr.mulL 3 (Int.Linear.Expr.var 2))).add + ((((((Int.Linear.Expr.var 0).add (Int.Linear.Expr.var 2)).add (Int.Linear.Expr.num 2)).add + (Int.Linear.Expr.var 2)).add + (Int.Linear.Expr.var 1)).add + (Int.Linear.Expr.var 1)) + (((((((Int.Linear.Expr.var 2).add (Int.Linear.Expr.mulL 3 (Int.Linear.Expr.var 1))).add (Int.Linear.Expr.num 1)).add (Int.Linear.Expr.num 1)).add (Int.Linear.Expr.var 0)).add - (Int.Linear.Expr.var 1)).sub - (Int.Linear.Expr.var 2))) + (Int.Linear.Expr.var 2)).sub + (Int.Linear.Expr.var 1))) (Eq.refl true))) (f y)) -/ @@ -256,3 +256,12 @@ example (x : Int) : (11*x ≤ 10) ↔ (x ≤ 0) := by example (x : Int) : (11*x > 10) ↔ (x ≥ 1) := by simp +arith only + +example (x y : Int) : (2*x + y + y = 4) ↔ (y + x = 2) := by + simp +arith + +example (x y : Int) : (2*x + y + y ≤ 3) ↔ (y + x ≤ 1) := by + simp +arith + +example (f : Int → Int) (x y : Int) : f (2*x + y) = f (y + x + x) := by + simp +arith diff --git a/tests/lean/run/simp_nat_arith.lean b/tests/lean/run/simp_nat_arith.lean new file mode 100644 index 0000000000..9bda711cb5 --- /dev/null +++ b/tests/lean/run/simp_nat_arith.lean @@ -0,0 +1,8 @@ +example (x y : Nat) : (2*x + y = 4) ↔ (y + x + x = 4) := by + simp +arith + +example (x y : Nat) : (2*x + y ≤ 3) ↔ (y + x + x ≤ 3) := by + simp +arith + +example (f : Nat → Nat) (x y : Nat) : f (2*x + y) = f (y + x + x) := by + simp +arith