perf: optimize simp congruence proofs (#11892)

This PR optimizes the construction on congruence proofs in `simp`.
It uses some of the ideas used in `Sym.simp`.
This commit is contained in:
Leonardo de Moura 2026-01-04 11:37:21 -08:00 committed by GitHub
parent 609d99e860
commit cf36ac986d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 136 additions and 48 deletions

View file

@ -884,8 +884,7 @@ def congrDefault (e : Expr) : SimpM Result := do
if let some result ← tryAutoCongrTheorem? e then
result.mkEqTrans (← visitFn result.expr)
else
withParent e <| e.withApp fun f args => do
congrArgs (← simp f) args
withParent e <| simpAppUsingCongr e
/-- Process the given congruence theorem hypothesis. Return true if it made "progress". -/
def processCongrHypothesis (h : Expr) (hType : Expr) : SimpM Bool := do

View file

@ -647,6 +647,91 @@ def congrArgs (r : Result) (args : Array Expr) : SimpM Result := do
i := i + 1
return r
/-- Helper function for `simpAppUsingCongr` -/
private def mkCongrFun' (e : Expr) (r : Result) (a : Expr) : MetaM Result := do
let e' := e.updateApp! r.expr a
match r.proof? with
| none => return { expr := e', proof? := none }
| some hf =>
let α ← inferType a
let u ← getLevel α
let v ← getLevel (← inferType e)
let f := e.appFn!
let .forallE x _ βx _ ← whnfD (← inferType f)
| throwError "failed to build congruence proof, function expected{indentExpr f}"
let β := Lean.mkLambda x .default α βx
return { expr := e', proof? := mkApp6 (mkConst ``congrFun [u, v]) α β f r.expr hf a }
/-- Helper function for `simpAppUsingCongr` -/
private def mkCongrPrefix (declName : Name) (e : Expr) : MetaM Expr := do
let α ← inferType e.appArg!
let u ← getLevel α
let β ← inferType e
let v ← getLevel β
return mkApp2 (mkConst declName [u, v]) α β
/-- Helper function for `simpAppUsingCongr` -/
private def mkCongrArg' (e : Expr) (f : Expr) (r : Result) : MetaM Result := do
let e' := e.updateApp! f r.expr
match r.proof? with
| none => return { expr := e', proof? := none }
| some ha =>
let h ← mkCongrPrefix ``congrArg e
return { expr := e', proof? := mkApp4 h e.appArg! r.expr f ha }
/-- Helper function for `simpAppUsingCongr` -/
private def mkCongr' (e : Expr) (r₁ r₂ : Result) : MetaM Result := do
let e' := e.updateApp! r₁.expr r₂.expr
match r₁.proof?, r₂.proof? with
| none, none => return { expr := e', proof? := none }
| some hf, none =>
let h ← mkCongrPrefix ``congrFun' e
return { expr := e', proof? := mkApp4 h e.appFn! r₁.expr hf r₂.expr }
| none, some ha =>
let h ← mkCongrPrefix ``congrArg e
return { expr := e', proof? := mkApp4 h e.appArg! r₂.expr r₁.expr ha }
| some hf, some ha =>
let h ← mkCongrPrefix ``congr e
return { expr := e', proof? := mkApp6 h e.appFn! r₁.expr e.appArg! r₂.expr hf ha }
/--
Given an application `e`, recursively simplifies its function and arguments and constructs a proof
using `congrArg`, `congrFun`, `congrFun'` and `congr`.
-/
def simpAppUsingCongr (e : Expr) : SimpM Result := do
let f := e.getAppFn
let numArgs := e.getAppNumArgs
let cfg ← getConfig
let infos := (← getFunInfoNArgs f numArgs).paramInfo
let rec visit (e : Expr) (i : Nat) : SimpM Result := do
if i == 0 then
simp f
else
let i := i - 1
let .app f a := e | unreachable!
let fr ← visit f i
if h : i < infos.size then
let info := infos[i]
trace[Debug.Meta.Tactic.simp] "app [{i}] {infos.size} {a} hasFwdDeps: {info.hasFwdDeps}"
if cfg.ground && info.isInstImplicit then
-- We don't visit instance implicit arguments when we are reducing ground terms.
-- Motivation: many instance implicit arguments are ground, and it does not make sense
-- to reduce them if the parent term is not ground.
-- TODO: consider using it as the default behavior.
-- We have considered it at https://github.com/leanprover/lean4/pull/3151
mkCongrFun' e fr a
else if !info.hasFwdDeps then
mkCongr' e fr (← simp a)
else if (← whnfD (← inferType f)).isArrow then
mkCongr' e fr (← simp a)
else
mkCongrFun' e fr (← dsimp a)
else if (← whnfD (← inferType f)).isArrow then
mkCongr' e fr (← simp a)
else
mkCongrFun' e fr (← dsimp a)
visit e numArgs
/--
Retrieve auto-generated congruence lemma for `f`.

View file

@ -10,7 +10,7 @@ foo test
/--
info: theorem test : ∀ (x : Foo✝), f✝ x = 42 :=
fun x => of_eq_true (Eq.trans (congrArg (fun x => x = 42) (Foo.prop✝ x)) (eq_self 42))
fun x => of_eq_true (Eq.trans (congrFun' (congrArg Eq (Foo.prop✝ x)) 42) (eq_self 42))
-/
#guard_msgs in
#print test

View file

@ -16,10 +16,10 @@ trace: [simp] Diagnostics
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
---
trace: [diag] Diagnostics
[kernel] unfolded declarations (max: 147, num: 3):
[kernel] OfNat.ofNat ↦ 147
[kernel] Add.add ↦ 61
[kernel] HAdd.hAdd ↦ 61
[kernel] unfolded declarations (max: 176, num: 3):
[kernel] OfNat.ofNat ↦ 176
[kernel] Add.add ↦ 60
[kernel] HAdd.hAdd ↦ 60
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
-/
#guard_msgs in
@ -42,10 +42,10 @@ trace: [simp] Diagnostics
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
---
trace: [diag] Diagnostics
[kernel] unfolded declarations (max: 145, num: 3):
[kernel] OfNat.ofNat ↦ 145
[kernel] Add.add ↦ 59
[kernel] HAdd.hAdd ↦ 59
[kernel] unfolded declarations (max: 174, num: 3):
[kernel] OfNat.ofNat ↦ 174
[kernel] Add.add ↦ 58
[kernel] HAdd.hAdd ↦ 58
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
-/
#guard_msgs in
@ -68,10 +68,10 @@ trace: [simp] Diagnostics
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
---
trace: [diag] Diagnostics
[kernel] unfolded declarations (max: 145, num: 3):
[kernel] OfNat.ofNat ↦ 145
[kernel] Add.add ↦ 59
[kernel] HAdd.hAdd ↦ 59
[kernel] unfolded declarations (max: 174, num: 3):
[kernel] OfNat.ofNat ↦ 174
[kernel] Add.add ↦ 58
[kernel] HAdd.hAdd ↦ 58
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
-/
#guard_msgs in
@ -91,12 +91,12 @@ trace: [simp] Diagnostics
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
---
trace: [diag] Diagnostics
[def_eq] heuristic for solving `f a =?= f b` (max: 103, num: 1):
[def_eq] ack ↦ 103
[kernel] unfolded declarations (max: 145, num: 3):
[kernel] OfNat.ofNat ↦ 145
[kernel] Add.add ↦ 59
[kernel] HAdd.hAdd ↦ 59
[def_eq] heuristic for solving `f a =?= f b` (max: 60, num: 1):
[def_eq] ack ↦ 60
[kernel] unfolded declarations (max: 174, num: 3):
[kernel] OfNat.ofNat ↦ 174
[kernel] Add.add ↦ 58
[kernel] HAdd.hAdd ↦ 58
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
-/
#guard_msgs in

View file

@ -9,7 +9,7 @@ theorem ex1 : f (f (x + 1)) = x + 3 := by
info: theorem ex1 : ∀ {x : Nat}, f (f (x + 1)) = x + 3 :=
fun {x} =>
of_eq_true
(Eq.trans (congrArg (fun x_1 => x_1 = x + 3) (Eq.trans (congrArg f (f_eq x)) (f_eq (x + 1)))) (eq_self (x + 1 + 2)))
(Eq.trans (congrFun' (congrArg Eq (Eq.trans (congrArg f (f_eq x)) (f_eq (x + 1)))) (x + 3)) (eq_self (x + 1 + 2)))
-/
#guard_msgs in
#print ex1

View file

@ -24,8 +24,6 @@ h : k = 2008 ^ 2 + 2 ^ 2008
⊢ ((4032064 + 2 ^ 2008) ^ 2 + 2 ^ (4032064 + 2 ^ 2008)) % 10 = 6
---
warning: declaration uses `sorry`
---
error: (kernel) deep recursion detected
-/
#guard_msgs in
example (k : Nat) (h : k = 2008^2 + 2^2008) : (k^2 + 2^k)%10 = 6 := by

View file

@ -10,7 +10,7 @@ theorem ex1 (a b c : α) : f (f a b) c = a := by
info: theorem ex1.{u_1} : ∀ {α : Sort u_1} (a b c : α), f (f a b) c = a :=
fun {α} a b c =>
of_eq_true
(Eq.trans (congrArg (fun x => x = a) (Eq.trans (congrArg (fun x => f x c) (f_Eq a b)) (f_Eq a c))) (eq_self a))
(Eq.trans (congrFun' (congrArg Eq (Eq.trans (congrFun' (congrArg f (f_Eq a b)) c) (f_Eq a c))) a) (eq_self a))
-/
#guard_msgs in
#print ex1
@ -33,7 +33,7 @@ info: theorem ex2 : ∀ (p : Nat → Bool) (x : Nat), p x = true → (if p x = t
fun p x h =>
of_eq_true
(Eq.trans
(congrArg (fun x => x = 1) (ite_cond_eq_true 1 2 (Eq.trans (congrArg (fun x => x = true) h) (eq_self true))))
(congrFun' (congrArg Eq (ite_cond_eq_true 1 2 (Eq.trans (congrFun' (congrArg Eq h) true) (eq_self true)))) 1)
(eq_self 1))
-/
#guard_msgs in

View file

@ -15,7 +15,7 @@ theorem ex5 : (10 = 20) = False :=
/--
info: theorem ex5 : (10 = 20) = False :=
of_eq_true (Eq.trans (congrArg (fun x => x = False) (eq_false_of_decide (Eq.refl false))) (eq_self False))
of_eq_true (Eq.trans (congrFun' (congrArg Eq (eq_false_of_decide (Eq.refl false))) False) (eq_self False))
-/
#guard_msgs in
#print ex5

View file

@ -286,13 +286,15 @@ info: theorem ex3 : ∀ (a b : Int), 6 a + (21 - a) + 3 * (a + 2 * b) + 12
fun a b =>
of_eq_true
(Eq.trans
(congrArg (fun x => x ↔ 2 a + 2 * b + 11)
(id
(norm_dvd_gcd (RArray.branch 1 (RArray.leaf b) (RArray.leaf a)) 6
((((Expr.var 1).add ((Expr.num 21).sub (Expr.var 1))).add
(Expr.mulL 3 ((Expr.var 1).add (Expr.mulL 2 (Expr.var 0))))).add
(Expr.num 12))
2 (Poly.add 1 1 (Poly.add 2 0 (Poly.num 11))) 3 (eagerReduce (Eq.refl true)))))
(congrFun'
(congrArg Iff
(id
(norm_dvd_gcd (RArray.branch 1 (RArray.leaf b) (RArray.leaf a)) 6
((((Expr.var 1).add ((Expr.num 21).sub (Expr.var 1))).add
(Expr.mulL 3 ((Expr.var 1).add (Expr.mulL 2 (Expr.var 0))))).add
(Expr.num 12))
2 (Poly.add 1 1 (Poly.add 2 0 (Poly.num 11))) 3 (eagerReduce (Eq.refl true)))))
(2 a + 2 * b + 11))
(iff_self (2 a + 2 * b + 11)))
-/
#guard_msgs (info) in

View file

@ -11,13 +11,15 @@ theorem ex0 : ∀ (x : Nat),
fun x h =>
Eq.mpr
(id
(congrArg (fun x => x = 1)
(id
(congrFun'
(congrArg Eq
(id
(have_congr' (Nat.zero_add (x * x)) fun y =>
ite_congr (Eq.trans (congrArg (fun x_1 => x_1 = x) h) (eq_self x)) (fun a => Eq.refl 1) fun a =>
Eq.refl (y + 1))))))
(of_eq_true (Eq.trans (congrArg (fun x => x = 1) (ite_cond_eq_true 1 (x * x + 1) (Eq.refl True))) (eq_self 1)))
(id
(have_congr' (Nat.zero_add (x * x)) fun y =>
ite_congr (Eq.trans (congrFun' (congrArg Eq h) x) (eq_self x)) (fun a => Eq.refl 1) fun a =>
Eq.refl (y + 1)))))
1))
(of_eq_true (Eq.trans (congrFun' (congrArg Eq (ite_cond_eq_true 1 (x * x + 1) (Eq.refl True))) 1) (eq_self 1)))
x : Nat
h : f (f x) = x
⊢ (have y := x * x;
@ -31,13 +33,15 @@ theorem ex1 : ∀ (x : Nat),
fun x h =>
Eq.mpr
(id
(congrArg (fun x => x = 1)
(id
(congrFun'
(congrArg Eq
(id
(have_body_congr' (x * x) fun y =>
ite_congr (Eq.trans (congrArg (fun x_1 => x_1 = x) h) (eq_self x)) (fun a => Eq.refl 1) fun a =>
Eq.refl (y + 1))))))
(of_eq_true (Eq.trans (congrArg (fun x => x = 1) (ite_cond_eq_true 1 (x * x + 1) (Eq.refl True))) (eq_self 1)))
(id
(have_body_congr' (x * x) fun y =>
ite_congr (Eq.trans (congrFun' (congrArg Eq h) x) (eq_self x)) (fun a => Eq.refl 1) fun a =>
Eq.refl (y + 1)))))
1))
(of_eq_true (Eq.trans (congrFun' (congrArg Eq (ite_cond_eq_true 1 (x * x + 1) (Eq.refl True))) 1) (eq_self 1)))
x z : Nat
h : f (f x) = x
h' : z = x
@ -51,7 +55,7 @@ theorem ex2 : ∀ (x z : Nat),
y) =
z :=
fun x z h h' =>
Eq.mpr (id (congrArg (fun x => x = z) (id (id (have_val_congr' h)))))
Eq.mpr (id (congrFun' (congrArg Eq (id (id (have_val_congr' h)))) z))
(of_eq_true (Eq.trans (congrArg (Eq x) h') (eq_self x)))
x z : Nat
⊢ (let α := Nat;
@ -69,5 +73,5 @@ theorem ex4 : ∀ (p : Prop),
fun z => p :=
fun p h =>
Eq.mpr
(id (congrArg (fun x => x = fun z => p) (id (id (have_body_congr_dep' 10 fun n => funext fun x => eq_self x)))))
(id (congrFun' (congrArg Eq (id (id (have_body_congr_dep' 10 fun n => funext fun x => eq_self x)))) fun z => p))
(of_eq_true (Eq.trans (congrArg (Eq fun x => True) (funext fun z => eq_true h)) (eq_self fun x => True)))