feat: add CodeBlock to Syntax converter

This commit is contained in:
Leonardo de Moura 2020-10-04 16:19:09 -07:00
parent b4289d5c5d
commit 9810b405d9
2 changed files with 269 additions and 33 deletions

View file

@ -85,17 +85,18 @@ structure Alt (σ : Type) :=
- `reassign` is an reassignment-like `doElem` (e.g., `x := x + 1`).
- `jointpoint` is a join point declaration: an auxiliary `let`-declaration used to represent the control-flow.
- `joinpoint` is a join point declaration: an auxiliary `let`-declaration used to represent the control-flow.
- `action` is an action-like `doElem` (e.g., `IO.println "hello"`, `dbgTrace! "foo"`).
A code block `C` is well-formed if
- For every `jmp ref j as` in `C`, there is a `jointpoint j ps b k` and `jmp ref j as` is in `k`, and
- For every `jmp ref j as` in `C`, there is a `joinpoint j ps b k` and `jmp ref j as` is in `k`, and
`ps.size == as.size` -/
inductive Code
| decl (xs : Array Name) (stx : Syntax) (cont : Code)
| reassign (xs : Array Name) (stx : Syntax) (cont : Code)
| jointpoint (name : Name) (params : Array Name) (body : Code) (cont : Code)
/- The Boolean value in `params` indicates whether we should use `(x : typeof! x)` when generating term Syntax or not -/
| joinpoint (name : Name) (params : Array (Name × Bool)) (body : Code) (cont : Code)
| action (stx : Syntax) (cond : Code)
| «break» (ref : Syntax)
| «continue» (ref : Syntax)
@ -122,8 +123,8 @@ MessageData.joinSep (vars.toList.map fun n => MessageData.ofName (n.simpMacroSco
partial def toMessageDataAux (updateVars : MessageData) : Code → MessageData
| Code.decl xs _ k => "let " ++ varsToMessageData xs ++ " := ... " ++ Format.line ++ toMessageDataAux k
| Code.reassign xs _ k => varsToMessageData xs ++ " := ... " ++ Format.line ++ toMessageDataAux k
| Code.jointpoint n ps body k =>
"let " ++ n.simpMacroScopes ++ " " ++ varsToMessageData ps ++ " := " ++ indentD (toMessageDataAux body)
| Code.joinpoint n ps body k =>
"let " ++ n.simpMacroScopes ++ " " ++ varsToMessageData (ps.map Prod.fst) ++ " := " ++ indentD (toMessageDataAux body)
++ Format.line ++ toMessageDataAux k
| Code.action e k => e ++ Format.line ++ toMessageDataAux k
| Code.ite _ _ _ c t e => "if " ++ c ++ " then " ++ indentD (toMessageDataAux t) ++ Format.line ++ "else " ++ indentD (toMessageDataAux e)
@ -151,7 +152,7 @@ toMessageDataAux (MessageData.ofList us) c.code
partial def hasExitPoint : Code → Bool
| Code.decl _ _ k => hasExitPoint k
| Code.reassign _ _ k => hasExitPoint k
| Code.jointpoint _ _ b k => hasExitPoint b || hasExitPoint k
| Code.joinpoint _ _ b k => hasExitPoint b || hasExitPoint k
| Code.action _ k => hasExitPoint k
| Code.ite _ _ _ _ t e => hasExitPoint t || hasExitPoint e
| Code.jmp _ _ _ => false
@ -160,10 +161,22 @@ partial def hasExitPoint : Code → Bool
| Code.«return» _ _ => true
| Code.«match» _ _ _ alts => alts.any fun alt => hasExitPoint alt.rhs
partial def hasContinueBreak : Code → Bool
| Code.decl _ _ k => hasContinueBreak k
| Code.reassign _ _ k => hasContinueBreak k
| Code.joinpoint _ _ b k => hasContinueBreak b || hasContinueBreak k
| Code.action _ k => hasContinueBreak k
| Code.ite _ _ _ _ t e => hasContinueBreak t || hasContinueBreak e
| Code.jmp _ _ _ => false
| Code.«break» _ => true
| Code.«continue» _ => true
| Code.«return» _ _ => false
| Code.«match» _ _ _ alts => alts.any fun alt => hasContinueBreak alt.rhs
partial def convertReturnIntoJmpAux (jp : Name) (xs : Array Name) : Code → Code
| Code.decl xs stx k => Code.decl xs stx $ convertReturnIntoJmpAux k
| Code.reassign xs stx k => Code.reassign xs stx $ convertReturnIntoJmpAux k
| Code.jointpoint n ps b k => Code.jointpoint n ps (convertReturnIntoJmpAux b) (convertReturnIntoJmpAux k)
| Code.joinpoint n ps b k => Code.joinpoint n ps (convertReturnIntoJmpAux b) (convertReturnIntoJmpAux k)
| Code.action e k => Code.action e $ convertReturnIntoJmpAux k
| Code.ite ref x? h c t e => Code.ite ref x? h c (convertReturnIntoJmpAux t) (convertReturnIntoJmpAux e)
| Code.«match» ref ds t alts => Code.«match» ref ds t $ alts.map fun alt => { alt with rhs := convertReturnIntoJmpAux alt.rhs }
@ -175,23 +188,29 @@ def convertReturnIntoJmp (c : Code) (jp : Name) (xs : Array Name) : Code :=
convertReturnIntoJmpAux jp xs c
structure JPDecl :=
(name : Name) (params : Array Name) (body : Code)
(name : Name) (params : Array (Name × Bool)) (body : Code)
def attachJP (jpDecl : JPDecl) (k : Code) : Code :=
Code.jointpoint jpDecl.name jpDecl.params jpDecl.body k
Code.joinpoint jpDecl.name jpDecl.params jpDecl.body k
def attachJPs (jpDecls : Array JPDecl) (k : Code) : Code :=
jpDecls.foldr attachJP k
def mkFreshJP (ps : Array Name) (body : Code) : TermElabM JPDecl := do
def mkFreshJP (ps : Array (Name × Bool)) (body : Code) : TermElabM JPDecl := do
name ← mkFreshUserName `jp;
pure { name := name, params := ps, body := body }
def addFreshJP (ps : Array Name) (body : Code) : StateRefT (Array JPDecl) TermElabM Name := do
def mkFreshJP' (xs : Array Name) (body : Code) : TermElabM JPDecl :=
mkFreshJP (xs.map fun x => (x, true)) body
def addFreshJP (ps : Array (Name × Bool)) (body : Code) : StateRefT (Array JPDecl) TermElabM Name := do
jp ← liftM $ mkFreshJP ps body;
modify fun (jps : Array JPDecl) => jps.push jp;
pure jp.name
def addFreshJP' (xs : Array Name) (body : Code) : StateRefT (Array JPDecl) TermElabM Name :=
addFreshJP (xs.map fun x => (x, true)) body
def insertVars (rs : NameSet) (xs : Array Name) : NameSet :=
xs.foldl (fun (rs : NameSet) x => rs.insert x) rs
@ -207,23 +226,24 @@ match x? with
partial def pullExitPointsAux : NameSet → Code → StateRefT (Array JPDecl) TermElabM Code
| rs, Code.decl xs stx k => Code.decl xs stx <$> pullExitPointsAux (eraseVars rs xs) k
| rs, Code.reassign xs stx k => Code.reassign xs stx <$> pullExitPointsAux (insertVars rs xs) k
| rs, Code.jointpoint j ps b k => Code.jointpoint j ps <$> pullExitPointsAux rs b <*> pullExitPointsAux rs k
| rs, Code.joinpoint j ps b k => Code.joinpoint j ps <$> pullExitPointsAux rs b <*> pullExitPointsAux rs k
| rs, Code.action e k => Code.action e <$> pullExitPointsAux rs k
| rs, Code.ite ref x? o c t e => Code.ite ref x? o c <$> pullExitPointsAux (eraseOptVar rs x?) t <*> pullExitPointsAux (eraseOptVar rs x?) e
| rs, Code.«match» ref ds t alts => Code.«match» ref ds t <$> alts.mapM fun alt => do
rhs ← pullExitPointsAux (eraseVars rs alt.vars) alt.rhs; pure { alt with rhs := rhs }
| rs, c@(Code.jmp _ _ _) => pure c
| rs, Code.«break» ref => do let xs := nameSetToArray rs; jp ← addFreshJP xs (Code.«break» ref); pure $ Code.jmp ref jp xs
| rs, Code.«continue» ref => do let xs := nameSetToArray rs; jp ← addFreshJP xs (Code.«continue» ref); pure $ Code.jmp ref jp xs
| rs, Code.«break» ref => do let xs := nameSetToArray rs; jp ← addFreshJP' xs (Code.«break» ref); pure $ Code.jmp ref jp xs
| rs, Code.«continue» ref => do let xs := nameSetToArray rs; jp ← addFreshJP' xs (Code.«continue» ref); pure $ Code.jmp ref jp xs
| rs, Code.«return» ref y? => do
let xs := nameSetToArray rs;
let ps := xs.map fun x => (x, true);
(ps, xs, y?) ← match y? with
| none => pure (xs, xs, none)
| none => pure (ps, xs, none)
| some y =>
if rs.contains y then pure (xs, xs, some y)
if rs.contains y then pure (ps, xs, some y)
else do {
yFresh ← mkFreshUserName y;
pure (xs.push yFresh, xs.push y, some yFresh)
pure (ps.push (yFresh, false), xs.push y, some yFresh)
};
jp ← addFreshJP ps (Code.«return» ref y?);
pure $ Code.jmp ref jp xs
@ -285,7 +305,7 @@ else
pure c
partial def extendUpdatedVarsAux (ws : NameSet) : Code → TermElabM Code
| Code.jointpoint j ps b k => Code.jointpoint j ps <$> extendUpdatedVarsAux b <*> extendUpdatedVarsAux k
| Code.joinpoint j ps b k => Code.joinpoint j ps <$> extendUpdatedVarsAux b <*> extendUpdatedVarsAux k
| Code.action e k => Code.action e <$> extendUpdatedVarsAux k
| c@(Code.«match» ref ds t alts) =>
if alts.any fun alt => alt.vars.any fun x => ws.contains x then
@ -386,7 +406,7 @@ pure {
def concat (terminal : CodeBlock) (k : CodeBlock) : TermElabM CodeBlock := do
(terminal, k) ← homogenize terminal k;
let xs := nameSetToArray k.uvars;
jpDecl ← mkFreshJP xs k.code;
jpDecl ← mkFreshJP' xs k.code;
let jp := jpDecl.name;
pure {
code := attachJP jpDecl (convertReturnIntoJmp terminal.code jp xs),
@ -522,10 +542,213 @@ let doIf := expandDoIf doIf;
thenBranch := doIf.getArg 4,
elseBranch := (doIf.getArg 6).getArg 1 }
private def mkUnit (ref : Syntax) : MacroM Syntax := do
unit ← `(PUnit.unit);
pure $ unit.copyInfo ref
private def mkTuple (ref : Syntax) (elems : Array Syntax) : MacroM Syntax :=
if elems.size == 0 then do
mkUnit ref
else if elems.size == 1 then
pure (elems.get! 0)
else
(elems.extract 0 (elems.size - 1)).foldrM
(fun elem tuple => do
tuple ← `(($elem, $tuple));
pure $ tuple.copyInfo ref)
(elems.back)
-- Code block to syntax term
namespace ToTerm
inductive Kind
| regular | forInNestedTerm | forIn | forInMap
structure Context :=
(m : Syntax) -- Syntax to reference the monad associated with the do notation.
(uvars : Array Name)
(kind : Kind)
abbrev M := ReaderT Context MacroM
def mkUVarTuple (ref : Syntax) : M Syntax := do
ctx ← read;
let uvarIdents := ctx.uvars.map fun x => mkIdentFrom ref x;
liftM $ mkTuple ref uvarIdents
/- Note that, in the current design, we can only reassign variables that were declared in the do-block.
Thus, if `ctx.kind == Kind.regular`, then `ctx.uvars` must be empty.
Therefore, the following method should never create a tuple.
We keep it as-is because we may change the design decision in the future. -/
def mkResultUVarTuple (ref : Syntax) (x? : Option Name) : M Syntax := do
ctx ← read;
match x?, ctx.uvars.isEmpty with
| none, true => liftM $ mkUnit ref
| none, false => do unit ← liftM $ mkUnit ref; uvars ← mkUVarTuple ref; liftM $ mkTuple ref #[unit, uvars]
| some x, true => pure $ mkIdentFrom ref x
| some x, false => do uvars ← mkUVarTuple ref; liftM $ mkTuple ref #[mkIdentFrom ref x, uvars]
def returnToTermCore (ref : Syntax) (x? : Option Name) : M Syntax := do
ctx ← read;
match ctx.kind with
| Kind.forInNestedTerm => do u ← mkUVarTuple ref; `(HasPure.pure (DoResult.«return» $u))
| Kind.regular => do r ← mkResultUVarTuple ref x?; `(HasPure.pure $r)
| _ => do u ← mkUVarTuple ref; `(HasPure.pure (ForInStep.yield $u))
def returnToTerm (ref : Syntax) (x? : Option Name) : M Syntax := do
r ← returnToTermCore ref x?;
pure $ r.copyInfo ref
def continueToTermCore (ref : Syntax) : M Syntax := do
ctx ← read;
match ctx.kind with
| Kind.regular => unreachable!
| Kind.forInNestedTerm => do u ← mkUVarTuple ref; `(HasPure.pure (DoResult.«continue» $u))
| _ => do u ← mkUVarTuple ref; `(HasPure.pure (ForInStep.yield $u))
def continueToTerm (ref : Syntax) : M Syntax := do
r ← continueToTermCore ref;
pure $ r.copyInfo ref
def breakToTermCore (ref : Syntax) : M Syntax := do
ctx ← read;
match ctx.kind with
| Kind.regular => unreachable!
| Kind.forInNestedTerm => do u ← mkUVarTuple ref; `(HasPure.pure (DoResult.«break» $u))
| _ => do u ← mkUVarTuple ref; `(HasPure.pure (ForInStep.done $u))
def breakToTerm (ref : Syntax) : M Syntax := do
r ← breakToTermCore ref;
pure $ r.copyInfo ref
def actionToTermCore (action : Syntax) (k : Syntax) : MacroM Syntax := withFreshMacroScope do
if action.getKind == `Lean.Parser.Term.doDbgTrace then
let msg := action.getArg 1;
`(dbgTrace! $msg; $k)
else if action.getKind == `Lean.Parser.Term.doAssert then
let cond := action.getArg 1;
`(assert! $cond; $k)
else do
`(HasBind.bind $action (fun _ => $k))
def actionToTerm (action : Syntax) (k : Syntax) : MacroM Syntax := do
r ← actionToTermCore action k;
pure $ r.copyInfo action
def declToTermCore (decl : Syntax) (k : Syntax) : M Syntax := withFreshMacroScope do
let kind := decl.getKind;
if kind == `Lean.Parser.Term.doLet then
let letDecl := decl.getArg 1;
`(let $letDecl:letDecl; $k)
else if kind == `Lean.Parser.Term.doLetRec then
liftM $ Macro.throwError decl "WIP"
else if kind == `Lean.Parser.Term.doLetArrow then
let arg := decl.getArg 1;
let ref := arg;
if arg.getKind == `Lean.Parser.Term.doIdDecl then
let id := arg.getArg 0;
let type := expandOptType ref (arg.getArg 1);
let val := arg.getArg 3;
`(HasBind.bind $val (fun ($id:ident : $type) => $k))
else if arg.getKind == `Lean.Parser.Term.doPatDecl then do
-- termParser >> leftArrow >> termParser >> optional (" | " >> termParser)
let pat := arg.getArg 0;
let discr := arg.getArg 2;
let optElse := arg.getArg 3;
if optElse.isNone then
`(HasBind.bind $discr (fun x => match x with | $pat => $k))
else do
let elseBody := optElse.getArg 1;
y ← `(y);
ret ← returnToTerm ref y.getId;
elseBody ← `(HasBind.bind $elseBody (fun y => $ret));
`(HasBind.bind $discr (fun x => match x with | $pat => $k | _ => $elseBody))
else
liftM $ Macro.throwError decl "unexpected kind of 'do' declaration"
else if kind == `Lean.Parser.Term.doHave then
liftM $ Macro.throwError decl ("WIP " ++ toString decl)
else
liftM $ Macro.throwError decl "unexpected kind of 'do' declaration"
def declToTerm (decl : Syntax) (k : Syntax) : M Syntax := do
r ← declToTermCore decl k;
pure $ r.copyInfo decl
def reassignToTermCore (reassign : Syntax) (k : Syntax) : MacroM Syntax := withFreshMacroScope do
let kind := reassign.getKind;
if kind == `Lean.Parser.Term.doReassign then
let letDecl := mkNode `Lean.Parser.Term.letDecl #[reassign.getArg 0];
`(let $letDecl:letDecl; $k)
else if kind == `Lean.Parser.Term.doReassignArrow then
Macro.throwError reassign ("WIP " ++ toString reassign)
else
Macro.throwError reassign "unexpected kind of 'do' reassignment"
def reassignToTerm (reassign : Syntax) (k : Syntax) : MacroM Syntax := do
r ← reassignToTermCore reassign k;
pure $ r.copyInfo reassign
def mkIte (ref : Syntax) (optIdent : Syntax) (cond : Syntax) (thenBranch : Syntax) (elseBranch : Syntax) : Syntax :=
mkNode `Lean.Parser.Term.«if» #[mkAtomFrom ref "if", optIdent, cond, mkAtomFrom ref "then", thenBranch, mkAtomFrom ref "else", elseBranch]
def mkJoinPointCore (j : Name) (ps : Array (Name × Bool)) (body : Syntax) (k : Syntax) : M Syntax := withFreshMacroScope do
let ref := body;
binders ← ps.mapM fun ⟨id, useTypeOf⟩ => do {
type ← if useTypeOf then `(typeOf! $(mkIdentFrom ref id)) else `(_);
let binderType := mkNullNode #[mkAtomFrom ref ":", type];
pure $ mkNode `Lean.Parser.Term.explicitBinder #[mkAtomFrom ref "(", mkNullNode #[mkIdentFrom ref id], binderType, mkNullNode, mkAtomFrom ref ")"]
};
ctx ← read;
let m := ctx.m;
type ← `($m _);
`(let $(mkIdentFrom ref j):ident $binders:explicitBinder* : $type := $body; $k)
def mkJoinPoint (j : Name) (ps : Array (Name × Bool)) (body : Syntax) (k : Syntax) : M Syntax := do
r ← mkJoinPointCore j ps body k;
pure $ r.copyInfo body
def mkJmp (ref : Syntax) (j : Name) (args : Array Name) : Syntax :=
mkAppStx (mkIdentFrom ref j) (args.map $ mkIdentFrom ref)
partial def toTerm : Code → M Syntax
| Code.«return» ref x? => returnToTerm ref x?
| Code.«continue» ref => continueToTerm ref
| Code.«break» ref => breakToTerm ref
| Code.joinpoint j ps b k => do b ← toTerm b; k ← toTerm k; mkJoinPoint j ps b k
| Code.jmp ref j args => pure $ mkJmp ref j args
| Code.decl _ stx k => do k ← toTerm k; declToTerm stx k
| Code.reassign _ stx k => do k ← toTerm k; liftM $ reassignToTerm stx k
| Code.action stx k => do k ← toTerm k; liftM $ actionToTerm stx k
| Code.ite ref _ o c t e => do t ← toTerm t; e ← toTerm e; pure $ mkIte ref o c t e
| _ => liftM $ Macro.throwError Syntax.missing "WIP"
private def getKindUVars (c : CodeBlock) (forInVar? : Option Name) : Kind × Array Name :=
match forInVar? with
| none =>
if hasContinueBreak c.code then
(Kind.forInNestedTerm, nameSetToArray c.uvars)
else
(Kind.regular, nameSetToArray c.uvars)
| some forInVar =>
if c.uvars.contains forInVar then
let uvars := #[forInVar] ++ nameSetToArray (c.uvars.erase forInVar);
(Kind.forInMap, uvars)
else
(Kind.forIn, nameSetToArray c.uvars)
def run (c : CodeBlock) (m : Syntax) (forInVar? : Option Name := none) : MacroM (Array Name × Syntax) := do
let code := c.code;
let (kind, uvars) := getKindUVars c forInVar?;
term ← toTerm code { m := m, kind := kind, uvars := uvars };
pure (uvars, term)
end ToTerm
namespace ToCodeBlock
structure Context :=
(ref : Syntax)
(m : Syntax) -- Syntax representing the monad associated with the do notation.
(varSet : NameSet := {})
(insideFor : Bool := false)
@ -581,6 +804,9 @@ else do
(doElem, doElemsNew) ← (expandLiftMethodAux doElem).run [];
pure (doElemsNew, doElem)
instance auususus : HasToString SourceInfo :=
{ toString := fun info => toString info.pos }
partial def doSeqToCode : List Syntax → M CodeBlock
| [] => do ctx ← read; pure $ mkReturn ctx.ref
| doElem::doElems => withRef doElem do
@ -596,6 +822,11 @@ partial def doSeqToCode : List Syntax → M CodeBlock
k ← doSeqToCode doElems;
liftM $ concat c k
};
let auxDoToCode (auxDo : Syntax) : M CodeBlock := do {
let auxDoElems := getDoSeqElems (getDoSeq auxDo);
let auxDoElems := auxDoElems.map fun auxDoElem => auxDoElem.copyInfo doElem;
doSeqToCode auxDoElems
};
let k := doElem.getKind;
if k == `Lean.Parser.Term.doLet then do
vars ← liftM $ getDoLetVars doElem;
@ -651,8 +882,8 @@ partial def doSeqToCode : List Syntax → M CodeBlock
| some x => pure $ mkReturn ref x
| none => withFreshMacroScope do
auxDo ← `(do let x := $arg; return x);
doSeqToCode (getDoSeqElems (getDoSeq auxDo))
else if k == `Lean.Parser.Term.doDbgTracethen then
auxDoToCode auxDo
else if k == `Lean.Parser.Term.doDbgTrace then
mkAction doElem <$> doSeqToCode doElems
else if k == `Lean.Parser.Term.doAssert then
mkAction doElem <$> doSeqToCode doElems
@ -660,31 +891,36 @@ partial def doSeqToCode : List Syntax → M CodeBlock
let term := doElem.getArg 0;
if doElems.isEmpty then withFreshMacroScope do
auxDo ← `(do let x ← $term; return x);
doSeqToCode (getDoSeqElems (getDoSeq auxDo))
auxDoToCode auxDo
else
mkAction term <$> doSeqToCode doElems
else
throwError ("unexpected do-element" ++ Format.line ++ toString doElem)
def run (doStx : Syntax) : TermElabM CodeBlock :=
(doSeqToCode $ getDoSeqElems $ getDoSeq doStx).run { ref := doStx }
def run (doStx : Syntax) (m : Syntax) : TermElabM CodeBlock :=
(doSeqToCode $ getDoSeqElems $ getDoSeq doStx).run { ref := doStx, m := m }
end ToCodeBlock
private def mkTuple (elems : Array Syntax) : MacroM Syntax :=
if elems.size == 1 then pure (elems.get! 0)
else
(elems.extract 0 (elems.size - 1)).foldrM
(fun elem tuple => `(($elem, $tuple)))
(elems.back)
/- Create a synthetic metavariable `?m` and assign `m` to it.
We use `?m` to refer to `m` when expanding the `do` notation. -/
private def mkMonadAlias (m : Expr) : TermElabM Syntax := do
result ← `(?m);
mType ← inferType m;
mvar ← elabTerm result mType;
assignExprMVar mvar.mvarId! m;
pure result
-- @[builtinTermElab «do»]
def elabDo : TermElab :=
fun stx expectedType? => do
tryPostponeIfNoneOrMVar expectedType?;
bindInfo ← extractBind expectedType?;
codeBlock ← ToCodeBlock.run stx;
throwError ("WIP" ++ Format.line ++ codeBlock.toMessageData)
m ← mkMonadAlias bindInfo.m;
codeBlock ← ToCodeBlock.run stx m;
(_, stxNew) ← liftMacroM $ ToTerm.run codeBlock m;
trace! `Elab.do stxNew;
withMacroExpansion stx stxNew $ elabTerm stxNew expectedType?
end Do

View file

@ -267,7 +267,7 @@ partial def replaceInfo (info : SourceInfo) : Syntax → Syntax
def copyInfo (s : Syntax) (source : Syntax) : Syntax :=
match source.getHeadInfo with
| none => s
| some info => s.setInfo info
| some info => s.setHeadInfo info
private def reprintLeaf (info : SourceInfo) (val : String) : String :=
-- no source info => add gracious amounts of whitespace to definitely separate tokens