From 1fe192802b3b909f33df4bc80898299ee0835135 Mon Sep 17 00:00:00 2001 From: Sebastian Ullrich Date: Wed, 29 Jul 2020 11:52:17 +0200 Subject: [PATCH] fix: parenthesizer --- src/Lean/PrettyPrinter/Parenthesizer.lean | 138 ++++++++++++---------- tests/lean/PPRoundtrip.lean | 5 +- tests/lean/PPRoundtrip.lean.expected.out | 14 +++ 3 files changed, 95 insertions(+), 62 deletions(-) diff --git a/src/Lean/PrettyPrinter/Parenthesizer.lean b/src/Lean/PrettyPrinter/Parenthesizer.lean index 60ab9ff759..178389ea1b 100644 --- a/src/Lean/PrettyPrinter/Parenthesizer.lean +++ b/src/Lean/PrettyPrinter/Parenthesizer.lean @@ -19,16 +19,16 @@ parsers defined via `Lean.Parser.prattParser`, which includes both aforementione parenthesizers can be added for new node kinds, but the data collected in the implementation below might not be appropriate for other parenthesization strategies. -Usages of a parser defined via `prattParser` in general have the form `p rbp`, where `rbp` is the right-binding power. -Recall that a Pratt parser greedily runs a leading parser with precedence at least `rbp` (otherwise it fails) followed -by zero or more trailing parsers with precedence *higher* than `rbp`; the precedence of a parser is encoded by an -initial call to `checkRbpLe/Lt`, respectively. Thus we should parenthesize a syntax node `stx` supposedly produced by -`p rbp` if +Usages of a parser defined via `prattParser` in general have the form `p prec`, where `prec` is the minimum precedence +or binding power. Recall that a Pratt parser greedily runs a leading parser with precedence at least `prec` (otherwise +it fails) followed by zero or more trailing parsers with precedence at least `prec`; the precedence of a parser is +encoded in the call to `leadingNode/trailingNode`, respectively. Thus we should parenthesize a syntax node `stx` +supposedly produced by `p prec` if -1. the leading/any trailing parser involved in `stx` has precedence < `rbp`/<= `rbp`, respectively (because without - parentheses, `p rbp` would not produce all of `stx`), or -2. the trailing parser parsing the input to *the right of* `stx`, if any, has precedence > `rbp` (because without - parentheses, `p rbp` would have parsed it as well and made it a part of `stx`). +1. the leading/any trailing parser involved in `stx` has precedence < `prec` (because without parentheses, `p prec` + would not produce all of `stx`), or +2. the trailing parser parsing the input to *the right of* `stx`, if any, has precedence >= `prec` (because without + parentheses, `p prec` would have parsed it as well and made it a part of `stx`). Note that in case 2, it is also sufficient to parenthesize a *parent* node as long as the offending parser is still to the right of that node. For example, imagine the tree structure of `(f $ fun x => x) y` without parentheses. We need to @@ -41,15 +41,11 @@ RHS (0) again is smaller than that of `y`. So it's better to only parenthesize t We transform the syntax tree and collect the necessary precedence information for that in a single traversal over the syntax tree and the parser (as a `Lean.Expr`) that produced it. The traversal is right-to-left to cover case 2. More -specifically, for every Pratt parser call, we store as monadic state the (current) first and minimum precedence of any -parser (`firstLbp`/`minLbp`) in this call, if any, and the precedence of the nested trailing Pratt parser call -(`trailRbp`), if any. We subtract 1 from the precedence of trailing parsers so that we don't have to differentiate -between leading and trailing parsers in `minLbp`. If `stP` is the state resulting from the traversal of a Pratt parser -call `p rbp`, and `st` is the state of the surrounding call, we parenthesize if `rbp > stP.minLbp` (case 1) or if -`stP.trailRbp < st.firstLbp` (case 2). Note that because trailing parsers are encoded as -`checkRblLt lbp >> trailingNode p`, when we check if we should parenthesize the parser's LHS (the first child in the -node) inside `trailingNode`, `st.firstLbp` is actually already set to the trailing parser's precedence even though we -are doing a left-to-right traversal. +specifically, for every Pratt parser call, we store as monadic state the precedence of the left-most trailing parser and +the minimum precedence of any parser (`contPrec`/`minPrec`) in this call, if any, and the precedence of the nested +trailing Pratt parser call (`trailPrec`), if any. If `stP` is the state resulting from the traversal of a Pratt parser +call `p prec`, and `st` is the state of the surrounding call, we parenthesize if `prec > stP.minPrec` (case 1) or if +`stP.trailPrec <= st.contPrec` (case 2). The primary traversal is over the parser `Expr`. The `visit` function takes such a parser and, if it is the application of a constant `c`, looks for a `[parenthesizer c]` declaration. If it exists, we run it, which might again call `visit`. @@ -61,13 +57,16 @@ The traversal over the `Syntax` object is complicated by the fact that a parser node, but an arbitrary (but constant, for each parser) amount that it pushes on top of the parser stack. This amount can even be zero for parsers such as `checkWsBefore`. Thus we cannot simply pass and return a `Syntax` object to and from `visit`. Instead, we use a `Syntax.Traverser` that allows arbitrary movement and modification inside the syntax tree. -Our traversal invariant is that a parser interpreter should stop at the syntax object to the *right* of all syntax -objects its parser produced. +Our traversal invariant is that a parser interpreter should stop at the syntax object to the *left* of all syntax +objects its parser produced, except when it is already at the left-most child. This special case is not an issue in +practice since if there is another parser to the left that produced zero nodes in this case, it should always do so, so +there is no danger of the left-most child being processed multiple times. Ultimately, most parenthesizers are implemented via three primitives that do all the actual syntax traversal: -`visitParenthesizable mkParen rbp` recurses on the current node and afterwards transforms it with `mkParen` if the above -condition for `p rbp` is fulfilled. `goRight` advances to the next syntax sibling and is used on atoms. `visitArgs x` executes -`x` on the first child of the current node and then advances to the next sibling (of the original current node). +`visitParenthesizable mkParen prec` recurses on the current node and afterwards transforms it with `mkParen` if the above +condition for `p prec` is fulfilled. `visitToken` advances to the preceding sibling and is used on atoms. `visitArgs x` +executes `x` on the last child of the current node and then advances to the preceding sibling (of the original current +node). -/ @@ -85,10 +84,14 @@ structure Context := structure State := (stxTrav : Syntax.Traverser) +--- precedence of the current left-most trailing parser, if any; see module doc for details +(contPrec : Option Nat := none) -- current minimum precedence in this Pratt parser call, if any; see module doc for details -(minLbp : Option Nat := none) +(minPrec : Option Nat := none) -- precedence of the trailing Pratt parser call if any; see module doc for details -(trailRbp : Option Nat := none) +(trailPrec : Option Nat := none) +-- true iff we have already visited a token on this parser level; used for detecting trailing parsers +(visitedToken : Bool := false) end Parenthesizer @@ -124,15 +127,15 @@ instance ParenthesizerM.monadTraverser : Syntax.MonadTraverser ParenthesizerM := open Syntax.MonadTraverser -def addLbp (lbp : Nat) : ParenthesizerM Unit := -modify $ fun st => { st with minLbp := Nat.min (st.minLbp.getD lbp) lbp } +def addPrecCheck (prec : Nat) : ParenthesizerM Unit := +modify $ fun st => { st with contPrec := prec, minPrec := Nat.min (st.minPrec.getD prec) prec } /-- Execute `x` at the right-most child of the current node, if any, then advance to the left. -/ def visitArgs (x : ParenthesizerM Unit) : ParenthesizerM Unit := do stx ← getCur; when (stx.getArgs.size > 0) $ - goDown 0 *> x <* goUp; -goRight + goDown (stx.getArgs.size - 1) *> x <* goUp; +goLeft /-- Call an appropriate `[parenthesizer]` depending on the `Parser` `Expr` `p`. After the call, the traverser position @@ -174,29 +177,31 @@ def visitAntiquot : ParenthesizerM Unit := do stx ← getCur; if Elab.Term.Quotation.isAntiquot stx then visitArgs $ do -- antiquot syntax is, simplified, `syntax:maxPrec "$" "$"* antiquotExpr ":" (nonReservedSymbol name) "*"?` - goRight; goRight; -- now on `antiquotExpr` + goLeft; goLeft; goLeft; -- now on `antiquotExpr` visit (mkConst `Lean.Parser.antiquotExpr); - addLbp maxPrec + addPrecCheck maxPrec else throw $ Exception.other $ "not an antiquotation" /-- Recurse using `visit`, and parenthesize the result using `mkParen` if necessary. -/ -def visitParenthesizable (mkParen : Syntax → Syntax) (rbp : Nat) (trailLbp : Option Nat := none) : ParenthesizerM Unit := do +def visitParenthesizable (mkParen : Syntax → Syntax) (prec : Nat) : ParenthesizerM Unit := do stx ← getCur; idx ← getIdx; st ← get; --- reset lbp/rbp and store `mkParen` for the recursive call +-- reset prec/prec and store `mkParen` for the recursive call set { stxTrav := st.stxTrav }; adaptReader (fun (ctx : Context) => { ctx with mkParen := some mkParen }) $ -- we assume that each node kind is produced by a 0-ary parser of the same name visit (mkConst stx.getKind); -{ minLbp := some minLbpP, trailRbp := trailRbpP, .. } ← get +{ minPrec := some minPrec, trailPrec := trailPrec, .. } ← get | panic! "visitParenthesizable: visited a term without tokens?!"; -trace! `PrettyPrinter.parenthesize ("...precedences are " ++ fmt rbp ++ " >? " ++ fmt minLbpP ++ ", " ++ fmt trailRbpP ++ " <=? " ++ fmt trailLbp); +trace! `PrettyPrinter.parenthesize ("...precedences are " ++ fmt prec ++ " >? " ++ fmt minPrec ++ ", " ++ fmt trailPrec ++ " <=? " ++ fmt st.contPrec); -- Should we parenthesize? -when (rbp > minLbpP || match trailRbpP, trailLbp with some trailRbpP, some trailLbp => trailRbpP <= trailLbp | _, _ => false) $ do { - -- The recursive `visit` call, by the invariant, has moved to the next child, so move back temporarily - goLeft; +when (prec > minPrec || match trailPrec, st.contPrec with some trailPrec, some contPrec => trailPrec <= contPrec | _, _ => false) $ do { + -- The recursive `visit` call, by the invariant, has moved to the preceding node. In order to parenthesize + -- the original node, we must first move to the right, except if we already were at the left-most child in the first + -- place. + when (idx > 0) goRight; stx ← getCur; match stx.getHeadInfo, stx.getTailInfo with | some hi, some ti => @@ -207,22 +212,22 @@ when (rbp > minLbpP || match trailRbpP, trailLbp with some trailRbpP, some trail setCur stx | _, _ => setCur (mkParen stx); stx ← getCur; trace! `PrettyPrinter.parenthesize ("parenthesized: " ++ stx.formatStx none); - goRight; + goLeft; -- after parenthesization, there is no more trailing parser - modify (fun st => { st with minLbp := maxPrec, trailRbp := none }) + modify (fun st => { st with contPrec := maxPrec, trailPrec := none }) }; -modify $ fun stP => { stP with - minLbp := match trailLbp with - | some trailLbp => some (Nat.min (stP.minLbp.getD trailLbp) trailLbp) - | _ => st.minLbp, - trailRbp := match stP.trailRbp with - | some trailRbpP => some (Nat.min trailRbpP rbp) - | _ => some rbp } +{ trailPrec := trailPrec, .. } ← get; +-- If we already had a token at this level, keep the trailing parser. Otherwise, use the minimum of +-- `prec` and `trailPrec`. +let trailPrec := if st.visitedToken then st.trailPrec else match trailPrec with + | some trailPrec => some (Nat.min trailPrec prec) + | _ => some prec; +modify (fun stP => { stP with minPrec := st.minPrec, trailPrec := trailPrec }) -/-- Clear `trailRbp` and advance. -/ +/-- Adjust state and advance. -/ def visitToken : Parenthesizer | p => do -modify (fun st => { st with trailRbp := none }); -goRight +modify (fun st => { st with contPrec := none, visitedToken := true }); +goLeft def evalNat (e : Expr) : ParenthesizerM Nat := do e ← liftM $ whnf e; @@ -264,18 +269,18 @@ stx ← getCur; if stx.getKind == nullKind then throw $ Exception.other "BACKTRACK" else do - lbp ← evalNat p.appArg!; - visitParenthesizable (fun stx => Unhygienic.run `(($stx))) lbp + prec ← evalNat p.appArg!; + visitParenthesizable (fun stx => Unhygienic.run `(($stx))) prec @[builtinParenthesizer tacticParser] def tacticParser.parenthesizer : Parenthesizer | p => visitAntiquot <|> do -lbp ← evalNat p.appArg!; -visitParenthesizable (fun stx => Unhygienic.run `(tactic|($stx))) lbp +prec ← evalNat p.appArg!; +visitParenthesizable (fun stx => Unhygienic.run `(tactic|($stx))) prec @[builtinParenthesizer levelParser] def levelParser.parenthesizer : Parenthesizer | p => visitAntiquot <|> do -lbp ← evalNat p.appArg!; -visitParenthesizable (fun stx => Unhygienic.run `(level|($stx))) lbp +prec ← evalNat p.appArg!; +visitParenthesizable (fun stx => Unhygienic.run `(level|($stx))) prec @[builtinParenthesizer categoryParser] def categoryParser.parenthesizer : Parenthesizer | p => visitAntiquot <|> do @@ -292,7 +297,7 @@ visit p.appArg! @[builtinParenthesizer andthen] def andthen.parenthesizer : Parenthesizer | p => -visit (p.getArg! 0) *> visit (p.getArg! 1) +visit (p.getArg! 1) *> visit (p.getArg! 0) @[builtinParenthesizer node] def node.parenthesizer : Parenthesizer | p => do @@ -308,7 +313,17 @@ visitArgs $ visit p.appArg! @[builtinParenthesizer checkPrec] def checkPrec.parenthesizer : Parenthesizer | p => do prec ← evalNat $ p.getArg! 0; -addLbp prec +addPrecCheck prec + +@[builtinParenthesizer leadingNode] +def leadingNode.parenthesizer : Parenthesizer | p => do +-- Unfold `leadingNode` as usual, but limit `contPrec` to `maxPrec-1` afterwards. +-- This is because `maxPrec-1` is the precedence of function application, which is the only way to turn a leading parser +-- into a trailing one. +some p ← liftM $ unfoldDefinition? p + | unreachable!; +visit p; +modify $ fun st => { st with contPrec := (fun p => Nat.min (maxPrec-1) p) <$> st.contPrec } @[builtinParenthesizer trailingNode] def trailingNode.parenthesizer : Parenthesizer | p => do @@ -321,14 +336,15 @@ when (k != stx.getKind) $ do { throw $ Exception.other "BACKTRACK" }; visitArgs $ do { + visit p.appArg!; + addPrecCheck prec; -- After visiting the nodes actually produced by the parser passed to `trailingNode`, we are positioned on the -- left-most child, which is the term injected by `trailingNode` in place of the recursion. Left recursion is not an -- issue for the parenthesizer, so we can think of this child being produced by `termParser 0`, or whichever Pratt -- parser is calling us; we only need to know its `mkParen`, which we retrieve from the context. { mkParen := some mkParen, .. } ← read | panic! "trailingNode.parenthesizer called outside of visitParenthesizable call"; - visitAntiquot <|> visitParenthesizable mkParen 0 prec; - visit p.appArg! + visitAntiquot <|> visitParenthesizable mkParen 0 } @[builtinParenthesizer symbol] def symbol.parenthesizer := visitToken @@ -365,7 +381,7 @@ visitArgs $ visit (p.getArg! 0) @[builtinParenthesizer sepBy] def sepBy.parenthesizer : Parenthesizer | p => do stx ← getCur; -visitArgs $ (List.range stx.getArgs.size).forM $ fun i => visit (p.getArg! (i % 2)) +visitArgs $ (List.range stx.getArgs.size).reverse.forM $ fun i => visit (p.getArg! (i % 2)) @[builtinParenthesizer sepBy1] def sepBy1.parenthesizer := sepBy.parenthesizer diff --git a/tests/lean/PPRoundtrip.lean b/tests/lean/PPRoundtrip.lean index 8c00ea9b3a..e931c2216d 100644 --- a/tests/lean/PPRoundtrip.lean +++ b/tests/lean/PPRoundtrip.lean @@ -50,7 +50,8 @@ def typeAs.{u} (α : Type u) (a : α) := () #eval check `(fun {a b : Nat} => a) -- implicit lambdas work as long as the expected type is preserved #eval check `(typeAs ({α : Type} → (a : α) → α) fun a => a) -section set_option pp.explicit true +section + set_option pp.explicit true #eval check `(fun {α : Type} [HasToString α] (a : α) => toString a) end @@ -71,3 +72,5 @@ end #eval check `((1,2).fst) #eval check `(1 < 2 || true) + +#eval check `(id (fun a => a) 0) diff --git a/tests/lean/PPRoundtrip.lean.expected.out b/tests/lean/PPRoundtrip.lean.expected.out index 7c81304b56..7c23ca06a3 100644 --- a/tests/lean/PPRoundtrip.lean.expected.out +++ b/tests/lean/PPRoundtrip.lean.expected.out @@ -135,3 +135,17 @@ (null (Term.app (Term.id `HasLess.Less (null)) (null (Term.num (numLit "1")) (Term.num (numLit "2")))) (null)) ")") (Term.id `Bool.true (null)))) +(Term.app + (Term.id `id (null)) + (null + (Term.paren + "(" + (null + (Term.fun + "fun" + (null (Term.paren "(" (null (Term.id `a (null)) (null (Term.typeAscription ":" (Term.id `Nat (null))))) ")")) + "=>" + (Term.id `a (null))) + (null)) + ")") + (Term.num (numLit "0"))))