fix: expand doIf notation before lifting nested methods

This commit is contained in:
Leonardo de Moura 2020-10-19 11:32:06 -07:00
parent c05f73577a
commit 437f4670ed
2 changed files with 39 additions and 7 deletions

View file

@ -613,14 +613,15 @@ mkDoSeq #[doElem]
>> many (group (" else " >> " if ") >> optIdent >> termParser >> " then " >> doSeq)
>> optional (" else " >> doSeq)
```
Given a `doIf`, return an equivalente `doIf` that has no `else if`s and the `else` is not none. -/
private def expandDoIf (doIf : Syntax) : MacroM Syntax := do
let ref := doIf
If the given syntax is a `doIf`, return an equivalente `doIf` that has no `else if`s and the `else` is not none. -/
private def expandDoIf? (stx : Syntax) : MacroM (Option Syntax) := do
if stx.getKind != `Lean.Parser.Term.doIf then pure none else
let doIf := stx
let ref := stx
let doElseIfs := doIf[5].getArgs
let doElse := doIf[6]
if doElseIfs.isEmpty && !doElse.isNone then
pure doIf
pure none
else
let doElse ←
if doElse.isNone then
@ -641,7 +642,7 @@ else
mkSingletonDoSeq $ mkNode `Lean.Parser.Term.doIf doIfArgs])
doElse
let doIf := doIf.setArg 6 doElse
pure $ doIf.setArg 5 mkNullNode -- remove else-ifs
pure $ some $ doIf.setArg 5 mkNullNode -- remove else-ifs
structure DoIfView :=
(ref : Syntax)
@ -650,8 +651,8 @@ structure DoIfView :=
(thenBranch : Syntax)
(elseBranch : Syntax)
/- This method assumes `expandDoIf?` is not applicable. -/
private def mkDoIfView (doIf : Syntax) : MacroM DoIfView := do
let doIf ← expandDoIf doIf
pure {
ref := doIf,
optIdent := doIf[1],
@ -1414,6 +1415,9 @@ partial def doSeqToCode : List Syntax → M CodeBlock
| doElem::doElems => withRef doElem do
match (← liftMacroM $ expandMacro? doElem) with
| some doElem => doSeqToCode (doElem::doElems)
| none =>
match (← liftMacroM $ expandDoIf? doElem) with
| some doElem => doSeqToCode (doElem::doElems)
| none =>
let (liftedDoElems, doElem) ← liftM (liftMacroM $ expandLiftMethod doElem : TermElabM _)
if !liftedDoElems.isEmpty then

View file

@ -0,0 +1,28 @@
#lang lean4
def foo (x : Nat) : IO Bool := do
if x == 0 then
throw $ IO.userError "foo: unexpected zero"
pure (x == 1)
def tst (x : Nat) : IO Unit := do
if x == 0 then
IO.println "x is 0"
else if (← foo x) then
IO.println "x is 1"
else
IO.println "other"
#eval tst 0
#eval tst 1
#eval tst 2
syntax term "<|||>" term : doElem
macro_rules
| `(doElem| $a:term <|||> $b:term) => `(doElem| if (← $a:term) then pure true else $b:term)
def tst2 : IO Bool := do
pure true <|||> (← throw $ IO.userError "failed")
#eval tst2