perf: avoid inferType at simpArith (#9398)
This PR avoids the expensive `inferType` call in `simpArith`. It also cleans up some of the code and removes anti-patterns.
This commit is contained in:
parent
d4afa3caaa
commit
e286f20179
5 changed files with 81 additions and 81 deletions
|
|
@ -9,9 +9,9 @@ import Lean.Meta.Tactic.Simp.Arith.Int
|
|||
|
||||
namespace Lean.Meta.Simp.Arith
|
||||
|
||||
def parentIsTarget (parent? : Option Expr) (isNatExpr : Bool) : Bool :=
|
||||
def parentIsTarget (parent? : Option Expr) : Bool :=
|
||||
match parent? with
|
||||
| none => false
|
||||
| some parent => isLinearTerm parent isNatExpr || isLinearCnstr parent || isDvdCnstr parent
|
||||
| some parent => isLinearTerm parent || isLinearCnstr parent || isDvdCnstr parent
|
||||
|
||||
end Lean.Meta.Simp.Arith
|
||||
|
|
|
|||
|
|
@ -105,36 +105,30 @@ def simpLe? (e : Expr) (checkIfModified : Bool) : MetaM (Option (Expr × Expr))
|
|||
|
||||
def simpRel? (e : Expr) : MetaM (Option (Expr × Expr)) := do
|
||||
if let some arg := e.not? then
|
||||
let mut eNew? := none
|
||||
let mut thmName := Name.anonymous
|
||||
let mut eNew? := none
|
||||
let mut h₁ := default
|
||||
match_expr arg with
|
||||
| LE.le α _ lhs rhs =>
|
||||
let_expr Int ← α | pure ()
|
||||
eNew? := some (mkIntLE (mkIntAdd rhs (mkIntLit 1)) lhs)
|
||||
thmName := ``Int.not_le_eq
|
||||
h₁ := mkApp2 (mkConst ``Int.not_le_eq) lhs rhs
|
||||
| GE.ge α _ lhs rhs =>
|
||||
let_expr Int ← α | pure ()
|
||||
eNew? := some (mkIntLE (mkIntAdd lhs (mkIntLit 1)) rhs)
|
||||
thmName := ``Int.not_ge_eq
|
||||
h₁ := mkApp2 (mkConst ``Int.not_ge_eq) lhs rhs
|
||||
| LT.lt α _ lhs rhs =>
|
||||
let_expr Int ← α | pure ()
|
||||
eNew? := some (mkIntLE rhs lhs)
|
||||
thmName := ``Int.not_lt_eq
|
||||
h₁ := mkApp2 (mkConst ``Int.not_lt_eq) lhs rhs
|
||||
| GT.gt α _ lhs rhs =>
|
||||
let_expr Int ← α | pure ()
|
||||
eNew? := some (mkIntLE lhs rhs)
|
||||
thmName := ``Int.not_gt_eq
|
||||
h₁ := mkApp2 (mkConst ``Int.not_gt_eq) lhs rhs
|
||||
| _ => pure ()
|
||||
if let some eNew := eNew? then
|
||||
let h₁ := mkApp2 (mkConst thmName) (arg.getArg! 2) (arg.getArg! 3)
|
||||
-- Already modified
|
||||
if let some (eNew', h₂) ← simpLe? eNew (checkIfModified := false) then
|
||||
let h := mkApp6 (mkConst ``Eq.trans [levelOne]) (mkSort levelZero) e eNew eNew' h₁ h₂
|
||||
return some (eNew', h)
|
||||
else
|
||||
return some (eNew, h₁)
|
||||
else
|
||||
return none
|
||||
let some eNew := eNew? | return none
|
||||
let some (eNew', h₂) ← simpLe? eNew (checkIfModified := false) | return (eNew, h₁)
|
||||
let h := mkApp6 (mkConst ``Eq.trans [levelOne]) (mkSort levelZero) e eNew eNew' h₁ h₂
|
||||
return some (eNew', h)
|
||||
else
|
||||
simpLe? e (checkIfModified := true)
|
||||
|
||||
|
|
|
|||
|
|
@ -35,35 +35,30 @@ def simpCnstrPos? (e : Expr) : MetaM (Option (Expr × Expr)) := do
|
|||
|
||||
def simpCnstr? (e : Expr) : MetaM (Option (Expr × Expr)) := do
|
||||
if let some arg := e.not? then
|
||||
let mut eNew? := none
|
||||
let mut thmName := Name.anonymous
|
||||
let mut eNew? := none
|
||||
let mut h₁ := default
|
||||
match_expr arg with
|
||||
| LE.le α _ _ _ =>
|
||||
if α.isConstOf ``Nat then
|
||||
eNew? := some (mkNatLE (mkNatAdd (arg.getArg! 3) (mkNatLit 1)) (arg.getArg! 2))
|
||||
thmName := ``Nat.not_le_eq
|
||||
| GE.ge α _ _ _ =>
|
||||
if α.isConstOf ``Nat then
|
||||
eNew? := some (mkNatLE (mkNatAdd (arg.getArg! 2) (mkNatLit 1)) (arg.getArg! 3))
|
||||
thmName := ``Nat.not_ge_eq
|
||||
| LT.lt α _ _ _ =>
|
||||
if α.isConstOf ``Nat then
|
||||
eNew? := some (mkNatLE (arg.getArg! 3) (arg.getArg! 2))
|
||||
thmName := ``Nat.not_lt_eq
|
||||
| GT.gt α _ _ _ =>
|
||||
if α.isConstOf ``Nat then
|
||||
eNew? := some (mkNatLE (arg.getArg! 2) (arg.getArg! 3))
|
||||
thmName := ``Nat.not_gt_eq
|
||||
| LE.le α _ a b =>
|
||||
let_expr Nat ← α | pure ()
|
||||
eNew? := some (mkNatLE (mkNatAdd b (mkNatLit 1)) a)
|
||||
h₁ := mkApp2 (mkConst ``Nat.not_le_eq) a b
|
||||
| GE.ge α _ a b =>
|
||||
let_expr Nat ← α | pure ()
|
||||
eNew? := some (mkNatLE (mkNatAdd a (mkNatLit 1)) b)
|
||||
h₁ := mkApp2 (mkConst ``Nat.not_ge_eq) a b
|
||||
| LT.lt α _ a b =>
|
||||
let_expr Nat ← α | pure ()
|
||||
eNew? := some (mkNatLE b a)
|
||||
h₁ := mkApp2 (mkConst ``Nat.not_lt_eq) a b
|
||||
| GT.gt α _ a b =>
|
||||
let_expr Nat ← α | pure ()
|
||||
eNew? := some (mkNatLE a b)
|
||||
h₁ := mkApp2 (mkConst ``Nat.not_gt_eq) a b
|
||||
| _ => pure ()
|
||||
if let some eNew := eNew? then
|
||||
let h₁ := mkApp2 (mkConst thmName) (arg.getArg! 2) (arg.getArg! 3)
|
||||
if let some (eNew', h₂) ← simpCnstrPos? eNew then
|
||||
let h := mkApp6 (mkConst ``Eq.trans [levelOne]) (mkSort levelZero) e eNew eNew' h₁ h₂
|
||||
return some (eNew', h)
|
||||
else
|
||||
return some (eNew, h₁)
|
||||
else
|
||||
return none
|
||||
let some eNew := eNew? | return none
|
||||
let some (eNew', h₂) ← simpCnstrPos? eNew | return (eNew, h₁)
|
||||
let h := mkApp6 (mkConst ``Eq.trans [levelOne]) (mkSort levelZero) e eNew eNew' h₁ h₂
|
||||
return some (eNew', h)
|
||||
else
|
||||
simpCnstrPos? e
|
||||
|
||||
|
|
|
|||
|
|
@ -33,33 +33,48 @@ def withAbstractAtoms (atoms : Array Expr) (type : Name) (k : Array Expr → Met
|
|||
return some (r, p)
|
||||
go 0 #[] #[] #[]
|
||||
|
||||
/-- Quick filter for linear terms. -/
|
||||
def isLinearTerm (e : Expr) (isNatExpr : Bool) : Bool :=
|
||||
let f := e.getAppFn
|
||||
if !f.isConst then
|
||||
false
|
||||
else
|
||||
let n := f.constName!
|
||||
n == ``HAdd.hAdd || n == ``HMul.hMul || n == ``Neg.neg || n == ``Nat.succ
|
||||
|| n == ``Add.add || n == ``Mul.mul
|
||||
-- Recall that `Nat.sub` is truncated
|
||||
|| (!isNatExpr && (n == ``HSub.hSub || n == ``Sub.sub))
|
||||
private def isSupportedType (type : Expr) : Bool :=
|
||||
match_expr type with
|
||||
| Nat => true
|
||||
| Int => true
|
||||
| _ => false
|
||||
|
||||
/-- Quick filter for linear constraints. -/
|
||||
partial def isLinearCnstr (e : Expr) : Bool :=
|
||||
let f := e.getAppFn
|
||||
if !f.isConst then
|
||||
false
|
||||
else
|
||||
let n := f.constName!
|
||||
if n == ``Eq || n == ``LT.lt || n == ``LE.le || n == ``GT.gt || n == ``GE.ge || n == ``Ne then
|
||||
true
|
||||
else if n == ``Not && e.getAppNumArgs == 1 then
|
||||
isLinearCnstr e.appArg!
|
||||
else
|
||||
false
|
||||
private def isSupportedCommRingType (type : Expr) : Bool :=
|
||||
match_expr type with
|
||||
| Int => true
|
||||
| _ => false
|
||||
|
||||
/-- Quick filter for linear terms. -/
|
||||
def isLinearTerm? (e : Expr) : Option Expr :=
|
||||
match_expr e with
|
||||
| HAdd.hAdd α _ _ _ _ _ => .guard isSupportedType α
|
||||
| HMul.hMul α _ _ _ _ _ => .guard isSupportedType α
|
||||
| HSub.hSub α _ _ _ _ _ => .guard isSupportedCommRingType α
|
||||
| Neg.neg α _ _ => .guard isSupportedCommRingType α
|
||||
| Nat.succ _ => some Nat.mkType
|
||||
| _ => none
|
||||
|
||||
def isLinearTerm (e : Expr) : Bool :=
|
||||
isLinearTerm? e |>.isSome
|
||||
|
||||
def isLinearPosCnstr (e : Expr) : Bool :=
|
||||
match_expr e with
|
||||
| Eq α _ _ => isSupportedType α
|
||||
| Ne α _ _ => isSupportedType α
|
||||
| LE.le α _ _ _ => isSupportedType α
|
||||
| LT.lt α _ _ _ => isSupportedType α
|
||||
| GT.gt α _ _ _ => isSupportedType α
|
||||
| GE.ge α _ _ _ => isSupportedType α
|
||||
| _ => false
|
||||
|
||||
def isLinearCnstr (e : Expr) : Bool :=
|
||||
match_expr e with
|
||||
| Not p => isLinearPosCnstr p
|
||||
| _ => isLinearPosCnstr e
|
||||
|
||||
def isDvdCnstr (e : Expr) : Bool :=
|
||||
e.isAppOfArity ``Dvd.dvd 4
|
||||
match_expr e with
|
||||
| Dvd.dvd α _ _ _ => isSupportedType α
|
||||
| _ => false
|
||||
|
||||
end Lean.Meta.Simp.Arith
|
||||
|
|
|
|||
|
|
@ -289,11 +289,6 @@ where
|
|||
catch _ =>
|
||||
return .continue
|
||||
|
||||
private def isNatExpr (e : Expr) : MetaM Bool := do
|
||||
let type ← inferType e
|
||||
let_expr Nat ← type | return false
|
||||
return true
|
||||
|
||||
def simpArith (e : Expr) : SimpM Step := do
|
||||
unless (← getConfig).arith do
|
||||
return .continue
|
||||
|
|
@ -305,18 +300,19 @@ def simpArith (e : Expr) : SimpM Step := do
|
|||
if let some (e', h) ← Arith.Int.simpEq? e then
|
||||
return .visit { expr := e', proof? := h }
|
||||
return .continue
|
||||
let isNat ← isNatExpr e
|
||||
if Arith.isLinearTerm e isNat then
|
||||
if Arith.parentIsTarget (← getContext).parent? isNat then
|
||||
if let some α := Arith.isLinearTerm? e then
|
||||
if Arith.parentIsTarget (← getContext).parent? then
|
||||
-- We mark `cache := false` to ensure we do not miss simplifications.
|
||||
return .continue (some { expr := e, cache := false })
|
||||
if isNat then
|
||||
match_expr α with
|
||||
| Nat =>
|
||||
let some (e', h) ← Arith.Nat.simpExpr? e | pure ()
|
||||
return .visit { expr := e', proof? := h }
|
||||
else
|
||||
| Int =>
|
||||
let some (e', h) ← Arith.Int.simpExpr? e | pure ()
|
||||
return .visit { expr := e', proof? := h }
|
||||
return .continue
|
||||
| _ =>
|
||||
return .continue
|
||||
if Arith.isDvdCnstr e then
|
||||
let some (e', h) ← Arith.Int.simpDvd? e | pure ()
|
||||
return .visit { expr := e', proof? := h }
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue