diff --git a/src/Lean/Elab/Do.lean b/src/Lean/Elab/Do.lean index 0e0581f896..d21cab310c 100644 --- a/src/Lean/Elab/Do.lean +++ b/src/Lean/Elab/Do.lean @@ -85,17 +85,18 @@ structure Alt (σ : Type) := - `reassign` is an reassignment-like `doElem` (e.g., `x := x + 1`). - - `jointpoint` is a join point declaration: an auxiliary `let`-declaration used to represent the control-flow. + - `joinpoint` is a join point declaration: an auxiliary `let`-declaration used to represent the control-flow. - `action` is an action-like `doElem` (e.g., `IO.println "hello"`, `dbgTrace! "foo"`). A code block `C` is well-formed if - - For every `jmp ref j as` in `C`, there is a `jointpoint j ps b k` and `jmp ref j as` is in `k`, and + - For every `jmp ref j as` in `C`, there is a `joinpoint j ps b k` and `jmp ref j as` is in `k`, and `ps.size == as.size` -/ inductive Code | decl (xs : Array Name) (stx : Syntax) (cont : Code) | reassign (xs : Array Name) (stx : Syntax) (cont : Code) -| jointpoint (name : Name) (params : Array Name) (body : Code) (cont : Code) +/- The Boolean value in `params` indicates whether we should use `(x : typeof! x)` when generating term Syntax or not -/ +| joinpoint (name : Name) (params : Array (Name × Bool)) (body : Code) (cont : Code) | action (stx : Syntax) (cond : Code) | «break» (ref : Syntax) | «continue» (ref : Syntax) @@ -122,8 +123,8 @@ MessageData.joinSep (vars.toList.map fun n => MessageData.ofName (n.simpMacroSco partial def toMessageDataAux (updateVars : MessageData) : Code → MessageData | Code.decl xs _ k => "let " ++ varsToMessageData xs ++ " := ... " ++ Format.line ++ toMessageDataAux k | Code.reassign xs _ k => varsToMessageData xs ++ " := ... " ++ Format.line ++ toMessageDataAux k -| Code.jointpoint n ps body k => - "let " ++ n.simpMacroScopes ++ " " ++ varsToMessageData ps ++ " := " ++ indentD (toMessageDataAux body) +| Code.joinpoint n ps body k => + "let " ++ n.simpMacroScopes ++ " " ++ varsToMessageData (ps.map Prod.fst) ++ " := " ++ indentD (toMessageDataAux body) ++ Format.line ++ toMessageDataAux k | Code.action e k => e ++ Format.line ++ toMessageDataAux k | Code.ite _ _ _ c t e => "if " ++ c ++ " then " ++ indentD (toMessageDataAux t) ++ Format.line ++ "else " ++ indentD (toMessageDataAux e) @@ -151,7 +152,7 @@ toMessageDataAux (MessageData.ofList us) c.code partial def hasExitPoint : Code → Bool | Code.decl _ _ k => hasExitPoint k | Code.reassign _ _ k => hasExitPoint k -| Code.jointpoint _ _ b k => hasExitPoint b || hasExitPoint k +| Code.joinpoint _ _ b k => hasExitPoint b || hasExitPoint k | Code.action _ k => hasExitPoint k | Code.ite _ _ _ _ t e => hasExitPoint t || hasExitPoint e | Code.jmp _ _ _ => false @@ -160,10 +161,22 @@ partial def hasExitPoint : Code → Bool | Code.«return» _ _ => true | Code.«match» _ _ _ alts => alts.any fun alt => hasExitPoint alt.rhs +partial def hasContinueBreak : Code → Bool +| Code.decl _ _ k => hasContinueBreak k +| Code.reassign _ _ k => hasContinueBreak k +| Code.joinpoint _ _ b k => hasContinueBreak b || hasContinueBreak k +| Code.action _ k => hasContinueBreak k +| Code.ite _ _ _ _ t e => hasContinueBreak t || hasContinueBreak e +| Code.jmp _ _ _ => false +| Code.«break» _ => true +| Code.«continue» _ => true +| Code.«return» _ _ => false +| Code.«match» _ _ _ alts => alts.any fun alt => hasContinueBreak alt.rhs + partial def convertReturnIntoJmpAux (jp : Name) (xs : Array Name) : Code → Code | Code.decl xs stx k => Code.decl xs stx $ convertReturnIntoJmpAux k | Code.reassign xs stx k => Code.reassign xs stx $ convertReturnIntoJmpAux k -| Code.jointpoint n ps b k => Code.jointpoint n ps (convertReturnIntoJmpAux b) (convertReturnIntoJmpAux k) +| Code.joinpoint n ps b k => Code.joinpoint n ps (convertReturnIntoJmpAux b) (convertReturnIntoJmpAux k) | Code.action e k => Code.action e $ convertReturnIntoJmpAux k | Code.ite ref x? h c t e => Code.ite ref x? h c (convertReturnIntoJmpAux t) (convertReturnIntoJmpAux e) | Code.«match» ref ds t alts => Code.«match» ref ds t $ alts.map fun alt => { alt with rhs := convertReturnIntoJmpAux alt.rhs } @@ -175,23 +188,29 @@ def convertReturnIntoJmp (c : Code) (jp : Name) (xs : Array Name) : Code := convertReturnIntoJmpAux jp xs c structure JPDecl := -(name : Name) (params : Array Name) (body : Code) +(name : Name) (params : Array (Name × Bool)) (body : Code) def attachJP (jpDecl : JPDecl) (k : Code) : Code := -Code.jointpoint jpDecl.name jpDecl.params jpDecl.body k +Code.joinpoint jpDecl.name jpDecl.params jpDecl.body k def attachJPs (jpDecls : Array JPDecl) (k : Code) : Code := jpDecls.foldr attachJP k -def mkFreshJP (ps : Array Name) (body : Code) : TermElabM JPDecl := do +def mkFreshJP (ps : Array (Name × Bool)) (body : Code) : TermElabM JPDecl := do name ← mkFreshUserName `jp; pure { name := name, params := ps, body := body } -def addFreshJP (ps : Array Name) (body : Code) : StateRefT (Array JPDecl) TermElabM Name := do +def mkFreshJP' (xs : Array Name) (body : Code) : TermElabM JPDecl := +mkFreshJP (xs.map fun x => (x, true)) body + +def addFreshJP (ps : Array (Name × Bool)) (body : Code) : StateRefT (Array JPDecl) TermElabM Name := do jp ← liftM $ mkFreshJP ps body; modify fun (jps : Array JPDecl) => jps.push jp; pure jp.name +def addFreshJP' (xs : Array Name) (body : Code) : StateRefT (Array JPDecl) TermElabM Name := +addFreshJP (xs.map fun x => (x, true)) body + def insertVars (rs : NameSet) (xs : Array Name) : NameSet := xs.foldl (fun (rs : NameSet) x => rs.insert x) rs @@ -207,23 +226,24 @@ match x? with partial def pullExitPointsAux : NameSet → Code → StateRefT (Array JPDecl) TermElabM Code | rs, Code.decl xs stx k => Code.decl xs stx <$> pullExitPointsAux (eraseVars rs xs) k | rs, Code.reassign xs stx k => Code.reassign xs stx <$> pullExitPointsAux (insertVars rs xs) k -| rs, Code.jointpoint j ps b k => Code.jointpoint j ps <$> pullExitPointsAux rs b <*> pullExitPointsAux rs k +| rs, Code.joinpoint j ps b k => Code.joinpoint j ps <$> pullExitPointsAux rs b <*> pullExitPointsAux rs k | rs, Code.action e k => Code.action e <$> pullExitPointsAux rs k | rs, Code.ite ref x? o c t e => Code.ite ref x? o c <$> pullExitPointsAux (eraseOptVar rs x?) t <*> pullExitPointsAux (eraseOptVar rs x?) e | rs, Code.«match» ref ds t alts => Code.«match» ref ds t <$> alts.mapM fun alt => do rhs ← pullExitPointsAux (eraseVars rs alt.vars) alt.rhs; pure { alt with rhs := rhs } | rs, c@(Code.jmp _ _ _) => pure c -| rs, Code.«break» ref => do let xs := nameSetToArray rs; jp ← addFreshJP xs (Code.«break» ref); pure $ Code.jmp ref jp xs -| rs, Code.«continue» ref => do let xs := nameSetToArray rs; jp ← addFreshJP xs (Code.«continue» ref); pure $ Code.jmp ref jp xs +| rs, Code.«break» ref => do let xs := nameSetToArray rs; jp ← addFreshJP' xs (Code.«break» ref); pure $ Code.jmp ref jp xs +| rs, Code.«continue» ref => do let xs := nameSetToArray rs; jp ← addFreshJP' xs (Code.«continue» ref); pure $ Code.jmp ref jp xs | rs, Code.«return» ref y? => do let xs := nameSetToArray rs; + let ps := xs.map fun x => (x, true); (ps, xs, y?) ← match y? with - | none => pure (xs, xs, none) + | none => pure (ps, xs, none) | some y => - if rs.contains y then pure (xs, xs, some y) + if rs.contains y then pure (ps, xs, some y) else do { yFresh ← mkFreshUserName y; - pure (xs.push yFresh, xs.push y, some yFresh) + pure (ps.push (yFresh, false), xs.push y, some yFresh) }; jp ← addFreshJP ps (Code.«return» ref y?); pure $ Code.jmp ref jp xs @@ -285,7 +305,7 @@ else pure c partial def extendUpdatedVarsAux (ws : NameSet) : Code → TermElabM Code -| Code.jointpoint j ps b k => Code.jointpoint j ps <$> extendUpdatedVarsAux b <*> extendUpdatedVarsAux k +| Code.joinpoint j ps b k => Code.joinpoint j ps <$> extendUpdatedVarsAux b <*> extendUpdatedVarsAux k | Code.action e k => Code.action e <$> extendUpdatedVarsAux k | c@(Code.«match» ref ds t alts) => if alts.any fun alt => alt.vars.any fun x => ws.contains x then @@ -386,7 +406,7 @@ pure { def concat (terminal : CodeBlock) (k : CodeBlock) : TermElabM CodeBlock := do (terminal, k) ← homogenize terminal k; let xs := nameSetToArray k.uvars; -jpDecl ← mkFreshJP xs k.code; +jpDecl ← mkFreshJP' xs k.code; let jp := jpDecl.name; pure { code := attachJP jpDecl (convertReturnIntoJmp terminal.code jp xs), @@ -522,10 +542,213 @@ let doIf := expandDoIf doIf; thenBranch := doIf.getArg 4, elseBranch := (doIf.getArg 6).getArg 1 } +private def mkUnit (ref : Syntax) : MacroM Syntax := do +unit ← `(PUnit.unit); +pure $ unit.copyInfo ref + +private def mkTuple (ref : Syntax) (elems : Array Syntax) : MacroM Syntax := +if elems.size == 0 then do + mkUnit ref +else if elems.size == 1 then + pure (elems.get! 0) +else + (elems.extract 0 (elems.size - 1)).foldrM + (fun elem tuple => do + tuple ← `(($elem, $tuple)); + pure $ tuple.copyInfo ref) + (elems.back) + +-- Code block to syntax term +namespace ToTerm + +inductive Kind +| regular | forInNestedTerm | forIn | forInMap + +structure Context := +(m : Syntax) -- Syntax to reference the monad associated with the do notation. +(uvars : Array Name) +(kind : Kind) + +abbrev M := ReaderT Context MacroM + +def mkUVarTuple (ref : Syntax) : M Syntax := do +ctx ← read; +let uvarIdents := ctx.uvars.map fun x => mkIdentFrom ref x; +liftM $ mkTuple ref uvarIdents + +/- Note that, in the current design, we can only reassign variables that were declared in the do-block. + Thus, if `ctx.kind == Kind.regular`, then `ctx.uvars` must be empty. + Therefore, the following method should never create a tuple. + We keep it as-is because we may change the design decision in the future. -/ +def mkResultUVarTuple (ref : Syntax) (x? : Option Name) : M Syntax := do +ctx ← read; +match x?, ctx.uvars.isEmpty with +| none, true => liftM $ mkUnit ref +| none, false => do unit ← liftM $ mkUnit ref; uvars ← mkUVarTuple ref; liftM $ mkTuple ref #[unit, uvars] +| some x, true => pure $ mkIdentFrom ref x +| some x, false => do uvars ← mkUVarTuple ref; liftM $ mkTuple ref #[mkIdentFrom ref x, uvars] + +def returnToTermCore (ref : Syntax) (x? : Option Name) : M Syntax := do +ctx ← read; +match ctx.kind with +| Kind.forInNestedTerm => do u ← mkUVarTuple ref; `(HasPure.pure (DoResult.«return» $u)) +| Kind.regular => do r ← mkResultUVarTuple ref x?; `(HasPure.pure $r) +| _ => do u ← mkUVarTuple ref; `(HasPure.pure (ForInStep.yield $u)) + +def returnToTerm (ref : Syntax) (x? : Option Name) : M Syntax := do +r ← returnToTermCore ref x?; +pure $ r.copyInfo ref + +def continueToTermCore (ref : Syntax) : M Syntax := do +ctx ← read; +match ctx.kind with +| Kind.regular => unreachable! +| Kind.forInNestedTerm => do u ← mkUVarTuple ref; `(HasPure.pure (DoResult.«continue» $u)) +| _ => do u ← mkUVarTuple ref; `(HasPure.pure (ForInStep.yield $u)) + +def continueToTerm (ref : Syntax) : M Syntax := do +r ← continueToTermCore ref; +pure $ r.copyInfo ref + +def breakToTermCore (ref : Syntax) : M Syntax := do +ctx ← read; +match ctx.kind with +| Kind.regular => unreachable! +| Kind.forInNestedTerm => do u ← mkUVarTuple ref; `(HasPure.pure (DoResult.«break» $u)) +| _ => do u ← mkUVarTuple ref; `(HasPure.pure (ForInStep.done $u)) + +def breakToTerm (ref : Syntax) : M Syntax := do +r ← breakToTermCore ref; +pure $ r.copyInfo ref + +def actionToTermCore (action : Syntax) (k : Syntax) : MacroM Syntax := withFreshMacroScope do +if action.getKind == `Lean.Parser.Term.doDbgTrace then + let msg := action.getArg 1; + `(dbgTrace! $msg; $k) +else if action.getKind == `Lean.Parser.Term.doAssert then + let cond := action.getArg 1; + `(assert! $cond; $k) +else do + `(HasBind.bind $action (fun _ => $k)) + +def actionToTerm (action : Syntax) (k : Syntax) : MacroM Syntax := do +r ← actionToTermCore action k; +pure $ r.copyInfo action + +def declToTermCore (decl : Syntax) (k : Syntax) : M Syntax := withFreshMacroScope do +let kind := decl.getKind; +if kind == `Lean.Parser.Term.doLet then + let letDecl := decl.getArg 1; + `(let $letDecl:letDecl; $k) +else if kind == `Lean.Parser.Term.doLetRec then + liftM $ Macro.throwError decl "WIP" +else if kind == `Lean.Parser.Term.doLetArrow then + let arg := decl.getArg 1; + let ref := arg; + if arg.getKind == `Lean.Parser.Term.doIdDecl then + let id := arg.getArg 0; + let type := expandOptType ref (arg.getArg 1); + let val := arg.getArg 3; + `(HasBind.bind $val (fun ($id:ident : $type) => $k)) + else if arg.getKind == `Lean.Parser.Term.doPatDecl then do + -- termParser >> leftArrow >> termParser >> optional (" | " >> termParser) + let pat := arg.getArg 0; + let discr := arg.getArg 2; + let optElse := arg.getArg 3; + if optElse.isNone then + `(HasBind.bind $discr (fun x => match x with | $pat => $k)) + else do + let elseBody := optElse.getArg 1; + y ← `(y); + ret ← returnToTerm ref y.getId; + elseBody ← `(HasBind.bind $elseBody (fun y => $ret)); + `(HasBind.bind $discr (fun x => match x with | $pat => $k | _ => $elseBody)) + else + liftM $ Macro.throwError decl "unexpected kind of 'do' declaration" +else if kind == `Lean.Parser.Term.doHave then + liftM $ Macro.throwError decl ("WIP " ++ toString decl) +else + liftM $ Macro.throwError decl "unexpected kind of 'do' declaration" + +def declToTerm (decl : Syntax) (k : Syntax) : M Syntax := do +r ← declToTermCore decl k; +pure $ r.copyInfo decl + +def reassignToTermCore (reassign : Syntax) (k : Syntax) : MacroM Syntax := withFreshMacroScope do +let kind := reassign.getKind; +if kind == `Lean.Parser.Term.doReassign then + let letDecl := mkNode `Lean.Parser.Term.letDecl #[reassign.getArg 0]; + `(let $letDecl:letDecl; $k) +else if kind == `Lean.Parser.Term.doReassignArrow then + Macro.throwError reassign ("WIP " ++ toString reassign) +else + Macro.throwError reassign "unexpected kind of 'do' reassignment" + +def reassignToTerm (reassign : Syntax) (k : Syntax) : MacroM Syntax := do +r ← reassignToTermCore reassign k; +pure $ r.copyInfo reassign + +def mkIte (ref : Syntax) (optIdent : Syntax) (cond : Syntax) (thenBranch : Syntax) (elseBranch : Syntax) : Syntax := +mkNode `Lean.Parser.Term.«if» #[mkAtomFrom ref "if", optIdent, cond, mkAtomFrom ref "then", thenBranch, mkAtomFrom ref "else", elseBranch] + +def mkJoinPointCore (j : Name) (ps : Array (Name × Bool)) (body : Syntax) (k : Syntax) : M Syntax := withFreshMacroScope do +let ref := body; +binders ← ps.mapM fun ⟨id, useTypeOf⟩ => do { + type ← if useTypeOf then `(typeOf! $(mkIdentFrom ref id)) else `(_); + let binderType := mkNullNode #[mkAtomFrom ref ":", type]; + pure $ mkNode `Lean.Parser.Term.explicitBinder #[mkAtomFrom ref "(", mkNullNode #[mkIdentFrom ref id], binderType, mkNullNode, mkAtomFrom ref ")"] +}; +ctx ← read; +let m := ctx.m; +type ← `($m _); +`(let $(mkIdentFrom ref j):ident $binders:explicitBinder* : $type := $body; $k) + +def mkJoinPoint (j : Name) (ps : Array (Name × Bool)) (body : Syntax) (k : Syntax) : M Syntax := do +r ← mkJoinPointCore j ps body k; +pure $ r.copyInfo body + +def mkJmp (ref : Syntax) (j : Name) (args : Array Name) : Syntax := +mkAppStx (mkIdentFrom ref j) (args.map $ mkIdentFrom ref) + +partial def toTerm : Code → M Syntax +| Code.«return» ref x? => returnToTerm ref x? +| Code.«continue» ref => continueToTerm ref +| Code.«break» ref => breakToTerm ref +| Code.joinpoint j ps b k => do b ← toTerm b; k ← toTerm k; mkJoinPoint j ps b k +| Code.jmp ref j args => pure $ mkJmp ref j args +| Code.decl _ stx k => do k ← toTerm k; declToTerm stx k +| Code.reassign _ stx k => do k ← toTerm k; liftM $ reassignToTerm stx k +| Code.action stx k => do k ← toTerm k; liftM $ actionToTerm stx k +| Code.ite ref _ o c t e => do t ← toTerm t; e ← toTerm e; pure $ mkIte ref o c t e +| _ => liftM $ Macro.throwError Syntax.missing "WIP" + +private def getKindUVars (c : CodeBlock) (forInVar? : Option Name) : Kind × Array Name := +match forInVar? with +| none => + if hasContinueBreak c.code then + (Kind.forInNestedTerm, nameSetToArray c.uvars) + else + (Kind.regular, nameSetToArray c.uvars) +| some forInVar => + if c.uvars.contains forInVar then + let uvars := #[forInVar] ++ nameSetToArray (c.uvars.erase forInVar); + (Kind.forInMap, uvars) + else + (Kind.forIn, nameSetToArray c.uvars) + +def run (c : CodeBlock) (m : Syntax) (forInVar? : Option Name := none) : MacroM (Array Name × Syntax) := do +let code := c.code; +let (kind, uvars) := getKindUVars c forInVar?; +term ← toTerm code { m := m, kind := kind, uvars := uvars }; +pure (uvars, term) + +end ToTerm + namespace ToCodeBlock structure Context := (ref : Syntax) +(m : Syntax) -- Syntax representing the monad associated with the do notation. (varSet : NameSet := {}) (insideFor : Bool := false) @@ -581,6 +804,9 @@ else do (doElem, doElemsNew) ← (expandLiftMethodAux doElem).run []; pure (doElemsNew, doElem) +instance auususus : HasToString SourceInfo := +{ toString := fun info => toString info.pos } + partial def doSeqToCode : List Syntax → M CodeBlock | [] => do ctx ← read; pure $ mkReturn ctx.ref | doElem::doElems => withRef doElem do @@ -596,6 +822,11 @@ partial def doSeqToCode : List Syntax → M CodeBlock k ← doSeqToCode doElems; liftM $ concat c k }; + let auxDoToCode (auxDo : Syntax) : M CodeBlock := do { + let auxDoElems := getDoSeqElems (getDoSeq auxDo); + let auxDoElems := auxDoElems.map fun auxDoElem => auxDoElem.copyInfo doElem; + doSeqToCode auxDoElems + }; let k := doElem.getKind; if k == `Lean.Parser.Term.doLet then do vars ← liftM $ getDoLetVars doElem; @@ -651,8 +882,8 @@ partial def doSeqToCode : List Syntax → M CodeBlock | some x => pure $ mkReturn ref x | none => withFreshMacroScope do auxDo ← `(do let x := $arg; return x); - doSeqToCode (getDoSeqElems (getDoSeq auxDo)) - else if k == `Lean.Parser.Term.doDbgTracethen then + auxDoToCode auxDo + else if k == `Lean.Parser.Term.doDbgTrace then mkAction doElem <$> doSeqToCode doElems else if k == `Lean.Parser.Term.doAssert then mkAction doElem <$> doSeqToCode doElems @@ -660,31 +891,36 @@ partial def doSeqToCode : List Syntax → M CodeBlock let term := doElem.getArg 0; if doElems.isEmpty then withFreshMacroScope do auxDo ← `(do let x ← $term; return x); - doSeqToCode (getDoSeqElems (getDoSeq auxDo)) + auxDoToCode auxDo else mkAction term <$> doSeqToCode doElems else throwError ("unexpected do-element" ++ Format.line ++ toString doElem) -def run (doStx : Syntax) : TermElabM CodeBlock := -(doSeqToCode $ getDoSeqElems $ getDoSeq doStx).run { ref := doStx } +def run (doStx : Syntax) (m : Syntax) : TermElabM CodeBlock := +(doSeqToCode $ getDoSeqElems $ getDoSeq doStx).run { ref := doStx, m := m } end ToCodeBlock -private def mkTuple (elems : Array Syntax) : MacroM Syntax := -if elems.size == 1 then pure (elems.get! 0) -else - (elems.extract 0 (elems.size - 1)).foldrM - (fun elem tuple => `(($elem, $tuple))) - (elems.back) +/- Create a synthetic metavariable `?m` and assign `m` to it. + We use `?m` to refer to `m` when expanding the `do` notation. -/ +private def mkMonadAlias (m : Expr) : TermElabM Syntax := do +result ← `(?m); +mType ← inferType m; +mvar ← elabTerm result mType; +assignExprMVar mvar.mvarId! m; +pure result -- @[builtinTermElab «do»] def elabDo : TermElab := fun stx expectedType? => do tryPostponeIfNoneOrMVar expectedType?; bindInfo ← extractBind expectedType?; - codeBlock ← ToCodeBlock.run stx; - throwError ("WIP" ++ Format.line ++ codeBlock.toMessageData) + m ← mkMonadAlias bindInfo.m; + codeBlock ← ToCodeBlock.run stx m; + (_, stxNew) ← liftMacroM $ ToTerm.run codeBlock m; + trace! `Elab.do stxNew; + withMacroExpansion stx stxNew $ elabTerm stxNew expectedType? end Do diff --git a/src/Lean/Syntax.lean b/src/Lean/Syntax.lean index f328ff2899..aea9b543b9 100644 --- a/src/Lean/Syntax.lean +++ b/src/Lean/Syntax.lean @@ -267,7 +267,7 @@ partial def replaceInfo (info : SourceInfo) : Syntax → Syntax def copyInfo (s : Syntax) (source : Syntax) : Syntax := match source.getHeadInfo with | none => s -| some info => s.setInfo info +| some info => s.setHeadInfo info private def reprintLeaf (info : SourceInfo) (val : String) : String := -- no source info => add gracious amounts of whitespace to definitely separate tokens