perf: optimize simp congruence proofs (#11892)
This PR optimizes the construction on congruence proofs in `simp`. It uses some of the ideas used in `Sym.simp`.
This commit is contained in:
parent
609d99e860
commit
cf36ac986d
10 changed files with 136 additions and 48 deletions
|
|
@ -884,8 +884,7 @@ def congrDefault (e : Expr) : SimpM Result := do
|
|||
if let some result ← tryAutoCongrTheorem? e then
|
||||
result.mkEqTrans (← visitFn result.expr)
|
||||
else
|
||||
withParent e <| e.withApp fun f args => do
|
||||
congrArgs (← simp f) args
|
||||
withParent e <| simpAppUsingCongr e
|
||||
|
||||
/-- Process the given congruence theorem hypothesis. Return true if it made "progress". -/
|
||||
def processCongrHypothesis (h : Expr) (hType : Expr) : SimpM Bool := do
|
||||
|
|
|
|||
|
|
@ -647,6 +647,91 @@ def congrArgs (r : Result) (args : Array Expr) : SimpM Result := do
|
|||
i := i + 1
|
||||
return r
|
||||
|
||||
/-- Helper function for `simpAppUsingCongr` -/
|
||||
private def mkCongrFun' (e : Expr) (r : Result) (a : Expr) : MetaM Result := do
|
||||
let e' := e.updateApp! r.expr a
|
||||
match r.proof? with
|
||||
| none => return { expr := e', proof? := none }
|
||||
| some hf =>
|
||||
let α ← inferType a
|
||||
let u ← getLevel α
|
||||
let v ← getLevel (← inferType e)
|
||||
let f := e.appFn!
|
||||
let .forallE x _ βx _ ← whnfD (← inferType f)
|
||||
| throwError "failed to build congruence proof, function expected{indentExpr f}"
|
||||
let β := Lean.mkLambda x .default α βx
|
||||
return { expr := e', proof? := mkApp6 (mkConst ``congrFun [u, v]) α β f r.expr hf a }
|
||||
|
||||
/-- Helper function for `simpAppUsingCongr` -/
|
||||
private def mkCongrPrefix (declName : Name) (e : Expr) : MetaM Expr := do
|
||||
let α ← inferType e.appArg!
|
||||
let u ← getLevel α
|
||||
let β ← inferType e
|
||||
let v ← getLevel β
|
||||
return mkApp2 (mkConst declName [u, v]) α β
|
||||
|
||||
/-- Helper function for `simpAppUsingCongr` -/
|
||||
private def mkCongrArg' (e : Expr) (f : Expr) (r : Result) : MetaM Result := do
|
||||
let e' := e.updateApp! f r.expr
|
||||
match r.proof? with
|
||||
| none => return { expr := e', proof? := none }
|
||||
| some ha =>
|
||||
let h ← mkCongrPrefix ``congrArg e
|
||||
return { expr := e', proof? := mkApp4 h e.appArg! r.expr f ha }
|
||||
|
||||
/-- Helper function for `simpAppUsingCongr` -/
|
||||
private def mkCongr' (e : Expr) (r₁ r₂ : Result) : MetaM Result := do
|
||||
let e' := e.updateApp! r₁.expr r₂.expr
|
||||
match r₁.proof?, r₂.proof? with
|
||||
| none, none => return { expr := e', proof? := none }
|
||||
| some hf, none =>
|
||||
let h ← mkCongrPrefix ``congrFun' e
|
||||
return { expr := e', proof? := mkApp4 h e.appFn! r₁.expr hf r₂.expr }
|
||||
| none, some ha =>
|
||||
let h ← mkCongrPrefix ``congrArg e
|
||||
return { expr := e', proof? := mkApp4 h e.appArg! r₂.expr r₁.expr ha }
|
||||
| some hf, some ha =>
|
||||
let h ← mkCongrPrefix ``congr e
|
||||
return { expr := e', proof? := mkApp6 h e.appFn! r₁.expr e.appArg! r₂.expr hf ha }
|
||||
|
||||
/--
|
||||
Given an application `e`, recursively simplifies its function and arguments and constructs a proof
|
||||
using `congrArg`, `congrFun`, `congrFun'` and `congr`.
|
||||
-/
|
||||
def simpAppUsingCongr (e : Expr) : SimpM Result := do
|
||||
let f := e.getAppFn
|
||||
let numArgs := e.getAppNumArgs
|
||||
let cfg ← getConfig
|
||||
let infos := (← getFunInfoNArgs f numArgs).paramInfo
|
||||
let rec visit (e : Expr) (i : Nat) : SimpM Result := do
|
||||
if i == 0 then
|
||||
simp f
|
||||
else
|
||||
let i := i - 1
|
||||
let .app f a := e | unreachable!
|
||||
let fr ← visit f i
|
||||
if h : i < infos.size then
|
||||
let info := infos[i]
|
||||
trace[Debug.Meta.Tactic.simp] "app [{i}] {infos.size} {a} hasFwdDeps: {info.hasFwdDeps}"
|
||||
if cfg.ground && info.isInstImplicit then
|
||||
-- We don't visit instance implicit arguments when we are reducing ground terms.
|
||||
-- Motivation: many instance implicit arguments are ground, and it does not make sense
|
||||
-- to reduce them if the parent term is not ground.
|
||||
-- TODO: consider using it as the default behavior.
|
||||
-- We have considered it at https://github.com/leanprover/lean4/pull/3151
|
||||
mkCongrFun' e fr a
|
||||
else if !info.hasFwdDeps then
|
||||
mkCongr' e fr (← simp a)
|
||||
else if (← whnfD (← inferType f)).isArrow then
|
||||
mkCongr' e fr (← simp a)
|
||||
else
|
||||
mkCongrFun' e fr (← dsimp a)
|
||||
else if (← whnfD (← inferType f)).isArrow then
|
||||
mkCongr' e fr (← simp a)
|
||||
else
|
||||
mkCongrFun' e fr (← dsimp a)
|
||||
visit e numArgs
|
||||
|
||||
/--
|
||||
Retrieve auto-generated congruence lemma for `f`.
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ foo test
|
|||
|
||||
/--
|
||||
info: theorem test : ∀ (x : Foo✝), f✝ x = 42 :=
|
||||
fun x => of_eq_true (Eq.trans (congrArg (fun x => x = 42) (Foo.prop✝ x)) (eq_self 42))
|
||||
fun x => of_eq_true (Eq.trans (congrFun' (congrArg Eq (Foo.prop✝ x)) 42) (eq_self 42))
|
||||
-/
|
||||
#guard_msgs in
|
||||
#print test
|
||||
|
|
|
|||
|
|
@ -16,10 +16,10 @@ trace: [simp] Diagnostics
|
|||
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
|
||||
---
|
||||
trace: [diag] Diagnostics
|
||||
[kernel] unfolded declarations (max: 147, num: 3):
|
||||
[kernel] OfNat.ofNat ↦ 147
|
||||
[kernel] Add.add ↦ 61
|
||||
[kernel] HAdd.hAdd ↦ 61
|
||||
[kernel] unfolded declarations (max: 176, num: 3):
|
||||
[kernel] OfNat.ofNat ↦ 176
|
||||
[kernel] Add.add ↦ 60
|
||||
[kernel] HAdd.hAdd ↦ 60
|
||||
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
|
||||
-/
|
||||
#guard_msgs in
|
||||
|
|
@ -42,10 +42,10 @@ trace: [simp] Diagnostics
|
|||
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
|
||||
---
|
||||
trace: [diag] Diagnostics
|
||||
[kernel] unfolded declarations (max: 145, num: 3):
|
||||
[kernel] OfNat.ofNat ↦ 145
|
||||
[kernel] Add.add ↦ 59
|
||||
[kernel] HAdd.hAdd ↦ 59
|
||||
[kernel] unfolded declarations (max: 174, num: 3):
|
||||
[kernel] OfNat.ofNat ↦ 174
|
||||
[kernel] Add.add ↦ 58
|
||||
[kernel] HAdd.hAdd ↦ 58
|
||||
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
|
||||
-/
|
||||
#guard_msgs in
|
||||
|
|
@ -68,10 +68,10 @@ trace: [simp] Diagnostics
|
|||
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
|
||||
---
|
||||
trace: [diag] Diagnostics
|
||||
[kernel] unfolded declarations (max: 145, num: 3):
|
||||
[kernel] OfNat.ofNat ↦ 145
|
||||
[kernel] Add.add ↦ 59
|
||||
[kernel] HAdd.hAdd ↦ 59
|
||||
[kernel] unfolded declarations (max: 174, num: 3):
|
||||
[kernel] OfNat.ofNat ↦ 174
|
||||
[kernel] Add.add ↦ 58
|
||||
[kernel] HAdd.hAdd ↦ 58
|
||||
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
|
||||
-/
|
||||
#guard_msgs in
|
||||
|
|
@ -91,12 +91,12 @@ trace: [simp] Diagnostics
|
|||
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
|
||||
---
|
||||
trace: [diag] Diagnostics
|
||||
[def_eq] heuristic for solving `f a =?= f b` (max: 103, num: 1):
|
||||
[def_eq] ack ↦ 103
|
||||
[kernel] unfolded declarations (max: 145, num: 3):
|
||||
[kernel] OfNat.ofNat ↦ 145
|
||||
[kernel] Add.add ↦ 59
|
||||
[kernel] HAdd.hAdd ↦ 59
|
||||
[def_eq] heuristic for solving `f a =?= f b` (max: 60, num: 1):
|
||||
[def_eq] ack ↦ 60
|
||||
[kernel] unfolded declarations (max: 174, num: 3):
|
||||
[kernel] OfNat.ofNat ↦ 174
|
||||
[kernel] Add.add ↦ 58
|
||||
[kernel] HAdd.hAdd ↦ 58
|
||||
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
|
||||
-/
|
||||
#guard_msgs in
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ theorem ex1 : f (f (x + 1)) = x + 3 := by
|
|||
info: theorem ex1 : ∀ {x : Nat}, f (f (x + 1)) = x + 3 :=
|
||||
fun {x} =>
|
||||
of_eq_true
|
||||
(Eq.trans (congrArg (fun x_1 => x_1 = x + 3) (Eq.trans (congrArg f (f_eq x)) (f_eq (x + 1)))) (eq_self (x + 1 + 2)))
|
||||
(Eq.trans (congrFun' (congrArg Eq (Eq.trans (congrArg f (f_eq x)) (f_eq (x + 1)))) (x + 3)) (eq_self (x + 1 + 2)))
|
||||
-/
|
||||
#guard_msgs in
|
||||
#print ex1
|
||||
|
|
|
|||
|
|
@ -24,8 +24,6 @@ h : k = 2008 ^ 2 + 2 ^ 2008
|
|||
⊢ ((4032064 + 2 ^ 2008) ^ 2 + 2 ^ (4032064 + 2 ^ 2008)) % 10 = 6
|
||||
---
|
||||
warning: declaration uses `sorry`
|
||||
---
|
||||
error: (kernel) deep recursion detected
|
||||
-/
|
||||
#guard_msgs in
|
||||
example (k : Nat) (h : k = 2008^2 + 2^2008) : (k^2 + 2^k)%10 = 6 := by
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ theorem ex1 (a b c : α) : f (f a b) c = a := by
|
|||
info: theorem ex1.{u_1} : ∀ {α : Sort u_1} (a b c : α), f (f a b) c = a :=
|
||||
fun {α} a b c =>
|
||||
of_eq_true
|
||||
(Eq.trans (congrArg (fun x => x = a) (Eq.trans (congrArg (fun x => f x c) (f_Eq a b)) (f_Eq a c))) (eq_self a))
|
||||
(Eq.trans (congrFun' (congrArg Eq (Eq.trans (congrFun' (congrArg f (f_Eq a b)) c) (f_Eq a c))) a) (eq_self a))
|
||||
-/
|
||||
#guard_msgs in
|
||||
#print ex1
|
||||
|
|
@ -33,7 +33,7 @@ info: theorem ex2 : ∀ (p : Nat → Bool) (x : Nat), p x = true → (if p x = t
|
|||
fun p x h =>
|
||||
of_eq_true
|
||||
(Eq.trans
|
||||
(congrArg (fun x => x = 1) (ite_cond_eq_true 1 2 (Eq.trans (congrArg (fun x => x = true) h) (eq_self true))))
|
||||
(congrFun' (congrArg Eq (ite_cond_eq_true 1 2 (Eq.trans (congrFun' (congrArg Eq h) true) (eq_self true)))) 1)
|
||||
(eq_self 1))
|
||||
-/
|
||||
#guard_msgs in
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ theorem ex5 : (10 = 20) = False :=
|
|||
|
||||
/--
|
||||
info: theorem ex5 : (10 = 20) = False :=
|
||||
of_eq_true (Eq.trans (congrArg (fun x => x = False) (eq_false_of_decide (Eq.refl false))) (eq_self False))
|
||||
of_eq_true (Eq.trans (congrFun' (congrArg Eq (eq_false_of_decide (Eq.refl false))) False) (eq_self False))
|
||||
-/
|
||||
#guard_msgs in
|
||||
#print ex5
|
||||
|
|
|
|||
|
|
@ -286,13 +286,15 @@ info: theorem ex3 : ∀ (a b : Int), 6 ∣ a + (21 - a) + 3 * (a + 2 * b) + 12
|
|||
fun a b =>
|
||||
of_eq_true
|
||||
(Eq.trans
|
||||
(congrArg (fun x => x ↔ 2 ∣ a + 2 * b + 11)
|
||||
(id
|
||||
(norm_dvd_gcd (RArray.branch 1 (RArray.leaf b) (RArray.leaf a)) 6
|
||||
((((Expr.var 1).add ((Expr.num 21).sub (Expr.var 1))).add
|
||||
(Expr.mulL 3 ((Expr.var 1).add (Expr.mulL 2 (Expr.var 0))))).add
|
||||
(Expr.num 12))
|
||||
2 (Poly.add 1 1 (Poly.add 2 0 (Poly.num 11))) 3 (eagerReduce (Eq.refl true)))))
|
||||
(congrFun'
|
||||
(congrArg Iff
|
||||
(id
|
||||
(norm_dvd_gcd (RArray.branch 1 (RArray.leaf b) (RArray.leaf a)) 6
|
||||
((((Expr.var 1).add ((Expr.num 21).sub (Expr.var 1))).add
|
||||
(Expr.mulL 3 ((Expr.var 1).add (Expr.mulL 2 (Expr.var 0))))).add
|
||||
(Expr.num 12))
|
||||
2 (Poly.add 1 1 (Poly.add 2 0 (Poly.num 11))) 3 (eagerReduce (Eq.refl true)))))
|
||||
(2 ∣ a + 2 * b + 11))
|
||||
(iff_self (2 ∣ a + 2 * b + 11)))
|
||||
-/
|
||||
#guard_msgs (info) in
|
||||
|
|
|
|||
|
|
@ -11,13 +11,15 @@ theorem ex0 : ∀ (x : Nat),
|
|||
fun x h =>
|
||||
Eq.mpr
|
||||
(id
|
||||
(congrArg (fun x => x = 1)
|
||||
(id
|
||||
(congrFun'
|
||||
(congrArg Eq
|
||||
(id
|
||||
(have_congr' (Nat.zero_add (x * x)) fun y =>
|
||||
ite_congr (Eq.trans (congrArg (fun x_1 => x_1 = x) h) (eq_self x)) (fun a => Eq.refl 1) fun a =>
|
||||
Eq.refl (y + 1))))))
|
||||
(of_eq_true (Eq.trans (congrArg (fun x => x = 1) (ite_cond_eq_true 1 (x * x + 1) (Eq.refl True))) (eq_self 1)))
|
||||
(id
|
||||
(have_congr' (Nat.zero_add (x * x)) fun y =>
|
||||
ite_congr (Eq.trans (congrFun' (congrArg Eq h) x) (eq_self x)) (fun a => Eq.refl 1) fun a =>
|
||||
Eq.refl (y + 1)))))
|
||||
1))
|
||||
(of_eq_true (Eq.trans (congrFun' (congrArg Eq (ite_cond_eq_true 1 (x * x + 1) (Eq.refl True))) 1) (eq_self 1)))
|
||||
x : Nat
|
||||
h : f (f x) = x
|
||||
⊢ (have y := x * x;
|
||||
|
|
@ -31,13 +33,15 @@ theorem ex1 : ∀ (x : Nat),
|
|||
fun x h =>
|
||||
Eq.mpr
|
||||
(id
|
||||
(congrArg (fun x => x = 1)
|
||||
(id
|
||||
(congrFun'
|
||||
(congrArg Eq
|
||||
(id
|
||||
(have_body_congr' (x * x) fun y =>
|
||||
ite_congr (Eq.trans (congrArg (fun x_1 => x_1 = x) h) (eq_self x)) (fun a => Eq.refl 1) fun a =>
|
||||
Eq.refl (y + 1))))))
|
||||
(of_eq_true (Eq.trans (congrArg (fun x => x = 1) (ite_cond_eq_true 1 (x * x + 1) (Eq.refl True))) (eq_self 1)))
|
||||
(id
|
||||
(have_body_congr' (x * x) fun y =>
|
||||
ite_congr (Eq.trans (congrFun' (congrArg Eq h) x) (eq_self x)) (fun a => Eq.refl 1) fun a =>
|
||||
Eq.refl (y + 1)))))
|
||||
1))
|
||||
(of_eq_true (Eq.trans (congrFun' (congrArg Eq (ite_cond_eq_true 1 (x * x + 1) (Eq.refl True))) 1) (eq_self 1)))
|
||||
x z : Nat
|
||||
h : f (f x) = x
|
||||
h' : z = x
|
||||
|
|
@ -51,7 +55,7 @@ theorem ex2 : ∀ (x z : Nat),
|
|||
y) =
|
||||
z :=
|
||||
fun x z h h' =>
|
||||
Eq.mpr (id (congrArg (fun x => x = z) (id (id (have_val_congr' h)))))
|
||||
Eq.mpr (id (congrFun' (congrArg Eq (id (id (have_val_congr' h)))) z))
|
||||
(of_eq_true (Eq.trans (congrArg (Eq x) h') (eq_self x)))
|
||||
x z : Nat
|
||||
⊢ (let α := Nat;
|
||||
|
|
@ -69,5 +73,5 @@ theorem ex4 : ∀ (p : Prop),
|
|||
fun z => p :=
|
||||
fun p h =>
|
||||
Eq.mpr
|
||||
(id (congrArg (fun x => x = fun z => p) (id (id (have_body_congr_dep' 10 fun n => funext fun x => eq_self x)))))
|
||||
(id (congrFun' (congrArg Eq (id (id (have_body_congr_dep' 10 fun n => funext fun x => eq_self x)))) fun z => p))
|
||||
(of_eq_true (Eq.trans (congrArg (Eq fun x => True) (funext fun z => eq_true h)) (eq_self fun x => True)))
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue