lean4-htt/src/Init/Data/Fin/Fold.lean
Eric Wieser ae1ab94992
fix: replace bad simp lemmas for Id (#7352)
This PR reworks the `simp` set around the `Id` monad, to not elide or
unfold `pure` and `Id.run`

In particular, it stops encoding the "defeq abuse" of `Id X = X` in the
statements of theorems, instead using `Id.run` and `pure` to pass back
and forth between these two spellings. Often when writing these with
`pure`, they generalize to other lawful monads; though such changes were
split off to other PRs.

This fixes the problem with the current simp set where `Id.run (pure x)`
is simplified to `Id.run x`, instead of the desirable `x`.
This is particularly bad because the` x` is sometimes inferred with type
`Id X` instead of `X`, which prevents other `simp` lemmas about `X` from
firing.

Making `Id` reducible instead is not an option, as then the `Monad`
instances would have nothing to key on.

---------

Co-authored-by: Sebastian Graf <sg@lean-fro.org>
Co-authored-by: Kim Morrison <kim@tqft.net>
Co-authored-by: Paul Reichert <6992158+datokrat@users.noreply.github.com>
2025-05-22 22:45:35 +00:00

344 lines
12 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2023 François G. Dorais. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: François G. Dorais
-/
module
prelude
import Init.Data.Nat.Linear
import Init.Control.Lawful.Basic
import Init.Data.Fin.Lemmas
namespace Fin
/--
Combine all the values that can be represented by `Fin n` with an initial value, starting at `0` and
nesting to the left.
Example:
* `Fin.foldl 3 (· + ·.val) (0 : Nat) = ((0 + (0 : Fin 3).val) + (1 : Fin 3).val) + (2 : Fin 3).val`
-/
@[inline] def foldl (n) (f : α → Fin n → α) (init : α) : α := loop init 0 where
/-- Inner loop for `Fin.foldl`. `Fin.foldl.loop n f x i = f (f (f x i) ...) (n-1)` -/
@[semireducible, specialize] loop (x : α) (i : Nat) : α :=
if h : i < n then loop (f x ⟨i, h⟩) (i+1) else x
termination_by n - i
/--
Combine all the values that can be represented by `Fin n` with an initial value, starting at `n - 1`
and nesting to the right.
Example:
* `Fin.foldr 3 (·.val + ·) (0 : Nat) = (0 : Fin 3).val + ((1 : Fin 3).val + ((2 : Fin 3).val + 0))`
-/
@[inline] def foldr (n) (f : Fin n → αα) (init : α) : α := loop n (Nat.le_refl n) init where
/-- Inner loop for `Fin.foldr`. `Fin.foldr.loop n f i x = f 0 (f ... (f (i-1) x))` -/
@[specialize] loop : (i : _) → i ≤ n → αα
| 0, _, x => x
| i+1, h, x => loop i (Nat.le_of_lt h) (f ⟨i, h⟩ x)
termination_by structural i => i
/--
Folds a monadic function over all the values in `Fin n` from left to right, starting with `0`.
It is the sequence of steps:
```
Fin.foldlM n f x₀ = do
let x₁ ← f x₀ 0
let x₂ ← f x₁ 1
...
let xₙ ← f xₙ₋₁ (n-1)
pure xₙ
```
-/
@[inline] def foldlM [Monad m] (n) (f : α → Fin n → m α) (init : α) : m α := loop init 0 where
/--
Inner loop for `Fin.foldlM`.
```
Fin.foldlM.loop n f xᵢ i = do
let xᵢ₊₁ ← f xᵢ i
...
let xₙ ← f xₙ₋₁ (n-1)
pure xₙ
```
-/
@[semireducible, specialize] loop (x : α) (i : Nat) : m α := do
if h : i < n then f x ⟨i, h⟩ >>= (loop · (i+1)) else pure x
termination_by n - i
decreasing_by decreasing_trivial_pre_omega
/--
Folds a monadic function over `Fin n` from right to left, starting with `n-1`.
It is the sequence of steps:
```
Fin.foldrM n f xₙ = do
let xₙ₋₁ ← f (n-1) xₙ
let xₙ₋₂ ← f (n-2) xₙ₋₁
...
let x₀ ← f 0 x₁
pure x₀
```
-/
@[inline] def foldrM [Monad m] (n) (f : Fin n → α → m α) (init : α) : m α :=
loop ⟨n, Nat.le_refl n⟩ init where
/--
Inner loop for `Fin.foldrM`.
```
Fin.foldrM.loop n f i xᵢ = do
let xᵢ₋₁ ← f (i-1) xᵢ
...
let x₁ ← f 1 x₂
let x₀ ← f 0 x₁
pure x₀
```
-/
@[semireducible, specialize] loop : {i // i ≤ n} → α → m α
| ⟨0, _⟩, x => pure x
| ⟨i+1, h⟩, x => f ⟨i, h⟩ x >>= loop ⟨i, Nat.le_of_lt h⟩
/-! ### foldlM -/
@[congr] theorem foldlM_congr [Monad m] {n k : Nat} (w : n = k) (f : α → Fin n → m α) :
foldlM n f = foldlM k (fun x i => f x (i.cast w.symm)) := by
subst w
rfl
theorem foldlM_loop_lt [Monad m] (f : α → Fin n → m α) (x) (h : i < n) :
foldlM.loop n f x i = f x ⟨i, h⟩ >>= (foldlM.loop n f . (i+1)) := by
rw [foldlM.loop, dif_pos h]
theorem foldlM_loop_eq [Monad m] (f : α → Fin n → m α) (x) : foldlM.loop n f x n = pure x := by
rw [foldlM.loop, dif_neg (Nat.lt_irrefl _)]
theorem foldlM_loop [Monad m] (f : α → Fin (n+1) → m α) (x) (h : i < n+1) :
foldlM.loop (n+1) f x i = f x ⟨i, h⟩ >>= (foldlM.loop n (fun x j => f x j.succ) . i) := by
if h' : i < n then
rw [foldlM_loop_lt _ _ h]
congr; funext
rw [foldlM_loop_lt _ _ h', foldlM_loop]; rfl
else
cases Nat.le_antisymm (Nat.le_of_lt_succ h) (Nat.not_lt.1 h')
rw [foldlM_loop_lt]
congr; funext
rw [foldlM_loop_eq, foldlM_loop_eq]
termination_by n - i
@[simp] theorem foldlM_zero [Monad m] (f : α → Fin 0 → m α) : foldlM 0 f = pure := by
funext x
exact foldlM_loop_eq ..
theorem foldlM_succ [Monad m] (f : α → Fin (n+1) → m α) :
foldlM (n+1) f = fun x => f x 0 >>= foldlM n (fun x j => f x j.succ) := by
funext x
exact foldlM_loop ..
/-- Variant of `foldlM_succ` that splits off `Fin.last n` rather than `0`. -/
theorem foldlM_succ_last [Monad m] [LawfulMonad m] (f : α → Fin (n+1) → m α) :
foldlM (n+1) f = fun x => foldlM n (fun x j => f x j.castSucc) x >>= (f · (Fin.last n)) := by
funext x
induction n generalizing x with
| zero =>
simp [foldlM_succ]
| succ n ih =>
rw [foldlM_succ]
conv => rhs; rw [foldlM_succ]
simp only [castSucc_zero, castSucc_succ, bind_assoc]
congr 1
funext x
rw [ih]
simp
theorem foldlM_add [Monad m] [LawfulMonad m] (f : α → Fin (n + k) → m α) :
foldlM (n + k) f =
fun x => foldlM n (fun x i => f x (i.castLE (Nat.le_add_right n k))) x >>= foldlM k (fun x i => f x (i.natAdd n)) := by
induction k with
| zero =>
funext x
simp
| succ k ih =>
funext x
simp [foldlM_succ_last, ← Nat.add_assoc, ih]
/-! ### foldrM -/
@[congr] theorem foldrM_congr [Monad m] {n k : Nat} (w : n = k) (f : Fin n → α → m α) :
foldrM n f = foldrM k (fun i => f (i.cast w.symm)) := by
subst w
rfl
theorem foldrM_loop_zero [Monad m] (f : Fin n → α → m α) (x) :
foldrM.loop n f ⟨0, Nat.zero_le _⟩ x = pure x := by
rw [foldrM.loop]
theorem foldrM_loop_succ [Monad m] (f : Fin n → α → m α) (x) (h : i < n) :
foldrM.loop n f ⟨i+1, h⟩ x = f ⟨i, h⟩ x >>= foldrM.loop n f ⟨i, Nat.le_of_lt h⟩ := by
rw [foldrM.loop]
theorem foldrM_loop [Monad m] [LawfulMonad m] (f : Fin (n+1) → α → m α) (x) (h : i+1 ≤ n+1) :
foldrM.loop (n+1) f ⟨i+1, h⟩ x =
foldrM.loop n (fun j => f j.succ) ⟨i, Nat.le_of_succ_le_succ h⟩ x >>= f 0 := by
induction i generalizing x with
| zero =>
rw [foldrM_loop_zero, foldrM_loop_succ, pure_bind]
conv => rhs; rw [←bind_pure (f 0 x)]
congr
funext
simp [foldrM_loop_zero]
| succ i ih =>
rw [foldrM_loop_succ, foldrM_loop_succ, bind_assoc]
congr; funext; exact ih ..
@[simp] theorem foldrM_zero [Monad m] (f : Fin 0 → α → m α) : foldrM 0 f = pure := by
funext x
exact foldrM_loop_zero ..
theorem foldrM_succ [Monad m] [LawfulMonad m] (f : Fin (n+1) → α → m α) :
foldrM (n+1) f = fun x => foldrM n (fun i => f i.succ) x >>= f 0 := by
funext x
exact foldrM_loop ..
theorem foldrM_succ_last [Monad m] [LawfulMonad m] (f : Fin (n+1) → α → m α) :
foldrM (n+1) f = fun x => f (Fin.last n) x >>= foldrM n (fun i => f i.castSucc) := by
funext x
induction n generalizing x with
| zero => simp [foldrM_succ]
| succ n ih =>
rw [foldrM_succ]
conv => rhs; rw [foldrM_succ]
simp [ih]
theorem foldrM_add [Monad m] [LawfulMonad m] (f : Fin (n + k) → α → m α) :
foldrM (n + k) f =
fun x => foldrM k (fun i => f (i.natAdd n)) x >>= foldrM n (fun i => f (i.castLE (Nat.le_add_right n k))) := by
induction k with
| zero =>
simp
| succ k ih =>
funext x
simp [foldrM_succ_last, ← Nat.add_assoc, ih]
/-! ### foldl -/
@[congr] theorem foldl_congr {n k : Nat} (w : n = k) (f : α → Fin n → α) :
foldl n f = foldl k (fun x i => f x (i.cast w.symm)) := by
subst w
rfl
theorem foldl_loop_lt (f : α → Fin n → α) (x) (h : i < n) :
foldl.loop n f x i = foldl.loop n f (f x ⟨i, h⟩) (i+1) := by
rw [foldl.loop, dif_pos h]
theorem foldl_loop_eq (f : α → Fin n → α) (x) : foldl.loop n f x n = x := by
rw [foldl.loop, dif_neg (Nat.lt_irrefl _)]
theorem foldl_loop (f : α → Fin (n+1) → α) (x) (h : i < n+1) :
foldl.loop (n+1) f x i = foldl.loop n (fun x j => f x j.succ) (f x ⟨i, h⟩) i := by
if h' : i < n then
rw [foldl_loop_lt _ _ h]
rw [foldl_loop_lt _ _ h', foldl_loop]; rfl
else
cases Nat.le_antisymm (Nat.le_of_lt_succ h) (Nat.not_lt.1 h')
rw [foldl_loop_lt]
rw [foldl_loop_eq, foldl_loop_eq]
@[simp] theorem foldl_zero (f : α → Fin 0 → α) (x) : foldl 0 f x = x :=
foldl_loop_eq ..
theorem foldl_succ (f : α → Fin (n+1) → α) (x) :
foldl (n+1) f x = foldl n (fun x i => f x i.succ) (f x 0) :=
foldl_loop ..
theorem foldl_succ_last (f : α → Fin (n+1) → α) (x) :
foldl (n+1) f x = f (foldl n (f · ·.castSucc) x) (last n) := by
rw [foldl_succ]
induction n generalizing x with
| zero => simp [foldl_succ, Fin.last]
| succ n ih => rw [foldl_succ, ih (f · ·.succ), foldl_succ]; simp
theorem foldl_add (f : α → Fin (n + m) → α) (x) :
foldl (n + m) f x =
foldl m (fun x i => f x (i.natAdd n))
(foldl n (fun x i => f x (i.castLE (Nat.le_add_right n m))) x):= by
induction m with
| zero => simp
| succ m ih => simp [foldl_succ_last, ih, ← Nat.add_assoc]
theorem foldl_eq_foldlM (f : α → Fin n → α) (x) :
foldl n f x = (foldlM (m := Id) n (pure <| f · ·) x).run := by
induction n generalizing x <;> simp [foldl_succ, foldlM_succ, *]
-- This is not marked `@[simp]` as it would match on every occurrence of `foldlM`.
theorem foldlM_pure [Monad m] [LawfulMonad m] {n} {f : α → Fin n → α} :
foldlM n (fun x i => pure (f x i)) x = (pure (foldl n f x) : m α) := by
induction n generalizing x with
| zero => simp
| succ n ih => simp [foldlM_succ, foldl_succ, ih]
/-! ### foldr -/
@[congr] theorem foldr_congr {n k : Nat} (w : n = k) (f : Fin n → αα) :
foldr n f = foldr k (fun i => f (i.cast w.symm)) := by
subst w
rfl
theorem foldr_loop_zero (f : Fin n → αα) (x) :
foldr.loop n f 0 (Nat.zero_le _) x = x := by
rw [foldr.loop]
theorem foldr_loop_succ (f : Fin n → αα) (x) (h : i < n) :
foldr.loop n f (i+1) h x = foldr.loop n f i (Nat.le_of_lt h) (f ⟨i, h⟩ x) := by
rw [foldr.loop]
theorem foldr_loop (f : Fin (n+1) → αα) (x) (h : i+1 ≤ n+1) :
foldr.loop (n+1) f (i+1) h x =
f 0 (foldr.loop n (fun j => f j.succ) i (Nat.le_of_succ_le_succ h) x) := by
induction i generalizing x with
| zero => simp [foldr_loop_succ, foldr_loop_zero]
| succ i ih => rw [foldr_loop_succ, ih]; rfl
@[simp] theorem foldr_zero (f : Fin 0 → αα) (x) : foldr 0 f x = x :=
foldr_loop_zero ..
theorem foldr_succ (f : Fin (n+1) → αα) (x) :
foldr (n+1) f x = f 0 (foldr n (fun i => f i.succ) x) := foldr_loop ..
theorem foldr_succ_last (f : Fin (n+1) → αα) (x) :
foldr (n+1) f x = foldr n (f ·.castSucc) (f (last n) x) := by
induction n generalizing x with
| zero => simp [foldr_succ, Fin.last]
| succ n ih => rw [foldr_succ, ih (f ·.succ), foldr_succ]; simp
theorem foldr_add (f : Fin (n + m) → αα) (x) :
foldr (n + m) f x =
foldr n (fun i => f (i.castLE (Nat.le_add_right n m)))
(foldr m (fun i => f (i.natAdd n)) x) := by
induction m generalizing x with
| zero => simp
| succ m ih => simp [foldr_succ_last, ih, ← Nat.add_assoc]
theorem foldr_eq_foldrM (f : Fin n → αα) (x) :
foldr n f x = (foldrM (m := Id) n (pure <| f · ·) x).run := by
induction n <;> simp [foldr_succ, foldrM_succ, *]
theorem foldl_rev (f : Fin n → αα) (x) :
foldl n (fun x i => f i.rev x) x = foldr n f x := by
induction n generalizing x with
| zero => simp
| succ n ih => rw [foldl_succ, foldr_succ_last, ← ih]; simp [rev_succ]
theorem foldr_rev (f : α → Fin n → α) (x) :
foldr n (fun i x => f x i.rev) x = foldl n f x := by
induction n generalizing x with
| zero => simp
| succ n ih => rw [foldl_succ_last, foldr_succ, ← ih]; simp [rev_succ]
-- This is not marked `@[simp]` as it would match on every occurrence of `foldrM`.
theorem foldrM_pure [Monad m] [LawfulMonad m] {n} {f : Fin n → αα} :
foldrM n (fun i x => pure (f i x)) x = (pure (foldr n f x) : m α) := by
induction n generalizing x with
| zero => simp
| succ n ih => simp [foldrM_succ, foldr_succ, ih]
end Fin