From db1b110f7ef0edb410bfb7fc87aa5b01554659a8 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 6 Oct 2020 14:52:54 -0700 Subject: [PATCH] fix: use `let*` to avoid bad error messages in `do` notation cc @Kha --- src/Lean/Elab/Do.lean | 26 +++++++++++++++++++++++- tests/lean/doNotation1.lean | 5 +++++ tests/lean/doNotation1.lean.expected.out | 6 ++++++ 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/src/Lean/Elab/Do.lean b/src/Lean/Elab/Do.lean index b27edd364a..3b8d25f7e3 100644 --- a/src/Lean/Elab/Do.lean +++ b/src/Lean/Elab/Do.lean @@ -772,7 +772,31 @@ binders ← ps.mapM fun ⟨id, useTypeOf⟩ => do { ctx ← read; let m := ctx.m; type ← `($m _); -`(let $(mkIdentFrom ref j):ident $binders:explicitBinder* : $type := $body; $k) +/- +We use `let*` instead of `let` for joinpoints to make sure `$k` is elaborated before `$body`. +By elaborating `$k` first, we "learn" more about `$body`'s type. +For example, consider the following example `do` expression +``` +def f (x : Nat) : IO Unit := do +if x > 0 then + IO.println "x is not zero" -- Error is here +IO.mkRef true +``` +it is expanded into +``` +def f (x : Nat) : IO Unit := do +let jp (u : Unit) : IO _ := + IO.mkRef true; +if x > 0 then + IO.println "not zero" + jp () +else + jp () +``` +If we use the regular `let` instead of `let*`, the joinpoint `jp` will be elaborated and its type will be inferred to be `Unit → IO (IO.Ref Bool)`. +Then, we get a typing error at `jp ()`. By using `let*`, we first elaborate `if x > 0 ...` and learn that `jp` has type `Unit → IO Unit`. +Then, we get the expected type mismatch error at `IO.mkRef true`. -/ +`(let* $(mkIdentFrom ref j):ident $binders:explicitBinder* : $type := $body; $k) def mkJoinPoint (j : Name) (ps : Array (Name × Bool)) (body : Syntax) (k : Syntax) : M Syntax := do r ← mkJoinPointCore j ps body k; diff --git a/tests/lean/doNotation1.lean b/tests/lean/doNotation1.lean index 8e769c7a4d..698d5c65db 100644 --- a/tests/lean/doNotation1.lean +++ b/tests/lean/doNotation1.lean @@ -44,3 +44,8 @@ def f10 (x : Nat) : IO Unit := do IO.println x #print f10 -- we do not generate an unnecessary bind + +def f11 (x : Nat) : IO Unit := do +if x > 0 then + IO.println "x is not zero" +IO.mkRef true -- error here as expected diff --git a/tests/lean/doNotation1.lean.expected.out b/tests/lean/doNotation1.lean.expected.out index aaabf1c03d..be9ab12f52 100644 --- a/tests/lean/doNotation1.lean.expected.out +++ b/tests/lean/doNotation1.lean.expected.out @@ -21,3 +21,9 @@ doNotation1.lean:37:2: error: invalid 'do' element, it must be inside 'for' doNotation1.lean:40:0: error: must be last element in a 'do' sequence def f10 : Nat → IO Unit := fun (x : Nat) => IO.println x +doNotation1.lean:51:0: error: type mismatch + IO.mkRef true +has type + EIO IO.Error (IO.Ref Bool) +but is expected to have type + EIO IO.Error Unit