Some examples (ex. `hello`) may now segfault. See https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/Problems.20when.20updating.20Lean/near/249599195.
131 lines
4.4 KiB
Text
131 lines
4.4 KiB
Text
/-
|
||
Copyright (c) 2021 Mac Malone. All rights reserved.
|
||
Released under Apache 2.0 license as described in the file LICENSE.
|
||
Authors: Mac Malone
|
||
-/
|
||
|
||
namespace Lake
|
||
|
||
-- # Async / Await
|
||
|
||
class Async (m : Type u → Type v) (n : outParam $ Type u → Type u) where
|
||
async : m α → m (n α)
|
||
|
||
export Async (async)
|
||
|
||
class Await (m : Type u → Type v) (n : outParam $ Type u → Type u) where
|
||
await : n α → m α
|
||
|
||
export Await (await)
|
||
|
||
class ApplicativeAsync (m : Type u → Type v) (n : outParam $ Type u → Type u) extends Async m n, Await m n where
|
||
seqLeftAsync {α β : Type u} : n α → m β → m (n α) -- := fun x y => async (await x <* y)
|
||
seqRightAsync {α β : Type u} : n α → m β → m (n β) -- := fun x y => async (await x *> y)
|
||
|
||
export ApplicativeAsync (seqLeftAsync seqRightAsync)
|
||
|
||
class MonadAsync (m : Type u → Type v) (n : outParam $ Type u → Type u) extends ApplicativeAsync m n where
|
||
mapAsync {α β : Type u} : (α → m β) → n α → m (n β) -- := fun f x => async (await x >>= f)
|
||
bindAsync {α β : Type u} : n α → (α → m (n β)) → m (n β) -- := fun x f => async (await x >>= f >>= await)
|
||
|
||
export MonadAsync (mapAsync bindAsync)
|
||
|
||
section
|
||
variable [MonadAsync m n]
|
||
|
||
-- ## List Utilities
|
||
|
||
/-- `MonadAsync` version of `IO.mapTasks` -/
|
||
def mapListAsync (f : List α → m β) (ts : List (n α)) : m (n β) :=
|
||
go ts []
|
||
where
|
||
go
|
||
| [], as => async (f as.reverse)
|
||
| t::ts, as => bindAsync t fun a => go ts (a :: as)
|
||
|
||
def afterListAsync (task : m (n β)) : (ts : List (n α)) → m (n β)
|
||
| [] => task
|
||
| t::ts => bindAsync t fun _ => afterListAsync task ts
|
||
|
||
def andThenListAsync [Pure m] (task : (n α)) : (ts : List (n α)) → m (n α)
|
||
| [] => pure task
|
||
| t::ts => bindAsync task fun _ => andThenListAsync t ts
|
||
|
||
def seqListAsync [Pure m] [Pure n] : (ts : List (n PUnit)) → m (n PUnit)
|
||
| [] => pure (pure ())
|
||
| t::ts => andThenListAsync t ts
|
||
|
||
-- ## Array Utilities
|
||
-- These Follow the pattern of Array iterators established in the Lean core.
|
||
|
||
@[inline] unsafe def mapArrayAsyncUnsafe (f : Array α → m β) (ts : Array (n α)) (start := 0) (stop := ts.size) : m (n β) :=
|
||
let rec @[specialize] fold (i : USize) (stop : USize) (as : Array α) : m (n β) :=
|
||
if i == stop then
|
||
async (f as)
|
||
else
|
||
bindAsync (ts.uget i lcProof) fun a => fold (i+1) stop (as.push a)
|
||
if start < stop then
|
||
if stop ≤ ts.size then
|
||
fold (USize.ofNat start) (USize.ofNat stop) (Array.mkEmpty (start - stop))
|
||
else
|
||
async (f #[])
|
||
else
|
||
async (f #[])
|
||
|
||
@[implementedBy mapArrayAsyncUnsafe]
|
||
def mapArrayAsync (f : Array α → m β) (ts : Array (n α)) (start := 0) (stop := ts.size) : m (n β) :=
|
||
let fold (stop : Nat) (h : stop ≤ ts.size) :=
|
||
let rec loop (i : Nat) (j : Nat) (as : Array α) : m (n β) :=
|
||
if hlt : j < stop then
|
||
match i with
|
||
| Nat.zero => async (f as)
|
||
| Nat.succ i' =>
|
||
let t := ts.get ⟨j, Nat.lt_of_lt_of_le hlt h⟩
|
||
bindAsync t fun a => loop i' (j+1) (as.push a)
|
||
else
|
||
async (f as)
|
||
loop (stop - start) start (Array.mkEmpty (stop - start))
|
||
if h : stop ≤ ts.size then
|
||
fold stop h
|
||
else
|
||
fold ts.size (Nat.le_refl _)
|
||
|
||
@[inline] unsafe def afterArrayAsyncUnsafe (task : m (n β)) (ts : Array (n α)) (start := 0) (stop := ts.size) : m (n β) :=
|
||
let rec @[specialize] fold (i : USize) (stop : USize) : m (n β) :=
|
||
if i == stop then
|
||
task
|
||
else
|
||
bindAsync (ts.uget i lcProof) fun _ => fold (i+1) stop
|
||
if start < stop then
|
||
if stop ≤ ts.size then
|
||
fold (USize.ofNat start) (USize.ofNat stop)
|
||
else
|
||
task
|
||
else
|
||
task
|
||
|
||
@[implementedBy afterArrayAsyncUnsafe]
|
||
def afterArrayAsync (task : m (n β)) (ts : Array (n α)) (start := 0) (stop := ts.size) : m (n β) :=
|
||
let fold (stop : Nat) (h : stop ≤ ts.size) :=
|
||
let rec loop (i : Nat) (j : Nat) : m (n β) :=
|
||
if hlt : j < stop then
|
||
match i with
|
||
| Nat.zero => task
|
||
| Nat.succ i' =>
|
||
let t := ts.get ⟨j, Nat.lt_of_lt_of_le hlt h⟩
|
||
bindAsync t fun a => loop i' (j+1)
|
||
else
|
||
task
|
||
loop (stop - start) start
|
||
if h : stop ≤ ts.size then
|
||
fold stop h
|
||
else
|
||
fold ts.size (Nat.le_refl _)
|
||
|
||
def seqArrayAsync [Pure m] [Pure n] (ts : Array (n PUnit)) : m (n PUnit) :=
|
||
if h : 0 < ts.size then
|
||
afterArrayAsync (ts.get ⟨ts.size - 1, Nat.sub_lt h (by decide)⟩) ts.pop
|
||
else
|
||
pure (pure ())
|
||
|
||
end
|