feat: pp.parens option to pretty print with all parentheses (#2934)

This PR adds the option `pp.parens` (default: false) that causes the
pretty printer to eagerly insert parentheses, which can be useful for
teaching and for understanding the structure of expressions. For
example, it causes `p → q → r` to pretty print as `p → (q → r)`.

Any notations with precedence greater than or equal to `maxPrec` do not
receive such discretionary parentheses, since this precedence level is
considered to be infinity.

This option was a feature in the Lean 3 community edition.
This commit is contained in:
Kyle Miller 2024-11-15 11:11:54 -08:00 committed by GitHub
parent b1e0c1b594
commit 691acde696
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 83 additions and 6 deletions

View file

@ -24,6 +24,11 @@ register_builtin_option pp.notation : Bool := {
group := "pp"
descr := "(pretty printer) disable/enable notation (infix, mixfix, postfix operators and unicode characters)"
}
register_builtin_option pp.parens : Bool := {
defValue := false
group := "pp"
descr := "(pretty printer) if set to true, notation is wrapped in parentheses regardless of precedence"
}
register_builtin_option pp.unicode.fun : Bool := {
defValue := false
group := "pp"
@ -248,6 +253,7 @@ def getPPNatLit (o : Options) : Bool := o.get pp.natLit.name (getPPNumericTypes
def getPPCoercions (o : Options) : Bool := o.get pp.coercions.name (!getPPAll o)
def getPPExplicit (o : Options) : Bool := o.get pp.explicit.name (getPPAll o)
def getPPNotation (o : Options) : Bool := o.get pp.notation.name (!getPPAll o)
def getPPParens (o : Options) : Bool := o.get pp.parens.name pp.parens.defValue
def getPPUnicodeFun (o : Options) : Bool := o.get pp.unicode.fun.name false
def getPPMatch (o : Options) : Bool := o.get pp.match.name (!getPPAll o)
def getPPFieldNotation (o : Options) : Bool := o.get pp.fieldNotation.name (!getPPAll o)

View file

@ -8,6 +8,7 @@ import Lean.Parser.Extension
import Lean.Parser.StrInterpolation
import Lean.ParserCompiler.Attribute
import Lean.PrettyPrinter.Basic
import Lean.PrettyPrinter.Delaborator.Options
/-!
@ -82,8 +83,10 @@ namespace PrettyPrinter
namespace Parenthesizer
structure Context where
-- We need to store this `categoryParser` argument to deal with the implicit Pratt parser call in `trailingNode.parenthesizer`.
/-- We need to store this `categoryParser` argument to deal with the implicit Pratt parser call in `trailingNode.parenthesizer`. -/
cat : Name := Name.anonymous
/-- Whether to add parentheses regardless of any other conditions. This is cached from the `pp.parens` option. -/
forceParens : Bool := false
structure State where
stxTrav : Syntax.Traverser
@ -217,8 +220,13 @@ def maybeParenthesize (cat : Name) (canJuxtapose : Bool) (mkParen : Syntax → S
let { minPrec := some minPrec, trailPrec := trailPrec, trailCat := trailCat, .. } ← get
| trace[PrettyPrinter.parenthesize] "visited a syntax tree without precedences?!{line ++ format stx}"
trace[PrettyPrinter.parenthesize] (m!"...precedences are {prec} >? {minPrec}" ++ if canJuxtapose then m!", {(trailPrec, trailCat)} <=? {(st.contPrec, st.contCat)}" else "")
-- Should we parenthesize?
if (prec > minPrec || canJuxtapose && match trailPrec, st.contPrec with | some trailPrec, some contPrec => trailCat == st.contCat && trailPrec <= contPrec | _, _ => false) then
/- Should we parenthesize?
* Note about forceParens mode: we don't insert outermost parentheses (we use the syntax traverser parents to detect this),
and we don't insert parentheses when we are at `maxPrec` (since this is effectively infinity).
-/
if (((← read).forceParens && !st.stxTrav.parents.isEmpty && minPrec < Parser.maxPrec)
|| prec > minPrec
|| canJuxtapose && match trailPrec, st.contPrec with | some trailPrec, some contPrec => trailCat == st.contCat && trailPrec <= contPrec | _, _ => false) then
-- The recursive `visit` call, by the invariant, has moved to the preceding node. In order to parenthesize
-- the original node, we must first move to the right, except if we already were at the left-most child in the first
-- place.
@ -540,16 +548,23 @@ instance : Coe (Parenthesizer → Parenthesizer → Parenthesizer) Parenthesizer
end Parenthesizer
open Parenthesizer
/-- Add necessary parentheses in `stx` parsed by `parser`. -/
/--
Adds necessary parentheses in `stx` parsed by `parser`.
-/
def parenthesize (parenthesizer : Parenthesizer) (stx : Syntax) : CoreM Syntax := do
trace[PrettyPrinter.parenthesize.input] "{format stx}"
let opts ← getOptions
catchInternalId backtrackExceptionId
(do
let (_, st) ← (parenthesizer {}).run { stxTrav := Syntax.Traverser.fromSyntax stx }
let (_, st) ← (parenthesizer { forceParens := getPPParens opts }).run { stxTrav := Syntax.Traverser.fromSyntax stx }
pure st.stxTrav.cur)
(fun _ => throwError "parenthesize: uncaught backtrack exception")
def parenthesizeCategory (cat : Name) := parenthesize <| categoryParser.parenthesizer cat 0
/--
Adds necessary parentheses to the syntax in the given category (for example, `term`, `tactic`, or `command`).
-/
def parenthesizeCategory (cat : Name) (stx : Syntax) :=
parenthesize (categoryParser.parenthesizer cat 0) stx
def parenthesizeTerm := parenthesizeCategory `term
def parenthesizeTactic := parenthesizeCategory `tactic

View file

@ -0,0 +1,56 @@
/-!
# Tests for the `pp.parens` pretty printing option
-/
set_option pp.parens true
/-!
No parentheses around numeral.
-/
/-- info: 1 : Nat -/
#guard_msgs in #check 1
/-!
No parentheses around variable.
-/
/-- info: x : Nat -/
#guard_msgs in variable (x : Nat) in #check x
/-!
No parentheses around each individual function application.
-/
def f (x y z : Nat) : Nat := x + y + z
/-- info: f 1 2 3 : Nat -/
#guard_msgs in #check f 1 2 3
/-!
Example arithmetic expressions
-/
/-- info: (1 + (2 * 3)) + 4 : Nat -/
#guard_msgs in #check 1 + 2 * 3 + 4
/-- info: Nat.add_assoc : ∀ (n m k : Nat), (((n + m) + k) = (n + (m + k))) -/
#guard_msgs in #check (Nat.add_assoc)
/-!
Implication chains
-/
/-- info: p → (q → r) : Prop -/
#guard_msgs in variable (p q r : Prop) in #check p → q → r
/-!
No parentheses around list literals
-/
/-- info: [1, 2, 3] ++ [3, 4, 5] : List Nat -/
#guard_msgs in #check [1,2,3] ++ [3,4,5]
/-!
Parentheses around body of forall.
-/
/-- info: ∀ (p : (Nat → (Nat → Prop))), (p (1 + 2) 3) : Prop -/
#guard_msgs in #check ∀ (p : Nat → Nat → Prop), p (1 + 2) 3
/-!
Parentheses around branches of `if`.
-/
/-- info: if True then (1 + 2) else (2 + 3) : Nat -/
#guard_msgs in #check if True then 1 + 2 else 2 + 3