lean4-htt/src/Init/Data/BitVec/Folds.lean
2026-02-05 09:10:32 +00:00

129 lines
4.8 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2023 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Joe Hendrix, Harun Khan
-/
module
prelude
import all Init.Data.BitVec.Basic
public import Init.Data.BitVec.Basic
public import Init.Ext
import Init.Data.BitVec.Lemmas
import Init.Data.Fin.Iterate
public section
set_option linter.missingDocs true
namespace BitVec
/--
Constructs a bitvector by iteratively computing a state for each bit using the function `f`,
starting with the initial state `s`. At each step, the prior state and the current bit index are
passed to `f`, and it produces a bit along with the next state value. These bits are assembled into
the final bitvector.
It produces a sequence of state values `[s_0, s_1 .. s_w]` and a bitvector `v` where `f i s_i =
(s_{i+1}, b_i)` and `b_i` is bit `i`th least-significant bit in `v` (e.g., `getLsb v i = b_i`).
The theorem `iunfoldr_replace` allows uses of `BitVec.iunfoldr` to be replaced with declarative
specifications that are easier to reason about.
-/
def iunfoldr (f : Fin w → αα × Bool) (s : α) : α × BitVec w :=
Fin.hIterate (fun i => α × BitVec i) (s, nil) fun i q =>
(fun p => ⟨p.fst, cons p.snd q.snd⟩) (f i q.fst)
theorem iunfoldr.fst_eq
{f : Fin w → αα × Bool} (state : Nat → α) (s : α)
(init : s = state 0)
(ind : ∀(i : Fin w), (f i (state i.val)).fst = state (i.val+1)) :
(iunfoldr f s).fst = state w := by
unfold iunfoldr
apply Fin.hIterate_elim (fun i (p : α × BitVec i) => p.fst = state i)
case init =>
exact init
case step =>
intro i ⟨s, v⟩ p
simp_all [ind i]
private theorem iunfoldr.eq_test
{f : Fin w → αα × Bool} (state : Nat → α) (value : BitVec w) (a : α)
(init : state 0 = a)
(step : ∀(i : Fin w), f i (state i.val) = (state (i.val+1), value.getLsbD i.val)) :
iunfoldr f a = (state w, BitVec.truncate w value) := by
apply Fin.hIterate_eq (fun i => ((state i, BitVec.truncate i value) : α × BitVec i))
case init =>
simp only [init, eq_nil]
case step =>
intro i
simp_all [setWidth_succ]
theorem iunfoldr_getLsbD' {f : Fin w → αα × Bool} (state : Nat → α)
(ind : ∀(i : Fin w), (f i (state i.val)).fst = state (i.val+1)) :
(∀ i : Fin w, getLsbD (iunfoldr f (state 0)).snd i.val = (f i (state i.val)).snd)
∧ (iunfoldr f (state 0)).fst = state w := by
unfold iunfoldr
simp
apply Fin.hIterate_elim
(fun j (p : α × BitVec j) => (hj : j ≤ w) →
(∀ i : Fin j, getLsbD p.snd i.val = (f ⟨i.val, Nat.lt_of_lt_of_le i.isLt hj⟩ (state i.val)).snd)
∧ p.fst = state j)
case hj => simp
case init =>
intro
apply And.intro
· intro i
have := Fin.pos i
contradiction
· rfl
case step =>
intro j ⟨s, v⟩ ih hj
apply And.intro
case left =>
intro i
simp only [getLsbD_cons]
have hj2 : j.val ≤ w := by simp
cases (Nat.lt_or_eq_of_le (Nat.lt_succ_iff.mp i.isLt)) with
| inl h3 => simp [(Nat.ne_of_lt h3)]
exact (ih hj2).1 ⟨i.val, h3⟩
| inr h3 => simp [h3]
cases (Nat.eq_zero_or_pos j.val) with
| inl hj3 => congr
rw [← (ih hj2).2]
| inr hj3 => congr
exact (ih hj2).2
case right =>
simp
have hj2 : j.val ≤ w := by simp
rw [← ind j, ← (ih hj2).2]
theorem iunfoldr_getLsbD {f : Fin w → αα × Bool} (state : Nat → α) (i : Fin w)
(ind : ∀(i : Fin w), (f i (state i.val)).fst = state (i.val+1)) :
getLsbD (iunfoldr f (state 0)).snd i.val = (f i (state i.val)).snd := by
exact (iunfoldr_getLsbD' state ind).1 i
/--
Given a function `state` that provides the correct state for every potential iteration count and a
function that computes these states from the correct initial state, the result of applying
`BitVec.iunfoldr f` to the initial state is the state corresponding to the bitvector's width paired
with the bitvector that consists of each computed bit.
This theorem can be used to prove properties of functions that are defined using `BitVec.iunfoldr`.
-/
theorem iunfoldr_replace
{f : Fin w → αα × Bool} (state : Nat → α) (value : BitVec w) (a : α)
(init : state 0 = a)
(step : ∀(i : Fin w), f i (state i.val) = (state (i.val+1), value[i.val])) :
iunfoldr f a = (state w, value) := by
simp [iunfoldr.eq_test state value a init step]
theorem iunfoldr_replace_snd
{f : Fin w → αα × Bool} (state : Nat → α) (value : BitVec w) (a : α)
(init : state 0 = a)
(step : ∀(i : Fin w), f i (state i.val) = (state (i.val+1), value[i.val])) :
(iunfoldr f a).snd = value := by
simp [iunfoldr.eq_test state value a init step]
end BitVec