From 751947482f0b090ed2ebbfaa77f96937843f823f Mon Sep 17 00:00:00 2001 From: Kenny Lau Date: Tue, 22 Jul 2025 09:34:14 +0100 Subject: [PATCH] fix: use let rec for Fin.reverseInduction (#9142) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR changes `Fin.reverseInduction` from using well-founded recursion to using `let rec`, which makes it have better definitional equality. Co-authored by @digama0. See the test below: ```lean namespace Fin /-- The new one. -/ @[elab_as_elim] def reverseInduction' {motive : Fin (n + 1) → Sort _} (last : motive (Fin.last n)) (cast : ∀ i : Fin n, motive i.succ → motive (castSucc i)) (i : Fin (n + 1)) : motive i := let rec go (j : Nat) (h) (h2 : i ≤ j) (x : motive ⟨j, h⟩) : motive i := if hi : i.1 = j then (show i = ⟨j, h⟩ by simp [← hi]) ▸ x else match j with | 0 => by omega | j+1 => go j (by omega) (by omega) (cast ⟨j, by omega⟩ x) go _ _ (by omega) last /-- Same code but using reverseInduction'. -/ @[elab_as_elim] def lastCases' {n : Nat} {motive : Fin (n + 1) → Sort _} (last : motive (Fin.last n)) (cast : ∀ i : Fin n, motive (castSucc i)) (i : Fin (n + 1)) : motive i := reverseInduction' last (fun i _ => cast i) i end Fin theorem foo : (Fin.lastCases (-4) (fun i ↦ (i : Int) * 2 + 1) (2 : Fin 3) : Int) = -4 := rfl #eval (Fin.lastCases (-4) (fun i ↦ (i : Int) * 2 + 1) (2 : Fin 3) : Int) theorem foo' : (Fin.lastCases' (-4) (fun i ↦ (i : Int) * 2 + 1) (2 : Fin 3) : Int) = -4 := rfl #eval (Fin.lastCases' (-4) (fun i ↦ (i : Int) * 2 + 1) (2 : Fin 3) : Int) theorem bar : (Fin.reverseInduction (n := 2) (motive := fun _ ↦ Int) (-4) (fun i _ ↦ (i : Int) * 2 + 1) (2 : Fin 3) : Int) = -4 := rfl #eval (Fin.reverseInduction (n := 2) (motive := fun _ ↦ Int) (-4) (fun i _ ↦ (i : Int) * 2 + 1) (2 : Fin 3) : Int) theorem bar' : (Fin.reverseInduction' (n := 2) (motive := fun _ ↦ Int) (-4) (fun i _ ↦ (i : Int) * 2 + 1) (2 : Fin 3) : Int) = -4 := rfl #eval (Fin.reverseInduction' (n := 2) (motive := fun _ ↦ Int) (-4) (fun i _ ↦ (i : Int) * 2 + 1) (2 : Fin 3) : Int) ``` [Link to Lean 4 Web](https://live.lean-lang.org/#project=lean-nightly&codez=HYQwtgpgzgDiDGEAEAxAlsAUJg9AWjyQBUALZYCAdyQHsKA6JPHTAAQG0IAbEAIwH0QUftzRgAukgAmEAGZIAThABuEBVAgBJYFICu8AC5o6AciQBvMDSOqkALlQYkACmBIA1EgCMASiSAkwiQAZRoFAyR+AF8XHihwhysbZGd0YHpY8OAfH0wkPJd4IXikQAAiJDR7RzdgABokRLRbNHoofXgA+utG5MK4oLby7JcKh1SXN09fPwSupvsAXlykLghwpXaAcxoXACtKgDkQAz9nEhOSACZKisATIiQdk4APSobbQAvyHbqSQEvyac6k8oLJZ5NDyEgjcr0LxIeb3JAGMhuZxQEg0agVWEfL7fJC8ACeSCgYhgSHYgATCJDg8R+QAdpEhHsCkNwNPUjvASHDKGgEYy8gAfJAABhhAD5cQSaJANiBeUgBTt3ND5mKtnDnPjaFKQCcNZKINKTr1wljxZr9SAcY8cnlVfwIi5dVq/BlsPhCEFwMh4DQZLjdOFdETgBtFCo1BptHpDMZgCZGMw2JweAIhCIuGJJDJ5BkAMJCaBmcxuByHAzRSyzZCjJyuDzePyBEJhCLRZwZF6VlypdJFJBZa35ZxGyplCFjWr/boFIr9eDtNBDZxjmsTet/V7IEaLPJKVTqLQ6fRGUzLXvOWS6NwVO3KpDDhflbAQHRVbAIiChCBgJCyGjbBwpBgPZxHmGhQC4eAACwnBeV5IIAZYTDJU2jHEgABUSBXJMJxXNWbgAMx/ChfiwlBCyKLIXCYAAxCoIBcF2QG5vm4HOFBMGXoCiFLshwCoRhWH1i4uFVEghG8ccmDvp+36/jQZgAd2zFgWYbHQS4sFcUhDjEehmF1r4wmVGM4k6XxJFMJB5EKJRNF0QxgFpMpBYQep56cRU3EQrpAkGThxlOKZSDEW+ZAybiIAKJUjn0Lu4YHlGx5IsWsIXCcG7kZpdqISF+SuRxcHZdpwXmXpgmGc4IkmURpWkVZdiwjZVG0co9GMWkcX7pGR4xuM5FpS4GUNT+nFFblg7sRpHn2l5El+L52FGXhYk1ZJ0lKN+vCRQp7WxWGXWHtGJ61sNA3OENsJZQhJWSXlakFYCRU8WZ/H6YtlUBQRq0WWRw1NXZrUOd2nURodSWqSlmHpZ2w1XTl5mMvdU2Fddz03fNb1CR9y1BcRQA) Notice how `rfl` fails for the 1st and 5th tests that use the original `Fin.reverseInduction`, but the 3rd and 7th tests that use the new code in this PR succeed. Closes #9141. --------- Co-authored-by: Markus Himmel --- src/Init/Data/Fin/Lemmas.lean | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/src/Init/Data/Fin/Lemmas.lean b/src/Init/Data/Fin/Lemmas.lean index a7a1935493..70352acdac 100644 --- a/src/Init/Data/Fin/Lemmas.lean +++ b/src/Init/Data/Fin/Lemmas.lean @@ -927,24 +927,34 @@ For the induction: -/ @[elab_as_elim] def reverseInduction {motive : Fin (n + 1) → Sort _} (last : motive (Fin.last n)) (cast : ∀ i : Fin n, motive i.succ → motive (castSucc i)) (i : Fin (n + 1)) : motive i := - if hi : i = Fin.last n then _root_.cast (congrArg motive hi.symm) last - else - let j : Fin n := ⟨i, Nat.lt_of_le_of_ne (Nat.le_of_lt_succ i.2) fun h => hi (Fin.ext h)⟩ - cast _ (reverseInduction last cast j.succ) -termination_by n + 1 - i -decreasing_by decreasing_with - -- FIXME: we put the proof down here to avoid getting a dummy `have` in the definition - try simp only [Nat.succ_sub_succ_eq_sub] - exact Nat.add_sub_add_right .. ▸ Nat.sub_lt_sub_left i.2 (Nat.lt_succ_self i) + let rec go (j : Nat) (h) (h2 : i ≤ j) (x : motive ⟨j, h⟩) : motive i := + if hi : i.1 = j then _root_.cast (by simp [← hi]) x + else match j with + | 0 => by omega + | j + 1 => go j (by omega) (by omega) (cast ⟨j, by omega⟩ x) + go _ _ (by omega) last @[simp] theorem reverseInduction_last {n : Nat} {motive : Fin (n + 1) → Sort _} {zero succ} : (reverseInduction zero succ (Fin.last n) : motive (Fin.last n)) = zero := by - rw [reverseInduction]; simp + rw [reverseInduction, reverseInduction.go]; simp + +@[simp] theorem reverseInduction_castSucc_aux {n : Nat} {motive : Fin (n + 1) → Sort _} {succ} + (i : Fin n) (j : Nat) (h) (h2 : i.1 < j) (zero : motive ⟨j, h⟩) : + reverseInduction.go (motive := motive) succ i.castSucc j h (Nat.le_of_lt h2) zero = + succ i (reverseInduction.go succ i.succ j h h2 zero) := by + induction j generalizing i with + | zero => omega + | succ j ih => + rw [reverseInduction.go, dif_neg (by exact Nat.ne_of_lt h2)] + by_cases hij : i = j + · subst hij; simp [reverseInduction.go] + dsimp only + rw [ih _ _ (by omega), eq_comm, reverseInduction.go, dif_neg (by change i.1 + 1 ≠ _; omega)] @[simp] theorem reverseInduction_castSucc {n : Nat} {motive : Fin (n + 1) → Sort _} {zero succ} (i : Fin n) : reverseInduction (motive := motive) zero succ (castSucc i) = succ i (reverseInduction zero succ i.succ) := by - rw [reverseInduction, dif_neg (Fin.ne_of_lt (Fin.castSucc_lt_last i))]; rfl + rw [reverseInduction, reverseInduction_castSucc_aux _ _ _ i.isLt, reverseInduction] /-- Proves a statement by cases on the underlying `Nat` value in a `Fin (n + 1)`, checking whether the