feat: expand doMatch

This commit is contained in:
Leonardo de Moura 2020-10-06 19:07:47 -07:00
parent f4ccb78014
commit 294a750110
4 changed files with 76 additions and 16 deletions

View file

@ -69,7 +69,7 @@ namespace Do
/- A `doMatch` alternative. `vars` is the array of variables declared by `patterns`. -/
structure Alt (σ : Type) :=
(ref : Syntax) (vars : Array Name) (patterns : Array Syntax) (rhs : σ)
(ref : Syntax) (vars : Array Name) (patterns : Syntax) (rhs : σ)
/-
Auxiliary datastructure for representing a `do` code block, and compiling "reassignments" (e.g., `x := x + 1`).
@ -111,21 +111,21 @@ inductive Code
| reassign (xs : Array Name) (stx : Syntax) (cont : Code)
/- The Boolean value in `params` indicates whether we should use `(x : typeof! x)` when generating term Syntax or not -/
| joinpoint (name : Name) (params : Array (Name × Bool)) (body : Code) (cont : Code)
| action (stx : Syntax) (cont : Code)
| returnAction (stx : Syntax)
| action (stx : Syntax) (cont : Code) -- TODO: rename to `seq`?
| returnAction (stx : Syntax) -- TODO: rename to `result`?
| «break» (ref : Syntax)
| «continue» (ref : Syntax)
| «return» (ref : Syntax) (val : Syntax)
/- Recall that an if-then-else may declare a variable using `optIdent` for the branches `thenBranch` and `elseBranch`. We store the variable name at `var?`. -/
| ite (ref : Syntax) (h? : Option Name) (optIdent : Syntax) (cond : Syntax) (thenBranch : Code) (elseBranch : Code)
| «match» (ref : Syntax) (discrs : Array Syntax) (type? : Option Syntax) (alts : Array (Alt Code))
| «match» (ref : Syntax) (discrs : Syntax) (optType : Syntax) (alts : Array (Alt Code))
| jmp (ref : Syntax) (jpName : Name) (args : Array Syntax)
instance Code.inhabited : Inhabited Code :=
⟨Code.«break» (arbitrary _)⟩
instance Alt.inhabited : Inhabited (Alt Code) :=
⟨{ ref := arbitrary _, vars := #[], patterns := #[], rhs := arbitrary _ }⟩
⟨{ ref := arbitrary _, vars := #[], patterns := arbitrary _, rhs := arbitrary _ }⟩
/- A code block, and the collection of variables updated by it. -/
structure CodeBlock :=
@ -149,12 +149,10 @@ partial def toMessageDataAux (updateVars : MessageData) : Code → MessageData
| Code.«continue» _ => "continue " ++ updateVars
| Code.«return» _ _ => "return ... " ++ updateVars
| Code.«match» _ ds t alts =>
"match " ++ MessageData.joinSep (ds.toList.map MessageData.ofSyntax) ", " ++ " with " ++
"match " ++ ds ++ " with " ++
alts.foldl
(fun (acc : MessageData) (alt : Alt Code) =>
acc ++ Format.line ++ "| "
++ MessageData.joinSep (alt.patterns.toList.map MessageData.ofSyntax) ", "
++ " => " ++ toMessageDataAux alt.rhs)
acc ++ Format.line ++ "| " ++ alt.patterns ++ " => " ++ toMessageDataAux alt.rhs)
Format.nil
private def nameSetToArray (s : NameSet) : Array Name :=
@ -444,6 +442,15 @@ unit ← `(PUnit.unit);
let unit := unit.copyInfo ref;
pure { c with code := Code.ite ref none mkNullNode cond (Code.«return» ref unit) c.code }
def mkMatch (ref : Syntax) (discrs : Syntax) (optType : Syntax) (alts : Array (Alt CodeBlock)) : TermElabM CodeBlock := do
-- nary version of homogenize
let ws := alts.foldl (fun (ws : NameSet) alt => union ws alt.rhs.uvars) {};
alts : Array (Alt Code) ← alts.mapM fun alt => do {
rhs ← extendUpdatedVars alt.rhs ws;
pure { ref := alt.ref, vars := alt.vars, patterns := alt.patterns, rhs := rhs.code : Alt Code }
};
pure { code := Code.«match» ref discrs optType alts, uvars := ws }
/- Return a code block that executes `terminal` and then `k`.
This method assumes `terminal` is a terminal -/
def concat (terminal : CodeBlock) (k : CodeBlock) : TermElabM CodeBlock := do
@ -499,13 +506,16 @@ letDecls.foldlM
def getDoIdDeclVar (doIdDecl : Syntax) : Name :=
(doIdDecl.getArg 0).getId
def getPatternVarNames (pvars : Array PatternVar) : Array Name :=
pvars.filterMap fun pvar => match pvar with
| PatternVar.localVar x => some x
| _ => none
-- termParser >> leftArrow >> termParser >> optional (" | " >> termParser)
def getDoPatDeclVars (doPatDecl : Syntax) : TermElabM (Array Name) := do
let pattern := doPatDecl.getArg 0;
patternVars ← getPatternVars pattern;
pure $ patternVars.filterMap fun patternVar => match patternVar with
| PatternVar.localVar x => some x
| _ => none
pure $ getPatternVarNames patternVars
-- parser! "let " >> (doIdDecl <|> doPatDecl)
def getDoLetArrowVars (doLetArrow : Syntax) : TermElabM (Array Name) := do
@ -823,7 +833,19 @@ partial def toTerm : Code → M Syntax
else do
c ← liftM (expandReturnAction e);
toTerm c
| Code.«match» _ _ _ _ => liftM $ Macro.throwError Syntax.missing "WIP"
| Code.«match» ref discrs optType alts => do
termSepAlts : Array Syntax ← alts.foldlM
(fun (termAlts : Array Syntax) alt => do
let termAlts := termAlts.push $ mkAtomFrom alt.ref "|";
rhs ← toTerm alt.rhs;
let termAlt := mkNode `Lean.Parser.Term.matchAlt #[alt.patterns, mkAtomFrom alt.ref "=>", rhs];
let termAlts := termAlts.push termAlt;
pure termAlts)
#[];
let firstVBar := termSepAlts.get! 0;
let termSepAlts := mkNullNode $ termSepAlts.extract 1 termSepAlts.size;
let termMatchAlts := mkNode `Lean.Parser.Term.matchAlts #[mkNullNode #[firstVBar], termSepAlts];
pure $ mkNode `Lean.Parser.Term.«match» #[mkAtomFrom ref "match", discrs, optType, mkAtomFrom ref "with", termMatchAlts]
def run (code : Code) (m : Syntax) (uvars : Array Name := #[]) (kind := Kind.regular) : MacroM Syntax := do
term ← toTerm code { m := m, kind := kind, uvars := uvars };
@ -984,8 +1006,27 @@ partial def doSeqToCode : List Syntax → M CodeBlock
else do
auxDo ← `(do let r ← $forInTerm; $uvarsTuple:term := r);
doSeqToCode (getDoSeqElems (getDoSeq auxDo) ++ doElems)
else if k == `Lean.Parser.Term.doMatch then
throwError "WIP"
else if k == `Lean.Parser.Term.doMatch then do
/- Recall that
```
def doMatchAlt := sepBy1 termParser ", " >> darrow >> doSeq
def doMatchAlts := parser! optional "| " >> sepBy1 doMatchAlt "|"
def doMatch := parser! "match " >> sepBy1 matchDiscr ", " >> optType >> " with " >> doMatchAlts
-/
let ref := doElem;
let discrs := doElem.getArg 1;
let optType := doElem.getArg 2;
let matchAlts := ((doElem.getArg 4).getArg 1).getArgs.getSepElems; -- Array of `doMatchAlt`
alts : Array (Alt CodeBlock) ← matchAlts.mapM fun matchAlt => do {
let patterns := matchAlt.getArg 0;
pvars ← liftM $ getPatternsVars patterns.getArgs.getSepElems;
let vars := getPatternVarNames pvars;
let rhs := matchAlt.getArg 2;
rhs ← withNewVars vars $ doSeqToCode (getDoSeqElems rhs);
pure { ref := matchAlt, vars := vars, patterns := patterns, rhs := rhs : Alt CodeBlock }
};
matchCode ← liftM $ mkMatch ref discrs optType alts;
concatWithRest matchCode
else if k == `Lean.Parser.Term.doTry then
throwError "WIP"
else if k == `Lean.Parser.Term.doBreak then do

View file

@ -502,6 +502,10 @@ def getPatternVars (patternStx : Syntax) : TermElabM (Array PatternVar) := do
(_, s) ← (CollectPatternVars.collect patternStx).run {};
pure s.vars
def getPatternsVars (patterns : Array Syntax) : TermElabM (Array PatternVar) := do
(_, s) ← (patterns.mapM fun pattern => CollectPatternVars.collect pattern).run {};
pure s.vars
/- We convert the collected `PatternVar`s intro `PatternVarDecl` -/
inductive PatternVarDecl
/- For `anonymousVar`, we create both a metavariable and a free variable. The free variable is used as an assignment for the metavariable

View file

@ -72,7 +72,7 @@ else if c_2 then
@[builtinDoElemParser] def doFor := parser! "for " >> termParser >> " in " >> withForbidden "do" termParser >> "do " >> doSeq
/- `match`-expression where the right-hand-side of alternatives is a `doSeq` instead of a `term` -/
def doMatchAlt : Parser := sepBy1 termParser ", " >> darrow >> doSeq
def doMatchAlt : Parser := parser! sepBy1 termParser ", " >> darrow >> doSeq
def doMatchAlts : Parser := parser! withPosition $ (optional "| ") >> sepBy1 doMatchAlt (checkColGe "alternatives must be indented" >> "|")
@[builtinDoElemParser] def doMatch := parser!:leadPrec "match " >> sepBy1 matchDiscr ", " >> optType >> " with " >> doMatchAlts

View file

@ -128,3 +128,18 @@ rfl
def f3 (x : Nat) : IO Bool := do
let y ← cond (x == 0) (do IO.println "hello"; true) false;
!y
def f4 (x y : Nat) : Nat × Nat := do
match x with
| 0 => y := y + 1
| _ => x := x + y
return (x, y)
#eval f4 0 10
#eval f4 5 10
theorem ex9 (y : Nat) : f4 0 y = (0, y+1) :=
rfl
theorem ex10 (x y : Nat) : f4 (x+1) y = ((x+1)+y, y) :=
rfl