From 89f88b1caa4dc3542baa50b28064a9b7b1b76aa0 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 27 Feb 2022 08:59:55 -0800 Subject: [PATCH] feat: simplify nested arith expressions --- src/Init/Data/Nat/Linear.lean | 20 ++++++++++++++------ src/Lean/Meta/Tactic/LinearArith/Nat.lean | 13 +++++++++++++ src/Lean/Meta/Tactic/LinearArith/Simp.lean | 21 +++++++++++++++++++-- src/Lean/Meta/Tactic/Simp/Rewrite.lean | 2 +- tests/lean/run/simpArith1.lean | 6 ++++++ tests/lean/run/simpCnstr1.lean | 2 +- 6 files changed, 54 insertions(+), 10 deletions(-) diff --git a/src/Init/Data/Nat/Linear.lean b/src/Init/Data/Nat/Linear.lean index 544ac13113..73c5cbf809 100644 --- a/src/Init/Data/Nat/Linear.lean +++ b/src/Init/Data/Nat/Linear.lean @@ -137,8 +137,11 @@ def Expr.toPoly : Expr → Poly | Expr.mulL k a => a.toPoly.mul k | Expr.mulR a k => a.toPoly.mul k +def Poly.norm (p : Poly) : Poly := + p.sort.fuse + def Expr.toNormPoly (e : Expr) : Poly := - e.toPoly.sort.fuse + e.toPoly.norm def Expr.inc (e : Expr) : Expr := Expr.add e (Expr.num 1) @@ -490,7 +493,7 @@ theorem Expr.denote_toPoly (ctx : Context) (e : Expr) : e.toPoly.denote ctx = e. attribute [local simp] Expr.denote_toPoly theorem Expr.eq_of_toNormPoly (ctx : Context) (a b : Expr) (h : a.toNormPoly = b.toNormPoly) : a.denote ctx = b.denote ctx := by - simp [toNormPoly] at h + simp [toNormPoly, Poly.norm] at h have h := congrArg (Poly.denote ctx) h simp at h assumption @@ -498,13 +501,13 @@ theorem Expr.eq_of_toNormPoly (ctx : Context) (a b : Expr) (h : a.toNormPoly = b theorem Expr.of_cancel_eq (ctx : Context) (a b c d : Expr) (h : Poly.cancel a.toNormPoly b.toNormPoly = (c.toPoly, d.toPoly)) : (a.denote ctx = b.denote ctx) = (c.denote ctx = d.denote ctx) := by have := Poly.denote_eq_cancel_eq ctx a.toNormPoly b.toNormPoly rw [h] at this - simp [toNormPoly, Poly.denote_eq] at this + simp [toNormPoly, Poly.norm, Poly.denote_eq] at this exact this.symm theorem Expr.of_cancel_le (ctx : Context) (a b c d : Expr) (h : Poly.cancel a.toNormPoly b.toNormPoly = (c.toPoly, d.toPoly)) : (a.denote ctx ≤ b.denote ctx) = (c.denote ctx ≤ d.denote ctx) := by have := Poly.denote_le_cancel_eq ctx a.toNormPoly b.toNormPoly rw [h] at this - simp [toNormPoly, Poly.denote_le] at this + simp [toNormPoly, Poly.norm,Poly.denote_le] at this exact this.symm theorem Expr.of_cancel_lt (ctx : Context) (a b c d : Expr) (h : Poly.cancel a.inc.toNormPoly b.toNormPoly = (c.inc.toPoly, d.toPoly)) : (a.denote ctx < b.denote ctx) = (c.denote ctx < d.denote ctx) := @@ -526,8 +529,8 @@ theorem ExprCnstr.denote_toNormPoly (ctx : Context) (c : ExprCnstr) : c.toNormPo cases c; rename_i eq lhs rhs simp [ExprCnstr.denote, PolyCnstr.denote, ExprCnstr.toNormPoly] by_cases h : eq = true <;> simp [h] - . rw [Poly.denote_eq_cancel_eq]; simp [Poly.denote_eq, Expr.toNormPoly] - . rw [Poly.denote_le_cancel_eq]; simp [Poly.denote_le, Expr.toNormPoly] + . rw [Poly.denote_eq_cancel_eq]; simp [Poly.denote_eq, Expr.toNormPoly, Poly.norm] + . rw [Poly.denote_le_cancel_eq]; simp [Poly.denote_le, Expr.toNormPoly, Poly.norm] attribute [local simp] ExprCnstr.denote_toNormPoly @@ -671,4 +674,9 @@ theorem ExprCnstr.eq_of_toNormPoly_eq (ctx : Context) (c d : ExprCnstr) (h : c.t simp at h assumption +theorem Expr.eq_of_toNormPoly_eq (ctx : Context) (e e' : Expr) (h : e.toNormPoly == e'.toPoly) : e.denote ctx = e'.denote ctx := by + have h := congrArg (Poly.denote ctx) (eq_of_beq h) + simp [Expr.toNormPoly, Poly.norm] at h + assumption + end Nat.Linear diff --git a/src/Lean/Meta/Tactic/LinearArith/Nat.lean b/src/Lean/Meta/Tactic/LinearArith/Nat.lean index 723a0cc9ca..0d59167217 100644 --- a/src/Lean/Meta/Tactic/LinearArith/Nat.lean +++ b/src/Lean/Meta/Tactic/LinearArith/Nat.lean @@ -200,4 +200,17 @@ def simpCnstr? (e : Expr) : MetaM (Option (Expr × Expr)) := do else simpCnstrPos? e +def simpExpr? (e : Expr) : MetaM (Option (Expr × Expr)) := do + let (e, ctx) ← ToLinear.run (ToLinear.toLinearExpr e) + let p := e.toPoly + let p' := p.norm + if p'.length < p.length then + -- We only return some if monomials were fused + let e' : LinearExpr := p'.toExpr + let p := mkApp4 (mkConst ``Nat.Linear.Expr.eq_of_toNormPoly_eq) (← toContextExpr ctx) (toExpr e) (toExpr e') reflTrue + let r ← e'.toArith ctx + return some (r, p) + else + return none + end Lean.Meta.Linear.Nat diff --git a/src/Lean/Meta/Tactic/LinearArith/Simp.lean b/src/Lean/Meta/Tactic/LinearArith/Simp.lean index fcc2e1fe10..a7ca71b874 100644 --- a/src/Lean/Meta/Tactic/LinearArith/Simp.lean +++ b/src/Lean/Meta/Tactic/LinearArith/Simp.lean @@ -7,6 +7,15 @@ import Lean.Meta.Tactic.LinearArith.Nat namespace Lean.Meta.Linear +/-- Quick filter simpExpr? -/ +private partial def isSimpExprTarget (e : Expr) : Bool := + let f := e.getAppFn + if !f.isConst then + false + else + let n := f.constName! + n == ``HAdd.hAdd || n == ``HMul.hMul || n == ``HSub.hSub || n == ``Nat.succ + /-- Quick filter simpCnstr? -/ private partial def isSimpCnstrTarget (e : Expr) : Bool := let f := e.getAppFn @@ -21,10 +30,18 @@ private partial def isSimpCnstrTarget (e : Expr) : Bool := else false -def simpCnstr? (e : Expr) : MetaM (Option (Expr × Expr)) := do +private def parentIsTarget (parent? : Option Expr) : Bool := + match parent? with + | none => false + | some parent => isSimpExprTarget parent || isSimpCnstrTarget parent + +def simp? (e : Expr) (parent? : Option Expr) : MetaM (Option (Expr × Expr)) := do + -- TODO: add support for `Int` and arbitrary ordered comm rings if isSimpCnstrTarget e then - -- TODO: add support for `Int` and arbitrary ordered comm rings Nat.simpCnstr? e + else if isSimpExprTarget e && !parentIsTarget parent? then + trace[Meta.Tactic.simp] "arith expr: {e}" + Nat.simpExpr? e else return none diff --git a/src/Lean/Meta/Tactic/Simp/Rewrite.lean b/src/Lean/Meta/Tactic/Simp/Rewrite.lean index fa83201996..7a6fa1050d 100644 --- a/src/Lean/Meta/Tactic/Simp/Rewrite.lean +++ b/src/Lean/Meta/Tactic/Simp/Rewrite.lean @@ -199,7 +199,7 @@ def rewriteUsingDecide? (e : Expr) : MetaM (Option Result) := withReducibleAndIn def simpArith? (e : Expr) : SimpM (Option Step) := do if !(← read).config.arith then return none - let some (e', h) ← Linear.simpCnstr? e | return none + let some (e', h) ← Linear.simp? e (← read).parent? | return none return Step.visit { expr := e', proof? := h } def rewritePre (e : Expr) (discharge? : Expr → SimpM (Option Expr)) : SimpM Step := do diff --git a/tests/lean/run/simpArith1.lean b/tests/lean/run/simpArith1.lean index f3ca3b4b36..31c2a1dc4f 100644 --- a/tests/lean/run/simpArith1.lean +++ b/tests/lean/run/simpArith1.lean @@ -16,3 +16,9 @@ theorem ex5 (h : a + d + b > b + 1 + (a + (c + c) + d)) : False := by simp_arith at h #print ex5 + +theorem ex6 (p : Nat → Prop) (h : p (a + 1 + a + 2 + b)) : p (2*a + b + 3) := by + simp_arith at h + assumption + +#print ex6 diff --git a/tests/lean/run/simpCnstr1.lean b/tests/lean/run/simpCnstr1.lean index 7e5bdcecfd..e55b1074b7 100644 --- a/tests/lean/run/simpCnstr1.lean +++ b/tests/lean/run/simpCnstr1.lean @@ -4,7 +4,7 @@ open Lean in open Lean.Meta in def test (declName : Name) : MetaM Unit := do let info ← getConstInfo declName forallTelescope info.type fun _ e => do - let some (e', p) ← Linear.simpCnstr? e | throwError "failed to simplify{indentExpr e}" + let some (e', p) ← Linear.simp? e none | throwError "failed to simplify{indentExpr e}" check p unless (← isDefEq (← inferType p) (← mkEq e e')) do throwError "invalid proof"