feat: do code blocks

WIP
This commit is contained in:
Leonardo de Moura 2020-09-30 19:20:16 -07:00
parent a3218dd063
commit 94c7945bd3
2 changed files with 309 additions and 144 deletions

View file

@ -18,6 +18,12 @@ namespace Do
structure Alt (σ : Type) :=
(ref : Syntax) (patterns : Array Syntax) (rhs : σ)
structure VarDecl :=
(ref : Syntax) (name : Name) (pure : Bool) (letDecl : Syntax)
structure JPDecl (σ : Type) :=
(ref : Syntax) (name : Name) (params : Array Name) (body : σ)
/-
Auxiliary datastructure for representing a `do` code block.
We convert `Code` into a `Syntax` term representing the:
@ -39,49 +45,288 @@ structure Alt (σ : Type) :=
- `match`: pattern matching
- `jmp` a goto to a join-point
We store the set of updated variables `uvars` in the terminals `break`, `continue`, and `return`.
We say `break`, `continue` and `return` are "exit points"
The terminal `return` also contains the name of the variable containing the result of the computation.
We ignore this value when inside a `for x in s`.
A code block `C` is well-formed if
1- The collection of updated variables is the same in all `break`
`continue` and `return` in `C`.
2- For every `jmp r j as` in `C`, there is a `jdecl r j ps b k` s.t. `jmp r j` is in `k`, and
`ps.size == as.size`
3- The update variables occurring in `break`, `continue`, and `return` are pairwise distinct.
We use the notation `C[u_1, ..., u_k]` to denote a code block that updates variables `u_1, ..., u_k`
- For every `jmp r j as` in `C`, there is a `jdecl r j ps b k` s.t. `jmp r j` is in `k`, and
`ps.size == as.size`
-/
inductive Code
| vdecl (ref : Syntax) (id : Name) (type : Syntax) (pure : Bool) (val : Syntax) (cont : Code)
| jdecl (ref : Syntax) (id : Name) (params : Array Name) (body : Code) (cont : Code)
| vdecl (decl : VarDecl) (reassignment : Bool) (cont : Code)
| jdecl (decl : JPDecl Code) (cont : Code)
| action (term : Syntax) (cond : Code)
| «break» (ref : Syntax) (uvars : Array Name)
| «continue» (ref : Syntax) (uvars : Array Name)
| «return» (ref : Syntax) (var? : Option Name) (uvars : Array Name)
| «break» (ref : Syntax)
| «continue» (ref : Syntax)
| «return» (ref : Syntax) (var? : Option Name)
| ite (ref : Syntax) (cond : Syntax) (thenBranch : Code) (elseBranch : Code)
| «match» (ref : Syntax) (discrs : Array Syntax) (type? : Option Syntax) (alts : Array (Alt Code))
| jmp (ref : Syntax) (jpName : Name) (args : Array Name)
instance body.inhabited : Inhabited Code :=
⟨Code.«break» (arbitrary _) #[]
instance Code.inhabited : Inhabited Code :=
⟨Code.«break» (arbitrary _)⟩
instance alt.inhabited : Inhabited (Alt Code) :=
instance Alt.inhabited : Inhabited (Alt Code) :=
⟨{ ref := arbitrary _, patterns := #[], rhs := arbitrary _ }⟩
partial def getUpdatedVars? : Code → Option (Array Name)
| Code.vdecl _ _ _ _ _ k => getUpdatedVars? k
| Code.jdecl _ _ _ b k => getUpdatedVars? b <|> getUpdatedVars? k
| Code.action _ k => getUpdatedVars? k
| Code.«break» _ uvars => some uvars
| Code.«continue» _ uvars => some uvars
| Code.«return» _ _ uvars => some uvars
| Code.ite _ _ t e => getUpdatedVars? t <|> getUpdatedVars? e
| Code.«match» _ _ _ alts => alts.findSome? fun alt => getUpdatedVars? alt.rhs
| Code.jmp _ _ _ => none
/- A code block, and the collection of variables updated by it. -/
structure CodeBlock :=
(code : Code)
(uvars : NameSet := {}) -- set of variables updated by `code`
partial def toMessageDataAux (updateVars : MessageData) : Code → MessageData
| Code.vdecl d r k =>
(if r then "" else "let ") ++ d.name ++ " " ++ (if d.pure then ":=" else "←") ++ " ... " ++ Format.line ++ toMessageDataAux k
| Code.jdecl d k =>
"let " ++ d.name.simpMacroScopes ++ " " ++ toString d.params.toList ++ ":=" ++ indentD (toMessageDataAux d.body) ++ Format.line ++ toMessageDataAux k
| Code.action e k => e ++ Format.line ++ toMessageDataAux k
| Code.ite _ c t e => "if " ++ c ++ " then " ++ indentD (toMessageDataAux t) ++ Format.line ++ "else " ++ indentD (toMessageDataAux e)
| Code.jmp _ j xs => "jmp " ++ j.simpMacroScopes ++ " " ++ toString xs.toList
| Code.«break» _ => "break " ++ updateVars
| Code.«continue» _ => "continue " ++ updateVars
| Code.«return» _ none => "return " ++ updateVars
| Code.«return» _ (some x) => "return " ++ x ++ " " ++ updateVars
| Code.«match» _ ds t alts =>
"match " ++ MessageData.joinSep (ds.toList.map MessageData.ofSyntax) ", " ++ " with " ++
alts.foldl
(fun (acc : MessageData) (alt : Alt Code) =>
acc ++ Format.line ++ "| "
++ MessageData.joinSep (alt.patterns.toList.map MessageData.ofSyntax) ", "
++ " => " ++ toMessageDataAux alt.rhs)
Format.nil
private def nameSetToArray (s : NameSet) : Array Name :=
s.fold (fun (xs : Array Name) x => xs.push x) #[]
def CodeBlock.toMessageData (c : CodeBlock) : MessageData :=
let us := (nameSetToArray c.uvars).toList.map MessageData.ofName;
toMessageDataAux (MessageData.ofList us) c.code
partial def getSomeRef : Code → Syntax
| Code.vdecl d _ _ => d.ref
| Code.jdecl d _ => d.ref
| Code.action e _ => e
| Code.ite ref _ _ _ => ref
| Code.jmp ref _ _ => ref
| Code.«break» ref => ref
| Code.«continue» ref => ref
| Code.«return» ref _ => ref
| Code.«match» ref _ _ _ => ref
partial def hasExitPoint : Code → Bool
| Code.vdecl _ _ k => hasExitPoint k
| Code.jdecl d k => hasExitPoint d.body || hasExitPoint k
| Code.action _ k => hasExitPoint k
| Code.ite _ _ t e => hasExitPoint t || hasExitPoint e
| Code.jmp _ _ _ => false
| Code.«break» _ => true
| Code.«continue» _ => true
| Code.«return» _ _ => true
| Code.«match» _ _ _ alts => alts.any fun alt => hasExitPoint alt.rhs
partial def convertReturnIntoJmpAux (jp : Name) (xs : Array Name) : Code → Code
| Code.vdecl d r k => Code.vdecl d r $ convertReturnIntoJmpAux k
| Code.jdecl d k => Code.jdecl { d with body := convertReturnIntoJmpAux d.body } $ convertReturnIntoJmpAux k
| Code.action e k => Code.action e $ convertReturnIntoJmpAux k
| Code.ite ref c t e => Code.ite ref c (convertReturnIntoJmpAux t) (convertReturnIntoJmpAux e)
| Code.«match» ref ds t alts => Code.«match» ref ds t $ alts.map fun alt => { alt with rhs := convertReturnIntoJmpAux alt.rhs }
| Code.«return» ref _ => Code.jmp ref jp xs
| c => c
/- Convert `return _ x` instructions in `c` into `jmp _ jp xs`. -/
def convertReturnIntoJmp (c : Code) (jp : Name) (xs : Array Name) : Code :=
convertReturnIntoJmpAux jp xs c
def mkJPDecls (jpDecls : Array (JPDecl Code)) (k : Code) : Code :=
jpDecls.foldr (fun jp r => Code.jdecl jp r) k
def mkFreshJP (ref : Syntax) (ps : Array Name) (body : Code) : TermElabM (JPDecl Code) := do
name ← mkFreshUserName `jp;
pure { ref := ref, name := name, params := ps, body := body }
def addFreshJP (ref : Syntax) (ps : Array Name) (body : Code) : StateRefT (Array (JPDecl Code)) TermElabM Name := do
jp ← liftM $ mkFreshJP ref ps body;
modify fun (jps : Array (JPDecl Code)) => jps.push jp;
pure jp.name
/- `pullExitPointsAux rs c` auxiliary method for `pullExitPoints`, `rs` is the set of update variable in the current path. -/
partial def pullExitPointsAux : NameSet → Code → StateRefT (Array (JPDecl Code)) TermElabM Code
| rs, Code.vdecl d false k => Code.vdecl d false <$> pullExitPointsAux (rs.erase d.name) k
| rs, Code.vdecl d true k => Code.vdecl d true <$> pullExitPointsAux (rs.insert d.name) k
| rs, Code.jdecl d k => do b ← pullExitPointsAux rs d.body; Code.jdecl { d with body := b } <$> pullExitPointsAux rs k
| rs, Code.action e k => Code.action e <$> pullExitPointsAux rs k
| rs, Code.ite ref c t e => Code.ite ref c <$> pullExitPointsAux rs t <*> pullExitPointsAux rs e
| rs, Code.«match» ref ds t alts => Code.«match» ref ds t <$> alts.mapM fun alt => do rhs ← pullExitPointsAux rs alt.rhs; pure { alt with rhs := rhs }
| rs, c@(Code.jmp _ _ _) => pure c
| rs, Code.«break» ref => do let xs := nameSetToArray rs; jp ← addFreshJP ref xs (Code.«break» ref); pure $ Code.jmp ref jp xs
| rs, Code.«continue» ref => do let xs := nameSetToArray rs; jp ← addFreshJP ref xs (Code.«continue» ref); pure $ Code.jmp ref jp xs
| rs, Code.«return» ref y? => do
let xs := nameSetToArray rs;
(ps, xs) ← match y? with
| none => pure (xs, xs)
| some y =>
if rs.contains y then pure (xs, xs)
else do {
yFresh ← mkFreshUserName y;
pure (xs.push y, xs.push yFresh)
};
jp ← addFreshJP ref ps (Code.«return» ref y?);
pure $ Code.jmp ref jp xs
/-
Auxiliary operation for adding new variables to `c.uvars` (updated variables).
When a new variable is not already in `c.uvars`, but is shadowed by some declaration in `c.code`,
we create auxiliary join points to make sure we preserve the semantics of the code block.
Example: suppose we have the code block `print x; let x := 10; return x`. And we want to extend it
with the reassignment `x := x + 1`. We first use `pullExitPoints` to create
```
let jp (x!1) := return x!1;
print x;
let x := 10;
jmp jp x
```
and then we add the reassignment
```
x := x + 1
let jp (x!1) := return x!1;
print x;
let x := 10;
jmp jp x
```
Note that we created a fresh variable `x!1` to avoid accidental name capture.
```
print x;
let x := 10
y := y + 1;
return x;
```
We transform it into
```
let jp (y x!1) := return x!1;
print x;
let x := 10
y := y + 1;
jmp jp y x
```
and then we add the reassignment as in the previous example.
We need to include `y` in the jump, because each exit point is implicitly returning the set of
update variables.
We implement the method as follows. Let `us` be `c.uvars`, then
1- for each `return _ y` in `c`, we create a join point
`let j (us y!1) := return y!1`
and replace the `return _ y` with `jmp us y`
2- for each `break`, we create a join point
`let j (us) := break`
and replace the `break` with `jmp us`.
3- Same as 2 for `continue`.
-/
def pullExitPoints (c : Code) : TermElabM Code :=
if hasExitPoint c then do
(c, jpDecls) ← (pullExitPointsAux {} c).run #[];
pure $ mkJPDecls jpDecls c
else
pure c
partial def extendUpdatedVarsAux (ws : NameSet) : Code → TermElabM Code
| Code.jdecl d k => do b ← extendUpdatedVarsAux d.body; Code.jdecl { d with body := b } <$> extendUpdatedVarsAux k
| Code.action e k => Code.action e <$> extendUpdatedVarsAux k
| Code.ite ref c t e => Code.ite ref c <$> extendUpdatedVarsAux t <*> extendUpdatedVarsAux e
| Code.«match» ref ds t alts => Code.«match» ref ds t <$> alts.mapM fun alt => do rhs ← extendUpdatedVarsAux alt.rhs; pure { alt with rhs := rhs }
| Code.vdecl d true k => Code.vdecl d true <$> extendUpdatedVarsAux k
| c@(Code.vdecl d false k) =>
if ws.contains d.name then
-- This `let` declaration is shadowing a variable in ws
pullExitPoints c
else
Code.vdecl d false <$> extendUpdatedVarsAux k
| c => pure c
/-
Extend the set of updated variables. It assumes `ws` is a super set of `c.uvars`.
We **cannot** simply update the field `c.uvars`, because `c` may have shadowed some variable in `ws`.
See discussion at `pullExitPoints`.
-/
def extendUpdatedVars (c : CodeBlock) (ws : NameSet) : TermElabM CodeBlock :=
if ws.any fun x => !c.uvars.contains x then do
-- `ws` contains a variable that is not in `c.uvars`, but in `c.dvars` (i.e., it has been shadowed)
code ← extendUpdatedVarsAux ws c.code;
pure { code := code, uvars := ws }
else
pure { c with uvars := ws }
private def union (s₁ s₂ : NameSet) : NameSet :=
s₁.fold (fun (s : NameSet) x => s.insert x) s₂
/-
Given two code blocks `c₁` and `c₂`, make sure they have the same set of updated variables.
Let `ws` the union of the updated variables in `c₁ and c₂`.
We use `extendUpdatedVars c₁ ws` and `extendUpdatedVars c₂ ws`
-/
def homogenize (c₁ c₂ : CodeBlock) : TermElabM (CodeBlock × CodeBlock) := do
let ws := union c₁.uvars c₂.uvars;
c₁ ← extendUpdatedVars c₁ ws;
c₂ ← extendUpdatedVars c₂ ws;
pure (c₁, c₂)
/-
Extending code blocks with variable declarations: `let x : t := v` and `let x : t ← v`.
We remove `x` from the collection of updated varibles.
-/
def mkVarDecl (d : VarDecl) (c : CodeBlock) : CodeBlock :=
let x := d.name;
{ code := Code.vdecl d false c.code, uvars := c.uvars.erase x }
/-
Extending code blocks with reassignments: `x : t := v` and `x : t ← v`.
-/
def mkReassign (d : VarDecl) (c : CodeBlock) : TermElabM CodeBlock := do
let x := d.name;
let ws := c.uvars.insert x;
-- We must pull "exit points" IF `x` is not in `c.uvars`, but is shadowed by a declaration in `c`
-- See discussion at `pullExitPoints`
code ← if !c.uvars.contains x then extendUpdatedVarsAux ws c.code else pure c.code;
pure { code := Code.vdecl d true code, uvars := ws }
def mkAction (action : Syntax) (c : CodeBlock) : CodeBlock :=
{ c with code := Code.action action c.code }
def mkReturn (ref : Syntax) (x? : Option Name := none) : CodeBlock :=
{ code := Code.«return» ref x? }
def mkBreak (ref : Syntax) : CodeBlock :=
{ code := Code.«break» ref }
def mkContinue (ref : Syntax) : CodeBlock :=
{ code := Code.«continue» ref }
def mkIte (ref : Syntax) (c : Syntax) (thenBranch : CodeBlock) (elseBranch : CodeBlock) : TermElabM CodeBlock := do
(thenBranch, elseBranch) ← homogenize thenBranch elseBranch;
pure {
code := Code.ite ref c thenBranch.code elseBranch.code,
uvars := thenBranch.uvars,
}
/- Return a code block that executes `terminal` and then `k`.
This method assumes `terminal` is a terminal -/
def concat (terminal : CodeBlock) (k : CodeBlock) : TermElabM CodeBlock := do
(terminal, k) ← homogenize terminal k;
let xs := nameSetToArray k.uvars;
jpDecl ← mkFreshJP (getSomeRef k.code) xs k.code;
let jp := jpDecl.name;
pure {
code := Code.jdecl jpDecl (convertReturnIntoJmp terminal.code jp xs),
uvars := terminal.uvars,
}
def mkWhen (ref : Syntax) (cond : Syntax) (c : CodeBlock) : CodeBlock :=
{ c with code := Code.ite ref cond c.code (Code.«return» ref none) }
def mkUnless (ref : Syntax) (cond : Syntax) (c : CodeBlock) : CodeBlock :=
{ c with code := Code.ite ref cond (Code.«return» ref none) c.code }
private def mkTuple (elems : Array Syntax) : MacroM Syntax :=
if elems.size == 1 then pure (elems.get! 0)
@ -90,123 +335,8 @@ else
(fun elem tuple => `(($elem, $tuple)))
(elems.back)
/-
Extending code blocks with variable declarations: `let x : t := v` and `let x : t ← v`.
Suppose we have a code block `C[us]`, and we want to extend it with the
`let x : t := v` declaration. We first remove `x` from the collection of updated variables `us`, obtaining `us'`
and return:
```
Code.vdecl _ x t true v C[us']
```
The operation is the same for `let x : t ← v`, but we set `pure` with `false`.
-/
/-
Extending code blocks with reassignments: `x : t := v` and `x : t ← v`.
Suppose we have a code block `C[us]`, and we want to extend it with the
`x : t := v` reassignment. If `x` is in `us`, then we just return
```
Code.vdecl _ x t true v C[us]
```
If `x` is not in `us`, we create a C'[x, us] in the following way
1- for each `return _ y us` occurring in `C[us]`, we create a join point
`let j (y us) := return y [x, us]`
and we replace the `return _ y us` with `jmp y us`
2- for each `break us` occurring in `C[us]`, we create a join point
`let j (us) := break [x, us]`
and we replace the `break us` with `jmp us`.
3- Same as 2 for `continue us`
Finally, we return
```
Code.vdecl _ x t true v C'[x, us]
```
Note that it would be incorrect to just add `x` to the set of updated variables of each `break`, `continue`, and `return`.
The problem is that `C` may have shadowed `x`. As an example, consider the following piece of code
```
let x ← action₁; -- declares 'x'
x := x + 1; -- reassigns 'x'
IO.println x;
let x ← action₂; -- shadows previous x
IO.println x
```
The code block `C` for
```
IO.println x;
let x ← action₂; -- shadows previous x
IO.println x
```
is
```
Code.action (IO.println x) $
Code.vdecl _ x _ false action₂ $
Code.action (IO.println x) $
Code.return _ none []
```
Here is the incorrect way of extending it with the assignment `x := x + 1`.
```
Code.vdecl _ x _ true (x+1) $
Code.action (IO.println x) $
Code.vdecl _ x _ false action₂ $
Code.action (IO.println x) $
Code.return _ none [x]
```
The code above incorrectly returns the shadowed `x` as the updated value for `x`.
The process above using join-point produces the correct result:
```
Code.vdecl _ x _ true (x+1) $
Code.jdecl _ j [] (Code.return _ none [x]) $
Code.action (IO.println x) $
Code.vdecl _ x _ false action₂ $
Code.action (IO.println x) $
Code.jmp _ j []
```
The join point `j` returns the correct `x`.
-/
/-
Combining two code-blocks `C[us]` `D[vs]` into an if-then-else with condition `c`.
If `us == vs`, then it is easy. We just return:
```
Code.ite _ c C[us] D[us]
```
Otherwise, let `ws` be the union of `us` and `vs`. The for each `return`, `continue`, and `break` occurring in `C[us]` and `D[vs]`, we create
an auxiliary join point using a process similar to the one we used for extending code-blocks with reassignment operations.
For example, for a `break us` in `C[us]` we create a join point
```
Code.jdecl _ j [us] (Code.break [ws]) $ ...
```
and replace `break us` with `jmp _ j us`.
We call this operation `homogenise : Code → Code → Code × Code`. It takes two code blocks and returns two new code blocks that have the same
collection of updated variables.
Given `(C'[ws], D'[ws]) := homogenize C[us] D[vs]`, we return
```
ite c C'[ws] D'[ws]
```
The process of creating `match` terminal is similar.
-/
/-
We say a code-block `C[us]` is "terminal-like" if it is a sequence of join-point declarations followed by a `Code.ite` or `Code.match`.
That is, `C[us]` is obtained by the `mkIte` and `mkMatch` primitives.
For concatenating two joint points `C[us]` `D[vs]`, where `C[us]` is a terminal-like code block, we first consider the simpler case where `us == vs`,
then we use `homogenize` for implementing the general case.
If `us == vs`, we first create a joint point `j` for `D[us]`, and then replace each `return _ _ [us]` in `C[us]` with a `jmp j`, obtaining `C'[us]`.
The result is like
```
Code.jdecl _ j [] (D[us]) $
C'[us]
```
-/
end Do
structure ExtractMonadResult :=
(m : Expr)
(α : Expr)

View file

@ -0,0 +1,35 @@
import Lean
new_frontend
namespace Lean.Elab.Term.Do
def ref := Syntax.missing
def vdecl (name : Name) (pure := true) : VarDecl :=
{ ref := ref, name := name, pure := pure, letDecl := Syntax.missing }
def print (c : CodeBlock) : TermElabM Unit := do
let msg := c.toMessageData
let msg ← addMessageContext msg
IO.println (← liftIO msg.toString)
pure ()
def tst : TermElabM Unit := do
let x := mkIdentFrom ref `x
let c ← mkIte ref (← `($x < 1))
(mkVarDecl (vdecl `w) (mkVarDecl (vdecl `z) (← mkReassign (vdecl `x) (mkReturn ref))))
(mkVarDecl (vdecl `x) (← mkReassign (vdecl `y) (mkBreak ref)))
print c
IO.println "-----"
let c ← concat c (mkVarDecl (vdecl `w) (← mkReassign (vdecl `z) (mkReturn ref)))
print c
let c ← mkReassign (vdecl `w) c
IO.println "-----"
print c
pure ()
#eval tst
end Lean.Elab.Term.Do