feat: refactor match_syntax compiler to properly match quotation kinds, which can act as both "variable" and "constructor" patterns simultaneously
This commit is contained in:
parent
9bf8c96502
commit
fe9bd200da
1 changed files with 102 additions and 73 deletions
|
|
@ -59,6 +59,7 @@ instance Array.HasQuote {α : Type} [HasQuote α] : HasQuote (Array α) :=
|
|||
namespace Elab
|
||||
namespace Term
|
||||
|
||||
-- antiquotation node kinds are formed from the original node kind (if any) plus "antiquot"
|
||||
def isAntiquot : Syntax → Bool
|
||||
| Syntax.node (Name.str _ "antiquot" _) _ => true
|
||||
| _ => false
|
||||
|
|
@ -141,96 +142,124 @@ fun stx expectedType? => do
|
|||
-- an "alternative" of patterns plus right-hand side
|
||||
private abbrev Alt := List Syntax × Syntax
|
||||
|
||||
-- If `pat` is an unconditional pattern, return a transformation of the RHS that appropriately introduces
|
||||
-- bindings on the RHS.
|
||||
private def isVarPat? (pat : Syntax) : Option (Syntax → TermElabM Syntax) :=
|
||||
-- TODO: reimplement using match_syntax
|
||||
if pat.isOfKind `Lean.Parser.Term.id then some $ fun rhs => `(let $pat := discr; $rhs)
|
||||
else if pat.isOfKind `Lean.Parser.Term.hole then some pure
|
||||
/-- Information on a pattern's head that influences the compilation of a single
|
||||
match step. -/
|
||||
structure HeadInfo :=
|
||||
-- Node kind to match, if any
|
||||
(kind : Option SyntaxNodeKind := none)
|
||||
-- Nested patterns for each argument, if any. In a single match step, we only
|
||||
-- check that the arity matches. The arity is usually implied by the node kind,
|
||||
-- but not in the case of `many` nodes.
|
||||
(argPats : Option (Array Syntax) := none)
|
||||
-- Function to apply to the right-hand side in case the match succeeds. Used to
|
||||
-- bind pattern variables.
|
||||
(rhsFn : Syntax → TermElabM Syntax := pure)
|
||||
|
||||
instance HeadInfo.Inhabited : Inhabited HeadInfo := ⟨{}⟩
|
||||
|
||||
/-- `h1.generalizes h2` iff h1 is equal to or more general than h2, i.e. it matches all nodes
|
||||
h2 matches. This induces a partial ordering. -/
|
||||
def HeadInfo.generalizes : HeadInfo → HeadInfo → Bool
|
||||
| { kind := none, .. }, _ => true
|
||||
| { kind := some k1, argPats := none, .. },
|
||||
{ kind := some k2, .. } => k1 == k2
|
||||
| { kind := some k1, argPats := some ps1, .. },
|
||||
{ kind := some k2, argPats := some ps2, .. } => k1 == k2 && ps1.size == ps2.size
|
||||
| _, _ => false
|
||||
|
||||
private def getHeadInfo (alt : Alt) : HeadInfo :=
|
||||
let pat := alt.fst.head!;
|
||||
let unconditional (rhsFn) := { HeadInfo . rhsFn := rhsFn };
|
||||
-- variable pattern
|
||||
if pat.isOfKind `Lean.Parser.Term.id then unconditional $ fun rhs => `(let $pat := discr; $rhs)
|
||||
-- wildcard pattern
|
||||
else if pat.isOfKind `Lean.Parser.Term.hole then unconditional pure
|
||||
-- quotation pattern
|
||||
else if pat.isOfKind `Lean.Parser.Term.stxQuot then
|
||||
let quoted := pat.getArg 1;
|
||||
-- We assume that atoms are uniquely determined by the surrounding node and never have to be checked
|
||||
if quoted.isAtom then some pure
|
||||
-- TODO: antiquotations with kinds (`$id:id`) probably can't be handled as unconditional patterns
|
||||
else if isAntiquot quoted then
|
||||
match quoted with
|
||||
-- We assume that atoms are uniquely determined by the node kind and never have to be checked
|
||||
| Syntax.atom _ _ => unconditional pure
|
||||
-- quotation is a single antiquotation
|
||||
| Syntax.node (Name.str k "antiquot" _) _ =>
|
||||
-- Antiquotation kinds like `$id:id` influence the parser, but also need to be considered by
|
||||
-- match_syntax (but not by quotation terms). For example, `($id:id) and `($e) are not
|
||||
-- distinguishable without checking the kind of the node to be captured. Note that some
|
||||
-- antiquotations like the latter one for terms do not correspond to any actual node kind
|
||||
-- (signified by `k == Name.anonymous`), so we would only check for `Term.id` here.
|
||||
--
|
||||
-- if stx.isOfKind `Lean.Parser.Term.id then
|
||||
-- let id := stx; ...
|
||||
-- else
|
||||
-- let e := stx; ...
|
||||
let kind := if k == Name.anonymous then none else some k;
|
||||
let anti := quoted.getArg 1;
|
||||
if isAntiquotSplice quoted then some $ fun _ => throwError quoted "unexpected antiquotation splice"
|
||||
else if anti.isOfKind `Lean.Parser.Term.id then some $ fun rhs => `(let $anti := discr; $rhs)
|
||||
else unreachable!
|
||||
else if isAntiquotSplicePat quoted then
|
||||
let anti := (quoted.getArg 0).getArg 1;
|
||||
some $ fun rhs => `(let $anti := Syntax.getArgs discr; $rhs)
|
||||
else none
|
||||
else none
|
||||
|
||||
-- If the first pattern of the alternative is a conditional pattern, return the node we should match against
|
||||
private def altNextNode? : Alt → Option SyntaxNode
|
||||
| (pat::_, _) =>
|
||||
if (isVarPat? pat).isNone && pat.isOfKind `Lean.Parser.Term.stxQuot then
|
||||
let quoted := pat.getArg 1;
|
||||
some quoted.asNode
|
||||
else none
|
||||
| _ => none
|
||||
-- Splices should only appear inside a nullKind node, see next case
|
||||
if isAntiquotSplice quoted then unconditional $ fun _ => throwError quoted "unexpected antiquotation splice"
|
||||
else if anti.isOfKind `Lean.Parser.Term.id then { kind := kind, rhsFn := fun rhs => `(let $anti := discr; $rhs) }
|
||||
else unconditional $ fun _ => throwError anti "syntax_match: antiquotation must be variable"
|
||||
| _ =>
|
||||
-- quotation is a single antiquotation splice => bind args array
|
||||
if isAntiquotSplicePat quoted then
|
||||
let anti := (quoted.getArg 0).getArg 1;
|
||||
unconditional $ fun rhs => `(let $anti := Syntax.getArgs discr; $rhs)
|
||||
else
|
||||
-- not an antiquotation: match head shape
|
||||
let argPats := quoted.getArgs.map $ fun arg => Syntax.node `Lean.Parser.Term.stxQuot #[mkAtom "`(", arg, mkAtom ")"];
|
||||
{ kind := quoted.getKind, argPats := argPats }
|
||||
else
|
||||
unconditional $ fun _ => throwError pat "syntax_match: unexpected pattern kind"
|
||||
|
||||
-- Assuming that the first pattern of the alternative is taken, replace it with patterns (if any) for its
|
||||
-- child nodes.
|
||||
-- Ex: `($a + (- $b)) => `($a), `(+), `(- $b)
|
||||
-- Note: The atom pattern `(+) will be discarded in a later step
|
||||
private def explodeHeadPat (numArgs : Nat) : Alt → TermElabM Alt
|
||||
| (pat::pats, rhs) => match isVarPat? pat with
|
||||
| some fnRhs => do
|
||||
-- unconditional pattern: replace with appropriate number of wildcards
|
||||
newPat ← `(_);
|
||||
let newPats := List.replicate numArgs newPat;
|
||||
rhs ← fnRhs rhs;
|
||||
private def explodeHeadPat (numArgs : Nat) : HeadInfo × Alt → TermElabM Alt
|
||||
| (info, (pat::pats, rhs)) => do
|
||||
let newPats := match info.argPats with
|
||||
| some argPats => argPats.toList
|
||||
| none => List.replicate numArgs $ Unhygienic.run `(_);
|
||||
rhs ← info.rhsFn rhs;
|
||||
pure (newPats ++ pats, rhs)
|
||||
| none =>
|
||||
if pat.isOfKind `Lean.Parser.Term.stxQuot then do
|
||||
let quoted := pat.getArg 1;
|
||||
let newPats := quoted.getArgs.toList.map $ fun arg => Syntax.node `Lean.Parser.Term.stxQuot #[mkAtom "`(", arg, mkAtom ")"];
|
||||
pure (newPats ++ pats, rhs)
|
||||
else throwError pat $ "unsupported `syntax_match` pattern kind " ++ toString pat.getKind
|
||||
| _ => unreachable!
|
||||
|
||||
-- The "shape" is the information that should be compared in a single matching step. Currently, it is the node kind
|
||||
-- and its arity (which is not constant in the case of `many` nodes)
|
||||
private def nodeShape (n : SyntaxNode) : SyntaxNodeKind × Nat :=
|
||||
(n.getKind, n.getArgs.size)
|
||||
|
||||
private partial def compileStxMatch (ref : Syntax) : List Syntax → List Alt → TermElabM Syntax
|
||||
| [], ([], rhs)::_ => pure rhs -- nothing left to match
|
||||
| _, [] => throwError ref "non-exhaustive 'match_syntax'"
|
||||
| discr::discrs, alts =>
|
||||
match alts.findSome? altNextNode? with
|
||||
-- at least one conditional pattern: introduce an `if` for it and recurse
|
||||
| some node => do
|
||||
let shape := nodeShape node;
|
||||
-- introduce pattern matches on the discriminant's children
|
||||
newDiscrs ← (List.range node.getArgs.size).mapM $ fun i => `(Syntax.getArg discr $(quote i));
|
||||
-- collect matching alternatives and explode them
|
||||
let yesAlts := alts.filter $ fun alt => match altNextNode? alt with some n => nodeShape n == shape | none => true;
|
||||
yesAlts ← yesAlts.mapM $ explodeHeadPat node.getArgs.size;
|
||||
-- non-matching alternatives are left as-is
|
||||
-- NOTE: unconditional patterns must go into both `yesAlts` and `noAlts`
|
||||
let noAlts := alts.filter $ fun alt => match altNextNode? alt with some n => nodeShape n != shape | none => true;
|
||||
-- NOTE: use fresh macro scopes for recursion so that different `discr`s introduced by the quotation below do not collide
|
||||
yes ← withFreshMacroScope $ compileStxMatch (newDiscrs ++ discrs) yesAlts;
|
||||
no ← withFreshMacroScope $ compileStxMatch (discr::discrs) noAlts;
|
||||
`(let discr := $discr; if Syntax.isOfKind discr $(quote shape.fst) && Array.size (Syntax.getArgs discr) == $(quote shape.snd) then $yes else $no)
|
||||
-- only unconditional patterns: introduce binds and discard patterns
|
||||
| none => do
|
||||
alts ← alts.mapM $ explodeHeadPat 0;
|
||||
res ← withFreshMacroScope $ compileStxMatch discrs alts;
|
||||
`(let discr := $discr; $res)
|
||||
| discr::discrs, alts => do
|
||||
let alts := (alts.map getHeadInfo).zip alts;
|
||||
-- Choose a most specific pattern, ie. a minimal element according to `generalizes`.
|
||||
-- If there are multiple minimal elements, the choice does not matter.
|
||||
let (info, alt) := alts.tail!.foldl (fun (min : HeadInfo × Alt) (alt : HeadInfo × Alt) => if min.1.generalizes alt.1 then alt else min) alts.head!;
|
||||
-- introduce pattern matches on the discriminant's children if there are any nested patterns
|
||||
newDiscrs ← match info.argPats with
|
||||
| some pats => (List.range pats.size).mapM $ fun i => `(Syntax.getArg discr $(quote i))
|
||||
| none => pure [];
|
||||
-- collect matching alternatives and explode them
|
||||
let yesAlts := alts.filter $ fun (alt : HeadInfo × Alt) => alt.1.generalizes info;
|
||||
yesAlts ← yesAlts.mapM $ explodeHeadPat newDiscrs.length;
|
||||
-- NOTE: use fresh macro scopes for recursive call so that different `discr`s introduced by the quotations below do not collide
|
||||
yes ← withFreshMacroScope $ compileStxMatch (newDiscrs ++ discrs) yesAlts;
|
||||
some kind ← pure info.kind
|
||||
-- unconditional match step
|
||||
| `(let discr := $discr; $yes);
|
||||
-- conditional match step
|
||||
let noAlts := (alts.filter $ fun (alt : HeadInfo × Alt) => !info.generalizes alt.1).map Prod.snd;
|
||||
no ← withFreshMacroScope $ compileStxMatch (discr::discrs) noAlts;
|
||||
cond ← match info.argPats with
|
||||
| some pats => `(Syntax.isOfKind discr $(quote kind) && Array.size (Syntax.getArgs discr) == $(quote pats.size))
|
||||
| none => `(Syntax.isOfKind discr $(quote kind));
|
||||
`(let discr := $discr; if $cond then $yes else $no)
|
||||
| _, _ => unreachable!
|
||||
|
||||
private partial def getAntiquotVarsAux : Syntax → TermElabM (List Syntax)
|
||||
| Syntax.node `Lean.Parser.Term.antiquot args =>
|
||||
let anti := args.get! 1;
|
||||
if anti.isOfKind `Lean.Parser.Term.id then pure [anti]
|
||||
else throwError anti "syntax_match: antiquotation must be variable"
|
||||
| Syntax.node k args => do
|
||||
List.join <$> args.toList.mapM getAntiquotVarsAux
|
||||
| stx@(Syntax.node k args) => do
|
||||
if isAntiquot stx then
|
||||
let anti := args.get! 1;
|
||||
if anti.isOfKind `Lean.Parser.Term.id then pure [anti]
|
||||
else throwError anti "syntax_match: antiquotation must be variable"
|
||||
else
|
||||
List.join <$> args.toList.mapM getAntiquotVarsAux
|
||||
| _ => pure []
|
||||
|
||||
-- Get all antiquotations (as Term.id nodes) in `stx`
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue