feat: simplify nested arith expressions
This commit is contained in:
parent
c5baf759e2
commit
89f88b1caa
6 changed files with 54 additions and 10 deletions
|
|
@ -137,8 +137,11 @@ def Expr.toPoly : Expr → Poly
|
|||
| Expr.mulL k a => a.toPoly.mul k
|
||||
| Expr.mulR a k => a.toPoly.mul k
|
||||
|
||||
def Poly.norm (p : Poly) : Poly :=
|
||||
p.sort.fuse
|
||||
|
||||
def Expr.toNormPoly (e : Expr) : Poly :=
|
||||
e.toPoly.sort.fuse
|
||||
e.toPoly.norm
|
||||
|
||||
def Expr.inc (e : Expr) : Expr :=
|
||||
Expr.add e (Expr.num 1)
|
||||
|
|
@ -490,7 +493,7 @@ theorem Expr.denote_toPoly (ctx : Context) (e : Expr) : e.toPoly.denote ctx = e.
|
|||
attribute [local simp] Expr.denote_toPoly
|
||||
|
||||
theorem Expr.eq_of_toNormPoly (ctx : Context) (a b : Expr) (h : a.toNormPoly = b.toNormPoly) : a.denote ctx = b.denote ctx := by
|
||||
simp [toNormPoly] at h
|
||||
simp [toNormPoly, Poly.norm] at h
|
||||
have h := congrArg (Poly.denote ctx) h
|
||||
simp at h
|
||||
assumption
|
||||
|
|
@ -498,13 +501,13 @@ theorem Expr.eq_of_toNormPoly (ctx : Context) (a b : Expr) (h : a.toNormPoly = b
|
|||
theorem Expr.of_cancel_eq (ctx : Context) (a b c d : Expr) (h : Poly.cancel a.toNormPoly b.toNormPoly = (c.toPoly, d.toPoly)) : (a.denote ctx = b.denote ctx) = (c.denote ctx = d.denote ctx) := by
|
||||
have := Poly.denote_eq_cancel_eq ctx a.toNormPoly b.toNormPoly
|
||||
rw [h] at this
|
||||
simp [toNormPoly, Poly.denote_eq] at this
|
||||
simp [toNormPoly, Poly.norm, Poly.denote_eq] at this
|
||||
exact this.symm
|
||||
|
||||
theorem Expr.of_cancel_le (ctx : Context) (a b c d : Expr) (h : Poly.cancel a.toNormPoly b.toNormPoly = (c.toPoly, d.toPoly)) : (a.denote ctx ≤ b.denote ctx) = (c.denote ctx ≤ d.denote ctx) := by
|
||||
have := Poly.denote_le_cancel_eq ctx a.toNormPoly b.toNormPoly
|
||||
rw [h] at this
|
||||
simp [toNormPoly, Poly.denote_le] at this
|
||||
simp [toNormPoly, Poly.norm,Poly.denote_le] at this
|
||||
exact this.symm
|
||||
|
||||
theorem Expr.of_cancel_lt (ctx : Context) (a b c d : Expr) (h : Poly.cancel a.inc.toNormPoly b.toNormPoly = (c.inc.toPoly, d.toPoly)) : (a.denote ctx < b.denote ctx) = (c.denote ctx < d.denote ctx) :=
|
||||
|
|
@ -526,8 +529,8 @@ theorem ExprCnstr.denote_toNormPoly (ctx : Context) (c : ExprCnstr) : c.toNormPo
|
|||
cases c; rename_i eq lhs rhs
|
||||
simp [ExprCnstr.denote, PolyCnstr.denote, ExprCnstr.toNormPoly]
|
||||
by_cases h : eq = true <;> simp [h]
|
||||
. rw [Poly.denote_eq_cancel_eq]; simp [Poly.denote_eq, Expr.toNormPoly]
|
||||
. rw [Poly.denote_le_cancel_eq]; simp [Poly.denote_le, Expr.toNormPoly]
|
||||
. rw [Poly.denote_eq_cancel_eq]; simp [Poly.denote_eq, Expr.toNormPoly, Poly.norm]
|
||||
. rw [Poly.denote_le_cancel_eq]; simp [Poly.denote_le, Expr.toNormPoly, Poly.norm]
|
||||
|
||||
attribute [local simp] ExprCnstr.denote_toNormPoly
|
||||
|
||||
|
|
@ -671,4 +674,9 @@ theorem ExprCnstr.eq_of_toNormPoly_eq (ctx : Context) (c d : ExprCnstr) (h : c.t
|
|||
simp at h
|
||||
assumption
|
||||
|
||||
theorem Expr.eq_of_toNormPoly_eq (ctx : Context) (e e' : Expr) (h : e.toNormPoly == e'.toPoly) : e.denote ctx = e'.denote ctx := by
|
||||
have h := congrArg (Poly.denote ctx) (eq_of_beq h)
|
||||
simp [Expr.toNormPoly, Poly.norm] at h
|
||||
assumption
|
||||
|
||||
end Nat.Linear
|
||||
|
|
|
|||
|
|
@ -200,4 +200,17 @@ def simpCnstr? (e : Expr) : MetaM (Option (Expr × Expr)) := do
|
|||
else
|
||||
simpCnstrPos? e
|
||||
|
||||
def simpExpr? (e : Expr) : MetaM (Option (Expr × Expr)) := do
|
||||
let (e, ctx) ← ToLinear.run (ToLinear.toLinearExpr e)
|
||||
let p := e.toPoly
|
||||
let p' := p.norm
|
||||
if p'.length < p.length then
|
||||
-- We only return some if monomials were fused
|
||||
let e' : LinearExpr := p'.toExpr
|
||||
let p := mkApp4 (mkConst ``Nat.Linear.Expr.eq_of_toNormPoly_eq) (← toContextExpr ctx) (toExpr e) (toExpr e') reflTrue
|
||||
let r ← e'.toArith ctx
|
||||
return some (r, p)
|
||||
else
|
||||
return none
|
||||
|
||||
end Lean.Meta.Linear.Nat
|
||||
|
|
|
|||
|
|
@ -7,6 +7,15 @@ import Lean.Meta.Tactic.LinearArith.Nat
|
|||
|
||||
namespace Lean.Meta.Linear
|
||||
|
||||
/-- Quick filter simpExpr? -/
|
||||
private partial def isSimpExprTarget (e : Expr) : Bool :=
|
||||
let f := e.getAppFn
|
||||
if !f.isConst then
|
||||
false
|
||||
else
|
||||
let n := f.constName!
|
||||
n == ``HAdd.hAdd || n == ``HMul.hMul || n == ``HSub.hSub || n == ``Nat.succ
|
||||
|
||||
/-- Quick filter simpCnstr? -/
|
||||
private partial def isSimpCnstrTarget (e : Expr) : Bool :=
|
||||
let f := e.getAppFn
|
||||
|
|
@ -21,10 +30,18 @@ private partial def isSimpCnstrTarget (e : Expr) : Bool :=
|
|||
else
|
||||
false
|
||||
|
||||
def simpCnstr? (e : Expr) : MetaM (Option (Expr × Expr)) := do
|
||||
private def parentIsTarget (parent? : Option Expr) : Bool :=
|
||||
match parent? with
|
||||
| none => false
|
||||
| some parent => isSimpExprTarget parent || isSimpCnstrTarget parent
|
||||
|
||||
def simp? (e : Expr) (parent? : Option Expr) : MetaM (Option (Expr × Expr)) := do
|
||||
-- TODO: add support for `Int` and arbitrary ordered comm rings
|
||||
if isSimpCnstrTarget e then
|
||||
-- TODO: add support for `Int` and arbitrary ordered comm rings
|
||||
Nat.simpCnstr? e
|
||||
else if isSimpExprTarget e && !parentIsTarget parent? then
|
||||
trace[Meta.Tactic.simp] "arith expr: {e}"
|
||||
Nat.simpExpr? e
|
||||
else
|
||||
return none
|
||||
|
||||
|
|
|
|||
|
|
@ -199,7 +199,7 @@ def rewriteUsingDecide? (e : Expr) : MetaM (Option Result) := withReducibleAndIn
|
|||
|
||||
def simpArith? (e : Expr) : SimpM (Option Step) := do
|
||||
if !(← read).config.arith then return none
|
||||
let some (e', h) ← Linear.simpCnstr? e | return none
|
||||
let some (e', h) ← Linear.simp? e (← read).parent? | return none
|
||||
return Step.visit { expr := e', proof? := h }
|
||||
|
||||
def rewritePre (e : Expr) (discharge? : Expr → SimpM (Option Expr)) : SimpM Step := do
|
||||
|
|
|
|||
|
|
@ -16,3 +16,9 @@ theorem ex5 (h : a + d + b > b + 1 + (a + (c + c) + d)) : False := by
|
|||
simp_arith at h
|
||||
|
||||
#print ex5
|
||||
|
||||
theorem ex6 (p : Nat → Prop) (h : p (a + 1 + a + 2 + b)) : p (2*a + b + 3) := by
|
||||
simp_arith at h
|
||||
assumption
|
||||
|
||||
#print ex6
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ open Lean in open Lean.Meta in
|
|||
def test (declName : Name) : MetaM Unit := do
|
||||
let info ← getConstInfo declName
|
||||
forallTelescope info.type fun _ e => do
|
||||
let some (e', p) ← Linear.simpCnstr? e | throwError "failed to simplify{indentExpr e}"
|
||||
let some (e', p) ← Linear.simp? e none | throwError "failed to simplify{indentExpr e}"
|
||||
check p
|
||||
unless (← isDefEq (← inferType p) (← mkEq e e')) do
|
||||
throwError "invalid proof"
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue