From db914052cebab54c5dd9bfbeb1b8689a2cc39665 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 23 Mar 2020 12:26:09 -0700 Subject: [PATCH] fix: IO.getLine --- src/runtime/io.cpp | 86 +++++++++++++++---------------- tests/lean/run/getline_crash.lean | 69 +++++++++++++++++++++++++ 2 files changed, 111 insertions(+), 44 deletions(-) create mode 100644 tests/lean/run/getline_crash.lean diff --git a/src/runtime/io.cpp b/src/runtime/io.cpp index ab0cfc0012..21f8989fe8 100644 --- a/src/runtime/io.cpp +++ b/src/runtime/io.cpp @@ -35,10 +35,8 @@ Author: Leonardo de Moura namespace lean { -extern "C" lean_object* lean_string_append(lean_object*, lean_object*); - extern "C" lean_object* lean_mk_io_error_already_exists(uint32_t, lean_object*); -extern "C" lean_object* lean_mk_io_error_eof(); +extern "C" lean_object* lean_mk_io_error_eof(lean_object*); extern "C" lean_object* lean_mk_io_error_hardware_fault(uint32_t, lean_object*); extern "C" lean_object* lean_mk_io_error_illegal_operation(uint32_t, lean_object*); extern "C" lean_object* lean_mk_io_error_inappropriate_type(uint32_t, lean_object*); @@ -296,11 +294,13 @@ extern "C" obj_res lean_io_prim_handle_write_byte(b_obj_arg h, uint8 c, obj_arg } } +static object * g_io_error_eof = nullptr; + /* Handle.read : (@& Handle) → USize → IO ByteArray */ extern "C" obj_res lean_io_prim_handle_read(b_obj_arg h, usize nbytes, obj_arg /* w */) { FILE * fp = io_get_handle(h); if (feof(fp)) { - return set_io_error(lean_mk_io_error_eof()); + return set_io_error(g_io_error_eof); } obj_res res = lean_alloc_sarray(1, 0, nbytes); usize n = std::fread(lean_sarray_cptr(res), 1, nbytes, fp); @@ -328,47 +328,41 @@ extern "C" obj_res lean_io_prim_handle_write(b_obj_arg h, b_obj_arg buf, obj_arg } } -obj_res lean_get_line(FILE * fp) { - const int buf_sz = 64; - lean_string_object * buf_str = lean_to_string(lean_alloc_string(0, buf_sz, 0)); - lean_object * res_str = lean_alloc_string(1, buf_sz, 0); - lean_to_string(res_str)->m_data[0] = '\0'; - char * out = nullptr; - do { - out = std::fgets(buf_str->m_data, buf_sz, fp); - if (out != nullptr) { - buf_str->m_size = strlen(buf_str->m_data); - buf_str->m_length = buf_str->m_size; - buf_str->m_size++; - res_str = lean_string_append(res_str, reinterpret_cast(buf_str)); - } - } while (out != nullptr && buf_str->m_size == buf_sz); - dec_ref(reinterpret_cast(buf_str)); - lean_to_string(res_str)->m_length = utf8_strlen(lean_to_string(res_str)->m_data); - if (out == nullptr && !feof(fp)) { - dec_ref(res_str); - return nullptr; - } else { - return res_str; - } -} +static object * g_io_error_getline = nullptr; -/* Handle.getLine : (@& Handle) → IO Unit */ -/* The line returned by `lean_io_prim_handle_get_line` */ -/* is truncated at the first '\0' character and the */ -/* rest of the line is discarded. */ +/* + Handle.getLine : (@& Handle) → IO Unit + The line returned by `lean_io_prim_handle_get_line` + is truncated at the first '\0' character and the + rest of the line is discarded. */ extern "C" obj_res lean_io_prim_handle_get_line(b_obj_arg h, obj_arg /* w */) { FILE * fp = io_get_handle(h); if (feof(fp)) { - return set_io_error(lean_mk_io_error_eof()); + std::cout << "get_line eof " << g_io_error_eof << std::endl; + return set_io_error(g_io_error_eof); } - object * res = lean_get_line(fp); - if (res != nullptr) { - return set_io_result(res); - } else if (feof(fp)) { - return set_io_result(lean_mk_string("")); - } else { - return set_io_error(decode_io_error(errno, nullptr)); + const int buf_sz = 64; + char buf_str[buf_sz]; + std::string result; + bool first = true; + while (true) { + char * out = std::fgets(buf_str, buf_sz, fp); + if (out != nullptr) { + if (strlen(buf_str) < buf_sz-1 || buf_str[buf_sz-2] == '\n') { + if (first) { + return set_io_result(mk_string(out)); + } else { + result.append(out); + return set_io_result(mk_string(result)); + } + } + result.append(out); + } else if (!first && std::feof(fp)) { + return set_io_result(mk_string(result)); + } else { + return set_io_error(g_io_error_getline); + } + first = false; } } @@ -481,9 +475,9 @@ extern "C" obj_res lean_io_app_dir(obj_arg) { char buf2[PATH_MAX]; uint32_t bufsize = PATH_MAX; if (_NSGetExecutablePath(buf1, &bufsize) != 0) - return set_io_error(mk_string("failed to locate application")); + return set_io_error("failed to locate application"); if (!realpath(buf1, buf2)) - return set_io_error(mk_string("failed to resolve symbolic links when locating application")); + return set_io_error("failed to resolve symbolic links when locating application"); return set_io_result(mk_string(buf2)); #else // Linux version @@ -493,7 +487,7 @@ extern "C" obj_res lean_io_app_dir(obj_arg) { pid_t pid = getpid(); snprintf(path, PATH_MAX, "/proc/%d/exe", pid); if (readlink(path, dest, PATH_MAX) == -1) { - return set_io_error(mk_string("failed to locate application")); + return set_io_error("failed to locate application"); } else { return set_io_result(mk_string(dest)); } @@ -611,8 +605,12 @@ extern "C" obj_res lean_io_ref_ptr_eq(b_obj_arg ref1, b_obj_arg ref2, obj_arg) { } void initialize_io() { - g_io_error_nullptr_read = mk_string("null reference read"); + g_io_error_nullptr_read = mk_io_user_error(mk_string("null reference read")); mark_persistent(g_io_error_nullptr_read); + g_io_error_getline = mk_io_user_error(mk_string("getLine failed")); + mark_persistent(g_io_error_getline); + g_io_error_eof = lean_mk_io_error_eof(lean_box(0)); + mark_persistent(g_io_error_eof); g_io_handle_external_class = lean_register_external_class(io_handle_finalizer, io_handle_foreach); } diff --git a/tests/lean/run/getline_crash.lean b/tests/lean/run/getline_crash.lean new file mode 100644 index 0000000000..62445b3ed2 --- /dev/null +++ b/tests/lean/run/getline_crash.lean @@ -0,0 +1,69 @@ +def tstGetLine (str : String) : IO Unit := do +let path := "tmp_file"; +IO.FS.withFile path IO.FS.Mode.write $ λ (h : IO.FS.Handle) => + h.putStrLn str; +IO.FS.withFile path IO.FS.Mode.read $ λ (h : IO.FS.Handle) => do + str' ← h.getLine; + IO.println str.length; + IO.println str'.length; + IO.print str'; + unless (str'.length == str.length + 1) $ + throw (IO.userError ("unexpected length: " ++ toString str'.trim.length)); + unless (str'.trim == str) $ + throw (IO.userError ("unexpected result: " ++ str')) + +def tstGetLine2 (str1 str2 : String) : IO Unit := do +let path := "tmp_file"; +IO.FS.withFile path IO.FS.Mode.write $ λ (h : IO.FS.Handle) => do { + h.putStrLn str1; h.putStr str2 +}; +IO.FS.withFile path IO.FS.Mode.read $ λ (h : IO.FS.Handle) => do + str1' ← h.getLine; + str2' ← h.getLine; + unless (str1'.length == str1.length + 1) $ + throw (IO.userError ("unexpected length: " ++ toString str1'.trim.length)); + unless (str1'.trim == str1) $ + throw (IO.userError ("unexpected result: " ++ str1')); + unless (str2'.length == str2.length) $ + throw (IO.userError ("unexpected length: " ++ toString str2'.trim.length)); + unless (str2'.trim == str2) $ + throw (IO.userError ("unexpected result: " ++ str2')) + +def tstGetLineFailure1 (str : String) : IO Unit := do +let path := "tmp_file"; +IO.FS.withFile path IO.FS.Mode.write $ λ (h : IO.FS.Handle) => do { + h.putStrLn str +}; +IO.FS.withFile path IO.FS.Mode.read $ λ (h : IO.FS.Handle) => do + whenM (catch (do (h.getLine >>= IO.println); (h.getLine >>= IO.println); (h.getLine >>= IO.println); pure true) (fun _ => pure false)) $ + throw $ IO.userError "unexpected success" + +def tstGetLineFailure2 (str : String) : IO Unit := do +let path := "tmp_file"; +IO.FS.withFile path IO.FS.Mode.write $ λ (h : IO.FS.Handle) => do { + h.putStrLn str +}; +IO.FS.withFile path IO.FS.Mode.read $ λ (h : IO.FS.Handle) => do + whenM (catch (do (h.getLine >>= IO.println); (h.getLine >>= IO.println); pure true) (fun _ => pure false)) $ + throw $ IO.userError "unexpected success" + +#eval tstGetLineFailure1 "abc" +#eval tstGetLineFailure2 "abc" + +#eval tstGetLine ("".pushn 'α' 40) +#eval tstGetLine "a" +#eval tstGetLine "" +#eval tstGetLine ("".pushn 'α' 20) +#eval tstGetLine ("".pushn 'a' 61) +#eval tstGetLine ("".pushn 'α' 61) +#eval tstGetLine ("".pushn 'a' 62) +#eval tstGetLine ("".pushn 'a' 63) +#eval tstGetLine ("".pushn 'a' 64) +#eval tstGetLine ("".pushn 'a' 65) +#eval tstGetLine ("".pushn 'a' 66) +#eval tstGetLine ("".pushn 'a' 128) + +#eval tstGetLine2 ("".pushn 'α' 20) ("".pushn 'β' 20) +#eval tstGetLine2 ("".pushn 'α' 40) ("".pushn 'β' 40) +#eval tstGetLine2 ("".pushn 'a' 61) ("".pushn 'b' 61) +#eval tstGetLine2 ("".pushn 'a' 61) ("".pushn 'b' 62)