diff --git a/src/Lean/Elab/Do.lean b/src/Lean/Elab/Do.lean index 58ed842f82..b8a657dd6c 100644 --- a/src/Lean/Elab/Do.lean +++ b/src/Lean/Elab/Do.lean @@ -481,8 +481,19 @@ else throwError "unexpected kind of let declaration" def getDoLetVars (doLet : Syntax) : TermElabM (Array Name) := +-- parser! "let " >> letDecl getLetDeclVars (doLet.getArg 1) +def getDoLetRecVars (doLetRec : Syntax) : TermElabM (Array Name) := do +-- letRecDecls is an array of `(group (optional attributes >> letDecl))` +let letRecDecls := (doLetRec.getArg 1).getArgs.getSepElems; +let letDecls := letRecDecls.map fun p => p.getArg 1; +letDecls.foldlM + (fun allVars letDecl => do + vars ← getLetDeclVars letDecl; + pure (allVars ++ vars)) + #[] + -- ident >> optType >> leftArrow >> termParser def getDoIdDeclVar (doIdDecl : Syntax) : Name := (doIdDecl.getArg 0).getId @@ -682,7 +693,9 @@ if kind == `Lean.Parser.Term.doLet then let letDecl := decl.getArg 1; `(let $letDecl:letDecl; $k) else if kind == `Lean.Parser.Term.doLetRec then - liftM $ Macro.throwError decl "WIP" + let letRecToken := decl.getArg 0; + let letRecDecls := decl.getArg 1; + pure $ mkNode `Lean.Parser.Term.letrec #[letRecToken, letRecDecls, mkNullNode, k] else if kind == `Lean.Parser.Term.doLetArrow then let arg := decl.getArg 1; let ref := arg; @@ -887,7 +900,8 @@ partial def doSeqToCode : List Syntax → M CodeBlock vars ← liftM $ getDoLetVars doElem; mkVarDeclCore vars doElem <$> withNewVars vars (doSeqToCode doElems) else if k == `Lean.Parser.Term.doLetRec then do - throwError "WIP" + vars ← liftM $ getDoLetRecVars doElem; + mkVarDeclCore vars doElem <$> withNewVars vars (doSeqToCode doElems) else if k == `Lean.Parser.Term.doLetArrow then do vars ← liftM $ getDoLetArrowVars doElem; mkVarDeclCore vars doElem <$> withNewVars vars (doSeqToCode doElems) diff --git a/tests/lean/run/doNotation2.lean b/tests/lean/run/doNotation2.lean index 854e625536..44cff136dd 100644 --- a/tests/lean/run/doNotation2.lean +++ b/tests/lean/run/doNotation2.lean @@ -90,3 +90,24 @@ return sum theorem ex7 : sumDiff [(2, 1), (10, 5)] = 6 := rfl + +def f1 (x : Nat) : IO Unit := do +let rec loop : Nat → IO Unit + | 0 => pure () + | x+1 => do IO.println x; loop x +loop x + +#eval f1 10 + +partial def f2 (x : Nat) : IO Unit := do +let rec + isEven : Nat → Bool + | 0 => true + | x+1 => isOdd x, + isOdd : Nat → Bool + | 0 => false + | x+1 => isEven x +IO.println ("isOdd(" ++ toString x ++ "): " ++ toString (isOdd x)) + +#eval f2 11 +#eval f2 10