feat: simplify nested arith expressions

This commit is contained in:
Leonardo de Moura 2022-02-27 08:59:55 -08:00
parent c5baf759e2
commit 89f88b1caa
6 changed files with 54 additions and 10 deletions

View file

@ -137,8 +137,11 @@ def Expr.toPoly : Expr → Poly
| Expr.mulL k a => a.toPoly.mul k
| Expr.mulR a k => a.toPoly.mul k
def Poly.norm (p : Poly) : Poly :=
p.sort.fuse
def Expr.toNormPoly (e : Expr) : Poly :=
e.toPoly.sort.fuse
e.toPoly.norm
def Expr.inc (e : Expr) : Expr :=
Expr.add e (Expr.num 1)
@ -490,7 +493,7 @@ theorem Expr.denote_toPoly (ctx : Context) (e : Expr) : e.toPoly.denote ctx = e.
attribute [local simp] Expr.denote_toPoly
theorem Expr.eq_of_toNormPoly (ctx : Context) (a b : Expr) (h : a.toNormPoly = b.toNormPoly) : a.denote ctx = b.denote ctx := by
simp [toNormPoly] at h
simp [toNormPoly, Poly.norm] at h
have h := congrArg (Poly.denote ctx) h
simp at h
assumption
@ -498,13 +501,13 @@ theorem Expr.eq_of_toNormPoly (ctx : Context) (a b : Expr) (h : a.toNormPoly = b
theorem Expr.of_cancel_eq (ctx : Context) (a b c d : Expr) (h : Poly.cancel a.toNormPoly b.toNormPoly = (c.toPoly, d.toPoly)) : (a.denote ctx = b.denote ctx) = (c.denote ctx = d.denote ctx) := by
have := Poly.denote_eq_cancel_eq ctx a.toNormPoly b.toNormPoly
rw [h] at this
simp [toNormPoly, Poly.denote_eq] at this
simp [toNormPoly, Poly.norm, Poly.denote_eq] at this
exact this.symm
theorem Expr.of_cancel_le (ctx : Context) (a b c d : Expr) (h : Poly.cancel a.toNormPoly b.toNormPoly = (c.toPoly, d.toPoly)) : (a.denote ctx ≤ b.denote ctx) = (c.denote ctx ≤ d.denote ctx) := by
have := Poly.denote_le_cancel_eq ctx a.toNormPoly b.toNormPoly
rw [h] at this
simp [toNormPoly, Poly.denote_le] at this
simp [toNormPoly, Poly.norm,Poly.denote_le] at this
exact this.symm
theorem Expr.of_cancel_lt (ctx : Context) (a b c d : Expr) (h : Poly.cancel a.inc.toNormPoly b.toNormPoly = (c.inc.toPoly, d.toPoly)) : (a.denote ctx < b.denote ctx) = (c.denote ctx < d.denote ctx) :=
@ -526,8 +529,8 @@ theorem ExprCnstr.denote_toNormPoly (ctx : Context) (c : ExprCnstr) : c.toNormPo
cases c; rename_i eq lhs rhs
simp [ExprCnstr.denote, PolyCnstr.denote, ExprCnstr.toNormPoly]
by_cases h : eq = true <;> simp [h]
. rw [Poly.denote_eq_cancel_eq]; simp [Poly.denote_eq, Expr.toNormPoly]
. rw [Poly.denote_le_cancel_eq]; simp [Poly.denote_le, Expr.toNormPoly]
. rw [Poly.denote_eq_cancel_eq]; simp [Poly.denote_eq, Expr.toNormPoly, Poly.norm]
. rw [Poly.denote_le_cancel_eq]; simp [Poly.denote_le, Expr.toNormPoly, Poly.norm]
attribute [local simp] ExprCnstr.denote_toNormPoly
@ -671,4 +674,9 @@ theorem ExprCnstr.eq_of_toNormPoly_eq (ctx : Context) (c d : ExprCnstr) (h : c.t
simp at h
assumption
theorem Expr.eq_of_toNormPoly_eq (ctx : Context) (e e' : Expr) (h : e.toNormPoly == e'.toPoly) : e.denote ctx = e'.denote ctx := by
have h := congrArg (Poly.denote ctx) (eq_of_beq h)
simp [Expr.toNormPoly, Poly.norm] at h
assumption
end Nat.Linear

View file

@ -200,4 +200,17 @@ def simpCnstr? (e : Expr) : MetaM (Option (Expr × Expr)) := do
else
simpCnstrPos? e
def simpExpr? (e : Expr) : MetaM (Option (Expr × Expr)) := do
let (e, ctx) ← ToLinear.run (ToLinear.toLinearExpr e)
let p := e.toPoly
let p' := p.norm
if p'.length < p.length then
-- We only return some if monomials were fused
let e' : LinearExpr := p'.toExpr
let p := mkApp4 (mkConst ``Nat.Linear.Expr.eq_of_toNormPoly_eq) (← toContextExpr ctx) (toExpr e) (toExpr e') reflTrue
let r ← e'.toArith ctx
return some (r, p)
else
return none
end Lean.Meta.Linear.Nat

View file

@ -7,6 +7,15 @@ import Lean.Meta.Tactic.LinearArith.Nat
namespace Lean.Meta.Linear
/-- Quick filter simpExpr? -/
private partial def isSimpExprTarget (e : Expr) : Bool :=
let f := e.getAppFn
if !f.isConst then
false
else
let n := f.constName!
n == ``HAdd.hAdd || n == ``HMul.hMul || n == ``HSub.hSub || n == ``Nat.succ
/-- Quick filter simpCnstr? -/
private partial def isSimpCnstrTarget (e : Expr) : Bool :=
let f := e.getAppFn
@ -21,10 +30,18 @@ private partial def isSimpCnstrTarget (e : Expr) : Bool :=
else
false
def simpCnstr? (e : Expr) : MetaM (Option (Expr × Expr)) := do
private def parentIsTarget (parent? : Option Expr) : Bool :=
match parent? with
| none => false
| some parent => isSimpExprTarget parent || isSimpCnstrTarget parent
def simp? (e : Expr) (parent? : Option Expr) : MetaM (Option (Expr × Expr)) := do
-- TODO: add support for `Int` and arbitrary ordered comm rings
if isSimpCnstrTarget e then
-- TODO: add support for `Int` and arbitrary ordered comm rings
Nat.simpCnstr? e
else if isSimpExprTarget e && !parentIsTarget parent? then
trace[Meta.Tactic.simp] "arith expr: {e}"
Nat.simpExpr? e
else
return none

View file

@ -199,7 +199,7 @@ def rewriteUsingDecide? (e : Expr) : MetaM (Option Result) := withReducibleAndIn
def simpArith? (e : Expr) : SimpM (Option Step) := do
if !(← read).config.arith then return none
let some (e', h) ← Linear.simpCnstr? e | return none
let some (e', h) ← Linear.simp? e (← read).parent? | return none
return Step.visit { expr := e', proof? := h }
def rewritePre (e : Expr) (discharge? : Expr → SimpM (Option Expr)) : SimpM Step := do

View file

@ -16,3 +16,9 @@ theorem ex5 (h : a + d + b > b + 1 + (a + (c + c) + d)) : False := by
simp_arith at h
#print ex5
theorem ex6 (p : Nat → Prop) (h : p (a + 1 + a + 2 + b)) : p (2*a + b + 3) := by
simp_arith at h
assumption
#print ex6

View file

@ -4,7 +4,7 @@ open Lean in open Lean.Meta in
def test (declName : Name) : MetaM Unit := do
let info ← getConstInfo declName
forallTelescope info.type fun _ e => do
let some (e', p) ← Linear.simpCnstr? e | throwError "failed to simplify{indentExpr e}"
let some (e', p) ← Linear.simp? e none | throwError "failed to simplify{indentExpr e}"
check p
unless (← isDefEq (← inferType p) (← mkEq e e')) do
throwError "invalid proof"