refactor: use Lists for Array reference implementation

Motivation: better reduction in the kernel.

cc @Kha
This commit is contained in:
Leonardo de Moura 2020-11-17 17:05:53 -08:00
parent 160a263049
commit 7e533b4650
11 changed files with 148 additions and 83 deletions

View file

@ -17,12 +17,11 @@ variables {α : Type u}
@[extern "lean_mk_array"]
def mkArray {α : Type u} (n : Nat) (v : α) : Array α := {
sz := n,
data := fun _ => v
data := List.replicate n v
}
theorem sizeMkArrayEq (n : Nat) (v : α) : (mkArray n v).size = n :=
rfl
List.lengthReplicateEq ..
instance : EmptyCollection (Array α) := ⟨Array.empty⟩
instance : Inhabited (Array α) := ⟨Array.empty⟩
@ -55,15 +54,14 @@ abbrev getLit {α : Type u} {n : Nat} (a : Array α) (i : Nat) (h₁ : a.size =
@[extern "lean_array_fset"]
def set (a : Array α) (i : @& Fin a.size) (v : α) : Array α := {
sz := a.sz,
data := fun j => if h : i = j then v else a.data j
data := a.data.set i v
}
theorem szFSetEq (a : Array α) (i : Fin a.size) (v : α) : (set a i v).size = a.size :=
rfl
theorem sizeSetEq (a : Array α) (i : Fin a.size) (v : α) : (set a i v).size = a.size :=
List.lengthSetEq ..
theorem szPushEq (a : Array α) (v : α) : (push a v).size = a.size + 1 :=
rfl
theorem sizePushEq (a : Array α) (v : α) : (push a v).size = a.size + 1 :=
List.lengthConcatEq ..
/- Low-level version of `fset` which is as fast as a C array fset.
`Fin` values are represented as tag pointers in the Lean runtime. Thus,
@ -81,8 +79,8 @@ def set! (a : Array α) (i : @& Nat) (v : α) : Array α :=
def swap (a : Array α) (i j : @& Fin a.size) : Array α :=
let v₁ := a.get i
let v₂ := a.get j
let a := a.set i v₂
a.set j v₁
let a' := a.set i v₂
a'.set (sizeSetEq a i v₂ ▸ j) v₁
@[extern "lean_array_swap"]
def swap! (a : Array α) (i j : @& Nat) : Array α :=
@ -106,8 +104,7 @@ def swapAt! {α : Type} (a : Array α) (i : Nat) (v : α) : α × Array α :=
@[extern "lean_array_pop"]
def pop (a : Array α) : Array α := {
sz := Nat.pred a.size,
data := fun ⟨j, h⟩ => a.get ⟨j, Nat.ltOfLtOfLe h (Nat.predLe _)⟩
data := a.data.dropLast
}
def shrink {α : Type u} (a : Array α) (n : Nat) : Array α :=
@ -121,9 +118,9 @@ def modifyM {m : Type u → Type v} [Monad m] [Inhabited α] (a : Array α) (i :
if h : i < a.size then
let idx : Fin a.size := ⟨i, h⟩
let v := a.get idx
let a := a.set idx (arbitrary α)
let a' := a.set idx (arbitrary α)
let v ← f v
pure $ (a.set idx v)
pure $ (a'.set (sizeSetEq a .. ▸ idx) v)
else
pure a
@ -463,6 +460,7 @@ partial def reverse {α : Type u} (as : Array α) : Array α :=
else
(true, r)
@[export lean_array_to_list]
def toList {α : Type u} (as : Array α) : List α :=
as.foldr List.cons []
@ -489,7 +487,7 @@ def List.redLength {α : Type u} : List α → Nat
| [] => 0
| _::as => as.redLength + 1
@[inline, matchPattern]
@[inline, matchPattern, export lean_list_to_array]
def List.toArray {α : Type u} (as : List α) : Array α :=
as.toArrayAux (Array.mkEmpty as.redLength)
@ -568,17 +566,42 @@ theorem ext {α} (a b : Array α)
(h₁ : a.size = b.size)
(h₂ : (i : Nat) → (hi₁ : i < a.size) → (hi₂ : i < b.size) → a.get ⟨i, hi₁⟩ = b.get ⟨i, hi₂⟩)
: a = b := by
match a, b, h₁, h₂ with
| ⟨sz₁, f₁⟩, ⟨sz₂, f₂⟩, h₁, h₂ =>
subst h₁
have f₁ = f₂ from funext fun ⟨i, hi₁⟩ => h₂ i hi₁ hi₁
subst this
exact rfl
let rec extAux (a b : List α)
(h₁ : a.length = b.length)
(h₂ : (i : Nat) → (hi₁ : i < a.length) → (hi₂ : i < b.length) → a.get i hi₁ = b.get i hi₂)
: a = b := by
induction a generalizing b
| nil =>
cases b
| nil => rfl
| cons b bs => rw [List.lengthConsEq] at h₁; injection h₁
| cons a as ih =>
cases b
| nil => rw [List.lengthConsEq] at h₁; injection h₁
| cons b bs =>
have hz₁ : 0 < (a::as).length by rw [List.lengthConsEq]; apply Nat.zeroLtSucc
have hz₂ : 0 < (b::bs).length by rw [List.lengthConsEq]; apply Nat.zeroLtSucc
have headEq : a = b from h₂ 0 hz₁ hz₂
have h₁' : as.length = bs.length by rw [List.lengthConsEq, List.lengthConsEq] at h₁; injection h₁; assumption
have h₂' : (i : Nat) → (hi₁ : i < as.length) → (hi₂ : i < bs.length) → as.get i hi₁ = bs.get i hi₂ by
intro i hi₁ hi₂
have hi₁' : i+1 < (a::as).length by rw [List.lengthConsEq]; apply Nat.succLtSucc; assumption
have hi₂' : i+1 < (b::bs).length by rw [List.lengthConsEq]; apply Nat.succLtSucc; assumption
have (a::as).get (i+1) hi₁' = (b::bs).get (i+1) hi₂' from h₂ (i+1) hi₁' hi₂'
apply this
have tailEq : as = bs from ih bs h₁' h₂'
rw [headEq, tailEq]
rfl
cases a; cases b
apply congrArg
apply extAux
assumption
assumption
theorem extLit {α : Type u} {n : Nat}
(a b : Array α)
(hsz₁ : a.size = n) (hsz₂ : b.size = n)
(h : ∀ (i : Nat) (hi : i < n), a.getLit i hsz₁ hi = b.getLit i hsz₂ hi) : a = b :=
(h : (i : Nat) → (hi : i < n) → a.getLit i hsz₁ hi = b.getLit i hsz₂ hi) : a = b :=
Array.ext a b (hsz₁.trans hsz₂.symm) fun i hi₁ hi₂ => h i (hsz₁ ▸ hi₁)
end Array
@ -612,26 +635,28 @@ def feraseIdx {α} (a : Array α) (i : Fin a.size) : Array α :=
def eraseIdx {α} (a : Array α) (i : Nat) : Array α :=
if i < a.size then eraseIdxAux (i+1) a else a
theorem szFSwapEq {α} (a : Array α) (i j : Fin a.size) : (a.swap i j).size = a.size :=
theorem sizeSwapEq {α} (a : Array α) (i j : Fin a.size) : (a.swap i j).size = a.size := by
show ((a.set i (a.get j)).set (sizeSetEq a i _ ▸ j) (a.get i)).size = a.size
rw [sizeSetEq, sizeSetEq]
rfl
theorem szPopEq {α} (a : Array α) : a.pop.size = a.size - 1 :=
rfl
theorem sizePopEq {α} (a : Array α) : a.pop.size = a.size - 1 :=
List.lengthDropLast ..
section
/- Instance for justifying `partial` declaration.
We should be able to delete it as soon as we restore support for well-founded recursion. -/
instance eraseIdxSzAuxInstance {α} (a : Array α) : Inhabited { r : Array α // r.size = a.size - 1 } :=
⟨⟨a.pop, szPopEq a⟩⟩
⟨⟨a.pop, sizePopEq a⟩⟩
partial def eraseIdxSzAux {α} (a : Array α) : ∀ (i : Nat) (r : Array α), r.size = a.size → { r : Array α // r.size = a.size - 1 }
| i, r, heq =>
if h : i < r.size then
let idx : Fin r.size := ⟨i, h⟩;
let idx1 : Fin r.size := ⟨i - 1, Nat.ltOfLeOfLt (Nat.predLe i) h⟩;
eraseIdxSzAux a (i+1) (r.swap idx idx1) ((szFSwapEq r idx idx1).trans heq)
eraseIdxSzAux a (i+1) (r.swap idx idx1) ((sizeSwapEq r idx idx1).trans heq)
else
⟨r.pop, (szPopEq r).trans (heq ▸ rfl)⟩
⟨r.pop, (sizePopEq r).trans (heq ▸ rfl)⟩
end
def eraseIdx' {α} (a : Array α) (i : Fin a.size) : { r : Array α // r.size = a.size - 1 } :=

View file

@ -69,13 +69,6 @@ def eraseIdx : List α → Nat → List α
| a::as, 0 => as
| a::as, n+1 => a :: eraseIdx as n
def lengthAux : List α → Nat → Nat
| [], n => n
| a::as, n => lengthAux as (n+1)
def length (as : List α) : Nat :=
lengthAux as 0
def isEmpty : List α → Bool
| [] => true
| _ :: _ => false
@ -243,9 +236,6 @@ def unzip : List (α × β) → List α × List β
| [] => ([], [])
| (a, b) :: t => match unzip t with | (al, bl) => (a::al, b::bl)
def replicate (n : Nat) (a : α) : List α :=
n.repeat (fun xs => a :: xs) []
def rangeAux : Nat → List Nat → List Nat
| 0, ns => ns
| n+1, ns => rangeAux n (n::ns)
@ -335,4 +325,54 @@ protected def beq [BEq α] : List α → List α → Bool
instance [BEq α] : BEq (List α) := ⟨List.beq⟩
def replicate {α : Type u} (n : Nat) (a : α) : List α :=
let rec loop : Nat → List α → List α
| 0, as => as
| n+1, as => loop n (a::as)
loop n []
def dropLast {α} : List α → List α
| [] => []
| [a] => []
| a::as => a :: dropLast as
def lengthReplicateEq {α} (n : Nat) (a : α) : (replicate n a).length = n :=
let rec aux (n : Nat) (as : List α) : (replicate.loop a n as).length = n + as.length := by
induction n generalizing as
| zero => rw [Nat.zeroAdd]; rfl
| succ n ih =>
show length (replicate.loop a n (a::as)) = Nat.succ n + length as
rw [ih, lengthConsEq, Nat.addSucc, Nat.succAdd]
rfl
aux n []
def lengthConcatEq {α} (as : List α) (a : α) : (concat as a).length = as.length + 1 := by
induction as
| nil => rfl
| cons x xs ih =>
show length (x :: concat xs a) = length (x :: xs) + 1
rw [lengthConsEq, lengthConsEq, ih]
rfl
def lengthSetEq {α} (as : List α) (i : Nat) (a : α) : (as.set i a).length = as.length := by
induction as generalizing i
| nil => rfl
| cons x xs ih =>
cases i
| zero => rfl
| succ i =>
show length (x :: set xs i a) = length (x :: xs)
rw [lengthConsEq, lengthConsEq, ih]
rfl
def lengthDropLast {α} (as : List α) : as.dropLast.length = as.length - 1 := by
match as with
| [] => rfl
| [a] => rfl
| a::b::as =>
have ih := lengthDropLast (b::as)
show (a :: dropLast (b::as)).length = (a::b::as).length - 1
rw [lengthConsEq, ih, lengthConsEq, lengthConsEq, lengthConsEq]
rfl
end List

View file

@ -875,6 +875,31 @@ def List.foldl {α β} (f : α → β → α) : (init : α) → List β → α
| a, nil => a
| a, cons b l => foldl f (f a b) l
def List.lengthAux {α : Type u} : List α → Nat → Nat
| nil, n => n
| cons a as, n => lengthAux as (Nat.succ n)
def List.length {α : Type u} (as : List α) : Nat :=
lengthAux as 0
theorem List.lengthConsEq {α} (a : α) (as : List α) : Eq (cons a as).length as.length.succ :=
let rec aux (a : α) (as : List α) : (n : Nat) → Eq ((cons a as).lengthAux n) (as.lengthAux n).succ :=
match as with
| nil => fun _ => rfl
| cons a as => fun n => aux a as n.succ
aux a as 0
def List.concat {α : Type u} : List αα → List α
| nil, b => cons b nil
| cons a as, b => cons a (concat as b)
def List.get {α : Type u} : (as : List α) → (i : Nat) → Less i as.length → α
| nil, i, h => absurd h (Nat.notLtZero _)
| cons a as, 0, h => a
| cons a as, Nat.succ i, h =>
have Less i.succ as.length.succ from lengthConsEq .. ▸ h
get as i (Nat.leOfSuccLeSucc this)
structure String :=
(data : List Char)
@ -933,18 +958,15 @@ The Compiler has special support for arrays.
They are implemented using dynamic arrays: https://en.wikipedia.org/wiki/Dynamic_array
-/
structure Array (α : Type u) :=
(sz : Nat)
(data : Fin sz → α)
(data : List α)
attribute [extern "lean_array_mk"] Array.mk
attribute [extern "lean_array_data"] Array.data
attribute [extern "lean_array_sz"] Array.sz
attribute [extern "lean_array_to_list"] Array.data
attribute [extern "lean_list_to_array"] Array.mk
/- The parameter `c` is the initial capacity -/
@[extern "lean_mk_empty_array_with_capacity"]
def Array.mkEmpty {α : Type u} (c : @& Nat) : Array α := {
sz := 0,
data := fun ⟨x, h⟩ => absurd h (Nat.notLtZero x)
data := List.nil
}
def Array.empty {α : Type u} : Array α :=
@ -952,11 +974,11 @@ def Array.empty {α : Type u} : Array α :=
@[reducible, extern "lean_array_get_size"]
def Array.size {α : Type u} (a : @& Array α) : Nat :=
a.sz
a.data.length
@[extern "lean_array_fget"]
def Array.get {α : Type u} (a : @& Array α) (i : @& Fin a.size) : α :=
a.data i
a.data.get i.val i.isLt
/- "Comfortable" version of `fget`. It performs a bound check at runtime. -/
@[extern "lean_array_get"]
@ -968,8 +990,7 @@ def Array.getOp {α : Type u} [Inhabited α] (self : Array α) (idx : Nat) : α
@[extern "lean_array_push"]
def Array.push {α : Type u} (a : Array α) (v : α) : Array α := {
sz := Nat.succ a.sz,
data := fun ⟨j, h₁⟩ => dite (Eq j a.sz) (fun _ => v) (fun h₂ => a.data ⟨j, Nat.ltOfLeOfNe (Nat.leOfLtSucc h₁) h₂⟩)
data := List.concat a.data v
}
class Bind (m : Type u → Type v) :=

View file

@ -12,7 +12,7 @@ def HashMapBucket (α : Type u) (β : Type v) :=
def HashMapBucket.update {α : Type u} {β : Type v} (data : HashMapBucket α β) (i : USize) (d : AssocList α β) (h : i.toNat < data.val.size) : HashMapBucket α β :=
⟨ data.val.uset i d h,
by rw [Array.szFSetEq]; exact data.property ⟩
by rw [Array.sizeSetEq]; exact data.property ⟩
structure HashMapImp (α : Type u) (β : Type v) :=
(size : Nat)

View file

@ -11,7 +11,7 @@ def HashSetBucket (α : Type u) :=
def HashSetBucket.update {α : Type u} (data : HashSetBucket α) (i : USize) (d : List α) (h : i.toNat < data.val.size) : HashSetBucket α :=
⟨ data.val.uset i d h,
by rw [Array.szFSetEq]; exact data.property ⟩
by rw [Array.sizeSetEq]; exact data.property ⟩
structure HashSetImp (α : Type u) :=
(size : Nat)

View file

@ -146,15 +146,15 @@ partial def popLeaf : PersistentArrayNode α → Option (Array α) × Array (Per
if h : cs.size ≠ 0 then
let idx : Fin cs.size := ⟨cs.size - 1, Nat.predLt h⟩
let last := cs.get idx
let cs := cs.set idx (arbitrary _)
let cs' := cs.set idx (arbitrary _)
match popLeaf last with
| (none, _) => (none, emptyArray)
| (some l, newLast) =>
if newLast.size == 0 then
let cs := cs.pop
if cs.isEmpty then (some l, emptyArray) else (some l, cs)
let cs' := cs'.pop
if cs'.isEmpty then (some l, emptyArray) else (some l, cs')
else
(some l, cs.set idx (node newLast))
(some l, cs'.set (Array.sizeSetEq cs idx _ ▸ idx) (node newLast))
else
(none, emptyArray)
| leaf vs => (some vs, emptyArray)

View file

@ -66,11 +66,11 @@ abbrev EntriesNode (α β) := { n : Node α β // IsEntriesNode n }
private theorem setSizeEq {ks : Array α} {vs : Array β} (h : ks.size = vs.size) (i : Fin ks.size) (j : Fin vs.size) (k : α) (v : β)
: (ks.set i k).size = (vs.set j v).size := by
rw [Array.szFSetEq, Array.szFSetEq vs, h]
rw [Array.sizeSetEq, Array.sizeSetEq vs, h]
rfl
private theorem pushSizeEq {ks : Array α} {vs : Array β} (h : ks.size = vs.size) (k : α) (v : β) : (ks.push k).size = (vs.push v).size := by
rw [Array.szPushEq, Array.szPushEq, h]
rw [Array.sizePushEq, Array.sizePushEq, h]
rfl
partial def insertAtCollisionNodeAux [BEq α] : CollisionNode α β → Nat → α → β → CollisionNode α β

View file

@ -1,5 +1,3 @@
def foo (a : Array Nat) : Array Nat :=
let a := a.push 0
let a := a.push 1
@ -10,12 +8,12 @@ a
def main : IO UInt32 := do
let a : Array Nat := Array.empty
IO.println (toString a)
IO.println (toString a.sz)
IO.println (toString a.size)
let a := foo a
IO.println (toString a)
let a := a.map (fun a => a + 10)
IO.println (toString a)
IO.println (toString a.sz)
IO.println (toString a.size)
let a1 := a.pop
let a2 := a.push 100
IO.println (toString a1)

View file

@ -1,5 +1,3 @@
def f1 (x : Nat × Nat) : Nat :=
match x with
| { fst := x, snd := y } => x - y
@ -46,7 +44,7 @@ match x with
def Vector (α : Type) (n : Nat) := { a : Array α // a.size = n }
def mkVec {α : Type} (n : Nat) (a : α) : Vector α n :=
⟨mkArray n a, rfl
⟨mkArray n a, Array.sizeMkArrayEq ..
structure S :=
(n : Nat)

View file

@ -1,7 +1,6 @@
#check @Array.mk
def v : Array Nat := @Array.mk Nat 10 (fun ⟨i, _⟩ => i)
def v : Array Nat := Array.mk [1, 2, 3, 4]
def w : Array Nat :=
(mkArray 9 1).push 3

View file

@ -1,4 +1,3 @@
universes u v
theorem eqLitOfSize0 {α : Type u} (a : Array α) (hsz : a.size = 0) : a = #[] :=
@ -63,18 +62,3 @@ theorem matchArrayLit.eq3 {α : Type u} (C : Array α → Sort v)
(a₁ a₂ a₃ : α)
: matchArrayLit C #[a₁, a₂, a₃] h₁ h₂ h₃ h₄ = h₃ a₁ a₂ a₃ :=
rfl
theorem matchArrayLit.eq4 {α : Type u} (C : Array α → Sort v)
(h₁ : Unit → C #[])
(h₂ : ∀ a₁, C #[a₁])
(h₃ : ∀ a₁ a₂ a₃, C #[a₁, a₂, a₃])
(h₄ : ∀ a, C a)
(a : Array α)
(n₁ : a.size ≠ 0) (n₂ : a.size ≠ 1) (n₃ : a.size ≠ 3)
: matchArrayLit C a h₁ h₂ h₃ h₄ = h₄ a :=
match a, n₁, n₂, n₃ with
| ⟨0, _⟩, n₁, _, _ => absurd rfl n₁
| ⟨1, _⟩, _, n₂, _ => absurd rfl n₂
| ⟨2, _⟩, _, _, _ => rfl
| ⟨3, _⟩, _, _, n₃ => absurd rfl n₃
| ⟨n+4, _⟩, _, _, _ => rfl