diff --git a/src/Lean/PrettyPrinter/Formatter.lean b/src/Lean/PrettyPrinter/Formatter.lean index d4cb7e4d2e..669c3c689a 100644 --- a/src/Lean/PrettyPrinter/Formatter.lean +++ b/src/Lean/PrettyPrinter/Formatter.lean @@ -108,6 +108,9 @@ modify fun st => { st with stack := stack } def push (f : Format) : FormatterM Unit := modify fun st => { st with stack := st.stack.push f } +def pushLine : FormatterM Unit := +push Format.line + /-- Execute `x` at the right-most child of the current node, if any, then advance to the left. -/ def visitArgs (x : FormatterM Unit) : FormatterM Unit := do stx ← getCur; @@ -219,6 +222,13 @@ ctx ← read; env ← getEnv; pure $ Parser.tokenFn { input := s, fileName := "", fileMap := FileMap.ofString "", prec := 0, env := env, tokens := ctx.table } (Parser.mkParserState s) +def pushTokenCore (tk : String) : FormatterM Unit := +if tk.trimRight == tk then + push tk +else do + pushLine; + push tk.trimRight + def pushToken (tk : String) : FormatterM Unit := do st ← get; -- If there is no space between `tk` and the next word, compare parsing `tk` with and without the next word @@ -228,15 +238,15 @@ if st.leadWord != "" && tk.trimRight == tk then do if t1.pos == t2.pos then do -- same result => use `tk` as is, extend `leadWord` if not prefixed by whitespace modify fun st => { st with leadWord := if tk.trimLeft == tk then tk ++ st.leadWord else "" }; - push tk + pushTokenCore tk else do -- different result => add space modify fun st => { st with leadWord := if tk.trimLeft == tk then tk else "" }; - push $ tk ++ " " + pushTokenCore $ tk ++ " " else do { -- already separated => use `tk` as is modify fun st => { st with leadWord := if tk.trimLeft == tk then tk else "" }; - push tk + pushTokenCore tk } @[combinatorFormatter symbol] @@ -367,7 +377,7 @@ push " " @[combinatorFormatter Lean.Parser.checkOutsideQuot] def checkOutsideQuot.formatter : Formatter := pure () @[combinatorFormatter Lean.Parser.skip] def skip.formatter : Formatter := pure () -@[combinatorFormatter Lean.Parser.ppSpace] def ppSpace.formatter : Formatter := push " " +@[combinatorFormatter Lean.Parser.ppSpace] def ppSpace.formatter : Formatter := pushLine @[combinatorFormatter Lean.Parser.ppLine] def ppLine.formatter : Formatter := push "\n" @[combinatorFormatter pushNone] def pushNone.formatter : Formatter := goLeft @@ -393,7 +403,7 @@ table ← Parser.builtinTokenTable.get; catchInternalId backtrackExceptionId (do (_, st) ← (formatter { table := table }).run { stxTrav := Syntax.Traverser.fromSyntax stx }; - pure $ st.stack.get! 0) + pure $ Format.group $ st.stack.get! 0) (fun _ => throwError "format: uncaught backtrack exception") def formatTerm := format $ categoryParser.formatter `term diff --git a/tests/lean/PPRoundtrip.lean b/tests/lean/PPRoundtrip.lean index d0eb874757..bc23693191 100644 --- a/tests/lean/PPRoundtrip.lean +++ b/tests/lean/PPRoundtrip.lean @@ -9,21 +9,23 @@ open Lean.Format open Lean.Meta def check (stx : TermElabM Syntax) (optionsPerPos : OptionsPerPos := {}) : TermElabM Unit := do - stx ← stx; - e ← elabTermAndSynthesize stx none <* throwErrorIfErrors; - stx' ← liftMetaM $ delab Name.anonymous [] e optionsPerPos; - stx' ← liftCoreM $ PrettyPrinter.parenthesizeTerm stx'; - f' ← liftCoreM $ PrettyPrinter.formatTerm stx'; - IO.println $ toString f'; - env ← getEnv; - match Parser.runParserCategory env `term (toString f') "" with - | Except.error e => throwErrorAt stx e - | Except.ok stx'' => do - e' ← elabTermAndSynthesize stx'' none <* throwErrorIfErrors; - unlessM (isDefEq e e') $ - throwErrorAt stx (fmt "failed to round-trip" ++ line ++ fmt e ++ line ++ fmt e') +opts ← getOptions; +stx ← stx; +e ← elabTermAndSynthesize stx none <* throwErrorIfErrors; +stx' ← liftMetaM $ delab Name.anonymous [] e optionsPerPos; +stx' ← liftCoreM $ PrettyPrinter.parenthesizeTerm stx'; +f' ← liftCoreM $ PrettyPrinter.formatTerm stx'; +IO.println $ f'.pretty opts; +env ← getEnv; +match Parser.runParserCategory env `term (toString f') "" with +| Except.error e => throwErrorAt stx e +| Except.ok stx'' => do + e' ← elabTermAndSynthesize stx'' none <* throwErrorIfErrors; + unlessM (isDefEq e e') $ + throwErrorAt stx (fmt "failed to round-trip" ++ line ++ fmt e ++ line ++ fmt e') -- set_option trace.PrettyPrinter.parenthesize true +set_option format.width 20 -- #eval check `(?m) -- fails round-trip diff --git a/tests/lean/PPRoundtrip.lean.expected.out b/tests/lean/PPRoundtrip.lean.expected.out index 18a7c50159..86001cc62f 100644 --- a/tests/lean/PPRoundtrip.lean.expected.out +++ b/tests/lean/PPRoundtrip.lean.expected.out @@ -15,23 +15,60 @@ id.{2} Nat id (@id Type Nat) fun (a : Nat) => a fun (a b : Nat) => a -fun (a : Nat) (b : Bool) => a +fun +(a : +Nat) +(b : +Bool) => +a fun {a b : Nat} => a -typeAs ({α : Type} → α → α) fun {α : Type} (a : α) => a -fun {α : Type} [inst : HasToString α] (a : α) => @toString α inst a +typeAs ({α : +Type} → +α → +α) fun +{α : +Type} +(a : +α) => +a +fun +{α : +Type} +[inst : +HasToString α] +(a : +α) => +@toString α inst a (α : Type) → α (α β : Type) → α Type → Type → Type (α : Type) → α → α -(α : Type) → (a : α) → a = a +(α : +Type) → +(a : +α) → +a = +a {α : Type} → α -{α : Type} → [inst : HasToString α] → α +{α : +Type} → +[inst : +HasToString α] → +α 0 1 42 "hi" -{ type := Nat, val := 0 : PointedType } +{ +type := +Nat, +val := +0 : +PointedType } (1, 2, 3) (1, 2).fst 1 < 2 || true -id (fun (a : Nat) => a) 0 +id (fun +(a : +Nat) => +a) 0