895 lines
37 KiB
Text
895 lines
37 KiB
Text
/-
|
||
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
|
||
Released under Apache 2.0 license as described in the file LICENSE.
|
||
Authors: Leonardo de Moura
|
||
-/
|
||
import Lean.Meta.Match.MatchPatternAttr
|
||
import Lean.Meta.Match.Match
|
||
import Lean.Elab.SyntheticMVars
|
||
import Lean.Elab.App
|
||
import Lean.Parser.Term
|
||
|
||
namespace Lean.Elab.Term
|
||
open Meta
|
||
open Lean.Parser.Term
|
||
|
||
/- This modules assumes "match"-expressions use the following syntax.
|
||
|
||
```lean
|
||
def matchDiscr := parser! optional (try (ident >> checkNoWsBefore "no space before ':'" >> ":")) >> termParser
|
||
|
||
def «match» := parser!:leadPrec "match " >> sepBy1 matchDiscr ", " >> optType >> " with " >> matchAlts
|
||
```
|
||
-/
|
||
|
||
structure MatchAltView where
|
||
ref : Syntax
|
||
patterns : Array Syntax
|
||
rhs : Syntax
|
||
deriving Inhabited
|
||
|
||
private def expandSimpleMatch (stx discr lhsVar rhs : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do
|
||
let newStx ← `(let $lhsVar := $discr; $rhs)
|
||
withMacroExpansion stx newStx $ elabTerm newStx expectedType?
|
||
|
||
private def expandSimpleMatchWithType (stx discr lhsVar type rhs : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do
|
||
let newStx ← `(let $lhsVar : $type := $discr; $rhs)
|
||
withMacroExpansion stx newStx $ elabTerm newStx expectedType?
|
||
|
||
private def elabDiscrsWitMatchType (discrStxs : Array Syntax) (matchType : Expr) (expectedType : Expr) : TermElabM (Array Expr) := do
|
||
let mut discrs := #[]
|
||
let mut i := 0
|
||
let mut matchType := matchType
|
||
for discrStx in discrStxs do
|
||
i := i + 1
|
||
matchType ← whnf matchType
|
||
match matchType with
|
||
| Expr.forallE _ d b _ =>
|
||
let discr ← fullApproxDefEq $ elabTermEnsuringType discrStx[1] d
|
||
trace[Elab.match]! "discr #{i} {discr} : {d}"
|
||
matchType ← b.instantiate1 discr
|
||
discrs := discrs.push discr
|
||
| _ =>
|
||
throwError! "invalid type provided to match-expression, function type with arity #{discrStxs.size} expected"
|
||
pure discrs
|
||
|
||
private def mkUserNameFor (e : Expr) : TermElabM Name :=
|
||
match e with
|
||
| Expr.fvar fvarId _ => do pure (← getLocalDecl fvarId).userName
|
||
| _ => mkFreshBinderName
|
||
|
||
-- `expandNonAtomicDiscrs?` create auxiliary variables with base name `_discr`
|
||
private def isAuxDiscrName (n : Name) : Bool :=
|
||
n.eraseMacroScopes == `_discr
|
||
|
||
-- See expandNonAtomicDiscrs?
|
||
private def elabAtomicDiscr (discr : Syntax) : TermElabM Expr := do
|
||
let term := discr[1]
|
||
match (← isLocalIdent? term) with
|
||
| some e@(Expr.fvar fvarId _) =>
|
||
let localDecl ← getLocalDecl fvarId
|
||
if !isAuxDiscrName localDecl.userName then
|
||
pure e -- it is not an auxiliary local created by `expandNonAtomicDiscrs?`
|
||
else
|
||
pure localDecl.value
|
||
| _ => throwErrorAt discr "unexpected discriminant"
|
||
|
||
private def elabMatchTypeAndDiscrs (discrStxs : Array Syntax) (matchOptType : Syntax) (matchAltViews : Array MatchAltView) (expectedType : Expr)
|
||
: TermElabM (Array Expr × Expr × Array MatchAltView) := do
|
||
let numDiscrs := discrStxs.size
|
||
if matchOptType.isNone then
|
||
let rec loop (i : Nat) (discrs : Array Expr) (matchType : Expr) (matchAltViews : Array MatchAltView) := do
|
||
match i with
|
||
| 0 => pure (discrs.reverse, matchType, matchAltViews)
|
||
| i+1 =>
|
||
let discrStx := discrStxs[i]
|
||
let discr ← elabAtomicDiscr discrStx
|
||
let discr ← instantiateMVars discr
|
||
let discrType ← inferType discr
|
||
let discrType ← instantiateMVars discrType
|
||
let matchTypeBody ← kabstract matchType discr
|
||
let userName ← mkUserNameFor discr
|
||
if discrStx[0].isNone then
|
||
loop i (discrs.push discr) (Lean.mkForall userName BinderInfo.default discrType matchTypeBody) matchAltViews
|
||
else
|
||
let identStx := discrStx[0][0]
|
||
withLocalDeclD userName discrType fun x => do
|
||
let eqType ← mkEq discr x
|
||
withLocalDeclD identStx.getId eqType fun h => do
|
||
let matchTypeBody := matchTypeBody.instantiate1 x
|
||
let matchType ← mkForallFVars #[x, h] matchTypeBody
|
||
let refl ← mkEqRefl discr
|
||
let discrs := (discrs.push refl).push discr
|
||
let matchAltViews := matchAltViews.map fun altView =>
|
||
{ altView with patterns := altView.patterns.insertAt (i+1) identStx }
|
||
loop i discrs matchType matchAltViews
|
||
loop discrStxs.size #[] expectedType matchAltViews
|
||
else
|
||
let matchTypeStx := matchOptType[0][1]
|
||
let matchType ← elabType matchTypeStx
|
||
let discrs ← elabDiscrsWitMatchType discrStxs matchType expectedType
|
||
pure (discrs, matchType, matchAltViews)
|
||
|
||
def expandMacrosInPatterns (matchAlts : Array MatchAltView) : MacroM (Array MatchAltView) := do
|
||
matchAlts.mapM fun matchAlt => do
|
||
let patterns ← matchAlt.patterns.mapM expandMacros
|
||
pure { matchAlt with patterns := patterns }
|
||
|
||
/- Given `stx` a match-expression, return its alternatives. -/
|
||
private def getMatchAlts : Syntax → Array MatchAltView
|
||
| `(match $discrs,* $[: $ty?]? with $alts:matchAlt*) =>
|
||
alts.map fun alt => match alt with
|
||
| `(matchAltExpr| | $patterns,* => $rhs) => {
|
||
ref := alt,
|
||
patterns := patterns,
|
||
rhs := rhs
|
||
}
|
||
| _ => unreachable!
|
||
| _ => unreachable!
|
||
|
||
/--
|
||
Auxiliary annotation used to mark terms marked with the "inaccessible" annotation `.(t)` and
|
||
`_` in patterns. -/
|
||
def mkInaccessible (e : Expr) : Expr :=
|
||
mkAnnotation `_inaccessible e
|
||
|
||
def inaccessible? (e : Expr) : Option Expr :=
|
||
annotation? `_inaccessible e
|
||
|
||
inductive PatternVar where
|
||
| localVar (userName : Name)
|
||
-- anonymous variables (`_`) are encoded using metavariables
|
||
| anonymousVar (mvarId : MVarId)
|
||
|
||
instance : ToString PatternVar := ⟨fun
|
||
| PatternVar.localVar x => toString x
|
||
| PatternVar.anonymousVar mvarId => s!"?m{mvarId}"⟩
|
||
|
||
builtin_initialize Parser.registerBuiltinNodeKind `MVarWithIdKind
|
||
|
||
/--
|
||
Create an auxiliary Syntax node wrapping a fresh metavariable id.
|
||
We use this kind of Syntax for representing `_` occurring in patterns.
|
||
The metavariables are created before we elaborate the patterns into `Expr`s. -/
|
||
private def mkMVarSyntax : TermElabM Syntax := do
|
||
let mvarId ← mkFreshId
|
||
pure $ Syntax.node `MVarWithIdKind #[Syntax.node mvarId #[]]
|
||
|
||
/-- Given a syntax node constructed using `mkMVarSyntax`, return its MVarId -/
|
||
private def getMVarSyntaxMVarId (stx : Syntax) : MVarId :=
|
||
stx[0].getKind
|
||
|
||
/--
|
||
The elaboration function for `Syntax` created using `mkMVarSyntax`.
|
||
It just converts the metavariable id wrapped by the Syntax into an `Expr`. -/
|
||
@[builtinTermElab MVarWithIdKind] def elabMVarWithIdKind : TermElab := fun stx expectedType? =>
|
||
pure $ mkInaccessible $ mkMVar (getMVarSyntaxMVarId stx)
|
||
|
||
@[builtinTermElab inaccessible] def elabInaccessible : TermElab := fun stx expectedType? => do
|
||
let e ← elabTerm stx[1] expectedType?
|
||
pure $ mkInaccessible e
|
||
|
||
/-
|
||
Patterns define new local variables.
|
||
This module collect them and preprocess `_` occurring in patterns.
|
||
Recall that an `_` may represent anonymous variables or inaccessible terms
|
||
that are implied by typing constraints. Thus, we represent them with fresh named holes `?x`.
|
||
After we elaborate the pattern, if the metavariable remains unassigned, we transform it into
|
||
a regular pattern variable. Otherwise, it becomes an inaccessible term.
|
||
|
||
Macros occurring in patterns are expanded before the `collectPatternVars` method is executed.
|
||
The following kinds of Syntax are handled by this module
|
||
- Constructor applications
|
||
- Applications of functions tagged with the `[matchPattern]` attribute
|
||
- Identifiers
|
||
- Anonymous constructors
|
||
- Structure instances
|
||
- Inaccessible terms
|
||
- Named patterns
|
||
- Tuple literals
|
||
- Type ascriptions
|
||
- Literals: num, string and char
|
||
-/
|
||
namespace CollectPatternVars
|
||
|
||
structure State where
|
||
found : NameSet := {}
|
||
vars : Array PatternVar := #[]
|
||
|
||
abbrev M := StateRefT State TermElabM
|
||
|
||
private def throwCtorExpected {α} : M α :=
|
||
throwError "invalid pattern, constructor or constant marked with '[matchPattern]' expected"
|
||
|
||
private def getNumExplicitCtorParams (ctorVal : ConstructorVal) : TermElabM Nat :=
|
||
forallBoundedTelescope ctorVal.type ctorVal.nparams fun ps _ => do
|
||
let mut result := 0
|
||
for p in ps do
|
||
let localDecl ← getLocalDecl p.fvarId!
|
||
if localDecl.binderInfo.isExplicit then
|
||
result := result+1
|
||
pure result
|
||
|
||
private def throwInvalidPattern {α} : M α :=
|
||
throwError "invalid pattern"
|
||
|
||
namespace CtorApp
|
||
|
||
/-
|
||
An application in a pattern can be
|
||
|
||
1- A constructor application
|
||
The elaborator assumes fields are accessible and inductive parameters are not accessible.
|
||
|
||
2- A regular application `(f ...)` where `f` is tagged with `[matchPattern]`.
|
||
The elaborator assumes implicit arguments are not accessible and explicit ones are accessible.
|
||
-/
|
||
|
||
structure Context where
|
||
funId : Syntax
|
||
ctorVal? : Option ConstructorVal -- It is `some`, if constructor application
|
||
explicit : Bool
|
||
ellipsis : Bool
|
||
paramDecls : Array LocalDecl
|
||
paramDeclIdx : Nat := 0
|
||
namedArgs : Array NamedArg
|
||
args : List Arg
|
||
newArgs : Array Syntax := #[]
|
||
deriving Inhabited
|
||
|
||
private def isDone (ctx : Context) : Bool :=
|
||
ctx.paramDeclIdx ≥ ctx.paramDecls.size
|
||
|
||
private def finalize (ctx : Context) : M Syntax := do
|
||
if ctx.namedArgs.isEmpty && ctx.args.isEmpty then
|
||
let fStx ← `(@$(ctx.funId):ident)
|
||
pure $ Syntax.mkApp fStx ctx.newArgs
|
||
else
|
||
throwError "too many arguments"
|
||
|
||
private def isNextArgAccessible (ctx : Context) : Bool :=
|
||
let i := ctx.paramDeclIdx
|
||
match ctx.ctorVal? with
|
||
| some ctorVal => i ≥ ctorVal.nparams -- For constructor applications only fields are accessible
|
||
| none =>
|
||
if h : i < ctx.paramDecls.size then
|
||
-- For `[matchPattern]` applications, only explicit parameters are accessible.
|
||
let d := ctx.paramDecls.get ⟨i, h⟩
|
||
d.binderInfo.isExplicit
|
||
else
|
||
false
|
||
|
||
private def getNextParam (ctx : Context) : LocalDecl × Context :=
|
||
let i := ctx.paramDeclIdx
|
||
let d := ctx.paramDecls[i]
|
||
(d, { ctx with paramDeclIdx := ctx.paramDeclIdx + 1 })
|
||
|
||
private def pushNewArg (collect : Syntax → M Syntax) (accessible : Bool) (ctx : Context) (arg : Arg) : M Context :=
|
||
match arg with
|
||
| Arg.stx stx => do
|
||
let stx ← if accessible then collect stx else pure stx
|
||
pure { ctx with newArgs := ctx.newArgs.push stx }
|
||
| _ => unreachable!
|
||
|
||
private def processExplicitArg (collect : Syntax → M Syntax) (accessible : Bool) (ctx : Context) : M Context :=
|
||
match ctx.args with
|
||
| [] =>
|
||
if ctx.ellipsis then do
|
||
let hole ← `(_)
|
||
pushNewArg collect accessible ctx (Arg.stx hole)
|
||
else
|
||
throwError! "explicit parameter is missing, unused named arguments {ctx.namedArgs.map fun narg => narg.name}"
|
||
| arg::args => do
|
||
let ctx := { ctx with args := args }
|
||
pushNewArg collect accessible ctx arg
|
||
|
||
private def processImplicitArg (collect : Syntax → M Syntax) (accessible : Bool) (ctx : Context) : M Context :=
|
||
if ctx.explicit then
|
||
processExplicitArg collect accessible ctx
|
||
else do
|
||
let hole ← `(_)
|
||
pushNewArg collect accessible ctx (Arg.stx hole)
|
||
|
||
private partial def processCtorAppAux (collect : Syntax → M Syntax) (ctx : Context) : M Syntax := do
|
||
if isDone ctx then
|
||
finalize ctx
|
||
else
|
||
let accessible := isNextArgAccessible ctx
|
||
let (d, ctx) := getNextParam ctx
|
||
match ctx.namedArgs.findIdx? fun namedArg => namedArg.name == d.userName with
|
||
| some idx =>
|
||
let arg := ctx.namedArgs[idx]
|
||
let ctx := { ctx with namedArgs := ctx.namedArgs.eraseIdx idx }
|
||
let ctx ← pushNewArg collect accessible ctx arg.val
|
||
processCtorAppAux collect ctx
|
||
| none =>
|
||
let ctx ← match d.binderInfo with
|
||
| BinderInfo.implicit => processImplicitArg collect accessible ctx
|
||
| BinderInfo.instImplicit => processImplicitArg collect accessible ctx
|
||
| _ => processExplicitArg collect accessible ctx
|
||
processCtorAppAux collect ctx
|
||
|
||
def processCtorApp (collect : Syntax → M Syntax) (f : Syntax) (namedArgs : Array NamedArg) (args : Array Arg) (ellipsis : Bool) : M Syntax := do
|
||
let args := args.toList
|
||
let (fId, explicit) ← match f with
|
||
| `($fId:ident) => pure (fId, false)
|
||
| `(@$fId:ident) => pure (fId, true)
|
||
| _ => throwError "identifier expected"
|
||
let some (Expr.const fName _ _) ← resolveId? fId "pattern" | throwCtorExpected
|
||
let fInfo ← getConstInfo fName
|
||
forallTelescopeReducing fInfo.type fun xs _ => do
|
||
let paramDecls ← xs.mapM (getFVarLocalDecl ·)
|
||
match fInfo with
|
||
| ConstantInfo.ctorInfo val =>
|
||
processCtorAppAux collect
|
||
{ funId := fId, explicit := explicit, ctorVal? := val, paramDecls := paramDecls, namedArgs := namedArgs, args := args, ellipsis := ellipsis }
|
||
| _ =>
|
||
let env ← getEnv
|
||
if hasMatchPatternAttribute env fName then
|
||
processCtorAppAux collect
|
||
{ funId := fId, explicit := explicit, ctorVal? := none, paramDecls := paramDecls, namedArgs := namedArgs, args := args, ellipsis := ellipsis }
|
||
else
|
||
throwCtorExpected
|
||
|
||
end CtorApp
|
||
|
||
def processCtorApp (collect : Syntax → M Syntax) (stx : Syntax) : M Syntax := do
|
||
let (f, namedArgs, args, ellipsis) ← liftM $ expandApp stx true
|
||
CtorApp.processCtorApp collect f namedArgs args ellipsis
|
||
|
||
def processCtor (collect : Syntax → M Syntax) (stx : Syntax) : M Syntax := do
|
||
CtorApp.processCtorApp collect stx #[] #[] false
|
||
|
||
private def processVar (idStx : Syntax) : M Syntax := do
|
||
unless idStx.isIdent do
|
||
throwErrorAt idStx "identifier expected"
|
||
let id := idStx.getId
|
||
unless id.eraseMacroScopes.isAtomic do throwError "invalid pattern variable, must be atomic"
|
||
let s ← get
|
||
if s.found.contains id then throwError! "invalid pattern, variable '{id}' occurred more than once"
|
||
modify fun s => { s with vars := s.vars.push (PatternVar.localVar id), found := s.found.insert id }
|
||
pure idStx
|
||
|
||
/- Check whether `stx` is a pattern variable or constructor-like (i.e., constructor or constant tagged with `[matchPattern]` attribute) -/
|
||
private def processId (collect : Syntax → M Syntax) (stx : Syntax) : M Syntax := do
|
||
let env ← getEnv
|
||
match (← resolveId? stx "pattern") with
|
||
| none => processVar stx
|
||
| some f => match f with
|
||
| Expr.const fName _ _ =>
|
||
match env.find? fName with
|
||
| some (ConstantInfo.ctorInfo _) => processCtor collect stx
|
||
| some _ =>
|
||
if hasMatchPatternAttribute env fName then
|
||
processCtor collect stx
|
||
else
|
||
processVar stx
|
||
| none => throwCtorExpected
|
||
| _ => processVar stx
|
||
|
||
private def nameToPattern : Name → TermElabM Syntax
|
||
| Name.anonymous => `(Name.anonymous)
|
||
| Name.str p s _ => do let p ← nameToPattern p; `(Name.str $p $(quote s) _)
|
||
| Name.num p n _ => do let p ← nameToPattern p; `(Name.num $p $(quote n) _)
|
||
|
||
private def quotedNameToPattern (stx : Syntax) : TermElabM Syntax :=
|
||
match stx[0].isNameLit? with
|
||
| some val => nameToPattern val
|
||
| none => throwIllFormedSyntax
|
||
|
||
partial def collect : Syntax → M Syntax
|
||
| stx@(Syntax.node k args) => withRef stx $ withFreshMacroScope do
|
||
if k == `Lean.Parser.Term.app then
|
||
processCtorApp collect stx
|
||
else if k == `Lean.Parser.Term.anonymousCtor then
|
||
let elems ← args[1].getArgs.mapSepElemsM collect
|
||
pure $ Syntax.node k (args.set! 1 $ mkNullNode elems)
|
||
else if k == `Lean.Parser.Term.structInst then
|
||
/-
|
||
```
|
||
parser! "{" >> optional (atomic (termParser >> " with "))
|
||
>> manyIndent (group (structInstField >> optional ", "))
|
||
>> optional ".."
|
||
>> optional (" : " >> termParser)
|
||
>> " }"
|
||
``` -/
|
||
let withMod := args[1]
|
||
unless withMod.isNone do
|
||
throwErrorAt withMod "invalid struct instance pattern, 'with' is not allowed in patterns"
|
||
let fields ← args[2].getArgs.mapM fun p => do
|
||
-- p is of the form (group (structInstField >> optional ", "))
|
||
let field := p[0]
|
||
-- parser! structInstLVal >> " := " >> termParser
|
||
let newVal ← collect field[2]
|
||
let field := field.setArg 2 newVal
|
||
pure <| field.setArg 0 field
|
||
pure <| Syntax.node k (args.set! 2 <| mkNullNode fields)
|
||
else if k == `Lean.Parser.Term.hole then
|
||
let r ← mkMVarSyntax
|
||
modify fun s => { s with vars := s.vars.push $ PatternVar.anonymousVar $ getMVarSyntaxMVarId r }
|
||
pure r
|
||
else if k == `Lean.Parser.Term.paren then
|
||
let arg := args[1]
|
||
if arg.isNone then
|
||
pure stx -- `()`
|
||
else
|
||
let t := arg[0]
|
||
let s := arg[1]
|
||
if s.isNone || s[0].getKind == `Lean.Parser.Term.typeAscription then
|
||
-- Ignore `s`, since it empty or it is a type ascription
|
||
let t ← collect t
|
||
let arg := arg.setArg 0 t
|
||
pure $ Syntax.node k (args.set! 1 arg)
|
||
else
|
||
-- Tuple literal is a constructor
|
||
let t ← collect t
|
||
let arg := arg.setArg 0 t
|
||
let tupleTail := s[0]
|
||
let tupleTailElems := tupleTail[1].getArgs
|
||
let tupleTailElems ← tupleTailElems.mapSepElemsM collect
|
||
let tupleTail := tupleTail.setArg 1 $ mkNullNode tupleTailElems
|
||
let s := s.setArg 0 tupleTail
|
||
let arg := arg.setArg 1 s
|
||
pure $ Syntax.node k (args.set! 1 arg)
|
||
else if k == `Lean.Parser.Term.explicitUniv then
|
||
processCtor collect stx[0]
|
||
else if k == `Lean.Parser.Term.namedPattern then
|
||
/- Recall that
|
||
def namedPattern := check... >> tparser! "@" >> termParser -/
|
||
let id := stx[0]
|
||
discard <| processVar id
|
||
let pat := stx[2]
|
||
let pat ← collect pat
|
||
`(_root_.namedPattern $id $pat)
|
||
else if k == `Lean.Parser.Term.inaccessible then
|
||
pure stx
|
||
else if k == strLitKind then
|
||
pure stx
|
||
else if k == numLitKind then
|
||
pure stx
|
||
else if k == scientificLitKind then
|
||
pure stx
|
||
else if k == charLitKind then
|
||
pure stx
|
||
else if k == `Lean.Parser.Term.quotedName then
|
||
/- Quoted names have an elaboration function associated with them, and they will not be macro expanded.
|
||
Note that macro expansion is not a good option since it produces a term using the smart constructors `Name.mkStr`, `Name.mkNum`
|
||
instead of the constructors `Name.str` and `Name.num` -/
|
||
quotedNameToPattern stx
|
||
else if k == choiceKind then
|
||
throwError "invalid pattern, notation is ambiguous"
|
||
else
|
||
throwInvalidPattern
|
||
| stx@(Syntax.ident _ _ _ _) =>
|
||
processId collect stx
|
||
| stx =>
|
||
throwInvalidPattern
|
||
|
||
def main (alt : MatchAltView) : M MatchAltView := do
|
||
let patterns ← alt.patterns.mapM fun p => do
|
||
trace[Elab.match]! "collecting variables at pattern: {p}"
|
||
collect p
|
||
pure { alt with patterns := patterns }
|
||
|
||
end CollectPatternVars
|
||
|
||
private def collectPatternVars (alt : MatchAltView) : TermElabM (Array PatternVar × MatchAltView) := do
|
||
let (alt, s) ← (CollectPatternVars.main alt).run {}
|
||
pure (s.vars, alt)
|
||
|
||
/- Return the pattern variables in the given pattern.
|
||
Remark: this method is not used here, but in other macros (e.g., at `Do.lean`). -/
|
||
def getPatternVars (patternStx : Syntax) : TermElabM (Array PatternVar) := do
|
||
let patternStx ← liftMacroM $ expandMacros patternStx
|
||
let (_, s) ← (CollectPatternVars.collect patternStx).run {}
|
||
pure s.vars
|
||
|
||
def getPatternsVars (patterns : Array Syntax) : TermElabM (Array PatternVar) := do
|
||
let collect : CollectPatternVars.M Unit := do
|
||
for pattern in patterns do
|
||
discard <| CollectPatternVars.collect (← liftMacroM $ expandMacros pattern)
|
||
let (_, s) ← collect.run {}
|
||
pure s.vars
|
||
|
||
/- We convert the collected `PatternVar`s intro `PatternVarDecl` -/
|
||
inductive PatternVarDecl where
|
||
/- For `anonymousVar`, we create both a metavariable and a free variable. The free variable is used as an assignment for the metavariable
|
||
when it is not assigned during pattern elaboration. -/
|
||
| anonymousVar (mvarId : MVarId) (fvarId : FVarId)
|
||
| localVar (fvarId : FVarId)
|
||
|
||
private partial def withPatternVars {α} (pVars : Array PatternVar) (k : Array PatternVarDecl → TermElabM α) : TermElabM α :=
|
||
let rec loop (i : Nat) (decls : Array PatternVarDecl) := do
|
||
if h : i < pVars.size then
|
||
match pVars.get ⟨i, h⟩ with
|
||
| PatternVar.anonymousVar mvarId =>
|
||
let type ← mkFreshTypeMVar
|
||
let userName ← mkFreshBinderName
|
||
withLocalDecl userName BinderInfo.default type fun x =>
|
||
loop (i+1) (decls.push (PatternVarDecl.anonymousVar mvarId x.fvarId!))
|
||
| PatternVar.localVar userName =>
|
||
let type ← mkFreshTypeMVar
|
||
withLocalDecl userName BinderInfo.default type fun x =>
|
||
loop (i+1) (decls.push (PatternVarDecl.localVar x.fvarId!))
|
||
else
|
||
/- We must create the metavariables for `PatternVar.anonymousVar` AFTER we create the new local decls using `withLocalDecl`.
|
||
Reason: their scope must include the new local decls since some of them are assigned by typing constraints. -/
|
||
decls.forM fun decl => match decl with
|
||
| PatternVarDecl.anonymousVar mvarId fvarId => do
|
||
let type ← inferType (mkFVar fvarId)
|
||
discard <| mkFreshExprMVarWithId mvarId type
|
||
| _ => pure ()
|
||
k decls
|
||
loop 0 #[]
|
||
|
||
private def elabPatterns (patternStxs : Array Syntax) (matchType : Expr) : TermElabM (Array Expr × Expr) := do
|
||
let mut patterns := #[]
|
||
let mut matchType := matchType
|
||
for patternStx in patternStxs do
|
||
matchType ← whnf matchType
|
||
match matchType with
|
||
| Expr.forallE _ d b _ =>
|
||
let pattern ← elabTermEnsuringType patternStx d
|
||
matchType := b.instantiate1 pattern
|
||
patterns := patterns.push pattern
|
||
| _ => throwError "unexpected match type"
|
||
pure (patterns, matchType)
|
||
|
||
def finalizePatternDecls (patternVarDecls : Array PatternVarDecl) : TermElabM (Array LocalDecl) := do
|
||
let mut decls := #[]
|
||
for pdecl in patternVarDecls do
|
||
match pdecl with
|
||
| PatternVarDecl.localVar fvarId =>
|
||
let decl ← getLocalDecl fvarId
|
||
let decl ← instantiateLocalDeclMVars decl
|
||
decls := decls.push decl
|
||
| PatternVarDecl.anonymousVar mvarId fvarId =>
|
||
let e ← instantiateMVars (mkMVar mvarId);
|
||
trace[Elab.match]! "finalizePatternDecls: mvarId: {mvarId} := {e}, fvar: {mkFVar fvarId}"
|
||
match e with
|
||
| Expr.mvar newMVarId _ =>
|
||
/- Metavariable was not assigned, or assigned to another metavariable. So,
|
||
we assign to the auxiliary free variable we created at `withPatternVars` to `newMVarId`. -/
|
||
assignExprMVar newMVarId (mkFVar fvarId)
|
||
trace[Elab.match]! "finalizePatternDecls: {mkMVar newMVarId} := {mkFVar fvarId}"
|
||
let decl ← getLocalDecl fvarId
|
||
let decl ← instantiateLocalDeclMVars decl
|
||
decls := decls.push decl
|
||
| _ => pure ()
|
||
pure decls
|
||
|
||
open Meta.Match (Pattern Pattern.var Pattern.inaccessible Pattern.ctor Pattern.as Pattern.val Pattern.arrayLit AltLHS MatcherResult)
|
||
|
||
namespace ToDepElimPattern
|
||
|
||
structure State where
|
||
found : NameSet := {}
|
||
localDecls : Array LocalDecl
|
||
newLocals : NameSet := {}
|
||
|
||
abbrev M := StateRefT State TermElabM
|
||
|
||
private def alreadyVisited (fvarId : FVarId) : M Bool := do
|
||
let s ← get
|
||
pure $ s.found.contains fvarId
|
||
|
||
private def markAsVisited (fvarId : FVarId) : M Unit :=
|
||
modify fun s => { s with found := s.found.insert fvarId }
|
||
|
||
private def throwInvalidPattern {α} (e : Expr) : M α :=
|
||
throwError! "invalid pattern {indentExpr e}"
|
||
|
||
/- Create a new LocalDecl `x` for the metavariable `mvar`, and return `Pattern.var x` -/
|
||
private def mkLocalDeclFor (mvar : Expr) : M Pattern := do
|
||
let mvarId := mvar.mvarId!
|
||
let s ← get
|
||
match (← getExprMVarAssignment? mvarId) with
|
||
| some val => pure $ Pattern.inaccessible val
|
||
| none =>
|
||
let fvarId ← mkFreshId
|
||
let type ← inferType mvar
|
||
/- HACK: `fvarId` is not in the scope of `mvarId`
|
||
If this generates problems in the future, we should update the metavariable declarations. -/
|
||
assignExprMVar mvarId (mkFVar fvarId)
|
||
let userName ← liftM $ mkFreshBinderName
|
||
let newDecl := LocalDecl.cdecl arbitrary fvarId userName type BinderInfo.default;
|
||
modify fun s =>
|
||
{ s with
|
||
newLocals := s.newLocals.insert fvarId,
|
||
localDecls :=
|
||
match s.localDecls.findIdx? fun decl => mvar.occurs decl.type with
|
||
| none => s.localDecls.push newDecl -- None of the existing declarations depend on `mvar`
|
||
| some i => s.localDecls.insertAt i newDecl }
|
||
pure $ Pattern.var fvarId
|
||
|
||
partial def main (e : Expr) : M Pattern := do
|
||
let isLocalDecl (fvarId : FVarId) : M Bool := do
|
||
let s ← get
|
||
pure $ s.localDecls.any fun d => d.fvarId == fvarId
|
||
let mkPatternVar (fvarId : FVarId) (e : Expr) : M Pattern := do
|
||
if (← alreadyVisited fvarId) then
|
||
pure $ Pattern.inaccessible e
|
||
else
|
||
markAsVisited fvarId
|
||
pure $ Pattern.var e.fvarId!
|
||
let mkInaccessible (e : Expr) : M Pattern := do
|
||
match e with
|
||
| Expr.fvar fvarId _ =>
|
||
if (← isLocalDecl fvarId) then
|
||
mkPatternVar fvarId e
|
||
else
|
||
pure $ Pattern.inaccessible e
|
||
| _ =>
|
||
pure $ Pattern.inaccessible e
|
||
match inaccessible? e with
|
||
| some t => mkInaccessible t
|
||
| none =>
|
||
match e.arrayLit? with
|
||
| some (α, lits) =>
|
||
let ps ← lits.mapM main;
|
||
pure $ Pattern.arrayLit α ps
|
||
| none =>
|
||
if e.isAppOfArity `namedPattern 3 then
|
||
let p ← main $ e.getArg! 2;
|
||
match e.getArg! 1 with
|
||
| Expr.fvar fvarId _ => pure $ Pattern.as fvarId p
|
||
| _ => throwError "unexpected occurrence of auxiliary declaration 'namedPattern'"
|
||
else if e.isNatLit || e.isStringLit || e.isCharLit then
|
||
pure $ Pattern.val e
|
||
else if e.isFVar then
|
||
let fvarId := e.fvarId!
|
||
unless(← isLocalDecl fvarId) do throwInvalidPattern e
|
||
mkPatternVar fvarId e
|
||
else if e.isMVar then
|
||
mkLocalDeclFor e
|
||
else
|
||
let newE ← whnf e
|
||
if newE != e then
|
||
main newE
|
||
else matchConstCtor e.getAppFn (fun _ => throwInvalidPattern e) fun v us => do
|
||
let args := e.getAppArgs
|
||
unless args.size == v.nparams + v.nfields do
|
||
throwInvalidPattern e
|
||
let params := args.extract 0 v.nparams
|
||
let fields := args.extract v.nparams args.size
|
||
let fields ← fields.mapM main
|
||
pure $ Pattern.ctor v.name us params.toList fields.toList
|
||
|
||
end ToDepElimPattern
|
||
|
||
def withDepElimPatterns {α} (localDecls : Array LocalDecl) (ps : Array Expr) (k : Array LocalDecl → Array Pattern → TermElabM α) : TermElabM α := do
|
||
let (patterns, s) ← (ps.mapM ToDepElimPattern.main).run { localDecls := localDecls }
|
||
let localDecls ← s.localDecls.mapM fun d => instantiateLocalDeclMVars d
|
||
/- toDepElimPatterns may have added new localDecls. Thus, we must update the local context before we execute `k` -/
|
||
let lctx ← getLCtx
|
||
let lctx := localDecls.foldl (fun (lctx : LocalContext) d => lctx.erase d.fvarId) lctx
|
||
let lctx := localDecls.foldl (fun (lctx : LocalContext) d => lctx.addDecl d) lctx
|
||
withTheReader Meta.Context (fun ctx => { ctx with lctx := lctx }) $ k localDecls patterns
|
||
|
||
private def withElaboratedLHS {α} (ref : Syntax) (patternVarDecls : Array PatternVarDecl) (patternStxs : Array Syntax) (matchType : Expr)
|
||
(k : AltLHS → Expr → TermElabM α) : TermElabM α := do
|
||
let (patterns, matchType) ← withSynthesize $ elabPatterns patternStxs matchType
|
||
let localDecls ← finalizePatternDecls patternVarDecls
|
||
let patterns ← patterns.mapM (instantiateMVars ·)
|
||
withDepElimPatterns localDecls patterns fun localDecls patterns =>
|
||
k { ref := ref, fvarDecls := localDecls.toList, patterns := patterns.toList } matchType
|
||
|
||
def elabMatchAltView (alt : MatchAltView) (matchType : Expr) : TermElabM (AltLHS × Expr) := withRef alt.ref do
|
||
let (patternVars, alt) ← collectPatternVars alt
|
||
trace[Elab.match]! "patternVars: {patternVars}"
|
||
withPatternVars patternVars fun patternVarDecls => do
|
||
withElaboratedLHS alt.ref patternVarDecls alt.patterns matchType fun altLHS matchType => do
|
||
let rhs ← elabTermEnsuringType alt.rhs matchType
|
||
let xs := altLHS.fvarDecls.toArray.map LocalDecl.toExpr
|
||
let rhs ← if xs.isEmpty then pure $ mkSimpleThunk rhs else mkLambdaFVars xs rhs
|
||
trace[Elab.match]! "rhs: {rhs}"
|
||
pure (altLHS, rhs)
|
||
|
||
def mkMatcher (elimName : Name) (matchType : Expr) (numDiscrs : Nat) (lhss : List AltLHS) : TermElabM MatcherResult :=
|
||
liftMetaM $ Meta.Match.mkMatcher elimName matchType numDiscrs lhss
|
||
|
||
builtin_initialize
|
||
registerOption `match.ignoreUnusedAlts { defValue := false, group := "", descr := "if true, do not generate error if an alternative is not used" }
|
||
|
||
def ignoreUnusedAlts (opts : Options) : Bool :=
|
||
opts.get `match.ignoreUnusedAlts false
|
||
|
||
def reportMatcherResultErrors (altLHSS : List AltLHS) (result : MatcherResult) : TermElabM Unit := do
|
||
unless result.counterExamples.isEmpty do
|
||
throwError! "missing cases:\n{Meta.Match.counterExamplesToMessageData result.counterExamples}"
|
||
unless ignoreUnusedAlts (← getOptions) || result.unusedAltIdxs.isEmpty do
|
||
let mut i := 0
|
||
for alt in altLHSS do
|
||
if result.unusedAltIdxs.contains i then
|
||
withRef alt.ref do
|
||
logError "redundant alternative"
|
||
i := i + 1
|
||
|
||
private def elabMatchAux (discrStxs : Array Syntax) (altViews : Array MatchAltView) (matchOptType : Syntax) (expectedType : Expr)
|
||
: TermElabM Expr := do
|
||
let (discrs, matchType, altViews) ← elabMatchTypeAndDiscrs discrStxs matchOptType altViews expectedType
|
||
let matchAlts ← liftMacroM $ expandMacrosInPatterns altViews
|
||
trace[Elab.match]! "matchType: {matchType}"
|
||
let alts ← matchAlts.mapM $ fun alt => elabMatchAltView alt matchType
|
||
/-
|
||
We should not use `synthesizeSyntheticMVarsNoPostponing` here. Otherwise, we will not be
|
||
able to elaborate examples such as:
|
||
```
|
||
def f (x : Nat) : Option Nat := none
|
||
|
||
def g (xs : List (Nat × Nat)) : IO Unit :=
|
||
xs.forM fun x =>
|
||
match f x.fst with
|
||
| _ => pure ()
|
||
```
|
||
If `synthesizeSyntheticMVarsNoPostponing`, the example above fails at `x.fst` because
|
||
the type of `x` is only available after we proces the last argument of `List.forM`.
|
||
|
||
We apply pending default types to make sure we can process examples such as
|
||
```
|
||
let (a, b) := (0, 0)
|
||
```
|
||
-/
|
||
synthesizeSyntheticMVarsUsingDefault
|
||
let rhss := alts.map Prod.snd
|
||
let matchType ← instantiateMVars matchType
|
||
let altLHSS ← alts.toList.mapM fun alt => do
|
||
let altLHS ← Match.instantiateAltLHSMVars alt.1
|
||
withRef altLHS.ref do
|
||
for d in altLHS.fvarDecls do
|
||
if d.hasExprMVar then
|
||
withExistingLocalDecls altLHS.fvarDecls do
|
||
throwMVarError m!"invalid match-expression, type of pattern variable '{d.toExpr}' contains metavariables{indentExpr d.type}"
|
||
for p in altLHS.patterns do
|
||
if p.hasExprMVar then
|
||
withExistingLocalDecls altLHS.fvarDecls do
|
||
throwMVarError m!"invalid match-expression, pattern contains metavariables{indentExpr (← p.toExpr)}"
|
||
pure altLHS
|
||
let numDiscrs := discrs.size
|
||
let matcherName ← mkAuxName `match
|
||
let matcherResult ← mkMatcher matcherName matchType numDiscrs altLHSS
|
||
let motive ← forallBoundedTelescope matchType numDiscrs fun xs matchType => mkLambdaFVars xs matchType
|
||
reportMatcherResultErrors altLHSS matcherResult
|
||
let r := mkApp matcherResult.matcher motive
|
||
let r := mkAppN r discrs
|
||
let r := mkAppN r rhss
|
||
trace[Elab.match]! "result: {r}"
|
||
pure r
|
||
|
||
private def getDiscrs (matchStx : Syntax) : Array Syntax :=
|
||
matchStx[1].getSepArgs
|
||
|
||
private def getMatchOptType (matchStx : Syntax) : Syntax :=
|
||
matchStx[2]
|
||
|
||
private def expandNonAtomicDiscrs? (matchStx : Syntax) : TermElabM (Option Syntax) :=
|
||
let matchOptType := getMatchOptType matchStx;
|
||
if matchOptType.isNone then do
|
||
let discrs := getDiscrs matchStx;
|
||
let allLocal ← discrs.allM fun discr => Option.isSome <$> isLocalIdent? discr[1]
|
||
if allLocal then
|
||
pure none
|
||
else
|
||
let rec loop (discrs : List Syntax) (discrsNew : Array Syntax) := do
|
||
match discrs with
|
||
| [] =>
|
||
let discrs := Syntax.mkSep discrsNew (mkAtomFrom matchStx ", ");
|
||
pure (matchStx.setArg 1 discrs)
|
||
| discr :: discrs =>
|
||
-- Recall that
|
||
-- matchDiscr := parser! optional (ident >> ":") >> termParser
|
||
let term := discr[1]
|
||
match (← isLocalIdent? term) with
|
||
| some _ => loop discrs (discrsNew.push discr)
|
||
| none => withFreshMacroScope do
|
||
let d ← `(_discr);
|
||
unless isAuxDiscrName d.getId do -- Use assertion?
|
||
throwError "unexpected internal auxiliary discriminant name"
|
||
let discrNew := discr.setArg 1 d;
|
||
let r ← loop discrs (discrsNew.push discrNew)
|
||
`(let _discr := $term; $r)
|
||
pure (some (← loop discrs.toList #[]))
|
||
else
|
||
-- We do not pull non atomic discriminants when match type is provided explicitly by the user
|
||
pure none
|
||
|
||
private def waitExpectedType (expectedType? : Option Expr) : TermElabM Expr := do
|
||
tryPostponeIfNoneOrMVar expectedType?
|
||
match expectedType? with
|
||
| some expectedType => pure expectedType
|
||
| none => mkFreshTypeMVar
|
||
|
||
private def tryPostponeIfDiscrTypeIsMVar (matchStx : Syntax) : TermElabM Unit := do
|
||
-- We don't wait for the discriminants types when match type is provided by user
|
||
if getMatchOptType matchStx |>.isNone then
|
||
let discrs := getDiscrs matchStx
|
||
for discr in discrs do
|
||
let term := discr[1]
|
||
match (← isLocalIdent? term) with
|
||
| none => throwErrorAt discr "unexpected discriminant" -- see `expandNonAtomicDiscrs?
|
||
| some d =>
|
||
let dType ← inferType d
|
||
trace[Elab.match]! "discr {d} : {dType}"
|
||
tryPostponeIfMVar dType
|
||
|
||
/-
|
||
We (try to) elaborate a `match` only when the expected type is available.
|
||
If the `matchType` has not been provided by the user, we also try to postpone elaboration if the type
|
||
of a discriminant is not available. That is, it is of the form `(?m ...)`.
|
||
We use `expandNonAtomicDiscrs?` to make sure all discriminants are local variables.
|
||
This is a standard trick we use in the elaborator, and it is also used to elaborate structure instances.
|
||
Suppose, we are trying to elaborate
|
||
```
|
||
match g x with
|
||
| ... => ...
|
||
```
|
||
`expandNonAtomicDiscrs?` converts it intro
|
||
```
|
||
let _discr := g x
|
||
match _discr with
|
||
| ... => ...
|
||
```
|
||
Thus, at `tryPostponeIfDiscrTypeIsMVar` we only need to check whether the type of `_discr` is not of the form `(?m ...)`.
|
||
Note that, the auxiliary variable `_discr` is expanded at `elabAtomicDiscr`.
|
||
|
||
This elaboration technique is needed to elaborate terms such as:
|
||
```lean
|
||
xs.filter fun (a, b) => a > b
|
||
```
|
||
which are syntax sugar for
|
||
```lean
|
||
List.filter (fun p => match p with | (a, b) => a > b) xs
|
||
```
|
||
When we visit `match p with | (a, b) => a > b`, we don't know the type of `p` yet.
|
||
-/
|
||
private def waitExpectedTypeAndDiscrs (matchStx : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do
|
||
tryPostponeIfNoneOrMVar expectedType?
|
||
tryPostponeIfDiscrTypeIsMVar matchStx
|
||
match expectedType? with
|
||
| some expectedType => pure expectedType
|
||
| none => mkFreshTypeMVar
|
||
|
||
/-
|
||
```
|
||
parser!:leadPrec "match " >> sepBy1 matchDiscr ", " >> optType >> " with " >> matchAlts
|
||
```
|
||
Remark the `optIdent` must be `none` at `matchDiscr`. They are expanded by `expandMatchDiscr?`.
|
||
-/
|
||
private def elabMatchCore (stx : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do
|
||
let expectedType ← waitExpectedTypeAndDiscrs stx expectedType?
|
||
let discrStxs := (getDiscrs stx).map fun d => d
|
||
let altViews := getMatchAlts stx
|
||
let matchOptType := getMatchOptType stx
|
||
elabMatchAux discrStxs altViews matchOptType expectedType
|
||
|
||
-- parser! "match " >> sepBy1 termParser ", " >> optType >> " with " >> matchAlts
|
||
@[builtinTermElab «match»] def elabMatch : TermElab := fun stx expectedType? =>
|
||
match stx with
|
||
| `(match $discr:term with | $y:ident => $rhs:term) => expandSimpleMatch stx discr y rhs expectedType?
|
||
| `(match $discr:term : $type with | $y:ident => $rhs:term) => expandSimpleMatchWithType stx discr y type rhs expectedType?
|
||
| _ => do
|
||
match (← expandNonAtomicDiscrs? stx) with
|
||
| some stxNew => withMacroExpansion stx stxNew $ elabTerm stxNew expectedType?
|
||
| none =>
|
||
let discrs := getDiscrs stx;
|
||
let matchOptType := getMatchOptType stx;
|
||
if !matchOptType.isNone && discrs.any fun d => !d[0].isNone then
|
||
throwErrorAt matchOptType "match expected type should not be provided when discriminants with equality proofs are used"
|
||
elabMatchCore stx expectedType?
|
||
|
||
@[builtinInit] private def regTraceClasses : IO Unit := do
|
||
registerTraceClass `Elab.match;
|
||
pure ()
|
||
|
||
-- parser!:leadPrec "nomatch " >> termParser
|
||
@[builtinTermElab «nomatch»] def elabNoMatch : TermElab := fun stx expectedType? =>
|
||
match stx with
|
||
| `(nomatch $discrExpr) => do
|
||
let expectedType ← waitExpectedType expectedType?
|
||
let discr := Syntax.node `Lean.Parser.Term.matchDiscr #[mkNullNode, discrExpr]
|
||
elabMatchAux #[discr] #[] mkNullNode expectedType
|
||
| _ => throwUnsupportedSyntax
|
||
|
||
end Term
|
||
end Elab
|
||
end Lean
|