diff --git a/src/Lean/Meta/Tactic/Simp/Main.lean b/src/Lean/Meta/Tactic/Simp/Main.lean index 67215c5976..f27329176d 100644 --- a/src/Lean/Meta/Tactic/Simp/Main.lean +++ b/src/Lean/Meta/Tactic/Simp/Main.lean @@ -884,8 +884,7 @@ def congrDefault (e : Expr) : SimpM Result := do if let some result ← tryAutoCongrTheorem? e then result.mkEqTrans (← visitFn result.expr) else - withParent e <| e.withApp fun f args => do - congrArgs (← simp f) args + withParent e <| simpAppUsingCongr e /-- Process the given congruence theorem hypothesis. Return true if it made "progress". -/ def processCongrHypothesis (h : Expr) (hType : Expr) : SimpM Bool := do diff --git a/src/Lean/Meta/Tactic/Simp/Types.lean b/src/Lean/Meta/Tactic/Simp/Types.lean index b5da990960..c8f7ce971b 100644 --- a/src/Lean/Meta/Tactic/Simp/Types.lean +++ b/src/Lean/Meta/Tactic/Simp/Types.lean @@ -647,6 +647,91 @@ def congrArgs (r : Result) (args : Array Expr) : SimpM Result := do i := i + 1 return r +/-- Helper function for `simpAppUsingCongr` -/ +private def mkCongrFun' (e : Expr) (r : Result) (a : Expr) : MetaM Result := do + let e' := e.updateApp! r.expr a + match r.proof? with + | none => return { expr := e', proof? := none } + | some hf => + let α ← inferType a + let u ← getLevel α + let v ← getLevel (← inferType e) + let f := e.appFn! + let .forallE x _ βx _ ← whnfD (← inferType f) + | throwError "failed to build congruence proof, function expected{indentExpr f}" + let β := Lean.mkLambda x .default α βx + return { expr := e', proof? := mkApp6 (mkConst ``congrFun [u, v]) α β f r.expr hf a } + +/-- Helper function for `simpAppUsingCongr` -/ +private def mkCongrPrefix (declName : Name) (e : Expr) : MetaM Expr := do + let α ← inferType e.appArg! + let u ← getLevel α + let β ← inferType e + let v ← getLevel β + return mkApp2 (mkConst declName [u, v]) α β + +/-- Helper function for `simpAppUsingCongr` -/ +private def mkCongrArg' (e : Expr) (f : Expr) (r : Result) : MetaM Result := do + let e' := e.updateApp! f r.expr + match r.proof? with + | none => return { expr := e', proof? := none } + | some ha => + let h ← mkCongrPrefix ``congrArg e + return { expr := e', proof? := mkApp4 h e.appArg! r.expr f ha } + +/-- Helper function for `simpAppUsingCongr` -/ +private def mkCongr' (e : Expr) (r₁ r₂ : Result) : MetaM Result := do + let e' := e.updateApp! r₁.expr r₂.expr + match r₁.proof?, r₂.proof? with + | none, none => return { expr := e', proof? := none } + | some hf, none => + let h ← mkCongrPrefix ``congrFun' e + return { expr := e', proof? := mkApp4 h e.appFn! r₁.expr hf r₂.expr } + | none, some ha => + let h ← mkCongrPrefix ``congrArg e + return { expr := e', proof? := mkApp4 h e.appArg! r₂.expr r₁.expr ha } + | some hf, some ha => + let h ← mkCongrPrefix ``congr e + return { expr := e', proof? := mkApp6 h e.appFn! r₁.expr e.appArg! r₂.expr hf ha } + +/-- +Given an application `e`, recursively simplifies its function and arguments and constructs a proof +using `congrArg`, `congrFun`, `congrFun'` and `congr`. +-/ +def simpAppUsingCongr (e : Expr) : SimpM Result := do + let f := e.getAppFn + let numArgs := e.getAppNumArgs + let cfg ← getConfig + let infos := (← getFunInfoNArgs f numArgs).paramInfo + let rec visit (e : Expr) (i : Nat) : SimpM Result := do + if i == 0 then + simp f + else + let i := i - 1 + let .app f a := e | unreachable! + let fr ← visit f i + if h : i < infos.size then + let info := infos[i] + trace[Debug.Meta.Tactic.simp] "app [{i}] {infos.size} {a} hasFwdDeps: {info.hasFwdDeps}" + if cfg.ground && info.isInstImplicit then + -- We don't visit instance implicit arguments when we are reducing ground terms. + -- Motivation: many instance implicit arguments are ground, and it does not make sense + -- to reduce them if the parent term is not ground. + -- TODO: consider using it as the default behavior. + -- We have considered it at https://github.com/leanprover/lean4/pull/3151 + mkCongrFun' e fr a + else if !info.hasFwdDeps then + mkCongr' e fr (← simp a) + else if (← whnfD (← inferType f)).isArrow then + mkCongr' e fr (← simp a) + else + mkCongrFun' e fr (← dsimp a) + else if (← whnfD (← inferType f)).isArrow then + mkCongr' e fr (← simp a) + else + mkCongrFun' e fr (← dsimp a) + visit e numArgs + /-- Retrieve auto-generated congruence lemma for `f`. diff --git a/tests/lean/run/793.lean b/tests/lean/run/793.lean index 9760b82d72..cacc486d27 100644 --- a/tests/lean/run/793.lean +++ b/tests/lean/run/793.lean @@ -10,7 +10,7 @@ foo test /-- info: theorem test : ∀ (x : Foo✝), f✝ x = 42 := -fun x => of_eq_true (Eq.trans (congrArg (fun x => x = 42) (Foo.prop✝ x)) (eq_self 42)) +fun x => of_eq_true (Eq.trans (congrFun' (congrArg Eq (Foo.prop✝ x)) 42) (eq_self 42)) -/ #guard_msgs in #print test diff --git a/tests/lean/run/ack.lean b/tests/lean/run/ack.lean index 1490be58c1..7ed74ff009 100644 --- a/tests/lean/run/ack.lean +++ b/tests/lean/run/ack.lean @@ -16,10 +16,10 @@ trace: [simp] Diagnostics use `set_option diagnostics.threshold ` to control threshold for reporting counters --- trace: [diag] Diagnostics - [kernel] unfolded declarations (max: 147, num: 3): - [kernel] OfNat.ofNat ↦ 147 - [kernel] Add.add ↦ 61 - [kernel] HAdd.hAdd ↦ 61 + [kernel] unfolded declarations (max: 176, num: 3): + [kernel] OfNat.ofNat ↦ 176 + [kernel] Add.add ↦ 60 + [kernel] HAdd.hAdd ↦ 60 use `set_option diagnostics.threshold ` to control threshold for reporting counters -/ #guard_msgs in @@ -42,10 +42,10 @@ trace: [simp] Diagnostics use `set_option diagnostics.threshold ` to control threshold for reporting counters --- trace: [diag] Diagnostics - [kernel] unfolded declarations (max: 145, num: 3): - [kernel] OfNat.ofNat ↦ 145 - [kernel] Add.add ↦ 59 - [kernel] HAdd.hAdd ↦ 59 + [kernel] unfolded declarations (max: 174, num: 3): + [kernel] OfNat.ofNat ↦ 174 + [kernel] Add.add ↦ 58 + [kernel] HAdd.hAdd ↦ 58 use `set_option diagnostics.threshold ` to control threshold for reporting counters -/ #guard_msgs in @@ -68,10 +68,10 @@ trace: [simp] Diagnostics use `set_option diagnostics.threshold ` to control threshold for reporting counters --- trace: [diag] Diagnostics - [kernel] unfolded declarations (max: 145, num: 3): - [kernel] OfNat.ofNat ↦ 145 - [kernel] Add.add ↦ 59 - [kernel] HAdd.hAdd ↦ 59 + [kernel] unfolded declarations (max: 174, num: 3): + [kernel] OfNat.ofNat ↦ 174 + [kernel] Add.add ↦ 58 + [kernel] HAdd.hAdd ↦ 58 use `set_option diagnostics.threshold ` to control threshold for reporting counters -/ #guard_msgs in @@ -91,12 +91,12 @@ trace: [simp] Diagnostics use `set_option diagnostics.threshold ` to control threshold for reporting counters --- trace: [diag] Diagnostics - [def_eq] heuristic for solving `f a =?= f b` (max: 103, num: 1): - [def_eq] ack ↦ 103 - [kernel] unfolded declarations (max: 145, num: 3): - [kernel] OfNat.ofNat ↦ 145 - [kernel] Add.add ↦ 59 - [kernel] HAdd.hAdd ↦ 59 + [def_eq] heuristic for solving `f a =?= f b` (max: 60, num: 1): + [def_eq] ack ↦ 60 + [kernel] unfolded declarations (max: 174, num: 3): + [kernel] OfNat.ofNat ↦ 174 + [kernel] Add.add ↦ 58 + [kernel] HAdd.hAdd ↦ 58 use `set_option diagnostics.threshold ` to control threshold for reporting counters -/ #guard_msgs in diff --git a/tests/lean/run/implicitRflProofs.lean b/tests/lean/run/implicitRflProofs.lean index bdc3ee73a2..d99eabb251 100644 --- a/tests/lean/run/implicitRflProofs.lean +++ b/tests/lean/run/implicitRflProofs.lean @@ -9,7 +9,7 @@ theorem ex1 : f (f (x + 1)) = x + 3 := by info: theorem ex1 : ∀ {x : Nat}, f (f (x + 1)) = x + 3 := fun {x} => of_eq_true - (Eq.trans (congrArg (fun x_1 => x_1 = x + 3) (Eq.trans (congrArg f (f_eq x)) (f_eq (x + 1)))) (eq_self (x + 1 + 2))) + (Eq.trans (congrFun' (congrArg Eq (Eq.trans (congrArg f (f_eq x)) (f_eq (x + 1)))) (x + 3)) (eq_self (x + 1 + 2))) -/ #guard_msgs in #print ex1 diff --git a/tests/lean/run/safeExp.lean b/tests/lean/run/safeExp.lean index f17132aa03..9cac3b713a 100644 --- a/tests/lean/run/safeExp.lean +++ b/tests/lean/run/safeExp.lean @@ -24,8 +24,6 @@ h : k = 2008 ^ 2 + 2 ^ 2008 ⊢ ((4032064 + 2 ^ 2008) ^ 2 + 2 ^ (4032064 + 2 ^ 2008)) % 10 = 6 --- warning: declaration uses `sorry` ---- -error: (kernel) deep recursion detected -/ #guard_msgs in example (k : Nat) (h : k = 2008^2 + 2^2008) : (k^2 + 2^k)%10 = 6 := by diff --git a/tests/lean/run/simp5.lean b/tests/lean/run/simp5.lean index 5d77b52d00..ebb49a87a2 100644 --- a/tests/lean/run/simp5.lean +++ b/tests/lean/run/simp5.lean @@ -10,7 +10,7 @@ theorem ex1 (a b c : α) : f (f a b) c = a := by info: theorem ex1.{u_1} : ∀ {α : Sort u_1} (a b c : α), f (f a b) c = a := fun {α} a b c => of_eq_true - (Eq.trans (congrArg (fun x => x = a) (Eq.trans (congrArg (fun x => f x c) (f_Eq a b)) (f_Eq a c))) (eq_self a)) + (Eq.trans (congrFun' (congrArg Eq (Eq.trans (congrFun' (congrArg f (f_Eq a b)) c) (f_Eq a c))) a) (eq_self a)) -/ #guard_msgs in #print ex1 @@ -33,7 +33,7 @@ info: theorem ex2 : ∀ (p : Nat → Bool) (x : Nat), p x = true → (if p x = t fun p x h => of_eq_true (Eq.trans - (congrArg (fun x => x = 1) (ite_cond_eq_true 1 2 (Eq.trans (congrArg (fun x => x = true) h) (eq_self true)))) + (congrFun' (congrArg Eq (ite_cond_eq_true 1 2 (Eq.trans (congrFun' (congrArg Eq h) true) (eq_self true)))) 1) (eq_self 1)) -/ #guard_msgs in diff --git a/tests/lean/run/simp6.lean b/tests/lean/run/simp6.lean index 4cd447f770..39b0ab159a 100644 --- a/tests/lean/run/simp6.lean +++ b/tests/lean/run/simp6.lean @@ -15,7 +15,7 @@ theorem ex5 : (10 = 20) = False := /-- info: theorem ex5 : (10 = 20) = False := -of_eq_true (Eq.trans (congrArg (fun x => x = False) (eq_false_of_decide (Eq.refl false))) (eq_self False)) +of_eq_true (Eq.trans (congrFun' (congrArg Eq (eq_false_of_decide (Eq.refl false))) False) (eq_self False)) -/ #guard_msgs in #print ex5 diff --git a/tests/lean/run/simp_int_arith.lean b/tests/lean/run/simp_int_arith.lean index f0c3d1cdb2..5231493fdd 100644 --- a/tests/lean/run/simp_int_arith.lean +++ b/tests/lean/run/simp_int_arith.lean @@ -286,13 +286,15 @@ info: theorem ex3 : ∀ (a b : Int), 6 ∣ a + (21 - a) + 3 * (a + 2 * b) + 12 fun a b => of_eq_true (Eq.trans - (congrArg (fun x => x ↔ 2 ∣ a + 2 * b + 11) - (id - (norm_dvd_gcd (RArray.branch 1 (RArray.leaf b) (RArray.leaf a)) 6 - ((((Expr.var 1).add ((Expr.num 21).sub (Expr.var 1))).add - (Expr.mulL 3 ((Expr.var 1).add (Expr.mulL 2 (Expr.var 0))))).add - (Expr.num 12)) - 2 (Poly.add 1 1 (Poly.add 2 0 (Poly.num 11))) 3 (eagerReduce (Eq.refl true))))) + (congrFun' + (congrArg Iff + (id + (norm_dvd_gcd (RArray.branch 1 (RArray.leaf b) (RArray.leaf a)) 6 + ((((Expr.var 1).add ((Expr.num 21).sub (Expr.var 1))).add + (Expr.mulL 3 ((Expr.var 1).add (Expr.mulL 2 (Expr.var 0))))).add + (Expr.num 12)) + 2 (Poly.add 1 1 (Poly.add 2 0 (Poly.num 11))) 3 (eagerReduce (Eq.refl true))))) + (2 ∣ a + 2 * b + 11)) (iff_self (2 ∣ a + 2 * b + 11))) -/ #guard_msgs (info) in diff --git a/tests/lean/simpZetaFalse.lean.expected.out b/tests/lean/simpZetaFalse.lean.expected.out index 12652d3f77..029aa9d975 100644 --- a/tests/lean/simpZetaFalse.lean.expected.out +++ b/tests/lean/simpZetaFalse.lean.expected.out @@ -11,13 +11,15 @@ theorem ex0 : ∀ (x : Nat), fun x h => Eq.mpr (id - (congrArg (fun x => x = 1) - (id + (congrFun' + (congrArg Eq (id - (have_congr' (Nat.zero_add (x * x)) fun y => - ite_congr (Eq.trans (congrArg (fun x_1 => x_1 = x) h) (eq_self x)) (fun a => Eq.refl 1) fun a => - Eq.refl (y + 1)))))) - (of_eq_true (Eq.trans (congrArg (fun x => x = 1) (ite_cond_eq_true 1 (x * x + 1) (Eq.refl True))) (eq_self 1))) + (id + (have_congr' (Nat.zero_add (x * x)) fun y => + ite_congr (Eq.trans (congrFun' (congrArg Eq h) x) (eq_self x)) (fun a => Eq.refl 1) fun a => + Eq.refl (y + 1))))) + 1)) + (of_eq_true (Eq.trans (congrFun' (congrArg Eq (ite_cond_eq_true 1 (x * x + 1) (Eq.refl True))) 1) (eq_self 1))) x : Nat h : f (f x) = x ⊢ (have y := x * x; @@ -31,13 +33,15 @@ theorem ex1 : ∀ (x : Nat), fun x h => Eq.mpr (id - (congrArg (fun x => x = 1) - (id + (congrFun' + (congrArg Eq (id - (have_body_congr' (x * x) fun y => - ite_congr (Eq.trans (congrArg (fun x_1 => x_1 = x) h) (eq_self x)) (fun a => Eq.refl 1) fun a => - Eq.refl (y + 1)))))) - (of_eq_true (Eq.trans (congrArg (fun x => x = 1) (ite_cond_eq_true 1 (x * x + 1) (Eq.refl True))) (eq_self 1))) + (id + (have_body_congr' (x * x) fun y => + ite_congr (Eq.trans (congrFun' (congrArg Eq h) x) (eq_self x)) (fun a => Eq.refl 1) fun a => + Eq.refl (y + 1))))) + 1)) + (of_eq_true (Eq.trans (congrFun' (congrArg Eq (ite_cond_eq_true 1 (x * x + 1) (Eq.refl True))) 1) (eq_self 1))) x z : Nat h : f (f x) = x h' : z = x @@ -51,7 +55,7 @@ theorem ex2 : ∀ (x z : Nat), y) = z := fun x z h h' => - Eq.mpr (id (congrArg (fun x => x = z) (id (id (have_val_congr' h))))) + Eq.mpr (id (congrFun' (congrArg Eq (id (id (have_val_congr' h)))) z)) (of_eq_true (Eq.trans (congrArg (Eq x) h') (eq_self x))) x z : Nat ⊢ (let α := Nat; @@ -69,5 +73,5 @@ theorem ex4 : ∀ (p : Prop), fun z => p := fun p h => Eq.mpr - (id (congrArg (fun x => x = fun z => p) (id (id (have_body_congr_dep' 10 fun n => funext fun x => eq_self x))))) + (id (congrFun' (congrArg Eq (id (id (have_body_congr_dep' 10 fun n => funext fun x => eq_self x)))) fun z => p)) (of_eq_true (Eq.trans (congrArg (Eq fun x => True) (funext fun z => eq_true h)) (eq_self fun x => True)))