From bcffbdd3a14a5aa21e52ec5594bd3f2a6651cdbe Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 9 Feb 2025 14:46:09 -0800 Subject: [PATCH] chore: improve `withAbstractAtoms` (#7012) We should not abstract free variables --- src/Lean/Meta/Tactic/LinearArith/Basic.lean | 26 ++++++--- .../Meta/Tactic/LinearArith/Int/Simp.lean | 20 +++---- .../Meta/Tactic/LinearArith/Nat/Simp.lean | 12 ++-- tests/lean/run/simp_int_arith.lean | 55 +++++++++++++++++++ 4 files changed, 89 insertions(+), 24 deletions(-) diff --git a/src/Lean/Meta/Tactic/LinearArith/Basic.lean b/src/Lean/Meta/Tactic/LinearArith/Basic.lean index 25ec2bd3f1..3be0041b5c 100644 --- a/src/Lean/Meta/Tactic/LinearArith/Basic.lean +++ b/src/Lean/Meta/Tactic/LinearArith/Basic.lean @@ -14,14 +14,24 @@ we abstract over them. -/ def withAbstractAtoms (atoms : Array Expr) (type : Name) (k : Array Expr → MetaM (Option (Expr × Expr))) : MetaM (Option (Expr × Expr)) := do - let atoms := atoms - let decls : Array (Name × (Array Expr → MetaM Expr)) ← atoms.mapM fun _ => do - return ((← mkFreshUserName `x), fun _ => pure (mkConst type)) - withLocalDeclsD decls fun ctxt => do - let some (r, p) ← k ctxt | return none - let r := (← mkLambdaFVars ctxt r).beta atoms - let p := mkAppN (← mkLambdaFVars ctxt p) atoms - return some (r, p) + let type := mkConst type + let rec go (i : Nat) (atoms' : Array Expr) (xs : Array Expr) (args : Array Expr) : MetaM (Option (Expr × Expr)) := do + if h : i < atoms.size then + let atom := atoms[i] + if atom.isFVar then + go (i+1) (atoms'.push atom) xs args + else + withLocalDeclD (← mkFreshUserName `x) type fun x => + go (i+1) (atoms'.push x) (xs.push x) (args.push atom) + else + if xs.isEmpty then + k atoms' + else + let some (r, p) ← k atoms' | return none + let r := (← mkLambdaFVars xs r).beta args + let p := mkAppN (← mkLambdaFVars xs p) args + return some (r, p) + go 0 #[] #[] #[] /-- Quick filter for linear terms. -/ def isLinearTerm (e : Expr) : Bool := diff --git a/src/Lean/Meta/Tactic/LinearArith/Int/Simp.lean b/src/Lean/Meta/Tactic/LinearArith/Int/Simp.lean index 3f3c8cacfd..e8642fdf38 100644 --- a/src/Lean/Meta/Tactic/LinearArith/Int/Simp.lean +++ b/src/Lean/Meta/Tactic/LinearArith/Int/Simp.lean @@ -11,32 +11,32 @@ namespace Lean.Meta.Linear.Int def simpCnstrPos? (e : Expr) : MetaM (Option (Expr × Expr)) := do let (some c, atoms) ← ToLinear.run (ToLinear.toLinearCnstr? e) | return none - withAbstractAtoms atoms ``Int fun ctx => do - let lhs ← c.toArith ctx + withAbstractAtoms atoms ``Int fun atoms => do + let lhs ← c.toArith atoms let p := c.toPoly if p.isUnsat then let r := mkConst ``False - let p := mkApp3 (mkConst ``Int.Linear.ExprCnstr.eq_false_of_isUnsat) (toContextExpr ctx) (toExpr c) reflBoolTrue + let p := mkApp3 (mkConst ``Int.Linear.ExprCnstr.eq_false_of_isUnsat) (toContextExpr atoms) (toExpr c) reflBoolTrue return some (r, ← mkExpectedTypeHint p (← mkEq lhs r)) else if p.isValid then let r := mkConst ``True - let p := mkApp3 (mkConst ``Int.Linear.ExprCnstr.eq_true_of_isValid) (toContextExpr ctx) (toExpr c) reflBoolTrue + let p := mkApp3 (mkConst ``Int.Linear.ExprCnstr.eq_true_of_isValid) (toContextExpr atoms) (toExpr c) reflBoolTrue return some (r, ← mkExpectedTypeHint p (← mkEq lhs r)) else let c' : LinearCnstr := p.toExprCnstr if c != c' then match p with | .eq (.add 1 x (.add (-1) y (.num 0))) => - let r := mkIntEq ctx[x]! ctx[y]! - let p := mkApp5 (mkConst ``Int.Linear.ExprCnstr.eq_of_toPoly_eq_var) (toContextExpr ctx) (toExpr x) (toExpr y) (toExpr c) reflBoolTrue + let r := mkIntEq atoms[x]! atoms[y]! + let p := mkApp5 (mkConst ``Int.Linear.ExprCnstr.eq_of_toPoly_eq_var) (toContextExpr atoms) (toExpr x) (toExpr y) (toExpr c) reflBoolTrue return some (r, ← mkExpectedTypeHint p (← mkEq lhs r)) | .eq (.add 1 x (.num k)) => - let r := mkIntEq ctx[x]! (toExpr (-k)) - let p := mkApp5 (mkConst ``Int.Linear.ExprCnstr.eq_of_toPoly_eq_const) (toContextExpr ctx) (toExpr x) (toExpr (-k)) (toExpr c) reflBoolTrue + let r := mkIntEq atoms[x]! (toExpr (-k)) + let p := mkApp5 (mkConst ``Int.Linear.ExprCnstr.eq_of_toPoly_eq_const) (toContextExpr atoms) (toExpr x) (toExpr (-k)) (toExpr c) reflBoolTrue return some (r, ← mkExpectedTypeHint p (← mkEq lhs r)) | _ => - let r ← c'.toArith ctx - let p := mkApp4 (mkConst ``Int.Linear.ExprCnstr.eq_of_toPoly_eq) (toContextExpr ctx) (toExpr c) (toExpr c') reflBoolTrue + let r ← c'.toArith atoms + let p := mkApp4 (mkConst ``Int.Linear.ExprCnstr.eq_of_toPoly_eq) (toContextExpr atoms) (toExpr c) (toExpr c') reflBoolTrue return some (r, ← mkExpectedTypeHint p (← mkEq lhs r)) else return none diff --git a/src/Lean/Meta/Tactic/LinearArith/Nat/Simp.lean b/src/Lean/Meta/Tactic/LinearArith/Nat/Simp.lean index f9cc4183a5..ffcb5127c7 100644 --- a/src/Lean/Meta/Tactic/LinearArith/Nat/Simp.lean +++ b/src/Lean/Meta/Tactic/LinearArith/Nat/Simp.lean @@ -11,23 +11,23 @@ namespace Lean.Meta.Linear.Nat def simpCnstrPos? (e : Expr) : MetaM (Option (Expr × Expr)) := do let (some c, atoms) ← ToLinear.run (ToLinear.toLinearCnstr? e) | return none - withAbstractAtoms atoms ``Nat fun ctx => do - let lhs ← c.toArith ctx + withAbstractAtoms atoms ``Nat fun atoms => do + let lhs ← c.toArith atoms let c₁ := c.toPoly let c₂ := c₁.norm if c₂.isUnsat then let r := mkConst ``False - let p := mkApp3 (mkConst ``Nat.Linear.ExprCnstr.eq_false_of_isUnsat) (toContextExpr ctx) (toExpr c) reflBoolTrue + let p := mkApp3 (mkConst ``Nat.Linear.ExprCnstr.eq_false_of_isUnsat) (toContextExpr atoms) (toExpr c) reflBoolTrue return some (r, ← mkExpectedTypeHint p (← mkEq lhs r)) else if c₂.isValid then let r := mkConst ``True - let p := mkApp3 (mkConst ``Nat.Linear.ExprCnstr.eq_true_of_isValid) (toContextExpr ctx) (toExpr c) reflBoolTrue + let p := mkApp3 (mkConst ``Nat.Linear.ExprCnstr.eq_true_of_isValid) (toContextExpr atoms) (toExpr c) reflBoolTrue return some (r, ← mkExpectedTypeHint p (← mkEq lhs r)) else let c₂ : LinearCnstr := c₂.toExpr - let r ← c₂.toArith ctx + let r ← c₂.toArith atoms if r != lhs then - let p := mkApp4 (mkConst ``Nat.Linear.ExprCnstr.eq_of_toNormPoly_eq) (toContextExpr ctx) (toExpr c) (toExpr c₂) reflBoolTrue + let p := mkApp4 (mkConst ``Nat.Linear.ExprCnstr.eq_of_toNormPoly_eq) (toContextExpr atoms) (toExpr c) (toExpr c₂) reflBoolTrue return some (r, ← mkExpectedTypeHint p (← mkEq lhs r)) else return none diff --git a/tests/lean/run/simp_int_arith.lean b/tests/lean/run/simp_int_arith.lean index 06f205a49d..4078df421f 100644 --- a/tests/lean/run/simp_int_arith.lean +++ b/tests/lean/run/simp_int_arith.lean @@ -133,3 +133,58 @@ example (x : Int) (h : False) : x > x := by simp +arith only guard_target = False assumption + +theorem ex₁ (x y z : Int) : x + y + 2 + y + z + z ≤ y + 3*z + 1 + 1 + x + y - z := by + simp +arith only + +/-- +info: theorem ex₁ : ∀ (x y z : Int), x + y + 2 + y + z + z ≤ y + 3 * z + 1 + 1 + x + y - z := +fun x y z => + of_eq_true + (id + (Int.Linear.ExprCnstr.eq_true_of_isValid + (Lean.RArray.branch 1 (Lean.RArray.leaf x) (Lean.RArray.branch 2 (Lean.RArray.leaf y) (Lean.RArray.leaf z))) + (Int.Linear.ExprCnstr.le + ((((((Int.Linear.Expr.var 0).add (Int.Linear.Expr.var 1)).add (Int.Linear.Expr.num 2)).add + (Int.Linear.Expr.var 1)).add + (Int.Linear.Expr.var 2)).add + (Int.Linear.Expr.var 2)) + (((((((Int.Linear.Expr.var 1).add (Int.Linear.Expr.mulL 3 (Int.Linear.Expr.var 2))).add + (Int.Linear.Expr.num 1)).add + (Int.Linear.Expr.num 1)).add + (Int.Linear.Expr.var 0)).add + (Int.Linear.Expr.var 1)).sub + (Int.Linear.Expr.var 2))) + (Eq.refl true))) +-/ +#guard_msgs (info) in +#print ex₁ + +theorem ex₂ (x y z : Int) (f : Int → Int) : x + f y + 2 + f y + z + z ≤ f y + 3*z + 1 + 1 + x + f y - z := by + simp +arith only + +/-- +info: theorem ex₂ : ∀ (x y z : Int) (f : Int → Int), x + f y + 2 + f y + z + z ≤ f y + 3 * z + 1 + 1 + x + f y - z := +fun x y z f => + of_eq_true + ((fun x_1 => + id + (Int.Linear.ExprCnstr.eq_true_of_isValid + (Lean.RArray.branch 1 (Lean.RArray.leaf x) + (Lean.RArray.branch 2 (Lean.RArray.leaf x_1) (Lean.RArray.leaf z))) + (Int.Linear.ExprCnstr.le + ((((((Int.Linear.Expr.var 0).add (Int.Linear.Expr.var 1)).add (Int.Linear.Expr.num 2)).add + (Int.Linear.Expr.var 1)).add + (Int.Linear.Expr.var 2)).add + (Int.Linear.Expr.var 2)) + (((((((Int.Linear.Expr.var 1).add (Int.Linear.Expr.mulL 3 (Int.Linear.Expr.var 2))).add + (Int.Linear.Expr.num 1)).add + (Int.Linear.Expr.num 1)).add + (Int.Linear.Expr.var 0)).add + (Int.Linear.Expr.var 1)).sub + (Int.Linear.Expr.var 2))) + (Eq.refl true))) + (f y)) +-/ +#guard_msgs (info) in +#print ex₂