diff --git a/src/Init/Lean/Elab/Quotation.lean b/src/Init/Lean/Elab/Quotation.lean index c15d3e6837..0a7242b6bb 100644 --- a/src/Init/Lean/Elab/Quotation.lean +++ b/src/Init/Lean/Elab/Quotation.lean @@ -59,6 +59,7 @@ instance Array.HasQuote {α : Type} [HasQuote α] : HasQuote (Array α) := namespace Elab namespace Term +-- antiquotation node kinds are formed from the original node kind (if any) plus "antiquot" def isAntiquot : Syntax → Bool | Syntax.node (Name.str _ "antiquot" _) _ => true | _ => false @@ -141,96 +142,124 @@ fun stx expectedType? => do -- an "alternative" of patterns plus right-hand side private abbrev Alt := List Syntax × Syntax --- If `pat` is an unconditional pattern, return a transformation of the RHS that appropriately introduces --- bindings on the RHS. -private def isVarPat? (pat : Syntax) : Option (Syntax → TermElabM Syntax) := --- TODO: reimplement using match_syntax -if pat.isOfKind `Lean.Parser.Term.id then some $ fun rhs => `(let $pat := discr; $rhs) -else if pat.isOfKind `Lean.Parser.Term.hole then some pure +/-- Information on a pattern's head that influences the compilation of a single + match step. -/ +structure HeadInfo := +-- Node kind to match, if any +(kind : Option SyntaxNodeKind := none) +-- Nested patterns for each argument, if any. In a single match step, we only +-- check that the arity matches. The arity is usually implied by the node kind, +-- but not in the case of `many` nodes. +(argPats : Option (Array Syntax) := none) +-- Function to apply to the right-hand side in case the match succeeds. Used to +-- bind pattern variables. +(rhsFn : Syntax → TermElabM Syntax := pure) + +instance HeadInfo.Inhabited : Inhabited HeadInfo := ⟨{}⟩ + +/-- `h1.generalizes h2` iff h1 is equal to or more general than h2, i.e. it matches all nodes + h2 matches. This induces a partial ordering. -/ +def HeadInfo.generalizes : HeadInfo → HeadInfo → Bool +| { kind := none, .. }, _ => true +| { kind := some k1, argPats := none, .. }, + { kind := some k2, .. } => k1 == k2 +| { kind := some k1, argPats := some ps1, .. }, + { kind := some k2, argPats := some ps2, .. } => k1 == k2 && ps1.size == ps2.size +| _, _ => false + +private def getHeadInfo (alt : Alt) : HeadInfo := +let pat := alt.fst.head!; +let unconditional (rhsFn) := { HeadInfo . rhsFn := rhsFn }; +-- variable pattern +if pat.isOfKind `Lean.Parser.Term.id then unconditional $ fun rhs => `(let $pat := discr; $rhs) +-- wildcard pattern +else if pat.isOfKind `Lean.Parser.Term.hole then unconditional pure +-- quotation pattern else if pat.isOfKind `Lean.Parser.Term.stxQuot then let quoted := pat.getArg 1; - -- We assume that atoms are uniquely determined by the surrounding node and never have to be checked - if quoted.isAtom then some pure - -- TODO: antiquotations with kinds (`$id:id`) probably can't be handled as unconditional patterns - else if isAntiquot quoted then + match quoted with + -- We assume that atoms are uniquely determined by the node kind and never have to be checked + | Syntax.atom _ _ => unconditional pure + -- quotation is a single antiquotation + | Syntax.node (Name.str k "antiquot" _) _ => + -- Antiquotation kinds like `$id:id` influence the parser, but also need to be considered by + -- match_syntax (but not by quotation terms). For example, `($id:id) and `($e) are not + -- distinguishable without checking the kind of the node to be captured. Note that some + -- antiquotations like the latter one for terms do not correspond to any actual node kind + -- (signified by `k == Name.anonymous`), so we would only check for `Term.id` here. + -- + -- if stx.isOfKind `Lean.Parser.Term.id then + -- let id := stx; ... + -- else + -- let e := stx; ... + let kind := if k == Name.anonymous then none else some k; let anti := quoted.getArg 1; - if isAntiquotSplice quoted then some $ fun _ => throwError quoted "unexpected antiquotation splice" - else if anti.isOfKind `Lean.Parser.Term.id then some $ fun rhs => `(let $anti := discr; $rhs) - else unreachable! - else if isAntiquotSplicePat quoted then - let anti := (quoted.getArg 0).getArg 1; - some $ fun rhs => `(let $anti := Syntax.getArgs discr; $rhs) - else none -else none - --- If the first pattern of the alternative is a conditional pattern, return the node we should match against -private def altNextNode? : Alt → Option SyntaxNode -| (pat::_, _) => - if (isVarPat? pat).isNone && pat.isOfKind `Lean.Parser.Term.stxQuot then - let quoted := pat.getArg 1; - some quoted.asNode - else none -| _ => none + -- Splices should only appear inside a nullKind node, see next case + if isAntiquotSplice quoted then unconditional $ fun _ => throwError quoted "unexpected antiquotation splice" + else if anti.isOfKind `Lean.Parser.Term.id then { kind := kind, rhsFn := fun rhs => `(let $anti := discr; $rhs) } + else unconditional $ fun _ => throwError anti "syntax_match: antiquotation must be variable" + | _ => + -- quotation is a single antiquotation splice => bind args array + if isAntiquotSplicePat quoted then + let anti := (quoted.getArg 0).getArg 1; + unconditional $ fun rhs => `(let $anti := Syntax.getArgs discr; $rhs) + else + -- not an antiquotation: match head shape + let argPats := quoted.getArgs.map $ fun arg => Syntax.node `Lean.Parser.Term.stxQuot #[mkAtom "`(", arg, mkAtom ")"]; + { kind := quoted.getKind, argPats := argPats } +else + unconditional $ fun _ => throwError pat "syntax_match: unexpected pattern kind" -- Assuming that the first pattern of the alternative is taken, replace it with patterns (if any) for its -- child nodes. -- Ex: `($a + (- $b)) => `($a), `(+), `(- $b) -- Note: The atom pattern `(+) will be discarded in a later step -private def explodeHeadPat (numArgs : Nat) : Alt → TermElabM Alt -| (pat::pats, rhs) => match isVarPat? pat with - | some fnRhs => do - -- unconditional pattern: replace with appropriate number of wildcards - newPat ← `(_); - let newPats := List.replicate numArgs newPat; - rhs ← fnRhs rhs; +private def explodeHeadPat (numArgs : Nat) : HeadInfo × Alt → TermElabM Alt +| (info, (pat::pats, rhs)) => do + let newPats := match info.argPats with + | some argPats => argPats.toList + | none => List.replicate numArgs $ Unhygienic.run `(_); + rhs ← info.rhsFn rhs; pure (newPats ++ pats, rhs) - | none => - if pat.isOfKind `Lean.Parser.Term.stxQuot then do - let quoted := pat.getArg 1; - let newPats := quoted.getArgs.toList.map $ fun arg => Syntax.node `Lean.Parser.Term.stxQuot #[mkAtom "`(", arg, mkAtom ")"]; - pure (newPats ++ pats, rhs) - else throwError pat $ "unsupported `syntax_match` pattern kind " ++ toString pat.getKind | _ => unreachable! --- The "shape" is the information that should be compared in a single matching step. Currently, it is the node kind --- and its arity (which is not constant in the case of `many` nodes) -private def nodeShape (n : SyntaxNode) : SyntaxNodeKind × Nat := -(n.getKind, n.getArgs.size) - private partial def compileStxMatch (ref : Syntax) : List Syntax → List Alt → TermElabM Syntax | [], ([], rhs)::_ => pure rhs -- nothing left to match | _, [] => throwError ref "non-exhaustive 'match_syntax'" -| discr::discrs, alts => - match alts.findSome? altNextNode? with - -- at least one conditional pattern: introduce an `if` for it and recurse - | some node => do - let shape := nodeShape node; - -- introduce pattern matches on the discriminant's children - newDiscrs ← (List.range node.getArgs.size).mapM $ fun i => `(Syntax.getArg discr $(quote i)); - -- collect matching alternatives and explode them - let yesAlts := alts.filter $ fun alt => match altNextNode? alt with some n => nodeShape n == shape | none => true; - yesAlts ← yesAlts.mapM $ explodeHeadPat node.getArgs.size; - -- non-matching alternatives are left as-is - -- NOTE: unconditional patterns must go into both `yesAlts` and `noAlts` - let noAlts := alts.filter $ fun alt => match altNextNode? alt with some n => nodeShape n != shape | none => true; - -- NOTE: use fresh macro scopes for recursion so that different `discr`s introduced by the quotation below do not collide - yes ← withFreshMacroScope $ compileStxMatch (newDiscrs ++ discrs) yesAlts; - no ← withFreshMacroScope $ compileStxMatch (discr::discrs) noAlts; - `(let discr := $discr; if Syntax.isOfKind discr $(quote shape.fst) && Array.size (Syntax.getArgs discr) == $(quote shape.snd) then $yes else $no) - -- only unconditional patterns: introduce binds and discard patterns - | none => do - alts ← alts.mapM $ explodeHeadPat 0; - res ← withFreshMacroScope $ compileStxMatch discrs alts; - `(let discr := $discr; $res) +| discr::discrs, alts => do + let alts := (alts.map getHeadInfo).zip alts; + -- Choose a most specific pattern, ie. a minimal element according to `generalizes`. + -- If there are multiple minimal elements, the choice does not matter. + let (info, alt) := alts.tail!.foldl (fun (min : HeadInfo × Alt) (alt : HeadInfo × Alt) => if min.1.generalizes alt.1 then alt else min) alts.head!; + -- introduce pattern matches on the discriminant's children if there are any nested patterns + newDiscrs ← match info.argPats with + | some pats => (List.range pats.size).mapM $ fun i => `(Syntax.getArg discr $(quote i)) + | none => pure []; + -- collect matching alternatives and explode them + let yesAlts := alts.filter $ fun (alt : HeadInfo × Alt) => alt.1.generalizes info; + yesAlts ← yesAlts.mapM $ explodeHeadPat newDiscrs.length; + -- NOTE: use fresh macro scopes for recursive call so that different `discr`s introduced by the quotations below do not collide + yes ← withFreshMacroScope $ compileStxMatch (newDiscrs ++ discrs) yesAlts; + some kind ← pure info.kind + -- unconditional match step + | `(let discr := $discr; $yes); + -- conditional match step + let noAlts := (alts.filter $ fun (alt : HeadInfo × Alt) => !info.generalizes alt.1).map Prod.snd; + no ← withFreshMacroScope $ compileStxMatch (discr::discrs) noAlts; + cond ← match info.argPats with + | some pats => `(Syntax.isOfKind discr $(quote kind) && Array.size (Syntax.getArgs discr) == $(quote pats.size)) + | none => `(Syntax.isOfKind discr $(quote kind)); + `(let discr := $discr; if $cond then $yes else $no) | _, _ => unreachable! private partial def getAntiquotVarsAux : Syntax → TermElabM (List Syntax) -| Syntax.node `Lean.Parser.Term.antiquot args => - let anti := args.get! 1; - if anti.isOfKind `Lean.Parser.Term.id then pure [anti] - else throwError anti "syntax_match: antiquotation must be variable" -| Syntax.node k args => do - List.join <$> args.toList.mapM getAntiquotVarsAux +| stx@(Syntax.node k args) => do + if isAntiquot stx then + let anti := args.get! 1; + if anti.isOfKind `Lean.Parser.Term.id then pure [anti] + else throwError anti "syntax_match: antiquotation must be variable" + else + List.join <$> args.toList.mapM getAntiquotVarsAux | _ => pure [] -- Get all antiquotations (as Term.id nodes) in `stx`