129 lines
4.8 KiB
Text
129 lines
4.8 KiB
Text
/-
|
||
Copyright (c) 2023 Lean FRO, LLC. All rights reserved.
|
||
Released under Apache 2.0 license as described in the file LICENSE.
|
||
Authors: Joe Hendrix, Harun Khan
|
||
-/
|
||
module
|
||
|
||
prelude
|
||
import all Init.Data.BitVec.Basic
|
||
public import Init.Data.BitVec.Basic
|
||
public import Init.Ext
|
||
import Init.Data.BitVec.Lemmas
|
||
import Init.Data.Fin.Iterate
|
||
|
||
public section
|
||
|
||
set_option linter.missingDocs true
|
||
|
||
namespace BitVec
|
||
|
||
/--
|
||
Constructs a bitvector by iteratively computing a state for each bit using the function `f`,
|
||
starting with the initial state `s`. At each step, the prior state and the current bit index are
|
||
passed to `f`, and it produces a bit along with the next state value. These bits are assembled into
|
||
the final bitvector.
|
||
|
||
It produces a sequence of state values `[s_0, s_1 .. s_w]` and a bitvector `v` where `f i s_i =
|
||
(s_{i+1}, b_i)` and `b_i` is bit `i`th least-significant bit in `v` (e.g., `getLsb v i = b_i`).
|
||
|
||
The theorem `iunfoldr_replace` allows uses of `BitVec.iunfoldr` to be replaced with declarative
|
||
specifications that are easier to reason about.
|
||
-/
|
||
def iunfoldr (f : Fin w → α → α × Bool) (s : α) : α × BitVec w :=
|
||
Fin.hIterate (fun i => α × BitVec i) (s, nil) fun i q =>
|
||
(fun p => ⟨p.fst, cons p.snd q.snd⟩) (f i q.fst)
|
||
|
||
theorem iunfoldr.fst_eq
|
||
{f : Fin w → α → α × Bool} (state : Nat → α) (s : α)
|
||
(init : s = state 0)
|
||
(ind : ∀(i : Fin w), (f i (state i.val)).fst = state (i.val+1)) :
|
||
(iunfoldr f s).fst = state w := by
|
||
unfold iunfoldr
|
||
apply Fin.hIterate_elim (fun i (p : α × BitVec i) => p.fst = state i)
|
||
case init =>
|
||
exact init
|
||
case step =>
|
||
intro i ⟨s, v⟩ p
|
||
simp_all [ind i]
|
||
|
||
private theorem iunfoldr.eq_test
|
||
{f : Fin w → α → α × Bool} (state : Nat → α) (value : BitVec w) (a : α)
|
||
(init : state 0 = a)
|
||
(step : ∀(i : Fin w), f i (state i.val) = (state (i.val+1), value.getLsbD i.val)) :
|
||
iunfoldr f a = (state w, BitVec.truncate w value) := by
|
||
apply Fin.hIterate_eq (fun i => ((state i, BitVec.truncate i value) : α × BitVec i))
|
||
case init =>
|
||
simp only [init, eq_nil]
|
||
case step =>
|
||
intro i
|
||
simp_all [setWidth_succ]
|
||
|
||
theorem iunfoldr_getLsbD' {f : Fin w → α → α × Bool} (state : Nat → α)
|
||
(ind : ∀(i : Fin w), (f i (state i.val)).fst = state (i.val+1)) :
|
||
(∀ i : Fin w, getLsbD (iunfoldr f (state 0)).snd i.val = (f i (state i.val)).snd)
|
||
∧ (iunfoldr f (state 0)).fst = state w := by
|
||
unfold iunfoldr
|
||
simp
|
||
apply Fin.hIterate_elim
|
||
(fun j (p : α × BitVec j) => (hj : j ≤ w) →
|
||
(∀ i : Fin j, getLsbD p.snd i.val = (f ⟨i.val, Nat.lt_of_lt_of_le i.isLt hj⟩ (state i.val)).snd)
|
||
∧ p.fst = state j)
|
||
case hj => simp
|
||
case init =>
|
||
intro
|
||
apply And.intro
|
||
· intro i
|
||
have := Fin.pos i
|
||
contradiction
|
||
· rfl
|
||
case step =>
|
||
intro j ⟨s, v⟩ ih hj
|
||
apply And.intro
|
||
case left =>
|
||
intro i
|
||
simp only [getLsbD_cons]
|
||
have hj2 : j.val ≤ w := by simp
|
||
cases (Nat.lt_or_eq_of_le (Nat.lt_succ_iff.mp i.isLt)) with
|
||
| inl h3 => simp [(Nat.ne_of_lt h3)]
|
||
exact (ih hj2).1 ⟨i.val, h3⟩
|
||
| inr h3 => simp [h3]
|
||
cases (Nat.eq_zero_or_pos j.val) with
|
||
| inl hj3 => congr
|
||
rw [← (ih hj2).2]
|
||
| inr hj3 => congr
|
||
exact (ih hj2).2
|
||
case right =>
|
||
simp
|
||
have hj2 : j.val ≤ w := by simp
|
||
rw [← ind j, ← (ih hj2).2]
|
||
|
||
|
||
theorem iunfoldr_getLsbD {f : Fin w → α → α × Bool} (state : Nat → α) (i : Fin w)
|
||
(ind : ∀(i : Fin w), (f i (state i.val)).fst = state (i.val+1)) :
|
||
getLsbD (iunfoldr f (state 0)).snd i.val = (f i (state i.val)).snd := by
|
||
exact (iunfoldr_getLsbD' state ind).1 i
|
||
|
||
/--
|
||
Given a function `state` that provides the correct state for every potential iteration count and a
|
||
function that computes these states from the correct initial state, the result of applying
|
||
`BitVec.iunfoldr f` to the initial state is the state corresponding to the bitvector's width paired
|
||
with the bitvector that consists of each computed bit.
|
||
|
||
This theorem can be used to prove properties of functions that are defined using `BitVec.iunfoldr`.
|
||
-/
|
||
theorem iunfoldr_replace
|
||
{f : Fin w → α → α × Bool} (state : Nat → α) (value : BitVec w) (a : α)
|
||
(init : state 0 = a)
|
||
(step : ∀(i : Fin w), f i (state i.val) = (state (i.val+1), value[i.val])) :
|
||
iunfoldr f a = (state w, value) := by
|
||
simp [iunfoldr.eq_test state value a init step]
|
||
|
||
theorem iunfoldr_replace_snd
|
||
{f : Fin w → α → α × Bool} (state : Nat → α) (value : BitVec w) (a : α)
|
||
(init : state 0 = a)
|
||
(step : ∀(i : Fin w), f i (state i.val) = (state (i.val+1), value[i.val])) :
|
||
(iunfoldr f a).snd = value := by
|
||
simp [iunfoldr.eq_test state value a init step]
|
||
|
||
end BitVec
|