lean4-htt/src/Init/Data/BitVec/Folds.lean
Tobias Grosser a9efbf04f4
feat: make BitVec.getElem the simp normal form and use it in ext (#5498)
This PR makes `BitVec.getElem` the simp normal form in case a proof is
available and changes `ext` to return `x[i]` + a hypothesis that proves
that we are in-bounds. This aligns `BitVec` further with the API
conventions of the Lean standard datatypes.

We move our proofs to this new normal form, which results in slightly
smaller proofs. With the exception of `getElem_ofFin`, no new API
surface is added as the `getElem` API has already been completed over
the previous months. We also move `getElem_shiftConcat_*` a bit higher
as they are needed in earlier proofs. To keep the changeset small, we do
not update the API of `BVDecide` but insert `←
BitVec.getLsbD_eq_getElem` at the few locations where it is needed.
Finally, we add a simproc for getElem, mirroring the existing ones for
getLsbD/getMsdD.

---------

Co-authored-by: Alex Keizer <alex@keizer.dev>
2025-02-16 00:04:56 +00:00

115 lines
4 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2023 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Joe Hendrix, Harun Khan
-/
prelude
import Init.Data.BitVec.Lemmas
import Init.Data.Nat.Lemmas
import Init.Data.Fin.Iterate
set_option linter.missingDocs true
namespace BitVec
/--
iunfoldr is an iterative operation that applies a function `f` repeatedly.
It produces a sequence of state values `[s_0, s_1 .. s_w]` and a bitvector
`v` where `f i s_i = (s_{i+1}, b_i)` and `b_i` is bit `i`th least-significant bit
in `v` (e.g., `getLsb v i = b_i`).
Theorems involving `iunfoldr` can be eliminated using `iunfoldr_replace` below.
-/
def iunfoldr (f : Fin w -> αα × Bool) (s : α) : α × BitVec w :=
Fin.hIterate (fun i => α × BitVec i) (s, nil) fun i q =>
(fun p => ⟨p.fst, cons p.snd q.snd⟩) (f i q.fst)
theorem iunfoldr.fst_eq
{f : Fin w → αα × Bool} (state : Nat → α) (s : α)
(init : s = state 0)
(ind : ∀(i : Fin w), (f i (state i.val)).fst = state (i.val+1)) :
(iunfoldr f s).fst = state w := by
unfold iunfoldr
apply Fin.hIterate_elim (fun i (p : α × BitVec i) => p.fst = state i)
case init =>
exact init
case step =>
intro i ⟨s, v⟩ p
simp_all [ind i]
private theorem iunfoldr.eq_test
{f : Fin w → αα × Bool} (state : Nat → α) (value : BitVec w) (a : α)
(init : state 0 = a)
(step : ∀(i : Fin w), f i (state i.val) = (state (i.val+1), value.getLsbD i.val)) :
iunfoldr f a = (state w, BitVec.truncate w value) := by
apply Fin.hIterate_eq (fun i => ((state i, BitVec.truncate i value) : α × BitVec i))
case init =>
simp only [init, eq_nil]
case step =>
intro i
simp_all [setWidth_succ]
theorem iunfoldr_getLsbD' {f : Fin w → αα × Bool} (state : Nat → α)
(ind : ∀(i : Fin w), (f i (state i.val)).fst = state (i.val+1)) :
(∀ i : Fin w, getLsbD (iunfoldr f (state 0)).snd i.val = (f i (state i.val)).snd)
∧ (iunfoldr f (state 0)).fst = state w := by
unfold iunfoldr
simp
apply Fin.hIterate_elim
(fun j (p : α × BitVec j) => (hj : j ≤ w) →
(∀ i : Fin j, getLsbD p.snd i.val = (f ⟨i.val, Nat.lt_of_lt_of_le i.isLt hj⟩ (state i.val)).snd)
∧ p.fst = state j)
case hj => simp
case init =>
intro
apply And.intro
· intro i
have := Fin.pos i
contradiction
· rfl
case step =>
intro j ⟨s, v⟩ ih hj
apply And.intro
case left =>
intro i
simp only [getLsbD_cons]
have hj2 : j.val ≤ w := by simp
cases (Nat.lt_or_eq_of_le (Nat.lt_succ.mp i.isLt)) with
| inl h3 => simp [if_neg, (Nat.ne_of_lt h3)]
exact (ih hj2).1 ⟨i.val, h3⟩
| inr h3 => simp [h3, if_pos]
cases (Nat.eq_zero_or_pos j.val) with
| inl hj3 => congr
rw [← (ih hj2).2]
| inr hj3 => congr
exact (ih hj2).2
case right =>
simp
have hj2 : j.val ≤ w := by simp
rw [← ind j, ← (ih hj2).2]
theorem iunfoldr_getLsbD {f : Fin w → αα × Bool} (state : Nat → α) (i : Fin w)
(ind : ∀(i : Fin w), (f i (state i.val)).fst = state (i.val+1)) :
getLsbD (iunfoldr f (state 0)).snd i.val = (f i (state i.val)).snd := by
exact (iunfoldr_getLsbD' state ind).1 i
/--
Correctness theorem for `iunfoldr`.
-/
theorem iunfoldr_replace
{f : Fin w → αα × Bool} (state : Nat → α) (value : BitVec w) (a : α)
(init : state 0 = a)
(step : ∀(i : Fin w), f i (state i.val) = (state (i.val+1), value[i.val])) :
iunfoldr f a = (state w, value) := by
simp [iunfoldr.eq_test state value a init step]
theorem iunfoldr_replace_snd
{f : Fin w → αα × Bool} (state : Nat → α) (value : BitVec w) (a : α)
(init : state 0 = a)
(step : ∀(i : Fin w), f i (state i.val) = (state (i.val+1), value[i.val])) :
(iunfoldr f a).snd = value := by
simp [iunfoldr.eq_test state value a init step]
end BitVec