/- Copyright (c) 2024 Shreyas Srinivas. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. Authors: Shreyas Srinivas, François G. Dorais, Kim Morrison -/ prelude import Init.Data.Array.Lemmas import Init.Data.Array.MapIdx import Init.Data.Range /-! # Vectors `Vector α n` is a thin wrapper around `Array α` for arrays of fixed size `n`. -/ /-- `Vector α n` is an `Array α` with size `n`. -/ structure Vector (α : Type u) (n : Nat) extends Array α where /-- Array size. -/ size_toArray : toArray.size = n deriving Repr, DecidableEq attribute [simp] Vector.size_toArray /-- Convert `xs : Array α` to `Vector α xs.size`. -/ abbrev Array.toVector (xs : Array α) : Vector α xs.size := .mk xs rfl namespace Vector /-- Syntax for `Vector α n` -/ syntax "#v[" withoutPosition(sepBy(term, ", ")) "]" : term open Lean in macro_rules | `(#v[ $elems,* ]) => `(Vector.mk (n := $(quote elems.getElems.size)) #[$elems,*] rfl) /-- Custom eliminator for `Vector α n` through `Array α` -/ @[elab_as_elim] def elimAsArray {motive : Vector α n → Sort u} (mk : ∀ (a : Array α) (ha : a.size = n), motive ⟨a, ha⟩) : (v : Vector α n) → motive v | ⟨a, ha⟩ => mk a ha /-- Custom eliminator for `Vector α n` through `List α` -/ @[elab_as_elim] def elimAsList {motive : Vector α n → Sort u} (mk : ∀ (a : List α) (ha : a.length = n), motive ⟨⟨a⟩, ha⟩) : (v : Vector α n) → motive v | ⟨⟨a⟩, ha⟩ => mk a ha /-- Make an empty vector with pre-allocated capacity. -/ @[inline] def mkEmpty (capacity : Nat) : Vector α 0 := ⟨.mkEmpty capacity, rfl⟩ /-- Makes a vector of size `n` with all cells containing `v`. -/ @[inline] def mkVector (n) (v : α) : Vector α n := ⟨mkArray n v, by simp⟩ /-- Returns a vector of size `1` with element `v`. -/ @[inline] def singleton (v : α) : Vector α 1 := ⟨#[v], rfl⟩ instance [Inhabited α] : Inhabited (Vector α n) where default := mkVector n default /-- Get an element of a vector using a `Fin` index. -/ @[inline] def get (v : Vector α n) (i : Fin n) : α := v.toArray[(i.cast v.size_toArray.symm).1] /-- Get an element of a vector using a `USize` index and a proof that the index is within bounds. -/ @[inline] def uget (v : Vector α n) (i : USize) (h : i.toNat < n) : α := v.toArray.uget i (v.size_toArray.symm ▸ h) instance : GetElem (Vector α n) Nat α fun _ i => i < n where getElem x i h := get x ⟨i, h⟩ /-- Check if there is an element which satisfies `a == ·`. -/ def contains [BEq α] (v : Vector α n) (a : α) : Bool := v.toArray.contains a /-- `a ∈ v` is a predicate which asserts that `a` is in the vector `v`. -/ structure Mem (as : Vector α n) (a : α) : Prop where val : a ∈ as.toArray instance : Membership α (Vector α n) where mem := Mem /-- Get an element of a vector using a `Nat` index. Returns the given default value if the index is out of bounds. -/ @[inline] def getD (v : Vector α n) (i : Nat) (default : α) : α := v.toArray.getD i default /-- The last element of a vector. Panics if the vector is empty. -/ @[inline] def back! [Inhabited α] (v : Vector α n) : α := v.toArray.back! /-- The last element of a vector, or `none` if the vector is empty. -/ @[inline] def back? (v : Vector α n) : Option α := v.toArray.back? /-- The last element of a non-empty vector. -/ @[inline] def back [NeZero n] (v : Vector α n) : α := v[n - 1]'(Nat.sub_one_lt (NeZero.ne n)) /-- The first element of a non-empty vector. -/ @[inline] def head [NeZero n] (v : Vector α n) := v[0]'(Nat.pos_of_neZero n) /-- Push an element `x` to the end of a vector. -/ @[inline] def push (v : Vector α n) (x : α) : Vector α (n + 1) := ⟨v.toArray.push x, by simp⟩ /-- Remove the last element of a vector. -/ @[inline] def pop (v : Vector α n) : Vector α (n - 1) := ⟨Array.pop v.toArray, by simp⟩ /-- Set an element in a vector using a `Nat` index, with a tactic provided proof that the index is in bounds. This will perform the update destructively provided that the vector has a reference count of 1. -/ @[inline] def set (v : Vector α n) (i : Nat) (x : α) (h : i < n := by get_elem_tactic): Vector α n := ⟨v.toArray.set i x (by simp [*]), by simp⟩ /-- Set an element in a vector using a `Nat` index. Returns the vector unchanged if the index is out of bounds. This will perform the update destructively provided that the vector has a reference count of 1. -/ @[inline] def setIfInBounds (v : Vector α n) (i : Nat) (x : α) : Vector α n := ⟨v.toArray.setIfInBounds i x, by simp⟩ /-- Set an element in a vector using a `Nat` index. Panics if the index is out of bounds. This will perform the update destructively provided that the vector has a reference count of 1. -/ @[inline] def set! (v : Vector α n) (i : Nat) (x : α) : Vector α n := ⟨v.toArray.set! i x, by simp⟩ @[inline] def foldlM [Monad m] (f : β → α → m β) (b : β) (v : Vector α n) : m β := v.toArray.foldlM f b @[inline] def foldrM [Monad m] (f : α → β → m β) (b : β) (v : Vector α n) : m β := v.toArray.foldrM f b @[inline] def foldl (f : β → α → β) (b : β) (v : Vector α n) : β := v.toArray.foldl f b @[inline] def foldr (f : α → β → β) (b : β) (v : Vector α n) : β := v.toArray.foldr f b /-- Append two vectors. -/ @[inline] def append (v : Vector α n) (w : Vector α m) : Vector α (n + m) := ⟨v.toArray ++ w.toArray, by simp⟩ instance : HAppend (Vector α n) (Vector α m) (Vector α (n + m)) where hAppend := append /-- Creates a vector from another with a provably equal length. -/ @[inline] protected def cast (h : n = m) (v : Vector α n) : Vector α m := ⟨v.toArray, by simp [*]⟩ /-- Extracts the slice of a vector from indices `start` to `stop` (exclusive). If `start ≥ stop`, the result is empty. If `stop` is greater than the size of the vector, the size is used instead. -/ @[inline] def extract (v : Vector α n) (start stop : Nat) : Vector α (min stop n - start) := ⟨v.toArray.extract start stop, by simp⟩ /-- Maps elements of a vector using the function `f`. -/ @[inline] def map (f : α → β) (v : Vector α n) : Vector β n := ⟨v.toArray.map f, by simp⟩ /-- Maps elements of a vector using the function `f`, which also receives the index of the element. -/ @[inline] def mapIdx (f : Nat → α → β) (v : Vector α n) : Vector β n := ⟨v.toArray.mapIdx f, by simp⟩ /-- Maps elements of a vector using the function `f`, which also receives the index of the element, and the fact that the index is less than the size of the vector. -/ @[inline] def mapFinIdx (v : Vector α n) (f : (i : Nat) → α → (h : i < n) → β) : Vector β n := ⟨v.toArray.mapFinIdx (fun i a h => f i a (by simpa [v.size_toArray] using h)), by simp⟩ @[inline] def flatten (v : Vector (Vector α n) m) : Vector α (m * n) := ⟨(v.toArray.map Vector.toArray).flatten, by rcases v; simp_all [Function.comp_def, Array.map_const']⟩ @[inline] def flatMap (v : Vector α n) (f : α → Vector β m) : Vector β (n * m) := ⟨v.toArray.flatMap fun a => (f a).toArray, by simp [Array.map_const']⟩ @[inline] def zipWithIndex (v : Vector α n) : Vector (α × Nat) n := ⟨v.toArray.zipWithIndex, by simp⟩ /-- Maps corresponding elements of two vectors of equal size using the function `f`. -/ @[inline] def zipWith (a : Vector α n) (b : Vector β n) (f : α → β → φ) : Vector φ n := ⟨Array.zipWith a.toArray b.toArray f, by simp⟩ /-- The vector of length `n` whose `i`-th element is `f i`. -/ @[inline] def ofFn (f : Fin n → α) : Vector α n := ⟨Array.ofFn f, by simp⟩ /-- Swap two elements of a vector using `Fin` indices. This will perform the update destructively provided that the vector has a reference count of 1. -/ @[inline] def swap (v : Vector α n) (i j : Nat) (hi : i < n := by get_elem_tactic) (hj : j < n := by get_elem_tactic) : Vector α n := ⟨v.toArray.swap i j (by simpa using hi) (by simpa using hj), by simp⟩ /-- Swap two elements of a vector using `Nat` indices. Panics if either index is out of bounds. This will perform the update destructively provided that the vector has a reference count of 1. -/ @[inline] def swapIfInBounds (v : Vector α n) (i j : Nat) : Vector α n := ⟨v.toArray.swapIfInBounds i j, by simp⟩ /-- Swaps an element of a vector with a given value using a `Fin` index. The original value is returned along with the updated vector. This will perform the update destructively provided that the vector has a reference count of 1. -/ @[inline] def swapAt (v : Vector α n) (i : Nat) (x : α) (hi : i < n := by get_elem_tactic) : α × Vector α n := let a := v.toArray.swapAt i x (by simpa using hi) ⟨a.fst, a.snd, by simp [a]⟩ /-- Swaps an element of a vector with a given value using a `Nat` index. Panics if the index is out of bounds. The original value is returned along with the updated vector. This will perform the update destructively provided that the vector has a reference count of 1. -/ @[inline] def swapAt! (v : Vector α n) (i : Nat) (x : α) : α × Vector α n := let a := v.toArray.swapAt! i x ⟨a.fst, a.snd, by simp [a]⟩ /-- The vector `#v[0,1,2,...,n-1]`. -/ @[inline] def range (n : Nat) : Vector Nat n := ⟨Array.range n, by simp⟩ /-- Extract the first `m` elements of a vector. If `m` is greater than or equal to the size of the vector then the vector is returned unchanged. -/ @[inline] def take (v : Vector α n) (m : Nat) : Vector α (min m n) := ⟨v.toArray.take m, by simp⟩ /-- Deletes the first `m` elements of a vector. If `m` is greater than or equal to the size of the vector then the empty vector is returned. -/ @[inline] def drop (v : Vector α n) (m : Nat) : Vector α (n - m) := ⟨v.toArray.extract m v.size, by simp⟩ /-- Compares two vectors of the same size using a given boolean relation `r`. `isEqv v w r` returns `true` if and only if `r v[i] w[i]` is true for all indices `i`. -/ @[inline] def isEqv (v w : Vector α n) (r : α → α → Bool) : Bool := Array.isEqvAux v.toArray w.toArray (by simp) r n (by simp) instance [BEq α] : BEq (Vector α n) where beq a b := isEqv a b (· == ·) /-- Reverse the elements of a vector. -/ @[inline] def reverse (v : Vector α n) : Vector α n := ⟨v.toArray.reverse, by simp⟩ /-- Delete an element of a vector using a `Nat` index and a tactic provided proof. -/ @[inline] def eraseIdx (v : Vector α n) (i : Nat) (h : i < n := by get_elem_tactic) : Vector α (n-1) := ⟨v.toArray.eraseIdx i (v.size_toArray.symm ▸ h), by simp [Array.size_eraseIdx]⟩ /-- Delete an element of a vector using a `Nat` index. Panics if the index is out of bounds. -/ @[inline] def eraseIdx! (v : Vector α n) (i : Nat) : Vector α (n-1) := if _ : i < n then v.eraseIdx i else have : Inhabited (Vector α (n-1)) := ⟨v.pop⟩ panic! "index out of bounds" /-- Delete the first element of a vector. Returns the empty vector if the input vector is empty. -/ @[inline] def tail (v : Vector α n) : Vector α (n-1) := if _ : 0 < n then v.eraseIdx 0 else v.cast (by omega) /-- Finds the first index of a given value in a vector using `==` for comparison. Returns `none` if the no element of the index matches the given value. -/ @[inline] def indexOf? [BEq α] (v : Vector α n) (x : α) : Option (Fin n) := (v.toArray.indexOf? x).map (Fin.cast v.size_toArray) /-- Returns `true` when `v` is a prefix of the vector `w`. -/ @[inline] def isPrefixOf [BEq α] (v : Vector α m) (w : Vector α n) : Bool := v.toArray.isPrefixOf w.toArray /-- Returns `true` with the monad if `p` returns `true` for any element of the vector. -/ @[inline] def anyM [Monad m] (p : α → m Bool) (v : Vector α n) : m Bool := v.toArray.anyM p /-- Returns `true` with the monad if `p` returns `true` for all elements of the vector. -/ @[inline] def allM [Monad m] (p : α → m Bool) (v : Vector α n) : m Bool := v.toArray.allM p /-- Returns `true` if `p` returns `true` for any element of the vector. -/ @[inline] def any (v : Vector α n) (p : α → Bool) : Bool := v.toArray.any p /-- Returns `true` if `p` returns `true` for all elements of the vector. -/ @[inline] def all (v : Vector α n) (p : α → Bool) : Bool := v.toArray.all p /-! ### Lexicographic ordering -/ instance instLT [LT α] : LT (Vector α n) := ⟨fun v w => v.toArray < w.toArray⟩ instance instLE [LT α] : LE (Vector α n) := ⟨fun v w => v.toArray ≤ w.toArray⟩ /-- Lexicographic comparator for vectors. `lex v w lt` is true if - `v` is pairwise equivalent via `==` to `w`, or - there is an index `i` such that `lt v[i] w[i]`, and for all `j < i`, `v[j] == w[j]`. -/ def lex [BEq α] (v w : Vector α n) (lt : α → α → Bool := by exact (· < ·)) : Bool := Id.run do for h : i in [0 : n] do if lt v[i] w[i] then return true else if v[i] != w[i] then return false return false