diff --git a/src/Lean/Elab.lean b/src/Lean/Elab.lean index fbb4fa5682..99de7a8bb6 100644 --- a/src/Lean/Elab.lean +++ b/src/Lean/Elab.lean @@ -31,3 +31,4 @@ import Lean.Elab.Extra import Lean.Elab.GenInjective import Lean.Elab.BuiltinTerm import Lean.Elab.Arg +import Lean.Elab.PatternVar diff --git a/src/Lean/Elab/Do.lean b/src/Lean/Elab/Do.lean index e4c7a2a9f3..a6d1c14fbf 100644 --- a/src/Lean/Elab/Do.lean +++ b/src/Lean/Elab/Do.lean @@ -5,7 +5,7 @@ Authors: Leonardo de Moura -/ import Lean.Elab.Term import Lean.Elab.Binders -import Lean.Elab.Match +import Lean.Elab.PatternVar import Lean.Elab.Quotation.Util import Lean.Parser.Do diff --git a/src/Lean/Elab/Match.lean b/src/Lean/Elab/Match.lean index c17c25396a..5a4eedbd4e 100644 --- a/src/Lean/Elab/Match.lean +++ b/src/Lean/Elab/Match.lean @@ -11,26 +11,12 @@ import Lean.Meta.GeneralizeVars import Lean.Elab.SyntheticMVars import Lean.Elab.Arg import Lean.Parser.Term +import Lean.Elab.PatternVar namespace Lean.Elab.Term open Meta open Lean.Parser.Term -/- This modules assumes "match"-expressions use the following syntax. - -```lean -def matchDiscr := leading_parser optional (try (ident >> checkNoWsBefore "no space before ':'" >> ":")) >> termParser - -def «match» := leading_parser:leadPrec "match " >> sepBy1 matchDiscr ", " >> optType >> " with " >> matchAlts -``` --/ - -structure MatchAltView where - ref : Syntax - patterns : Array Syntax - rhs : Syntax - deriving Inhabited - private def expandSimpleMatch (stx discr lhsVar rhs : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do let newStx ← `(let $lhsVar := $discr; $rhs) withMacroExpansion stx newStx <| elabTerm newStx expectedType? @@ -180,29 +166,8 @@ private def getMatchAlts : Syntax → Array MatchAltView | _ => none | _ => #[] -inductive PatternVar where - | localVar (userName : Name) - -- anonymous variables (`_`) are encoded using metavariables - | anonymousVar (mvarId : MVarId) - -instance : ToString PatternVar := ⟨fun - | PatternVar.localVar x => toString x - | PatternVar.anonymousVar mvarId => s!"?m{mvarId}"⟩ - builtin_initialize Parser.registerBuiltinNodeKind `MVarWithIdKind -/-- - Create an auxiliary Syntax node wrapping a fresh metavariable id. - We use this kind of Syntax for representing `_` occurring in patterns. - The metavariables are created before we elaborate the patterns into `Expr`s. -/ -private def mkMVarSyntax : TermElabM Syntax := do - let mvarId ← mkFreshId - return Syntax.node `MVarWithIdKind #[Syntax.node mvarId #[]] - -/-- Given a syntax node constructed using `mkMVarSyntax`, return its MVarId -/ -private def getMVarSyntaxMVarId (stx : Syntax) : MVarId := - stx[0].getKind - open Meta.Match (mkInaccessible inaccessible?) /-- @@ -215,330 +180,6 @@ open Meta.Match (mkInaccessible inaccessible?) let e ← elabTerm stx[1] expectedType? return mkInaccessible e -/- - Patterns define new local variables. - This module collect them and preprocess `_` occurring in patterns. - Recall that an `_` may represent anonymous variables or inaccessible terms - that are implied by typing constraints. Thus, we represent them with fresh named holes `?x`. - After we elaborate the pattern, if the metavariable remains unassigned, we transform it into - a regular pattern variable. Otherwise, it becomes an inaccessible term. - - Macros occurring in patterns are expanded before the `collectPatternVars` method is executed. - The following kinds of Syntax are handled by this module - - Constructor applications - - Applications of functions tagged with the `[matchPattern]` attribute - - Identifiers - - Anonymous constructors - - Structure instances - - Inaccessible terms - - Named patterns - - Tuple literals - - Type ascriptions - - Literals: num, string and char --/ -namespace CollectPatternVars - -structure State where - found : NameSet := {} - vars : Array PatternVar := #[] - -abbrev M := StateRefT State TermElabM - -private def throwCtorExpected {α} : M α := - throwError "invalid pattern, constructor or constant marked with '[matchPattern]' expected" - -private def getNumExplicitCtorParams (ctorVal : ConstructorVal) : TermElabM Nat := - forallBoundedTelescope ctorVal.type ctorVal.numParams fun ps _ => do - let mut result := 0 - for p in ps do - let localDecl ← getLocalDecl p.fvarId! - if localDecl.binderInfo.isExplicit then - result := result+1 - pure result - -private def throwInvalidPattern {α} : M α := - throwError "invalid pattern" - -/- -An application in a pattern can be - -1- A constructor application - The elaborator assumes fields are accessible and inductive parameters are not accessible. - -2- A regular application `(f ...)` where `f` is tagged with `[matchPattern]`. - The elaborator assumes implicit arguments are not accessible and explicit ones are accessible. --/ - -structure Context where - funId : Syntax - ctorVal? : Option ConstructorVal -- It is `some`, if constructor application - explicit : Bool - ellipsis : Bool - paramDecls : Array (Name × BinderInfo) -- parameters names and binder information - paramDeclIdx : Nat := 0 - namedArgs : Array NamedArg - args : List Arg - newArgs : Array Syntax := #[] - deriving Inhabited - -private def isDone (ctx : Context) : Bool := - ctx.paramDeclIdx ≥ ctx.paramDecls.size - -private def finalize (ctx : Context) : M Syntax := do - if ctx.namedArgs.isEmpty && ctx.args.isEmpty then - let fStx ← `(@$(ctx.funId):ident) - return Syntax.mkApp fStx ctx.newArgs - else - throwError "too many arguments" - -private def isNextArgAccessible (ctx : Context) : Bool := - let i := ctx.paramDeclIdx - match ctx.ctorVal? with - | some ctorVal => i ≥ ctorVal.numParams -- For constructor applications only fields are accessible - | none => - if h : i < ctx.paramDecls.size then - -- For `[matchPattern]` applications, only explicit parameters are accessible. - let d := ctx.paramDecls.get ⟨i, h⟩ - d.2.isExplicit - else - false - -private def getNextParam (ctx : Context) : (Name × BinderInfo) × Context := - let i := ctx.paramDeclIdx - let d := ctx.paramDecls[i] - (d, { ctx with paramDeclIdx := ctx.paramDeclIdx + 1 }) - -private def processVar (idStx : Syntax) : M Syntax := do - unless idStx.isIdent do - throwErrorAt idStx "identifier expected" - let id := idStx.getId - unless id.eraseMacroScopes.isAtomic do - throwError "invalid pattern variable, must be atomic" - if (← get).found.contains id then - throwError "invalid pattern, variable '{id}' occurred more than once" - modify fun s => { s with vars := s.vars.push (PatternVar.localVar id), found := s.found.insert id } - return idStx - -private def nameToPattern : Name → TermElabM Syntax - | Name.anonymous => `(Name.anonymous) - | Name.str p s _ => do let p ← nameToPattern p; `(Name.str $p $(quote s) _) - | Name.num p n _ => do let p ← nameToPattern p; `(Name.num $p $(quote n) _) - -private def quotedNameToPattern (stx : Syntax) : TermElabM Syntax := - match stx[0].isNameLit? with - | some val => nameToPattern val - | none => throwIllFormedSyntax - -private def doubleQuotedNameToPattern (stx : Syntax) : TermElabM Syntax := do - match stx[1].isNameLit? with - | some val => nameToPattern (← resolveGlobalConstNoOverloadWithInfo stx[1] val) - | none => throwIllFormedSyntax - -partial def collect (stx : Syntax) : M Syntax := withRef stx <| withFreshMacroScope do - let k := stx.getKind - if k == identKind then - processId stx - else if k == ``Lean.Parser.Term.app then - processCtorApp stx - else if k == ``Lean.Parser.Term.anonymousCtor then - let elems ← stx[1].getArgs.mapSepElemsM collect - return stx.setArg 1 <| mkNullNode elems - else if k == ``Lean.Parser.Term.structInst then - /- - ``` - leading_parser "{" >> optional (atomic (termParser >> " with ")) - >> manyIndent (group (structInstField >> optional ", ")) - >> optional ".." - >> optional (" : " >> termParser) - >> " }" - ``` - -/ - let withMod := stx[1] - unless withMod.isNone do - throwErrorAt withMod "invalid struct instance pattern, 'with' is not allowed in patterns" - let fields ← stx[2].getArgs.mapM fun p => do - -- p is of the form (group (structInstField >> optional ", ")) - let field := p[0] - -- leading_parser structInstLVal >> " := " >> termParser - let newVal ← collect field[2] - let field := field.setArg 2 newVal - pure <| field.setArg 0 field - return stx.setArg 2 <| mkNullNode fields - else if k == ``Lean.Parser.Term.hole then - let r ← mkMVarSyntax - modify fun s => { s with vars := s.vars.push <| PatternVar.anonymousVar <| getMVarSyntaxMVarId r } - return r - else if k == ``Lean.Parser.Term.paren then - let arg := stx[1] - if arg.isNone then - return stx -- `()` - else - let t := arg[0] - let s := arg[1] - if s.isNone || s[0].getKind == ``Lean.Parser.Term.typeAscription then - -- Ignore `s`, since it empty or it is a type ascription - let t ← collect t - let arg := arg.setArg 0 t - return stx.setArg 1 arg - else - return stx - else if k == ``Lean.Parser.Term.explicitUniv then - processCtor stx[0] - else if k == ``Lean.Parser.Term.namedPattern then - /- Recall that - def namedPattern := check... >> trailing_parser "@" >> termParser -/ - let id := stx[0] - discard <| processVar id - let pat := stx[2] - let pat ← collect pat - `(_root_.namedPattern $id $pat) - else if k == ``Lean.Parser.Term.binop then - let lhs ← collect stx[2] - let rhs ← collect stx[3] - return stx.setArg 2 lhs |>.setArg 3 rhs - else if k == ``Lean.Parser.Term.inaccessible then - return stx - else if k == strLitKind then - return stx - else if k == numLitKind then - return stx - else if k == scientificLitKind then - return stx - else if k == charLitKind then - return stx - else if k == ``Lean.Parser.Term.quotedName then - /- Quoted names have an elaboration function associated with them, and they will not be macro expanded. - Note that macro expansion is not a good option since it produces a term using the smart constructors `Name.mkStr`, `Name.mkNum` - instead of the constructors `Name.str` and `Name.num` -/ - quotedNameToPattern stx - else if k == ``Lean.Parser.Term.doubleQuotedName then - /- Similar to previous case -/ - doubleQuotedNameToPattern stx - else if k == choiceKind then - throwError "invalid pattern, notation is ambiguous" - else - throwInvalidPattern - -where - - processCtorApp (stx : Syntax) : M Syntax := do - let (f, namedArgs, args, ellipsis) ← expandApp stx true - processCtorAppCore f namedArgs args ellipsis - - processCtor (stx : Syntax) : M Syntax := do - processCtorAppCore stx #[] #[] false - - /- Check whether `stx` is a pattern variable or constructor-like (i.e., constructor or constant tagged with `[matchPattern]` attribute) -/ - processId (stx : Syntax) : M Syntax := do - match (← resolveId? stx "pattern" (withInfo := true)) with - | none => processVar stx - | some f => match f with - | Expr.const fName _ _ => - match (← getEnv).find? fName with - | some (ConstantInfo.ctorInfo _) => processCtor stx - | some _ => - if hasMatchPatternAttribute (← getEnv) fName then - processCtor stx - else - processVar stx - | none => throwCtorExpected - | _ => processVar stx - - pushNewArg (accessible : Bool) (ctx : Context) (arg : Arg) : M Context := do - match arg with - | Arg.stx stx => - let stx ← if accessible then collect stx else pure stx - return { ctx with newArgs := ctx.newArgs.push stx } - | _ => unreachable! - - processExplicitArg (accessible : Bool) (ctx : Context) : M Context := do - match ctx.args with - | [] => - if ctx.ellipsis then - pushNewArg accessible ctx (Arg.stx (← `(_))) - else - throwError "explicit parameter is missing, unused named arguments {ctx.namedArgs.map fun narg => narg.name}" - | arg::args => - pushNewArg accessible { ctx with args := args } arg - - processImplicitArg (accessible : Bool) (ctx : Context) : M Context := do - if ctx.explicit then - processExplicitArg accessible ctx - else - pushNewArg accessible ctx (Arg.stx (← `(_))) - - processCtorAppContext (ctx : Context) : M Syntax := do - if isDone ctx then - finalize ctx - else - let accessible := isNextArgAccessible ctx - let (d, ctx) := getNextParam ctx - match ctx.namedArgs.findIdx? fun namedArg => namedArg.name == d.1 with - | some idx => - let arg := ctx.namedArgs[idx] - let ctx := { ctx with namedArgs := ctx.namedArgs.eraseIdx idx } - let ctx ← pushNewArg accessible ctx arg.val - processCtorAppContext ctx - | none => - let ctx ← match d.2 with - | BinderInfo.implicit => processImplicitArg accessible ctx - | BinderInfo.instImplicit => processImplicitArg accessible ctx - | _ => processExplicitArg accessible ctx - processCtorAppContext ctx - - processCtorAppCore (f : Syntax) (namedArgs : Array NamedArg) (args : Array Arg) (ellipsis : Bool) : M Syntax := do - let args := args.toList - let (fId, explicit) ← match f with - | `($fId:ident) => pure (fId, false) - | `(@$fId:ident) => pure (fId, true) - | _ => throwError "identifier expected" - let some (Expr.const fName _ _) ← resolveId? fId "pattern" (withInfo := true) | throwCtorExpected - let fInfo ← getConstInfo fName - let paramDecls ← forallTelescopeReducing fInfo.type fun xs _ => xs.mapM fun x => do - let d ← getFVarLocalDecl x - return (d.userName, d.binderInfo) - match fInfo with - | ConstantInfo.ctorInfo val => - processCtorAppContext - { funId := fId, explicit := explicit, ctorVal? := val, paramDecls := paramDecls, namedArgs := namedArgs, args := args, ellipsis := ellipsis } - | _ => - if hasMatchPatternAttribute (← getEnv) fName then - processCtorAppContext - { funId := fId, explicit := explicit, ctorVal? := none, paramDecls := paramDecls, namedArgs := namedArgs, args := args, ellipsis := ellipsis } - else - throwCtorExpected - -def main (alt : MatchAltView) : M MatchAltView := do - let patterns ← alt.patterns.mapM fun p => do - trace[Elab.match] "collecting variables at pattern: {p}" - collect p - return { alt with patterns := patterns } - -end CollectPatternVars - -private def collectPatternVars (alt : MatchAltView) : TermElabM (Array PatternVar × MatchAltView) := do - let (alt, s) ← (CollectPatternVars.main alt).run {} - return (s.vars, alt) - -/- Return the pattern variables in the given pattern. - Remark: this method is not used by the main `match` elaborator, but in the precheck hook and other macros (e.g., at `Do.lean`). -/ -def getPatternVars (patternStx : Syntax) : TermElabM (Array PatternVar) := do - let patternStx ← liftMacroM <| expandMacros patternStx - let (_, s) ← (CollectPatternVars.collect patternStx).run {} - return s.vars - -def getPatternsVars (patterns : Array Syntax) : TermElabM (Array PatternVar) := do - let collect : CollectPatternVars.M Unit := do - for pattern in patterns do - discard <| CollectPatternVars.collect (← liftMacroM <| expandMacros pattern) - let (_, s) ← collect.run {} - return s.vars - -def getPatternVarNames (pvars : Array PatternVar) : Array Name := - pvars.filterMap fun - | PatternVar.localVar x => some x - | _ => none - open Lean.Elab.Term.Quotation in @[builtinQuotPrecheck Lean.Parser.Term.match] def precheckMatch : Precheck | `(match $[$discrs:term],* with $[| $[$patss],* => $rhss]*) => do diff --git a/src/Lean/Elab/MatchAltView.lean b/src/Lean/Elab/MatchAltView.lean new file mode 100644 index 0000000000..40c47578c9 --- /dev/null +++ b/src/Lean/Elab/MatchAltView.lean @@ -0,0 +1,25 @@ +/- +Copyright (c) 2021 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +import Lean.Elab.Term + +namespace Lean.Elab.Term + +/- This modules assumes "match"-expressions use the following syntax. + +```lean +def matchDiscr := leading_parser optional (try (ident >> checkNoWsBefore "no space before ':'" >> ":")) >> termParser + +def «match» := leading_parser:leadPrec "match " >> sepBy1 matchDiscr ", " >> optType >> " with " >> matchAlts +``` +-/ + +structure MatchAltView where + ref : Syntax + patterns : Array Syntax + rhs : Syntax + deriving Inhabited + +end Lean.Elab.Term diff --git a/src/Lean/Elab/PatternVar.lean b/src/Lean/Elab/PatternVar.lean new file mode 100644 index 0000000000..555c1b8f13 --- /dev/null +++ b/src/Lean/Elab/PatternVar.lean @@ -0,0 +1,359 @@ +/- +Copyright (c) 2021 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +import Lean.Meta.Match.MatchPatternAttr +import Lean.Elab.Arg +import Lean.Elab.MatchAltView + +namespace Lean.Elab.Term + +open Meta + +inductive PatternVar where + | localVar (userName : Name) + -- anonymous variables (`_`) are encoded using metavariables + | anonymousVar (mvarId : MVarId) + +instance : ToString PatternVar := ⟨fun + | PatternVar.localVar x => toString x + | PatternVar.anonymousVar mvarId => s!"?m{mvarId}"⟩ + +/-- + Create an auxiliary Syntax node wrapping a fresh metavariable id. + We use this kind of Syntax for representing `_` occurring in patterns. + The metavariables are created before we elaborate the patterns into `Expr`s. -/ +private def mkMVarSyntax : TermElabM Syntax := do + let mvarId ← mkFreshId + return Syntax.node `MVarWithIdKind #[Syntax.node mvarId #[]] + +/-- Given a syntax node constructed using `mkMVarSyntax`, return its MVarId -/ +def getMVarSyntaxMVarId (stx : Syntax) : MVarId := + stx[0].getKind + +/- + Patterns define new local variables. + This module collect them and preprocess `_` occurring in patterns. + Recall that an `_` may represent anonymous variables or inaccessible terms + that are implied by typing constraints. Thus, we represent them with fresh named holes `?x`. + After we elaborate the pattern, if the metavariable remains unassigned, we transform it into + a regular pattern variable. Otherwise, it becomes an inaccessible term. + + Macros occurring in patterns are expanded before the `collectPatternVars` method is executed. + The following kinds of Syntax are handled by this module + - Constructor applications + - Applications of functions tagged with the `[matchPattern]` attribute + - Identifiers + - Anonymous constructors + - Structure instances + - Inaccessible terms + - Named patterns + - Tuple literals + - Type ascriptions + - Literals: num, string and char +-/ +namespace CollectPatternVars + +structure State where + found : NameSet := {} + vars : Array PatternVar := #[] + +abbrev M := StateRefT State TermElabM + +private def throwCtorExpected {α} : M α := + throwError "invalid pattern, constructor or constant marked with '[matchPattern]' expected" + +private def getNumExplicitCtorParams (ctorVal : ConstructorVal) : TermElabM Nat := + forallBoundedTelescope ctorVal.type ctorVal.numParams fun ps _ => do + let mut result := 0 + for p in ps do + let localDecl ← getLocalDecl p.fvarId! + if localDecl.binderInfo.isExplicit then + result := result+1 + pure result + +private def throwInvalidPattern {α} : M α := + throwError "invalid pattern" + +/- +An application in a pattern can be + +1- A constructor application + The elaborator assumes fields are accessible and inductive parameters are not accessible. + +2- A regular application `(f ...)` where `f` is tagged with `[matchPattern]`. + The elaborator assumes implicit arguments are not accessible and explicit ones are accessible. +-/ + +structure Context where + funId : Syntax + ctorVal? : Option ConstructorVal -- It is `some`, if constructor application + explicit : Bool + ellipsis : Bool + paramDecls : Array (Name × BinderInfo) -- parameters names and binder information + paramDeclIdx : Nat := 0 + namedArgs : Array NamedArg + args : List Arg + newArgs : Array Syntax := #[] + deriving Inhabited + +private def isDone (ctx : Context) : Bool := + ctx.paramDeclIdx ≥ ctx.paramDecls.size + +private def finalize (ctx : Context) : M Syntax := do + if ctx.namedArgs.isEmpty && ctx.args.isEmpty then + let fStx ← `(@$(ctx.funId):ident) + return Syntax.mkApp fStx ctx.newArgs + else + throwError "too many arguments" + +private def isNextArgAccessible (ctx : Context) : Bool := + let i := ctx.paramDeclIdx + match ctx.ctorVal? with + | some ctorVal => i ≥ ctorVal.numParams -- For constructor applications only fields are accessible + | none => + if h : i < ctx.paramDecls.size then + -- For `[matchPattern]` applications, only explicit parameters are accessible. + let d := ctx.paramDecls.get ⟨i, h⟩ + d.2.isExplicit + else + false + +private def getNextParam (ctx : Context) : (Name × BinderInfo) × Context := + let i := ctx.paramDeclIdx + let d := ctx.paramDecls[i] + (d, { ctx with paramDeclIdx := ctx.paramDeclIdx + 1 }) + +private def processVar (idStx : Syntax) : M Syntax := do + unless idStx.isIdent do + throwErrorAt idStx "identifier expected" + let id := idStx.getId + unless id.eraseMacroScopes.isAtomic do + throwError "invalid pattern variable, must be atomic" + if (← get).found.contains id then + throwError "invalid pattern, variable '{id}' occurred more than once" + modify fun s => { s with vars := s.vars.push (PatternVar.localVar id), found := s.found.insert id } + return idStx + +private def nameToPattern : Name → TermElabM Syntax + | Name.anonymous => `(Name.anonymous) + | Name.str p s _ => do let p ← nameToPattern p; `(Name.str $p $(quote s) _) + | Name.num p n _ => do let p ← nameToPattern p; `(Name.num $p $(quote n) _) + +private def quotedNameToPattern (stx : Syntax) : TermElabM Syntax := + match stx[0].isNameLit? with + | some val => nameToPattern val + | none => throwIllFormedSyntax + +private def doubleQuotedNameToPattern (stx : Syntax) : TermElabM Syntax := do + match stx[1].isNameLit? with + | some val => nameToPattern (← resolveGlobalConstNoOverloadWithInfo stx[1] val) + | none => throwIllFormedSyntax + +partial def collect (stx : Syntax) : M Syntax := withRef stx <| withFreshMacroScope do + let k := stx.getKind + if k == identKind then + processId stx + else if k == ``Lean.Parser.Term.app then + processCtorApp stx + else if k == ``Lean.Parser.Term.anonymousCtor then + let elems ← stx[1].getArgs.mapSepElemsM collect + return stx.setArg 1 <| mkNullNode elems + else if k == ``Lean.Parser.Term.structInst then + /- + ``` + leading_parser "{" >> optional (atomic (termParser >> " with ")) + >> manyIndent (group (structInstField >> optional ", ")) + >> optional ".." + >> optional (" : " >> termParser) + >> " }" + ``` + -/ + let withMod := stx[1] + unless withMod.isNone do + throwErrorAt withMod "invalid struct instance pattern, 'with' is not allowed in patterns" + let fields ← stx[2].getArgs.mapM fun p => do + -- p is of the form (group (structInstField >> optional ", ")) + let field := p[0] + -- leading_parser structInstLVal >> " := " >> termParser + let newVal ← collect field[2] + let field := field.setArg 2 newVal + pure <| field.setArg 0 field + return stx.setArg 2 <| mkNullNode fields + else if k == ``Lean.Parser.Term.hole then + let r ← mkMVarSyntax + modify fun s => { s with vars := s.vars.push <| PatternVar.anonymousVar <| getMVarSyntaxMVarId r } + return r + else if k == ``Lean.Parser.Term.paren then + let arg := stx[1] + if arg.isNone then + return stx -- `()` + else + let t := arg[0] + let s := arg[1] + if s.isNone || s[0].getKind == ``Lean.Parser.Term.typeAscription then + -- Ignore `s`, since it empty or it is a type ascription + let t ← collect t + let arg := arg.setArg 0 t + return stx.setArg 1 arg + else + return stx + else if k == ``Lean.Parser.Term.explicitUniv then + processCtor stx[0] + else if k == ``Lean.Parser.Term.namedPattern then + /- Recall that + def namedPattern := check... >> trailing_parser "@" >> termParser -/ + let id := stx[0] + discard <| processVar id + let pat := stx[2] + let pat ← collect pat + `(_root_.namedPattern $id $pat) + else if k == ``Lean.Parser.Term.binop then + let lhs ← collect stx[2] + let rhs ← collect stx[3] + return stx.setArg 2 lhs |>.setArg 3 rhs + else if k == ``Lean.Parser.Term.inaccessible then + return stx + else if k == strLitKind then + return stx + else if k == numLitKind then + return stx + else if k == scientificLitKind then + return stx + else if k == charLitKind then + return stx + else if k == ``Lean.Parser.Term.quotedName then + /- Quoted names have an elaboration function associated with them, and they will not be macro expanded. + Note that macro expansion is not a good option since it produces a term using the smart constructors `Name.mkStr`, `Name.mkNum` + instead of the constructors `Name.str` and `Name.num` -/ + quotedNameToPattern stx + else if k == ``Lean.Parser.Term.doubleQuotedName then + /- Similar to previous case -/ + doubleQuotedNameToPattern stx + else if k == choiceKind then + throwError "invalid pattern, notation is ambiguous" + else + throwInvalidPattern + +where + + processCtorApp (stx : Syntax) : M Syntax := do + let (f, namedArgs, args, ellipsis) ← expandApp stx true + processCtorAppCore f namedArgs args ellipsis + + processCtor (stx : Syntax) : M Syntax := do + processCtorAppCore stx #[] #[] false + + /- Check whether `stx` is a pattern variable or constructor-like (i.e., constructor or constant tagged with `[matchPattern]` attribute) -/ + processId (stx : Syntax) : M Syntax := do + match (← resolveId? stx "pattern" (withInfo := true)) with + | none => processVar stx + | some f => match f with + | Expr.const fName _ _ => + match (← getEnv).find? fName with + | some (ConstantInfo.ctorInfo _) => processCtor stx + | some _ => + if hasMatchPatternAttribute (← getEnv) fName then + processCtor stx + else + processVar stx + | none => throwCtorExpected + | _ => processVar stx + + pushNewArg (accessible : Bool) (ctx : Context) (arg : Arg) : M Context := do + match arg with + | Arg.stx stx => + let stx ← if accessible then collect stx else pure stx + return { ctx with newArgs := ctx.newArgs.push stx } + | _ => unreachable! + + processExplicitArg (accessible : Bool) (ctx : Context) : M Context := do + match ctx.args with + | [] => + if ctx.ellipsis then + pushNewArg accessible ctx (Arg.stx (← `(_))) + else + throwError "explicit parameter is missing, unused named arguments {ctx.namedArgs.map fun narg => narg.name}" + | arg::args => + pushNewArg accessible { ctx with args := args } arg + + processImplicitArg (accessible : Bool) (ctx : Context) : M Context := do + if ctx.explicit then + processExplicitArg accessible ctx + else + pushNewArg accessible ctx (Arg.stx (← `(_))) + + processCtorAppContext (ctx : Context) : M Syntax := do + if isDone ctx then + finalize ctx + else + let accessible := isNextArgAccessible ctx + let (d, ctx) := getNextParam ctx + match ctx.namedArgs.findIdx? fun namedArg => namedArg.name == d.1 with + | some idx => + let arg := ctx.namedArgs[idx] + let ctx := { ctx with namedArgs := ctx.namedArgs.eraseIdx idx } + let ctx ← pushNewArg accessible ctx arg.val + processCtorAppContext ctx + | none => + let ctx ← match d.2 with + | BinderInfo.implicit => processImplicitArg accessible ctx + | BinderInfo.instImplicit => processImplicitArg accessible ctx + | _ => processExplicitArg accessible ctx + processCtorAppContext ctx + + processCtorAppCore (f : Syntax) (namedArgs : Array NamedArg) (args : Array Arg) (ellipsis : Bool) : M Syntax := do + let args := args.toList + let (fId, explicit) ← match f with + | `($fId:ident) => pure (fId, false) + | `(@$fId:ident) => pure (fId, true) + | _ => throwError "identifier expected" + let some (Expr.const fName _ _) ← resolveId? fId "pattern" (withInfo := true) | throwCtorExpected + let fInfo ← getConstInfo fName + let paramDecls ← forallTelescopeReducing fInfo.type fun xs _ => xs.mapM fun x => do + let d ← getFVarLocalDecl x + return (d.userName, d.binderInfo) + match fInfo with + | ConstantInfo.ctorInfo val => + processCtorAppContext + { funId := fId, explicit := explicit, ctorVal? := val, paramDecls := paramDecls, namedArgs := namedArgs, args := args, ellipsis := ellipsis } + | _ => + if hasMatchPatternAttribute (← getEnv) fName then + processCtorAppContext + { funId := fId, explicit := explicit, ctorVal? := none, paramDecls := paramDecls, namedArgs := namedArgs, args := args, ellipsis := ellipsis } + else + throwCtorExpected + +def main (alt : MatchAltView) : M MatchAltView := do + let patterns ← alt.patterns.mapM fun p => do + trace[Elab.match] "collecting variables at pattern: {p}" + collect p + return { alt with patterns := patterns } + +end CollectPatternVars + +def collectPatternVars (alt : MatchAltView) : TermElabM (Array PatternVar × MatchAltView) := do + let (alt, s) ← (CollectPatternVars.main alt).run {} + return (s.vars, alt) + +/- Return the pattern variables in the given pattern. + Remark: this method is not used by the main `match` elaborator, but in the precheck hook and other macros (e.g., at `Do.lean`). -/ +def getPatternVars (patternStx : Syntax) : TermElabM (Array PatternVar) := do + let patternStx ← liftMacroM <| expandMacros patternStx + let (_, s) ← (CollectPatternVars.collect patternStx).run {} + return s.vars + +def getPatternsVars (patterns : Array Syntax) : TermElabM (Array PatternVar) := do + let collect : CollectPatternVars.M Unit := do + for pattern in patterns do + discard <| CollectPatternVars.collect (← liftMacroM <| expandMacros pattern) + let (_, s) ← collect.run {} + return s.vars + +def getPatternVarNames (pvars : Array PatternVar) : Array Name := + pvars.filterMap fun + | PatternVar.localVar x => some x + | _ => none + +end Lean.Elab.Term