From 94c7945bd3355bbb503e6cc2a7dd8cfed30a51f0 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 30 Sep 2020 19:20:16 -0700 Subject: [PATCH] feat: `do` code blocks WIP --- src/Lean/Elab/Do.lean | 418 +++++++++++++++++++++----------- tests/lean/run/doCodeBlock.lean | 35 +++ 2 files changed, 309 insertions(+), 144 deletions(-) create mode 100644 tests/lean/run/doCodeBlock.lean diff --git a/src/Lean/Elab/Do.lean b/src/Lean/Elab/Do.lean index 91cdf5618e..2541fb0ea5 100644 --- a/src/Lean/Elab/Do.lean +++ b/src/Lean/Elab/Do.lean @@ -18,6 +18,12 @@ namespace Do structure Alt (σ : Type) := (ref : Syntax) (patterns : Array Syntax) (rhs : σ) +structure VarDecl := +(ref : Syntax) (name : Name) (pure : Bool) (letDecl : Syntax) + +structure JPDecl (σ : Type) := +(ref : Syntax) (name : Name) (params : Array Name) (body : σ) + /- Auxiliary datastructure for representing a `do` code block. We convert `Code` into a `Syntax` term representing the: @@ -39,49 +45,288 @@ structure Alt (σ : Type) := - `match`: pattern matching - `jmp` a goto to a join-point - We store the set of updated variables `uvars` in the terminals `break`, `continue`, and `return`. + We say `break`, `continue` and `return` are "exit points" + The terminal `return` also contains the name of the variable containing the result of the computation. We ignore this value when inside a `for x in s`. A code block `C` is well-formed if - 1- The collection of updated variables is the same in all `break` - `continue` and `return` in `C`. - - 2- For every `jmp r j as` in `C`, there is a `jdecl r j ps b k` s.t. `jmp r j` is in `k`, and - `ps.size == as.size` - - 3- The update variables occurring in `break`, `continue`, and `return` are pairwise distinct. - - We use the notation `C[u_1, ..., u_k]` to denote a code block that updates variables `u_1, ..., u_k` - + - For every `jmp r j as` in `C`, there is a `jdecl r j ps b k` s.t. `jmp r j` is in `k`, and + `ps.size == as.size` -/ inductive Code -| vdecl (ref : Syntax) (id : Name) (type : Syntax) (pure : Bool) (val : Syntax) (cont : Code) -| jdecl (ref : Syntax) (id : Name) (params : Array Name) (body : Code) (cont : Code) +| vdecl (decl : VarDecl) (reassignment : Bool) (cont : Code) +| jdecl (decl : JPDecl Code) (cont : Code) | action (term : Syntax) (cond : Code) -| «break» (ref : Syntax) (uvars : Array Name) -| «continue» (ref : Syntax) (uvars : Array Name) -| «return» (ref : Syntax) (var? : Option Name) (uvars : Array Name) +| «break» (ref : Syntax) +| «continue» (ref : Syntax) +| «return» (ref : Syntax) (var? : Option Name) | ite (ref : Syntax) (cond : Syntax) (thenBranch : Code) (elseBranch : Code) | «match» (ref : Syntax) (discrs : Array Syntax) (type? : Option Syntax) (alts : Array (Alt Code)) | jmp (ref : Syntax) (jpName : Name) (args : Array Name) -instance body.inhabited : Inhabited Code := -⟨Code.«break» (arbitrary _) #[]⟩ +instance Code.inhabited : Inhabited Code := +⟨Code.«break» (arbitrary _)⟩ -instance alt.inhabited : Inhabited (Alt Code) := +instance Alt.inhabited : Inhabited (Alt Code) := ⟨{ ref := arbitrary _, patterns := #[], rhs := arbitrary _ }⟩ -partial def getUpdatedVars? : Code → Option (Array Name) -| Code.vdecl _ _ _ _ _ k => getUpdatedVars? k -| Code.jdecl _ _ _ b k => getUpdatedVars? b <|> getUpdatedVars? k -| Code.action _ k => getUpdatedVars? k -| Code.«break» _ uvars => some uvars -| Code.«continue» _ uvars => some uvars -| Code.«return» _ _ uvars => some uvars -| Code.ite _ _ t e => getUpdatedVars? t <|> getUpdatedVars? e -| Code.«match» _ _ _ alts => alts.findSome? fun alt => getUpdatedVars? alt.rhs -| Code.jmp _ _ _ => none +/- A code block, and the collection of variables updated by it. -/ +structure CodeBlock := +(code : Code) +(uvars : NameSet := {}) -- set of variables updated by `code` + +partial def toMessageDataAux (updateVars : MessageData) : Code → MessageData +| Code.vdecl d r k => + (if r then "" else "let ") ++ d.name ++ " " ++ (if d.pure then ":=" else "←") ++ " ... " ++ Format.line ++ toMessageDataAux k +| Code.jdecl d k => + "let " ++ d.name.simpMacroScopes ++ " " ++ toString d.params.toList ++ ":=" ++ indentD (toMessageDataAux d.body) ++ Format.line ++ toMessageDataAux k +| Code.action e k => e ++ Format.line ++ toMessageDataAux k +| Code.ite _ c t e => "if " ++ c ++ " then " ++ indentD (toMessageDataAux t) ++ Format.line ++ "else " ++ indentD (toMessageDataAux e) +| Code.jmp _ j xs => "jmp " ++ j.simpMacroScopes ++ " " ++ toString xs.toList +| Code.«break» _ => "break " ++ updateVars +| Code.«continue» _ => "continue " ++ updateVars +| Code.«return» _ none => "return " ++ updateVars +| Code.«return» _ (some x) => "return " ++ x ++ " " ++ updateVars +| Code.«match» _ ds t alts => + "match " ++ MessageData.joinSep (ds.toList.map MessageData.ofSyntax) ", " ++ " with " ++ + alts.foldl + (fun (acc : MessageData) (alt : Alt Code) => + acc ++ Format.line ++ "| " + ++ MessageData.joinSep (alt.patterns.toList.map MessageData.ofSyntax) ", " + ++ " => " ++ toMessageDataAux alt.rhs) + Format.nil + +private def nameSetToArray (s : NameSet) : Array Name := +s.fold (fun (xs : Array Name) x => xs.push x) #[] + +def CodeBlock.toMessageData (c : CodeBlock) : MessageData := +let us := (nameSetToArray c.uvars).toList.map MessageData.ofName; +toMessageDataAux (MessageData.ofList us) c.code + +partial def getSomeRef : Code → Syntax +| Code.vdecl d _ _ => d.ref +| Code.jdecl d _ => d.ref +| Code.action e _ => e +| Code.ite ref _ _ _ => ref +| Code.jmp ref _ _ => ref +| Code.«break» ref => ref +| Code.«continue» ref => ref +| Code.«return» ref _ => ref +| Code.«match» ref _ _ _ => ref + +partial def hasExitPoint : Code → Bool +| Code.vdecl _ _ k => hasExitPoint k +| Code.jdecl d k => hasExitPoint d.body || hasExitPoint k +| Code.action _ k => hasExitPoint k +| Code.ite _ _ t e => hasExitPoint t || hasExitPoint e +| Code.jmp _ _ _ => false +| Code.«break» _ => true +| Code.«continue» _ => true +| Code.«return» _ _ => true +| Code.«match» _ _ _ alts => alts.any fun alt => hasExitPoint alt.rhs + +partial def convertReturnIntoJmpAux (jp : Name) (xs : Array Name) : Code → Code +| Code.vdecl d r k => Code.vdecl d r $ convertReturnIntoJmpAux k +| Code.jdecl d k => Code.jdecl { d with body := convertReturnIntoJmpAux d.body } $ convertReturnIntoJmpAux k +| Code.action e k => Code.action e $ convertReturnIntoJmpAux k +| Code.ite ref c t e => Code.ite ref c (convertReturnIntoJmpAux t) (convertReturnIntoJmpAux e) +| Code.«match» ref ds t alts => Code.«match» ref ds t $ alts.map fun alt => { alt with rhs := convertReturnIntoJmpAux alt.rhs } +| Code.«return» ref _ => Code.jmp ref jp xs +| c => c + +/- Convert `return _ x` instructions in `c` into `jmp _ jp xs`. -/ +def convertReturnIntoJmp (c : Code) (jp : Name) (xs : Array Name) : Code := +convertReturnIntoJmpAux jp xs c + +def mkJPDecls (jpDecls : Array (JPDecl Code)) (k : Code) : Code := +jpDecls.foldr (fun jp r => Code.jdecl jp r) k + +def mkFreshJP (ref : Syntax) (ps : Array Name) (body : Code) : TermElabM (JPDecl Code) := do +name ← mkFreshUserName `jp; +pure { ref := ref, name := name, params := ps, body := body } + +def addFreshJP (ref : Syntax) (ps : Array Name) (body : Code) : StateRefT (Array (JPDecl Code)) TermElabM Name := do +jp ← liftM $ mkFreshJP ref ps body; +modify fun (jps : Array (JPDecl Code)) => jps.push jp; +pure jp.name + +/- `pullExitPointsAux rs c` auxiliary method for `pullExitPoints`, `rs` is the set of update variable in the current path. -/ +partial def pullExitPointsAux : NameSet → Code → StateRefT (Array (JPDecl Code)) TermElabM Code +| rs, Code.vdecl d false k => Code.vdecl d false <$> pullExitPointsAux (rs.erase d.name) k +| rs, Code.vdecl d true k => Code.vdecl d true <$> pullExitPointsAux (rs.insert d.name) k +| rs, Code.jdecl d k => do b ← pullExitPointsAux rs d.body; Code.jdecl { d with body := b } <$> pullExitPointsAux rs k +| rs, Code.action e k => Code.action e <$> pullExitPointsAux rs k +| rs, Code.ite ref c t e => Code.ite ref c <$> pullExitPointsAux rs t <*> pullExitPointsAux rs e +| rs, Code.«match» ref ds t alts => Code.«match» ref ds t <$> alts.mapM fun alt => do rhs ← pullExitPointsAux rs alt.rhs; pure { alt with rhs := rhs } +| rs, c@(Code.jmp _ _ _) => pure c +| rs, Code.«break» ref => do let xs := nameSetToArray rs; jp ← addFreshJP ref xs (Code.«break» ref); pure $ Code.jmp ref jp xs +| rs, Code.«continue» ref => do let xs := nameSetToArray rs; jp ← addFreshJP ref xs (Code.«continue» ref); pure $ Code.jmp ref jp xs +| rs, Code.«return» ref y? => do + let xs := nameSetToArray rs; + (ps, xs) ← match y? with + | none => pure (xs, xs) + | some y => + if rs.contains y then pure (xs, xs) + else do { + yFresh ← mkFreshUserName y; + pure (xs.push y, xs.push yFresh) + }; + jp ← addFreshJP ref ps (Code.«return» ref y?); + pure $ Code.jmp ref jp xs + +/- +Auxiliary operation for adding new variables to `c.uvars` (updated variables). +When a new variable is not already in `c.uvars`, but is shadowed by some declaration in `c.code`, +we create auxiliary join points to make sure we preserve the semantics of the code block. +Example: suppose we have the code block `print x; let x := 10; return x`. And we want to extend it +with the reassignment `x := x + 1`. We first use `pullExitPoints` to create +``` +let jp (x!1) := return x!1; +print x; +let x := 10; +jmp jp x +``` +and then we add the reassignment +``` +x := x + 1 +let jp (x!1) := return x!1; +print x; +let x := 10; +jmp jp x +``` +Note that we created a fresh variable `x!1` to avoid accidental name capture. + +``` +print x; +let x := 10 +y := y + 1; +return x; +``` +We transform it into +``` +let jp (y x!1) := return x!1; +print x; +let x := 10 +y := y + 1; +jmp jp y x +``` +and then we add the reassignment as in the previous example. +We need to include `y` in the jump, because each exit point is implicitly returning the set of +update variables. + +We implement the method as follows. Let `us` be `c.uvars`, then +1- for each `return _ y` in `c`, we create a join point + `let j (us y!1) := return y!1` + and replace the `return _ y` with `jmp us y` +2- for each `break`, we create a join point + `let j (us) := break` + and replace the `break` with `jmp us`. +3- Same as 2 for `continue`. +-/ +def pullExitPoints (c : Code) : TermElabM Code := +if hasExitPoint c then do + (c, jpDecls) ← (pullExitPointsAux {} c).run #[]; + pure $ mkJPDecls jpDecls c +else + pure c + +partial def extendUpdatedVarsAux (ws : NameSet) : Code → TermElabM Code +| Code.jdecl d k => do b ← extendUpdatedVarsAux d.body; Code.jdecl { d with body := b } <$> extendUpdatedVarsAux k +| Code.action e k => Code.action e <$> extendUpdatedVarsAux k +| Code.ite ref c t e => Code.ite ref c <$> extendUpdatedVarsAux t <*> extendUpdatedVarsAux e +| Code.«match» ref ds t alts => Code.«match» ref ds t <$> alts.mapM fun alt => do rhs ← extendUpdatedVarsAux alt.rhs; pure { alt with rhs := rhs } +| Code.vdecl d true k => Code.vdecl d true <$> extendUpdatedVarsAux k +| c@(Code.vdecl d false k) => + if ws.contains d.name then + -- This `let` declaration is shadowing a variable in ws + pullExitPoints c + else + Code.vdecl d false <$> extendUpdatedVarsAux k +| c => pure c + +/- +Extend the set of updated variables. It assumes `ws` is a super set of `c.uvars`. +We **cannot** simply update the field `c.uvars`, because `c` may have shadowed some variable in `ws`. +See discussion at `pullExitPoints`. +-/ +def extendUpdatedVars (c : CodeBlock) (ws : NameSet) : TermElabM CodeBlock := +if ws.any fun x => !c.uvars.contains x then do + -- `ws` contains a variable that is not in `c.uvars`, but in `c.dvars` (i.e., it has been shadowed) + code ← extendUpdatedVarsAux ws c.code; + pure { code := code, uvars := ws } +else + pure { c with uvars := ws } + +private def union (s₁ s₂ : NameSet) : NameSet := +s₁.fold (fun (s : NameSet) x => s.insert x) s₂ + +/- +Given two code blocks `c₁` and `c₂`, make sure they have the same set of updated variables. +Let `ws` the union of the updated variables in `c₁‵ and ‵c₂`. +We use `extendUpdatedVars c₁ ws` and `extendUpdatedVars c₂ ws` +-/ +def homogenize (c₁ c₂ : CodeBlock) : TermElabM (CodeBlock × CodeBlock) := do +let ws := union c₁.uvars c₂.uvars; +c₁ ← extendUpdatedVars c₁ ws; +c₂ ← extendUpdatedVars c₂ ws; +pure (c₁, c₂) + +/- +Extending code blocks with variable declarations: `let x : t := v` and `let x : t ← v`. +We remove `x` from the collection of updated varibles. +-/ +def mkVarDecl (d : VarDecl) (c : CodeBlock) : CodeBlock := +let x := d.name; +{ code := Code.vdecl d false c.code, uvars := c.uvars.erase x } + +/- +Extending code blocks with reassignments: `x : t := v` and `x : t ← v`. +-/ +def mkReassign (d : VarDecl) (c : CodeBlock) : TermElabM CodeBlock := do +let x := d.name; +let ws := c.uvars.insert x; +-- We must pull "exit points" IF `x` is not in `c.uvars`, but is shadowed by a declaration in `c` +-- See discussion at `pullExitPoints` +code ← if !c.uvars.contains x then extendUpdatedVarsAux ws c.code else pure c.code; +pure { code := Code.vdecl d true code, uvars := ws } + +def mkAction (action : Syntax) (c : CodeBlock) : CodeBlock := +{ c with code := Code.action action c.code } + +def mkReturn (ref : Syntax) (x? : Option Name := none) : CodeBlock := +{ code := Code.«return» ref x? } + +def mkBreak (ref : Syntax) : CodeBlock := +{ code := Code.«break» ref } + +def mkContinue (ref : Syntax) : CodeBlock := +{ code := Code.«continue» ref } + +def mkIte (ref : Syntax) (c : Syntax) (thenBranch : CodeBlock) (elseBranch : CodeBlock) : TermElabM CodeBlock := do +(thenBranch, elseBranch) ← homogenize thenBranch elseBranch; +pure { + code := Code.ite ref c thenBranch.code elseBranch.code, + uvars := thenBranch.uvars, +} + +/- Return a code block that executes `terminal` and then `k`. + This method assumes `terminal` is a terminal -/ +def concat (terminal : CodeBlock) (k : CodeBlock) : TermElabM CodeBlock := do +(terminal, k) ← homogenize terminal k; +let xs := nameSetToArray k.uvars; +jpDecl ← mkFreshJP (getSomeRef k.code) xs k.code; +let jp := jpDecl.name; +pure { + code := Code.jdecl jpDecl (convertReturnIntoJmp terminal.code jp xs), + uvars := terminal.uvars, +} + +def mkWhen (ref : Syntax) (cond : Syntax) (c : CodeBlock) : CodeBlock := +{ c with code := Code.ite ref cond c.code (Code.«return» ref none) } + +def mkUnless (ref : Syntax) (cond : Syntax) (c : CodeBlock) : CodeBlock := +{ c with code := Code.ite ref cond (Code.«return» ref none) c.code } private def mkTuple (elems : Array Syntax) : MacroM Syntax := if elems.size == 1 then pure (elems.get! 0) @@ -90,123 +335,8 @@ else (fun elem tuple => `(($elem, $tuple))) (elems.back) -/- -Extending code blocks with variable declarations: `let x : t := v` and `let x : t ← v`. - -Suppose we have a code block `C[us]`, and we want to extend it with the -`let x : t := v` declaration. We first remove `x` from the collection of updated variables `us`, obtaining `us'` -and return: -``` -Code.vdecl _ x t true v C[us'] -``` -The operation is the same for `let x : t ← v`, but we set `pure` with `false`. --/ - -/- -Extending code blocks with reassignments: `x : t := v` and `x : t ← v`. - -Suppose we have a code block `C[us]`, and we want to extend it with the -`x : t := v` reassignment. If `x` is in `us`, then we just return -``` -Code.vdecl _ x t true v C[us] -``` -If `x` is not in `us`, we create a C'[x, us] in the following way -1- for each `return _ y us` occurring in `C[us]`, we create a join point - `let j (y us) := return y [x, us]` - and we replace the `return _ y us` with `jmp y us` -2- for each `break us` occurring in `C[us]`, we create a join point - `let j (us) := break [x, us]` - and we replace the `break us` with `jmp us`. -3- Same as 2 for `continue us` -Finally, we return -``` -Code.vdecl _ x t true v C'[x, us] -``` - -Note that it would be incorrect to just add `x` to the set of updated variables of each `break`, `continue`, and `return`. -The problem is that `C` may have shadowed `x`. As an example, consider the following piece of code -``` -let x ← action₁; -- declares 'x' -x := x + 1; -- reassigns 'x' -IO.println x; -let x ← action₂; -- shadows previous x -IO.println x -``` -The code block `C` for -``` -IO.println x; -let x ← action₂; -- shadows previous x -IO.println x -``` -is -``` -Code.action (IO.println x) $ -Code.vdecl _ x _ false action₂ $ -Code.action (IO.println x) $ -Code.return _ none [] -``` -Here is the incorrect way of extending it with the assignment `x := x + 1`. -``` -Code.vdecl _ x _ true (x+1) $ -Code.action (IO.println x) $ -Code.vdecl _ x _ false action₂ $ -Code.action (IO.println x) $ -Code.return _ none [x] -``` -The code above incorrectly returns the shadowed `x` as the updated value for `x`. -The process above using join-point produces the correct result: -``` -Code.vdecl _ x _ true (x+1) $ -Code.jdecl _ j [] (Code.return _ none [x]) $ -Code.action (IO.println x) $ -Code.vdecl _ x _ false action₂ $ -Code.action (IO.println x) $ -Code.jmp _ j [] -``` -The join point `j` returns the correct `x`. --/ - -/- -Combining two code-blocks `C[us]` `D[vs]` into an if-then-else with condition `c`. -If `us == vs`, then it is easy. We just return: -``` -Code.ite _ c C[us] D[us] -``` -Otherwise, let `ws` be the union of `us` and `vs`. The for each `return`, `continue`, and `break` occurring in `C[us]` and `D[vs]`, we create -an auxiliary join point using a process similar to the one we used for extending code-blocks with reassignment operations. -For example, for a `break us` in `C[us]` we create a join point -``` -Code.jdecl _ j [us] (Code.break [ws]) $ ... -``` -and replace `break us` with `jmp _ j us`. -We call this operation `homogenise : Code → Code → Code × Code`. It takes two code blocks and returns two new code blocks that have the same -collection of updated variables. -Given `(C'[ws], D'[ws]) := homogenize C[us] D[vs]`, we return -``` -ite c C'[ws] D'[ws] -``` - -The process of creating `match` terminal is similar. - --/ - -/- -We say a code-block `C[us]` is "terminal-like" if it is a sequence of join-point declarations followed by a `Code.ite` or `Code.match`. -That is, `C[us]` is obtained by the `mkIte` and `mkMatch` primitives. - -For concatenating two joint points `C[us]` `D[vs]`, where `C[us]` is a terminal-like code block, we first consider the simpler case where `us == vs`, -then we use `homogenize` for implementing the general case. -If `us == vs`, we first create a joint point `j` for `D[us]`, and then replace each `return _ _ [us]` in `C[us]` with a `jmp j`, obtaining `C'[us]`. -The result is like -``` -Code.jdecl _ j [] (D[us]) $ -C'[us] -``` --/ - end Do - structure ExtractMonadResult := (m : Expr) (α : Expr) diff --git a/tests/lean/run/doCodeBlock.lean b/tests/lean/run/doCodeBlock.lean new file mode 100644 index 0000000000..b35ec87b90 --- /dev/null +++ b/tests/lean/run/doCodeBlock.lean @@ -0,0 +1,35 @@ +import Lean + +new_frontend + +namespace Lean.Elab.Term.Do + +def ref := Syntax.missing + +def vdecl (name : Name) (pure := true) : VarDecl := +{ ref := ref, name := name, pure := pure, letDecl := Syntax.missing } + +def print (c : CodeBlock) : TermElabM Unit := do +let msg := c.toMessageData +let msg ← addMessageContext msg +IO.println (← liftIO msg.toString) +pure () + +def tst : TermElabM Unit := do +let x := mkIdentFrom ref `x +let c ← mkIte ref (← `($x < 1)) + (mkVarDecl (vdecl `w) (mkVarDecl (vdecl `z) (← mkReassign (vdecl `x) (mkReturn ref)))) + (mkVarDecl (vdecl `x) (← mkReassign (vdecl `y) (mkBreak ref))) +print c +IO.println "-----" +let c ← concat c (mkVarDecl (vdecl `w) (← mkReassign (vdecl `z) (mkReturn ref))) +print c +let c ← mkReassign (vdecl `w) c +IO.println "-----" +print c +pure () + +#eval tst + + +end Lean.Elab.Term.Do