feat: upgrade TSyntax to union of kinds

This commit is contained in:
Sebastian Ullrich 2022-04-17 11:16:22 +02:00
parent 3a61cc247e
commit 43ba121e98
4 changed files with 79 additions and 42 deletions

View file

@ -1841,7 +1841,9 @@ inductive Syntax where
| atom (info : SourceInfo) (val : String) : Syntax
| ident (info : SourceInfo) (rawVal : Substring) (val : Name) (preresolved : List (Prod Name (List String))) : Syntax
structure TSyntax (k : SyntaxNodeKind) where
def SyntaxNodeKinds := List SyntaxNodeKind
structure TSyntax (ks : SyntaxNodeKinds) where
raw : Syntax
instance : Inhabited Syntax where
@ -1967,20 +1969,29 @@ partial def getTailPos? (stx : Syntax) (originalOnly := false) : Option String.P
loop 0
| _, _ => none
structure TSyntaxArray (kind : SyntaxNodeKind) where
raw : Array Syntax
/--
An array of syntax elements interspersed with separators. Can be coerced to/from `Array Syntax` to automatically
remove/insert the separators. -/
structure SepArray (sep : String) where
elemsAndSeps : Array Syntax
structure TSepArray (kind : SyntaxNodeKind) (sep : String) where
structure TSepArray (ks : SyntaxNodeKinds) (sep : String) where
elemsAndSeps : Array Syntax
end Syntax
abbrev TSyntaxArray (ks : SyntaxNodeKinds) := Array (TSyntax ks)
unsafe def TSyntaxArray.rawImpl : TSyntaxArray ks → Array Syntax := unsafeCast
@[implementedBy TSyntaxArray.rawImpl]
opaque TSyntaxArray.raw (as : TSyntaxArray ks) : Array Syntax := Array.empty
unsafe def TSyntaxArray.mkImpl : Array Syntax → TSyntaxArray ks := unsafeCast
@[implementedBy TSyntaxArray.mkImpl]
opaque TSyntaxArray.mk (as : Array Syntax) : TSyntaxArray ks := Array.empty
def SourceInfo.fromRef (ref : Syntax) : SourceInfo :=
match ref.getPos?, ref.getTailPos? with
| some pos, some tailPos => SourceInfo.synthetic pos tailPos

View file

@ -18,17 +18,24 @@ open Lean.Syntax
open Meta
/-- `C[$(e)]` ~> `let a := e; C[$a]`. Used in the implementation of antiquot splices. -/
private partial def floatOutAntiquotTerms : Syntax → StateT (Syntax → TermElabM Syntax) TermElabM Syntax
| stx@(Syntax.node i k args) => do
if isAntiquot stx && !isEscapedAntiquot stx then
let e := getAntiquotTerm stx
if !e.isIdent || !e.getId.isAtomic then
return ← withFreshMacroScope do
let a ← `(a)
modify (fun _ stx => (`(let $a:ident := $e; $stx) : TermElabM _))
pure <| stx.setArg 2 a
private partial def floatOutAntiquotTerms (stx : Syntax) : StateT (Syntax → TermElabM Syntax) TermElabM Syntax :=
if isAntiquots stx && !isEscapedAntiquot (getCanonicalAntiquot stx) then
let e := getAntiquotTerm (getCanonicalAntiquot stx)
if !e.isIdent || !e.getId.isAtomic then
withFreshMacroScope do
let a ← `(a)
modify (fun _ (stx : Syntax) => (`(let $a:ident := $e; $stx) : TermElabM Syntax))
let stx := if stx.isOfKind choiceKind then
mkNullNode <| stx.getArgs.map (·.setArg 2 a)
else
stx.setArg 2 a
return stx
else
return stx
else if let Syntax.node i k args := stx then
return Syntax.node i k (← args.mapM floatOutAntiquotTerms)
| stx => pure stx
else
return stx
private def getSepFromSplice (splice : Syntax) : String :=
if let Syntax.atom _ sep := getAntiquotSpliceSuffix splice then
@ -99,9 +106,9 @@ private partial def quoteSyntax : Syntax → TermElabM Syntax
`(Syntax.ident info $(quote rawVal) (addMacroScope mainModule $val scp) $(quote preresolved))
-- if antiquotation, insert contents as-is, else recurse
| stx@(Syntax.node _ k _) => do
if isAntiquot stx && !isEscapedAntiquot stx then
let (k, _) := antiquotKind? stx |>.get!
`(@TSyntax.raw $(quote k) $(getAntiquotTerm stx))
if isAntiquots stx && !isEscapedAntiquot (getCanonicalAntiquot stx) then
let ks := antiquotKinds stx
`(@TSyntax.raw $(quote <| ks.map (·.1)) $(getAntiquotTerm (getCanonicalAntiquot stx)))
else if isTokenAntiquot stx && !isEscapedAntiquot stx then
match stx[0] with
| Syntax.atom _ val => `(Syntax.atom (Option.getD (getHeadInfo? $(getAntiquotTerm stx)) info) $(quote val))
@ -119,15 +126,15 @@ private partial def quoteSyntax : Syntax → TermElabM Syntax
for arg in stx.getArgs do
if k == nullKind && isAntiquotSuffixSplice arg then
let antiquot := getAntiquotSuffixSpliceInner arg
let (k, _) := antiquotKind? antiquot |>.get!
let val := getAntiquotTerm antiquot
let ks := antiquotKinds antiquot |>.map (·.1)
let val := getAntiquotTerm (getCanonicalAntiquot antiquot)
args := args.append (appendName := appendName) <| ←
match antiquotSuffixSplice? arg with
| `optional => `(match Option.map (@TSyntax.raw $(quote k)) $val:term with
| `optional => `(match Option.map (@TSyntax.raw $(quote ks)) $val:term with
| some x => Array.empty.push x
| none => Array.empty)
| `many => `(@TSyntaxArray.raw $(quote k) $val)
| `sepBy => `(@TSepArray.elemsAndSeps $(quote k) $(quote <| getSepFromSplice arg) $val)
| `many => `(@TSyntaxArray.raw $(quote ks) $val)
| `sepBy => `(@TSepArray.elemsAndSeps $(quote ks) $(quote <| getSepFromSplice arg) $val)
| k => throwErrorAt arg "invalid antiquotation suffix splice kind '{k}'"
else if k == nullKind && isAntiquotSplice arg then
let k := antiquotSpliceKind? arg
@ -243,7 +250,7 @@ inductive HeadCheck where
-- the node kind.
-- without arity: `($x:k)
-- with arity: any quotation without an antiquotation head pattern
| shape (k : SyntaxNodeKind) (arity : Option Nat)
| shape (k : List SyntaxNodeKind) (arity : Option Nat)
-- Match step that succeeds on `null` nodes of arity at least `numPrefix + numSuffix`, introducing discriminants
-- for the first `numPrefix` children, one `null` node for those in between, and for the `numSuffix` last children.
-- example: `([$x, $xs,*, $y]) is `slice 2 2`
@ -299,12 +306,12 @@ private partial def getHeadInfo (alt : Alt) : TermElabM HeadInfo :=
unconditionally pure
else if quoted.isTokenAntiquot then
unconditionally (`(let $(quoted.getAntiquotTerm) := discr; $(·)))
else if isAntiquot quoted && !isEscapedAntiquot quoted then
else if isAntiquots quoted && !isEscapedAntiquot (getCanonicalAntiquot quoted) then
-- quotation contains a single antiquotation
let (k, pseudoKind) := antiquotKind? quoted |>.get!
let rhsFn := match getAntiquotTerm quoted with
let (ks, pseudoKinds) := antiquotKinds quoted |>.unzip
let rhsFn := match getAntiquotTerm (getCanonicalAntiquot quoted) with
| `(_) => pure
| `($id:ident) => fun stx => `(let $id := @TSyntax.mk $(quote k) discr; $(stx))
| `($id:ident) => fun stx => `(let $id := @TSyntax.mk $(quote ks) discr; $(stx))
| anti => fun _ => throwErrorAt anti "unsupported antiquotation kind in pattern"
-- Antiquotation kinds like `$id:ident` influence the parser, but also need to be considered by
-- `match` (but not by quotation terms). For example, `($id:ident) and `($e) are not
@ -316,27 +323,29 @@ private partial def getHeadInfo (alt : Alt) : TermElabM HeadInfo :=
-- let id := stx; let e := stx; ...
-- else
-- let e := stx; ...
if pseudoKind then unconditionally rhsFn else pure {
check := shape k none,
if pseudoKinds.all id then unconditionally rhsFn else pure {
check := shape ks none,
onMatch := fun
| other _ => undecided
| taken@(shape k' sz) =>
if k' == k then
| taken@(shape ks' sz) =>
if ks' == ks then
covered (adaptRhs rhsFn ∘ noOpMatchAdaptPats taken) (exhaustive := sz.isNone)
else uncovered
| _ => uncovered,
doMatch := fun yes no => do `(cond (Syntax.isOfKind discr $(quote k)) $(← yes []) $(← no)),
doMatch := fun yes no => do
let cond ← ks.foldlM (fun cond k => `(or $cond (Syntax.isOfKind discr $(quote k)))) (← `(false))
`(cond $cond $(← yes []) $(← no)),
}
else if isAntiquotSuffixSplice quoted then throwErrorAt quoted "unexpected antiquotation splice"
else if isAntiquotSplice quoted then throwErrorAt quoted "unexpected antiquotation splice"
else if quoted.getArgs.size == 1 && isAntiquotSuffixSplice quoted[0] then
let inner := getAntiquotSuffixSpliceInner quoted[0]
let anti := getAntiquotTerm inner
let (k, _) := antiquotKind? inner |>.get!
let anti := getAntiquotTerm (getCanonicalAntiquot inner)
let ks := antiquotKinds inner |>.map (·.1)
unconditionally fun rhs => match antiquotSuffixSplice? quoted[0] with
| `optional => `(let $anti := Option.map (@TSyntax.mk $(quote k)) (Syntax.getOptional? discr); $rhs)
| `many => `(let $anti := @TSyntaxArray.mk $(quote k) (Syntax.getArgs discr); $rhs)
| `sepBy => `(let $anti := @TSepArray.mk $(quote k) $(quote <| getSepFromSplice quoted[0]) (Syntax.getArgs discr); $rhs)
| `optional => `(let $anti := Option.map (@TSyntax.mk $(quote ks)) (Syntax.getOptional? discr); $rhs)
| `many => `(let $anti := @TSyntaxArray.mk $(quote ks) (Syntax.getArgs discr); $rhs)
| `sepBy => `(let $anti := @TSepArray.mk $(quote ks) $(quote <| getSepFromSplice quoted[0]) (Syntax.getArgs discr); $rhs)
| k => throwErrorAt quoted "invalid antiquotation suffix splice kind '{k}'"
else if quoted.getArgs.size == 1 && isAntiquotSplice quoted[0] then pure {
check := other pat,
@ -427,15 +436,15 @@ private partial def getHeadInfo (alt : Alt) : TermElabM HeadInfo :=
-- but matching on literals is quite rare.
other quoted
else
shape kind argPats.size,
shape [kind] argPats.size,
onMatch := fun
| other stx' =>
if (quoted.isIdent || lit) && quoted == stx' then
covered pure (exhaustive := true)
else
uncovered
| shape k' sz =>
if k' == kind && sz == argPats.size then
| shape ks sz =>
if ks == [kind] && sz == argPats.size then
covered (fun (pats, rhs) => pure (argPats.toList ++ pats, rhs)) (exhaustive := true)
else
uncovered

View file

@ -15,7 +15,7 @@ register_builtin_option hygiene : Bool := {
def getAntiquotationIds (stx : Syntax) : TermElabM (Array Syntax) := do
let mut ids := #[]
for stx in stx.topDown do
for stx in stx.topDown (firstChoiceOnly := true) do
if (isAntiquot stx || isTokenAntiquot stx) && !isEscapedAntiquot stx then
let anti := getAntiquotTerm stx
if anti.isIdent then ids := ids.push anti

View file

@ -383,6 +383,15 @@ def isAntiquot : Syntax → Bool
| Syntax.node _ (Name.str _ "antiquot" _) _ => true
| _ => false
def isAntiquots (stx : Syntax) : Bool :=
stx.isAntiquot || (stx.isOfKind choiceKind && stx.getNumArgs > 0 && stx.getArgs.all isAntiquot)
def getCanonicalAntiquot (stx : Syntax) : Syntax :=
if stx.isOfKind choiceKind then
stx[0]
else
stx
def mkAntiquotNode (kind : Name) (term : Syntax) (nesting := 0) (name : Option String := none) (isPseudoKind := false) : Syntax :=
let nesting := mkNullNode (mkArray nesting (mkAtom "$"))
let term :=
@ -420,6 +429,14 @@ def antiquotKind? : Syntax → Option (SyntaxNodeKind × Bool)
| Syntax.node _ (Name.str k "antiquot" _) args => (k, false)
| _ => none
def antiquotKinds (stx : Syntax) : List (SyntaxNodeKind × Bool) :=
if stx.isOfKind choiceKind then
stx.getArgs.filterMap antiquotKind? |>.toList
else
match antiquotKind? stx with
| some stx => [stx]
| none => []
-- An "antiquotation splice" is something like `$[...]?` or `$[...]*`.
def antiquotSpliceKind? : Syntax → Option SyntaxNodeKind
| Syntax.node _ (Name.str k "antiquot_scope" _) _ => some k