diff --git a/src/Init/Prelude.lean b/src/Init/Prelude.lean index fdc8e079a2..9f6b03973c 100644 --- a/src/Init/Prelude.lean +++ b/src/Init/Prelude.lean @@ -1841,7 +1841,9 @@ inductive Syntax where | atom (info : SourceInfo) (val : String) : Syntax | ident (info : SourceInfo) (rawVal : Substring) (val : Name) (preresolved : List (Prod Name (List String))) : Syntax -structure TSyntax (k : SyntaxNodeKind) where +def SyntaxNodeKinds := List SyntaxNodeKind + +structure TSyntax (ks : SyntaxNodeKinds) where raw : Syntax instance : Inhabited Syntax where @@ -1967,20 +1969,29 @@ partial def getTailPos? (stx : Syntax) (originalOnly := false) : Option String.P loop 0 | _, _ => none -structure TSyntaxArray (kind : SyntaxNodeKind) where - raw : Array Syntax - /-- An array of syntax elements interspersed with separators. Can be coerced to/from `Array Syntax` to automatically remove/insert the separators. -/ structure SepArray (sep : String) where elemsAndSeps : Array Syntax -structure TSepArray (kind : SyntaxNodeKind) (sep : String) where +structure TSepArray (ks : SyntaxNodeKinds) (sep : String) where elemsAndSeps : Array Syntax end Syntax +abbrev TSyntaxArray (ks : SyntaxNodeKinds) := Array (TSyntax ks) + +unsafe def TSyntaxArray.rawImpl : TSyntaxArray ks → Array Syntax := unsafeCast + +@[implementedBy TSyntaxArray.rawImpl] +opaque TSyntaxArray.raw (as : TSyntaxArray ks) : Array Syntax := Array.empty + +unsafe def TSyntaxArray.mkImpl : Array Syntax → TSyntaxArray ks := unsafeCast + +@[implementedBy TSyntaxArray.mkImpl] +opaque TSyntaxArray.mk (as : Array Syntax) : TSyntaxArray ks := Array.empty + def SourceInfo.fromRef (ref : Syntax) : SourceInfo := match ref.getPos?, ref.getTailPos? with | some pos, some tailPos => SourceInfo.synthetic pos tailPos diff --git a/src/Lean/Elab/Quotation.lean b/src/Lean/Elab/Quotation.lean index f93716a1d4..9cf19878d2 100644 --- a/src/Lean/Elab/Quotation.lean +++ b/src/Lean/Elab/Quotation.lean @@ -18,17 +18,24 @@ open Lean.Syntax open Meta /-- `C[$(e)]` ~> `let a := e; C[$a]`. Used in the implementation of antiquot splices. -/ -private partial def floatOutAntiquotTerms : Syntax → StateT (Syntax → TermElabM Syntax) TermElabM Syntax - | stx@(Syntax.node i k args) => do - if isAntiquot stx && !isEscapedAntiquot stx then - let e := getAntiquotTerm stx - if !e.isIdent || !e.getId.isAtomic then - return ← withFreshMacroScope do - let a ← `(a) - modify (fun _ stx => (`(let $a:ident := $e; $stx) : TermElabM _)) - pure <| stx.setArg 2 a +private partial def floatOutAntiquotTerms (stx : Syntax) : StateT (Syntax → TermElabM Syntax) TermElabM Syntax := + if isAntiquots stx && !isEscapedAntiquot (getCanonicalAntiquot stx) then + let e := getAntiquotTerm (getCanonicalAntiquot stx) + if !e.isIdent || !e.getId.isAtomic then + withFreshMacroScope do + let a ← `(a) + modify (fun _ (stx : Syntax) => (`(let $a:ident := $e; $stx) : TermElabM Syntax)) + let stx := if stx.isOfKind choiceKind then + mkNullNode <| stx.getArgs.map (·.setArg 2 a) + else + stx.setArg 2 a + return stx + else + return stx + else if let Syntax.node i k args := stx then return Syntax.node i k (← args.mapM floatOutAntiquotTerms) - | stx => pure stx + else + return stx private def getSepFromSplice (splice : Syntax) : String := if let Syntax.atom _ sep := getAntiquotSpliceSuffix splice then @@ -99,9 +106,9 @@ private partial def quoteSyntax : Syntax → TermElabM Syntax `(Syntax.ident info $(quote rawVal) (addMacroScope mainModule $val scp) $(quote preresolved)) -- if antiquotation, insert contents as-is, else recurse | stx@(Syntax.node _ k _) => do - if isAntiquot stx && !isEscapedAntiquot stx then - let (k, _) := antiquotKind? stx |>.get! - `(@TSyntax.raw $(quote k) $(getAntiquotTerm stx)) + if isAntiquots stx && !isEscapedAntiquot (getCanonicalAntiquot stx) then + let ks := antiquotKinds stx + `(@TSyntax.raw $(quote <| ks.map (·.1)) $(getAntiquotTerm (getCanonicalAntiquot stx))) else if isTokenAntiquot stx && !isEscapedAntiquot stx then match stx[0] with | Syntax.atom _ val => `(Syntax.atom (Option.getD (getHeadInfo? $(getAntiquotTerm stx)) info) $(quote val)) @@ -119,15 +126,15 @@ private partial def quoteSyntax : Syntax → TermElabM Syntax for arg in stx.getArgs do if k == nullKind && isAntiquotSuffixSplice arg then let antiquot := getAntiquotSuffixSpliceInner arg - let (k, _) := antiquotKind? antiquot |>.get! - let val := getAntiquotTerm antiquot + let ks := antiquotKinds antiquot |>.map (·.1) + let val := getAntiquotTerm (getCanonicalAntiquot antiquot) args := args.append (appendName := appendName) <| ← match antiquotSuffixSplice? arg with - | `optional => `(match Option.map (@TSyntax.raw $(quote k)) $val:term with + | `optional => `(match Option.map (@TSyntax.raw $(quote ks)) $val:term with | some x => Array.empty.push x | none => Array.empty) - | `many => `(@TSyntaxArray.raw $(quote k) $val) - | `sepBy => `(@TSepArray.elemsAndSeps $(quote k) $(quote <| getSepFromSplice arg) $val) + | `many => `(@TSyntaxArray.raw $(quote ks) $val) + | `sepBy => `(@TSepArray.elemsAndSeps $(quote ks) $(quote <| getSepFromSplice arg) $val) | k => throwErrorAt arg "invalid antiquotation suffix splice kind '{k}'" else if k == nullKind && isAntiquotSplice arg then let k := antiquotSpliceKind? arg @@ -243,7 +250,7 @@ inductive HeadCheck where -- the node kind. -- without arity: `($x:k) -- with arity: any quotation without an antiquotation head pattern - | shape (k : SyntaxNodeKind) (arity : Option Nat) + | shape (k : List SyntaxNodeKind) (arity : Option Nat) -- Match step that succeeds on `null` nodes of arity at least `numPrefix + numSuffix`, introducing discriminants -- for the first `numPrefix` children, one `null` node for those in between, and for the `numSuffix` last children. -- example: `([$x, $xs,*, $y]) is `slice 2 2` @@ -299,12 +306,12 @@ private partial def getHeadInfo (alt : Alt) : TermElabM HeadInfo := unconditionally pure else if quoted.isTokenAntiquot then unconditionally (`(let $(quoted.getAntiquotTerm) := discr; $(·))) - else if isAntiquot quoted && !isEscapedAntiquot quoted then + else if isAntiquots quoted && !isEscapedAntiquot (getCanonicalAntiquot quoted) then -- quotation contains a single antiquotation - let (k, pseudoKind) := antiquotKind? quoted |>.get! - let rhsFn := match getAntiquotTerm quoted with + let (ks, pseudoKinds) := antiquotKinds quoted |>.unzip + let rhsFn := match getAntiquotTerm (getCanonicalAntiquot quoted) with | `(_) => pure - | `($id:ident) => fun stx => `(let $id := @TSyntax.mk $(quote k) discr; $(stx)) + | `($id:ident) => fun stx => `(let $id := @TSyntax.mk $(quote ks) discr; $(stx)) | anti => fun _ => throwErrorAt anti "unsupported antiquotation kind in pattern" -- Antiquotation kinds like `$id:ident` influence the parser, but also need to be considered by -- `match` (but not by quotation terms). For example, `($id:ident) and `($e) are not @@ -316,27 +323,29 @@ private partial def getHeadInfo (alt : Alt) : TermElabM HeadInfo := -- let id := stx; let e := stx; ... -- else -- let e := stx; ... - if pseudoKind then unconditionally rhsFn else pure { - check := shape k none, + if pseudoKinds.all id then unconditionally rhsFn else pure { + check := shape ks none, onMatch := fun | other _ => undecided - | taken@(shape k' sz) => - if k' == k then + | taken@(shape ks' sz) => + if ks' == ks then covered (adaptRhs rhsFn ∘ noOpMatchAdaptPats taken) (exhaustive := sz.isNone) else uncovered | _ => uncovered, - doMatch := fun yes no => do `(cond (Syntax.isOfKind discr $(quote k)) $(← yes []) $(← no)), + doMatch := fun yes no => do + let cond ← ks.foldlM (fun cond k => `(or $cond (Syntax.isOfKind discr $(quote k)))) (← `(false)) + `(cond $cond $(← yes []) $(← no)), } else if isAntiquotSuffixSplice quoted then throwErrorAt quoted "unexpected antiquotation splice" else if isAntiquotSplice quoted then throwErrorAt quoted "unexpected antiquotation splice" else if quoted.getArgs.size == 1 && isAntiquotSuffixSplice quoted[0] then let inner := getAntiquotSuffixSpliceInner quoted[0] - let anti := getAntiquotTerm inner - let (k, _) := antiquotKind? inner |>.get! + let anti := getAntiquotTerm (getCanonicalAntiquot inner) + let ks := antiquotKinds inner |>.map (·.1) unconditionally fun rhs => match antiquotSuffixSplice? quoted[0] with - | `optional => `(let $anti := Option.map (@TSyntax.mk $(quote k)) (Syntax.getOptional? discr); $rhs) - | `many => `(let $anti := @TSyntaxArray.mk $(quote k) (Syntax.getArgs discr); $rhs) - | `sepBy => `(let $anti := @TSepArray.mk $(quote k) $(quote <| getSepFromSplice quoted[0]) (Syntax.getArgs discr); $rhs) + | `optional => `(let $anti := Option.map (@TSyntax.mk $(quote ks)) (Syntax.getOptional? discr); $rhs) + | `many => `(let $anti := @TSyntaxArray.mk $(quote ks) (Syntax.getArgs discr); $rhs) + | `sepBy => `(let $anti := @TSepArray.mk $(quote ks) $(quote <| getSepFromSplice quoted[0]) (Syntax.getArgs discr); $rhs) | k => throwErrorAt quoted "invalid antiquotation suffix splice kind '{k}'" else if quoted.getArgs.size == 1 && isAntiquotSplice quoted[0] then pure { check := other pat, @@ -427,15 +436,15 @@ private partial def getHeadInfo (alt : Alt) : TermElabM HeadInfo := -- but matching on literals is quite rare. other quoted else - shape kind argPats.size, + shape [kind] argPats.size, onMatch := fun | other stx' => if (quoted.isIdent || lit) && quoted == stx' then covered pure (exhaustive := true) else uncovered - | shape k' sz => - if k' == kind && sz == argPats.size then + | shape ks sz => + if ks == [kind] && sz == argPats.size then covered (fun (pats, rhs) => pure (argPats.toList ++ pats, rhs)) (exhaustive := true) else uncovered diff --git a/src/Lean/Elab/Quotation/Util.lean b/src/Lean/Elab/Quotation/Util.lean index 8f7fac396b..d9870e416b 100644 --- a/src/Lean/Elab/Quotation/Util.lean +++ b/src/Lean/Elab/Quotation/Util.lean @@ -15,7 +15,7 @@ register_builtin_option hygiene : Bool := { def getAntiquotationIds (stx : Syntax) : TermElabM (Array Syntax) := do let mut ids := #[] - for stx in stx.topDown do + for stx in stx.topDown (firstChoiceOnly := true) do if (isAntiquot stx || isTokenAntiquot stx) && !isEscapedAntiquot stx then let anti := getAntiquotTerm stx if anti.isIdent then ids := ids.push anti diff --git a/src/Lean/Syntax.lean b/src/Lean/Syntax.lean index 221e490ee4..caf502ae7f 100644 --- a/src/Lean/Syntax.lean +++ b/src/Lean/Syntax.lean @@ -383,6 +383,15 @@ def isAntiquot : Syntax → Bool | Syntax.node _ (Name.str _ "antiquot" _) _ => true | _ => false +def isAntiquots (stx : Syntax) : Bool := + stx.isAntiquot || (stx.isOfKind choiceKind && stx.getNumArgs > 0 && stx.getArgs.all isAntiquot) + +def getCanonicalAntiquot (stx : Syntax) : Syntax := + if stx.isOfKind choiceKind then + stx[0] + else + stx + def mkAntiquotNode (kind : Name) (term : Syntax) (nesting := 0) (name : Option String := none) (isPseudoKind := false) : Syntax := let nesting := mkNullNode (mkArray nesting (mkAtom "$")) let term := @@ -420,6 +429,14 @@ def antiquotKind? : Syntax → Option (SyntaxNodeKind × Bool) | Syntax.node _ (Name.str k "antiquot" _) args => (k, false) | _ => none +def antiquotKinds (stx : Syntax) : List (SyntaxNodeKind × Bool) := + if stx.isOfKind choiceKind then + stx.getArgs.filterMap antiquotKind? |>.toList + else + match antiquotKind? stx with + | some stx => [stx] + | none => [] + -- An "antiquotation splice" is something like `$[...]?` or `$[...]*`. def antiquotSpliceKind? : Syntax → Option SyntaxNodeKind | Syntax.node _ (Name.str k "antiquot_scope" _) _ => some k