From e8944fcf9dc0f5509091f54cddfa3e86a1c87c55 Mon Sep 17 00:00:00 2001 From: Sebastian Ullrich Date: Thu, 12 Dec 2019 14:43:14 +0100 Subject: [PATCH] feat: implement match_syntax --- src/Init/Data/List/Basic.lean | 6 ++ src/Init/Lean/Elab/Quotation.lean | 112 ++++++++++++++++++++++++++++++ 2 files changed, 118 insertions(+) diff --git a/src/Init/Data/List/Basic.lean b/src/Init/Data/List/Basic.lean index 05a84824a3..e3b7f3f571 100644 --- a/src/Init/Data/List/Basic.lean +++ b/src/Init/Data/List/Basic.lean @@ -159,6 +159,12 @@ def find (p : α → Bool) : List α → Option α | true => some a | false => find as +def findSome? (f : α → Option β) : List α → Option β +| [] => none +| a::as => match f a with + | some b => some b + | none => findSome? as + def elem [HasBeq α] (a : α) : List α → Bool | [] => false | b::bs => match a == b with diff --git a/src/Init/Lean/Elab/Quotation.lean b/src/Init/Lean/Elab/Quotation.lean index 5c48c70083..efdff03890 100644 --- a/src/Init/Lean/Elab/Quotation.lean +++ b/src/Init/Lean/Elab/Quotation.lean @@ -82,6 +82,118 @@ fun stx expectedType? => do stx ← stxQuot.expand env (stx.getArg 1); elabTerm stx expectedType? +private abbrev Alt := List Syntax × Syntax + +private def isVarPat? (pat : Syntax) : Option (Syntax → TermElabM Syntax) := +if pat.isOfKind `Lean.Parser.Term.id then some $ fun rhs => `(%%rhs discr) +else if pat.isOfKind `Lean.Parser.Term.hole then some pure +else if pat.isOfKind `Lean.Parser.Term.stxQuot then + let quoted := pat.getArg 1; + if quoted.isAtom then some pure + else if quoted.isOfKind `Lean.Parser.Term.antiquot then + let anti := quoted.getArg 1; + if anti.isOfKind `Lean.Parser.Term.id then some $ fun rhs => `(%%rhs discr) + -- TODO: *, ? + else unreachable! + else none +else none + +private def altNextNode? : Alt → Option SyntaxNode +| (pat::_, _) => + if (isVarPat? pat).isNone && pat.isOfKind `Lean.Parser.Term.stxQuot then + let quoted := pat.getArg 1; + some quoted.asNode + else none +| _ => none + +private def explodeHeadPat (numArgs : Nat) : Alt → TermElabM Alt +| (pat::pats, rhs) => match isVarPat? pat with + | some fnRhs => do + newPat ← `(_); + let newPats := List.replicate numArgs newPat; + rhs ← fnRhs rhs; + pure (newPats ++ pats, rhs) + | none => + if pat.isOfKind `Lean.Parser.Term.stxQuot then do + let quoted := pat.getArg 1; + let newPats := quoted.getArgs.toList.map $ fun arg => Syntax.node `Lean.Parser.Term.stxQuot #[mkAtom "`(", arg, mkAtom ")"]; + pure (newPats ++ pats, rhs) + else throwError pat $ "unsupported `syntax_match` pattern kind " ++ toString pat.getKind +| _ => unreachable! + +private def nodeShape (n : SyntaxNode) : SyntaxNodeKind × Nat := +(n.getKind, n.getArgs.size) + +private partial def compileStxMatch (ref : Syntax) : List Syntax → List Alt → TermElabM Syntax +| [], ([], rhs)::_ => pure rhs +| _, [] => throwError ref "non-exhaustive 'match_syntax'" +| discr::discrs, alts => + match alts.findSome? altNextNode? with + | some node => do + let shape := nodeShape node; + newDiscrs ← (List.range node.getArgs.size).mapM $ fun i => `(Lean.Syntax.getArg discr %%(Lean.HasQuote.quote i)); + let yesAlts := alts.filter $ fun alt => match altNextNode? alt with some n => nodeShape n == shape | none => true; + yesAlts ← yesAlts.mapM $ explodeHeadPat node.getArgs.size; + let noAlts := alts.filter $ fun alt => match altNextNode? alt with some n => nodeShape n != shape | none => true; + yes ← withFreshMacroScope $ compileStxMatch (newDiscrs ++ discrs) yesAlts; + no ← withFreshMacroScope $ compileStxMatch (discr::discrs) noAlts; + `(let discr := %%discr; if Lean.Syntax.isOfKind discr %%(Lean.HasQuote.quote (Prod.fst shape)) then %%yes else %%no) + | none => do + alts ← alts.mapM $ explodeHeadPat 0; + res ← withFreshMacroScope $ compileStxMatch discrs alts; + `(let discr := %%discr; %%res) +--| _, _ => unreachable! +| discrs, alts => throwError ref $ toString (discrs, alts) + +private partial def getAntiquotVarsAux : Syntax → TermElabM (List Syntax) +| Syntax.node `Lean.Parser.Term.antiquot args => + let anti := args.get! 1; + if anti.isOfKind `Lean.Parser.Term.id then pure [anti] + else throwError anti "syntax_match: antiquotation must be variable" +| Syntax.node k args => do + List.join <$> args.toList.mapM getAntiquotVarsAux +| _ => pure [] + +private partial def getAntiquotVars (stx : Syntax) : TermElabM (List Syntax) := +if stx.isOfKind `Lean.Parser.Term.stxQuot then do + let quoted := stx.getArg 1; + getAntiquotVarsAux stx +else pure [] + +private def letBindRhss (cont : List Alt → TermElabM Syntax) : List Alt → List Alt → TermElabM Syntax +| [], altsRev' => cont altsRev'.reverse +| (pats, rhs)::alts, altsRev' => do + vars ← List.join <$> pats.mapM getAntiquotVars; + match vars with + | [] => do + rhs' ← `(rhs ()); + cont ← withFreshMacroScope $ letBindRhss alts ((pats, rhs')::altsRev'); + `(let rhs := fun _ => %%rhs; %%cont) + | _ => do + -- rhs ← `(fun %%vars... => %%rhs) + let rhs := Syntax.node `Lean.Parser.Term.fun #[mkAtom "fun", Syntax.node `null vars.toArray, mkAtom "=>", rhs]; + -- rhs' ← `(rhs %%vars...); + rhs' ← `(rhs); + cont ← withFreshMacroScope $ letBindRhss alts ((pats, rhs')::altsRev'); + `(let rhs := %%rhs; %%cont) + +def match_syntax.expand (stx : SyntaxNode) : TermElabM Syntax := do +let discr := stx.getArg 1; +let alts := stx.getArg 3; +alts ← alts.getArgs.mapM $ fun alt => do { + let pats := alt.getArg 1; + pat ← if pats.getArgs.size == 1 then pure $ pats.getArg 0 + else throwError stx.val "syntax_match: expected exactly one pattern per alternative"; + let rhs := alt.getArg 3; + pure ([pat], rhs) +}; +letBindRhss (compileStxMatch stx.val [discr]) alts.toList [] + +@[builtinTermElab «match_syntax»] def elabMatchSyntax : TermElab := +fun stx expectedType? => do + stx ← match_syntax.expand stx; + elabTerm stx expectedType? + -- REMOVE with old frontend private def exprPlaceholder := mkMVar Name.anonymous