feat: simp +arith sorts linear atoms (#7040)
This PR ensures that terms such as `f (2*x + y)` and `f (y + x + x)` have the same normal form when using `simp +arith`
This commit is contained in:
parent
0f1133fe69
commit
b87c01b1c0
9 changed files with 148 additions and 33 deletions
|
|
@ -35,11 +35,11 @@ inductive Expr where
|
|||
deriving Inhabited
|
||||
|
||||
def Expr.denote (ctx : Context) : Expr → Nat
|
||||
| Expr.add a b => Nat.add (denote ctx a) (denote ctx b)
|
||||
| Expr.num k => k
|
||||
| Expr.var v => v.denote ctx
|
||||
| Expr.mulL k e => Nat.mul k (denote ctx e)
|
||||
| Expr.mulR e k => Nat.mul (denote ctx e) k
|
||||
| .add a b => Nat.add (denote ctx a) (denote ctx b)
|
||||
| .num k => k
|
||||
| .var v => v.denote ctx
|
||||
| .mulL k e => Nat.mul k (denote ctx e)
|
||||
| .mulR e k => Nat.mul (denote ctx e) k
|
||||
|
||||
abbrev Poly := List (Nat × Var)
|
||||
|
||||
|
|
@ -146,17 +146,17 @@ where
|
|||
-- Implementation note: This assembles the result using difference lists
|
||||
-- to avoid `++` on lists.
|
||||
go (coeff : Nat) : Expr → (Poly → Poly)
|
||||
| Expr.num k => bif k == 0 then id else ((coeff * k, fixedVar) :: ·)
|
||||
| Expr.var i => ((coeff, i) :: ·)
|
||||
| Expr.add a b => go coeff a ∘ go coeff b
|
||||
| Expr.mulL k a
|
||||
| Expr.mulR a k => bif k == 0 then id else go (coeff * k) a
|
||||
| .num k => bif k == 0 then id else ((coeff * k, fixedVar) :: ·)
|
||||
| .var i => ((coeff, i) :: ·)
|
||||
| .add a b => go coeff a ∘ go coeff b
|
||||
| .mulL k a
|
||||
| .mulR a k => bif k == 0 then id else go (coeff * k) a
|
||||
|
||||
def Expr.toNormPoly (e : Expr) : Poly :=
|
||||
e.toPoly.norm
|
||||
|
||||
def Expr.inc (e : Expr) : Expr :=
|
||||
Expr.add e (Expr.num 1)
|
||||
.add e (.num 1)
|
||||
|
||||
structure PolyCnstr where
|
||||
eq : Bool
|
||||
|
|
@ -244,21 +244,21 @@ def Certificate.denote (ctx : Context) (c : Certificate) : Prop :=
|
|||
|
||||
def monomialToExpr (k : Nat) (v : Var) : Expr :=
|
||||
bif v == fixedVar then
|
||||
Expr.num k
|
||||
.num k
|
||||
else bif k == 1 then
|
||||
Expr.var v
|
||||
.var v
|
||||
else
|
||||
Expr.mulL k (Expr.var v)
|
||||
.mulL k (.var v)
|
||||
|
||||
def Poly.toExpr (p : Poly) : Expr :=
|
||||
match p with
|
||||
| [] => Expr.num 0
|
||||
| [] => .num 0
|
||||
| (k, v) :: p => go (monomialToExpr k v) p
|
||||
where
|
||||
go (e : Expr) (p : Poly) : Expr :=
|
||||
match p with
|
||||
| [] => e
|
||||
| (k, v) :: p => go (Expr.add e (monomialToExpr k v)) p
|
||||
| (k, v) :: p => go (.add e (monomialToExpr k v)) p
|
||||
|
||||
def PolyCnstr.toExpr (c : PolyCnstr) : ExprCnstr :=
|
||||
{ c with lhs := c.lhs.toExpr, rhs := c.rhs.toExpr }
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ Authors: Leonardo de Moura
|
|||
-/
|
||||
prelude
|
||||
import Init.Data.Int.Linear
|
||||
import Lean.Util.SortExprs
|
||||
import Lean.Meta.Check
|
||||
import Lean.Meta.Offset
|
||||
import Lean.Meta.IntInstTesters
|
||||
|
|
@ -31,6 +32,24 @@ def PolyCnstr.toExprCnstr : PolyCnstr → ExprCnstr
|
|||
| .eq p => .eq p.toExpr (.num 0)
|
||||
| .le p => .le p.toExpr (.num 0)
|
||||
|
||||
/-- Applies the given variable permutation to `e` -/
|
||||
def Expr.applyPerm (perm : Lean.Perm) (e : Expr) : Expr :=
|
||||
go e
|
||||
where
|
||||
go : Expr → Expr
|
||||
| .num v => .num v
|
||||
| .var i => .var (perm[(i : Nat)]?.getD i)
|
||||
| .neg a => .neg (go a)
|
||||
| .add a b => .add (go a) (go b)
|
||||
| .sub a b => .sub (go a) (go b)
|
||||
| .mulL k a => .mulL k (go a)
|
||||
| .mulR a k => .mulR (go a) k
|
||||
|
||||
/-- Applies the given variable permutation to the given expression constraint. -/
|
||||
def ExprCnstr.applyPerm (perm : Lean.Perm) : ExprCnstr → ExprCnstr
|
||||
| .eq a b => .eq (a.applyPerm perm) (b.applyPerm perm)
|
||||
| .le a b => .le (a.applyPerm perm) (b.applyPerm perm)
|
||||
|
||||
end Int.Linear
|
||||
|
||||
namespace Lean.Meta.Linear.Int
|
||||
|
|
@ -187,7 +206,24 @@ def run (x : M α) : MetaM (α × Array Expr) := do
|
|||
|
||||
end ToLinear
|
||||
|
||||
export ToLinear (toLinearCnstr? toLinearExpr)
|
||||
def toLinearExpr (e : Expr) : MetaM (LinearExpr × Array Expr) := do
|
||||
let (e, atoms) ← ToLinear.run (ToLinear.toLinearExpr e)
|
||||
if atoms.size == 1 then
|
||||
return (e, atoms)
|
||||
else
|
||||
let (atoms, perm) := sortExprs atoms
|
||||
let e := e.applyPerm perm
|
||||
return (e, atoms)
|
||||
|
||||
def toLinearCnstr? (e : Expr) : MetaM (Option (LinearCnstr × Array Expr)) := do
|
||||
let (some c, atoms) ← ToLinear.run (ToLinear.toLinearCnstr? e)
|
||||
| return none
|
||||
if atoms.size <= 1 then
|
||||
return some (c, atoms)
|
||||
else
|
||||
let (atoms, perm) := sortExprs atoms
|
||||
let c := c.applyPerm perm
|
||||
return some (c, atoms)
|
||||
|
||||
def toContextExpr (ctx : Array Expr) : Expr :=
|
||||
if h : 0 < ctx.size then
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ def Int.Linear.PolyCnstr.getConst : PolyCnstr → Int
|
|||
namespace Lean.Meta.Linear.Int
|
||||
|
||||
def simpCnstrPos? (e : Expr) : MetaM (Option (Expr × Expr)) := do
|
||||
let (some c, atoms) ← ToLinear.run (ToLinear.toLinearCnstr? e) | return none
|
||||
let some (c, atoms) ← toLinearCnstr? e | return none
|
||||
withAbstractAtoms atoms ``Int fun atoms => do
|
||||
let lhs ← c.toArith atoms
|
||||
let p := c.toPoly
|
||||
|
|
@ -127,13 +127,13 @@ def simpCnstr? (e : Expr) : MetaM (Option (Expr × Expr)) := do
|
|||
simpCnstrPos? e
|
||||
|
||||
def simpExpr? (e : Expr) : MetaM (Option (Expr × Expr)) := do
|
||||
let (e, ctx) ← ToLinear.run (ToLinear.toLinearExpr e)
|
||||
let (e, atoms) ← toLinearExpr e
|
||||
let p := e.toPoly
|
||||
let e' := p.toExpr
|
||||
if e != e' then
|
||||
-- We only return some if monomials were fused
|
||||
let p := mkApp4 (mkConst ``Int.Linear.Expr.eq_of_toPoly_eq) (toContextExpr ctx) (toExpr e) (toExpr e') reflBoolTrue
|
||||
let r ← LinearExpr.toArith ctx e'
|
||||
let p := mkApp4 (mkConst ``Int.Linear.Expr.eq_of_toPoly_eq) (toContextExpr atoms) (toExpr e) (toExpr e') reflBoolTrue
|
||||
let r ← LinearExpr.toArith atoms e'
|
||||
return some (r, p)
|
||||
else
|
||||
return none
|
||||
|
|
|
|||
|
|
@ -4,12 +4,32 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
|||
Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Lean.Util.SortExprs
|
||||
import Lean.Meta.Check
|
||||
import Lean.Meta.Offset
|
||||
import Lean.Meta.AppBuilder
|
||||
import Lean.Meta.KExprMap
|
||||
import Lean.Data.RArray
|
||||
|
||||
namespace Nat.Linear
|
||||
|
||||
/-- Applies the given variable permutation to `e` -/
|
||||
def Expr.applyPerm (perm : Lean.Perm) (e : Expr) : Expr :=
|
||||
go e
|
||||
where
|
||||
go : Expr → Expr
|
||||
| .num v => .num v
|
||||
| .var i => .var (perm[(i : Nat)]?.getD i)
|
||||
| .add a b => .add (go a) (go b)
|
||||
| .mulL k a => .mulL k (go a)
|
||||
| .mulR a k => .mulR (go a) k
|
||||
|
||||
/-- Applies the given variable permutation to the given expression constraint. -/
|
||||
def ExprCnstr.applyPerm (perm : Lean.Perm) : ExprCnstr → ExprCnstr
|
||||
| { eq, lhs, rhs } => { eq, lhs := lhs.applyPerm perm, rhs := rhs.applyPerm perm }
|
||||
|
||||
end Nat.Linear
|
||||
|
||||
namespace Lean.Meta.Linear.Nat
|
||||
|
||||
deriving instance Repr for Nat.Linear.Expr
|
||||
|
|
@ -140,7 +160,24 @@ def run (x : M α) : MetaM (α × Array Expr) := do
|
|||
|
||||
end ToLinear
|
||||
|
||||
export ToLinear (toLinearCnstr? toLinearExpr)
|
||||
def toLinearExpr (e : Expr) : MetaM (LinearExpr × Array Expr) := do
|
||||
let (e, atoms) ← ToLinear.run (ToLinear.toLinearExpr e)
|
||||
if atoms.size == 1 then
|
||||
return (e, atoms)
|
||||
else
|
||||
let (atoms, perm) := sortExprs atoms
|
||||
let e := e.applyPerm perm
|
||||
return (e, atoms)
|
||||
|
||||
def toLinearCnstr? (e : Expr) : MetaM (Option (LinearCnstr × Array Expr)) := do
|
||||
let (some c, atoms) ← ToLinear.run (ToLinear.toLinearCnstr? e)
|
||||
| return none
|
||||
if atoms.size <= 1 then
|
||||
return some (c, atoms)
|
||||
else
|
||||
let (atoms, perm) := sortExprs atoms
|
||||
let c := c.applyPerm perm
|
||||
return some (c, atoms)
|
||||
|
||||
def toContextExpr (ctx : Array Expr) : Expr :=
|
||||
if h : 0 < ctx.size then
|
||||
|
|
@ -148,4 +185,4 @@ def toContextExpr (ctx : Array Expr) : Expr :=
|
|||
else
|
||||
RArray.toExpr (mkConst ``Nat) id (RArray.leaf (mkNatLit 0))
|
||||
|
||||
end Lean.Meta.Linear.Nat
|
||||
namespace Lean.Meta.Linear.Nat
|
||||
|
|
|
|||
|
|
@ -10,7 +10,8 @@ import Lean.Meta.Tactic.LinearArith.Nat.Basic
|
|||
namespace Lean.Meta.Linear.Nat
|
||||
|
||||
def simpCnstrPos? (e : Expr) : MetaM (Option (Expr × Expr)) := do
|
||||
let (some c, atoms) ← ToLinear.run (ToLinear.toLinearCnstr? e) | return none
|
||||
let some (c, atoms) ← toLinearCnstr? e
|
||||
| return none
|
||||
withAbstractAtoms atoms ``Nat fun atoms => do
|
||||
let lhs ← c.toArith atoms
|
||||
let c₁ := c.toPoly
|
||||
|
|
@ -67,7 +68,7 @@ def simpCnstr? (e : Expr) : MetaM (Option (Expr × Expr)) := do
|
|||
simpCnstrPos? e
|
||||
|
||||
def simpExpr? (e : Expr) : MetaM (Option (Expr × Expr)) := do
|
||||
let (e, ctx) ← ToLinear.run (ToLinear.toLinearExpr e)
|
||||
let (e, ctx) ← toLinearExpr e
|
||||
let p := e.toPoly
|
||||
let p' := p.norm
|
||||
if p'.length < p.length then
|
||||
|
|
|
|||
|
|
@ -35,3 +35,4 @@ import Lean.Util.SafeExponentiation
|
|||
import Lean.Util.NumObjs
|
||||
import Lean.Util.NumApps
|
||||
import Lean.Util.FVarSubset
|
||||
import Lean.Util.SortExprs
|
||||
|
|
|
|||
23
src/Lean/Util/SortExprs.lean
Normal file
23
src/Lean/Util/SortExprs.lean
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
/-
|
||||
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
prelude
|
||||
import Lean.Expr
|
||||
|
||||
namespace Lean
|
||||
|
||||
abbrev Perm := Std.HashMap Nat Nat
|
||||
|
||||
/--
|
||||
Sorts the given expressions using `Expr.lt`, and creates a "permutation map" storing the new position of each expression.
|
||||
-/
|
||||
def sortExprs (es : Array Expr) : Array Expr × Perm :=
|
||||
let es := es.mapIdx fun i e => (e, i)
|
||||
let es := es.qsort fun (e₁, _) (e₂, _) => e₁.lt e₂
|
||||
let (_, perm) := es.foldl (init := (0, Std.HashMap.empty)) fun (i, perm) (_, j) => (i+1, perm.insert j i)
|
||||
let es := es.map (·.1)
|
||||
(es, perm)
|
||||
|
||||
end Lean
|
||||
|
|
@ -171,18 +171,18 @@ fun x y z f =>
|
|||
id
|
||||
(Int.Linear.ExprCnstr.eq_true_of_isValid
|
||||
(Lean.RArray.branch 1 (Lean.RArray.leaf x)
|
||||
(Lean.RArray.branch 2 (Lean.RArray.leaf x_1) (Lean.RArray.leaf z)))
|
||||
(Lean.RArray.branch 2 (Lean.RArray.leaf z) (Lean.RArray.leaf x_1)))
|
||||
(Int.Linear.ExprCnstr.le
|
||||
((((((Int.Linear.Expr.var 0).add (Int.Linear.Expr.var 1)).add (Int.Linear.Expr.num 2)).add
|
||||
(Int.Linear.Expr.var 1)).add
|
||||
(Int.Linear.Expr.var 2)).add
|
||||
(Int.Linear.Expr.var 2))
|
||||
(((((((Int.Linear.Expr.var 1).add (Int.Linear.Expr.mulL 3 (Int.Linear.Expr.var 2))).add
|
||||
((((((Int.Linear.Expr.var 0).add (Int.Linear.Expr.var 2)).add (Int.Linear.Expr.num 2)).add
|
||||
(Int.Linear.Expr.var 2)).add
|
||||
(Int.Linear.Expr.var 1)).add
|
||||
(Int.Linear.Expr.var 1))
|
||||
(((((((Int.Linear.Expr.var 2).add (Int.Linear.Expr.mulL 3 (Int.Linear.Expr.var 1))).add
|
||||
(Int.Linear.Expr.num 1)).add
|
||||
(Int.Linear.Expr.num 1)).add
|
||||
(Int.Linear.Expr.var 0)).add
|
||||
(Int.Linear.Expr.var 1)).sub
|
||||
(Int.Linear.Expr.var 2)))
|
||||
(Int.Linear.Expr.var 2)).sub
|
||||
(Int.Linear.Expr.var 1)))
|
||||
(Eq.refl true)))
|
||||
(f y))
|
||||
-/
|
||||
|
|
@ -256,3 +256,12 @@ example (x : Int) : (11*x ≤ 10) ↔ (x ≤ 0) := by
|
|||
|
||||
example (x : Int) : (11*x > 10) ↔ (x ≥ 1) := by
|
||||
simp +arith only
|
||||
|
||||
example (x y : Int) : (2*x + y + y = 4) ↔ (y + x = 2) := by
|
||||
simp +arith
|
||||
|
||||
example (x y : Int) : (2*x + y + y ≤ 3) ↔ (y + x ≤ 1) := by
|
||||
simp +arith
|
||||
|
||||
example (f : Int → Int) (x y : Int) : f (2*x + y) = f (y + x + x) := by
|
||||
simp +arith
|
||||
|
|
|
|||
8
tests/lean/run/simp_nat_arith.lean
Normal file
8
tests/lean/run/simp_nat_arith.lean
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
example (x y : Nat) : (2*x + y = 4) ↔ (y + x + x = 4) := by
|
||||
simp +arith
|
||||
|
||||
example (x y : Nat) : (2*x + y ≤ 3) ↔ (y + x + x ≤ 3) := by
|
||||
simp +arith
|
||||
|
||||
example (f : Nat → Nat) (x y : Nat) : f (2*x + y) = f (y + x + x) := by
|
||||
simp +arith
|
||||
Loading…
Add table
Reference in a new issue