feat: implement match_syntax

This commit is contained in:
Sebastian Ullrich 2019-12-12 14:43:14 +01:00 committed by Leonardo de Moura
parent 53276e99dc
commit e8944fcf9d
2 changed files with 118 additions and 0 deletions

View file

@ -159,6 +159,12 @@ def find (p : α → Bool) : List α → Option α
| true => some a
| false => find as
def findSome? (f : α → Option β) : List α → Option β
| [] => none
| a::as => match f a with
| some b => some b
| none => findSome? as
def elem [HasBeq α] (a : α) : List α → Bool
| [] => false
| b::bs => match a == b with

View file

@ -82,6 +82,118 @@ fun stx expectedType? => do
stx ← stxQuot.expand env (stx.getArg 1);
elabTerm stx expectedType?
private abbrev Alt := List Syntax × Syntax
private def isVarPat? (pat : Syntax) : Option (Syntax → TermElabM Syntax) :=
if pat.isOfKind `Lean.Parser.Term.id then some $ fun rhs => `(%%rhs discr)
else if pat.isOfKind `Lean.Parser.Term.hole then some pure
else if pat.isOfKind `Lean.Parser.Term.stxQuot then
let quoted := pat.getArg 1;
if quoted.isAtom then some pure
else if quoted.isOfKind `Lean.Parser.Term.antiquot then
let anti := quoted.getArg 1;
if anti.isOfKind `Lean.Parser.Term.id then some $ fun rhs => `(%%rhs discr)
-- TODO: *, ?
else unreachable!
else none
else none
private def altNextNode? : Alt → Option SyntaxNode
| (pat::_, _) =>
if (isVarPat? pat).isNone && pat.isOfKind `Lean.Parser.Term.stxQuot then
let quoted := pat.getArg 1;
some quoted.asNode
else none
| _ => none
private def explodeHeadPat (numArgs : Nat) : Alt → TermElabM Alt
| (pat::pats, rhs) => match isVarPat? pat with
| some fnRhs => do
newPat ← `(_);
let newPats := List.replicate numArgs newPat;
rhs ← fnRhs rhs;
pure (newPats ++ pats, rhs)
| none =>
if pat.isOfKind `Lean.Parser.Term.stxQuot then do
let quoted := pat.getArg 1;
let newPats := quoted.getArgs.toList.map $ fun arg => Syntax.node `Lean.Parser.Term.stxQuot #[mkAtom "`(", arg, mkAtom ")"];
pure (newPats ++ pats, rhs)
else throwError pat $ "unsupported `syntax_match` pattern kind " ++ toString pat.getKind
| _ => unreachable!
private def nodeShape (n : SyntaxNode) : SyntaxNodeKind × Nat :=
(n.getKind, n.getArgs.size)
private partial def compileStxMatch (ref : Syntax) : List Syntax → List Alt → TermElabM Syntax
| [], ([], rhs)::_ => pure rhs
| _, [] => throwError ref "non-exhaustive 'match_syntax'"
| discr::discrs, alts =>
match alts.findSome? altNextNode? with
| some node => do
let shape := nodeShape node;
newDiscrs ← (List.range node.getArgs.size).mapM $ fun i => `(Lean.Syntax.getArg discr %%(Lean.HasQuote.quote i));
let yesAlts := alts.filter $ fun alt => match altNextNode? alt with some n => nodeShape n == shape | none => true;
yesAlts ← yesAlts.mapM $ explodeHeadPat node.getArgs.size;
let noAlts := alts.filter $ fun alt => match altNextNode? alt with some n => nodeShape n != shape | none => true;
yes ← withFreshMacroScope $ compileStxMatch (newDiscrs ++ discrs) yesAlts;
no ← withFreshMacroScope $ compileStxMatch (discr::discrs) noAlts;
`(let discr := %%discr; if Lean.Syntax.isOfKind discr %%(Lean.HasQuote.quote (Prod.fst shape)) then %%yes else %%no)
| none => do
alts ← alts.mapM $ explodeHeadPat 0;
res ← withFreshMacroScope $ compileStxMatch discrs alts;
`(let discr := %%discr; %%res)
--| _, _ => unreachable!
| discrs, alts => throwError ref $ toString (discrs, alts)
private partial def getAntiquotVarsAux : Syntax → TermElabM (List Syntax)
| Syntax.node `Lean.Parser.Term.antiquot args =>
let anti := args.get! 1;
if anti.isOfKind `Lean.Parser.Term.id then pure [anti]
else throwError anti "syntax_match: antiquotation must be variable"
| Syntax.node k args => do
List.join <$> args.toList.mapM getAntiquotVarsAux
| _ => pure []
private partial def getAntiquotVars (stx : Syntax) : TermElabM (List Syntax) :=
if stx.isOfKind `Lean.Parser.Term.stxQuot then do
let quoted := stx.getArg 1;
getAntiquotVarsAux stx
else pure []
private def letBindRhss (cont : List Alt → TermElabM Syntax) : List Alt → List Alt → TermElabM Syntax
| [], altsRev' => cont altsRev'.reverse
| (pats, rhs)::alts, altsRev' => do
vars ← List.join <$> pats.mapM getAntiquotVars;
match vars with
| [] => do
rhs' ← `(rhs ());
cont ← withFreshMacroScope $ letBindRhss alts ((pats, rhs')::altsRev');
`(let rhs := fun _ => %%rhs; %%cont)
| _ => do
-- rhs ← `(fun %%vars... => %%rhs)
let rhs := Syntax.node `Lean.Parser.Term.fun #[mkAtom "fun", Syntax.node `null vars.toArray, mkAtom "=>", rhs];
-- rhs' ← `(rhs %%vars...);
rhs' ← `(rhs);
cont ← withFreshMacroScope $ letBindRhss alts ((pats, rhs')::altsRev');
`(let rhs := %%rhs; %%cont)
def match_syntax.expand (stx : SyntaxNode) : TermElabM Syntax := do
let discr := stx.getArg 1;
let alts := stx.getArg 3;
alts ← alts.getArgs.mapM $ fun alt => do {
let pats := alt.getArg 1;
pat ← if pats.getArgs.size == 1 then pure $ pats.getArg 0
else throwError stx.val "syntax_match: expected exactly one pattern per alternative";
let rhs := alt.getArg 3;
pure ([pat], rhs)
};
letBindRhss (compileStxMatch stx.val [discr]) alts.toList []
@[builtinTermElab «match_syntax»] def elabMatchSyntax : TermElab :=
fun stx expectedType? => do
stx ← match_syntax.expand stx;
elabTerm stx expectedType?
-- REMOVE with old frontend
private def exprPlaceholder := mkMVar Name.anonymous