diff --git a/src/Lean/PrettyPrinter/Formatter.lean b/src/Lean/PrettyPrinter/Formatter.lean
index d4cb7e4d2e..669c3c689a 100644
--- a/src/Lean/PrettyPrinter/Formatter.lean
+++ b/src/Lean/PrettyPrinter/Formatter.lean
@@ -108,6 +108,9 @@ modify fun st => { st with stack := stack }
def push (f : Format) : FormatterM Unit :=
modify fun st => { st with stack := st.stack.push f }
+def pushLine : FormatterM Unit :=
+push Format.line
+
/-- Execute `x` at the right-most child of the current node, if any, then advance to the left. -/
def visitArgs (x : FormatterM Unit) : FormatterM Unit := do
stx ← getCur;
@@ -219,6 +222,13 @@ ctx ← read;
env ← getEnv;
pure $ Parser.tokenFn { input := s, fileName := "", fileMap := FileMap.ofString "", prec := 0, env := env, tokens := ctx.table } (Parser.mkParserState s)
+def pushTokenCore (tk : String) : FormatterM Unit :=
+if tk.trimRight == tk then
+ push tk
+else do
+ pushLine;
+ push tk.trimRight
+
def pushToken (tk : String) : FormatterM Unit := do
st ← get;
-- If there is no space between `tk` and the next word, compare parsing `tk` with and without the next word
@@ -228,15 +238,15 @@ if st.leadWord != "" && tk.trimRight == tk then do
if t1.pos == t2.pos then do
-- same result => use `tk` as is, extend `leadWord` if not prefixed by whitespace
modify fun st => { st with leadWord := if tk.trimLeft == tk then tk ++ st.leadWord else "" };
- push tk
+ pushTokenCore tk
else do
-- different result => add space
modify fun st => { st with leadWord := if tk.trimLeft == tk then tk else "" };
- push $ tk ++ " "
+ pushTokenCore $ tk ++ " "
else do {
-- already separated => use `tk` as is
modify fun st => { st with leadWord := if tk.trimLeft == tk then tk else "" };
- push tk
+ pushTokenCore tk
}
@[combinatorFormatter symbol]
@@ -367,7 +377,7 @@ push " "
@[combinatorFormatter Lean.Parser.checkOutsideQuot] def checkOutsideQuot.formatter : Formatter := pure ()
@[combinatorFormatter Lean.Parser.skip] def skip.formatter : Formatter := pure ()
-@[combinatorFormatter Lean.Parser.ppSpace] def ppSpace.formatter : Formatter := push " "
+@[combinatorFormatter Lean.Parser.ppSpace] def ppSpace.formatter : Formatter := pushLine
@[combinatorFormatter Lean.Parser.ppLine] def ppLine.formatter : Formatter := push "\n"
@[combinatorFormatter pushNone] def pushNone.formatter : Formatter := goLeft
@@ -393,7 +403,7 @@ table ← Parser.builtinTokenTable.get;
catchInternalId backtrackExceptionId
(do
(_, st) ← (formatter { table := table }).run { stxTrav := Syntax.Traverser.fromSyntax stx };
- pure $ st.stack.get! 0)
+ pure $ Format.group $ st.stack.get! 0)
(fun _ => throwError "format: uncaught backtrack exception")
def formatTerm := format $ categoryParser.formatter `term
diff --git a/tests/lean/PPRoundtrip.lean b/tests/lean/PPRoundtrip.lean
index d0eb874757..bc23693191 100644
--- a/tests/lean/PPRoundtrip.lean
+++ b/tests/lean/PPRoundtrip.lean
@@ -9,21 +9,23 @@ open Lean.Format
open Lean.Meta
def check (stx : TermElabM Syntax) (optionsPerPos : OptionsPerPos := {}) : TermElabM Unit := do
- stx ← stx;
- e ← elabTermAndSynthesize stx none <* throwErrorIfErrors;
- stx' ← liftMetaM $ delab Name.anonymous [] e optionsPerPos;
- stx' ← liftCoreM $ PrettyPrinter.parenthesizeTerm stx';
- f' ← liftCoreM $ PrettyPrinter.formatTerm stx';
- IO.println $ toString f';
- env ← getEnv;
- match Parser.runParserCategory env `term (toString f') "" with
- | Except.error e => throwErrorAt stx e
- | Except.ok stx'' => do
- e' ← elabTermAndSynthesize stx'' none <* throwErrorIfErrors;
- unlessM (isDefEq e e') $
- throwErrorAt stx (fmt "failed to round-trip" ++ line ++ fmt e ++ line ++ fmt e')
+opts ← getOptions;
+stx ← stx;
+e ← elabTermAndSynthesize stx none <* throwErrorIfErrors;
+stx' ← liftMetaM $ delab Name.anonymous [] e optionsPerPos;
+stx' ← liftCoreM $ PrettyPrinter.parenthesizeTerm stx';
+f' ← liftCoreM $ PrettyPrinter.formatTerm stx';
+IO.println $ f'.pretty opts;
+env ← getEnv;
+match Parser.runParserCategory env `term (toString f') "" with
+| Except.error e => throwErrorAt stx e
+| Except.ok stx'' => do
+ e' ← elabTermAndSynthesize stx'' none <* throwErrorIfErrors;
+ unlessM (isDefEq e e') $
+ throwErrorAt stx (fmt "failed to round-trip" ++ line ++ fmt e ++ line ++ fmt e')
-- set_option trace.PrettyPrinter.parenthesize true
+set_option format.width 20
-- #eval check `(?m) -- fails round-trip
diff --git a/tests/lean/PPRoundtrip.lean.expected.out b/tests/lean/PPRoundtrip.lean.expected.out
index 18a7c50159..86001cc62f 100644
--- a/tests/lean/PPRoundtrip.lean.expected.out
+++ b/tests/lean/PPRoundtrip.lean.expected.out
@@ -15,23 +15,60 @@ id.{2} Nat
id (@id Type Nat)
fun (a : Nat) => a
fun (a b : Nat) => a
-fun (a : Nat) (b : Bool) => a
+fun
+(a :
+Nat)
+(b :
+Bool) =>
+a
fun {a b : Nat} => a
-typeAs ({α : Type} → α → α) fun {α : Type} (a : α) => a
-fun {α : Type} [inst : HasToString α] (a : α) => @toString α inst a
+typeAs ({α :
+Type} →
+α →
+α) fun
+{α :
+Type}
+(a :
+α) =>
+a
+fun
+{α :
+Type}
+[inst :
+HasToString α]
+(a :
+α) =>
+@toString α inst a
(α : Type) → α
(α β : Type) → α
Type → Type → Type
(α : Type) → α → α
-(α : Type) → (a : α) → a = a
+(α :
+Type) →
+(a :
+α) →
+a =
+a
{α : Type} → α
-{α : Type} → [inst : HasToString α] → α
+{α :
+Type} →
+[inst :
+HasToString α] →
+α
0
1
42
"hi"
-{ type := Nat, val := 0 : PointedType }
+{
+type :=
+Nat,
+val :=
+0 :
+PointedType }
(1, 2, 3)
(1, 2).fst
1 < 2 || true
-id (fun (a : Nat) => a) 0
+id (fun
+(a :
+Nat) =>
+a) 0