fix: store syntax kinds of parser aliases in order to construct correct antiquotations in macro and elab
This commit is contained in:
parent
2c54a0d17a
commit
7d48d125da
9 changed files with 92 additions and 69 deletions
|
|
@ -76,20 +76,20 @@ def elabElabRulesAux (doc? : Option Syntax) (attrKind : Syntax) (k : SyntaxNodeK
|
|||
do elabElabRulesAux doc? attrKind (← resolveSyntaxKind kind.getId) cat? expty? alts
|
||||
| _ => throwUnsupportedSyntax
|
||||
|
||||
@[builtinMacro Lean.Parser.Command.elab]
|
||||
def expandElab : Macro
|
||||
@[builtinCommandElab Lean.Parser.Command.elab]
|
||||
def elabElab : CommandElab
|
||||
| `($[$doc?:docComment]? $attrKind:attrKind
|
||||
elab$[:$prec?]? $[(name := $name?)]? $[(priority := $prio?)]? $args:macroArg* :
|
||||
$cat $[<= $expectedType?]? => $rhs) => do
|
||||
let prio ← evalOptPrio prio?
|
||||
let prio ← liftMacroM <| evalOptPrio prio?
|
||||
let (stxParts, patArgs) := (← args.mapM expandMacroArg).unzip
|
||||
-- name
|
||||
let name ← match name? with
|
||||
| some name => pure name.getId
|
||||
| none => mkNameFromParserSyntax cat.getId (mkNullNode stxParts)
|
||||
let pat := mkNode ((← Macro.getCurrNamespace) ++ name) patArgs
|
||||
| none => liftMacroM <| mkNameFromParserSyntax cat.getId (mkNullNode stxParts)
|
||||
let pat := mkNode ((← getCurrNamespace) ++ name) patArgs
|
||||
`($[$doc?:docComment]? $attrKind:attrKind syntax$[:$prec?]? (name := $(← mkIdentFromRef name)) (priority := $(quote prio)) $[$stxParts]* : $cat
|
||||
$[$doc?:docComment]? elab_rules : $cat $[<= $expectedType?]? | `($pat) => $rhs)
|
||||
| _ => Macro.throwUnsupported
|
||||
$[$doc?:docComment]? elab_rules : $cat $[<= $expectedType?]? | `($pat) => $rhs) >>= elabCommand
|
||||
| _ => throwUnsupportedSyntax
|
||||
|
||||
end Lean.Elab.Command
|
||||
|
|
|
|||
|
|
@ -10,19 +10,19 @@ open Lean.Syntax
|
|||
open Lean.Parser.Term hiding macroArg
|
||||
open Lean.Parser.Command
|
||||
|
||||
@[builtinMacro Lean.Parser.Command.macro] def expandMacro : Macro
|
||||
@[builtinCommandElab Lean.Parser.Command.macro] def elabMacro : CommandElab
|
||||
| `($[$doc?:docComment]? $attrKind:attrKind
|
||||
macro%$tk$[:$prec?]? $[(name := $name?)]? $[(priority := $prio?)]? $args:macroArg* :
|
||||
$cat => $rhs) => do
|
||||
let prio ← evalOptPrio prio?
|
||||
let prio ← liftMacroM <| evalOptPrio prio?
|
||||
let (stxParts, patArgs) := (← args.mapM expandMacroArg).unzip
|
||||
-- name
|
||||
let name ← match name? with
|
||||
| some name => pure name.getId
|
||||
| none => mkNameFromParserSyntax cat.getId (mkNullNode stxParts)
|
||||
| none => liftMacroM <| mkNameFromParserSyntax cat.getId (mkNullNode stxParts)
|
||||
/- The command `syntax [<kind>] ...` adds the current namespace to the syntax node kind.
|
||||
So, we must include current namespace when we create a pattern for the following `macro_rules` commands. -/
|
||||
let pat := mkNode ((← Macro.getCurrNamespace) ++ name) patArgs
|
||||
let pat := mkNode ((← getCurrNamespace) ++ name) patArgs
|
||||
let stxCmd ← `($[$doc?:docComment]? $attrKind:attrKind
|
||||
syntax%$tk$[:$prec?]? (name := $(← mkIdentFromRef name)) (priority := $(quote prio)) $[$stxParts]* : $cat)
|
||||
let macroRulesCmd ← if rhs.getArgs.size == 1 then
|
||||
|
|
@ -33,7 +33,7 @@ open Lean.Parser.Command
|
|||
-- `rhs` is of the form `` `( $body ) ``
|
||||
let rhsBody := rhs[1]
|
||||
`($[$doc?:docComment]? macro_rules%$tk | `($pat) => `($rhsBody))
|
||||
return mkNullNode #[stxCmd, macroRulesCmd]
|
||||
| _ => Macro.throwUnsupported
|
||||
elabCommand <| mkNullNode #[stxCmd, macroRulesCmd]
|
||||
| _ => throwUnsupportedSyntax
|
||||
|
||||
end Lean.Elab.Command
|
||||
|
|
|
|||
|
|
@ -11,33 +11,49 @@ open Lean.Parser.Term hiding macroArg
|
|||
open Lean.Parser.Command
|
||||
|
||||
/- Convert `macro` arg into a `syntax` command item and a pattern element -/
|
||||
def expandMacroArg (stx : Syntax) : MacroM (Syntax × Syntax) := do
|
||||
let (id?, id, stx) ← match (← expandMacros stx) with
|
||||
partial def expandMacroArg (stx : Syntax) : CommandElabM (Syntax × Syntax) := do
|
||||
let (id?, id, stx) ← match (← liftMacroM <| expandMacros stx) with
|
||||
| `(macroArg| $id:ident:$stx) => pure (some id, id, stx)
|
||||
| `(macroArg| $stx:stx) => pure (none, (← `(x)), stx)
|
||||
| _ => Macro.throwUnsupported
|
||||
| _ => throwUnsupportedSyntax
|
||||
let pat ← match stx with
|
||||
| `(stx| $s:str) => pure <| mkNode `token_antiquot #[← strLitToPattern s, mkAtom "%", mkAtom "$", id]
|
||||
| `(stx| &$s:str) => pure <| mkNode `token_antiquot #[← strLitToPattern s, mkAtom "%", mkAtom "$", id]
|
||||
| `(stx| optional($_)) => pure <| mkSplicePat `optional id "?"
|
||||
| `(stx| many($_)) => pure <| mkSplicePat `many id "*"
|
||||
| `(stx| many1($_)) => pure <| mkSplicePat `many id "*"
|
||||
| `(stx| sepBy($_, $sep:str $[, $stxsep]? $[, allowTrailingSep]?)) =>
|
||||
pure <| mkSplicePat `sepBy id ((isStrLit? sep).get! ++ "*")
|
||||
| `(stx| sepBy1($_, $sep:str $[, $stxsep]? $[, allowTrailingSep]?)) =>
|
||||
pure <| mkSplicePat `sepBy id ((isStrLit? sep).get! ++ "*")
|
||||
| `(stx| $s:str) => pure <| mkNode `token_antiquot #[← liftMacroM <| strLitToPattern s, mkAtom "%", mkAtom "$", id]
|
||||
| `(stx| &$s:str) => pure <| mkNode `token_antiquot #[← liftMacroM <| strLitToPattern s, mkAtom "%", mkAtom "$", id]
|
||||
| `(stx| optional($stx)) => mkSplicePat `optional stx id "?"
|
||||
| `(stx| many($stx)) => mkSplicePat `many stx id "*"
|
||||
| `(stx| many1($stx)) => mkSplicePat `many stx id "*"
|
||||
| `(stx| sepBy($stx, $sep:str $[, $stxsep]? $[, allowTrailingSep]?)) =>
|
||||
mkSplicePat `sepBy stx id ((isStrLit? sep).get! ++ "*")
|
||||
| `(stx| sepBy1($stx, $sep:str $[, $stxsep]? $[, allowTrailingSep]?)) =>
|
||||
mkSplicePat `sepBy stx id ((isStrLit? sep).get! ++ "*")
|
||||
-- NOTE: all `interpolatedStr(·)` reuse the same node kind
|
||||
| `(stx| interpolatedStr(term)) => pure <| Syntax.mkAntiquotNode interpolatedStrKind id
|
||||
| _ => match id? with
|
||||
-- if there is a binding, we assume the user knows what they are doing
|
||||
| some id =>
|
||||
if stx.isOfKind ``Lean.Parser.Syntax.cat then
|
||||
let parser := stx[0].getId.eraseMacroScopes
|
||||
pure <| mkAntiquotNode id (kind := parser) (isPseudoKind := true)
|
||||
else
|
||||
pure <| mkAntiquotNode id
|
||||
| some id => mkAntiquotNode stx id
|
||||
-- otherwise `group` the syntax to enforce arity 1, e.g. for `noWs`
|
||||
| none => return (← `(stx| group($stx)), mkAntiquotNode id)
|
||||
| none => return (← `(stx| group($stx)), (← mkAntiquotNode stx id))
|
||||
pure (stx, pat)
|
||||
where mkSplicePat kind id suffix :=
|
||||
mkNullNode #[mkAntiquotSuffixSpliceNode kind (mkAntiquotNode id) suffix]
|
||||
where
|
||||
mkSplicePat kind stx id suffix :=
|
||||
return mkNullNode #[mkAntiquotSuffixSpliceNode kind (← mkAntiquotNode stx id) suffix]
|
||||
mkAntiquotNode
|
||||
| `(stx| ($stx)), term => mkAntiquotNode stx term
|
||||
| `(stx| $id:ident$[:$p:prec]?), term => do
|
||||
let kind ← match (← Elab.Term.resolveParserName id) with
|
||||
-- a syntax abbrev, assume kind == decl name
|
||||
| [(c, _)] => pure c
|
||||
| cs@(_ :: _ :: _) => throwError "ambiguous parser declaration {cs.map (·.1)}"
|
||||
| [] =>
|
||||
let id := id.getId.eraseMacroScopes
|
||||
if Parser.isParserCategory (← getEnv) id then
|
||||
return Syntax.mkAntiquotNode id term (isPseudoKind := true)
|
||||
else if (← Parser.isParserAlias id) then
|
||||
pure <| (← Parser.getSyntaxKindOfParserAlias? id).getD Name.anonymous
|
||||
else
|
||||
throwError "unknown parser declaration/category/alias '{id}'"
|
||||
pure <| Syntax.mkAntiquotNode kind term
|
||||
| stx, term =>
|
||||
pure <| Syntax.mkAntiquotNode Name.anonymous term (isPseudoKind := true)
|
||||
|
||||
end Lean.Elab.Command
|
||||
|
|
|
|||
|
|
@ -57,6 +57,27 @@ def checkLeftRec (stx : Syntax) : ToParserDescrM Bool := do
|
|||
markAsTrailingParser (prec?.getD 0)
|
||||
return true
|
||||
|
||||
/-- Resolve the given parser name and return a list of candidates.
|
||||
Each candidate is a pair `(resolvedParserName, isDescr)`.
|
||||
`isDescr == true` if the type of `resolvedParserName` is a `ParserDescr`. -/
|
||||
def resolveParserName [Monad m] [MonadInfoTree m] [MonadResolveName m] [MonadEnv m] [MonadError m] (parserName : Syntax) : m (List (Name × Bool)) := do
|
||||
try
|
||||
let candidates ← resolveGlobalConstWithInfos parserName
|
||||
/- Convert `candidates` in a list of pairs `(c, isDescr)`, where `c` is the parser name,
|
||||
and `isDescr` is true iff `c` has type `Lean.ParserDescr` or `Lean.TrailingParser` -/
|
||||
let env ← getEnv
|
||||
return candidates.filterMap fun c =>
|
||||
match env.find? c with
|
||||
| none => none
|
||||
| some info =>
|
||||
match info.type with
|
||||
| Expr.const ``Lean.Parser.TrailingParser _ _ => (c, false)
|
||||
| Expr.const ``Lean.Parser.Parser _ _ => (c, false)
|
||||
| Expr.const ``Lean.ParserDescr _ _ => (c, true)
|
||||
| Expr.const ``Lean.TrailingParserDescr _ _ => (c, true)
|
||||
| _ => none
|
||||
catch _ => return []
|
||||
|
||||
/--
|
||||
Given a `stx` of category `syntax`, return a pair `(newStx, lhsPrec?)`,
|
||||
where `newStx` is of category `term`. After elaboration, `newStx` should have type
|
||||
|
|
@ -106,27 +127,6 @@ where
|
|||
let args ← args.mapIdxM fun i arg => withReader (fun ctx => { ctx with first := ctx.first && i.val == 0 }) do process arg
|
||||
mkParserSeq args
|
||||
|
||||
/- Resolve the given parser name and return a list of candidates.
|
||||
Each candidate is a pair `(resolvedParserName, isDescr)`.
|
||||
`isDescr == true` if the type of `resolvedParserName` is a `ParserDescr`. -/
|
||||
resolveParserName (parserName : Syntax) : ToParserDescrM (List (Name × Bool)) := do
|
||||
try
|
||||
let candidates ← resolveGlobalConstWithInfos parserName
|
||||
/- Convert `candidates` in a list of pairs `(c, isDescr)`, where `c` is the parser name,
|
||||
and `isDescr` is true iff `c` has type `Lean.ParserDescr` or `Lean.TrailingParser` -/
|
||||
let env ← getEnv
|
||||
return candidates.filterMap fun c =>
|
||||
match env.find? c with
|
||||
| none => none
|
||||
| some info =>
|
||||
match info.type with
|
||||
| Expr.const ``Lean.Parser.TrailingParser _ _ => (c, false)
|
||||
| Expr.const ``Lean.Parser.Parser _ _ => (c, false)
|
||||
| Expr.const ``Lean.ParserDescr _ _ => (c, true)
|
||||
| Expr.const ``Lean.TrailingParserDescr _ _ => (c, true)
|
||||
| _ => none
|
||||
catch _ => return []
|
||||
|
||||
ensureNoPrec (stx : Syntax) :=
|
||||
unless stx[1].isNone do
|
||||
throwErrorAt stx[1] "unexpected precedence"
|
||||
|
|
|
|||
|
|
@ -22,12 +22,12 @@ builtin_initialize
|
|||
register_parser_alias "ws" checkWsBefore
|
||||
register_parser_alias "noWs" checkNoWsBefore
|
||||
register_parser_alias "linebreak" checkLinebreakBefore
|
||||
register_parser_alias "num" numLit
|
||||
register_parser_alias "str" strLit
|
||||
register_parser_alias "char" charLit
|
||||
register_parser_alias "name" nameLit
|
||||
register_parser_alias "scientific" scientificLit
|
||||
register_parser_alias ident
|
||||
register_parser_alias (kind := numLitKind) "num" numLit
|
||||
register_parser_alias (kind := strLitKind) "str" strLit
|
||||
register_parser_alias (kind := charLitKind) "char" charLit
|
||||
register_parser_alias (kind := nameLitKind) "name" nameLit
|
||||
register_parser_alias (kind := scientificLitKind) "scientific" scientificLit
|
||||
register_parser_alias (kind := identKind) "ident" ident
|
||||
register_parser_alias "colGt" checkColGt
|
||||
register_parser_alias "colGe" checkColGe
|
||||
register_parser_alias lookahead
|
||||
|
|
@ -38,7 +38,7 @@ builtin_initialize
|
|||
register_parser_alias many1Indent
|
||||
register_parser_alias optional
|
||||
register_parser_alias withPosition
|
||||
register_parser_alias interpolatedStr
|
||||
register_parser_alias (kind := interpolatedStrKind) interpolatedStr
|
||||
register_parser_alias orelse
|
||||
register_parser_alias andthen
|
||||
|
||||
|
|
|
|||
|
|
@ -203,10 +203,13 @@ def getBinaryAlias {α} (mapRef : IO.Ref (AliasTable α)) (aliasName : Name) : I
|
|||
abbrev ParserAliasValue := AliasValue Parser
|
||||
|
||||
builtin_initialize parserAliasesRef : IO.Ref (NameMap ParserAliasValue) ← IO.mkRef {}
|
||||
builtin_initialize parserAlias2kindRef : IO.Ref (NameMap SyntaxNodeKind) ← IO.mkRef {}
|
||||
|
||||
-- Later, we define macro registerParserAlias! which registers a parser, formatter and parenthesizer
|
||||
def registerAlias (aliasName : Name) (p : ParserAliasValue) : IO Unit := do
|
||||
-- Later, we define macro `register_parser_alias` which registers a parser, formatter and parenthesizer
|
||||
def registerAlias (aliasName : Name) (p : ParserAliasValue) (kind? : Option SyntaxNodeKind := none) : IO Unit := do
|
||||
registerAliasCore parserAliasesRef aliasName p
|
||||
if let some kind := kind? then
|
||||
parserAlias2kindRef.modify (·.insert aliasName kind)
|
||||
|
||||
instance : Coe Parser ParserAliasValue := { coe := AliasValue.const }
|
||||
instance : Coe (Parser → Parser) ParserAliasValue := { coe := AliasValue.unary }
|
||||
|
|
@ -217,6 +220,9 @@ def isParserAlias (aliasName : Name) : IO Bool := do
|
|||
| some _ => pure true
|
||||
| _ => pure false
|
||||
|
||||
def getSyntaxKindOfParserAlias? (aliasName : Name) : IO (Option SyntaxNodeKind) :=
|
||||
return (← parserAlias2kindRef.get).find? aliasName
|
||||
|
||||
def ensureUnaryParserAlias (aliasName : Name) : IO Unit :=
|
||||
discard $ getUnaryAlias parserAliasesRef aliasName
|
||||
|
||||
|
|
|
|||
|
|
@ -169,9 +169,11 @@ attribute [runBuiltinParserAttributeHooks]
|
|||
ppHardSpace ppSpace ppLine ppGroup ppRealGroup ppRealFill ppIndent ppDedent
|
||||
ppAllowUngrouped ppDedentIfGrouped ppHardLineUnlessUngrouped
|
||||
|
||||
macro "register_parser_alias" aliasName?:optional(strLit) declName:ident : term =>
|
||||
macro "register_parser_alias" kind?:group("(" &"kind" " := " term ")")? aliasName?:optional(strLit) declName:ident : term => do
|
||||
let [(fullDeclName, [])] ← Macro.resolveGlobalName declName.getId |
|
||||
Macro.throwError "expected non-overloaded constant name"
|
||||
let aliasName := aliasName?.getD (Syntax.mkStrLit declName.getId.toString)
|
||||
`(do Parser.registerAlias $aliasName $declName
|
||||
`(do Parser.registerAlias $aliasName $declName (kind? := some $(kind?.map (·[3]) |>.getD (quote fullDeclName)))
|
||||
PrettyPrinter.Formatter.registerAlias $aliasName $(mkIdentFrom declName (declName.getId ++ `formatter))
|
||||
PrettyPrinter.Parenthesizer.registerAlias $aliasName $(mkIdentFrom declName (declName.getId ++ `parenthesizer)))
|
||||
|
||||
|
|
|
|||
|
|
@ -240,7 +240,7 @@ def categoryFormatterCore (cat : Name) : Formatter := do
|
|||
-- TODO: We could use elaborator data here to format the chosen child when available
|
||||
formatterForKind (← getCur).getKind
|
||||
else if cat == `rawStx then
|
||||
withAntiquot.formatter (mkAntiquot.formatter' cat.toString none) (push stx.formatStx *> goLeft)
|
||||
withAntiquot.formatter (mkAntiquot.formatter' cat.toString cat (isPseudoKind := true)) (push stx.formatStx *> goLeft)
|
||||
else
|
||||
withAntiquot.formatter (mkAntiquot.formatter' cat.toString cat (isPseudoKind := true)) (formatterForKind stx.getKind)
|
||||
modify fun st => { st with mustBeGrouped := true, isUngrouped := !st.mustBeGrouped }
|
||||
|
|
|
|||
|
|
@ -382,8 +382,7 @@ def isAntiquot : Syntax → Bool
|
|||
| Syntax.node _ (Name.str _ "antiquot" _) _ => true
|
||||
| _ => false
|
||||
|
||||
-- TODO: `kind` should not be optional for best `TSyntax` usage
|
||||
def mkAntiquotNode (term : Syntax) (nesting := 0) (name : Option String := none) (kind := `pseudo) (isPseudoKind := false) : Syntax :=
|
||||
def mkAntiquotNode (kind : Name) (term : Syntax) (nesting := 0) (name : Option String := none) (isPseudoKind := false) : Syntax :=
|
||||
let nesting := mkNullNode (mkArray nesting (mkAtom "$"))
|
||||
let term :=
|
||||
if term.isIdent then term
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue