diff --git a/src/Lean/Elab/Do.lean b/src/Lean/Elab/Do.lean index c02dfc3da8..fc95262c65 100644 --- a/src/Lean/Elab/Do.lean +++ b/src/Lean/Elab/Do.lean @@ -69,7 +69,7 @@ namespace Do /- A `doMatch` alternative. `vars` is the array of variables declared by `patterns`. -/ structure Alt (σ : Type) := -(ref : Syntax) (vars : Array Name) (patterns : Array Syntax) (rhs : σ) +(ref : Syntax) (vars : Array Name) (patterns : Syntax) (rhs : σ) /- Auxiliary datastructure for representing a `do` code block, and compiling "reassignments" (e.g., `x := x + 1`). @@ -111,21 +111,21 @@ inductive Code | reassign (xs : Array Name) (stx : Syntax) (cont : Code) /- The Boolean value in `params` indicates whether we should use `(x : typeof! x)` when generating term Syntax or not -/ | joinpoint (name : Name) (params : Array (Name × Bool)) (body : Code) (cont : Code) -| action (stx : Syntax) (cont : Code) -| returnAction (stx : Syntax) +| action (stx : Syntax) (cont : Code) -- TODO: rename to `seq`? +| returnAction (stx : Syntax) -- TODO: rename to `result`? | «break» (ref : Syntax) | «continue» (ref : Syntax) | «return» (ref : Syntax) (val : Syntax) /- Recall that an if-then-else may declare a variable using `optIdent` for the branches `thenBranch` and `elseBranch`. We store the variable name at `var?`. -/ | ite (ref : Syntax) (h? : Option Name) (optIdent : Syntax) (cond : Syntax) (thenBranch : Code) (elseBranch : Code) -| «match» (ref : Syntax) (discrs : Array Syntax) (type? : Option Syntax) (alts : Array (Alt Code)) +| «match» (ref : Syntax) (discrs : Syntax) (optType : Syntax) (alts : Array (Alt Code)) | jmp (ref : Syntax) (jpName : Name) (args : Array Syntax) instance Code.inhabited : Inhabited Code := ⟨Code.«break» (arbitrary _)⟩ instance Alt.inhabited : Inhabited (Alt Code) := -⟨{ ref := arbitrary _, vars := #[], patterns := #[], rhs := arbitrary _ }⟩ +⟨{ ref := arbitrary _, vars := #[], patterns := arbitrary _, rhs := arbitrary _ }⟩ /- A code block, and the collection of variables updated by it. -/ structure CodeBlock := @@ -149,12 +149,10 @@ partial def toMessageDataAux (updateVars : MessageData) : Code → MessageData | Code.«continue» _ => "continue " ++ updateVars | Code.«return» _ _ => "return ... " ++ updateVars | Code.«match» _ ds t alts => - "match " ++ MessageData.joinSep (ds.toList.map MessageData.ofSyntax) ", " ++ " with " ++ + "match " ++ ds ++ " with " ++ alts.foldl (fun (acc : MessageData) (alt : Alt Code) => - acc ++ Format.line ++ "| " - ++ MessageData.joinSep (alt.patterns.toList.map MessageData.ofSyntax) ", " - ++ " => " ++ toMessageDataAux alt.rhs) + acc ++ Format.line ++ "| " ++ alt.patterns ++ " => " ++ toMessageDataAux alt.rhs) Format.nil private def nameSetToArray (s : NameSet) : Array Name := @@ -444,6 +442,15 @@ unit ← `(PUnit.unit); let unit := unit.copyInfo ref; pure { c with code := Code.ite ref none mkNullNode cond (Code.«return» ref unit) c.code } +def mkMatch (ref : Syntax) (discrs : Syntax) (optType : Syntax) (alts : Array (Alt CodeBlock)) : TermElabM CodeBlock := do +-- nary version of homogenize +let ws := alts.foldl (fun (ws : NameSet) alt => union ws alt.rhs.uvars) {}; +alts : Array (Alt Code) ← alts.mapM fun alt => do { + rhs ← extendUpdatedVars alt.rhs ws; + pure { ref := alt.ref, vars := alt.vars, patterns := alt.patterns, rhs := rhs.code : Alt Code } +}; +pure { code := Code.«match» ref discrs optType alts, uvars := ws } + /- Return a code block that executes `terminal` and then `k`. This method assumes `terminal` is a terminal -/ def concat (terminal : CodeBlock) (k : CodeBlock) : TermElabM CodeBlock := do @@ -499,13 +506,16 @@ letDecls.foldlM def getDoIdDeclVar (doIdDecl : Syntax) : Name := (doIdDecl.getArg 0).getId +def getPatternVarNames (pvars : Array PatternVar) : Array Name := +pvars.filterMap fun pvar => match pvar with + | PatternVar.localVar x => some x + | _ => none + -- termParser >> leftArrow >> termParser >> optional (" | " >> termParser) def getDoPatDeclVars (doPatDecl : Syntax) : TermElabM (Array Name) := do let pattern := doPatDecl.getArg 0; patternVars ← getPatternVars pattern; -pure $ patternVars.filterMap fun patternVar => match patternVar with - | PatternVar.localVar x => some x - | _ => none +pure $ getPatternVarNames patternVars -- parser! "let " >> (doIdDecl <|> doPatDecl) def getDoLetArrowVars (doLetArrow : Syntax) : TermElabM (Array Name) := do @@ -823,7 +833,19 @@ partial def toTerm : Code → M Syntax else do c ← liftM (expandReturnAction e); toTerm c -| Code.«match» _ _ _ _ => liftM $ Macro.throwError Syntax.missing "WIP" +| Code.«match» ref discrs optType alts => do + termSepAlts : Array Syntax ← alts.foldlM + (fun (termAlts : Array Syntax) alt => do + let termAlts := termAlts.push $ mkAtomFrom alt.ref "|"; + rhs ← toTerm alt.rhs; + let termAlt := mkNode `Lean.Parser.Term.matchAlt #[alt.patterns, mkAtomFrom alt.ref "=>", rhs]; + let termAlts := termAlts.push termAlt; + pure termAlts) + #[]; + let firstVBar := termSepAlts.get! 0; + let termSepAlts := mkNullNode $ termSepAlts.extract 1 termSepAlts.size; + let termMatchAlts := mkNode `Lean.Parser.Term.matchAlts #[mkNullNode #[firstVBar], termSepAlts]; + pure $ mkNode `Lean.Parser.Term.«match» #[mkAtomFrom ref "match", discrs, optType, mkAtomFrom ref "with", termMatchAlts] def run (code : Code) (m : Syntax) (uvars : Array Name := #[]) (kind := Kind.regular) : MacroM Syntax := do term ← toTerm code { m := m, kind := kind, uvars := uvars }; @@ -984,8 +1006,27 @@ partial def doSeqToCode : List Syntax → M CodeBlock else do auxDo ← `(do let r ← $forInTerm; $uvarsTuple:term := r); doSeqToCode (getDoSeqElems (getDoSeq auxDo) ++ doElems) - else if k == `Lean.Parser.Term.doMatch then - throwError "WIP" + else if k == `Lean.Parser.Term.doMatch then do + /- Recall that + ``` + def doMatchAlt := sepBy1 termParser ", " >> darrow >> doSeq + def doMatchAlts := parser! optional "| " >> sepBy1 doMatchAlt "|" + def doMatch := parser! "match " >> sepBy1 matchDiscr ", " >> optType >> " with " >> doMatchAlts + -/ + let ref := doElem; + let discrs := doElem.getArg 1; + let optType := doElem.getArg 2; + let matchAlts := ((doElem.getArg 4).getArg 1).getArgs.getSepElems; -- Array of `doMatchAlt` + alts : Array (Alt CodeBlock) ← matchAlts.mapM fun matchAlt => do { + let patterns := matchAlt.getArg 0; + pvars ← liftM $ getPatternsVars patterns.getArgs.getSepElems; + let vars := getPatternVarNames pvars; + let rhs := matchAlt.getArg 2; + rhs ← withNewVars vars $ doSeqToCode (getDoSeqElems rhs); + pure { ref := matchAlt, vars := vars, patterns := patterns, rhs := rhs : Alt CodeBlock } + }; + matchCode ← liftM $ mkMatch ref discrs optType alts; + concatWithRest matchCode else if k == `Lean.Parser.Term.doTry then throwError "WIP" else if k == `Lean.Parser.Term.doBreak then do diff --git a/src/Lean/Elab/Match.lean b/src/Lean/Elab/Match.lean index 14c57bbd7f..336d99ac11 100644 --- a/src/Lean/Elab/Match.lean +++ b/src/Lean/Elab/Match.lean @@ -502,6 +502,10 @@ def getPatternVars (patternStx : Syntax) : TermElabM (Array PatternVar) := do (_, s) ← (CollectPatternVars.collect patternStx).run {}; pure s.vars +def getPatternsVars (patterns : Array Syntax) : TermElabM (Array PatternVar) := do +(_, s) ← (patterns.mapM fun pattern => CollectPatternVars.collect pattern).run {}; +pure s.vars + /- We convert the collected `PatternVar`s intro `PatternVarDecl` -/ inductive PatternVarDecl /- For `anonymousVar`, we create both a metavariable and a free variable. The free variable is used as an assignment for the metavariable diff --git a/src/Lean/Parser/Do.lean b/src/Lean/Parser/Do.lean index 4f52837101..af42064382 100644 --- a/src/Lean/Parser/Do.lean +++ b/src/Lean/Parser/Do.lean @@ -72,7 +72,7 @@ else if c_2 then @[builtinDoElemParser] def doFor := parser! "for " >> termParser >> " in " >> withForbidden "do" termParser >> "do " >> doSeq /- `match`-expression where the right-hand-side of alternatives is a `doSeq` instead of a `term` -/ -def doMatchAlt : Parser := sepBy1 termParser ", " >> darrow >> doSeq +def doMatchAlt : Parser := parser! sepBy1 termParser ", " >> darrow >> doSeq def doMatchAlts : Parser := parser! withPosition $ (optional "| ") >> sepBy1 doMatchAlt (checkColGe "alternatives must be indented" >> "|") @[builtinDoElemParser] def doMatch := parser!:leadPrec "match " >> sepBy1 matchDiscr ", " >> optType >> " with " >> doMatchAlts diff --git a/tests/lean/run/doNotation2.lean b/tests/lean/run/doNotation2.lean index c1d623b983..12a24e4615 100644 --- a/tests/lean/run/doNotation2.lean +++ b/tests/lean/run/doNotation2.lean @@ -128,3 +128,18 @@ rfl def f3 (x : Nat) : IO Bool := do let y ← cond (x == 0) (do IO.println "hello"; true) false; !y + +def f4 (x y : Nat) : Nat × Nat := do +match x with +| 0 => y := y + 1 +| _ => x := x + y +return (x, y) + +#eval f4 0 10 +#eval f4 5 10 + +theorem ex9 (y : Nat) : f4 0 y = (0, y+1) := +rfl + +theorem ex10 (x y : Nat) : f4 (x+1) y = ((x+1)+y, y) := +rfl