fix: auto-group syntax parsers where necessary

This commit is contained in:
Sebastian Ullrich 2022-06-18 11:07:25 +02:00
parent e49a81bb56
commit eab64997cd
5 changed files with 99 additions and 64 deletions

View file

@ -1,6 +1,8 @@
Unreleased
---------
* The `group(·)` `syntax` combinator is now introduced automatically where necessary, such as when using multiple parsers inside `(...)+`.
* Add ["Typed Macros"](https://github.com/leanprover/lean4/pull/1251): syntax trees produced and accepted by syntax antiquotations now remember their syntax kinds, preventing accidental production of ill-formed syntax trees and reducing the need for explicit `:kind` antiquotation annotations. See PR for details.
* Aliases of protected definitions are protected too. Example:

View file

@ -17,16 +17,17 @@ def expandOptPrecedence (stx : Syntax) : MacroM (Option Nat) :=
else
return some (← evalPrec stx[0][1])
private def mkParserSeq (ds : Array Term) : TermElabM Syntax := do
private def mkParserSeq (ds : Array (Term × Nat)) : TermElabM (Term × Nat) := do
if ds.size == 0 then
throwUnsupportedSyntax
else if ds.size == 1 then
pure ds[0]
else
let mut r := ds[0]
for d in ds[1:ds.size] do
let mut (r, stackSum) := ds[0]
for (d, stackSz) in ds[1:ds.size] do
r ← `(ParserDescr.binary `andthen $r $d)
return r
stackSum := stackSum + stackSz
return (r, stackSum)
structure ToParserDescrContext where
catName : Name
@ -36,12 +37,20 @@ structure ToParserDescrContext where
behavior : Parser.LeadingIdentBehavior
abbrev ToParserDescrM := ReaderT ToParserDescrContext (StateRefT (Option Nat) TermElabM)
abbrev ToParserDescr := ToParserDescrM (Term × Nat)
private def markAsTrailingParser (lhsPrec : Nat) : ToParserDescrM Unit := set (some lhsPrec)
@[inline] private def withNotFirst {α} (x : ToParserDescrM α) : ToParserDescrM α :=
withReader (fun ctx => { ctx with first := false }) x
@[inline] private def withNestedParser {α} (x : ToParserDescrM α) : ToParserDescrM α :=
def ensureUnaryOutput (x : Term × Nat) : Term :=
let (stx, stackSz) := x
if stackSz != 1 then
Unhygienic.run ``(ParserDescr.unary $(quote `group) $stx)
else
stx
@[inline] private def withNestedParser (x : ToParserDescr) : ToParserDescr := do
withReader (fun ctx => { ctx with leftRec := false, first := false }) x
def checkLeftRec (stx : Syntax) : ToParserDescrM Bool := do
@ -83,12 +92,12 @@ open TSyntax.Compat in
Given a `stx` of category `syntax`, return a pair `(newStx, lhsPrec?)`,
where `newStx` is of category `term`. After elaboration, `newStx` should have type
`TrailingParserDescr` if `lhsPrec?.isSome`, and `ParserDescr` otherwise. -/
partial def toParserDescr (stx : Syntax) (catName : Name) : TermElabM (Term × Option Nat) := do
partial def toParserDescr (stx : Syntax) (catName : Name) : TermElabM ((Term × Nat) × Option Nat) := do
let env ← getEnv
let behavior := Parser.leadingIdentBehavior env catName
(process stx { catName := catName, first := true, leftRec := true, behavior := behavior }).run none
where
process (stx : Syntax) : ToParserDescrM Term := withRef stx do
process (stx : Syntax) : ToParserDescr := withRef stx do
let kind := stx.getKind
if kind == nullKind then
processSeq stx
@ -99,9 +108,9 @@ where
else if kind == ``Lean.Parser.Syntax.cat then
processNullaryOrCat stx
else if kind == ``Lean.Parser.Syntax.unary then
processUnary stx
processAlias stx[0] #[stx[2]]
else if kind == ``Lean.Parser.Syntax.binary then
processBinary stx
processAlias stx[0] #[stx[2], stx[4]]
else if kind == ``Lean.Parser.Syntax.sepBy then
processSepBy stx
else if kind == ``Lean.Parser.Syntax.sepBy1 then
@ -138,12 +147,39 @@ where
throwErrorAt stx "invalid atomic left recursive syntax"
let prec? ← liftMacroM <| expandOptPrecedence stx[1]
let prec := prec?.getD 0
`(ParserDescr.cat $(quote catName) $(quote prec))
return (← `(ParserDescr.cat $(quote catName) $(quote prec)), 1)
processAlias (id : Syntax) (args : Array Syntax) := do
let aliasName := id.getId.eraseMacroScopes
let info ← Parser.getParserAliasInfo aliasName
let args ← args.mapM (withNestedParser ∘ process)
let (args, stackSz) := if let some stackSz := info.stackSz? then
if !info.autoGroupArgs then
(args.map (·.1), stackSz)
else
(args.map ensureUnaryOutput, stackSz)
else
let (args, stackSzs) := args.unzip
(args, stackSzs.foldl (· + ·) 0)
let stx ← match args with
| #[] => Parser.ensureConstantParserAlias aliasName; ``(ParserDescr.const $(quote aliasName))
| #[p1] => Parser.ensureUnaryParserAlias aliasName; ``(ParserDescr.unary $(quote aliasName) $p1)
| #[p1, p2] => Parser.ensureBinaryParserAlias aliasName; ``(ParserDescr.binary $(quote aliasName) $p1 $p2)
| _ => unreachable!
return (stx, stackSz)
processNullaryOrCat (stx : Syntax) := do
match (← resolveParserName stx[0]) with
| [(c, true)] => ensureNoPrec stx; return mkIdentFrom stx c
| [(c, false)] => ensureNoPrec stx; `(ParserDescr.parser $(quote c))
| [(c, true)] =>
ensureNoPrec stx
-- `syntax _ :=` at least enforces this
let stackSz := 1
return (mkIdentFrom stx c, stackSz)
| [(c, false)] =>
ensureNoPrec stx
-- as usual, we assume that people using `Parser` know what they are doing
let stackSz := 1
return (← `(ParserDescr.parser $(quote c)), stackSz)
| cs@(_ :: _ :: _) => throwError "ambiguous parser declaration {cs.map (·.1)}"
| [] =>
let id := stx[0].getId.eraseMacroScopes
@ -151,37 +187,23 @@ where
processParserCategory stx
else if (← Parser.isParserAlias id) then
ensureNoPrec stx
Parser.ensureConstantParserAlias id
`(ParserDescr.const $(quote id))
processAlias stx[0] #[]
else
throwError "unknown parser declaration/category/alias '{id}'"
processUnary (stx : Syntax) := do
let aliasName := (stx[0].getId).eraseMacroScopes
Parser.ensureUnaryParserAlias aliasName
let d ← withNestedParser do process stx[2]
`(ParserDescr.unary $(quote aliasName) $d)
processBinary (stx : Syntax) := do
let aliasName := (stx[0].getId).eraseMacroScopes
Parser.ensureBinaryParserAlias aliasName
let d₁ ← withNestedParser do process stx[2]
let d₂ ← withNestedParser do process stx[4]
`(ParserDescr.binary $(quote aliasName) $d₁ $d₂)
processSepBy (stx : Syntax) := do
let p ← withNestedParser $ process stx[1]
let p ← ensureUnaryOutput <$> withNestedParser do process stx[1]
let sep := stx[3]
let psep ← if stx[4].isNone then `(ParserDescr.symbol $sep) else process stx[4][1]
let psep ← if stx[4].isNone then `(ParserDescr.symbol $sep) else ensureUnaryOutput <$> withNestedParser do process stx[4][1]
let allowTrailingSep := !stx[5].isNone
`(ParserDescr.sepBy $p $sep $psep $(quote allowTrailingSep))
return (← `(ParserDescr.sepBy $p $sep $psep $(quote allowTrailingSep)), 1)
processSepBy1 (stx : Syntax) := do
let p ← withNestedParser do process stx[1]
let p ← ensureUnaryOutput <$> withNestedParser do process stx[1]
let sep := stx[3]
let psep ← if stx[4].isNone then `(ParserDescr.symbol $sep) else process stx[4][1]
let psep ← if stx[4].isNone then `(ParserDescr.symbol $sep) else ensureUnaryOutput <$> withNestedParser do process stx[4][1]
let allowTrailingSep := !stx[5].isNone
`(ParserDescr.sepBy1 $p $sep $psep $(quote allowTrailingSep))
return (← `(ParserDescr.sepBy1 $p $sep $psep $(quote allowTrailingSep)), 1)
isValidAtom (s : String) : Bool :=
!s.isEmpty &&
@ -198,14 +220,14 @@ where
/- For syntax categories where initialized with `LeadingIdentBehavior` different from default (e.g., `tactic`), we automatically mark
the first symbol as nonReserved. -/
if (← read).behavior != Parser.LeadingIdentBehavior.default && (← read).first then
`(ParserDescr.nonReservedSymbol $(quote atom) false)
return (← `(ParserDescr.nonReservedSymbol $(quote atom) false), 1)
else
`(ParserDescr.symbol $(quote atom))
return (← `(ParserDescr.symbol $(quote atom)), 1)
| none => throwUnsupportedSyntax
processNonReserved (stx : Syntax) := do
match stx[1].isStrLit? with
| some atom => `(ParserDescr.nonReservedSymbol $(quote atom) false)
| some atom => return (← `(ParserDescr.nonReservedSymbol $(quote atom) false), 1)
| none => throwUnsupportedSyntax
@ -319,7 +341,7 @@ def resolveSyntaxKind (k : Name) : CommandElabM Name := do
let prio ← liftMacroM <| evalOptPrio prio?
let stxNodeKind := (← getCurrNamespace) ++ name
let catParserId := mkIdentFrom stx (cat.appendAfter "Parser")
let (val, lhsPrec?) ← runTermElabM none fun _ => Term.toParserDescr syntaxParser cat
let ((val, _), lhsPrec?) ← runTermElabM none fun _ => Term.toParserDescr syntaxParser cat
let declName := mkIdentFrom stx name
let d ← if let some lhsPrec := lhsPrec? then
`($[$doc?:docComment]? @[$attrKind:attrKind $catParserId:ident $(quote prio):num] def $declName:ident : Lean.TrailingParserDescr :=
@ -333,7 +355,7 @@ def resolveSyntaxKind (k : Name) : CommandElabM Name := do
@[builtinCommandElab «syntaxAbbrev»] def elabSyntaxAbbrev : CommandElab := fun stx => do
let `($[$doc?:docComment]? syntax $declName:ident := $[$ps:stx]*) ← pure stx | throwUnsupportedSyntax
-- TODO: nonatomic names
let (val, _) ← runTermElabM none fun _ => Term.toParserDescr (mkNullNode ps) Name.anonymous
let ((val, _), _) ← runTermElabM none fun _ => Term.toParserDescr (mkNullNode ps) Name.anonymous
let stxNodeKind := (← getCurrNamespace) ++ declName.getId
let stx' ← `($[$doc?:docComment]? def $declName:ident : Lean.ParserDescr := ParserDescr.nodeWithAntiquot $(quote (toString declName.getId)) $(quote stxNodeKind) $val)
withMacroExpansion stx stx' <| elabCommand stx'

View file

@ -19,28 +19,28 @@ open Lean.PrettyPrinter.Parenthesizer
open Lean.PrettyPrinter.Formatter
builtin_initialize
register_parser_alias "ws" checkWsBefore
register_parser_alias "noWs" checkNoWsBefore
register_parser_alias "linebreak" checkLinebreakBefore
register_parser_alias "ws" checkWsBefore { stackSz? := none }
register_parser_alias "noWs" checkNoWsBefore { stackSz? := none }
register_parser_alias "linebreak" checkLinebreakBefore { stackSz? := none }
register_parser_alias (kind := numLitKind) "num" numLit
register_parser_alias (kind := strLitKind) "str" strLit
register_parser_alias (kind := charLitKind) "char" charLit
register_parser_alias (kind := nameLitKind) "name" nameLit
register_parser_alias (kind := scientificLitKind) "scientific" scientificLit
register_parser_alias (kind := identKind) "ident" ident
register_parser_alias "colGt" checkColGt
register_parser_alias "colGe" checkColGe
register_parser_alias lookahead
register_parser_alias atomic
register_parser_alias "colGt" checkColGt { stackSz? := none }
register_parser_alias "colGe" checkColGe { stackSz? := none }
register_parser_alias lookahead { stackSz? := some 0 }
register_parser_alias atomic { stackSz? := none }
register_parser_alias many
register_parser_alias many1
register_parser_alias manyIndent
register_parser_alias many1Indent
register_parser_alias optional
register_parser_alias withPosition
register_parser_alias optional { autoGroupArgs := false }
register_parser_alias withPosition { stackSz? := none }
register_parser_alias (kind := interpolatedStrKind) interpolatedStr
register_parser_alias orelse
register_parser_alias andthen
register_parser_alias andthen { stackSz? := none }
registerAlias "notFollowedBy" (notFollowedBy · "element")
Parenthesizer.registerAlias "notFollowedBy" notFollowedBy.parenthesizer

View file

@ -202,14 +202,25 @@ def getBinaryAlias {α} (mapRef : IO.Ref (AliasTable α)) (aliasName : Name) : I
abbrev ParserAliasValue := AliasValue Parser
structure ParserAliasInfo where
/-- Number of syntax nodes produced by this parser. `none` means "sum of input sizes". -/
stackSz? : Option Nat := some 1
/-- Whether arguments should be wrapped in `group(·)` if they do not produce exactly one syntax node. -/
autoGroupArgs : Bool := stackSz?.isSome
builtin_initialize parserAliasesRef : IO.Ref (NameMap ParserAliasValue) ← IO.mkRef {}
builtin_initialize parserAlias2kindRef : IO.Ref (NameMap SyntaxNodeKind) ← IO.mkRef {}
builtin_initialize parserAliases2infoRef : IO.Ref (NameMap ParserAliasInfo) ← IO.mkRef {}
def getParserAliasInfo (aliasName : Name) : IO ParserAliasInfo := do
return (← parserAliases2infoRef.get).findD aliasName {}
-- Later, we define macro `register_parser_alias` which registers a parser, formatter and parenthesizer
def registerAlias (aliasName : Name) (p : ParserAliasValue) (kind? : Option SyntaxNodeKind := none) : IO Unit := do
def registerAlias (aliasName : Name) (p : ParserAliasValue) (kind? : Option SyntaxNodeKind := none) (info : ParserAliasInfo := {}) : IO Unit := do
registerAliasCore parserAliasesRef aliasName p
if let some kind := kind? then
parserAlias2kindRef.modify (·.insert aliasName kind)
parserAliases2infoRef.modify (·.insert aliasName info)
instance : Coe Parser ParserAliasValue := { coe := AliasValue.const }
instance : Coe (Parser → Parser) ParserAliasValue := { coe := AliasValue.unary }

View file

@ -169,29 +169,29 @@ attribute [runBuiltinParserAttributeHooks]
ppHardSpace ppSpace ppLine ppGroup ppRealGroup ppRealFill ppIndent ppDedent
ppAllowUngrouped ppDedentIfGrouped ppHardLineUnlessUngrouped
syntax "register_parser_alias" group("(" &"kind" " := " term ")")? (strLit)? ident : term
syntax "register_parser_alias" group("(" &"kind" " := " term ")")? (strLit)? ident (colGt term)? : term
macro_rules
| `(register_parser_alias $[(kind := $kind?)]? $(aliasName?)? $declName) => do
| `(register_parser_alias $[(kind := $kind?)]? $(aliasName?)? $declName $(info?)?) => do
let [(fullDeclName, [])] ← Macro.resolveGlobalName declName.getId |
Macro.throwError "expected non-overloaded constant name"
let aliasName := aliasName?.getD (Syntax.mkStrLit declName.getId.toString)
`(do Parser.registerAlias $aliasName $declName (kind? := some $(kind?.getD (quote fullDeclName)))
`(do Parser.registerAlias $aliasName $declName $(info?.getD (Unhygienic.run `({}))) (kind? := some $(kind?.getD (quote fullDeclName)))
PrettyPrinter.Formatter.registerAlias $aliasName $(mkIdentFrom declName (declName.getId ++ `formatter))
PrettyPrinter.Parenthesizer.registerAlias $aliasName $(mkIdentFrom declName (declName.getId ++ `parenthesizer)))
builtin_initialize
register_parser_alias group
register_parser_alias ppHardSpace
register_parser_alias ppSpace
register_parser_alias ppLine
register_parser_alias ppGroup
register_parser_alias ppRealGroup
register_parser_alias ppRealFill
register_parser_alias ppIndent
register_parser_alias ppDedent
register_parser_alias ppAllowUngrouped
register_parser_alias ppDedentIfGrouped
register_parser_alias ppHardLineUnlessUngrouped
register_parser_alias group { autoGroupArgs := false }
register_parser_alias ppHardSpace { stackSz? := none }
register_parser_alias ppSpace { stackSz? := none }
register_parser_alias ppLine { stackSz? := none }
register_parser_alias ppGroup { stackSz? := none }
register_parser_alias ppRealGroup { stackSz? := none }
register_parser_alias ppRealFill { stackSz? := none }
register_parser_alias ppIndent { stackSz? := none }
register_parser_alias ppDedent { stackSz? := none }
register_parser_alias ppAllowUngrouped { stackSz? := none }
register_parser_alias ppDedentIfGrouped { stackSz? := none }
register_parser_alias ppHardLineUnlessUngrouped { stackSz? := none }
end Parser