diff --git a/src/Lean/Elab/Quotation.lean b/src/Lean/Elab/Quotation.lean index eca6c4b12d..9d83b9032a 100644 --- a/src/Lean/Elab/Quotation.lean +++ b/src/Lean/Elab/Quotation.lean @@ -204,7 +204,7 @@ private def noOpMatchAdaptPats : HeadCheck → Alt → Alt private def adaptRhs (fn : Syntax → TermElabM Syntax) : Alt → TermElabM Alt | (pats, rhs) => do (pats, ← fn rhs) -private def getHeadInfo (alt : Alt) : TermElabM HeadInfo := +private partial def getHeadInfo (alt : Alt) : TermElabM HeadInfo := let pat := alt.fst.head! let unconditionally (rhsFn) := pure { check := unconditional, @@ -317,8 +317,13 @@ private def getHeadInfo (alt : Alt) : TermElabM HeadInfo := `(ite (Eq $cond true) $(← yes newDiscrs) $(← no)) } else match pat with - | `(_) => unconditionally pure - | `($id:ident) => unconditionally (`(let $id := discr; $(·))) + | `(_) => unconditionally pure + | `($id:ident) => unconditionally (`(let $id := discr; $(·))) + | `($id:ident@$pat) => do + let info ← getHeadInfo (pat::alt.1.tail!, alt.2) + { info with onMatch := fun taken => match info.onMatch taken with + | covered f exh => covered (fun alt => f alt >>= adaptRhs (`(let $id := discr; $(·)))) exh + | r => r } | _ => throwErrorAt! pat "match_syntax: unexpected pattern kind {pat}" private partial def compileStxMatch (discrs : List Syntax) (alts : List Alt) : TermElabM Syntax := do @@ -378,7 +383,9 @@ def match_syntax.expand (stx : Syntax) : TermElabM Syntax := do match stx with | `(match $[$discrs:term],* with $[| $[$patss],* => $rhss]*) => do -- letBindRhss ... - if patss.all (·.all (!·.isQuot)) then + if !patss.any (·.any (fun + | `($id@$pat) => pat.isQuot + | pat => pat.isQuot)) then -- no quotations => fall back to regular `match` throwUnsupportedSyntax let stx ← compileStxMatch discrs.toList (patss.map (·.toList) |>.zip rhss).toList diff --git a/tests/lean/StxQuot.lean b/tests/lean/StxQuot.lean index 24676b3896..952b1692d9 100644 --- a/tests/lean/StxQuot.lean +++ b/tests/lean/StxQuot.lean @@ -24,6 +24,7 @@ end Lean.Syntax #eval run $ do let a ← `(Nat.one); `(f $(id a)) #eval run $ do let a ← `(Nat.one); `($(a).b) #eval run $ do let a ← `(1 + 2); match a with | `($a + $b) => `($b + $a) | _ => pure Syntax.missing +#eval run $ do let a ← `(1 + 2); match a with | stx@`($a + $b) => `($stx + $a) | _ => pure Syntax.missing #eval run $ do let a ← `(def foo := 1); match a with | `($f:command) => pure f | _ => pure Syntax.missing #eval run $ do let a ← `(def foo := 1 def bar := 2); match a with | `($f:command $g:command) => `($g:command $f:command) | _ => pure Syntax.missing diff --git a/tests/lean/StxQuot.lean.expected.out b/tests/lean/StxQuot.lean.expected.out index 3d6007df28..f2ab494a46 100644 --- a/tests/lean/StxQuot.lean.expected.out +++ b/tests/lean/StxQuot.lean.expected.out @@ -15,6 +15,7 @@ "(Term.app `f._@.UnhygienicMain._hyg.1 [`Nat.one._@.UnhygienicMain._hyg.1])" "(Term.proj `Nat.one._@.UnhygienicMain._hyg.1 \".\" `b._@.UnhygienicMain._hyg.1)" "(term_+_ (numLit \"2\") \"+\" (numLit \"1\"))" +"(term_+_ (term_+_ (numLit \"1\") \"+\" (numLit \"2\")) \"+\" (numLit \"1\"))" "(Command.declaration\n (Command.declModifiers [] [] [] [] [] [])\n (Command.def\n \"def\"\n (Command.declId `foo._@.UnhygienicMain._hyg.1 [])\n (Command.optDeclSig [] [])\n (Command.declValSimple \":=\" (numLit \"1\") [])))" "[(Command.declaration\n (Command.declModifiers [] [] [] [] [] [])\n (Command.def\n \"def\"\n (Command.declId `bar._@.UnhygienicMain._hyg.1 [])\n (Command.optDeclSig [] [])\n (Command.declValSimple \":=\" (numLit \"2\") [])))\n (Command.declaration\n (Command.declModifiers [] [] [] [] [] [])\n (Command.def\n \"def\"\n (Command.declId `foo._@.UnhygienicMain._hyg.1 [])\n (Command.optDeclSig [] [])\n (Command.declValSimple \":=\" (numLit \"1\") [])))]" "0" @@ -34,4 +35,4 @@ "(Term.match\n \"match\"\n [(Term.matchDiscr [] `a._@.UnhygienicMain._hyg.1)]\n []\n \"with\"\n (Term.matchAlts\n [(Term.matchAlt \"|\" [`a._@.UnhygienicMain._hyg.1] \"=>\" `b._@.UnhygienicMain._hyg.1)\n (Term.matchAlt\n \"|\"\n [(term_+_ `a._@.UnhygienicMain._hyg.1 \"+\" (numLit \"1\"))]\n \"=>\"\n (term_+_ `b._@.UnhygienicMain._hyg.1 \"+\" (numLit \"1\")))]))" "(Term.match\n \"match\"\n [(Term.matchDiscr [] `a._@.UnhygienicMain._hyg.1)]\n []\n \"with\"\n (Term.matchAlts\n [(Term.matchAlt \"|\" [`a._@.UnhygienicMain._hyg.1] \"=>\" `b._@.UnhygienicMain._hyg.1)\n (Term.matchAlt\n \"|\"\n [(term_+_ `a._@.UnhygienicMain._hyg.1 \"+\" (numLit \"1\"))]\n \"=>\"\n (term_+_ `b._@.UnhygienicMain._hyg.1 \"+\" (numLit \"1\")))]))" "#[`a._@.UnhygienicMain._hyg.1, `b._@.UnhygienicMain._hyg.1]" -StxQuot.lean:70:33: error: expected parser to return exactly one syntax object +StxQuot.lean:71:33: error: expected parser to return exactly one syntax object