fix: parenthesizer

This commit is contained in:
Sebastian Ullrich 2020-06-09 11:26:00 +02:00
parent c8ee21747b
commit a78ceb8121

View file

@ -20,36 +20,36 @@ parenthesizers can be added for new node kinds, but the data collected in the im
appropriate for other parenthesization strategies.
Usages of a parser defined via `prattParser` in general have the form `p rbp`, where `rbp` is the right-binding power.
Recall that a Pratt parser greedily parses a leading token with precedence at least `rbp` (otherwise it fails) followed
by zero or more trailing tokens with precedence *higher* than `rbp`. Thus we should parenthesize a syntax node `stx`
produced by `p rbp` if
Recall that a Pratt parser greedily runs a leading parser with precedence at least `rbp` (otherwise it fails) followed
by zero or more trailing parsers with precedence *higher* than `rbp`; the precedence of a parser is encoded by an
initial call to `checkRbpLe/Lt`, respectively. Thus we should parenthesize a syntax node `stx` supposedly produced by
`p rbp` if
1. the leading/any trailing token in `stx` has precedence < `rbp`/<= `rbp`, respectively (because without parentheses,
`p rbp` would not have parsed all of `stx`), or
2. the token to *the right of* `stx`, if any, has precedence > `rbp` (because without parentheses, `p rbp` would have
parsed it as well and made it a part of `stx`).
1. the leading/any trailing parser involved in `stx` has precedence < `rbp`/<= `rbp`, respectively (because without
parentheses, `p rbp` would not produce all of `stx`), or
2. the trailing parser parsing the input to *the right of* `stx`, if any, has precedence > `rbp` (because without
parentheses, `p rbp` would have parsed it as well and made it a part of `stx`).
Note that in case 2, it is also sufficient to parenthesize a *parent* node as long as the offending token is still to
Note that in case 2, it is also sufficient to parenthesize a *parent* node as long as the offending parser is still to
the right of that node. For example, imagine the tree structure of `(f $ fun x => x) y` without parentheses. We need to
insert *some* parentheses between `x` and `y` since the lambda body is parsed with precedence 0, while `y` as an
identifier has precedence `appPrec`. But we need to parenthesize the `$` node anyway since the precedence of its
insert *some* parentheses between `x` and `y` since the lambda body is parsed with precedence 0, while the identifier
parser for `y` has precedence `appPrec`. But we need to parenthesize the `$` node anyway since the precedence of its
RHS (0) again is smaller than that of `y`. So it's better to only parenthesize the outer node than ending up with
`(f $ (fun x => x)) y`.
Unfortunately, we cannot determine the precedence of a token just by looking at the token table because it can actually
have different precedences in different contexts (e.g. because of whitespace sensitivity). Thus we need to look at the
parser that produced the token as well.
# Implementation
We transform the syntax tree and collect the necessary precedence information for that in a single traversal over the
syntax tree and the parser (as a `Lean.Expr`) that produced it. The traversal is right-to-left to cover case 2. More
specifically, for every Pratt parser call, we store as monadic state the (current) first and minimum precedence of any
token (`firstLbp`/`minLbp`) in this call, if any, and the precedence of the nested trailing Pratt parser call
(`trailRbp`), if any. We subtract 1 from the precedence of trailing tokens so that we don't have to differentiate
between leading and trailing tokens in `minLbp`. If `stP` is the state resulting from the traversal of a Pratt
parser call `p rbp`, and `st` is the state of the surrounding call, we parenthesize if `rbp > stP.minLbp` (case 1) or if
`stP.trailRbp < st.firstLbp` (case 2).
parser (`firstLbp`/`minLbp`) in this call, if any, and the precedence of the nested trailing Pratt parser call
(`trailRbp`), if any. We subtract 1 from the precedence of trailing parsers so that we don't have to differentiate
between leading and trailing parsers in `minLbp`. If `stP` is the state resulting from the traversal of a Pratt parser
call `p rbp`, and `st` is the state of the surrounding call, we parenthesize if `rbp > stP.minLbp` (case 1) or if
`stP.trailRbp < st.firstLbp` (case 2). Note that because trailing parsers are encoded as
`checkRblLt lbp >> trailingNode p`, when we check if we should parenthesize the parser's LHS (the first child in the
node) inside `trailingNode`, `st.firstLbp` is actually already set to the trailing parser's precedence even though we
are doing a left-to-right traversal.
The primary traversal is over the parser `Expr`. The `visit` function takes such a parser and, if it is the application
of a constant `c`, looks for a `[parenthesizer c]` declaration. If it exists, we run it, which might again call `visit`.
@ -61,16 +61,14 @@ The traversal over the `Syntax` object is complicated by the fact that a parser
node, but an arbitrary (but constant, for each parser) amount that it pushes on top of the parser stack. This amount can
even be zero for parsers such as `checkWsBefore`. Thus we cannot simply pass and return a `Syntax` object to and from
`visit`. Instead, we use a `Syntax.Traverser` that allows arbitrary movement and modification inside the syntax tree.
Our traversal invariant is that a parser interpreter should stop at the node to the *left* of all nodes its parser
produced, except when it is already at the left-most child. This special case is not an issue in practice since if there
is another parser to the left that produced zero nodes in this case, it should always do so, so there is no danger of
the left-most child being processed multiple times.
Our traversal invariant is that a parser interpreter should stop at the syntax object to the *right* of all syntax
objects its parser produced.
Ultimately, most parenthesizers are implemented via three primitives that do all the actual syntax traversal:
`visitParenthesizable mkParen rbp` recurses on the current node and afterwards transforms it with `mkParen` if the above
condition for `p rbp` is fulfilled. `visitToken lbp` does not recurse but updates `firstLbp` and advances one node to
the left. `visitArgs x` executes `x` on the right-most child of the current node and then advances one node to the left
(of the original current node).
condition for `p rbp` is fulfilled. `goRight` advances to the next syntax sibling and is used on atoms. `visitArgs x` executes
`x` on the first child of the current node and then advances to the next sibling (of the original current node).
-/
prelude
@ -84,6 +82,7 @@ namespace Syntax
/--
Represents a cursor into a syntax tree that can be read, written, and advanced down/up/left/right.
Indices are allowed to be out-of-bound, in which case `cur` is `Syntax.missing`.
If the `Traverser` is used linearly, updates are linear in the `Syntax` object as well.
-/
structure Traverser :=
@ -101,23 +100,27 @@ def setCur (t : Traverser) (stx : Syntax) : Traverser :=
/-- Advance to the `idx`-th child of the current node. -/
def down (t : Traverser) (idx : Nat) : Traverser :=
{ cur := t.cur.getArg idx, parents := t.parents.push $ t.cur.setArg idx (arbitrary _), idxs := t.idxs.push idx }
if idx < t.cur.getNumArgs then
{ cur := t.cur.getArg idx, parents := t.parents.push $ t.cur.setArg idx (arbitrary _), idxs := t.idxs.push idx }
else
{ cur := Syntax.missing, parents := t.parents.push t.cur, idxs := t.idxs.push idx }
/-- Advance to the parent of the current node, if any. -/
def up (t : Traverser) : Traverser :=
if t.parents.size > 0 then
{ cur := t.parents.back.setArg t.idxs.back t.cur, parents := t.parents.pop, idxs := t.idxs.pop }
let cur := if t.idxs.back < t.parents.back.getNumArgs then t.parents.back.setArg t.idxs.back t.cur else t.parents.back;
{ cur := cur, parents := t.parents.pop, idxs := t.idxs.pop }
else t
/-- Advance to the left sibling of the current node, if any. -/
def left (t : Traverser) : Traverser :=
if t.idxs.size > 0 && t.idxs.back > 0 then
if t.parents.size > 0 then
t.up.down (t.idxs.back - 1)
else t
/-- Advance to the right sibling of the current node, if any. -/
def right (t :Traverser) : Traverser :=
if t.idxs.size > 0 && t.idxs.back + 1 < t.parents.back.getArgs.size then
def right (t : Traverser) : Traverser :=
if t.parents.size > 0 then
t.up.down (t.idxs.back + 1)
else t
@ -155,9 +158,7 @@ structure Context :=
structure State :=
(stxTrav : Syntax.Traverser)
-- precedence of the current left-most token, if any; see module doc for details
(firstLbp : Option Nat := none)
-- current minimum precedence of tokens, if any; see module doc for details
-- current minimum precedence in this Pratt parser call, if any; see module doc for details
(minLbp : Option Nat := none)
-- precedence of the trailing Pratt parser call if any; see module doc for details
(trailRbp : Option Nat := none)
@ -196,12 +197,15 @@ instance ParenthesizerM.monadTraverser : Syntax.MonadTraverser ParenthesizerM :=
open Syntax.MonadTraverser
def addLbp (lbp : Nat) : ParenthesizerM Unit :=
modify $ fun st => { st with minLbp := Nat.min (st.minLbp.getD lbp) lbp }
/-- Execute `x` at the right-most child of the current node, if any, then advance to the left. -/
def visitArgs (x : ParenthesizerM Unit) : ParenthesizerM Unit := do
stx ← getCur;
when (stx.getArgs.size > 0) $
goDown (stx.getArgs.size - 1) *> x <* goUp;
goLeft
goDown 0 *> x <* goUp;
goRight
/--
Call an appropriate `[parenthesizer]` depending on the `Parser` `Expr` `p`. After the call, the traverser position
@ -242,16 +246,15 @@ instance monadQuotation : MonadQuotation ParenthesizerM := {
def visitAntiquot : ParenthesizerM Unit := do
stx ← getCur;
if Elab.Term.Quotation.isAntiquot stx then visitArgs $ do
-- antiquot syntax is, simplified, `"$" "$"* antiquotExpr ":" (nonReservedSymbol name) "*"?`
goLeft; goLeft; goLeft; -- now on `antiquotExpr`
-- antiquot syntax is, simplified, `syntax:appPrec "$" "$"* antiquotExpr ":" (nonReservedSymbol name) "*"?`
goRight; goRight; -- now on `antiquotExpr`
visit (mkConst `Lean.Parser.antiquotExpr);
-- set precedence; see special case in `currLbp`
modify (fun st => { st with firstLbp := Parser.appPrec, minLbp := Parser.appPrec })
addLbp appPrec
else
throw $ Exception.other $ "not an antiquotation"
/-- Recurse using `visit`, and parenthesize the result using `mkParen` if necessary. -/
def visitParenthesizable (mkParen : Syntax → Syntax) (rbp : Nat) : ParenthesizerM Unit := do
def visitParenthesizable (mkParen : Syntax → Syntax) (rbp : Nat) (trailLbp : Option Nat := none) : ParenthesizerM Unit := do
stx ← getCur;
idx ← getIdx;
st ← get;
@ -262,13 +265,11 @@ adaptReader (fun (ctx : Context) => { ctx with mkParen := some mkParen }) $
visit (mkConst stx.getKind);
{ minLbp := some minLbpP, trailRbp := trailRbpP, .. } ← get
| panic! "visitParenthesizable: visited a term without tokens?!";
trace! `PrettyPrinter.parenthesize ("...precedences are " ++ fmt rbp ++ " >? " ++ fmt minLbpP ++ ", " ++ fmt trailRbpP ++ " <? " ++ fmt st.firstLbp);
trace! `PrettyPrinter.parenthesize ("...precedences are " ++ fmt rbp ++ " >? " ++ fmt minLbpP ++ ", " ++ fmt trailRbpP ++ " <=? " ++ fmt trailLbp);
-- Should we parenthesize?
when (rbp > minLbpP || match trailRbpP, st.firstLbp with some trailRbpP, some firstLbp => trailRbpP < firstLbp | _, _ => false) $ do {
-- The recursive `visit` call, by the invariant, has moved to the next node to the left. In order to parenthesize
-- the original node, we must first move to the right, except if we already were at the left-most child in the first
-- place.
when (idx > 0) goRight;
when (rbp > minLbpP || match trailRbpP, trailLbp with some trailRbpP, some trailLbp => trailRbpP <= trailLbp | _, _ => false) $ do {
-- The recursive `visit` call, by the invariant, has moved to the next child, so move back temporarily
goLeft;
stx ← getCur;
match stx.getHeadInfo, stx.getTailInfo with
| some hi, some ti =>
@ -279,23 +280,22 @@ when (rbp > minLbpP || match trailRbpP, st.firstLbp with some trailRbpP, some fi
setCur stx
| _, _ => setCur (mkParen stx);
stx ← getCur; trace! `PrettyPrinter.parenthesize ("parenthesized: " ++ stx.formatStx none);
goLeft;
goRight;
-- after parenthesization, there is no more trailing parser
modify (fun st => { st with minLbp := appPrec, firstLbp := appPrec, trailRbp := none })
modify (fun st => { st with minLbp := appPrec, trailRbp := none })
};
{ trailRbp := trailRbpP, .. } ← get;
-- If we already had a token at this level (`st.firstLbp ≠ none`), keep the trailing parser. Otherwise, use the minimum of
-- `rbp` and `trailRbpP`.
let trailRbp := match trailRbpP, st.firstLbp with
| _, some _ => st.trailRbp
| some trailRbpP, _ => some (Nat.min trailRbpP rbp)
| _, _ => some rbp;
modify (fun stP => { stP with trailRbp := trailRbp })
modify $ fun stP => { stP with
minLbp := match trailLbp with
| some trailLbp => some (Nat.min (stP.minLbp.getD trailLbp) trailLbp)
| _ => st.minLbp,
trailRbp := match stP.trailRbp with
| some trailRbpP => some (Nat.min trailRbpP rbp)
| _ => some rbp }
/-- Set token precedence and advance to the left. -/
def visitToken (lbp : Nat) : ParenthesizerM Unit := do
modify (fun st => { st with firstLbp := lbp });
goLeft
/-- Clear `trailRbp` and advance. -/
def visitToken : Parenthesizer | p => do
modify (fun st => { st with trailRbp := none });
goRight
def evalNat (e : Expr) : ParenthesizerM Nat := do
e ← liftM $ whnf e;
@ -365,7 +365,7 @@ visit p.appArg!
@[builtinParenthesizer andthen]
def andthen.parenthesizer : Parenthesizer | p =>
visit (p.getArg! 1) *> visit (p.getArg! 0)
visit (p.getArg! 0) *> visit (p.getArg! 1)
@[builtinParenthesizer node]
def node.parenthesizer : Parenthesizer | p => do
@ -376,8 +376,12 @@ when (k != stx.getKind) $ do {
-- HACK; see `orelse.parenthesizer`
throw $ Exception.other "BACKTRACK"
};
visitArgs $ visit p.appArg!;
modify $ fun st => { st with minLbp := st.firstLbp }
visitArgs $ visit p.appArg!
@[builtinParenthesizer checkRbpLe]
def checkRbpLe.parenthesizer : Parenthesizer | p => do
prec ← evalNat $ p.getArg! 0;
addLbp prec
@[builtinParenthesizer trailingNode]
def trailingNode.parenthesizer : Parenthesizer | p => do
@ -389,35 +393,35 @@ when (k != stx.getKind) $ do {
throw $ Exception.other "BACKTRACK"
};
visitArgs $ do {
visit p.appArg!;
-- After visiting the nodes actually produced by the parser passed to `trailingNode`, we are positioned on the
-- left-most child, which is the term injected by `trailingNode` in place of the recursion. Left recursion is not an
-- issue for the parenthesizer, so we can think of this child being produced by `termParser 0`, or whichever Pratt
-- parser is calling us; we only need to know its `mkParen`, which we retrieve from the context.
some lbp ← State.firstLbp <$> get -- the trailing token's precedence; subtract 1 as described above
| panic! "trailingNode.parenthesizer: visited a trailing term without tokens?!";
{ mkParen := some mkParen, .. } ← read
| panic! "trailingNode.parenthesizer called outside of visitParenthesizable call";
visitAntiquot <|> visitParenthesizable mkParen 0;
modify $ fun st => { st with minLbp := Nat.min (st.minLbp.getD (lbp - 1)) (lbp - 1) }
{ minLbp := trailLbp, .. } ← get;
visitAntiquot <|> visitParenthesizable mkParen 0 trailLbp;
visit p.appArg!
}
@[builtinParenthesizer symbol]
def symbol.parenthesizer : Parenthesizer | p =>
evalOptPrec p.appArg! >>= visitToken
@[builtinParenthesizer checkRbpLt]
def checkRbpLt.parenthesizer : Parenthesizer | p => do
prec ← evalNat $ p.getArg! 0;
addLbp (prec - 1)
@[builtinParenthesizer symbolNoWs] def symbolNoWs.parenthesizer := symbol.parenthesizer
@[builtinParenthesizer unicodeSymbol] def unicodeSymbol.parenthesizer := symbol.parenthesizer
@[builtinParenthesizer symbol] def symbol.parenthesizer := visitToken
@[builtinParenthesizer symbolNoWs] def symbolNoWs.parenthesizer := visitToken
@[builtinParenthesizer unicodeSymbol] def unicodeSymbol.parenthesizer := visitToken
@[builtinParenthesizer identNoAntiquot] def identNoAntiquot.parenthesizer : Parenthesizer | p => visitToken appPrec
@[builtinParenthesizer rawIdent] def rawIdent.parenthesizer : Parenthesizer | p => visitToken appPrec
@[builtinParenthesizer nonReservedSymbol] def nonReservedSymbol.parenthesizer : Parenthesizer | p => visitToken appPrec
@[builtinParenthesizer identNoAntiquot] def identNoAntiquot.parenthesizer := visitToken
@[builtinParenthesizer rawIdent] def rawIdent.parenthesizer := visitToken
@[builtinParenthesizer nonReservedSymbol] def nonReservedSymbol.parenthesizer := visitToken
@[builtinParenthesizer charLitNoAntiquot] def charLitNoAntiquot.parenthesizer := identNoAntiquot.parenthesizer
@[builtinParenthesizer strLitNoAntiquot] def strLitNoAntiquot.parenthesizer := identNoAntiquot.parenthesizer
@[builtinParenthesizer nameLitNoAntiquot] def nameLitNoAntiquot.parenthesizer := identNoAntiquot.parenthesizer
@[builtinParenthesizer numLitNoAntiquot] def numLitNoAntiquot.parenthesizer := identNoAntiquot.parenthesizer
@[builtinParenthesizer fieldIdx] def fieldIdx.parenthesizer := identNoAntiquot.parenthesizer
@[builtinParenthesizer charLitNoAntiquot] def charLitNoAntiquot.parenthesizer := visitToken
@[builtinParenthesizer strLitNoAntiquot] def strLitNoAntiquot.parenthesizer := visitToken
@[builtinParenthesizer nameLitNoAntiquot] def nameLitNoAntiquot.parenthesizer := visitToken
@[builtinParenthesizer numLitNoAntiquot] def numLitNoAntiquot.parenthesizer := visitToken
@[builtinParenthesizer fieldIdx] def fieldIdx.parenthesizer := visitToken
@[builtinParenthesizer many]
def many.parenthesizer : Parenthesizer | p => do
@ -439,7 +443,7 @@ visitArgs $ visit (p.getArg! 0)
@[builtinParenthesizer sepBy]
def sepBy.parenthesizer : Parenthesizer | p => do
stx ← getCur;
visitArgs $ (List.range stx.getArgs.size).reverse.forM $ fun i => visit (p.getArg! (i % 2))
visitArgs $ (List.range stx.getArgs.size).forM $ fun i => visit (p.getArg! (i % 2))
@[builtinParenthesizer sepBy1] def sepBy1.parenthesizer := sepBy.parenthesizer
@ -460,31 +464,11 @@ visit $ mkApp (p.getArg! 0) (mkConst `sorryAx [levelZero])
@[builtinParenthesizer checkNoWsBefore] def checkNoWsBefore.parenthesizer : Parenthesizer | p => pure ()
@[builtinParenthesizer checkTailWs] def checkTailWs.parenthesizer : Parenthesizer | p => pure ()
@[builtinParenthesizer checkColGe] def checkColGe.parenthesizer : Parenthesizer | p => pure ()
@[builtinParenthesizer checkRbpLt] def checkRbpLt.parenthesizer : Parenthesizer | p => pure ()
open Lean.Parser.Command
@[builtinParenthesizer commentBody] def commentBody.parenthesizer : Parenthesizer | p => goLeft
@[builtinParenthesizer quotedSymbol] def quotedSymbol.parenthesizer : Parenthesizer | p => goLeft
@[builtinParenthesizer unquotedSymbol] def unquotedSymbol.parenthesizer : Parenthesizer | p => goLeft
section
open Lean.Parser.Term
def depArrow' := leadingNode `Lean.Parser.Term.depArrow $ bracketedBinder true >> unicodeSymbol " → " " -> " >> termParser
end
/-
`depArrow` is defined as
```
parser! bracketedBinder true >> checkRbpLe 25 "expected parentheses around dependent arrow" >> unicodeSymbol " → " " -> " >> termParser
```
There is no generally sensible parenthesizer implementation for `checkRbpLe`, so we special-case the entire
parser by ignoring `checkRbpLe` and lowering the result LBP to 25 (instead of the LBP of `bracketedBinder`, i.e.
`appPrec`). Thus terms such as `f ((a : _) -> b)` will be reparenthesized correctly since the new LBP is now lower than
the surrounding RBP (`appPrec`). -/
@[builtinParenthesizer Term.depArrow]
def depArrow.parenthesizer : Parenthesizer | p => do
visit (mkConst `Lean.PrettyPrinter.Parenthesizer.depArrow');
modify $ fun st => { st with firstLbp := some 25, minLbp := some 25 }
@[builtinParenthesizer commentBody] def commentBody.parenthesizer := visitToken
@[builtinParenthesizer quotedSymbol] def quotedSymbol.parenthesizer := visitToken
@[builtinParenthesizer unquotedSymbol] def unquotedSymbol.parenthesizer := visitToken
end Parenthesizer