From efa119bc94eebc4a66e4a86a95be70fe233e0dca Mon Sep 17 00:00:00 2001 From: Sebastian Ullrich Date: Thu, 27 Aug 2020 16:11:17 +0200 Subject: [PATCH] feat: make std streams `Stream`s --- src/Init/System/IO.lean | 104 +++++++++++++++++++++++------ src/runtime/io.cpp | 50 +++++++------- tests/lean/stdio.lean | 4 +- tests/lean/stdio.lean.expected.out | 2 - 4 files changed, 110 insertions(+), 50 deletions(-) diff --git a/src/Init/System/IO.lean b/src/Init/System/IO.lean index f9929c1776..bf8df1c22b 100644 --- a/src/Init/System/IO.lean +++ b/src/Init/System/IO.lean @@ -6,7 +6,7 @@ Authors: Luke Nelson, Jared Roesch, Leonardo de Moura, Sebastian Ullrich prelude import Init.Control.EState import Init.Control.Reader -import Init.Data.String.Basic +import Init.Data.String import Init.Data.ByteArray import Init.System.IOError import Init.System.FilePath @@ -105,22 +105,33 @@ inductive FS.Mode constant FS.Handle : Type := Unit +/-- + A pure-Lean abstraction of POSIX streams. We use `Stream`s for the standard streams stdin/stdout/stderr so we can + capture output of `#eval` commands into memory. -/ +structure FS.Stream := +(isEof : IO Bool) +(flush : IO Unit) +(read : forall (bytes : USize), IO ByteArray) +(write : ByteArray → IO Unit) +(getLine : IO String) +(putStr : String → IO Unit) + namespace Prim open FS @[extern "lean_get_stdin"] -constant getStdin : IO FS.Handle := arbitrary _ +constant getStdin : IO FS.Stream := arbitrary _ @[extern "lean_get_stdout"] -constant getStdout : IO FS.Handle := arbitrary _ +constant getStdout : IO FS.Stream := arbitrary _ @[extern "lean_get_stderr"] -constant getStderr : IO FS.Handle := arbitrary _ +constant getStderr : IO FS.Stream := arbitrary _ @[extern "lean_get_set_stdin"] -constant setStdin : FS.Handle → IO FS.Handle := arbitrary _ +constant setStdin : FS.Stream → IO FS.Stream := arbitrary _ @[extern "lean_get_set_stdout"] -constant setStdout : FS.Handle → IO FS.Handle := arbitrary _ +constant setStdout : FS.Stream → IO FS.Stream := arbitrary _ @[extern "lean_get_set_stderr"] -constant setStderr : FS.Handle → IO FS.Handle := arbitrary _ +constant setStderr : FS.Stream → IO FS.Stream := arbitrary _ /-- Run action with `stdin` closed and `stdout+stderr` captured into a `String`. -/ @[extern "lean_with_isolated_streams"] @@ -150,7 +161,6 @@ constant Handle.mk (s : @& String) (mode : @& String) : IO Handle := arbitrary _ constant Handle.isEof (h : @& Handle) : IO Bool := arbitrary _ @[extern "lean_io_prim_handle_flush"] constant Handle.flush (h : @& Handle) : IO Unit := arbitrary _ --- TODO: replace `String` with byte buffer @[extern "lean_io_prim_handle_read"] constant Handle.read (h : @& Handle) (bytes : USize) : IO ByteArray := arbitrary _ @[extern "lean_io_prim_handle_write"] @@ -235,53 +245,60 @@ def lines (fname : String) : m (Array String) := do h ← Handle.mk fname Mode.read false; linesAux h #[] +namespace Stream + +def putStrLn (strm : FS.Stream) (s : String) : m Unit := +liftIO (strm.putStr s) *> liftIO (strm.putStr "\n") + +end Stream + end FS section variables {m : Type → Type} [Monad m] [MonadIO m] -def getStdin : m FS.Handle := +def getStdin : m FS.Stream := liftIO Prim.getStdin -def getStdout : m FS.Handle := +def getStdout : m FS.Stream := liftIO Prim.getStdout -def getStderr : m FS.Handle := +def getStderr : m FS.Stream := liftIO Prim.getStderr -/-- Replaces the stdin handle and returns its previous value. -/ -def setStdin : FS.Handle → m FS.Handle := +/-- Replaces the stdin stream of the current thread and returns its previous value. -/ +def setStdin : FS.Stream → m FS.Stream := liftIO ∘ Prim.setStdin -/-- Replaces the stdout handle and returns its previous value. -/ -def setStdout : FS.Handle → m FS.Handle := +/-- Replaces the stdout stream of the current thread and returns its previous value. -/ +def setStdout : FS.Stream → m FS.Stream := liftIO ∘ Prim.setStdout -/-- Replaces the stderr handle and returns its previous value. -/ -def setStderr : FS.Handle → m FS.Handle := +/-- Replaces the stderr stream of the current thread and returns its previous value. -/ +def setStderr : FS.Stream → m FS.Stream := liftIO ∘ Prim.setStderr -def withStdin [MonadFinally m] {α} (h : FS.Handle) (x : m α) : m α := do +def withStdin [MonadFinally m] {α} (h : FS.Stream) (x : m α) : m α := do prev ← setStdin h; finally x (discard $ setStdin prev) -def withStdout [MonadFinally m] {α} (h : FS.Handle) (x : m α) : m α := do +def withStdout [MonadFinally m] {α} (h : FS.Stream) (x : m α) : m α := do prev ← setStdout h; finally x (discard $ setStdout prev) -def withStderr [MonadFinally m] {α} (h : FS.Handle) (x : m α) : m α := do +def withStderr [MonadFinally m] {α} (h : FS.Stream) (x : m α) : m α := do prev ← setStderr h; finally x (discard $ setStderr prev) def print {α} [HasToString α] (s : α) : m Unit := do out ← getStdout; -out.putStr $ toString s +liftIO $ out.putStr $ toString s def println {α} [HasToString α] (s : α) : m Unit := print s *> print "\n" def eprint {α} [HasToString α] (s : α) : m Unit := do out ← getStderr; -out.putStr $ toString s +liftIO $ out.putStr $ toString s def eprintln {α} [HasToString α] (s : α) : m Unit := eprint s *> eprint "\n" @@ -358,6 +375,49 @@ instance st2eio {ε} : MonadLift (ST IO.RealWorld) (EIO ε) := def mkRef {α : Type} {m : Type → Type} [Monad m] [MonadLiftT (ST IO.RealWorld) m] (a : α) : m (IO.Ref α) := ST.mkRef a +namespace FS +namespace Stream + +@[export lean_stream_of_handle] +def ofHandle (h : Handle) : Stream := { + isEof := Prim.Handle.isEof h, + flush := Prim.Handle.flush h, + read := Prim.Handle.read h, + write := Prim.Handle.write h, + getLine := Prim.Handle.getLine h, + putStr := Prim.Handle.putStr h, +} + +structure Buffer := +(data : ByteArray := ByteArray.empty) +(pos : Nat := 0) + +def ofBuffer (r : Ref Buffer) : Stream := { + isEof := do b ← r.get; pure $ b.pos >= b.data.size, + flush := pure (), + read := fun n => do + b ← r.get; + let data := b.data.extract b.pos (b.pos + n.toNat); + r.set { b with pos := b.pos + data.size }; + pure data, + write := fun data => r.modify fun b => + -- set `exact` to `false` so that repeatedly writing to the stream does not impose quadratic run time + { b with data := data.copySlice 0 b.data b.pos data.size false, pos := b.pos + data.size }, + getLine := do + b ← r.get; + let pos := match b.data.findIdxAux (fun u => u == 0 || u = '\n'.toNat.toUInt8) b.pos with + -- include '\n', but not '\0' + | some pos => if b.data.get! pos == 0 then pos else pos + 1 + | none => b.data.size; + r.set { b with pos := pos }; + pure $ String.fromUTF8Unchecked $ b.data.extract b.pos pos, + putStr := fun s => + let data := s.toUTF8; + r.modify fun b => + { b with data := data.copySlice 0 b.data b.pos data.size false, pos := b.pos + data.size }, +} +end Stream +end FS end IO universe u diff --git a/src/runtime/io.cpp b/src/runtime/io.cpp index 3686476ad3..a1c5ac855d 100644 --- a/src/runtime/io.cpp +++ b/src/runtime/io.cpp @@ -125,53 +125,55 @@ static lean_object * io_wrap_handle(FILE *hfile) { return lean_alloc_external(g_io_handle_external_class, hfile); } -static object * g_handle_stdin = nullptr; -static object * g_handle_stdout = nullptr; -static object * g_handle_stderr = nullptr; -MK_THREAD_LOCAL_GET(object *, get_handle_current_stdin, g_handle_stdin); -MK_THREAD_LOCAL_GET(object *, get_handle_current_stdout, g_handle_stdout); -MK_THREAD_LOCAL_GET(object *, get_handle_current_stderr, g_handle_stderr); +extern "C" obj_res lean_stream_of_handle(obj_arg h); -/* getStdin : IO FS.Handle */ +static object * g_stream_stdin = nullptr; +static object * g_stream_stdout = nullptr; +static object * g_stream_stderr = nullptr; +MK_THREAD_LOCAL_GET(object *, get_stream_current_stdin, g_stream_stdin); +MK_THREAD_LOCAL_GET(object *, get_stream_current_stdout, g_stream_stdout); +MK_THREAD_LOCAL_GET(object *, get_stream_current_stderr, g_stream_stderr); + +/* getStdin : IO FS.Stream */ extern "C" obj_res lean_get_stdin(obj_arg /* w */) { - object * r = get_handle_current_stdin(); + object * r = get_stream_current_stdin(); inc_ref(r); return set_io_result(r); } -/* getStdout : IO FS.Handle */ +/* getStdout : IO FS.Stream */ extern "C" obj_res lean_get_stdout(obj_arg /* w */) { - object * r = get_handle_current_stdout(); + object * r = get_stream_current_stdout(); inc_ref(r); return set_io_result(r); } -/* getStderr : IO FS.Handle */ +/* getStderr : IO FS.Stream */ extern "C" obj_res lean_get_stderr(obj_arg /* w */) { - object * r = get_handle_current_stderr(); + object * r = get_stream_current_stderr(); inc_ref(r); return set_io_result(r); } -/* setStdin : FS.Handle -> IO FS.Handle */ +/* setStdin : FS.Stream -> IO FS.Stream */ extern "C" obj_res lean_get_set_stdin(obj_arg h, obj_arg /* w */) { - object * & x = get_handle_current_stdin(); + object * & x = get_stream_current_stdin(); object * r = x; x = h; return set_io_result(r); } -/* setStdout : FS.Handle -> IO FS.Handle */ +/* setStdout : FS.Stream -> IO FS.Stream */ extern "C" obj_res lean_get_set_stdout(obj_arg h, obj_arg /* w */) { - object * & x = get_handle_current_stdout(); + object * & x = get_stream_current_stdout(); object * r = x; x = h; return set_io_result(r); } -/* setStderr : FS.Handle -> IO FS.Handle */ +/* setStderr : FS.Stream -> IO FS.Stream */ extern "C" obj_res lean_get_set_stderr(obj_arg h, obj_arg /* w */) { - object * & x = get_handle_current_stderr(); + object * & x = get_stream_current_stderr(); object * r = x; x = h; return set_io_result(r); @@ -719,12 +721,12 @@ void initialize_io() { _setmode(_fileno(stderr), _O_BINARY); _setmode(_fileno(stdin), _O_BINARY); #endif - g_handle_stdout = io_wrap_handle(stdout); - mark_persistent(g_handle_stdout); - g_handle_stderr = io_wrap_handle(stderr); - mark_persistent(g_handle_stderr); - g_handle_stdin = io_wrap_handle(stdin); - mark_persistent(g_handle_stdin); + g_stream_stdout = lean_stream_of_handle(io_wrap_handle(stdout)); + mark_persistent(g_stream_stdout); + g_stream_stderr = lean_stream_of_handle(io_wrap_handle(stderr)); + mark_persistent(g_stream_stderr); + g_stream_stdin = lean_stream_of_handle(io_wrap_handle(stdin)); + mark_persistent(g_stream_stdin); } void finalize_io() { diff --git a/tests/lean/stdio.lean b/tests/lean/stdio.lean index a2aa6b519a..13e995b9b7 100644 --- a/tests/lean/stdio.lean +++ b/tests/lean/stdio.lean @@ -13,11 +13,11 @@ open IO def test : IO Unit := do FS.withFile "stdout1.txt" IO.FS.Mode.write $ fun h₁ => do { h₂ ← FS.Handle.mk "stdout2.txt" IO.FS.Mode.write; - withStdout h₁ $ do + withStdout (Stream.ofHandle h₁) $ do println "line 1"; catch ( do - withStdout h₂ $ println "line 2"; + withStdout (Stream.ofHandle h₂) $ println "line 2"; throw $ IO.userError "my error" ) ( fun e => println e ); println "line 3" }; diff --git a/tests/lean/stdio.lean.expected.out b/tests/lean/stdio.lean.expected.out index 7334bb1b47..a9b798dee7 100644 --- a/tests/lean/stdio.lean.expected.out +++ b/tests/lean/stdio.lean.expected.out @@ -1,5 +1,4 @@ print stdout - print stderr line 4 @@ -10,4 +9,3 @@ line 3 > stdout2.txt line 2 -