feat: beta reduce while elaborating applications (#13807)

This PR modifies the app elaborator to beta reduce arguments while
substituting them into expected types for later arguments. This makes it
consistent with `inferType` and `instantiateMVars`, which both beta
reduce substitutions. In particular, this change ensures that the app
elaborator behaves as if it creates metavariables for each parameter and
assigns elaborated arguments to the metavariables. **Breaking change:**
tactic proofs may need to be modified to remove unnecessary steps, e.g.
`dsimp only` steps that were previously for beta reductions.
This commit is contained in:
Kyle Miller 2026-05-21 00:26:00 -07:00 committed by GitHub
parent acfe1d1a4b
commit 0db4ac18e5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 29 additions and 30 deletions

View file

@ -71,7 +71,7 @@ theorem eq_toDyadic_of_precision_le {q : Rat} {y : Dyadic} {prec : Int}
-- Multiplied form: `y.toRat * 2 ^ prec` equals its own floor cast back. -- Multiplied form: `y.toRat * 2 ^ prec` equals its own floor cast back.
have hL : y.toRat * (2 : Rat) ^ prec = (((y.toRat * 2 ^ prec).floor : Int) : Rat) := by have hL : y.toRat * (2 : Rat) ^ prec = (((y.toRat * 2 ^ prec).floor : Int) : Rat) := by
have := congrArg (· * (2 : Rat) ^ prec) hcan have := congrArg (· * (2 : Rat) ^ prec) hcan
simp only at this try simp only at this -- TODO(kmill) remove after stage0 update
rwa [Rat.div_mul_cancel h2ne] at this rwa [Rat.div_mul_cancel h2ne] at this
-- Multiply `h1`, `h2` by `2 ^ prec`. -- Multiply `h1`, `h2` by `2 ^ prec`.
have h1' : y.toRat * 2 ^ prec ≤ q * 2 ^ prec := have h1' : y.toRat * 2 ^ prec ≤ q * 2 ^ prec :=

View file

@ -27,7 +27,7 @@ theorem dvd_of_mul_dvd {a b c : Int} (w : a * b a * c) (h : 0 < a) : b c
obtain ⟨z, w⟩ := w obtain ⟨z, w⟩ := w
refine ⟨z, ?_⟩ refine ⟨z, ?_⟩
replace w := congrArg (· / a) w replace w := congrArg (· / a) w
dsimp at w try dsimp at w -- TODO(kmill): remove after stage0 update
rwa [Int.mul_ediv_cancel_left _ (Int.ne_of_gt h), Int.mul_assoc, rwa [Int.mul_ediv_cancel_left _ (Int.ne_of_gt h), Int.mul_assoc,
Int.mul_ediv_cancel_left _ (Int.ne_of_gt h)] at w Int.mul_ediv_cancel_left _ (Int.ne_of_gt h)] at w

View file

@ -364,7 +364,8 @@ protected theorem subNatNat_eq_coe {m n : Nat} : subNatNat m n = ↑m - ↑n :=
rw [← Int.subNatNat_eq_coe] rw [← Int.subNatNat_eq_coe]
refine subNatNat_elim m n (fun m n i => toNat i = m - n) (fun i n => ?_) (fun i n => ?_) refine subNatNat_elim m n (fun m n i => toNat i = m - n) (fun i n => ?_) (fun i n => ?_)
· exact (Nat.add_sub_cancel_left ..).symm · exact (Nat.add_sub_cancel_left ..).symm
· dsimp; rw [Nat.add_assoc, Nat.sub_eq_zero_of_le (Nat.le_add_right ..)]; rfl · try dsimp -- TODO(kmill) remove after stage0 update
rw [Nat.add_assoc, Nat.sub_eq_zero_of_le (Nat.le_add_right ..)]; rfl
theorem toNat_of_nonpos : ∀ {z : Int}, z ≤ 0 → z.toNat = 0 theorem toNat_of_nonpos : ∀ {z : Int}, z ≤ 0 → z.toNat = 0
| 0, _ => rfl | 0, _ => rfl

View file

@ -1757,11 +1757,11 @@ private theorem ex_of_dvd {α β a b d x : Int}
rw [one_emod_eq_one h₀] at h₂ rw [one_emod_eq_one h₀] at h₂
assumption assumption
have : ((α * a) * x) % d = (- α * b) % d := by have : ((α * a) * x) % d = (- α * b) % d := by
replace h₁ := congrArg (α * ·) h₁; simp only at h₁ replace h₁ := congrArg (α * ·) h₁; try simp only at h₁ -- TODO(kmill): remove simp after stage0 update
rw [Int.mul_add] at h₁ rw [Int.mul_add] at h₁
replace h₁ := congrArg (· - α * b) h₁; simp only [Int.add_sub_cancel] at h₁ replace h₁ := congrArg (· - α * b) h₁; simp only [Int.add_sub_cancel] at h₁
rw [← Int.mul_assoc, Int.mul_left_comm, Int.sub_eq_add_neg] at h₁ rw [← Int.mul_assoc, Int.mul_left_comm, Int.sub_eq_add_neg] at h₁
replace h₁ := congrArg (· % d) h₁; simp only at h₁ replace h₁ := congrArg (· % d) h₁; try simp only at h₁ -- TODO(kmill): remove simp after stage0 update
rw [Int.add_emod, Int.mul_emod_right, Int.zero_add, Int.emod_emod, ← Int.neg_mul] at h₁ rw [Int.add_emod, Int.mul_emod_right, Int.zero_add, Int.emod_emod, ← Int.neg_mul] at h₁
assumption assumption
have : x % d = (- α * b) % d := by have : x % d = (- α * b) % d := by

View file

@ -384,25 +384,25 @@ theorem parseFirstByte_eq_invalid_of_isInvalidContinuationByte_eq_false {b : UIn
| .done => | .done =>
rw [toBitVec_eq_of_parseFirstByte_eq_done h] at hb rw [toBitVec_eq_of_parseFirstByte_eq_done h] at hb
have := congrArg (·[7]) hb have := congrArg (·[7]) hb
simp only at this try simp only at this -- TODO(kmill) remove after stage0 update
rw [BitVec.getElem_append, BitVec.getElem_append] at this rw [BitVec.getElem_append, BitVec.getElem_append] at this
simp at this simp at this
| .oneMore => | .oneMore =>
rw [toBitVec_eq_of_parseFirstByte_eq_oneMore h] at hb rw [toBitVec_eq_of_parseFirstByte_eq_oneMore h] at hb
have := congrArg (·[6]) hb have := congrArg (·[6]) hb
simp only at this try simp only at this -- TODO(kmill) remove after stage0 update
rw [BitVec.getElem_append, BitVec.getElem_append] at this rw [BitVec.getElem_append, BitVec.getElem_append] at this
simp at this simp at this
| .twoMore => | .twoMore =>
rw [toBitVec_eq_of_parseFirstByte_eq_twoMore h] at hb rw [toBitVec_eq_of_parseFirstByte_eq_twoMore h] at hb
have := congrArg (·[6]) hb have := congrArg (·[6]) hb
simp only at this try simp only at this -- TODO(kmill) remove after stage0 update
rw [BitVec.getElem_append, BitVec.getElem_append] at this rw [BitVec.getElem_append, BitVec.getElem_append] at this
simp at this simp at this
| .threeMore => | .threeMore =>
rw [toBitVec_eq_of_parseFirstByte_eq_threeMore h] at hb rw [toBitVec_eq_of_parseFirstByte_eq_threeMore h] at hb
have := congrArg (·[6]) hb have := congrArg (·[6]) hb
simp only at this try simp only at this -- TODO(kmill) remove after stage0 update
rw [BitVec.getElem_append, BitVec.getElem_append] at this rw [BitVec.getElem_append, BitVec.getElem_append] at this
simp at this simp at this
| .invalid => rfl | .invalid => rfl

View file

@ -33,7 +33,8 @@ private theorem nonzero_helper {α} [Field α] {z : Int} {n m : Nat} (hn : (n :
have : z.natAbs.gcd (n * m) (n * m) := Nat.gcd_dvd_right z.natAbs (n * m) have : z.natAbs.gcd (n * m) (n * m) := Nat.gcd_dvd_right z.natAbs (n * m)
obtain ⟨k, hk⟩ := this obtain ⟨k, hk⟩ := this
replace hk := congrArg (fun x : Nat => (x : α)) hk replace hk := congrArg (fun x : Nat => (x : α)) hk
dsimp at hk -- TODO(kmill): remove after stage0 update
try dsimp at hk
rw [Semiring.natCast_mul, Semiring.natCast_mul, h, Semiring.zero_mul] at hk rw [Semiring.natCast_mul, Semiring.natCast_mul, h, Semiring.zero_mul] at hk
replace hk := Field.of_mul_eq_zero hk replace hk := Field.of_mul_eq_zero hk
simp_all simp_all

View file

@ -47,7 +47,7 @@ theorem r_trans {a b c : α × α} : r α a b → r α b c → r α a c := by
simp [r] simp [r]
intro k₁ h₁ k₂ h₂ intro k₁ h₁ k₂ h₂
refine ⟨(k₁ + k₂ + b₁ + b₂), ?_⟩ refine ⟨(k₁ + k₂ + b₁ + b₂), ?_⟩
replace h₁ := congrArg (· + (b₁ + c₂ + k₂)) h₁; simp at h₁ replace h₁ := congrArg (· + (b₁ + c₂ + k₂)) h₁; try simp at h₁ -- TODO(kmill) remove simp after stage0 update
have haux₁ : a₁ + b₂ + k₁ + (b₁ + c₂ + k₂) = (a₁ + c₂) + (k₁ + k₂ + b₁ + b₂) := by ac_rfl have haux₁ : a₁ + b₂ + k₁ + (b₁ + c₂ + k₂) = (a₁ + c₂) + (k₁ + k₂ + b₁ + b₂) := by ac_rfl
have haux₂ : a₂ + b₁ + k₁ + (b₁ + c₂ + k₂) = (a₂ + c₁) + (k₁ + k₂ + b₁ + b₂) := by rw [h₂]; ac_rfl have haux₂ : a₂ + b₁ + k₁ + (b₁ + c₂ + k₂) = (a₂ + c₁) + (k₁ + k₂ + b₁ + b₂) := by rw [h₂]; ac_rfl
rw [haux₁, haux₂] at h₁ rw [haux₁, haux₂] at h₁

View file

@ -229,7 +229,7 @@ instance [Ring R] [LE R] [LT R] [LawfulOrderLT R] [IsPreorder R] [OrderedRing R]
next => rfl next => rfl
next x => next x =>
rw [Semiring.ofNat_succ] at h rw [Semiring.ofNat_succ] at h
replace h := congrArg (· - 1) h; simp at h replace h := congrArg (· - 1) h; try simp at h -- TODO(kmill): remove simp after stage0 update
rw [Ring.sub_eq_add_neg, Semiring.add_assoc, AddCommGroup.add_neg_cancel, rw [Ring.sub_eq_add_neg, Semiring.add_assoc, AddCommGroup.add_neg_cancel,
Ring.sub_eq_add_neg, AddCommMonoid.zero_add, Semiring.add_zero] at h Ring.sub_eq_add_neg, AddCommMonoid.zero_add, Semiring.add_zero] at h
have h₁ : (OfNat.ofNat x : R) < 0 := by have h₁ : (OfNat.ofNat x : R) < 0 := by

View file

@ -601,7 +601,8 @@ theorem no_int_zero_divisors {α : Type u} [IntModule α] [NoNatZeroDivisors α]
rw [IntModule.neg_zsmul] rw [IntModule.neg_zsmul]
intro _ h intro _ h
replace h := congrArg (-·) h replace h := congrArg (-·) h
dsimp only at h -- TODO(kmill): remove after stage0 update
try dsimp only at h
rw [neg_neg, neg_zero] at h rw [neg_neg, neg_zero] at h
rw [IntModule.zsmul_natCast_eq_nsmul] at h rw [IntModule.zsmul_natCast_eq_nsmul] at h
exact NoNatZeroDivisors.eq_zero_of_mul_eq_zero (Nat.succ_ne_zero _) h exact NoNatZeroDivisors.eq_zero_of_mul_eq_zero (Nat.succ_ne_zero _) h

View file

@ -1915,11 +1915,11 @@ theorem eq_normEq0 {α} [CommRing α] (ctx : Context α) (c : Nat) (p₁ p₂ p
theorem gcd_eq_0 [CommRing α] (g n m a b : Int) (h : g = a * n + b * m) theorem gcd_eq_0 [CommRing α] (g n m a b : Int) (h : g = a * n + b * m)
(h₁ : Int.cast (R := α) n = 0) (h₂ : Int.cast (R := α) m = 0) : Int.cast (R := α) g = 0 := by (h₁ : Int.cast (R := α) n = 0) (h₂ : Int.cast (R := α) m = 0) : Int.cast (R := α) g = 0 := by
rw [← Ring.intCast_ofNat] at * rw [← Ring.intCast_ofNat] at *
replace h₁ := congrArg (Int.cast (R := α) a * ·) h₁; simp at h₁ replace h₁ := congrArg (Int.cast (R := α) a * ·) h₁; try simp at h₁ -- TODO(kmill): remove simp after stage0 update
rw [← Ring.intCast_mul, Ring.intCast_zero, Semiring.mul_zero] at h₁ rw [← Ring.intCast_mul, Ring.intCast_zero, Semiring.mul_zero] at h₁
replace h₂ := congrArg (Int.cast (R := α) b * ·) h₂; simp at h₂ replace h₂ := congrArg (Int.cast (R := α) b * ·) h₂; try simp at h₂ -- TODO(kmill): remove simp after stage0 update
rw [← Ring.intCast_mul, Ring.intCast_zero, Semiring.mul_zero] at h₂ rw [← Ring.intCast_mul, Ring.intCast_zero, Semiring.mul_zero] at h₂
replace h₁ := congrArg (· + Int.cast (b * m)) h₁; simp at h₁ replace h₁ := congrArg (· + Int.cast (b * m)) h₁; try simp at h₁ -- TODO(kmill): remove simp after stage0 update
rw [← Ring.intCast_add, h₂, zero_add, ← h] at h₁ rw [← Ring.intCast_add, h₂, zero_add, ← h] at h₁
rw [Ring.intCast_zero, h₁] rw [Ring.intCast_zero, h₁]

View file

@ -50,7 +50,7 @@ theorem r_trans {a b c : α × α} : r α a b → r α b c → r α a c := by
simp [r] simp [r]
intro k₁ h₁ k₂ h₂ intro k₁ h₁ k₂ h₂
refine ⟨(k₁ + k₂ + b₁ + b₂), ?_⟩ refine ⟨(k₁ + k₂ + b₁ + b₂), ?_⟩
replace h₁ := congrArg (· + (b₁ + c₂ + k₂)) h₁; simp at h₁ replace h₁ := congrArg (· + (b₁ + c₂ + k₂)) h₁; try simp at h₁ -- TODO(kmill): remove simp after stage0 update
have haux₁ : a₁ + b₂ + k₁ + (b₁ + c₂ + k₂) = (a₁ + c₂) + (k₁ + k₂ + b₁ + b₂) := by ac_rfl have haux₁ : a₁ + b₂ + k₁ + (b₁ + c₂ + k₂) = (a₁ + c₂) + (k₁ + k₂ + b₁ + b₂) := by ac_rfl
have haux₂ : a₂ + b₁ + k₁ + (b₁ + c₂ + k₂) = (a₂ + c₁) + (k₁ + k₂ + b₁ + b₂) := by rw [h₂]; ac_rfl have haux₂ : a₂ + b₁ + k₁ + (b₁ + c₂ + k₂) = (a₂ + c₁) + (k₁ + k₂ + b₁ + b₂) := by rw [h₂]; ac_rfl
rw [haux₁, haux₂] at h₁ rw [haux₁, haux₂] at h₁

View file

@ -90,7 +90,8 @@ theorem inv_eq_zero_iff {a : α} : a⁻¹ = 0 ↔ a = 0 := by
· subst h · subst h
rfl rfl
· have := congrArg (fun x => x * a) w · have := congrArg (fun x => x * a) w
dsimp at this -- TODO(kmill): remove after stage0 update
try dsimp at this
rw [Semiring.zero_mul, inv_mul_cancel h] at this rw [Semiring.zero_mul, inv_mul_cancel h] at this
exfalso exfalso
exact zero_ne_one this.symm exact zero_ne_one this.symm
@ -122,7 +123,7 @@ theorem inv_mul (a b : α) : (a*b)⁻¹ = a⁻¹*b⁻¹ := by
replace h₁ := Field.inv_mul_cancel h₁ replace h₁ := Field.inv_mul_cancel h₁
replace h₂ := Field.inv_mul_cancel h₂ replace h₂ := Field.inv_mul_cancel h₂
replace h₃ := Field.mul_inv_cancel h₃ replace h₃ := Field.mul_inv_cancel h₃
replace h₃ := congrArg (b⁻¹*a⁻¹* ·) h₃; simp at h₃ replace h₃ := congrArg (b⁻¹*a⁻¹* ·) h₃; try simp at h₃ -- TODO(kmill): remove simp after stage0 update
rw [Semiring.mul_assoc, Semiring.mul_assoc, ← Semiring.mul_assoc (a⁻¹), h₁, Semiring.one_mul, rw [Semiring.mul_assoc, Semiring.mul_assoc, ← Semiring.mul_assoc (a⁻¹), h₁, Semiring.one_mul,
← Semiring.mul_assoc, h₂, Semiring.one_mul, Semiring.mul_one, CommRing.mul_comm (b⁻¹)] at h₃ ← Semiring.mul_assoc, h₂, Semiring.one_mul, Semiring.mul_one, CommRing.mul_comm (b⁻¹)] at h₃
assumption assumption
@ -135,7 +136,7 @@ theorem of_pow_eq_zero (a : α) (n : Nat) : a^n = 0 → a = 0 := by
apply Classical.byContradiction apply Classical.byContradiction
intro hne intro hne
have := Field.mul_inv_cancel hne have := Field.mul_inv_cancel hne
replace h := congrArg (· * a⁻¹) h; simp at h replace h := congrArg (· * a⁻¹) h; try simp at h -- TODO(kmill): remove simp after stage0 update
rw [Semiring.mul_assoc, this, Semiring.mul_one, Semiring.zero_mul] at h rw [Semiring.mul_assoc, this, Semiring.mul_one, Semiring.zero_mul] at h
have := ih h have := ih h
contradiction contradiction

View file

@ -226,7 +226,7 @@ structure State where
/-- Gets `s.fType` with all loose bvars instantiated. -/ /-- Gets `s.fType` with all loose bvars instantiated. -/
@[inline] private def State.getFType (s : State) : Expr := @[inline] private def State.getFType (s : State) : Expr :=
s.fType.instantiateRevRange 0 s.fArgs.size s.fArgs s.fType.instantiateBetaRevRange 0 s.fArgs.size s.fArgs
abbrev M := ReaderT Context (StateRefT State TermElabM) abbrev M := ReaderT Context (StateRefT State TermElabM)
@ -240,7 +240,7 @@ private def fTypeIsForall : M Bool := do
if let Expr.forallE n d b bi := s.fType then if let Expr.forallE n d b bi := s.fType then
-- Ensure the domain is instantiated, to ensure validity of `getParamType` -- Ensure the domain is instantiated, to ensure validity of `getParamType`
if d.hasLooseBVars then if d.hasLooseBVars then
let d := d.instantiateRevRange 0 s.fArgs.size s.fArgs let d := d.instantiateBetaRevRange 0 s.fArgs.size s.fArgs
set { s with fType := Expr.forallE n d b bi } set { s with fType := Expr.forallE n d b bi }
return true return true
else else
@ -1288,7 +1288,7 @@ partial def main : M Expr := do
let .forallE binderName binderType body binderInfo ← whnfForall (← get).fType | let .forallE binderName binderType body binderInfo ← whnfForall (← get).fType |
finalize finalize
let addArgAndContinue (arg : Expr) : M Expr := do let addArgAndContinue (arg : Expr) : M Expr := do
modify fun s => { s with idx := s.idx + 1, f := mkApp s.f arg, fType := body.instantiate1 arg } modify fun s => { s with idx := s.idx + 1, f := mkApp s.f arg, fType := body.instantiateBetaRevRange 0 1 #[arg] }
saveArgInfo arg binderName saveArgInfo arg binderName
main main
let idx := (← get).idx let idx := (← get).idx

View file

@ -109,7 +109,7 @@ theorem Small.of_surjective (α : Type v) {β : Type w} (f : α → β) [Small.{
instance {α : Type v} {β : Type w} {f : α → β} [Small.{u} α] : instance {α : Type v} {β : Type w} {f : α → β} [Small.{u} α] :
Small.{u} { b : β // ∃ a, f a = b } := .of_surjective α (fun a => ⟨f a, a, rfl⟩) Small.{u} { b : β // ∃ a, f a = b } := .of_surjective α (fun a => ⟨f a, a, rfl⟩)
(fun b => ⟨b.2.choose, by simp; ext; exact b.2.choose_spec⟩) (fun b => ⟨b.2.choose, by ext; exact b.2.choose_spec⟩)
theorem Small.map {α : Type v} {β : Type w} (P : α → Prop) (f : (a : α) → P a → β) theorem Small.map {α : Type v} {β : Type w} (P : α → Prop) (f : (a : α) → P a → β)
[Small.{u} { a // P a }] : [Small.{u} { a // P a }] :

View file

@ -7,7 +7,6 @@ open CommAddSemigroup
theorem addComm3 [CommAddSemigroup α] {a b c : α} : a + b + c = a + c + b := by { theorem addComm3 [CommAddSemigroup α] {a b c : α} : a + b + c = a + c + b := by {
have h : b + c = c + b := addComm; have h : b + c = c + b := addComm;
have h' := congrArg (a + ·) h; have h' := congrArg (a + ·) h;
simp at h';
rw [←addAssoc] at h'; rw [←addAssoc] at h';
rw [←addAssoc (a := a)] at h'; rw [←addAssoc (a := a)] at h';
exact h'; exact h';
@ -21,7 +20,6 @@ theorem addComm4 [CommAddSemigroup α] {a b c : α} : a + b + c = a + c + b := b
theorem addComm5 [CommAddSemigroup α] {a b c : α} : a + b + c = a + c + b := by { theorem addComm5 [CommAddSemigroup α] {a b c : α} : a + b + c = a + c + b := by {
have h : b + c = c + b := addComm; have h : b + c = c + b := addComm;
have h' := congrArg (a + ·) h; have h' := congrArg (a + ·) h;
simp at h';
rw [←addAssoc] at h'; rw [←addAssoc] at h';
rw [←addAssoc (a := a)] at h'; rw [←addAssoc (a := a)] at h';
exact h'; exact h';
@ -30,7 +28,6 @@ theorem addComm5 [CommAddSemigroup α] {a b c : α} : a + b + c = a + c + b := b
theorem addComm6 [CommAddSemigroup α] {a b c : α} : a + b + c = a + c + b := by { theorem addComm6 [CommAddSemigroup α] {a b c : α} : a + b + c = a + c + b := by {
have h : b + c = c + b := addComm; have h : b + c = c + b := addComm;
have h' := congrArg (a + ·) h; have h' := congrArg (a + ·) h;
simp at h';
rw [←addAssoc] at h'; rw [←addAssoc] at h';
rw [←addAssoc] at h'; rw [←addAssoc] at h';
exact h'; exact h';

View file

@ -17,7 +17,7 @@ case refine'_4
⊢ ?refine'_1 ⊢ ?refine'_1
case refine'_5 case refine'_5
⊢ ¬(fun x => ?m.9) ?refine'_3 = (fun x => ?m.9) ?refine'_4 ⊢ ¬?m.10 ?refine'_3 = ?m.10 ?refine'_4
-/ -/
#guard_msgs in #guard_msgs in
example : False := by example : False := by

View file

@ -142,7 +142,6 @@ theorem addIdemIffZero [AddGroup α] {a : α} : a + a = a ↔ a = 0 := by
focus focus
intro h intro h
have h' := congrArg (λ x => x + -a) h have h' := congrArg (λ x => x + -a) h
simp at h'
rw [addAssoc, addNeg, addZero] at h' rw [addAssoc, addNeg, addZero] at h'
exact h' exact h'
focus focus

View file

@ -119,7 +119,6 @@ theorem addIdemIffZero [AddGroup α] {a : α} : a + a = a ↔ a = 0 := by
focus focus
intro h intro h
have h' := congrArg (λ x => x + -a) h have h' := congrArg (λ x => x + -a) h
simp at h'
rw [addAssoc, addNeg, addZero] at h' rw [addAssoc, addNeg, addZero] at h'
exact h' exact h'
focus focus