feat: refactor match_syntax compiler to properly match quotation kinds, which can act as both "variable" and "constructor" patterns simultaneously

This commit is contained in:
Sebastian Ullrich 2019-12-22 00:02:39 +01:00 committed by Leonardo de Moura
parent 9bf8c96502
commit fe9bd200da

View file

@ -59,6 +59,7 @@ instance Array.HasQuote {α : Type} [HasQuote α] : HasQuote (Array α) :=
namespace Elab
namespace Term
-- antiquotation node kinds are formed from the original node kind (if any) plus "antiquot"
def isAntiquot : Syntax → Bool
| Syntax.node (Name.str _ "antiquot" _) _ => true
| _ => false
@ -141,96 +142,124 @@ fun stx expectedType? => do
-- an "alternative" of patterns plus right-hand side
private abbrev Alt := List Syntax × Syntax
-- If `pat` is an unconditional pattern, return a transformation of the RHS that appropriately introduces
-- bindings on the RHS.
private def isVarPat? (pat : Syntax) : Option (Syntax → TermElabM Syntax) :=
-- TODO: reimplement using match_syntax
if pat.isOfKind `Lean.Parser.Term.id then some $ fun rhs => `(let $pat := discr; $rhs)
else if pat.isOfKind `Lean.Parser.Term.hole then some pure
/-- Information on a pattern's head that influences the compilation of a single
match step. -/
structure HeadInfo :=
-- Node kind to match, if any
(kind : Option SyntaxNodeKind := none)
-- Nested patterns for each argument, if any. In a single match step, we only
-- check that the arity matches. The arity is usually implied by the node kind,
-- but not in the case of `many` nodes.
(argPats : Option (Array Syntax) := none)
-- Function to apply to the right-hand side in case the match succeeds. Used to
-- bind pattern variables.
(rhsFn : Syntax → TermElabM Syntax := pure)
instance HeadInfo.Inhabited : Inhabited HeadInfo := ⟨{}⟩
/-- `h1.generalizes h2` iff h1 is equal to or more general than h2, i.e. it matches all nodes
h2 matches. This induces a partial ordering. -/
def HeadInfo.generalizes : HeadInfo → HeadInfo → Bool
| { kind := none, .. }, _ => true
| { kind := some k1, argPats := none, .. },
{ kind := some k2, .. } => k1 == k2
| { kind := some k1, argPats := some ps1, .. },
{ kind := some k2, argPats := some ps2, .. } => k1 == k2 && ps1.size == ps2.size
| _, _ => false
private def getHeadInfo (alt : Alt) : HeadInfo :=
let pat := alt.fst.head!;
let unconditional (rhsFn) := { HeadInfo . rhsFn := rhsFn };
-- variable pattern
if pat.isOfKind `Lean.Parser.Term.id then unconditional $ fun rhs => `(let $pat := discr; $rhs)
-- wildcard pattern
else if pat.isOfKind `Lean.Parser.Term.hole then unconditional pure
-- quotation pattern
else if pat.isOfKind `Lean.Parser.Term.stxQuot then
let quoted := pat.getArg 1;
-- We assume that atoms are uniquely determined by the surrounding node and never have to be checked
if quoted.isAtom then some pure
-- TODO: antiquotations with kinds (`$id:id`) probably can't be handled as unconditional patterns
else if isAntiquot quoted then
match quoted with
-- We assume that atoms are uniquely determined by the node kind and never have to be checked
| Syntax.atom _ _ => unconditional pure
-- quotation is a single antiquotation
| Syntax.node (Name.str k "antiquot" _) _ =>
-- Antiquotation kinds like `$id:id` influence the parser, but also need to be considered by
-- match_syntax (but not by quotation terms). For example, `($id:id) and `($e) are not
-- distinguishable without checking the kind of the node to be captured. Note that some
-- antiquotations like the latter one for terms do not correspond to any actual node kind
-- (signified by `k == Name.anonymous`), so we would only check for `Term.id` here.
--
-- if stx.isOfKind `Lean.Parser.Term.id then
-- let id := stx; ...
-- else
-- let e := stx; ...
let kind := if k == Name.anonymous then none else some k;
let anti := quoted.getArg 1;
if isAntiquotSplice quoted then some $ fun _ => throwError quoted "unexpected antiquotation splice"
else if anti.isOfKind `Lean.Parser.Term.id then some $ fun rhs => `(let $anti := discr; $rhs)
else unreachable!
else if isAntiquotSplicePat quoted then
let anti := (quoted.getArg 0).getArg 1;
some $ fun rhs => `(let $anti := Syntax.getArgs discr; $rhs)
else none
else none
-- If the first pattern of the alternative is a conditional pattern, return the node we should match against
private def altNextNode? : Alt → Option SyntaxNode
| (pat::_, _) =>
if (isVarPat? pat).isNone && pat.isOfKind `Lean.Parser.Term.stxQuot then
let quoted := pat.getArg 1;
some quoted.asNode
else none
| _ => none
-- Splices should only appear inside a nullKind node, see next case
if isAntiquotSplice quoted then unconditional $ fun _ => throwError quoted "unexpected antiquotation splice"
else if anti.isOfKind `Lean.Parser.Term.id then { kind := kind, rhsFn := fun rhs => `(let $anti := discr; $rhs) }
else unconditional $ fun _ => throwError anti "syntax_match: antiquotation must be variable"
| _ =>
-- quotation is a single antiquotation splice => bind args array
if isAntiquotSplicePat quoted then
let anti := (quoted.getArg 0).getArg 1;
unconditional $ fun rhs => `(let $anti := Syntax.getArgs discr; $rhs)
else
-- not an antiquotation: match head shape
let argPats := quoted.getArgs.map $ fun arg => Syntax.node `Lean.Parser.Term.stxQuot #[mkAtom "`(", arg, mkAtom ")"];
{ kind := quoted.getKind, argPats := argPats }
else
unconditional $ fun _ => throwError pat "syntax_match: unexpected pattern kind"
-- Assuming that the first pattern of the alternative is taken, replace it with patterns (if any) for its
-- child nodes.
-- Ex: `($a + (- $b)) => `($a), `(+), `(- $b)
-- Note: The atom pattern `(+) will be discarded in a later step
private def explodeHeadPat (numArgs : Nat) : Alt → TermElabM Alt
| (pat::pats, rhs) => match isVarPat? pat with
| some fnRhs => do
-- unconditional pattern: replace with appropriate number of wildcards
newPat ← `(_);
let newPats := List.replicate numArgs newPat;
rhs ← fnRhs rhs;
private def explodeHeadPat (numArgs : Nat) : HeadInfo × Alt → TermElabM Alt
| (info, (pat::pats, rhs)) => do
let newPats := match info.argPats with
| some argPats => argPats.toList
| none => List.replicate numArgs $ Unhygienic.run `(_);
rhs ← info.rhsFn rhs;
pure (newPats ++ pats, rhs)
| none =>
if pat.isOfKind `Lean.Parser.Term.stxQuot then do
let quoted := pat.getArg 1;
let newPats := quoted.getArgs.toList.map $ fun arg => Syntax.node `Lean.Parser.Term.stxQuot #[mkAtom "`(", arg, mkAtom ")"];
pure (newPats ++ pats, rhs)
else throwError pat $ "unsupported `syntax_match` pattern kind " ++ toString pat.getKind
| _ => unreachable!
-- The "shape" is the information that should be compared in a single matching step. Currently, it is the node kind
-- and its arity (which is not constant in the case of `many` nodes)
private def nodeShape (n : SyntaxNode) : SyntaxNodeKind × Nat :=
(n.getKind, n.getArgs.size)
private partial def compileStxMatch (ref : Syntax) : List Syntax → List Alt → TermElabM Syntax
| [], ([], rhs)::_ => pure rhs -- nothing left to match
| _, [] => throwError ref "non-exhaustive 'match_syntax'"
| discr::discrs, alts =>
match alts.findSome? altNextNode? with
-- at least one conditional pattern: introduce an `if` for it and recurse
| some node => do
let shape := nodeShape node;
-- introduce pattern matches on the discriminant's children
newDiscrs ← (List.range node.getArgs.size).mapM $ fun i => `(Syntax.getArg discr $(quote i));
-- collect matching alternatives and explode them
let yesAlts := alts.filter $ fun alt => match altNextNode? alt with some n => nodeShape n == shape | none => true;
yesAlts ← yesAlts.mapM $ explodeHeadPat node.getArgs.size;
-- non-matching alternatives are left as-is
-- NOTE: unconditional patterns must go into both `yesAlts` and `noAlts`
let noAlts := alts.filter $ fun alt => match altNextNode? alt with some n => nodeShape n != shape | none => true;
-- NOTE: use fresh macro scopes for recursion so that different `discr`s introduced by the quotation below do not collide
yes ← withFreshMacroScope $ compileStxMatch (newDiscrs ++ discrs) yesAlts;
no ← withFreshMacroScope $ compileStxMatch (discr::discrs) noAlts;
`(let discr := $discr; if Syntax.isOfKind discr $(quote shape.fst) && Array.size (Syntax.getArgs discr) == $(quote shape.snd) then $yes else $no)
-- only unconditional patterns: introduce binds and discard patterns
| none => do
alts ← alts.mapM $ explodeHeadPat 0;
res ← withFreshMacroScope $ compileStxMatch discrs alts;
`(let discr := $discr; $res)
| discr::discrs, alts => do
let alts := (alts.map getHeadInfo).zip alts;
-- Choose a most specific pattern, ie. a minimal element according to `generalizes`.
-- If there are multiple minimal elements, the choice does not matter.
let (info, alt) := alts.tail!.foldl (fun (min : HeadInfo × Alt) (alt : HeadInfo × Alt) => if min.1.generalizes alt.1 then alt else min) alts.head!;
-- introduce pattern matches on the discriminant's children if there are any nested patterns
newDiscrs ← match info.argPats with
| some pats => (List.range pats.size).mapM $ fun i => `(Syntax.getArg discr $(quote i))
| none => pure [];
-- collect matching alternatives and explode them
let yesAlts := alts.filter $ fun (alt : HeadInfo × Alt) => alt.1.generalizes info;
yesAlts ← yesAlts.mapM $ explodeHeadPat newDiscrs.length;
-- NOTE: use fresh macro scopes for recursive call so that different `discr`s introduced by the quotations below do not collide
yes ← withFreshMacroScope $ compileStxMatch (newDiscrs ++ discrs) yesAlts;
some kind ← pure info.kind
-- unconditional match step
| `(let discr := $discr; $yes);
-- conditional match step
let noAlts := (alts.filter $ fun (alt : HeadInfo × Alt) => !info.generalizes alt.1).map Prod.snd;
no ← withFreshMacroScope $ compileStxMatch (discr::discrs) noAlts;
cond ← match info.argPats with
| some pats => `(Syntax.isOfKind discr $(quote kind) && Array.size (Syntax.getArgs discr) == $(quote pats.size))
| none => `(Syntax.isOfKind discr $(quote kind));
`(let discr := $discr; if $cond then $yes else $no)
| _, _ => unreachable!
private partial def getAntiquotVarsAux : Syntax → TermElabM (List Syntax)
| Syntax.node `Lean.Parser.Term.antiquot args =>
let anti := args.get! 1;
if anti.isOfKind `Lean.Parser.Term.id then pure [anti]
else throwError anti "syntax_match: antiquotation must be variable"
| Syntax.node k args => do
List.join <$> args.toList.mapM getAntiquotVarsAux
| stx@(Syntax.node k args) => do
if isAntiquot stx then
let anti := args.get! 1;
if anti.isOfKind `Lean.Parser.Term.id then pure [anti]
else throwError anti "syntax_match: antiquotation must be variable"
else
List.join <$> args.toList.mapM getAntiquotVarsAux
| _ => pure []
-- Get all antiquotations (as Term.id nodes) in `stx`