lean4-htt/src/Lean/PrettyPrinter/Formatter.lean
2020-10-07 09:46:47 +02:00

444 lines
17 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2020 Sebastian Ullrich. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Sebastian Ullrich
-/
/-!
The formatter turns a `Syntax` tree into a `Format` object, inserting both mandatory whitespace (to separate adjacent
tokens) as well as "pretty" optional whitespace.
The basic approach works much like the parenthesizer: A right-to-left traversal over the syntax tree, driven by
parser-specific handlers registered via attributes. The traversal is right-to-left so that when emitting a token, we
already know the text following it and can decide whether or not whitespace between the two is necessary.
-/
import Lean.CoreM
import Lean.Parser.Extension
import Lean.KeyedDeclsAttribute
import Lean.ParserCompiler.Attribute
import Lean.PrettyPrinter.Backtrack
namespace Lean
namespace PrettyPrinter
namespace Formatter
structure Context :=
(options : Options)
(table : Parser.TokenTable)
structure State :=
(stxTrav : Syntax.Traverser)
-- Textual content of `stack` up to the first whitespace (not enclosed in an escaped ident). We assume that the textual
-- content of `stack` is modified only by `pushText` and `pushLine`, so `leadWord` is adjusted there accordingly.
(leadWord : String := "")
-- Stack of generated Format objects, analogous to the Syntax stack in the parser.
-- Note, however, that the stack is reversed because of the right-to-left traversal.
(stack : Array Format := #[])
end Formatter
abbrev FormatterM := ReaderT Formatter.Context $ StateRefT Formatter.State $ CoreM
@[inline] def FormatterM.orelse {α} (p₁ p₂ : FormatterM α) : FormatterM α := do
s ← get;
catchInternalId backtrackExceptionId
p₁
(fun _ => do set s; p₂)
instance Formatter.orelse {α} : HasOrelse (FormatterM α) := ⟨FormatterM.orelse⟩
abbrev Formatter := FormatterM Unit
unsafe def mkFormatterAttribute : IO (KeyedDeclsAttribute Formatter) :=
KeyedDeclsAttribute.init {
builtinName := `builtinFormatter,
name := `formatter,
descr := "Register a formatter for a parser.
[formatter k] registers a declaration of type `Lean.PrettyPrinter.Formatter` for the `SyntaxNodeKind` `k`.",
valueTypeName := `Lean.PrettyPrinter.Formatter,
evalKey := fun builtin args => do
env ← getEnv;
match attrParamSyntaxToIdentifier args with
| some id =>
-- `isValidSyntaxNodeKind` is updated only in the next stage for new `[builtin*Parser]`s, but we try to
-- synthesize a formatter for it immediately, so we just check for a declaration in this case
if (builtin && (env.find? id).isSome) || Parser.isValidSyntaxNodeKind env id then pure id
else throwError ("invalid [formatter] argument, unknown syntax kind '" ++ toString id ++ "'")
| none => throwError "invalid [formatter] argument, expected identifier"
} `Lean.PrettyPrinter.formatterAttribute
@[init mkFormatterAttribute] constant formatterAttribute : KeyedDeclsAttribute Formatter := arbitrary _
unsafe def mkCombinatorFormatterAttribute : IO ParserCompiler.CombinatorAttribute :=
ParserCompiler.registerCombinatorAttribute
`combinatorFormatter
"Register a formatter for a parser combinator.
[combinatorFormatter c] registers a declaration of type `Lean.PrettyPrinter.Formatter` for the `Parser` declaration `c`.
Note that, unlike with [formatter], this is not a node kind since combinators usually do not introduce their own node kinds.
The tagged declaration may optionally accept parameters corresponding to (a prefix of) those of `c`, where `Parser` is replaced
with `Formatter` in the parameter types."
@[init mkCombinatorFormatterAttribute] constant combinatorFormatterAttribute : ParserCompiler.CombinatorAttribute := arbitrary _
namespace Formatter
open Lean.Core
open Lean.Parser
def throwBacktrack {α} : FormatterM α :=
throw $ Exception.internal backtrackExceptionId
instance FormatterM.monadTraverser : Syntax.MonadTraverser FormatterM := ⟨{
get := State.stxTrav <$> get,
set := fun t => modify (fun st => { st with stxTrav := t }),
modifyGet := fun _ f => modifyGet (fun st => let (a, t) := f st.stxTrav; (a, { st with stxTrav := t })) }⟩
open Syntax.MonadTraverser
def getStack : FormatterM (Array Format) := do
st ← get;
pure st.stack
def getStackSize : FormatterM Nat := do
stack ← getStack;
pure stack.size
def setStack (stack : Array Format) : FormatterM Unit :=
modify fun st => { st with stack := stack }
def push (f : Format) : FormatterM Unit :=
modify fun st => { st with stack := st.stack.push f }
def pushLine : FormatterM Unit := do
push Format.line;
modify fun st => { st with leadWord := "" }
/-- Execute `x` at the right-most child of the current node, if any, then advance to the left. -/
def visitArgs (x : FormatterM Unit) : FormatterM Unit := do
stx ← getCur;
when (stx.getArgs.size > 0) $
goDown (stx.getArgs.size - 1) *> x <* goUp;
goLeft
/-- Execute `x`, pass array of generated Format objects to `fn`, and push result. -/
def fold (fn : Array Format → Format) (x : FormatterM Unit) : FormatterM Unit := do
sp ← getStackSize;
x;
stack ← getStack;
let f := fn $ stack.extract sp stack.size;
setStack $ (stack.shrink sp).push f
/-- Execute `x` and concatenate generated Format objects. -/
def concat (x : FormatterM Unit) : FormatterM Unit := do
fold (Array.foldl (fun acc f => f ++ acc) Format.nil) x
def concatArgs (x : FormatterM Unit) : FormatterM Unit :=
concat (visitArgs x)
def indent (x : Formatter) (indent : Option Int := none) : Formatter := do
concat x;
ctx ← read;
let indent := indent.getD $ Format.getIndent ctx.options;
modify fun st => { st with stack := st.stack.pop.push (Format.nest indent st.stack.back) }
def group (x : Formatter) : Formatter := do
concat x;
modify fun st => { st with stack := st.stack.pop.push (Format.group st.stack.back) }
@[combinatorFormatter Lean.Parser.orelse] def orelse.formatter (p1 p2 : Formatter) : Formatter :=
-- HACK: We have no (immediate) information on which side of the orelse could have produced the current node, so try
-- them in turn. Uses the syntax traverser non-linearly!
p1 <|> p2
-- `mkAntiquot` is quite complex, so we'd rather have its formatter synthesized below the actual parser definition.
-- Note that there is a mutual recursion
-- `categoryParser -> mkAntiquot -> termParser -> categoryParser`, so we need to introduce an indirection somewhere
-- anyway.
@[extern "lean_mk_antiquot_formatter"]
constant mkAntiquot.formatter' (name : String) (kind : Option SyntaxNodeKind) (anonymous := true) : Formatter :=
arbitrary _
def formatterForKind (k : SyntaxNodeKind) : Formatter := do
env ← getEnv;
p::_ ← pure $ formatterAttribute.getValues env k
| throwError $ "no known formatter for kind '" ++ k ++ "'";
p
@[combinatorFormatter Lean.Parser.withAntiquot]
def withAntiquot.formatter (antiP p : Formatter) : Formatter :=
-- TODO: could be optimized using `isAntiquot` (which would have to be moved), but I'd rather
-- fix the backtracking hack outright.
orelse.formatter antiP p
@[combinatorFormatter Lean.Parser.categoryParser]
def categoryParser.formatter (cat : Name) : Formatter := group $ indent do
stx ← getCur;
if stx.getKind == `choice then
visitArgs do {
stx ← getCur;
sp ← getStackSize;
stx.getArgs.forM fun stx => formatterForKind stx.getKind;
stack ← getStack;
when (stack.size > sp && stack.anyRange sp stack.size fun f => f.pretty != (stack.get! sp).pretty)
panic! "Formatter.visit: inequal choice children";
-- discard all but one child format
setStack $ stack.extract 0 (sp+1)
}
else
withAntiquot.formatter (mkAntiquot.formatter' cat.toString none) (formatterForKind stx.getKind)
@[combinatorFormatter Lean.Parser.categoryParserOfStack]
def categoryParserOfStack.formatter (offset : Nat) : Formatter := do
st ← get;
let stx := st.stxTrav.parents.back.getArg (st.stxTrav.idxs.back - offset);
categoryParser.formatter stx.getId
@[combinatorFormatter Lean.Parser.try]
def try.formatter (p : Formatter) : Formatter :=
p
@[combinatorFormatter Lean.Parser.lookahead]
def lookahead.formatter (p : Formatter) : Formatter :=
pure ()
@[combinatorFormatter Lean.Parser.notFollowedBy]
def notFollowedBy.formatter (p : Formatter) : Formatter :=
pure ()
@[combinatorFormatter Lean.Parser.andthen]
def andthen.formatter (p1 p2 : Formatter) : Formatter :=
p2 *> p1
def checkKind (k : SyntaxNodeKind) : FormatterM Unit := do
stx ← getCur;
when (k != stx.getKind) $ do {
trace! `PrettyPrinter.format.backtrack ("unexpected node kind '" ++ toString stx.getKind ++ "', expected '" ++ toString k ++ "'");
throwBacktrack
}
@[combinatorFormatter Lean.Parser.node]
def node.formatter (k : SyntaxNodeKind) (p : Formatter) : Formatter := do
checkKind k;
concatArgs p
@[combinatorFormatter Lean.Parser.trailingNode]
def trailingNode.formatter (k : SyntaxNodeKind) (_ : Nat) (p : Formatter) : Formatter := do
checkKind k;
concatArgs do
p;
-- leading term, not actually produced by `p`
categoryParser.formatter `foo
def parseToken (s : String) : FormatterM ParserState := do
ctx ← read;
env ← getEnv;
pure $ Parser.tokenFn { input := s, fileName := "", fileMap := FileMap.ofString "", prec := 0, env := env, tokens := ctx.table } (Parser.mkParserState s)
def pushTokenCore (tk : String) : FormatterM Unit :=
if tk.toSubstring.dropRightWhile (fun s => s == ' ') == tk.toSubstring then
push tk
else do
pushLine;
push tk.trimRight
def pushToken (tk : String) : FormatterM Unit := do
st ← get;
-- If there is no space between `tk` and the next word, compare parsing `tk` with and without the next word
if st.leadWord != "" && tk.trimRight == tk then do
t1 ← parseToken tk.trimLeft;
t2 ← parseToken $ tk.trimLeft ++ st.leadWord;
if t1.pos == t2.pos then do
-- same result => use `tk` as is, extend `leadWord` if not prefixed by whitespace
pushTokenCore tk;
modify fun st => { st with leadWord := if tk.trimLeft == tk then tk ++ st.leadWord else "" }
else do
-- different result => add space
pushTokenCore $ tk ++ " ";
modify fun st => { st with leadWord := if tk.trimLeft == tk then tk else "" }
else do {
-- already separated => use `tk` as is
pushTokenCore tk;
modify fun st => { st with leadWord := if tk.trimLeft == tk then tk else "" }
}
@[combinatorFormatter symbol]
def symbol.formatter (sym : String) : Formatter := do
stx ← getCur;
if stx.isToken sym then do
pushToken sym;
goLeft
else do
trace! `PrettyPrinter.format.backtrack ("unexpected syntax '" ++ stx ++ "', expected symbol '" ++ sym ++ "'");
throwBacktrack
@[combinatorFormatter symbolNoWs] def symbolNoWs.formatter := symbol.formatter
@[combinatorFormatter nonReservedSymbol] def nonReservedSymbol.formatter := symbol.formatter
@[combinatorFormatter unicodeSymbol]
def unicodeSymbol.formatter (sym asciiSym : String) : Formatter := do
stx ← getCur;
Syntax.atom _ val ← pure stx
| throwError $ "not an atom: " ++ toString stx;
if val == sym.trim then
pushToken sym
else
pushToken asciiSym;
goLeft
@[combinatorFormatter identNoAntiquot]
def identNoAntiquot.formatter : Formatter := do
checkKind identKind;
stx ← getCur;
let id := stx.getId;
let id := id.simpMacroScopes;
let s := id.toString;
if id.isAnonymous then
pushToken "[anonymous]"
else if isInaccessibleUserName id || id.components.any Name.isNum ||
-- loose bvar
"#".isPrefixOf s then
-- not parsable anyway, output as-is
pushToken s
else do {
-- try to parse `s` as-is; if it fails, escape
pst ← parseToken s;
if pst.stxStack == #[stx] then
pushToken s
else
let n := stx.getId;
-- TODO: do something better than escaping all parts
let n := (n.components.map fun c => "«" ++ toString c ++ "»").foldl mkNameStr Name.anonymous;
pushToken n.toString
};
goLeft
@[combinatorFormatter rawIdent] def rawIdent.formatter : Formatter := do
checkKind identKind;
stx ← getCur;
pushToken stx.getId.toString;
goLeft
@[combinatorFormatter Lean.Parser.identEq] def identEq.formatter := rawIdent.formatter
def visitAtom (k : SyntaxNodeKind) : Formatter := do
stx ← getCur;
when (k != Name.anonymous) $
checkKind k;
Syntax.atom _ val ← pure $ stx.ifNode (fun n => n.getArg 0) (fun _ => stx)
| throwError $ "not an atom: " ++ toString stx;
pushToken val;
goLeft
@[combinatorFormatter charLitNoAntiquot] def charLitNoAntiquot.formatter := visitAtom charLitKind
@[combinatorFormatter strLitNoAntiquot] def strLitNoAntiquot.formatter := visitAtom strLitKind
@[combinatorFormatter nameLitNoAntiquot] def nameLitNoAntiquot.formatter := visitAtom nameLitKind
@[combinatorFormatter numLitNoAntiquot] def numLitNoAntiquot.formatter := visitAtom numLitKind
@[combinatorFormatter fieldIdx] def fieldIdx.formatter := visitAtom fieldIdxKind
@[combinatorFormatter many]
def many.formatter (p : Formatter) : Formatter := do
stx ← getCur;
concatArgs $ stx.getArgs.size.forM fun _ => p
@[combinatorFormatter many1] def many1.formatter (p : Formatter) : Formatter :=
many.formatter p
@[combinatorFormatter Parser.optional]
def optional.formatter (p : Formatter) : Formatter := do
concatArgs p
@[combinatorFormatter Parser.many1Unbox]
def many1Unbox.formatter (p : Formatter) : Formatter := do
stx ← getCur;
if stx.getKind == nullKind then do
many.formatter p
else
p
@[combinatorFormatter sepBy]
def sepBy.formatter (p pSep : Formatter) : Formatter := do
stx ← getCur;
concatArgs $ (List.range stx.getArgs.size).reverse.forM $ fun i => if i % 2 == 0 then p else pSep
@[combinatorFormatter sepBy1] def sepBy1.formatter := sepBy.formatter
@[combinatorFormatter Lean.Parser.withPosition] def withPosition.formatter (p : Formatter) : Formatter := do
p
@[combinatorFormatter Lean.Parser.withoutPosition] def withoutPosition.formatter (p : Formatter) : Formatter := do
p
@[combinatorFormatter Lean.Parser.withForbidden] def withForbidden.formatter (tk : Token) (p : Formatter) : Formatter := do
p
@[combinatorFormatter Lean.Parser.withoutForbidden] def withoutForbidden.formatter (p : Formatter) : Formatter := do
p
@[combinatorFormatter Lean.Parser.setExpected]
def setExpected.formatter (expected : List String) (p : Formatter) : Formatter :=
p
@[combinatorFormatter Lean.Parser.toggleInsideQuot]
def toggleInsideQuot.formatter (p : Formatter) : Formatter :=
p
@[combinatorFormatter checkWsBefore] def checkWsBefore.formatter : Formatter := do
pushLine
@[combinatorFormatter checkPrec] def checkPrec.formatter : Formatter := pure ()
@[combinatorFormatter checkStackTop] def checkStackTop.formatter : Formatter := pure ()
@[combinatorFormatter checkNoWsBefore] def checkNoWsBefore.formatter : Formatter := pure ()
@[combinatorFormatter checkTailWs] def checkTailWs.formatter : Formatter := pure ()
@[combinatorFormatter checkColGe] def checkColGe.formatter : Formatter := pure ()
@[combinatorFormatter checkColGt] def checkColGt.formatter : Formatter := pure ()
@[combinatorFormatter eoi] def eoi.formatter : Formatter := pure ()
@[combinatorFormatter notFollowedByCategoryToken] def notFollowedByCategoryToken.formatter : Formatter := pure ()
@[combinatorFormatter checkNoImmediateColon] def checkNoImmediateColon.formatter : Formatter := pure ()
@[combinatorFormatter Lean.Parser.checkInsideQuot] def checkInsideQuot.formatter : Formatter := pure ()
@[combinatorFormatter Lean.Parser.checkOutsideQuot] def checkOutsideQuot.formatter : Formatter := pure ()
@[combinatorFormatter Lean.Parser.skip] def skip.formatter : Formatter := pure ()
@[combinatorFormatter Lean.Parser.ppHardSpace] def ppHardSpace.formatter : Formatter := push " "
@[combinatorFormatter Lean.Parser.ppSpace] def ppSpace.formatter : Formatter := pushLine
@[combinatorFormatter Lean.Parser.ppLine] def ppLine.formatter : Formatter := push "\n"
@[combinatorFormatter Lean.Parser.ppGroup] def ppGroup.formatter (p : Formatter) : Formatter := group $ indent p
@[combinatorFormatter Lean.Parser.ppDedent] def ppDedent.formatter (p : Formatter) : Formatter := do
opts ← getOptions;
indent p (some (-(Format.getIndent opts)))
@[combinatorFormatter pushNone] def pushNone.formatter : Formatter := goLeft
-- TODO: delete with old frontend
@[combinatorFormatter quotedSymbol] def quotedSymbol.formatter : Formatter := do
checkKind quotedSymbolKind;
concatArgs do
push "`"; goLeft;
visitAtom Name.anonymous;
push "`"; goLeft
@[combinatorFormatter unquotedSymbol] def unquotedSymbol.formatter := visitAtom Name.anonymous
@[combinatorFormatter ite, macroInline] def ite {α : Type} (c : Prop) [h : Decidable c] (t e : Formatter) : Formatter :=
if c then t else e
end Formatter
open Formatter
def format (formatter : Formatter) (stx : Syntax) : CoreM Format := do
options ← getOptions;
table ← Parser.builtinTokenTable.get;
catchInternalId backtrackExceptionId
(do
(_, st) ← (formatter { table := table, options := options }).run { stxTrav := Syntax.Traverser.fromSyntax stx };
pure $ Format.group $ st.stack.get! 0)
(fun _ => throwError "format: uncaught backtrack exception")
def formatTerm := format $ categoryParser.formatter `term
def formatCommand := format $ categoryParser.formatter `command
@[init] private def regTraceClasses : IO Unit := do
registerTraceClass `PrettyPrinter.format;
pure ()
end PrettyPrinter
end Lean