/- Copyright (c) 2020 Sebastian Ullrich. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. Authors: Sebastian Ullrich -/ /-! The formatter turns a `Syntax` tree into a `Format` object, inserting both mandatory whitespace (to separate adjacent tokens) as well as "pretty" optional whitespace. The basic approach works much like the parenthesizer: A right-to-left traversal over the syntax tree, driven by parser-specific handlers registered via attributes. The traversal is right-to-left so that when emitting a token, we already know the text following it and can decide whether or not whitespace between the two is necessary. -/ import Lean.CoreM import Lean.Parser.Extension import Lean.KeyedDeclsAttribute import Lean.ParserCompiler.Attribute import Lean.PrettyPrinter.Backtrack namespace Lean namespace PrettyPrinter namespace Formatter structure Context := (options : Options) (table : Parser.TokenTable) structure State := (stxTrav : Syntax.Traverser) -- Textual content of `stack` up to the first whitespace (not enclosed in an escaped ident). We assume that the textual -- content of `stack` is modified only by `pushText` and `pushLine`, so `leadWord` is adjusted there accordingly. (leadWord : String := "") -- Stack of generated Format objects, analogous to the Syntax stack in the parser. -- Note, however, that the stack is reversed because of the right-to-left traversal. (stack : Array Format := #[]) end Formatter abbrev FormatterM := ReaderT Formatter.Context $ StateRefT Formatter.State $ CoreM @[inline] def FormatterM.orelse {α} (p₁ p₂ : FormatterM α) : FormatterM α := do s ← get; catchInternalId backtrackExceptionId p₁ (fun _ => do set s; p₂) instance Formatter.orelse {α} : HasOrelse (FormatterM α) := ⟨FormatterM.orelse⟩ abbrev Formatter := FormatterM Unit unsafe def mkFormatterAttribute : IO (KeyedDeclsAttribute Formatter) := KeyedDeclsAttribute.init { builtinName := `builtinFormatter, name := `formatter, descr := "Register a formatter for a parser. [formatter k] registers a declaration of type `Lean.PrettyPrinter.Formatter` for the `SyntaxNodeKind` `k`.", valueTypeName := `Lean.PrettyPrinter.Formatter, evalKey := fun builtin args => do env ← getEnv; match attrParamSyntaxToIdentifier args with | some id => -- `isValidSyntaxNodeKind` is updated only in the next stage for new `[builtin*Parser]`s, but we try to -- synthesize a formatter for it immediately, so we just check for a declaration in this case if (builtin && (env.find? id).isSome) || Parser.isValidSyntaxNodeKind env id then pure id else throwError ("invalid [formatter] argument, unknown syntax kind '" ++ toString id ++ "'") | none => throwError "invalid [formatter] argument, expected identifier" } `Lean.PrettyPrinter.formatterAttribute @[init mkFormatterAttribute] constant formatterAttribute : KeyedDeclsAttribute Formatter := arbitrary _ unsafe def mkCombinatorFormatterAttribute : IO ParserCompiler.CombinatorAttribute := ParserCompiler.registerCombinatorAttribute `combinatorFormatter "Register a formatter for a parser combinator. [combinatorFormatter c] registers a declaration of type `Lean.PrettyPrinter.Formatter` for the `Parser` declaration `c`. Note that, unlike with [formatter], this is not a node kind since combinators usually do not introduce their own node kinds. The tagged declaration may optionally accept parameters corresponding to (a prefix of) those of `c`, where `Parser` is replaced with `Formatter` in the parameter types." @[init mkCombinatorFormatterAttribute] constant combinatorFormatterAttribute : ParserCompiler.CombinatorAttribute := arbitrary _ namespace Formatter open Lean.Core open Lean.Parser def throwBacktrack {α} : FormatterM α := throw $ Exception.internal backtrackExceptionId instance FormatterM.monadTraverser : Syntax.MonadTraverser FormatterM := ⟨{ get := State.stxTrav <$> get, set := fun t => modify (fun st => { st with stxTrav := t }), modifyGet := fun _ f => modifyGet (fun st => let (a, t) := f st.stxTrav; (a, { st with stxTrav := t })) }⟩ open Syntax.MonadTraverser def getStack : FormatterM (Array Format) := do st ← get; pure st.stack def getStackSize : FormatterM Nat := do stack ← getStack; pure stack.size def setStack (stack : Array Format) : FormatterM Unit := modify fun st => { st with stack := stack } def push (f : Format) : FormatterM Unit := modify fun st => { st with stack := st.stack.push f } def pushLine : FormatterM Unit := do push Format.line; modify fun st => { st with leadWord := "" } /-- Execute `x` at the right-most child of the current node, if any, then advance to the left. -/ def visitArgs (x : FormatterM Unit) : FormatterM Unit := do stx ← getCur; when (stx.getArgs.size > 0) $ goDown (stx.getArgs.size - 1) *> x <* goUp; goLeft /-- Execute `x`, pass array of generated Format objects to `fn`, and push result. -/ def fold (fn : Array Format → Format) (x : FormatterM Unit) : FormatterM Unit := do sp ← getStackSize; x; stack ← getStack; let f := fn $ stack.extract sp stack.size; setStack $ (stack.shrink sp).push f /-- Execute `x` and concatenate generated Format objects. -/ def concat (x : FormatterM Unit) : FormatterM Unit := do fold (Array.foldl (fun acc f => if acc.isNil then f else f ++ acc) Format.nil) x def indent (x : Formatter) (indent : Option Int := none) : Formatter := do concat x; ctx ← read; let indent := indent.getD $ Format.getIndent ctx.options; modify fun st => { st with stack := st.stack.pop.push (Format.nest indent st.stack.back) } def group (x : Formatter) : Formatter := do concat x; modify fun st => { st with stack := st.stack.pop.push (Format.fill st.stack.back) } @[combinatorFormatter Lean.Parser.orelse] def orelse.formatter (p1 p2 : Formatter) : Formatter := -- HACK: We have no (immediate) information on which side of the orelse could have produced the current node, so try -- them in turn. Uses the syntax traverser non-linearly! p1 <|> p2 -- `mkAntiquot` is quite complex, so we'd rather have its formatter synthesized below the actual parser definition. -- Note that there is a mutual recursion -- `categoryParser -> mkAntiquot -> termParser -> categoryParser`, so we need to introduce an indirection somewhere -- anyway. @[extern "lean_mk_antiquot_formatter"] constant mkAntiquot.formatter' (name : String) (kind : Option SyntaxNodeKind) (anonymous := true) : Formatter := arbitrary _ def formatterForKind (k : SyntaxNodeKind) : Formatter := do env ← getEnv; p::_ ← pure $ formatterAttribute.getValues env k | throwError $ "no known formatter for kind '" ++ k ++ "'"; p @[combinatorFormatter Lean.Parser.withAntiquot] def withAntiquot.formatter (antiP p : Formatter) : Formatter := -- TODO: could be optimized using `isAntiquot` (which would have to be moved), but I'd rather -- fix the backtracking hack outright. orelse.formatter antiP p @[combinatorFormatter Lean.Parser.categoryParser] def categoryParser.formatter (cat : Name) : Formatter := group $ indent do stx ← getCur; if stx.getKind == `choice then visitArgs do { stx ← getCur; sp ← getStackSize; stx.getArgs.forM fun stx => formatterForKind stx.getKind; stack ← getStack; when (stack.size > sp && stack.anyRange sp stack.size fun f => f.pretty != (stack.get! sp).pretty) panic! "Formatter.visit: inequal choice children"; -- discard all but one child format setStack $ stack.extract 0 (sp+1) } else withAntiquot.formatter (mkAntiquot.formatter' cat.toString none) (formatterForKind stx.getKind) @[combinatorFormatter Lean.Parser.categoryParserOfStack] def categoryParserOfStack.formatter (offset : Nat) : Formatter := do st ← get; let stx := st.stxTrav.parents.back.getArg (st.stxTrav.idxs.back - offset); categoryParser.formatter stx.getId @[combinatorFormatter Lean.Parser.try] def try.formatter (p : Formatter) : Formatter := p @[combinatorFormatter Lean.Parser.lookahead] def lookahead.formatter (p : Formatter) : Formatter := pure () @[combinatorFormatter Lean.Parser.notFollowedBy] def notFollowedBy.formatter (p : Formatter) : Formatter := pure () @[combinatorFormatter Lean.Parser.andthen] def andthen.formatter (p1 p2 : Formatter) : Formatter := p2 *> p1 def checkKind (k : SyntaxNodeKind) : FormatterM Unit := do stx ← getCur; when (k != stx.getKind) $ do { trace! `PrettyPrinter.format.backtrack ("unexpected node kind '" ++ toString stx.getKind ++ "', expected '" ++ toString k ++ "'"); throwBacktrack } @[combinatorFormatter Lean.Parser.node] def node.formatter (k : SyntaxNodeKind) (p : Formatter) : Formatter := do checkKind k; visitArgs p @[combinatorFormatter Lean.Parser.trailingNode] def trailingNode.formatter (k : SyntaxNodeKind) (_ : Nat) (p : Formatter) : Formatter := do checkKind k; visitArgs do p; -- leading term, not actually produced by `p` categoryParser.formatter `foo def parseToken (s : String) : FormatterM ParserState := do ctx ← read; env ← getEnv; pure $ Parser.tokenFn { input := s, fileName := "", fileMap := FileMap.ofString "", prec := 0, env := env, tokens := ctx.table } (Parser.mkParserState s) def pushTokenCore (tk : String) : FormatterM Unit := if tk.toSubstring.dropRightWhile (fun s => s == ' ') == tk.toSubstring then push tk else do pushLine; push tk.trimRight def pushToken (tk : String) : FormatterM Unit := do st ← get; -- If there is no space between `tk` and the next word, compare parsing `tk` with and without the next word if st.leadWord != "" && tk.trimRight == tk then do t1 ← parseToken tk.trimLeft; t2 ← parseToken $ tk.trimLeft ++ st.leadWord; if t1.pos == t2.pos then do -- same result => use `tk` as is, extend `leadWord` if not prefixed by whitespace pushTokenCore tk; modify fun st => { st with leadWord := if tk.trimLeft == tk then tk ++ st.leadWord else "" } else do -- different result => add space pushTokenCore $ tk ++ " "; modify fun st => { st with leadWord := if tk.trimLeft == tk then tk else "" } else do { -- already separated => use `tk` as is pushTokenCore tk; modify fun st => { st with leadWord := if tk.trimLeft == tk then tk else "" } } @[combinatorFormatter symbol] def symbol.formatter (sym : String) : Formatter := do stx ← getCur; if stx.isToken sym then do pushToken sym; goLeft else do trace! `PrettyPrinter.format.backtrack ("unexpected syntax '" ++ stx ++ "', expected symbol '" ++ sym ++ "'"); throwBacktrack @[combinatorFormatter nonReservedSymbol] def nonReservedSymbol.formatter := symbol.formatter @[combinatorFormatter unicodeSymbol] def unicodeSymbol.formatter (sym asciiSym : String) : Formatter := do stx ← getCur; Syntax.atom _ val ← pure stx | throwError $ "not an atom: " ++ toString stx; if val == sym.trim then pushToken sym else pushToken asciiSym; goLeft @[combinatorFormatter identNoAntiquot] def identNoAntiquot.formatter : Formatter := do checkKind identKind; stx ← getCur; let id := stx.getId; let id := id.simpMacroScopes; let s := id.toString; if id.isAnonymous then pushToken "[anonymous]" else if isInaccessibleUserName id || id.components.any Name.isNum || -- loose bvar "#".isPrefixOf s then -- not parsable anyway, output as-is pushToken s else do { -- try to parse `s` as-is; if it fails, escape pst ← parseToken s; if pst.stxStack == #[stx] then pushToken s else let n := stx.getId; -- TODO: do something better than escaping all parts let n := (n.components.map fun c => "«" ++ toString c ++ "»").foldl mkNameStr Name.anonymous; pushToken n.toString }; goLeft @[combinatorFormatter rawIdent] def rawIdent.formatter : Formatter := do checkKind identKind; stx ← getCur; pushToken stx.getId.toString; goLeft @[combinatorFormatter Lean.Parser.identEq] def identEq.formatter := rawIdent.formatter def visitAtom (k : SyntaxNodeKind) : Formatter := do stx ← getCur; when (k != Name.anonymous) $ checkKind k; Syntax.atom _ val ← pure $ stx.ifNode (fun n => n.getArg 0) (fun _ => stx) | throwError $ "not an atom: " ++ toString stx; pushToken val; goLeft @[combinatorFormatter charLitNoAntiquot] def charLitNoAntiquot.formatter := visitAtom charLitKind @[combinatorFormatter strLitNoAntiquot] def strLitNoAntiquot.formatter := visitAtom strLitKind @[combinatorFormatter nameLitNoAntiquot] def nameLitNoAntiquot.formatter := visitAtom nameLitKind @[combinatorFormatter numLitNoAntiquot] def numLitNoAntiquot.formatter := visitAtom numLitKind @[combinatorFormatter fieldIdx] def fieldIdx.formatter := visitAtom fieldIdxKind @[combinatorFormatter many] def many.formatter (p : Formatter) : Formatter := do stx ← getCur; visitArgs $ stx.getArgs.size.forM fun _ => p @[combinatorFormatter many1] def many1.formatter (p : Formatter) : Formatter := many.formatter p @[combinatorFormatter Parser.optional] def optional.formatter (p : Formatter) : Formatter := do visitArgs p @[combinatorFormatter Parser.many1Unbox] def many1Unbox.formatter (p : Formatter) : Formatter := do stx ← getCur; if stx.getKind == nullKind then do many.formatter p else p @[combinatorFormatter sepBy] def sepBy.formatter (p pSep : Formatter) : Formatter := do stx ← getCur; visitArgs $ (List.range stx.getArgs.size).reverse.forM $ fun i => if i % 2 == 0 then p else pSep @[combinatorFormatter sepBy1] def sepBy1.formatter := sepBy.formatter @[combinatorFormatter Lean.Parser.withPosition] def withPosition.formatter (p : Formatter) : Formatter := do p @[combinatorFormatter Lean.Parser.withoutPosition] def withoutPosition.formatter (p : Formatter) : Formatter := do p @[combinatorFormatter Lean.Parser.withForbidden] def withForbidden.formatter (tk : Token) (p : Formatter) : Formatter := do p @[combinatorFormatter Lean.Parser.withoutForbidden] def withoutForbidden.formatter (p : Formatter) : Formatter := do p @[combinatorFormatter Lean.Parser.setExpected] def setExpected.formatter (expected : List String) (p : Formatter) : Formatter := p @[combinatorFormatter Lean.Parser.toggleInsideQuot] def toggleInsideQuot.formatter (p : Formatter) : Formatter := p @[combinatorFormatter checkWsBefore] def checkWsBefore.formatter : Formatter := do st ← get; when (st.leadWord != "") $ pushLine @[combinatorFormatter checkPrec] def checkPrec.formatter : Formatter := pure () @[combinatorFormatter checkStackTop] def checkStackTop.formatter : Formatter := pure () @[combinatorFormatter checkNoWsBefore] def checkNoWsBefore.formatter : Formatter := pure () @[combinatorFormatter checkTailWs] def checkTailWs.formatter : Formatter := pure () @[combinatorFormatter checkColGe] def checkColGe.formatter : Formatter := pure () @[combinatorFormatter checkColGt] def checkColGt.formatter : Formatter := pure () @[combinatorFormatter eoi] def eoi.formatter : Formatter := pure () @[combinatorFormatter notFollowedByCategoryToken] def notFollowedByCategoryToken.formatter : Formatter := pure () @[combinatorFormatter checkNoImmediateColon] def checkNoImmediateColon.formatter : Formatter := pure () @[combinatorFormatter Lean.Parser.checkInsideQuot] def checkInsideQuot.formatter : Formatter := pure () @[combinatorFormatter Lean.Parser.checkOutsideQuot] def checkOutsideQuot.formatter : Formatter := pure () @[combinatorFormatter Lean.Parser.skip] def skip.formatter : Formatter := pure () @[combinatorFormatter Lean.Parser.ppHardSpace] def ppHardSpace.formatter : Formatter := push " " @[combinatorFormatter Lean.Parser.ppSpace] def ppSpace.formatter : Formatter := pushLine @[combinatorFormatter Lean.Parser.ppLine] def ppLine.formatter : Formatter := push "\n" @[combinatorFormatter Lean.Parser.ppGroup] def ppGroup.formatter (p : Formatter) : Formatter := group $ indent p @[combinatorFormatter Lean.Parser.ppIndent] def ppIndent.formatter (p : Formatter) : Formatter := indent p @[combinatorFormatter Lean.Parser.ppDedent] def ppDedent.formatter (p : Formatter) : Formatter := do opts ← getOptions; indent p (some (-(Format.getIndent opts))) @[combinatorFormatter pushNone] def pushNone.formatter : Formatter := goLeft -- TODO: delete with old frontend @[combinatorFormatter quotedSymbol] def quotedSymbol.formatter : Formatter := do checkKind quotedSymbolKind; visitArgs do push "`"; goLeft; visitAtom Name.anonymous; push "`"; goLeft @[combinatorFormatter Lean.Parser.interpolatedStr] def interpolatedStr.formatter (p : Formatter) : Formatter := throwError "NIY" @[combinatorFormatter unquotedSymbol] def unquotedSymbol.formatter := visitAtom Name.anonymous @[combinatorFormatter ite, macroInline] def ite {α : Type} (c : Prop) [h : Decidable c] (t e : Formatter) : Formatter := if c then t else e end Formatter open Formatter def format (formatter : Formatter) (stx : Syntax) : CoreM Format := do options ← getOptions; table ← Parser.builtinTokenTable.get; catchInternalId backtrackExceptionId (do (_, st) ← (concat formatter { table := table, options := options }).run { stxTrav := Syntax.Traverser.fromSyntax stx }; pure $ Format.fill $ st.stack.get! 0) (fun _ => throwError "format: uncaught backtrack exception") def formatTerm := format $ categoryParser.formatter `term def formatCommand := format $ categoryParser.formatter `command @[init] private def regTraceClasses : IO Unit := do registerTraceClass `PrettyPrinter.format; pure () end PrettyPrinter end Lean