feat: antiquotation scopes

This commit is contained in:
Sebastian Ullrich 2020-12-04 17:23:30 +01:00
parent bdabdd78cf
commit d7f27a140e
7 changed files with 147 additions and 71 deletions

View file

@ -28,10 +28,11 @@ def mkAntiquotNode (term : Syntax) (nesting := 0) (name : Option String := none)
| false => mkNullNode
mkNode (kind ++ `antiquot) #[mkAtom "$", nesting, term, name, splice]
-- Antiquotations can be escaped as in `$$x`, which is useful for nesting macros.
-- Antiquotations can be escaped as in `$$x`, which is useful for nesting macros. Also works for antiquotation scopes.
def isEscapedAntiquot (stx : Syntax) : Bool :=
!stx[1].getArgs.isEmpty
-- Also works for antiquotation scopes.
def unescapeAntiquot (stx : Syntax) : Syntax :=
if isAntiquot stx then
stx.setArg 1 $ mkNullNode stx[1].getArgs.pop
@ -57,11 +58,36 @@ def antiquotKind? : Syntax → Option SyntaxNodeKind
def isAntiquotSplice (stx : Syntax) : Bool :=
isAntiquot stx && stx[4].getOptional?.isSome
-- An "antiquotation scope" is something like `$[...]?` or `$[...]*`. Note that the latter could be of kind `many` or
-- `sepBy`, which have different implementations.
def antiquotScopeKind? : Syntax → Option SyntaxNodeKind
| Syntax.node (Name.str k "antiquot_scope" _) args => some k
| _ => none
def isAntiquotScope (stx : Syntax) : Bool :=
antiquotScopeKind? stx |>.isSome
def getAntiquotScopeContents (stx : Syntax) : Array Syntax :=
stx[3].getArgs
def getAntiquotScopeSuffix (stx : Syntax) : Syntax :=
stx[5]
-- If any item of a `many` node is an antiquotation splice, its result should
-- be substituted into the `many` node's children
def isAntiquotSplicePat (stx : Syntax) : Bool :=
stx.isOfKind nullKind && stx.getArgs.any fun arg => isAntiquotSplice arg && !isEscapedAntiquot arg
partial def getAntiquotationIds : Syntax → TermElabM (List Syntax)
| stx@(Syntax.node k args) =>
if isAntiquot stx && !isEscapedAntiquot stx then
let anti := getAntiquotTerm stx
if anti.isIdent then [anti]
else throwErrorAt stx "complex antiquotation not allowed here"
else
List.join <$> args.toList.mapM getAntiquotationIds
| _ => []
-- Elaborate the content of a syntax quotation term
private partial def quoteSyntax : Syntax → TermElabM Syntax
| Syntax.ident info rawVal val preresolved => do
@ -78,6 +104,8 @@ private partial def quoteSyntax : Syntax → TermElabM Syntax
-- splices must occur in a `many` node
if isAntiquotSplice stx then throwErrorAt stx "unexpected antiquotation splice"
else pure $ getAntiquotTerm stx
else if isAntiquotScope stx && !isEscapedAntiquot stx then
throwErrorAt stx "unexpected antiquotation splice"
else
let empty ← `(Array.empty);
-- if escaped antiquotation, decrement by one escape level
@ -86,6 +114,28 @@ private partial def quoteSyntax : Syntax → TermElabM Syntax
if k == nullKind && isAntiquotSplice arg then
-- antiquotation splice pattern: inject args array
`(Array.appendCore $args $(getAntiquotTerm arg))
else if k == nullKind && isAntiquotScope arg then do
let k := antiquotScopeKind? arg
let inner ← (getAntiquotScopeContents arg).mapM quoteSyntax
let arr ← match (← getAntiquotationIds arg) with
| [] => throwErrorAt stx "antiquotation scope must contain at least one antiquotation"
| [id] => match k with
| `optional => `(match $id:ident with
| some $id:ident => $(quote inner)
| none => #[])
| _ => `(Array.map (fun $id => $(inner[0])) $id)
| [id1, id2] => match k with
| `optional => `(match $id1:ident, $id2:ident with
| some $id1:ident, some $id2:ident => $(quote inner)
| _ => #[])
| _ => `(Array.zipWith $id1 $id2 fun $id1 $id2 => $(inner[0]))
| _ => throwErrorAt stx "too many antiquotations in antiquotation scope; don't be greedy"
let arr ←
if k == `sepBy then
let Syntax.atom _ sep ← (getAntiquotScopeSuffix arg)[0] | unreachable!
`(mkSepArray $arr (mkAtom $(Syntax.mkStrLit sep)))
else arr
`(Array.appendCore $args $arr)
else do
let arg ← quoteSyntax arg;
`(Array.push $args $arg)) empty
@ -246,21 +296,11 @@ private partial def compileStxMatch : List Syntax → List Alt → TermElabM Syn
`(let discr := $discr; ite (Eq $cond true) $yes $no)
| _, _ => unreachable!
private partial def getPatternVarsAux : Syntax → List Syntax
| stx@(Syntax.node k args) =>
if isAntiquot stx && !isEscapedAntiquot stx then
let anti := getAntiquotTerm stx
if anti.isIdent then [anti]
else []
else
List.join $ args.toList.map getPatternVarsAux
| _ => []
-- Get all pattern vars (as `Syntax.ident`s) in `stx`
partial def getPatternVars (stx : Syntax) : List Syntax :=
partial def getPatternVars (stx : Syntax) : TermElabM (List Syntax) :=
if isQuot stx then do
let quoted := stx.getArg 1;
getPatternVarsAux stx
getAntiquotationIds stx
else if stx.isIdent then
[stx]
else []
@ -270,7 +310,7 @@ partial def getPatternVars (stx : Syntax) : List Syntax :=
private def letBindRhss (cont : List Alt → TermElabM Syntax) : List Alt → List Alt → TermElabM Syntax
| [], altsRev' => cont altsRev'.reverse
| (pats, rhs)::alts, altsRev' => do
let vars := List.join $ pats.map getPatternVars
let vars ← List.join <$> pats.mapM getPatternVars
match vars with
-- no antiquotations => introduce Unit parameter to preserve evaluation order
| [] =>

View file

@ -515,7 +515,7 @@ def optionalFn (p : ParserFn) : ParserFn := fun c s =>
firstTokens := p.firstTokens.toOptional
}
@[inline] def optional (p : Parser) : Parser := {
@[inline] def optionalNoAntiquot (p : Parser) : Parser := {
info := optionaInfo p.info,
fn := optionalFn p.fn
}
@ -559,7 +559,7 @@ partial def manyAux (p : ParserFn) : ParserFn := fun c s =>
let s := manyAux p c s
s.mkNode nullKind iniSz
@[inline] def many (p : Parser) : Parser := {
@[inline] def manyNoAntiquot (p : Parser) : Parser := {
info := noFirstTokenInfo p.info,
fn := manyFn p.fn
}
@ -569,7 +569,7 @@ partial def manyAux (p : ParserFn) : ParserFn := fun c s =>
let s := andthenFn p (manyAux p) c s
s.mkNode nullKind iniSz
@[inline] def many1 (p : Parser) : Parser := {
@[inline] def many1NoAntiquot (p : Parser) : Parser := {
info := p.info,
fn := many1Fn p.fn
}
@ -618,12 +618,12 @@ def sepBy1Fn (allowTrailingSep : Bool) (p : ParserFn) (sep : ParserFn) : ParserF
firstTokens := p.firstTokens
}
@[inline] def sepBy (p sep : Parser) (allowTrailingSep : Bool := false) : Parser := {
@[inline] def sepByNoAntiquot (p sep : Parser) (allowTrailingSep : Bool := false) : Parser := {
info := sepByInfo p.info sep.info,
fn := sepByFn allowTrailingSep p.fn sep.fn
}
@[inline] def sepBy1 (p sep : Parser) (allowTrailingSep : Bool := false) : Parser := {
@[inline] def sepBy1NoAntiquot (p sep : Parser) (allowTrailingSep : Bool := false) : Parser := {
info := sepBy1Info p.info sep.info,
fn := sepBy1Fn allowTrailingSep p.fn sep.fn
}
@ -647,7 +647,7 @@ def withResultOfFn (p : ParserFn) (f : Syntax → Syntax) : ParserFn := fun c s
}
@[inline] def many1Unbox (p : Parser) : Parser :=
withResultOf (many1 p) fun stx => if stx.getNumArgs == 1 then stx.getArg 0 else stx
withResultOf (many1NoAntiquot p) fun stx => if stx.getNumArgs == 1 then stx.getArg 0 else stx
partial def satisfyFn (p : Char → Bool) (errorMsg : String := "unexpected character") : ParserFn := fun c s =>
let i := s.pos
@ -1603,10 +1603,10 @@ def mkAntiquot (name : String) (kind : Option SyntaxNodeKind) (anonymous := true
-- antiquotations are not part of the "standard" syntax, so hide "expected '$'" on error
leadingNode kind maxPrec $ atomic $
setExpected [] "$" >>
many (checkNoWsBefore "" >> "$") >>
manyNoAntiquot (checkNoWsBefore "" >> "$") >>
checkNoWsBefore "no space before spliced term" >> antiquotExpr >>
nameP >>
optional (checkNoWsBefore "" >> symbol "*")
optionalNoAntiquot (checkNoWsBefore "" >> symbol "*")
def tryAnti (c : ParserContext) (s : ParserState) : Bool :=
let (s, stx?) := peekToken c s
@ -1623,34 +1623,20 @@ def tryAnti (c : ParserContext) (s : ParserState) : Bool :=
info := orelseInfo antiquotP.info p.info
}
/- ===================== -/
/- End of Antiquotations -/
/- ===================== -/
def mkAntiquotScope (kind : SyntaxNodeKind) (p suffix : Parser) : Parser :=
let kind := kind ++ `antiquot_scope
leadingNode kind maxPrec $ atomic $
setExpected [] "$" >>
manyNoAntiquot (checkNoWsBefore "" >> "$") >>
checkNoWsBefore "no space before spliced term" >> symbol "[" >> node nullKind p >> symbol "]" >>
suffix
def nodeWithAntiquot (name : String) (kind : SyntaxNodeKind) (p : Parser) (anonymous := false) : Parser :=
withAntiquot (mkAntiquot name kind anonymous) $ node kind p
def ident : Parser :=
withAntiquot (mkAntiquot "ident" identKind) identNoAntiquot
-- `ident` and `rawIdent` produce the same syntax tree, so we reuse the antiquotation kind name
def rawIdent : Parser :=
withAntiquot (mkAntiquot "ident" identKind) rawIdentNoAntiquot
def numLit : Parser :=
withAntiquot (mkAntiquot "numLit" numLitKind) numLitNoAntiquot
def scientificLit : Parser :=
withAntiquot (mkAntiquot "scientificLit" scientificLitKind) scientificLitNoAntiquot
def strLit : Parser :=
withAntiquot (mkAntiquot "strLit" strLitKind) strLitNoAntiquot
def charLit : Parser :=
withAntiquot (mkAntiquot "charLit" charLitKind) charLitNoAntiquot
def nameLit : Parser :=
withAntiquot (mkAntiquot "nameLit" nameLitKind) nameLitNoAntiquot
/- ===================== -/
/- End of Antiquotations -/
/- ===================== -/
def categoryParserOfStackFn (offset : Nat) : ParserFn := fun ctx s =>
let stack := s.stxStack

View file

@ -14,8 +14,48 @@ namespace Parser
-- synthesize pretty printers for parsers declared prior to `Lean.PrettyPrinter`
-- (because `Parser.Extension` depends on them)
attribute [runBuiltinParserAttributeHooks]
leadingNode termParser commandParser antiquotNestedExpr antiquotExpr mkAntiquot nodeWithAntiquot
ident numLit scientificLit charLit strLit nameLit
leadingNode termParser commandParser mkAntiquot nodeWithAntiquot
@[runBuiltinParserAttributeHooks] def optional (p : Parser) : Parser :=
optionalNoAntiquot (withAntiquot (mkAntiquotScope `optional p (symbol "?")) p)
@[runBuiltinParserAttributeHooks] def many (p : Parser) : Parser :=
manyNoAntiquot (withAntiquot (mkAntiquotScope `many p (symbol "*")) p)
@[runBuiltinParserAttributeHooks] def many1 (p : Parser) : Parser :=
many1NoAntiquot (withAntiquot (mkAntiquotScope `many p (symbol "*")) p)
-- all the separators you could ever want
@[runBuiltinParserAttributeHooks] def sepByScopeSuffixes : Parser :=
parser! (symbol "," <|> symbol ";" <|> symbol "|") >> symbol "*"
@[runBuiltinParserAttributeHooks] def sepBy (p psep : Parser) (allowTrailingSep : Bool := false) : Parser :=
sepByNoAntiquot (withAntiquot (mkAntiquotScope `sepBy p sepByScopeSuffixes) p) psep allowTrailingSep
@[runBuiltinParserAttributeHooks] def sepBy1 (p psep : Parser) (allowTrailingSep : Bool := false) : Parser :=
sepBy1NoAntiquot (withAntiquot (mkAntiquotScope `sepBy p sepByScopeSuffixes) p) psep allowTrailingSep
@[runBuiltinParserAttributeHooks] def ident : Parser :=
withAntiquot (mkAntiquot "ident" identKind) identNoAntiquot
-- `ident` and `rawIdent` produce the same syntax tree, so we reuse the antiquotation kind name
@[runBuiltinParserAttributeHooks] def rawIdent : Parser :=
withAntiquot (mkAntiquot "ident" identKind) rawIdentNoAntiquot
@[runBuiltinParserAttributeHooks] def numLit : Parser :=
withAntiquot (mkAntiquot "numLit" numLitKind) numLitNoAntiquot
@[runBuiltinParserAttributeHooks] def scientificLit : Parser :=
withAntiquot (mkAntiquot "scientificLit" scientificLitKind) scientificLitNoAntiquot
@[runBuiltinParserAttributeHooks] def strLit : Parser :=
withAntiquot (mkAntiquot "strLit" strLitKind) strLitNoAntiquot
@[runBuiltinParserAttributeHooks] def charLit : Parser :=
withAntiquot (mkAntiquot "charLit" charLitKind) charLitNoAntiquot
@[runBuiltinParserAttributeHooks] def nameLit : Parser :=
withAntiquot (mkAntiquot "nameLit" nameLitKind) nameLitNoAntiquot
@[runBuiltinParserAttributeHooks, inline] def group (p : Parser) : Parser :=
node nullKind p

View file

@ -337,14 +337,14 @@ def identNoAntiquot.formatter : Formatter := do
pushToken info id.toString
goLeft
@[combinatorFormatter Lean.Parser.rawIdent] def rawIdent.formatter : Formatter := do
@[combinatorFormatter Lean.Parser.rawIdentNoAntiquot] def rawIdentNoAntiquot.formatter : Formatter := do
checkKind identKind
let Syntax.ident info _ id _ ← getCur
| throwError m!"not an ident: {← getCur}"
pushToken info id.toString
goLeft
@[combinatorFormatter Lean.Parser.identEq] def identEq.formatter (id : Name) := rawIdent.formatter
@[combinatorFormatter Lean.Parser.identEq] def identEq.formatter (id : Name) := rawIdentNoAntiquot.formatter
def visitAtom (k : SyntaxNodeKind) : Formatter := do
let stx ← getCur
@ -362,30 +362,30 @@ def visitAtom (k : SyntaxNodeKind) : Formatter := do
@[combinatorFormatter Lean.Parser.scientificLitNoAntiquot] def scientificLitNoAntiquot.formatter := visitAtom scientificLitKind
@[combinatorFormatter Lean.Parser.fieldIdx] def fieldIdx.formatter := visitAtom fieldIdxKind
@[combinatorFormatter Lean.Parser.many]
def many.formatter (p : Formatter) : Formatter := do
@[combinatorFormatter Lean.Parser.manyNoAntiquot]
def manyNoAntiquot.formatter (p : Formatter) : Formatter := do
let stx ← getCur
visitArgs $ stx.getArgs.size.forM fun _ => p
@[combinatorFormatter Lean.Parser.many1] def many1.formatter (p : Formatter) : Formatter := many.formatter p
@[combinatorFormatter Lean.Parser.many1NoAntiquot] def many1NoAntiquot.formatter (p : Formatter) : Formatter := manyNoAntiquot.formatter p
@[combinatorFormatter Lean.Parser.optional]
def optional.formatter (p : Formatter) : Formatter := visitArgs p
@[combinatorFormatter Lean.Parser.optionalNoAntiquot]
def optionalNoAntiquot.formatter (p : Formatter) : Formatter := visitArgs p
@[combinatorFormatter Lean.Parser.many1Unbox]
def many1Unbox.formatter (p : Formatter) : Formatter := do
let stx ← getCur
if stx.getKind == nullKind then do
many.formatter p
manyNoAntiquot.formatter p
else
p
@[combinatorFormatter Lean.Parser.sepBy]
def sepBy.formatter (p pSep : Formatter) : Formatter := do
@[combinatorFormatter Lean.Parser.sepByNoAntiquot]
def sepByNoAntiquot.formatter (p pSep : Formatter) : Formatter := do
let stx ← getCur
visitArgs $ (List.range stx.getArgs.size).reverse.forM $ fun i => if i % 2 == 0 then p else pSep
@[combinatorFormatter Lean.Parser.sepBy1] def sepBy1.formatter := sepBy.formatter
@[combinatorFormatter Lean.Parser.sepBy1NoAntiquot] def sepBy1NoAntiquot.formatter := sepByNoAntiquot.formatter
@[combinatorFormatter Lean.Parser.withPosition] def withPosition.formatter (p : Formatter) : Formatter := p
@[combinatorFormatter Lean.Parser.withoutPosition] def withoutPosition.formatter (p : Formatter) : Formatter := p

View file

@ -405,7 +405,7 @@ def trailingNode.parenthesizer (k : SyntaxNodeKind) (prec : Nat) (p : Parenthesi
@[combinatorParenthesizer Lean.Parser.unicodeSymbol] def unicodeSymbol.parenthesizer (sym asciiSym : String) := visitToken
@[combinatorParenthesizer Lean.Parser.identNoAntiquot] def identNoAntiquot.parenthesizer := visitToken
@[combinatorParenthesizer Lean.Parser.rawIdent] def rawIdent.parenthesizer := visitToken
@[combinatorParenthesizer Lean.Parser.rawIdentNoAntiquot] def rawIdentNoAntiquot.parenthesizer := visitToken
@[combinatorParenthesizer Lean.Parser.identEq] def identEq.parenthesizer (id : Name) := visitToken
@[combinatorParenthesizer Lean.Parser.nonReservedSymbol] def nonReservedSymbol.parenthesizer (sym : String) (includeIdent : Bool) := visitToken
@ -416,33 +416,33 @@ def trailingNode.parenthesizer (k : SyntaxNodeKind) (prec : Nat) (p : Parenthesi
@[combinatorParenthesizer Lean.Parser.scientificLitNoAntiquot] def scientificLitNoAntiquot.parenthesizer := visitToken
@[combinatorParenthesizer Lean.Parser.fieldIdx] def fieldIdx.parenthesizer := visitToken
@[combinatorParenthesizer Lean.Parser.many]
def many.parenthesizer (p : Parenthesizer) : Parenthesizer := do
@[combinatorParenthesizer Lean.Parser.manyNoAntiquot]
def manyNoAntiquot.parenthesizer (p : Parenthesizer) : Parenthesizer := do
let stx ← getCur
visitArgs $ stx.getArgs.size.forM fun _ => p
@[combinatorParenthesizer Lean.Parser.many1]
def many1.parenthesizer (p : Parenthesizer) : Parenthesizer := do
many.parenthesizer p
@[combinatorParenthesizer Lean.Parser.many1NoAntiquot]
def many1NoAntiquot.parenthesizer (p : Parenthesizer) : Parenthesizer := do
manyNoAntiquot.parenthesizer p
@[combinatorParenthesizer Lean.Parser.many1Unbox]
def many1Unbox.parenthesizer (p : Parenthesizer) : Parenthesizer := do
let stx ← getCur
if stx.getKind == nullKind then
many.parenthesizer p
manyNoAntiquot.parenthesizer p
else
p
@[combinatorParenthesizer Lean.Parser.optional]
def optional.parenthesizer (p : Parenthesizer) : Parenthesizer := do
@[combinatorParenthesizer Lean.Parser.optionalNoAntiquot]
def optionalNoAntiquot.parenthesizer (p : Parenthesizer) : Parenthesizer := do
visitArgs p
@[combinatorParenthesizer Lean.Parser.sepBy]
def sepBy.parenthesizer (p pSep : Parenthesizer) : Parenthesizer := do
@[combinatorParenthesizer Lean.Parser.sepByNoAntiquot]
def sepByNoAntiquot.parenthesizer (p pSep : Parenthesizer) : Parenthesizer := do
let stx ← getCur
visitArgs $ (List.range stx.getArgs.size).reverse.forM $ fun i => if i % 2 == 0 then p else pSep
@[combinatorParenthesizer Lean.Parser.sepBy1] def sepBy1.parenthesizer := sepBy.parenthesizer
@[combinatorParenthesizer Lean.Parser.sepBy1NoAntiquot] def sepBy1NoAntiquot.parenthesizer := sepByNoAntiquot.parenthesizer
@[combinatorParenthesizer Lean.Parser.withPosition] def withPosition.parenthesizer (p : Parenthesizer) : Parenthesizer := do
-- We assume the formatter will indent syntax sufficiently such that parenthesizing a `withPosition` node is never necessary

View file

@ -38,4 +38,11 @@ end Syntax
#eval run $ do let a ← `(a.{0}); match_syntax a with `($id:ident) => pure id | _ => pure a
#eval run $ do let a ← `(match a with | a => 1 | _ => 2); match_syntax a with `(match $e with $eqns:matchAlt*) => pure eqns | _ => pure #[]
#eval run do let a ← some <$> `(a); `({ a := a $[: $a]?})
#eval run do let a ← pure none; `({ a := a $[: $a]?})
#eval run do
let pats := #[← `(a), ← `(a + 1)]
let rhss := #[← `(b), ← `(b + 1)]
`(match a with $[$pats => $rhss]|*)
end Lean

View file

@ -26,3 +26,6 @@
"`a._@.UnhygienicMain._hyg.1"
"(Term.explicitUniv `a._@.UnhygienicMain._hyg.1 \".{\" [(numLit \"0\")] \"}\")"
"#[]"
"(Term.structInst\n \"{\"\n []\n [[(Term.structInstField `a._@.UnhygienicMain._hyg.1 [] \":=\" `a._@.UnhygienicMain._hyg.1) []]]\n []\n [\":\" `a._@.UnhygienicMain._hyg.1]\n \"}\")"
"(Term.structInst\n \"{\"\n []\n [[(Term.structInstField `a._@.UnhygienicMain._hyg.1 [] \":=\" `a._@.UnhygienicMain._hyg.1) []]]\n []\n []\n \"}\")"
"(Term.match\n \"match\"\n [(Term.matchDiscr [] `a._@.UnhygienicMain._hyg.1)]\n []\n \"with\"\n (Term.matchAlts\n []\n [(Term.matchAlt [`a._@.UnhygienicMain._hyg.1] \"=>\" `b._@.UnhygienicMain._hyg.1)\n \"|\"\n (Term.matchAlt\n [(_kind.term._@.Init.Notation._hyg.637 `a._@.UnhygienicMain._hyg.1 \"+\" (numLit \"1\"))]\n \"=>\"\n (_kind.term._@.Init.Notation._hyg.637 `b._@.UnhygienicMain._hyg.1 \"+\" (numLit \"1\")))]))"