lean4-htt/src/Lean/PrettyPrinter/Formatter.lean
2020-08-06 09:27:12 -07:00

347 lines
13 KiB
Text

/-
Copyright (c) 2020 Sebastian Ullrich. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Sebastian Ullrich
-/
/-!
The formatter turns a `Syntax` tree into a `Format` object, inserting both mandatory whitespace (to separate adjacent
tokens) as well as "pretty" optional whitespace.
The basic approach works much like the parenthesizer: A right-to-left traversal over the syntax tree and the parser that
produced it, driven by parser-specific handlers registered via an attribute. The traversal is right-to-left so that when
emitting a token, we already know the text following it and can decide whether or not whitespace between the two is
necessary.
-/
import Lean.Parser
import Lean.Meta
import Lean.Elab.Quotation
namespace Lean
namespace PrettyPrinter
namespace Formatter
structure Context :=
(table : Parser.TokenTable)
structure State :=
(stxTrav : Syntax.Traverser)
-- Textual content of `stack` up to the first whitespace (not enclosed in an escaped ident). We assume that the textual
-- content of `stack` is modified only by `pushText` and `pushLine`, so `leadWord` is adjusted there accordingly.
(leadWord : String := "")
-- Stack of generated Format objects, analogous to the Syntax stack in the parser.
-- Note, however, that the stack is reversed because of the right-to-left traversal.
(stack : Array Format := #[])
end Formatter
abbrev FormatterM := ReaderT Formatter.Context $ StateT Formatter.State MetaM
abbrev Formatter := Expr → FormatterM Unit
unsafe def mkFormatterAttribute : IO (KeyedDeclsAttribute Formatter) :=
KeyedDeclsAttribute.init {
builtinName := `builtinFormatter,
name := `formatter,
descr := "Register a formatter.
[formatter c] registers a declaration of type `Lean.PrettyPrinter.Formatter` for the `Parser` declaration `c`.",
valueTypeName := `Lean.PrettyPrinter.Formatter,
evalKey := fun env args => match attrParamSyntaxToIdentifier args with
| some id => match env.find? id with
| some _ => pure id
| none => throw ("invalid [formatter] argument, unknown identifier '" ++ toString id ++ "'")
| none => throw "invalid [formatter] argument, expected identifier"
} `Lean.PrettyPrinter.formatterAttribute
@[init mkFormatterAttribute] constant formatterAttribute : KeyedDeclsAttribute Formatter := arbitrary _
namespace Formatter
open Lean.Meta
open Lean.Format
instance FormatterM.monadTraverser : Syntax.MonadTraverser FormatterM := ⟨{
get := State.stxTrav <$> get,
set := fun t => modify (fun st => { st with stxTrav := t }),
modifyGet := fun _ f => modifyGet (fun st => let (a, t) := f st.stxTrav; (a, { st with stxTrav := t })) }⟩
open Syntax.MonadTraverser
def getStack : FormatterM (Array Format) := do
st ← get;
pure st.stack
def getStackSize : FormatterM Nat := do
stack ← getStack;
pure stack.size
def setStack (stack : Array Format) : FormatterM Unit :=
modify fun st => { st with stack := stack }
def push (f : Format) : FormatterM Unit :=
modify fun st => { st with stack := st.stack.push f }
/-- Execute `x` at the right-most child of the current node, if any, then advance to the left. -/
def visitArgs (x : FormatterM Unit) : FormatterM Unit := do
stx ← getCur;
when (stx.getArgs.size > 0) $
goDown (stx.getArgs.size - 1) *> x <* goUp;
goLeft
/-- Execute `x`, pass array of generated Format objects to `fn`, and push result. -/
def fold (fn : Array Format → Format) (x : FormatterM Unit) : FormatterM Unit := do
sp ← getStackSize;
x;
stack ← getStack;
let f := fn $ stack.extract sp stack.size;
setStack $ (stack.shrink sp).push f
/-- Execute `x` and concatenate generated Format objects. -/
def concat (x : FormatterM Unit) : FormatterM Unit := do
fold (Array.foldl (fun acc f => f ++ acc) Format.nil) x
def concatArgs (x : FormatterM Unit) : FormatterM Unit :=
concat (visitArgs x)
/--
Call an appropriate `[formatter]` depending on the `Parser` `Expr` `p`. After the call, the traverser position
should be to the left of all nodes produced by `p`, or at the left-most child if there are no other nodes left. -/
partial def visit : Formatter | p => do
stx ← getCur;
-- do reductions _except_ for definition unfolding
p ← liftM $ whnfCore p;
trace! `PrettyPrinter.format ("formatting" ++ MessageData.nest 2 (line ++ stx) ++ line ++ "using" ++ MessageData.nest 2 (line ++ p));
sp ← getStackSize;
let c := Expr.constName? p.getAppFn;
env ← liftM getEnv;
match c >>= (formatterAttribute.ext.getState env).table.find? with
| some (f::_) => do
-- call first matching formatter
f p
| _ =>
-- `choice` is not an actual parser, so special-case it here
if c == some `choice then do
visitArgs do {
stx ← getCur;
sp ← getStackSize;
stx.getArgs.forM fun _ => visit (mkConst stx.getKind);
stack ← getStack;
when (stack.size > sp && stack.anyRange sp stack.size fun f => pretty f != pretty (stack.get! sp))
panic! "Formatter.visit: inequal choice children";
-- discard all but one child format
setStack $ stack.extract 0 (sp+1)
}
else do {
-- (try to) unfold definition and recurse
some p' ← liftM $ unfoldDefinition? p
| throw $ Exception.other $ "no known formatter for '" ++ toString p ++ "'";
visit p'
};
stack ← getStack;
trace! `PrettyPrinter.format (" => " ++ (stack.extract sp stack.size).foldl (fun acc f => repr (toString f) ++ " " ++ acc) "")
open Lean.Parser
def visitAntiquot : Formatter | _ => do
stx ← getCur;
if Elab.Term.Quotation.isAntiquot stx then
visit (mkAppN (mkConst `Lean.Parser.mkAntiquot) #[mkNatLit 0, mkNatLit 0])
else
throw $ Exception.other $ "not an antiquotation"
@[builtinFormatter categoryParser]
def categoryParser.formatter : Formatter | p => visitAntiquot p <|> do
stx ← getCur;
visit (mkConst stx.getKind)
@[builtinFormatter termParser]
def termParser.formatter : Formatter | p => do
stx ← getCur;
-- this can happen at `termParser <|> many1 commandParser` in `Term.stxQuot`
if stx.getKind == nullKind then
throw $ Exception.other "BACKTRACK"
else
categoryParser.formatter p
@[builtinFormatter withAntiquot]
def withAntiquot.formatter : Formatter | p =>
visitAntiquot p <|> visit (p.getArg! 1)
@[builtinFormatter try]
def try.formatter : Formatter | p =>
visit p.appArg!
@[builtinFormatter andthen]
def andthen.formatter : Formatter | p =>
visit (p.getArg! 1) *> visit (p.getArg! 0)
def checkKind (k : SyntaxNodeKind) : FormatterM Unit := do
stx ← getCur;
when (k != stx.getKind) $ do {
trace! `PrettyPrinter.format.backtrack ("unexpected node kind '" ++ toString stx.getKind ++ "', expected '" ++ toString k ++ "'");
-- HACK; see `orelse.formatter`
throw $ Exception.other "BACKTRACK"
}
@[builtinFormatter node]
def node.formatter : Formatter | p => do
k ← liftM $ reduceEval $ p.getArg! 0;
checkKind k;
concatArgs $ visit p.appArg!
@[builtinFormatter trailingNode]
def trailingNode.formatter : Formatter | p => do
k ← liftM $ reduceEval $ p.getArg! 0;
checkKind k;
concatArgs do
visit p.appArg!;
-- leading term, not actually produced by `p`
categoryParser.formatter p
def parseToken (s : String) : FormatterM ParserState := do
ctx ← read;
env ← liftM getEnv;
pure $ Parser.tokenFn { input := s, fileName := "", fileMap := FileMap.ofString "", prec := 0, env := env, tokens := ctx.table } (Parser.mkParserState s)
def pushToken (tk : String) : FormatterM Unit := do
st ← get;
-- If there is no space between `tk` and the next word, compare parsing `tk` with and without the next word
if st.leadWord != "" && tk.trimRight == tk then do
t1 ← parseToken tk.trimLeft;
t2 ← parseToken $ tk.trimLeft ++ st.leadWord;
if t1.pos == t2.pos then do
-- same result => use `tk` as is, extend `leadWord` if not prefixed by whitespace
modify fun st => { st with leadWord := if tk.trimLeft == tk then tk ++ st.leadWord else "" };
push tk
else do
-- different result => add space
modify fun st => { st with leadWord := if tk.trimLeft == tk then tk else "" };
push $ tk ++ " "
else do {
-- already separated => use `tk` as is
modify fun st => { st with leadWord := if tk.trimLeft == tk then tk else "" };
push tk
}
@[builtinFormatter symbol]
def symbol.formatter : Formatter | p => do
let sym := p.getArg! 0;
sym ← liftM $ reduceEval sym;
pushToken sym;
goLeft
@[builtinFormatter symbolNoWs] def symbolNoWs.formatter := symbol.formatter
@[builtinFormatter unicodeSymbol] def unicodeSymbol.formatter := symbol.formatter
@[builtinFormatter nonReservedSymbol] def nonReservedSymbol.formatter := symbol.formatter
@[builtinFormatter identNoAntiquot]
def identNoAntiquot.formatter : Formatter | _ => do
checkKind identKind;
stx ← getCur;
let s := stx.getId.toString;
-- try to parse `s` as-is; if it fails, escape
pst ← parseToken s;
let s := if pst.stxStack == #[stx] then s else match stx.getId with
| Name.str Name.anonymous s _ => "«" ++ s ++ "»"
| _ => panic! "unimplemented: escaping non-atomic identifiers (is anyone even using those?)";
pushToken s;
goLeft
@[builtinFormatter rawIdent] def rawIdent.formatter : Formatter | _ => do
checkKind identKind;
stx ← getCur;
pushToken stx.getId.toString;
goLeft
def visitAtom (k : SyntaxNodeKind) : Formatter | p => do
stx ← getCur;
when (k != Name.anonymous) $
checkKind k;
Syntax.atom _ val ← pure $ stx.ifNode (fun n => n.getArg 0) (fun _ => stx)
| throw $ Exception.other $ "not an atom: " ++ toString stx;
pushToken val;
goLeft
@[builtinFormatter charLitNoAntiquot] def charLitNoAntiquot.formatter := visitAtom charLitKind
@[builtinFormatter strLitNoAntiquot] def strLitNoAntiquot.formatter := visitAtom strLitKind
@[builtinFormatter nameLitNoAntiquot] def nameLitNoAntiquot.formatter := visitAtom nameLitKind
@[builtinFormatter numLitNoAntiquot] def numLitNoAntiquot.formatter := visitAtom numLitKind
@[builtinFormatter fieldIdx] def fieldIdx.formatter := visitAtom fieldIdxKind
@[builtinFormatter many]
def many.formatter : Formatter | p => do
stx ← getCur;
concatArgs $ stx.getArgs.size.forM $ fun _ => visit (p.getArg! 0)
@[builtinFormatter many1] def many1.formatter : Formatter | p => do
stx ← getCur;
if stx.getKind == nullKind then do
many.formatter p
else
-- can happen with `unboxSingleton = true`
visit (p.getArg! 0)
@[builtinFormatter Parser.optional]
def optional.formatter : Formatter | p => do
concatArgs $ visit (p.getArg! 0)
@[builtinFormatter sepBy]
def sepBy.formatter : Formatter | p => do
stx ← getCur;
concatArgs $ (List.range stx.getArgs.size).reverse.forM $ fun i => visit (p.getArg! (i % 2))
@[builtinFormatter sepBy1] def sepBy1.formatter := sepBy.formatter
@[builtinFormatter orelse] def orelse.formatter : Formatter | p => do
st ← get;
-- HACK: We have no (immediate) information on which side of the orelse could have produced the current node, so try
-- them in turn. Uses the syntax traverser non-linearly!
catch (visit (p.getArg! 0)) $ fun e => match e with
| Exception.other "BACKTRACK" => set st *> visit (p.getArg! 1)
| _ => throw e
@[builtinFormatter withPosition] def withPosition.formatter : Formatter | p => do
-- call closure with dummy position
visit $ mkApp (p.getArg! 0) (mkConst `sorryAx [levelZero])
@[builtinFormatter checkWsBefore] def checkWsBefore.formatter : Formatter | p => do
modify fun st => { st with leadWord := "" };
push " "
@[builtinFormatter checkPrec] def checkPrec.formatter : Formatter | p => pure ()
@[builtinFormatter checkStackTop] def checkStackTop.formatter : Formatter | p => pure ()
@[builtinFormatter checkNoWsBefore] def checkNoWsBefore.formatter : Formatter | p => pure ()
@[builtinFormatter checkTailWs] def checkTailWs.formatter : Formatter | p => pure ()
@[builtinFormatter checkColGe] def checkColGe.formatter : Formatter | p => pure ()
@[builtinFormatter pushNone] def pushNone.formatter : Formatter | p => goLeft
open Lean.Parser.Command
@[builtinFormatter commentBody] def commentBody.formatter := visitAtom Name.anonymous
-- TODO: delete with old frontend
@[builtinFormatter quotedSymbol] def quotedSymbol.formatter : Formatter | p => do
checkKind quotedSymbolKind;
concatArgs do
push "`"; goLeft;
visitAtom Name.anonymous p;
push "`"; goLeft
@[builtinFormatter unquotedSymbol] def unquotedSymbol.formatter := visitAtom Name.anonymous
end Formatter
def format (table : Parser.TokenTable) (parser : Expr) (stx : Syntax) : MetaM Format := Meta.withAtLeastTransparency Meta.TransparencyMode.default do
(_, st) ← Formatter.visit parser { table := table } { stxTrav := Syntax.Traverser.fromSyntax stx };
pure $ st.stack.get! 0
def formatTerm (table) := format table (mkApp (mkConst `Lean.Parser.termParser) (mkNatLit 0))
def formatCommand (table) := format table (mkApp (mkConst `Lean.Parser.commandParser) (mkNatLit 0))
@[init] private def regTraceClasses : IO Unit := do
registerTraceClass `PrettyPrinter.format;
pure ()
end PrettyPrinter
end Lean