fix: simple-match macro

This commit is contained in:
Leonardo de Moura 2021-01-12 06:41:32 -08:00
parent 9aaa52cf66
commit 1ebf69e163
3 changed files with 44 additions and 4 deletions

View file

@ -862,12 +862,29 @@ private def elabMatchCore (stx : Syntax) (expectedType? : Option Expr) : TermEla
let matchOptType := getMatchOptType stx
elabMatchAux discrStxs altViews matchOptType expectedType
private def isPatternVar (stx : Syntax) : TermElabM Bool := do
match (← resolveId? stx "pattern") with
| none => isAtomicIdent stx
| some f => match f with
| Expr.const fName _ _ =>
match (← getEnv).find? fName with
| some (ConstantInfo.ctorInfo _) => return false
| _ => isAtomicIdent stx
| _ => isAtomicIdent stx
where
isAtomicIdent (stx : Syntax) : Bool :=
stx.isIdent && stx.getId.eraseMacroScopes.isAtomic
-- parser! "match " >> sepBy1 termParser ", " >> optType >> " with " >> matchAlts
@[builtinTermElab «match»] def elabMatch : TermElab := fun stx expectedType? =>
@[builtinTermElab «match»] def elabMatch : TermElab := fun stx expectedType? => do
match stx with
| `(match $discr:term with | $y:ident => $rhs:term) => expandSimpleMatch stx discr y rhs expectedType?
| `(match $discr:term : $type with | $y:ident => $rhs:term) => expandSimpleMatchWithType stx discr y type rhs expectedType?
| _ => do
| `(match $discr:term with | $y:ident => $rhs:term) =>
if (← isPatternVar y) then expandSimpleMatch stx discr y rhs expectedType? else elabMatchDefault stx expectedType?
| `(match $discr:term : $type with | $y:ident => $rhs:term) =>
if (← isPatternVar y) then expandSimpleMatchWithType stx discr y type rhs expectedType? else elabMatchDefault stx expectedType?
| _ => elabMatchDefault stx expectedType?
where
elabMatchDefault (stx : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do
match (← expandNonAtomicDiscrs? stx) with
| some stxNew => withMacroExpansion stx stxNew $ elabTerm stxNew expectedType?
| none =>

17
tests/lean/patvar.lean Normal file
View file

@ -0,0 +1,17 @@
-- set_option trace.Elab true
def myId : List α → List α
| List.nil => List.nil
def constNil : List α → List α
| List.nil => List.nil
| List.cons x y => List.nil
def failing1 : List α → List α
| [] => List.nil
def failing2 : List α → List α
| x => List.nil
| foo.bar => List.nil -- "invalid pattern variable"
def myId2 : List α → List α
| foo.bar => foo.bar

View file

@ -0,0 +1,6 @@
patvar.lean:3:0: error: missing cases:
(List.cons _ _)
patvar.lean:10:0: error: missing cases:
(List.cons _ _)
patvar.lean:14:0: error: invalid pattern variable, must be atomic
patvar.lean:17:0: error: invalid pattern variable, must be atomic