refactor: avoid double exception layer with AsyncList
This commit is contained in:
parent
d8ec900ae9
commit
d503fe6d13
3 changed files with 18 additions and 26 deletions
|
|
@ -35,29 +35,22 @@ def ofList : List α → AsyncList ε α :=
|
|||
|
||||
instance : Coe (List α) (AsyncList ε α) := ⟨ofList⟩
|
||||
|
||||
private def coeErr {β} [Coe Error ε] (t : Task $ Except Error $ Except ε β) : Task (Except ε β) :=
|
||||
t.map $ fun
|
||||
| Except.ok v => v
|
||||
| Except.error (e : Error) => Except.error (e : ε)
|
||||
|
||||
/-- A stateful step computation `f` is applied iteratively, forming an async
|
||||
stream. The stream ends once `f` returns `none` for the first time. The
|
||||
computation can throw IO exceptions, so to handle this `ε` must include
|
||||
`IO.Error`.
|
||||
stream. The stream ends once `f` returns `none` for the first time.
|
||||
|
||||
For cooperatively cancelling an ongoing computation, we recommend referencing
|
||||
a cancellation token in `f` and checking it when appropriate. -/
|
||||
partial def unfoldAsync [Coe Error ε] (f : StateT σ (ExceptT ε IO) $ Option α) (init : σ)
|
||||
: IO (AsyncList ε α) := do
|
||||
let rec step (s : σ) : ExceptT ε IO (AsyncList ε α) := do
|
||||
partial def unfoldAsync (f : StateT σ (EIO ε) $ Option α) (init : σ)
|
||||
: BaseIO (AsyncList ε α) := do
|
||||
let rec step (s : σ) : EIO ε (AsyncList ε α) := do
|
||||
let (aNext, sNext) ← f s
|
||||
match aNext with
|
||||
| none => return nil
|
||||
| some aNext => do
|
||||
let tNext ← coeErr <$> asTask (step sNext)
|
||||
let tNext ← EIO.asTask (step sNext)
|
||||
return cons aNext $ asyncTail tNext
|
||||
|
||||
let tInit ← coeErr <$> asTask (step init)
|
||||
let tInit ← EIO.asTask (step init)
|
||||
asyncTail tInit
|
||||
|
||||
/-- The computed, synchronous list. If an async tail was present, returns also
|
||||
|
|
@ -73,7 +66,7 @@ partial def getAll : AsyncList ε α → List α × Option ε
|
|||
| Except.error e => ⟨[], some e⟩
|
||||
|
||||
/-- Spawns a `Task` waiting on the prefix of elements for which `p` is true. -/
|
||||
partial def waitAll [Coe Error ε] (p : α → Bool := fun _ => true) : AsyncList ε α → BaseIO (Task (List α × Option ε))
|
||||
partial def waitAll (p : α → Bool := fun _ => true) : AsyncList ε α → BaseIO (Task (List α × Option ε))
|
||||
| cons hd tl => do
|
||||
if p hd then
|
||||
let t ← tl.waitAll p
|
||||
|
|
@ -82,26 +75,22 @@ partial def waitAll [Coe Error ε] (p : α → Bool := fun _ => true) : AsyncLis
|
|||
return Task.pure ⟨[hd], none⟩
|
||||
| nil => return Task.pure ⟨[], none⟩
|
||||
| asyncTail tl => do
|
||||
let t : Task (Except IO.Error (List α × Option ε)) ← BaseIO.bindTask tl fun
|
||||
| Except.ok tl => Task.map Except.ok <$> tl.waitAll p
|
||||
| Except.error e => return Task.pure <| Except.ok ⟨[], some e⟩
|
||||
t.map fun
|
||||
| Except.error e => ⟨[], some e⟩
|
||||
| Except.ok v => v
|
||||
BaseIO.bindTask tl fun
|
||||
| Except.ok tl => tl.waitAll p
|
||||
| Except.error e => Task.pure ⟨[], some e⟩
|
||||
|
||||
/-- Spawns a `Task` acting like `List.find?` but which will wait for tail evalution
|
||||
when necessary to traverse the list. If the tail terminates before a matching element
|
||||
is found, the task throws the terminating value. -/
|
||||
partial def waitFind? (p : α → Bool) [Coe Error ε] : AsyncList ε α → BaseIO (Task $ Except ε $ Option α)
|
||||
partial def waitFind? (p : α → Bool) : AsyncList ε α → BaseIO (Task $ Except ε $ Option α)
|
||||
| nil => return Task.pure <| Except.ok none
|
||||
| cons hd tl => do
|
||||
if p hd then return Task.pure <| Except.ok <| some hd
|
||||
else tl.waitFind? p
|
||||
| asyncTail tl => do
|
||||
let t ← BaseIO.bindTask tl fun
|
||||
| Except.ok tl => Task.map Except.ok <$> tl.waitFind? p
|
||||
| Except.error e => return Task.pure <| Except.ok <| Except.error e
|
||||
coeErr t
|
||||
BaseIO.bindTask tl fun
|
||||
| Except.ok tl => tl.waitFind? p
|
||||
| Except.error e => Task.pure <| Except.error e
|
||||
|
||||
/-- Extends the `finishedPrefix` as far as possible. If computation was ongoing
|
||||
and has finished, also returns the terminating value. -/
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ section Elab
|
|||
private def AsyncElabState.lastSnap (s : AsyncElabState) : Snapshot :=
|
||||
s.snaps.getD (s.snaps.size - 1) s.headerSnap
|
||||
|
||||
abbrev AsyncElabM := StateT AsyncElabState $ ExceptT ElabTaskError IO
|
||||
abbrev AsyncElabM := StateT AsyncElabState <| EIO ElabTaskError
|
||||
|
||||
-- Placed here instead of Lean.Server.Utils because of an import loop
|
||||
private def publishReferences (m : DocumentMeta) (s : AsyncElabState) (hOut : FS.Stream) : IO Unit := do
|
||||
|
|
|
|||
|
|
@ -24,6 +24,9 @@ inductive ElabTaskError where
|
|||
instance : Coe IO.Error ElabTaskError :=
|
||||
⟨ElabTaskError.ioError⟩
|
||||
|
||||
instance : MonadLift IO (EIO ElabTaskError) where
|
||||
monadLift act := act.toEIO (coe ·)
|
||||
|
||||
structure CancelToken where
|
||||
ref : IO.Ref Bool
|
||||
deriving Inhabited
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue