diff --git a/src/Lean/Elab/Do.lean b/src/Lean/Elab/Do.lean index 2541fb0ea5..8573e6e150 100644 --- a/src/Lean/Elab/Do.lean +++ b/src/Lean/Elab/Do.lean @@ -15,29 +15,17 @@ open Meta namespace Do +/- A `doMatch` alternative. `vars` is the array of variables declared by `patterns`. -/ structure Alt (σ : Type) := -(ref : Syntax) (patterns : Array Syntax) (rhs : σ) - -structure VarDecl := -(ref : Syntax) (name : Name) (pure : Bool) (letDecl : Syntax) - -structure JPDecl (σ : Type) := -(ref : Syntax) (name : Name) (params : Array Name) (body : σ) +(ref : Syntax) (vars : Array Name) (patterns : Array Syntax) (rhs : σ) /- - Auxiliary datastructure for representing a `do` code block. + Auxiliary datastructure for representing a `do` code block, and compiling "reassignments" (e.g., `x := x + 1`). We convert `Code` into a `Syntax` term representing the: - `do`-block, or - the visitor argument for the `forIn` combinator. - We have 2 kinds of declaration - - `vdecl`: variable declaration - - `jdecl`: join-point declaration - - and actions (e.g., `IO.println "hello"`) - - `action` - - We have 6 terminals + We say the following constructors are terminals: - `break`: for interrupting a `for x in s` - `continue`: for interrupting the current iteration of a `for x in s` - `return`: returning the result of the computation. @@ -45,23 +33,34 @@ structure JPDecl (σ : Type) := - `match`: pattern matching - `jmp` a goto to a join-point - We say `break`, `continue` and `return` are "exit points" + We say the terminals `break`, `continue` and `return` are "exit points" The terminal `return` also contains the name of the variable containing the result of the computation. We ignore this value when inside a `for x in s`. + - `decl` represents all declaration-like `doElem`s (e.g., `let`, `have`, `let rec`). The field `stx` is the actual `doElem`, + `vars` is the array of variables declared by it, and `cont` is the next instruction in the `do` code block. + `vars` is an array since we have declarations such as `let (a, b) := s`. + + - `reassign` is an reassignment-like `doElem` (e.g., `x := x + 1`). + + - `jointpoint` is a join point declaration: an auxiliary `let`-declaration used to represent the control-flow. + + - `action` is an action-like `doElem` (e.g., `IO.println "hello"`, `dbgTrace! "foo"`). + A code block `C` is well-formed if - - For every `jmp r j as` in `C`, there is a `jdecl r j ps b k` s.t. `jmp r j` is in `k`, and - `ps.size == as.size` --/ + - For every `jmp ref j as` in `C`, there is a `jointpoint j ps b k` and `jmp ref j as` is in `k`, and + `ps.size == as.size` -/ inductive Code -| vdecl (decl : VarDecl) (reassignment : Bool) (cont : Code) -| jdecl (decl : JPDecl Code) (cont : Code) -| action (term : Syntax) (cond : Code) +| decl (xs : Array Name) (stx : Syntax) (cont : Code) +| reassign (xs : Array Name) (stx : Syntax) (cont : Code) +| jointpoint (name : Name) (params : Array Name) (body : Code) (cont : Code) +| action (stx : Syntax) (cond : Code) | «break» (ref : Syntax) | «continue» (ref : Syntax) -| «return» (ref : Syntax) (var? : Option Name) -| ite (ref : Syntax) (cond : Syntax) (thenBranch : Code) (elseBranch : Code) +| «return» (ref : Syntax) (x? : Option Name) +/- Recall that an if-then-else may declare a variable using `optIdent` for the branches `thenBranch` and `elseBranch`. We store the variable name at `var?`. -/ +| ite (ref : Syntax) (h? : Option Name) (optIdent : Syntax) (cond : Syntax) (thenBranch : Code) (elseBranch : Code) | «match» (ref : Syntax) (discrs : Array Syntax) (type? : Option Syntax) (alts : Array (Alt Code)) | jmp (ref : Syntax) (jpName : Name) (args : Array Name) @@ -69,20 +68,24 @@ instance Code.inhabited : Inhabited Code := ⟨Code.«break» (arbitrary _)⟩ instance Alt.inhabited : Inhabited (Alt Code) := -⟨{ ref := arbitrary _, patterns := #[], rhs := arbitrary _ }⟩ +⟨{ ref := arbitrary _, vars := #[], patterns := #[], rhs := arbitrary _ }⟩ /- A code block, and the collection of variables updated by it. -/ structure CodeBlock := (code : Code) (uvars : NameSet := {}) -- set of variables updated by `code` +private def varsToMessageData (vars : Array Name) : MessageData := +MessageData.joinSep (vars.toList.map fun n => MessageData.ofName (n.simpMacroScopes)) " " + partial def toMessageDataAux (updateVars : MessageData) : Code → MessageData -| Code.vdecl d r k => - (if r then "" else "let ") ++ d.name ++ " " ++ (if d.pure then ":=" else "←") ++ " ... " ++ Format.line ++ toMessageDataAux k -| Code.jdecl d k => - "let " ++ d.name.simpMacroScopes ++ " " ++ toString d.params.toList ++ ":=" ++ indentD (toMessageDataAux d.body) ++ Format.line ++ toMessageDataAux k +| Code.decl xs _ k => "let " ++ varsToMessageData xs ++ " := ... " ++ Format.line ++ toMessageDataAux k +| Code.reassign xs _ k => varsToMessageData xs ++ " := ... " ++ Format.line ++ toMessageDataAux k +| Code.jointpoint n ps body k => + "let " ++ n.simpMacroScopes ++ " " ++ varsToMessageData ps ++ " := " ++ indentD (toMessageDataAux body) + ++ Format.line ++ toMessageDataAux k | Code.action e k => e ++ Format.line ++ toMessageDataAux k -| Code.ite _ c t e => "if " ++ c ++ " then " ++ indentD (toMessageDataAux t) ++ Format.line ++ "else " ++ indentD (toMessageDataAux e) +| Code.ite _ _ _ c t e => "if " ++ c ++ " then " ++ indentD (toMessageDataAux t) ++ Format.line ++ "else " ++ indentD (toMessageDataAux e) | Code.jmp _ j xs => "jmp " ++ j.simpMacroScopes ++ " " ++ toString xs.toList | Code.«break» _ => "break " ++ updateVars | Code.«continue» _ => "continue " ++ updateVars @@ -104,22 +107,12 @@ def CodeBlock.toMessageData (c : CodeBlock) : MessageData := let us := (nameSetToArray c.uvars).toList.map MessageData.ofName; toMessageDataAux (MessageData.ofList us) c.code -partial def getSomeRef : Code → Syntax -| Code.vdecl d _ _ => d.ref -| Code.jdecl d _ => d.ref -| Code.action e _ => e -| Code.ite ref _ _ _ => ref -| Code.jmp ref _ _ => ref -| Code.«break» ref => ref -| Code.«continue» ref => ref -| Code.«return» ref _ => ref -| Code.«match» ref _ _ _ => ref - partial def hasExitPoint : Code → Bool -| Code.vdecl _ _ k => hasExitPoint k -| Code.jdecl d k => hasExitPoint d.body || hasExitPoint k +| Code.decl _ _ k => hasExitPoint k +| Code.reassign _ _ k => hasExitPoint k +| Code.jointpoint _ _ b k => hasExitPoint b || hasExitPoint k | Code.action _ k => hasExitPoint k -| Code.ite _ _ t e => hasExitPoint t || hasExitPoint e +| Code.ite _ _ _ _ t e => hasExitPoint t || hasExitPoint e | Code.jmp _ _ _ => false | Code.«break» _ => true | Code.«continue» _ => true @@ -127,10 +120,11 @@ partial def hasExitPoint : Code → Bool | Code.«match» _ _ _ alts => alts.any fun alt => hasExitPoint alt.rhs partial def convertReturnIntoJmpAux (jp : Name) (xs : Array Name) : Code → Code -| Code.vdecl d r k => Code.vdecl d r $ convertReturnIntoJmpAux k -| Code.jdecl d k => Code.jdecl { d with body := convertReturnIntoJmpAux d.body } $ convertReturnIntoJmpAux k +| Code.decl xs stx k => Code.decl xs stx $ convertReturnIntoJmpAux k +| Code.reassign xs stx k => Code.reassign xs stx $ convertReturnIntoJmpAux k +| Code.jointpoint n ps b k => Code.jointpoint n ps (convertReturnIntoJmpAux b) (convertReturnIntoJmpAux k) | Code.action e k => Code.action e $ convertReturnIntoJmpAux k -| Code.ite ref c t e => Code.ite ref c (convertReturnIntoJmpAux t) (convertReturnIntoJmpAux e) +| Code.ite ref x? h c t e => Code.ite ref x? h c (convertReturnIntoJmpAux t) (convertReturnIntoJmpAux e) | Code.«match» ref ds t alts => Code.«match» ref ds t $ alts.map fun alt => { alt with rhs := convertReturnIntoJmpAux alt.rhs } | Code.«return» ref _ => Code.jmp ref jp xs | c => c @@ -139,30 +133,48 @@ partial def convertReturnIntoJmpAux (jp : Name) (xs : Array Name) : Code → Cod def convertReturnIntoJmp (c : Code) (jp : Name) (xs : Array Name) : Code := convertReturnIntoJmpAux jp xs c -def mkJPDecls (jpDecls : Array (JPDecl Code)) (k : Code) : Code := -jpDecls.foldr (fun jp r => Code.jdecl jp r) k +structure JPDecl := +(name : Name) (params : Array Name) (body : Code) -def mkFreshJP (ref : Syntax) (ps : Array Name) (body : Code) : TermElabM (JPDecl Code) := do +def attachJP (jpDecl : JPDecl) (k : Code) : Code := +Code.jointpoint jpDecl.name jpDecl.params jpDecl.body k + +def attachJPs (jpDecls : Array JPDecl) (k : Code) : Code := +jpDecls.foldr attachJP k + +def mkFreshJP (ps : Array Name) (body : Code) : TermElabM JPDecl := do name ← mkFreshUserName `jp; -pure { ref := ref, name := name, params := ps, body := body } +pure { name := name, params := ps, body := body } -def addFreshJP (ref : Syntax) (ps : Array Name) (body : Code) : StateRefT (Array (JPDecl Code)) TermElabM Name := do -jp ← liftM $ mkFreshJP ref ps body; -modify fun (jps : Array (JPDecl Code)) => jps.push jp; +def addFreshJP (ps : Array Name) (body : Code) : StateRefT (Array JPDecl) TermElabM Name := do +jp ← liftM $ mkFreshJP ps body; +modify fun (jps : Array JPDecl) => jps.push jp; pure jp.name +def insertVars (rs : NameSet) (xs : Array Name) : NameSet := +xs.foldl (fun (rs : NameSet) x => rs.insert x) rs + +def eraseVars (rs : NameSet) (xs : Array Name) : NameSet := +xs.foldl (fun (rs : NameSet) x => rs.erase x) rs + +def eraseOptVar (rs : NameSet) (x? : Option Name) : NameSet := +match x? with +| none => rs +| some x => rs.insert x + /- `pullExitPointsAux rs c` auxiliary method for `pullExitPoints`, `rs` is the set of update variable in the current path. -/ -partial def pullExitPointsAux : NameSet → Code → StateRefT (Array (JPDecl Code)) TermElabM Code -| rs, Code.vdecl d false k => Code.vdecl d false <$> pullExitPointsAux (rs.erase d.name) k -| rs, Code.vdecl d true k => Code.vdecl d true <$> pullExitPointsAux (rs.insert d.name) k -| rs, Code.jdecl d k => do b ← pullExitPointsAux rs d.body; Code.jdecl { d with body := b } <$> pullExitPointsAux rs k -| rs, Code.action e k => Code.action e <$> pullExitPointsAux rs k -| rs, Code.ite ref c t e => Code.ite ref c <$> pullExitPointsAux rs t <*> pullExitPointsAux rs e -| rs, Code.«match» ref ds t alts => Code.«match» ref ds t <$> alts.mapM fun alt => do rhs ← pullExitPointsAux rs alt.rhs; pure { alt with rhs := rhs } +partial def pullExitPointsAux : NameSet → Code → StateRefT (Array JPDecl) TermElabM Code +| rs, Code.decl xs stx k => Code.decl xs stx <$> pullExitPointsAux (eraseVars rs xs) k +| rs, Code.reassign xs stx k => Code.reassign xs stx <$> pullExitPointsAux (insertVars rs xs) k +| rs, Code.jointpoint j ps b k => Code.jointpoint j ps <$> pullExitPointsAux rs b <*> pullExitPointsAux rs k +| rs, Code.action e k => Code.action e <$> pullExitPointsAux rs k +| rs, Code.ite ref x? o c t e => Code.ite ref x? o c <$> pullExitPointsAux (eraseOptVar rs x?) t <*> pullExitPointsAux (eraseOptVar rs x?) e +| rs, Code.«match» ref ds t alts => Code.«match» ref ds t <$> alts.mapM fun alt => do + rhs ← pullExitPointsAux (eraseVars rs alt.vars) alt.rhs; pure { alt with rhs := rhs } | rs, c@(Code.jmp _ _ _) => pure c -| rs, Code.«break» ref => do let xs := nameSetToArray rs; jp ← addFreshJP ref xs (Code.«break» ref); pure $ Code.jmp ref jp xs -| rs, Code.«continue» ref => do let xs := nameSetToArray rs; jp ← addFreshJP ref xs (Code.«continue» ref); pure $ Code.jmp ref jp xs -| rs, Code.«return» ref y? => do +| rs, Code.«break» ref => do let xs := nameSetToArray rs; jp ← addFreshJP xs (Code.«break» ref); pure $ Code.jmp ref jp xs +| rs, Code.«continue» ref => do let xs := nameSetToArray rs; jp ← addFreshJP xs (Code.«continue» ref); pure $ Code.jmp ref jp xs +| rs, Code.«return» ref y? => do let xs := nameSetToArray rs; (ps, xs) ← match y? with | none => pure (xs, xs) @@ -172,7 +184,7 @@ partial def pullExitPointsAux : NameSet → Code → StateRefT (Array (JPDecl Co yFresh ← mkFreshUserName y; pure (xs.push y, xs.push yFresh) }; - jp ← addFreshJP ref ps (Code.«return» ref y?); + jp ← addFreshJP ps (Code.«return» ref y?); pure $ Code.jmp ref jp xs /- @@ -227,23 +239,35 @@ We implement the method as follows. Let `us` be `c.uvars`, then def pullExitPoints (c : Code) : TermElabM Code := if hasExitPoint c then do (c, jpDecls) ← (pullExitPointsAux {} c).run #[]; - pure $ mkJPDecls jpDecls c + pure $ attachJPs jpDecls c else pure c partial def extendUpdatedVarsAux (ws : NameSet) : Code → TermElabM Code -| Code.jdecl d k => do b ← extendUpdatedVarsAux d.body; Code.jdecl { d with body := b } <$> extendUpdatedVarsAux k -| Code.action e k => Code.action e <$> extendUpdatedVarsAux k -| Code.ite ref c t e => Code.ite ref c <$> extendUpdatedVarsAux t <*> extendUpdatedVarsAux e -| Code.«match» ref ds t alts => Code.«match» ref ds t <$> alts.mapM fun alt => do rhs ← extendUpdatedVarsAux alt.rhs; pure { alt with rhs := rhs } -| Code.vdecl d true k => Code.vdecl d true <$> extendUpdatedVarsAux k -| c@(Code.vdecl d false k) => - if ws.contains d.name then - -- This `let` declaration is shadowing a variable in ws +| Code.jointpoint j ps b k => Code.jointpoint j ps <$> extendUpdatedVarsAux b <*> extendUpdatedVarsAux k +| Code.action e k => Code.action e <$> extendUpdatedVarsAux k +| c@(Code.«match» ref ds t alts) => + if alts.any fun alt => alt.vars.any fun x => ws.contains x then + -- If a pattern variable is shadowing a variable in ws, we `pullExitPoints` pullExitPoints c else - Code.vdecl d false <$> extendUpdatedVarsAux k -| c => pure c + Code.«match» ref ds t <$> alts.mapM fun alt => do rhs ← extendUpdatedVarsAux alt.rhs; pure { alt with rhs := rhs } +| Code.ite ref none o c t e => + Code.ite ref none o c <$> extendUpdatedVarsAux t <*> extendUpdatedVarsAux e +| c@(Code.ite ref (some h) o cond t e) => + if ws.contains h then + -- if the `h` at `if h:c then t else e` shadows a variable in `ws`, we `pullExitPoints` + pullExitPoints c + else + Code.ite ref (some h) o cond <$> extendUpdatedVarsAux t <*> extendUpdatedVarsAux e +| Code.reassign xs stx k => Code.reassign xs stx <$> extendUpdatedVarsAux k +| c@(Code.decl xs stx k) => + if xs.any fun x => ws.contains x then + -- One the declared variables is shadowing a variable in `ws` + pullExitPoints c + else + Code.decl xs stx <$> extendUpdatedVarsAux k +| c => pure c /- Extend the set of updated variables. It assumes `ws` is a super set of `c.uvars`. @@ -275,21 +299,26 @@ pure (c₁, c₂) /- Extending code blocks with variable declarations: `let x : t := v` and `let x : t ← v`. We remove `x` from the collection of updated varibles. +Remark: `stx` is the syntax for the declaration (e.g., `letDecl`), and `xs` are the variables +declared by it. It is an array because we have let-declarations that declare multiple variables. +Example: `let (x, y) := t` -/ -def mkVarDecl (d : VarDecl) (c : CodeBlock) : CodeBlock := -let x := d.name; -{ code := Code.vdecl d false c.code, uvars := c.uvars.erase x } +def mkVarDeclCore (xs : Array Name) (stx : Syntax) (c : CodeBlock) : CodeBlock := +{ code := Code.decl xs stx c.code, uvars := eraseVars c.uvars xs } /- Extending code blocks with reassignments: `x : t := v` and `x : t ← v`. +Remark: `stx` is the syntax for the declaration (e.g., `letDecl`), and `xs` are the variables +declared by it. It is an array because we have let-declarations that declare multiple variables. +Example: `(x, y) ← t` -/ -def mkReassign (d : VarDecl) (c : CodeBlock) : TermElabM CodeBlock := do -let x := d.name; -let ws := c.uvars.insert x; --- We must pull "exit points" IF `x` is not in `c.uvars`, but is shadowed by a declaration in `c` +def mkReassignCore (xs : Array Name) (stx : Syntax) (c : CodeBlock) : TermElabM CodeBlock := do +let us := c.uvars; +let ws := insertVars us xs; +-- If `xs` contains a new updated variable, then we must use `extendUpdatedVars`. -- See discussion at `pullExitPoints` -code ← if !c.uvars.contains x then extendUpdatedVarsAux ws c.code else pure c.code; -pure { code := Code.vdecl d true code, uvars := ws } +code ← if xs.any fun x => !us.contains x then extendUpdatedVarsAux ws c.code else pure c.code; +pure { code := Code.reassign xs stx code, uvars := ws } def mkAction (action : Syntax) (c : CodeBlock) : CodeBlock := { c with code := Code.action action c.code } @@ -303,10 +332,11 @@ def mkBreak (ref : Syntax) : CodeBlock := def mkContinue (ref : Syntax) : CodeBlock := { code := Code.«continue» ref } -def mkIte (ref : Syntax) (c : Syntax) (thenBranch : CodeBlock) (elseBranch : CodeBlock) : TermElabM CodeBlock := do +def mkIte (ref : Syntax) (optIdent : Syntax) (cond : Syntax) (thenBranch : CodeBlock) (elseBranch : CodeBlock) : TermElabM CodeBlock := do +let x? := if optIdent.isNone then none else some (optIdent.getArg 0).getId; (thenBranch, elseBranch) ← homogenize thenBranch elseBranch; pure { - code := Code.ite ref c thenBranch.code elseBranch.code, + code := Code.ite ref x? optIdent cond thenBranch.code elseBranch.code, uvars := thenBranch.uvars, } @@ -315,18 +345,18 @@ pure { def concat (terminal : CodeBlock) (k : CodeBlock) : TermElabM CodeBlock := do (terminal, k) ← homogenize terminal k; let xs := nameSetToArray k.uvars; -jpDecl ← mkFreshJP (getSomeRef k.code) xs k.code; +jpDecl ← mkFreshJP xs k.code; let jp := jpDecl.name; pure { - code := Code.jdecl jpDecl (convertReturnIntoJmp terminal.code jp xs), + code := attachJP jpDecl (convertReturnIntoJmp terminal.code jp xs), uvars := terminal.uvars, } def mkWhen (ref : Syntax) (cond : Syntax) (c : CodeBlock) : CodeBlock := -{ c with code := Code.ite ref cond c.code (Code.«return» ref none) } +{ c with code := Code.ite ref none mkNullNode cond c.code (Code.«return» ref none) } def mkUnless (ref : Syntax) (cond : Syntax) (c : CodeBlock) : CodeBlock := -{ c with code := Code.ite ref cond (Code.«return» ref none) c.code } +{ c with code := Code.ite ref none mkNullNode cond (Code.«return» ref none) c.code } private def mkTuple (elems : Array Syntax) : MacroM Syntax := if elems.size == 1 then pure (elems.get! 0) diff --git a/tests/lean/run/doCodeBlock.lean b/tests/lean/run/doCodeBlock.lean index b35ec87b90..4b6b287778 100644 --- a/tests/lean/run/doCodeBlock.lean +++ b/tests/lean/run/doCodeBlock.lean @@ -6,8 +6,11 @@ namespace Lean.Elab.Term.Do def ref := Syntax.missing -def vdecl (name : Name) (pure := true) : VarDecl := -{ ref := ref, name := name, pure := pure, letDecl := Syntax.missing } +def mkVarDecl (x : Name) (k : CodeBlock) : CodeBlock := +mkVarDeclCore #[x] Syntax.missing k + +def mkReassign (x : Name) (k : CodeBlock) : TermElabM CodeBlock := +mkReassignCore #[x] Syntax.missing k def print (c : CodeBlock) : TermElabM Unit := do let msg := c.toMessageData @@ -17,14 +20,14 @@ pure () def tst : TermElabM Unit := do let x := mkIdentFrom ref `x -let c ← mkIte ref (← `($x < 1)) - (mkVarDecl (vdecl `w) (mkVarDecl (vdecl `z) (← mkReassign (vdecl `x) (mkReturn ref)))) - (mkVarDecl (vdecl `x) (← mkReassign (vdecl `y) (mkBreak ref))) +let c ← mkIte ref mkNullNode (← `($x < 1)) + (mkVarDecl `w (mkVarDecl `z (← mkReassign `x (mkReturn ref)))) + (mkVarDecl `x (← mkReassign `y (mkBreak ref))) print c IO.println "-----" -let c ← concat c (mkVarDecl (vdecl `w) (← mkReassign (vdecl `z) (mkReturn ref))) +let c ← concat c (mkVarDecl `w (← mkReassign `z (mkReturn ref))) print c -let c ← mkReassign (vdecl `w) c +let c ← mkReassign `w c IO.println "-----" print c pure ()