reefactor: improve async API

This commit is contained in:
tydeu 2021-08-09 05:04:38 -04:00
parent d0fbc93143
commit 4f75dd99d1
3 changed files with 163 additions and 26 deletions

125
Lake/Async.lean Normal file
View file

@ -0,0 +1,125 @@
/-
Copyright (c) 2021 Mac Malone. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Mac Malone
-/
namespace Lake
-- # Async / Await
class Async (m : Type u → Type v) (n : outParam $ Type u → Type u) where
async : m α → m (n α)
export Async (async)
class Await (m : Type u → Type v) (n : outParam $ Type u → Type u) where
await : n α → m α
export Await (await)
class MonadAsync (m : Type u → Type v) (n : outParam $ Type u → Type u) [Monad m] extends Async m n, Await m n where
mapAsync {α β : Type u} : (α → m β) → n α → m (n β) := fun f x => async (await x >>= f)
bindAsync {α β : Type u} : n α → (α → m (n β)) → m (n β) := fun x f => async (await x >>= f >>= await)
export MonadAsync (mapAsync bindAsync)
section
variable [Monad m] [MonadAsync m n]
-- ## List Utilities
/-- `MonadAsync` version of `IO.mapTasks` -/
def mapListAsync (f : List α → m β) (ts : List (n α)) : m (n β) :=
go ts []
where
go
| [], as => async (f as.reverse)
| t::ts, as => bindAsync t fun a => go ts (a :: as)
def afterListAsync (task : m (n β)) : (ts : List (n α)) → m (n β)
| [] => task
| t::ts => bindAsync t fun _ => afterListAsync task ts
def andThenListAsync (task : (n α)) : (ts : List (n α)) → m (n α)
| [] => task
| t::ts => bindAsync task fun _ => andThenListAsync t ts
def seqListAsync [Pure n] : (ts : List (n PUnit)) → m (n PUnit)
| [] => pure (pure ())
| t::ts => andThenListAsync t ts
-- ## Array Utilities
-- These Follow the pattern of Array iterators established in the Lean core.
@[inline] unsafe def mapArrayAsyncUnsafe (f : Array α → m β) (ts : Array (n α)) (start := 0) (stop := ts.size) : m (n β) :=
let rec @[specialize] fold (i : USize) (stop : USize) (as : Array α) : m (n β) := do
if i == stop then
async (f as)
else
bindAsync (ts.uget i lcProof) fun a => fold (i+1) stop (as.push a)
if start < stop then
if stop ≤ ts.size then
fold (USize.ofNat start) (USize.ofNat stop) (Array.mkEmpty (start - stop))
else
async (f #[])
else
async (f #[])
@[implementedBy mapArrayAsyncUnsafe]
def mapArrayAsync (f : Array α → m β) (ts : Array (n α)) (start := 0) (stop := ts.size) : m (n β) :=
let fold (stop : Nat) (h : stop ≤ ts.size) :=
let rec loop (i : Nat) (j : Nat) (as : Array α) : m (n β) := do
if hlt : j < stop then
match i with
| Nat.zero => async (f as)
| Nat.succ i' =>
let t := ts.get ⟨j, Nat.ltOfLtOfLe hlt h⟩
bindAsync t fun a => loop i' (j+1) (as.push a)
else
async (f as)
loop (stop - start) start (Array.mkEmpty (stop - start))
if h : stop ≤ ts.size then
fold stop h
else
fold ts.size (Nat.leRefl _)
@[inline] unsafe def afterArrayAsyncUnsafe (task : m (n β)) (ts : Array (n α)) (start := 0) (stop := ts.size) : m (n β) :=
let rec @[specialize] fold (i : USize) (stop : USize) : m (n β) := do
if i == stop then
task
else
bindAsync (ts.uget i lcProof) fun _ => fold (i+1) stop
if start < stop then
if stop ≤ ts.size then
fold (USize.ofNat start) (USize.ofNat stop)
else
task
else
task
@[implementedBy afterArrayAsyncUnsafe]
def afterArrayAsync (task : m (n β)) (ts : Array (n α)) (start := 0) (stop := ts.size) : m (n β) :=
let fold (stop : Nat) (h : stop ≤ ts.size) :=
let rec loop (i : Nat) (j : Nat) : m (n β) := do
if hlt : j < stop then
match i with
| Nat.zero => task
| Nat.succ i' =>
let t := ts.get ⟨j, Nat.ltOfLtOfLe hlt h⟩
bindAsync t fun a => loop i' (j+1)
else
task
loop (stop - start) start
if h : stop ≤ ts.size then
fold stop h
else
fold ts.size (Nat.leRefl _)
def seqArrayAsync [Pure n] (ts : Array (n PUnit)) : m (n PUnit) :=
if h : 0 < ts.size then
afterArrayAsync (ts.get ⟨ts.size - 1, Nat.subLt h (by decide)⟩) ts.pop
else
pure (pure ())
end

View file

@ -91,17 +91,23 @@ namespace ActiveBuildTarget
-- ### Combinators
def after (target : ActiveBuildTarget t a) (act : IO PUnit) : IO BuildTask :=
afterTask target.task act
target.task.andThen act
def afterList (targets : List (ActiveBuildTarget t a)) (act : IO PUnit) : IO BuildTask :=
afterTaskList (targets.map (·.task)) act
def afterArray (targets : Array (ActiveBuildTarget t a)) (act : IO PUnit) : IO BuildTask :=
afterTaskArray (targets.map (·.task)) act
instance : HAndThen (ActiveBuildTarget t a) (IO PUnit) (IO BuildTask) :=
⟨ActiveBuildTarget.after⟩
instance : HAndThen (List (ActiveBuildTarget t a)) (IO PUnit) (IO BuildTask) :=
⟨ActiveBuildTarget.afterList⟩
instance : HAndThen (Array (ActiveBuildTarget t a)) (IO PUnit) (IO BuildTask) :=
⟨ActiveBuildTarget.afterArray⟩
end ActiveBuildTarget
--------------------------------------------------------------------------------
@ -134,11 +140,17 @@ def materializeAsync [Async m n] (self : Target t m a) : m (n PUnit) :=
def materialize (self : Target t m a) : m PUnit :=
self.task
def materializeList [Monad m] [MonadAsync m n] (targets : List (Target t m a)) : m PUnit := do
(← targets.mapM (·.materializeAsync)).forM await
def materializeListAsync [Monad m] [Pure n] [MonadAsync m n] (targets : List (Target t m a)) : m (n PUnit) := do
seqListAsync (← targets.mapM (·.materializeAsync))
def materializeArray [Monad m] [MonadAsync m n] (targets : Array (Target t m a)) : m PUnit := do
(← targets.mapM (·.materializeAsync)).forM await
def materializeList [Monad m] [Pure n] [MonadAsync m n] (targets : List (Target t m a)) : m PUnit := do
await <| ← materializeListAsync targets
def materializeArrayAsync [Monad m] [Pure n] [MonadAsync m n] (targets : Array (Target t m a)) : m (n PUnit) := do
seqArrayAsync (← targets.mapM (·.materializeAsync))
def materializeArray [Monad m] [Pure n] [MonadAsync m n] (targets : Array (Target t m a)) : m PUnit := do
await <| ← materializeArrayAsync targets
end Target

View file

@ -3,6 +3,7 @@ Copyright (c) 2021 Mac Malone. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Mac Malone
-/
import Lake.Async
namespace Lake
@ -28,29 +29,25 @@ def spawn (act : IO α) (prio := Task.Priority.dedicated) : IO (IOTask α) :=
def await (self : IOTask α) : IO α := do
IO.ofExcept (← IO.wait self)
def collectAll (tasks : List (IOTask α)) : IO (IOTask (List α)) :=
IO.asTask (tasks.mapM (·.await))
def mapAsync (f : α → IO β) (self : IOTask α) (prio := Task.Priority.dedicated) : IO (IOTask β) :=
IO.mapTask (fun x => do let x ← IO.ofExcept x; f x) self prio
end IOTask
def andThen (self : IOTask α) (act : IO β) (prio := Task.Priority.dedicated) : IO (IOTask β) :=
IO.mapTask (fun x => IO.ofExcept x *> act) self prio
-- # Async / Await
instance : HAndThen (IOTask α) (IO β) (IO (IOTask β)) := ⟨andThen⟩
class Async (m : Type u → Type v) (n : outParam $ Type u → Type u) where
async : m α → m (n α)
export Async (async)
class Await (m : outParam $ Type u → Type v) (n : Type u → Type u) where
await : n α → m α
export Await (await)
class MonadAsync (m : Type u → Type v) (n : outParam $ Type u → Type u)
extends Async m n, Await m n
def bindAsync (self : IOTask α) (f : α → IO (IOTask β)) (prio := Task.Priority.dedicated) : IO (IOTask β) :=
IO.bindTask self (fun x => do let x ← IO.ofExcept x; f x) prio
instance : MonadAsync IO IOTask where
async := IOTask.spawn
await := IOTask.await
mapAsync := IOTask.mapAsync
bindAsync := IOTask.bindAsync
end IOTask
-- # Build Task
@ -71,14 +68,17 @@ end BuildTask
instance : Inhabited BuildTask := ⟨BuildTask.nop⟩
def afterTask (task : BuildTask) (act : IO PUnit) : IO BuildTask :=
IO.mapTask (fun x => IO.ofExcept x *> act) task
def afterTaskList (tasks : List BuildTask) (act : IO PUnit) : IO BuildTask :=
IO.mapTasks (fun xs => xs.forM IO.ofExcept *> act) tasks
afterListAsync (async act) tasks
def afterTaskArray (tasks : Array BuildTask) (act : IO PUnit) : IO BuildTask :=
afterArrayAsync (async act) tasks
instance : HAndThen BuildTask (IO PUnit) (IO BuildTask) :=
⟨afterTask⟩
IOTask.andThen
instance : HAndThen (List BuildTask) (IO PUnit) (IO BuildTask) :=
⟨afterTaskList⟩
instance : HAndThen (Array BuildTask) (IO PUnit) (IO BuildTask) :=
⟨afterTaskArray⟩