358 lines
14 KiB
Text
358 lines
14 KiB
Text
/-
|
||
Copyright (c) 2021 Microsoft Corporation. All rights reserved.
|
||
Released under Apache 2.0 license as described in the file LICENSE.
|
||
Authors: Leonardo de Moura
|
||
-/
|
||
import Lean.Meta.Match.MatchPatternAttr
|
||
import Lean.Elab.Arg
|
||
import Lean.Elab.MatchAltView
|
||
|
||
namespace Lean.Elab.Term
|
||
|
||
open Meta
|
||
|
||
abbrev PatternVar := Syntax -- TODO: should be `Ident`
|
||
|
||
/-!
|
||
Patterns define new local variables.
|
||
This module collect them and preprocess `_` occurring in patterns.
|
||
Recall that an `_` may represent anonymous variables or inaccessible terms
|
||
that are implied by typing constraints. Thus, we represent them with fresh named holes `?x`.
|
||
After we elaborate the pattern, if the metavariable remains unassigned, we transform it into
|
||
a regular pattern variable. Otherwise, it becomes an inaccessible term.
|
||
|
||
Macros occurring in patterns are expanded before the `collectPatternVars` method is executed.
|
||
The following kinds of Syntax are handled by this module
|
||
- Constructor applications
|
||
- Applications of functions tagged with the `[matchPattern]` attribute
|
||
- Identifiers
|
||
- Anonymous constructors
|
||
- Structure instances
|
||
- Inaccessible terms
|
||
- Named patterns
|
||
- Tuple literals
|
||
- Type ascriptions
|
||
- Literals: num, string and char
|
||
-/
|
||
namespace CollectPatternVars
|
||
|
||
/-- State for the pattern variable collector monad. -/
|
||
structure State where
|
||
/-- Pattern variables found so far. -/
|
||
found : NameSet := {}
|
||
/-- Pattern variables found so far as an array. It contains the order they were found. -/
|
||
vars : Array PatternVar := #[]
|
||
deriving Inhabited
|
||
|
||
abbrev M := StateRefT State TermElabM
|
||
|
||
private def throwCtorExpected {α} : M α :=
|
||
throwError "invalid pattern, constructor or constant marked with '[matchPattern]' expected"
|
||
|
||
private def throwInvalidPattern {α} : M α :=
|
||
throwError "invalid pattern"
|
||
|
||
/-!
|
||
An application in a pattern can be
|
||
|
||
1- A constructor application
|
||
The elaborator assumes fields are accessible and inductive parameters are not accessible.
|
||
|
||
2- A regular application `(f ...)` where `f` is tagged with `[matchPattern]`.
|
||
The elaborator assumes implicit arguments are not accessible and explicit ones are accessible.
|
||
-/
|
||
|
||
structure Context where
|
||
funId : Ident
|
||
ctorVal? : Option ConstructorVal -- It is `some`, if constructor application
|
||
explicit : Bool
|
||
ellipsis : Bool
|
||
paramDecls : Array (Name × BinderInfo) -- parameters names and binder information
|
||
paramDeclIdx : Nat := 0
|
||
namedArgs : Array NamedArg
|
||
args : List Arg
|
||
newArgs : Array Term := #[]
|
||
deriving Inhabited
|
||
|
||
private def isDone (ctx : Context) : Bool :=
|
||
ctx.paramDeclIdx ≥ ctx.paramDecls.size
|
||
|
||
private def finalize (ctx : Context) : M Syntax := do
|
||
if ctx.namedArgs.isEmpty && ctx.args.isEmpty then
|
||
let fStx ← `(@$(ctx.funId):ident)
|
||
return Syntax.mkApp fStx ctx.newArgs
|
||
else
|
||
throwError "too many arguments"
|
||
|
||
private def isNextArgAccessible (ctx : Context) : Bool :=
|
||
let i := ctx.paramDeclIdx
|
||
match ctx.ctorVal? with
|
||
| some ctorVal => i ≥ ctorVal.numParams -- For constructor applications only fields are accessible
|
||
| none =>
|
||
if h : i < ctx.paramDecls.size then
|
||
-- For `[matchPattern]` applications, only explicit parameters are accessible.
|
||
let d := ctx.paramDecls.get ⟨i, h⟩
|
||
d.2.isExplicit
|
||
else
|
||
false
|
||
|
||
private def getNextParam (ctx : Context) : (Name × BinderInfo) × Context :=
|
||
let i := ctx.paramDeclIdx
|
||
let d := ctx.paramDecls[i]!
|
||
(d, { ctx with paramDeclIdx := ctx.paramDeclIdx + 1 })
|
||
|
||
private def processVar (idStx : Syntax) : M Syntax := do
|
||
unless idStx.isIdent do
|
||
throwErrorAt idStx "identifier expected"
|
||
let id := idStx.getId
|
||
unless id.eraseMacroScopes.isAtomic do
|
||
throwError "invalid pattern variable, must be atomic"
|
||
if (← get).found.contains id then
|
||
throwError "invalid pattern, variable '{id}' occurred more than once"
|
||
modify fun s => { s with vars := s.vars.push idStx, found := s.found.insert id }
|
||
return idStx
|
||
|
||
private def samePatternsVariables (startingAt : Nat) (s₁ s₂ : State) : Bool :=
|
||
if h : s₁.vars.size = s₂.vars.size then
|
||
Array.isEqvAux s₁.vars s₂.vars h (.==.) startingAt
|
||
else
|
||
false
|
||
|
||
open TSyntax.Compat in
|
||
partial def collect (stx : Syntax) : M Syntax := withRef stx <| withFreshMacroScope do
|
||
let k := stx.getKind
|
||
if k == identKind then
|
||
processId stx
|
||
else if k == ``Lean.Parser.Term.app then
|
||
processCtorApp stx
|
||
else if k == ``Lean.Parser.Term.anonymousCtor then
|
||
let elems ← stx[1].getArgs.mapSepElemsM collect
|
||
return stx.setArg 1 <| mkNullNode elems
|
||
else if k == ``Lean.Parser.Term.dotIdent then
|
||
return stx
|
||
else if k == ``Lean.Parser.Term.hole then
|
||
`(.( $stx ))
|
||
else if k == ``Lean.Parser.Term.syntheticHole then
|
||
`(.( $stx ))
|
||
else if k == ``Lean.Parser.Term.paren then
|
||
let arg := stx[1]
|
||
if arg.isNone then
|
||
return stx -- `()`
|
||
else
|
||
let t := arg[0]
|
||
let s := arg[1]
|
||
if s.isNone || s[0].getKind == ``Lean.Parser.Term.typeAscription then
|
||
-- Ignore `s`, since it empty or it is a type ascription
|
||
let t ← collect t
|
||
let arg := arg.setArg 0 t
|
||
return stx.setArg 1 arg
|
||
else
|
||
return stx
|
||
else if k == ``Lean.Parser.Term.explicitUniv then
|
||
processCtor stx[0]
|
||
else if k == ``Lean.Parser.Term.namedPattern then
|
||
/- Recall that
|
||
```
|
||
def namedPattern := check... >> trailing_parser "@" >> optional (atomic (ident >> ":")) >> termParser
|
||
```
|
||
TODO: pattern variable for equality proof
|
||
-/
|
||
let id := stx[0]
|
||
discard <| processVar id
|
||
let h ← if stx[2].isNone then
|
||
`(h)
|
||
else
|
||
pure stx[2][0]
|
||
let pat := stx[3]
|
||
let pat ← collect pat
|
||
discard <| processVar h
|
||
``(_root_.namedPattern $id $pat $h)
|
||
else if k == ``Lean.Parser.Term.binop then
|
||
let lhs ← collect stx[2]
|
||
let rhs ← collect stx[3]
|
||
return stx.setArg 2 lhs |>.setArg 3 rhs
|
||
else if k == ``Lean.Parser.Term.inaccessible then
|
||
return stx
|
||
else if k == strLitKind then
|
||
return stx
|
||
else if k == numLitKind then
|
||
return stx
|
||
else if k == scientificLitKind then
|
||
return stx
|
||
else if k == charLitKind then
|
||
return stx
|
||
else if k == ``Lean.Parser.Term.quotedName || k == ``Lean.Parser.Term.doubleQuotedName then
|
||
return stx
|
||
else if k == choiceKind then
|
||
/- Remark: If there are `Term.structInst` alternatives, we keep only them. This is a hack to get rid of
|
||
Set-like notation in patterns. Recall that in Mathlib `{a, b}` can be a set with two elements or the
|
||
structure instance `{ a := a, b := b }`. Possible alternative solution: add a `pattern` category, or at least register
|
||
the `Syntax` node kinds that are allowed in patterns. -/
|
||
let args :=
|
||
let args := stx.getArgs
|
||
if args.any (·.isOfKind ``Parser.Term.structInst) then
|
||
args.filter (·.isOfKind ``Parser.Term.structInst)
|
||
else
|
||
args
|
||
let stateSaved ← get
|
||
let arg0 ← collect args[0]!
|
||
let stateNew ← get
|
||
let mut argsNew := #[arg0]
|
||
for arg in args[1:] do
|
||
set stateSaved
|
||
argsNew := argsNew.push (← collect arg)
|
||
unless samePatternsVariables stateSaved.vars.size stateNew (← get) do
|
||
throwError "invalid pattern, overloaded notation is only allowed when all alternative have the same set of pattern variables"
|
||
set stateNew
|
||
return mkNode choiceKind argsNew
|
||
else match stx with
|
||
| `({ $[$srcs?,* with]? $fields,* $[..%$ell?]? $[: $ty?]? }) =>
|
||
if let some srcs := srcs? then
|
||
throwErrorAt (mkNullNode srcs) "invalid struct instance pattern, 'with' is not allowed in patterns"
|
||
let fields ← fields.getElems.mapM fun
|
||
| `(Parser.Term.structInstField| $lval:structInstLVal := $val) => do
|
||
let newVal ← collect val
|
||
`(Parser.Term.structInstField| $lval:structInstLVal := $newVal)
|
||
| _ => throwInvalidPattern -- `structInstFieldAbbrev` should be expanded at this point
|
||
`({ $[$srcs?,* with]? $fields,* $[..%$ell?]? $[: $ty?]? })
|
||
| _ => throwInvalidPattern
|
||
|
||
where
|
||
|
||
processCtorApp (stx : Syntax) : M Syntax := do
|
||
let (f, namedArgs, args, ellipsis) ← expandApp stx
|
||
if f.getKind == ``Parser.Term.dotIdent then
|
||
let namedArgsNew ← namedArgs.mapM fun
|
||
| { ref, name, val := Arg.stx arg } => withRef ref do `(Lean.Parser.Term.namedArgument| ($(mkIdentFrom ref name) := $(← collect arg)))
|
||
| _ => unreachable!
|
||
let mut argsNew ← args.mapM fun | Arg.stx arg => collect arg | _ => unreachable!
|
||
if ellipsis then
|
||
argsNew := argsNew.push (mkNode ``Parser.Term.ellipsis #[mkAtomFrom stx ".."])
|
||
return Syntax.mkApp f (namedArgsNew ++ argsNew)
|
||
else
|
||
processCtorAppCore f namedArgs args ellipsis
|
||
|
||
processCtor (stx : Syntax) : M Syntax := do
|
||
processCtorAppCore stx #[] #[] false
|
||
|
||
/-- Check whether `stx` is a pattern variable or constructor-like (i.e., constructor or constant tagged with `[matchPattern]` attribute) -/
|
||
processId (stx : Syntax) : M Syntax := do
|
||
match (← resolveId? stx "pattern") with
|
||
| none => processVar stx
|
||
| some f => match f with
|
||
| Expr.const fName _ =>
|
||
match (← getEnv).find? fName with
|
||
| some (ConstantInfo.ctorInfo _) => processCtor stx
|
||
| some _ =>
|
||
if hasMatchPatternAttribute (← getEnv) fName then
|
||
processCtor stx
|
||
else
|
||
processVar stx
|
||
| none => throwCtorExpected
|
||
| _ => processVar stx
|
||
|
||
pushNewArg (accessible : Bool) (ctx : Context) (arg : Arg) : M Context := do
|
||
match arg with
|
||
| Arg.stx stx =>
|
||
let stx ← if accessible then collect stx else pure stx
|
||
return { ctx with newArgs := ctx.newArgs.push stx }
|
||
| _ => unreachable!
|
||
|
||
processExplicitArg (accessible : Bool) (ctx : Context) : M Context := do
|
||
match ctx.args with
|
||
| [] =>
|
||
if ctx.ellipsis then
|
||
pushNewArg accessible ctx (Arg.stx (← `(_)))
|
||
else
|
||
throwError "explicit parameter is missing, unused named arguments {ctx.namedArgs.map fun narg => narg.name}"
|
||
| arg::args =>
|
||
pushNewArg accessible { ctx with args := args } arg
|
||
|
||
processImplicitArg (accessible : Bool) (ctx : Context) : M Context := do
|
||
if ctx.explicit then
|
||
processExplicitArg accessible ctx
|
||
else
|
||
pushNewArg accessible ctx (Arg.stx (← `(_)))
|
||
|
||
processCtorAppContext (ctx : Context) : M Syntax := do
|
||
if isDone ctx then
|
||
finalize ctx
|
||
else
|
||
let accessible := isNextArgAccessible ctx
|
||
let (d, ctx) := getNextParam ctx
|
||
match ctx.namedArgs.findIdx? fun namedArg => namedArg.name == d.1 with
|
||
| some idx =>
|
||
let arg := ctx.namedArgs[idx]!
|
||
let ctx := { ctx with namedArgs := ctx.namedArgs.eraseIdx idx }
|
||
let ctx ← pushNewArg accessible ctx arg.val
|
||
processCtorAppContext ctx
|
||
| none =>
|
||
let ctx ← match d.2 with
|
||
| BinderInfo.implicit => processImplicitArg accessible ctx
|
||
| BinderInfo.strictImplicit => processImplicitArg accessible ctx
|
||
| BinderInfo.instImplicit => processImplicitArg accessible ctx
|
||
| _ => processExplicitArg accessible ctx
|
||
processCtorAppContext ctx
|
||
|
||
processCtorAppCore (f : Syntax) (namedArgs : Array NamedArg) (args : Array Arg) (ellipsis : Bool) : M Syntax := do
|
||
let args := args.toList
|
||
let (fId, explicit) ← match f with
|
||
| `($fId:ident) => pure (fId, false)
|
||
| `(@$fId:ident) => pure (fId, true)
|
||
| _ => throwError "identifier expected"
|
||
let some (Expr.const fName _) ← resolveId? fId "pattern" (withInfo := true) | throwCtorExpected
|
||
let fInfo ← getConstInfo fName
|
||
let paramDecls ← forallTelescopeReducing fInfo.type fun xs _ => xs.mapM fun x => do
|
||
let d ← getFVarLocalDecl x
|
||
return (d.userName, d.binderInfo)
|
||
match fInfo with
|
||
| ConstantInfo.ctorInfo val =>
|
||
processCtorAppContext
|
||
{ funId := fId, explicit := explicit, ctorVal? := val, paramDecls := paramDecls, namedArgs := namedArgs, args := args, ellipsis := ellipsis }
|
||
| _ =>
|
||
if hasMatchPatternAttribute (← getEnv) fName then
|
||
processCtorAppContext
|
||
{ funId := fId, explicit := explicit, ctorVal? := none, paramDecls := paramDecls, namedArgs := namedArgs, args := args, ellipsis := ellipsis }
|
||
else
|
||
throwCtorExpected
|
||
|
||
def main (alt : MatchAltView) : M MatchAltView := do
|
||
let patterns ← alt.patterns.mapM fun p => do
|
||
trace[Elab.match] "collecting variables at pattern: {p}"
|
||
collect p
|
||
return { alt with patterns := patterns }
|
||
|
||
end CollectPatternVars
|
||
|
||
/--
|
||
Collect pattern variables occurring in the `match`-alternative object views.
|
||
It also returns the updated views.
|
||
-/
|
||
def collectPatternVars (alt : MatchAltView) : TermElabM (Array PatternVar × MatchAltView) := do
|
||
let (alt, s) ← (CollectPatternVars.main alt).run {}
|
||
return (s.vars, alt)
|
||
|
||
/--
|
||
Return the pattern variables in the given pattern.
|
||
Remark: this method is not used by the main `match` elaborator, but in the precheck hook and other macros (e.g., at `Do.lean`).
|
||
-/
|
||
def getPatternVars (patternStx : Syntax) : TermElabM (Array PatternVar) := do
|
||
let patternStx ← liftMacroM <| expandMacros patternStx
|
||
let (_, s) ← (CollectPatternVars.collect patternStx).run {}
|
||
return s.vars
|
||
|
||
/--
|
||
Return the pattern variables occurring in the given patterns.
|
||
This method is used in the `match` and `do` notation elaborators
|
||
-/
|
||
def getPatternsVars (patterns : Array Syntax) : TermElabM (Array PatternVar) := do
|
||
let collect : CollectPatternVars.M Unit := do
|
||
for pattern in patterns do
|
||
discard <| CollectPatternVars.collect (← liftMacroM <| expandMacros pattern)
|
||
let (_, s) ← collect.run {}
|
||
return s.vars
|
||
|
||
def getPatternVarNames (pvars : Array PatternVar) : Array Name :=
|
||
pvars.map fun x => x.getId
|
||
|
||
end Lean.Elab.Term
|