From e286f2017998e6a2be1f1ca178c7ef05679bf2bb Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 15 Jul 2025 20:42:26 -0700 Subject: [PATCH] perf: avoid `inferType` at `simpArith` (#9398) This PR avoids the expensive `inferType` call in `simpArith`. It also cleans up some of the code and removes anti-patterns. --- src/Lean/Meta/Tactic/Simp/Arith.lean | 4 +- src/Lean/Meta/Tactic/Simp/Arith/Int/Simp.lean | 26 +++----- src/Lean/Meta/Tactic/Simp/Arith/Nat/Simp.lean | 49 +++++++------- src/Lean/Meta/Tactic/Simp/Arith/Util.lean | 65 ++++++++++++------- src/Lean/Meta/Tactic/Simp/Rewrite.lean | 18 ++--- 5 files changed, 81 insertions(+), 81 deletions(-) diff --git a/src/Lean/Meta/Tactic/Simp/Arith.lean b/src/Lean/Meta/Tactic/Simp/Arith.lean index bfd993203b..70c84f9ad8 100644 --- a/src/Lean/Meta/Tactic/Simp/Arith.lean +++ b/src/Lean/Meta/Tactic/Simp/Arith.lean @@ -9,9 +9,9 @@ import Lean.Meta.Tactic.Simp.Arith.Int namespace Lean.Meta.Simp.Arith -def parentIsTarget (parent? : Option Expr) (isNatExpr : Bool) : Bool := +def parentIsTarget (parent? : Option Expr) : Bool := match parent? with | none => false - | some parent => isLinearTerm parent isNatExpr || isLinearCnstr parent || isDvdCnstr parent + | some parent => isLinearTerm parent || isLinearCnstr parent || isDvdCnstr parent end Lean.Meta.Simp.Arith diff --git a/src/Lean/Meta/Tactic/Simp/Arith/Int/Simp.lean b/src/Lean/Meta/Tactic/Simp/Arith/Int/Simp.lean index 7c95609690..d9c8298c16 100644 --- a/src/Lean/Meta/Tactic/Simp/Arith/Int/Simp.lean +++ b/src/Lean/Meta/Tactic/Simp/Arith/Int/Simp.lean @@ -105,36 +105,30 @@ def simpLe? (e : Expr) (checkIfModified : Bool) : MetaM (Option (Expr × Expr)) def simpRel? (e : Expr) : MetaM (Option (Expr × Expr)) := do if let some arg := e.not? then - let mut eNew? := none - let mut thmName := Name.anonymous + let mut eNew? := none + let mut h₁ := default match_expr arg with | LE.le α _ lhs rhs => let_expr Int ← α | pure () eNew? := some (mkIntLE (mkIntAdd rhs (mkIntLit 1)) lhs) - thmName := ``Int.not_le_eq + h₁ := mkApp2 (mkConst ``Int.not_le_eq) lhs rhs | GE.ge α _ lhs rhs => let_expr Int ← α | pure () eNew? := some (mkIntLE (mkIntAdd lhs (mkIntLit 1)) rhs) - thmName := ``Int.not_ge_eq + h₁ := mkApp2 (mkConst ``Int.not_ge_eq) lhs rhs | LT.lt α _ lhs rhs => let_expr Int ← α | pure () eNew? := some (mkIntLE rhs lhs) - thmName := ``Int.not_lt_eq + h₁ := mkApp2 (mkConst ``Int.not_lt_eq) lhs rhs | GT.gt α _ lhs rhs => let_expr Int ← α | pure () eNew? := some (mkIntLE lhs rhs) - thmName := ``Int.not_gt_eq + h₁ := mkApp2 (mkConst ``Int.not_gt_eq) lhs rhs | _ => pure () - if let some eNew := eNew? then - let h₁ := mkApp2 (mkConst thmName) (arg.getArg! 2) (arg.getArg! 3) - -- Already modified - if let some (eNew', h₂) ← simpLe? eNew (checkIfModified := false) then - let h := mkApp6 (mkConst ``Eq.trans [levelOne]) (mkSort levelZero) e eNew eNew' h₁ h₂ - return some (eNew', h) - else - return some (eNew, h₁) - else - return none + let some eNew := eNew? | return none + let some (eNew', h₂) ← simpLe? eNew (checkIfModified := false) | return (eNew, h₁) + let h := mkApp6 (mkConst ``Eq.trans [levelOne]) (mkSort levelZero) e eNew eNew' h₁ h₂ + return some (eNew', h) else simpLe? e (checkIfModified := true) diff --git a/src/Lean/Meta/Tactic/Simp/Arith/Nat/Simp.lean b/src/Lean/Meta/Tactic/Simp/Arith/Nat/Simp.lean index 1c04bb152e..48f8226d79 100644 --- a/src/Lean/Meta/Tactic/Simp/Arith/Nat/Simp.lean +++ b/src/Lean/Meta/Tactic/Simp/Arith/Nat/Simp.lean @@ -35,35 +35,30 @@ def simpCnstrPos? (e : Expr) : MetaM (Option (Expr × Expr)) := do def simpCnstr? (e : Expr) : MetaM (Option (Expr × Expr)) := do if let some arg := e.not? then - let mut eNew? := none - let mut thmName := Name.anonymous + let mut eNew? := none + let mut h₁ := default match_expr arg with - | LE.le α _ _ _ => - if α.isConstOf ``Nat then - eNew? := some (mkNatLE (mkNatAdd (arg.getArg! 3) (mkNatLit 1)) (arg.getArg! 2)) - thmName := ``Nat.not_le_eq - | GE.ge α _ _ _ => - if α.isConstOf ``Nat then - eNew? := some (mkNatLE (mkNatAdd (arg.getArg! 2) (mkNatLit 1)) (arg.getArg! 3)) - thmName := ``Nat.not_ge_eq - | LT.lt α _ _ _ => - if α.isConstOf ``Nat then - eNew? := some (mkNatLE (arg.getArg! 3) (arg.getArg! 2)) - thmName := ``Nat.not_lt_eq - | GT.gt α _ _ _ => - if α.isConstOf ``Nat then - eNew? := some (mkNatLE (arg.getArg! 2) (arg.getArg! 3)) - thmName := ``Nat.not_gt_eq + | LE.le α _ a b => + let_expr Nat ← α | pure () + eNew? := some (mkNatLE (mkNatAdd b (mkNatLit 1)) a) + h₁ := mkApp2 (mkConst ``Nat.not_le_eq) a b + | GE.ge α _ a b => + let_expr Nat ← α | pure () + eNew? := some (mkNatLE (mkNatAdd a (mkNatLit 1)) b) + h₁ := mkApp2 (mkConst ``Nat.not_ge_eq) a b + | LT.lt α _ a b => + let_expr Nat ← α | pure () + eNew? := some (mkNatLE b a) + h₁ := mkApp2 (mkConst ``Nat.not_lt_eq) a b + | GT.gt α _ a b => + let_expr Nat ← α | pure () + eNew? := some (mkNatLE a b) + h₁ := mkApp2 (mkConst ``Nat.not_gt_eq) a b | _ => pure () - if let some eNew := eNew? then - let h₁ := mkApp2 (mkConst thmName) (arg.getArg! 2) (arg.getArg! 3) - if let some (eNew', h₂) ← simpCnstrPos? eNew then - let h := mkApp6 (mkConst ``Eq.trans [levelOne]) (mkSort levelZero) e eNew eNew' h₁ h₂ - return some (eNew', h) - else - return some (eNew, h₁) - else - return none + let some eNew := eNew? | return none + let some (eNew', h₂) ← simpCnstrPos? eNew | return (eNew, h₁) + let h := mkApp6 (mkConst ``Eq.trans [levelOne]) (mkSort levelZero) e eNew eNew' h₁ h₂ + return some (eNew', h) else simpCnstrPos? e diff --git a/src/Lean/Meta/Tactic/Simp/Arith/Util.lean b/src/Lean/Meta/Tactic/Simp/Arith/Util.lean index 729c018b42..dbba1ecb0a 100644 --- a/src/Lean/Meta/Tactic/Simp/Arith/Util.lean +++ b/src/Lean/Meta/Tactic/Simp/Arith/Util.lean @@ -33,33 +33,48 @@ def withAbstractAtoms (atoms : Array Expr) (type : Name) (k : Array Expr → Met return some (r, p) go 0 #[] #[] #[] -/-- Quick filter for linear terms. -/ -def isLinearTerm (e : Expr) (isNatExpr : Bool) : Bool := - let f := e.getAppFn - if !f.isConst then - false - else - let n := f.constName! - n == ``HAdd.hAdd || n == ``HMul.hMul || n == ``Neg.neg || n == ``Nat.succ - || n == ``Add.add || n == ``Mul.mul - -- Recall that `Nat.sub` is truncated - || (!isNatExpr && (n == ``HSub.hSub || n == ``Sub.sub)) +private def isSupportedType (type : Expr) : Bool := + match_expr type with + | Nat => true + | Int => true + | _ => false -/-- Quick filter for linear constraints. -/ -partial def isLinearCnstr (e : Expr) : Bool := - let f := e.getAppFn - if !f.isConst then - false - else - let n := f.constName! - if n == ``Eq || n == ``LT.lt || n == ``LE.le || n == ``GT.gt || n == ``GE.ge || n == ``Ne then - true - else if n == ``Not && e.getAppNumArgs == 1 then - isLinearCnstr e.appArg! - else - false +private def isSupportedCommRingType (type : Expr) : Bool := + match_expr type with + | Int => true + | _ => false + +/-- Quick filter for linear terms. -/ +def isLinearTerm? (e : Expr) : Option Expr := + match_expr e with + | HAdd.hAdd α _ _ _ _ _ => .guard isSupportedType α + | HMul.hMul α _ _ _ _ _ => .guard isSupportedType α + | HSub.hSub α _ _ _ _ _ => .guard isSupportedCommRingType α + | Neg.neg α _ _ => .guard isSupportedCommRingType α + | Nat.succ _ => some Nat.mkType + | _ => none + +def isLinearTerm (e : Expr) : Bool := + isLinearTerm? e |>.isSome + +def isLinearPosCnstr (e : Expr) : Bool := + match_expr e with + | Eq α _ _ => isSupportedType α + | Ne α _ _ => isSupportedType α + | LE.le α _ _ _ => isSupportedType α + | LT.lt α _ _ _ => isSupportedType α + | GT.gt α _ _ _ => isSupportedType α + | GE.ge α _ _ _ => isSupportedType α + | _ => false + +def isLinearCnstr (e : Expr) : Bool := + match_expr e with + | Not p => isLinearPosCnstr p + | _ => isLinearPosCnstr e def isDvdCnstr (e : Expr) : Bool := - e.isAppOfArity ``Dvd.dvd 4 + match_expr e with + | Dvd.dvd α _ _ _ => isSupportedType α + | _ => false end Lean.Meta.Simp.Arith diff --git a/src/Lean/Meta/Tactic/Simp/Rewrite.lean b/src/Lean/Meta/Tactic/Simp/Rewrite.lean index 6e39d5a536..0c2b0d3081 100644 --- a/src/Lean/Meta/Tactic/Simp/Rewrite.lean +++ b/src/Lean/Meta/Tactic/Simp/Rewrite.lean @@ -289,11 +289,6 @@ where catch _ => return .continue -private def isNatExpr (e : Expr) : MetaM Bool := do - let type ← inferType e - let_expr Nat ← type | return false - return true - def simpArith (e : Expr) : SimpM Step := do unless (← getConfig).arith do return .continue @@ -305,18 +300,19 @@ def simpArith (e : Expr) : SimpM Step := do if let some (e', h) ← Arith.Int.simpEq? e then return .visit { expr := e', proof? := h } return .continue - let isNat ← isNatExpr e - if Arith.isLinearTerm e isNat then - if Arith.parentIsTarget (← getContext).parent? isNat then + if let some α := Arith.isLinearTerm? e then + if Arith.parentIsTarget (← getContext).parent? then -- We mark `cache := false` to ensure we do not miss simplifications. return .continue (some { expr := e, cache := false }) - if isNat then + match_expr α with + | Nat => let some (e', h) ← Arith.Nat.simpExpr? e | pure () return .visit { expr := e', proof? := h } - else + | Int => let some (e', h) ← Arith.Int.simpExpr? e | pure () return .visit { expr := e', proof? := h } - return .continue + | _ => + return .continue if Arith.isDvdCnstr e then let some (e', h) ← Arith.Int.simpDvd? e | pure () return .visit { expr := e', proof? := h }