feat: implement match_syntax
This commit is contained in:
parent
53276e99dc
commit
e8944fcf9d
2 changed files with 118 additions and 0 deletions
|
|
@ -159,6 +159,12 @@ def find (p : α → Bool) : List α → Option α
|
|||
| true => some a
|
||||
| false => find as
|
||||
|
||||
def findSome? (f : α → Option β) : List α → Option β
|
||||
| [] => none
|
||||
| a::as => match f a with
|
||||
| some b => some b
|
||||
| none => findSome? as
|
||||
|
||||
def elem [HasBeq α] (a : α) : List α → Bool
|
||||
| [] => false
|
||||
| b::bs => match a == b with
|
||||
|
|
|
|||
|
|
@ -82,6 +82,118 @@ fun stx expectedType? => do
|
|||
stx ← stxQuot.expand env (stx.getArg 1);
|
||||
elabTerm stx expectedType?
|
||||
|
||||
private abbrev Alt := List Syntax × Syntax
|
||||
|
||||
private def isVarPat? (pat : Syntax) : Option (Syntax → TermElabM Syntax) :=
|
||||
if pat.isOfKind `Lean.Parser.Term.id then some $ fun rhs => `(%%rhs discr)
|
||||
else if pat.isOfKind `Lean.Parser.Term.hole then some pure
|
||||
else if pat.isOfKind `Lean.Parser.Term.stxQuot then
|
||||
let quoted := pat.getArg 1;
|
||||
if quoted.isAtom then some pure
|
||||
else if quoted.isOfKind `Lean.Parser.Term.antiquot then
|
||||
let anti := quoted.getArg 1;
|
||||
if anti.isOfKind `Lean.Parser.Term.id then some $ fun rhs => `(%%rhs discr)
|
||||
-- TODO: *, ?
|
||||
else unreachable!
|
||||
else none
|
||||
else none
|
||||
|
||||
private def altNextNode? : Alt → Option SyntaxNode
|
||||
| (pat::_, _) =>
|
||||
if (isVarPat? pat).isNone && pat.isOfKind `Lean.Parser.Term.stxQuot then
|
||||
let quoted := pat.getArg 1;
|
||||
some quoted.asNode
|
||||
else none
|
||||
| _ => none
|
||||
|
||||
private def explodeHeadPat (numArgs : Nat) : Alt → TermElabM Alt
|
||||
| (pat::pats, rhs) => match isVarPat? pat with
|
||||
| some fnRhs => do
|
||||
newPat ← `(_);
|
||||
let newPats := List.replicate numArgs newPat;
|
||||
rhs ← fnRhs rhs;
|
||||
pure (newPats ++ pats, rhs)
|
||||
| none =>
|
||||
if pat.isOfKind `Lean.Parser.Term.stxQuot then do
|
||||
let quoted := pat.getArg 1;
|
||||
let newPats := quoted.getArgs.toList.map $ fun arg => Syntax.node `Lean.Parser.Term.stxQuot #[mkAtom "`(", arg, mkAtom ")"];
|
||||
pure (newPats ++ pats, rhs)
|
||||
else throwError pat $ "unsupported `syntax_match` pattern kind " ++ toString pat.getKind
|
||||
| _ => unreachable!
|
||||
|
||||
private def nodeShape (n : SyntaxNode) : SyntaxNodeKind × Nat :=
|
||||
(n.getKind, n.getArgs.size)
|
||||
|
||||
private partial def compileStxMatch (ref : Syntax) : List Syntax → List Alt → TermElabM Syntax
|
||||
| [], ([], rhs)::_ => pure rhs
|
||||
| _, [] => throwError ref "non-exhaustive 'match_syntax'"
|
||||
| discr::discrs, alts =>
|
||||
match alts.findSome? altNextNode? with
|
||||
| some node => do
|
||||
let shape := nodeShape node;
|
||||
newDiscrs ← (List.range node.getArgs.size).mapM $ fun i => `(Lean.Syntax.getArg discr %%(Lean.HasQuote.quote i));
|
||||
let yesAlts := alts.filter $ fun alt => match altNextNode? alt with some n => nodeShape n == shape | none => true;
|
||||
yesAlts ← yesAlts.mapM $ explodeHeadPat node.getArgs.size;
|
||||
let noAlts := alts.filter $ fun alt => match altNextNode? alt with some n => nodeShape n != shape | none => true;
|
||||
yes ← withFreshMacroScope $ compileStxMatch (newDiscrs ++ discrs) yesAlts;
|
||||
no ← withFreshMacroScope $ compileStxMatch (discr::discrs) noAlts;
|
||||
`(let discr := %%discr; if Lean.Syntax.isOfKind discr %%(Lean.HasQuote.quote (Prod.fst shape)) then %%yes else %%no)
|
||||
| none => do
|
||||
alts ← alts.mapM $ explodeHeadPat 0;
|
||||
res ← withFreshMacroScope $ compileStxMatch discrs alts;
|
||||
`(let discr := %%discr; %%res)
|
||||
--| _, _ => unreachable!
|
||||
| discrs, alts => throwError ref $ toString (discrs, alts)
|
||||
|
||||
private partial def getAntiquotVarsAux : Syntax → TermElabM (List Syntax)
|
||||
| Syntax.node `Lean.Parser.Term.antiquot args =>
|
||||
let anti := args.get! 1;
|
||||
if anti.isOfKind `Lean.Parser.Term.id then pure [anti]
|
||||
else throwError anti "syntax_match: antiquotation must be variable"
|
||||
| Syntax.node k args => do
|
||||
List.join <$> args.toList.mapM getAntiquotVarsAux
|
||||
| _ => pure []
|
||||
|
||||
private partial def getAntiquotVars (stx : Syntax) : TermElabM (List Syntax) :=
|
||||
if stx.isOfKind `Lean.Parser.Term.stxQuot then do
|
||||
let quoted := stx.getArg 1;
|
||||
getAntiquotVarsAux stx
|
||||
else pure []
|
||||
|
||||
private def letBindRhss (cont : List Alt → TermElabM Syntax) : List Alt → List Alt → TermElabM Syntax
|
||||
| [], altsRev' => cont altsRev'.reverse
|
||||
| (pats, rhs)::alts, altsRev' => do
|
||||
vars ← List.join <$> pats.mapM getAntiquotVars;
|
||||
match vars with
|
||||
| [] => do
|
||||
rhs' ← `(rhs ());
|
||||
cont ← withFreshMacroScope $ letBindRhss alts ((pats, rhs')::altsRev');
|
||||
`(let rhs := fun _ => %%rhs; %%cont)
|
||||
| _ => do
|
||||
-- rhs ← `(fun %%vars... => %%rhs)
|
||||
let rhs := Syntax.node `Lean.Parser.Term.fun #[mkAtom "fun", Syntax.node `null vars.toArray, mkAtom "=>", rhs];
|
||||
-- rhs' ← `(rhs %%vars...);
|
||||
rhs' ← `(rhs);
|
||||
cont ← withFreshMacroScope $ letBindRhss alts ((pats, rhs')::altsRev');
|
||||
`(let rhs := %%rhs; %%cont)
|
||||
|
||||
def match_syntax.expand (stx : SyntaxNode) : TermElabM Syntax := do
|
||||
let discr := stx.getArg 1;
|
||||
let alts := stx.getArg 3;
|
||||
alts ← alts.getArgs.mapM $ fun alt => do {
|
||||
let pats := alt.getArg 1;
|
||||
pat ← if pats.getArgs.size == 1 then pure $ pats.getArg 0
|
||||
else throwError stx.val "syntax_match: expected exactly one pattern per alternative";
|
||||
let rhs := alt.getArg 3;
|
||||
pure ([pat], rhs)
|
||||
};
|
||||
letBindRhss (compileStxMatch stx.val [discr]) alts.toList []
|
||||
|
||||
@[builtinTermElab «match_syntax»] def elabMatchSyntax : TermElab :=
|
||||
fun stx expectedType? => do
|
||||
stx ← match_syntax.expand stx;
|
||||
elabTerm stx expectedType?
|
||||
|
||||
-- REMOVE with old frontend
|
||||
private def exprPlaceholder := mkMVar Name.anonymous
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue