fix: use let* to avoid bad error messages in do notation

cc @Kha
This commit is contained in:
Leonardo de Moura 2020-10-06 14:52:54 -07:00
parent d124718b05
commit db1b110f7e
3 changed files with 36 additions and 1 deletions

View file

@ -772,7 +772,31 @@ binders ← ps.mapM fun ⟨id, useTypeOf⟩ => do {
ctx ← read;
let m := ctx.m;
type ← `($m _);
`(let $(mkIdentFrom ref j):ident $binders:explicitBinder* : $type := $body; $k)
/-
We use `let*` instead of `let` for joinpoints to make sure `$k` is elaborated before `$body`.
By elaborating `$k` first, we "learn" more about `$body`'s type.
For example, consider the following example `do` expression
```
def f (x : Nat) : IO Unit := do
if x > 0 then
IO.println "x is not zero" -- Error is here
IO.mkRef true
```
it is expanded into
```
def f (x : Nat) : IO Unit := do
let jp (u : Unit) : IO _ :=
IO.mkRef true;
if x > 0 then
IO.println "not zero"
jp ()
else
jp ()
```
If we use the regular `let` instead of `let*`, the joinpoint `jp` will be elaborated and its type will be inferred to be `Unit → IO (IO.Ref Bool)`.
Then, we get a typing error at `jp ()`. By using `let*`, we first elaborate `if x > 0 ...` and learn that `jp` has type `Unit → IO Unit`.
Then, we get the expected type mismatch error at `IO.mkRef true`. -/
`(let* $(mkIdentFrom ref j):ident $binders:explicitBinder* : $type := $body; $k)
def mkJoinPoint (j : Name) (ps : Array (Name × Bool)) (body : Syntax) (k : Syntax) : M Syntax := do
r ← mkJoinPointCore j ps body k;

View file

@ -44,3 +44,8 @@ def f10 (x : Nat) : IO Unit := do
IO.println x
#print f10 -- we do not generate an unnecessary bind
def f11 (x : Nat) : IO Unit := do
if x > 0 then
IO.println "x is not zero"
IO.mkRef true -- error here as expected

View file

@ -21,3 +21,9 @@ doNotation1.lean:37:2: error: invalid 'do' element, it must be inside 'for'
doNotation1.lean:40:0: error: must be last element in a 'do' sequence
def f10 : Nat → IO Unit :=
fun (x : Nat) => IO.println x
doNotation1.lean:51:0: error: type mismatch
IO.mkRef true
has type
EIO IO.Error (IO.Ref Bool)
but is expected to have type
EIO IO.Error Unit