From 0db4ac18e5bad56717125c916f5b87e90d430526 Mon Sep 17 00:00:00 2001 From: Kyle Miller Date: Thu, 21 May 2026 00:26:00 -0700 Subject: [PATCH] feat: beta reduce while elaborating applications (#13807) This PR modifies the app elaborator to beta reduce arguments while substituting them into expected types for later arguments. This makes it consistent with `inferType` and `instantiateMVars`, which both beta reduce substitutions. In particular, this change ensures that the app elaborator behaves as if it creates metavariables for each parameter and assigns elaborated arguments to the metavariables. **Breaking change:** tactic proofs may need to be modified to remove unnecessary steps, e.g. `dsimp only` steps that were previously for beta reductions. --- src/Init/Data/Dyadic/Inv.lean | 2 +- src/Init/Data/Int/Cooper.lean | 2 +- src/Init/Data/Int/Lemmas.lean | 3 ++- src/Init/Data/Int/Linear.lean | 4 ++-- src/Init/Data/String/Decode.lean | 8 ++++---- src/Init/Grind/FieldNormNum.lean | 3 ++- src/Init/Grind/Module/Envelope.lean | 2 +- src/Init/Grind/Ordered/Ring.lean | 2 +- src/Init/Grind/Ring/Basic.lean | 3 ++- src/Init/Grind/Ring/CommSolver.lean | 6 +++--- src/Init/Grind/Ring/Envelope.lean | 2 +- src/Init/Grind/Ring/Field.lean | 7 ++++--- src/Lean/Elab/App.lean | 6 +++--- src/Std/Data/Iterators/Lemmas/Equivalence/HetT.lean | 2 +- tests/elab/270.lean | 3 --- tests/elab/4144.lean | 2 +- tests/elab/KyleAlg.lean | 1 - tests/elab/KyleAlgAbbrev.lean | 1 - 18 files changed, 29 insertions(+), 30 deletions(-) diff --git a/src/Init/Data/Dyadic/Inv.lean b/src/Init/Data/Dyadic/Inv.lean index 5f26d2d03f..84ec9c3a88 100644 --- a/src/Init/Data/Dyadic/Inv.lean +++ b/src/Init/Data/Dyadic/Inv.lean @@ -71,7 +71,7 @@ theorem eq_toDyadic_of_precision_le {q : Rat} {y : Dyadic} {prec : Int} -- Multiplied form: `y.toRat * 2 ^ prec` equals its own floor cast back. have hL : y.toRat * (2 : Rat) ^ prec = (((y.toRat * 2 ^ prec).floor : Int) : Rat) := by have := congrArg (· * (2 : Rat) ^ prec) hcan - simp only at this + try simp only at this -- TODO(kmill) remove after stage0 update rwa [Rat.div_mul_cancel h2ne] at this -- Multiply `h1`, `h2` by `2 ^ prec`. have h1' : y.toRat * 2 ^ prec ≤ q * 2 ^ prec := diff --git a/src/Init/Data/Int/Cooper.lean b/src/Init/Data/Int/Cooper.lean index 1646aedc22..15f49a11fe 100644 --- a/src/Init/Data/Int/Cooper.lean +++ b/src/Init/Data/Int/Cooper.lean @@ -27,7 +27,7 @@ theorem dvd_of_mul_dvd {a b c : Int} (w : a * b ∣ a * c) (h : 0 < a) : b ∣ c obtain ⟨z, w⟩ := w refine ⟨z, ?_⟩ replace w := congrArg (· / a) w - dsimp at w + try dsimp at w -- TODO(kmill): remove after stage0 update rwa [Int.mul_ediv_cancel_left _ (Int.ne_of_gt h), Int.mul_assoc, Int.mul_ediv_cancel_left _ (Int.ne_of_gt h)] at w diff --git a/src/Init/Data/Int/Lemmas.lean b/src/Init/Data/Int/Lemmas.lean index 8a7f1327e6..be4123f91d 100644 --- a/src/Init/Data/Int/Lemmas.lean +++ b/src/Init/Data/Int/Lemmas.lean @@ -364,7 +364,8 @@ protected theorem subNatNat_eq_coe {m n : Nat} : subNatNat m n = ↑m - ↑n := rw [← Int.subNatNat_eq_coe] refine subNatNat_elim m n (fun m n i => toNat i = m - n) (fun i n => ?_) (fun i n => ?_) · exact (Nat.add_sub_cancel_left ..).symm - · dsimp; rw [Nat.add_assoc, Nat.sub_eq_zero_of_le (Nat.le_add_right ..)]; rfl + · try dsimp -- TODO(kmill) remove after stage0 update + rw [Nat.add_assoc, Nat.sub_eq_zero_of_le (Nat.le_add_right ..)]; rfl theorem toNat_of_nonpos : ∀ {z : Int}, z ≤ 0 → z.toNat = 0 | 0, _ => rfl diff --git a/src/Init/Data/Int/Linear.lean b/src/Init/Data/Int/Linear.lean index 2324fba2eb..49e4740b42 100644 --- a/src/Init/Data/Int/Linear.lean +++ b/src/Init/Data/Int/Linear.lean @@ -1757,11 +1757,11 @@ private theorem ex_of_dvd {α β a b d x : Int} rw [one_emod_eq_one h₀] at h₂ assumption have : ((α * a) * x) % d = (- α * b) % d := by - replace h₁ := congrArg (α * ·) h₁; simp only at h₁ + replace h₁ := congrArg (α * ·) h₁; try simp only at h₁ -- TODO(kmill): remove simp after stage0 update rw [Int.mul_add] at h₁ replace h₁ := congrArg (· - α * b) h₁; simp only [Int.add_sub_cancel] at h₁ rw [← Int.mul_assoc, Int.mul_left_comm, Int.sub_eq_add_neg] at h₁ - replace h₁ := congrArg (· % d) h₁; simp only at h₁ + replace h₁ := congrArg (· % d) h₁; try simp only at h₁ -- TODO(kmill): remove simp after stage0 update rw [Int.add_emod, Int.mul_emod_right, Int.zero_add, Int.emod_emod, ← Int.neg_mul] at h₁ assumption have : x % d = (- α * b) % d := by diff --git a/src/Init/Data/String/Decode.lean b/src/Init/Data/String/Decode.lean index 6b780e2e42..51be24496d 100644 --- a/src/Init/Data/String/Decode.lean +++ b/src/Init/Data/String/Decode.lean @@ -384,25 +384,25 @@ theorem parseFirstByte_eq_invalid_of_isInvalidContinuationByte_eq_false {b : UIn | .done => rw [toBitVec_eq_of_parseFirstByte_eq_done h] at hb have := congrArg (·[7]) hb - simp only at this + try simp only at this -- TODO(kmill) remove after stage0 update rw [BitVec.getElem_append, BitVec.getElem_append] at this simp at this | .oneMore => rw [toBitVec_eq_of_parseFirstByte_eq_oneMore h] at hb have := congrArg (·[6]) hb - simp only at this + try simp only at this -- TODO(kmill) remove after stage0 update rw [BitVec.getElem_append, BitVec.getElem_append] at this simp at this | .twoMore => rw [toBitVec_eq_of_parseFirstByte_eq_twoMore h] at hb have := congrArg (·[6]) hb - simp only at this + try simp only at this -- TODO(kmill) remove after stage0 update rw [BitVec.getElem_append, BitVec.getElem_append] at this simp at this | .threeMore => rw [toBitVec_eq_of_parseFirstByte_eq_threeMore h] at hb have := congrArg (·[6]) hb - simp only at this + try simp only at this -- TODO(kmill) remove after stage0 update rw [BitVec.getElem_append, BitVec.getElem_append] at this simp at this | .invalid => rfl diff --git a/src/Init/Grind/FieldNormNum.lean b/src/Init/Grind/FieldNormNum.lean index d5e1cec6d4..f72e8e4a90 100644 --- a/src/Init/Grind/FieldNormNum.lean +++ b/src/Init/Grind/FieldNormNum.lean @@ -33,7 +33,8 @@ private theorem nonzero_helper {α} [Field α] {z : Int} {n m : Nat} (hn : (n : have : z.natAbs.gcd (n * m) ∣ (n * m) := Nat.gcd_dvd_right z.natAbs (n * m) obtain ⟨k, hk⟩ := this replace hk := congrArg (fun x : Nat => (x : α)) hk - dsimp at hk + -- TODO(kmill): remove after stage0 update + try dsimp at hk rw [Semiring.natCast_mul, Semiring.natCast_mul, h, Semiring.zero_mul] at hk replace hk := Field.of_mul_eq_zero hk simp_all diff --git a/src/Init/Grind/Module/Envelope.lean b/src/Init/Grind/Module/Envelope.lean index 1c88171701..288c718a6c 100644 --- a/src/Init/Grind/Module/Envelope.lean +++ b/src/Init/Grind/Module/Envelope.lean @@ -47,7 +47,7 @@ theorem r_trans {a b c : α × α} : r α a b → r α b c → r α a c := by simp [r] intro k₁ h₁ k₂ h₂ refine ⟨(k₁ + k₂ + b₁ + b₂), ?_⟩ - replace h₁ := congrArg (· + (b₁ + c₂ + k₂)) h₁; simp at h₁ + replace h₁ := congrArg (· + (b₁ + c₂ + k₂)) h₁; try simp at h₁ -- TODO(kmill) remove simp after stage0 update have haux₁ : a₁ + b₂ + k₁ + (b₁ + c₂ + k₂) = (a₁ + c₂) + (k₁ + k₂ + b₁ + b₂) := by ac_rfl have haux₂ : a₂ + b₁ + k₁ + (b₁ + c₂ + k₂) = (a₂ + c₁) + (k₁ + k₂ + b₁ + b₂) := by rw [h₂]; ac_rfl rw [haux₁, haux₂] at h₁ diff --git a/src/Init/Grind/Ordered/Ring.lean b/src/Init/Grind/Ordered/Ring.lean index 036c94a9e5..b24e73b66a 100644 --- a/src/Init/Grind/Ordered/Ring.lean +++ b/src/Init/Grind/Ordered/Ring.lean @@ -229,7 +229,7 @@ instance [Ring R] [LE R] [LT R] [LawfulOrderLT R] [IsPreorder R] [OrderedRing R] next => rfl next x => rw [Semiring.ofNat_succ] at h - replace h := congrArg (· - 1) h; simp at h + replace h := congrArg (· - 1) h; try simp at h -- TODO(kmill): remove simp after stage0 update rw [Ring.sub_eq_add_neg, Semiring.add_assoc, AddCommGroup.add_neg_cancel, Ring.sub_eq_add_neg, AddCommMonoid.zero_add, Semiring.add_zero] at h have h₁ : (OfNat.ofNat x : R) < 0 := by diff --git a/src/Init/Grind/Ring/Basic.lean b/src/Init/Grind/Ring/Basic.lean index 88e430ad35..dcedc3ebd1 100644 --- a/src/Init/Grind/Ring/Basic.lean +++ b/src/Init/Grind/Ring/Basic.lean @@ -601,7 +601,8 @@ theorem no_int_zero_divisors {α : Type u} [IntModule α] [NoNatZeroDivisors α] rw [IntModule.neg_zsmul] intro _ h replace h := congrArg (-·) h - dsimp only at h + -- TODO(kmill): remove after stage0 update + try dsimp only at h rw [neg_neg, neg_zero] at h rw [IntModule.zsmul_natCast_eq_nsmul] at h exact NoNatZeroDivisors.eq_zero_of_mul_eq_zero (Nat.succ_ne_zero _) h diff --git a/src/Init/Grind/Ring/CommSolver.lean b/src/Init/Grind/Ring/CommSolver.lean index 300e30781e..103f1fef41 100644 --- a/src/Init/Grind/Ring/CommSolver.lean +++ b/src/Init/Grind/Ring/CommSolver.lean @@ -1915,11 +1915,11 @@ theorem eq_normEq0 {α} [CommRing α] (ctx : Context α) (c : Nat) (p₁ p₂ p theorem gcd_eq_0 [CommRing α] (g n m a b : Int) (h : g = a * n + b * m) (h₁ : Int.cast (R := α) n = 0) (h₂ : Int.cast (R := α) m = 0) : Int.cast (R := α) g = 0 := by rw [← Ring.intCast_ofNat] at * - replace h₁ := congrArg (Int.cast (R := α) a * ·) h₁; simp at h₁ + replace h₁ := congrArg (Int.cast (R := α) a * ·) h₁; try simp at h₁ -- TODO(kmill): remove simp after stage0 update rw [← Ring.intCast_mul, Ring.intCast_zero, Semiring.mul_zero] at h₁ - replace h₂ := congrArg (Int.cast (R := α) b * ·) h₂; simp at h₂ + replace h₂ := congrArg (Int.cast (R := α) b * ·) h₂; try simp at h₂ -- TODO(kmill): remove simp after stage0 update rw [← Ring.intCast_mul, Ring.intCast_zero, Semiring.mul_zero] at h₂ - replace h₁ := congrArg (· + Int.cast (b * m)) h₁; simp at h₁ + replace h₁ := congrArg (· + Int.cast (b * m)) h₁; try simp at h₁ -- TODO(kmill): remove simp after stage0 update rw [← Ring.intCast_add, h₂, zero_add, ← h] at h₁ rw [Ring.intCast_zero, h₁] diff --git a/src/Init/Grind/Ring/Envelope.lean b/src/Init/Grind/Ring/Envelope.lean index bb22f3f561..004160d5b1 100644 --- a/src/Init/Grind/Ring/Envelope.lean +++ b/src/Init/Grind/Ring/Envelope.lean @@ -50,7 +50,7 @@ theorem r_trans {a b c : α × α} : r α a b → r α b c → r α a c := by simp [r] intro k₁ h₁ k₂ h₂ refine ⟨(k₁ + k₂ + b₁ + b₂), ?_⟩ - replace h₁ := congrArg (· + (b₁ + c₂ + k₂)) h₁; simp at h₁ + replace h₁ := congrArg (· + (b₁ + c₂ + k₂)) h₁; try simp at h₁ -- TODO(kmill): remove simp after stage0 update have haux₁ : a₁ + b₂ + k₁ + (b₁ + c₂ + k₂) = (a₁ + c₂) + (k₁ + k₂ + b₁ + b₂) := by ac_rfl have haux₂ : a₂ + b₁ + k₁ + (b₁ + c₂ + k₂) = (a₂ + c₁) + (k₁ + k₂ + b₁ + b₂) := by rw [h₂]; ac_rfl rw [haux₁, haux₂] at h₁ diff --git a/src/Init/Grind/Ring/Field.lean b/src/Init/Grind/Ring/Field.lean index 9f7f466b39..c8006a4a86 100644 --- a/src/Init/Grind/Ring/Field.lean +++ b/src/Init/Grind/Ring/Field.lean @@ -90,7 +90,8 @@ theorem inv_eq_zero_iff {a : α} : a⁻¹ = 0 ↔ a = 0 := by · subst h rfl · have := congrArg (fun x => x * a) w - dsimp at this + -- TODO(kmill): remove after stage0 update + try dsimp at this rw [Semiring.zero_mul, inv_mul_cancel h] at this exfalso exact zero_ne_one this.symm @@ -122,7 +123,7 @@ theorem inv_mul (a b : α) : (a*b)⁻¹ = a⁻¹*b⁻¹ := by replace h₁ := Field.inv_mul_cancel h₁ replace h₂ := Field.inv_mul_cancel h₂ replace h₃ := Field.mul_inv_cancel h₃ - replace h₃ := congrArg (b⁻¹*a⁻¹* ·) h₃; simp at h₃ + replace h₃ := congrArg (b⁻¹*a⁻¹* ·) h₃; try simp at h₃ -- TODO(kmill): remove simp after stage0 update rw [Semiring.mul_assoc, Semiring.mul_assoc, ← Semiring.mul_assoc (a⁻¹), h₁, Semiring.one_mul, ← Semiring.mul_assoc, h₂, Semiring.one_mul, Semiring.mul_one, CommRing.mul_comm (b⁻¹)] at h₃ assumption @@ -135,7 +136,7 @@ theorem of_pow_eq_zero (a : α) (n : Nat) : a^n = 0 → a = 0 := by apply Classical.byContradiction intro hne have := Field.mul_inv_cancel hne - replace h := congrArg (· * a⁻¹) h; simp at h + replace h := congrArg (· * a⁻¹) h; try simp at h -- TODO(kmill): remove simp after stage0 update rw [Semiring.mul_assoc, this, Semiring.mul_one, Semiring.zero_mul] at h have := ih h contradiction diff --git a/src/Lean/Elab/App.lean b/src/Lean/Elab/App.lean index 59274eea48..9dd08544e4 100644 --- a/src/Lean/Elab/App.lean +++ b/src/Lean/Elab/App.lean @@ -226,7 +226,7 @@ structure State where /-- Gets `s.fType` with all loose bvars instantiated. -/ @[inline] private def State.getFType (s : State) : Expr := - s.fType.instantiateRevRange 0 s.fArgs.size s.fArgs + s.fType.instantiateBetaRevRange 0 s.fArgs.size s.fArgs abbrev M := ReaderT Context (StateRefT State TermElabM) @@ -240,7 +240,7 @@ private def fTypeIsForall : M Bool := do if let Expr.forallE n d b bi := s.fType then -- Ensure the domain is instantiated, to ensure validity of `getParamType` if d.hasLooseBVars then - let d := d.instantiateRevRange 0 s.fArgs.size s.fArgs + let d := d.instantiateBetaRevRange 0 s.fArgs.size s.fArgs set { s with fType := Expr.forallE n d b bi } return true else @@ -1288,7 +1288,7 @@ partial def main : M Expr := do let .forallE binderName binderType body binderInfo ← whnfForall (← get).fType | finalize let addArgAndContinue (arg : Expr) : M Expr := do - modify fun s => { s with idx := s.idx + 1, f := mkApp s.f arg, fType := body.instantiate1 arg } + modify fun s => { s with idx := s.idx + 1, f := mkApp s.f arg, fType := body.instantiateBetaRevRange 0 1 #[arg] } saveArgInfo arg binderName main let idx := (← get).idx diff --git a/src/Std/Data/Iterators/Lemmas/Equivalence/HetT.lean b/src/Std/Data/Iterators/Lemmas/Equivalence/HetT.lean index 5a1d3208dd..3ecc8fd0f2 100644 --- a/src/Std/Data/Iterators/Lemmas/Equivalence/HetT.lean +++ b/src/Std/Data/Iterators/Lemmas/Equivalence/HetT.lean @@ -109,7 +109,7 @@ theorem Small.of_surjective (α : Type v) {β : Type w} (f : α → β) [Small.{ instance {α : Type v} {β : Type w} {f : α → β} [Small.{u} α] : Small.{u} { b : β // ∃ a, f a = b } := .of_surjective α (fun a => ⟨f a, a, rfl⟩) - (fun b => ⟨b.2.choose, by simp; ext; exact b.2.choose_spec⟩) + (fun b => ⟨b.2.choose, by ext; exact b.2.choose_spec⟩) theorem Small.map {α : Type v} {β : Type w} (P : α → Prop) (f : (a : α) → P a → β) [Small.{u} { a // P a }] : diff --git a/tests/elab/270.lean b/tests/elab/270.lean index 99a6bd8747..a00e0702fd 100644 --- a/tests/elab/270.lean +++ b/tests/elab/270.lean @@ -7,7 +7,6 @@ open CommAddSemigroup theorem addComm3 [CommAddSemigroup α] {a b c : α} : a + b + c = a + c + b := by { have h : b + c = c + b := addComm; have h' := congrArg (a + ·) h; - simp at h'; rw [←addAssoc] at h'; rw [←addAssoc (a := a)] at h'; exact h'; @@ -21,7 +20,6 @@ theorem addComm4 [CommAddSemigroup α] {a b c : α} : a + b + c = a + c + b := b theorem addComm5 [CommAddSemigroup α] {a b c : α} : a + b + c = a + c + b := by { have h : b + c = c + b := addComm; have h' := congrArg (a + ·) h; - simp at h'; rw [←addAssoc] at h'; rw [←addAssoc (a := a)] at h'; exact h'; @@ -30,7 +28,6 @@ theorem addComm5 [CommAddSemigroup α] {a b c : α} : a + b + c = a + c + b := b theorem addComm6 [CommAddSemigroup α] {a b c : α} : a + b + c = a + c + b := by { have h : b + c = c + b := addComm; have h' := congrArg (a + ·) h; - simp at h'; rw [←addAssoc] at h'; rw [←addAssoc] at h'; exact h'; diff --git a/tests/elab/4144.lean b/tests/elab/4144.lean index f2e51810d6..031563e3fe 100644 --- a/tests/elab/4144.lean +++ b/tests/elab/4144.lean @@ -17,7 +17,7 @@ case refine'_4 ⊢ ?refine'_1 case refine'_5 -⊢ ¬(fun x => ?m.9) ?refine'_3 = (fun x => ?m.9) ?refine'_4 +⊢ ¬?m.10 ?refine'_3 = ?m.10 ?refine'_4 -/ #guard_msgs in example : False := by diff --git a/tests/elab/KyleAlg.lean b/tests/elab/KyleAlg.lean index 9c3b567f29..6931733c7a 100644 --- a/tests/elab/KyleAlg.lean +++ b/tests/elab/KyleAlg.lean @@ -142,7 +142,6 @@ theorem addIdemIffZero [AddGroup α] {a : α} : a + a = a ↔ a = 0 := by focus intro h have h' := congrArg (λ x => x + -a) h - simp at h' rw [addAssoc, addNeg, addZero] at h' exact h' focus diff --git a/tests/elab/KyleAlgAbbrev.lean b/tests/elab/KyleAlgAbbrev.lean index e22fb7f68c..6548943a21 100644 --- a/tests/elab/KyleAlgAbbrev.lean +++ b/tests/elab/KyleAlgAbbrev.lean @@ -119,7 +119,6 @@ theorem addIdemIffZero [AddGroup α] {a : α} : a + a = a ↔ a = 0 := by focus intro h have h' := congrArg (λ x => x + -a) h - simp at h' rw [addAssoc, addNeg, addZero] at h' exact h' focus