feat: expand doMatch
This commit is contained in:
parent
f4ccb78014
commit
294a750110
4 changed files with 76 additions and 16 deletions
|
|
@ -69,7 +69,7 @@ namespace Do
|
|||
|
||||
/- A `doMatch` alternative. `vars` is the array of variables declared by `patterns`. -/
|
||||
structure Alt (σ : Type) :=
|
||||
(ref : Syntax) (vars : Array Name) (patterns : Array Syntax) (rhs : σ)
|
||||
(ref : Syntax) (vars : Array Name) (patterns : Syntax) (rhs : σ)
|
||||
|
||||
/-
|
||||
Auxiliary datastructure for representing a `do` code block, and compiling "reassignments" (e.g., `x := x + 1`).
|
||||
|
|
@ -111,21 +111,21 @@ inductive Code
|
|||
| reassign (xs : Array Name) (stx : Syntax) (cont : Code)
|
||||
/- The Boolean value in `params` indicates whether we should use `(x : typeof! x)` when generating term Syntax or not -/
|
||||
| joinpoint (name : Name) (params : Array (Name × Bool)) (body : Code) (cont : Code)
|
||||
| action (stx : Syntax) (cont : Code)
|
||||
| returnAction (stx : Syntax)
|
||||
| action (stx : Syntax) (cont : Code) -- TODO: rename to `seq`?
|
||||
| returnAction (stx : Syntax) -- TODO: rename to `result`?
|
||||
| «break» (ref : Syntax)
|
||||
| «continue» (ref : Syntax)
|
||||
| «return» (ref : Syntax) (val : Syntax)
|
||||
/- Recall that an if-then-else may declare a variable using `optIdent` for the branches `thenBranch` and `elseBranch`. We store the variable name at `var?`. -/
|
||||
| ite (ref : Syntax) (h? : Option Name) (optIdent : Syntax) (cond : Syntax) (thenBranch : Code) (elseBranch : Code)
|
||||
| «match» (ref : Syntax) (discrs : Array Syntax) (type? : Option Syntax) (alts : Array (Alt Code))
|
||||
| «match» (ref : Syntax) (discrs : Syntax) (optType : Syntax) (alts : Array (Alt Code))
|
||||
| jmp (ref : Syntax) (jpName : Name) (args : Array Syntax)
|
||||
|
||||
instance Code.inhabited : Inhabited Code :=
|
||||
⟨Code.«break» (arbitrary _)⟩
|
||||
|
||||
instance Alt.inhabited : Inhabited (Alt Code) :=
|
||||
⟨{ ref := arbitrary _, vars := #[], patterns := #[], rhs := arbitrary _ }⟩
|
||||
⟨{ ref := arbitrary _, vars := #[], patterns := arbitrary _, rhs := arbitrary _ }⟩
|
||||
|
||||
/- A code block, and the collection of variables updated by it. -/
|
||||
structure CodeBlock :=
|
||||
|
|
@ -149,12 +149,10 @@ partial def toMessageDataAux (updateVars : MessageData) : Code → MessageData
|
|||
| Code.«continue» _ => "continue " ++ updateVars
|
||||
| Code.«return» _ _ => "return ... " ++ updateVars
|
||||
| Code.«match» _ ds t alts =>
|
||||
"match " ++ MessageData.joinSep (ds.toList.map MessageData.ofSyntax) ", " ++ " with " ++
|
||||
"match " ++ ds ++ " with " ++
|
||||
alts.foldl
|
||||
(fun (acc : MessageData) (alt : Alt Code) =>
|
||||
acc ++ Format.line ++ "| "
|
||||
++ MessageData.joinSep (alt.patterns.toList.map MessageData.ofSyntax) ", "
|
||||
++ " => " ++ toMessageDataAux alt.rhs)
|
||||
acc ++ Format.line ++ "| " ++ alt.patterns ++ " => " ++ toMessageDataAux alt.rhs)
|
||||
Format.nil
|
||||
|
||||
private def nameSetToArray (s : NameSet) : Array Name :=
|
||||
|
|
@ -444,6 +442,15 @@ unit ← `(PUnit.unit);
|
|||
let unit := unit.copyInfo ref;
|
||||
pure { c with code := Code.ite ref none mkNullNode cond (Code.«return» ref unit) c.code }
|
||||
|
||||
def mkMatch (ref : Syntax) (discrs : Syntax) (optType : Syntax) (alts : Array (Alt CodeBlock)) : TermElabM CodeBlock := do
|
||||
-- nary version of homogenize
|
||||
let ws := alts.foldl (fun (ws : NameSet) alt => union ws alt.rhs.uvars) {};
|
||||
alts : Array (Alt Code) ← alts.mapM fun alt => do {
|
||||
rhs ← extendUpdatedVars alt.rhs ws;
|
||||
pure { ref := alt.ref, vars := alt.vars, patterns := alt.patterns, rhs := rhs.code : Alt Code }
|
||||
};
|
||||
pure { code := Code.«match» ref discrs optType alts, uvars := ws }
|
||||
|
||||
/- Return a code block that executes `terminal` and then `k`.
|
||||
This method assumes `terminal` is a terminal -/
|
||||
def concat (terminal : CodeBlock) (k : CodeBlock) : TermElabM CodeBlock := do
|
||||
|
|
@ -499,13 +506,16 @@ letDecls.foldlM
|
|||
def getDoIdDeclVar (doIdDecl : Syntax) : Name :=
|
||||
(doIdDecl.getArg 0).getId
|
||||
|
||||
def getPatternVarNames (pvars : Array PatternVar) : Array Name :=
|
||||
pvars.filterMap fun pvar => match pvar with
|
||||
| PatternVar.localVar x => some x
|
||||
| _ => none
|
||||
|
||||
-- termParser >> leftArrow >> termParser >> optional (" | " >> termParser)
|
||||
def getDoPatDeclVars (doPatDecl : Syntax) : TermElabM (Array Name) := do
|
||||
let pattern := doPatDecl.getArg 0;
|
||||
patternVars ← getPatternVars pattern;
|
||||
pure $ patternVars.filterMap fun patternVar => match patternVar with
|
||||
| PatternVar.localVar x => some x
|
||||
| _ => none
|
||||
pure $ getPatternVarNames patternVars
|
||||
|
||||
-- parser! "let " >> (doIdDecl <|> doPatDecl)
|
||||
def getDoLetArrowVars (doLetArrow : Syntax) : TermElabM (Array Name) := do
|
||||
|
|
@ -823,7 +833,19 @@ partial def toTerm : Code → M Syntax
|
|||
else do
|
||||
c ← liftM (expandReturnAction e);
|
||||
toTerm c
|
||||
| Code.«match» _ _ _ _ => liftM $ Macro.throwError Syntax.missing "WIP"
|
||||
| Code.«match» ref discrs optType alts => do
|
||||
termSepAlts : Array Syntax ← alts.foldlM
|
||||
(fun (termAlts : Array Syntax) alt => do
|
||||
let termAlts := termAlts.push $ mkAtomFrom alt.ref "|";
|
||||
rhs ← toTerm alt.rhs;
|
||||
let termAlt := mkNode `Lean.Parser.Term.matchAlt #[alt.patterns, mkAtomFrom alt.ref "=>", rhs];
|
||||
let termAlts := termAlts.push termAlt;
|
||||
pure termAlts)
|
||||
#[];
|
||||
let firstVBar := termSepAlts.get! 0;
|
||||
let termSepAlts := mkNullNode $ termSepAlts.extract 1 termSepAlts.size;
|
||||
let termMatchAlts := mkNode `Lean.Parser.Term.matchAlts #[mkNullNode #[firstVBar], termSepAlts];
|
||||
pure $ mkNode `Lean.Parser.Term.«match» #[mkAtomFrom ref "match", discrs, optType, mkAtomFrom ref "with", termMatchAlts]
|
||||
|
||||
def run (code : Code) (m : Syntax) (uvars : Array Name := #[]) (kind := Kind.regular) : MacroM Syntax := do
|
||||
term ← toTerm code { m := m, kind := kind, uvars := uvars };
|
||||
|
|
@ -984,8 +1006,27 @@ partial def doSeqToCode : List Syntax → M CodeBlock
|
|||
else do
|
||||
auxDo ← `(do let r ← $forInTerm; $uvarsTuple:term := r);
|
||||
doSeqToCode (getDoSeqElems (getDoSeq auxDo) ++ doElems)
|
||||
else if k == `Lean.Parser.Term.doMatch then
|
||||
throwError "WIP"
|
||||
else if k == `Lean.Parser.Term.doMatch then do
|
||||
/- Recall that
|
||||
```
|
||||
def doMatchAlt := sepBy1 termParser ", " >> darrow >> doSeq
|
||||
def doMatchAlts := parser! optional "| " >> sepBy1 doMatchAlt "|"
|
||||
def doMatch := parser! "match " >> sepBy1 matchDiscr ", " >> optType >> " with " >> doMatchAlts
|
||||
-/
|
||||
let ref := doElem;
|
||||
let discrs := doElem.getArg 1;
|
||||
let optType := doElem.getArg 2;
|
||||
let matchAlts := ((doElem.getArg 4).getArg 1).getArgs.getSepElems; -- Array of `doMatchAlt`
|
||||
alts : Array (Alt CodeBlock) ← matchAlts.mapM fun matchAlt => do {
|
||||
let patterns := matchAlt.getArg 0;
|
||||
pvars ← liftM $ getPatternsVars patterns.getArgs.getSepElems;
|
||||
let vars := getPatternVarNames pvars;
|
||||
let rhs := matchAlt.getArg 2;
|
||||
rhs ← withNewVars vars $ doSeqToCode (getDoSeqElems rhs);
|
||||
pure { ref := matchAlt, vars := vars, patterns := patterns, rhs := rhs : Alt CodeBlock }
|
||||
};
|
||||
matchCode ← liftM $ mkMatch ref discrs optType alts;
|
||||
concatWithRest matchCode
|
||||
else if k == `Lean.Parser.Term.doTry then
|
||||
throwError "WIP"
|
||||
else if k == `Lean.Parser.Term.doBreak then do
|
||||
|
|
|
|||
|
|
@ -502,6 +502,10 @@ def getPatternVars (patternStx : Syntax) : TermElabM (Array PatternVar) := do
|
|||
(_, s) ← (CollectPatternVars.collect patternStx).run {};
|
||||
pure s.vars
|
||||
|
||||
def getPatternsVars (patterns : Array Syntax) : TermElabM (Array PatternVar) := do
|
||||
(_, s) ← (patterns.mapM fun pattern => CollectPatternVars.collect pattern).run {};
|
||||
pure s.vars
|
||||
|
||||
/- We convert the collected `PatternVar`s intro `PatternVarDecl` -/
|
||||
inductive PatternVarDecl
|
||||
/- For `anonymousVar`, we create both a metavariable and a free variable. The free variable is used as an assignment for the metavariable
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ else if c_2 then
|
|||
@[builtinDoElemParser] def doFor := parser! "for " >> termParser >> " in " >> withForbidden "do" termParser >> "do " >> doSeq
|
||||
|
||||
/- `match`-expression where the right-hand-side of alternatives is a `doSeq` instead of a `term` -/
|
||||
def doMatchAlt : Parser := sepBy1 termParser ", " >> darrow >> doSeq
|
||||
def doMatchAlt : Parser := parser! sepBy1 termParser ", " >> darrow >> doSeq
|
||||
def doMatchAlts : Parser := parser! withPosition $ (optional "| ") >> sepBy1 doMatchAlt (checkColGe "alternatives must be indented" >> "|")
|
||||
@[builtinDoElemParser] def doMatch := parser!:leadPrec "match " >> sepBy1 matchDiscr ", " >> optType >> " with " >> doMatchAlts
|
||||
|
||||
|
|
|
|||
|
|
@ -128,3 +128,18 @@ rfl
|
|||
def f3 (x : Nat) : IO Bool := do
|
||||
let y ← cond (x == 0) (do IO.println "hello"; true) false;
|
||||
!y
|
||||
|
||||
def f4 (x y : Nat) : Nat × Nat := do
|
||||
match x with
|
||||
| 0 => y := y + 1
|
||||
| _ => x := x + y
|
||||
return (x, y)
|
||||
|
||||
#eval f4 0 10
|
||||
#eval f4 5 10
|
||||
|
||||
theorem ex9 (y : Nat) : f4 0 y = (0, y+1) :=
|
||||
rfl
|
||||
|
||||
theorem ex10 (x y : Nat) : f4 (x+1) y = ((x+1)+y, y) :=
|
||||
rfl
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue