986 lines
44 KiB
Text
986 lines
44 KiB
Text
/-
|
||
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
|
||
Released under Apache 2.0 license as described in the file LICENSE.
|
||
Authors: Leonardo de Moura
|
||
-/
|
||
import Lean.Util.CollectFVars
|
||
import Lean.Meta.Match.MatchPatternAttr
|
||
import Lean.Meta.Match.Match
|
||
import Lean.Meta.SortLocalDecls
|
||
import Lean.Meta.GeneralizeVars
|
||
import Lean.Elab.SyntheticMVars
|
||
import Lean.Elab.Arg
|
||
import Lean.Parser.Term
|
||
import Lean.Elab.PatternVar
|
||
|
||
namespace Lean.Elab.Term
|
||
open Meta
|
||
open Lean.Parser.Term
|
||
|
||
private def expandSimpleMatch (stx discr lhsVar rhs : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do
|
||
let newStx ← `(let $lhsVar := $discr; $rhs)
|
||
withMacroExpansion stx newStx <| elabTerm newStx expectedType?
|
||
|
||
private def mkUserNameFor (e : Expr) : TermElabM Name := do
|
||
match e with
|
||
/- Remark: we use `mkFreshUserName` to make sure we don't add a variable to the local context that can be resolved to `e`. -/
|
||
| Expr.fvar fvarId _ => mkFreshUserName ((← getLocalDecl fvarId).userName)
|
||
| _ => mkFreshBinderName
|
||
|
||
/-- Return true iff `n` is an auxiliary variable created by `expandNonAtomicDiscrs?` -/
|
||
def isAuxDiscrName (n : Name) : Bool :=
|
||
n.hasMacroScopes && n.eraseMacroScopes == `_discr
|
||
|
||
/- We treat `@x` as atomic to avoid unnecessary extra local declarations from being
|
||
inserted into the local context. Recall that `expandMatchAltsIntoMatch` uses `@` modifier.
|
||
Thus this is kind of discriminant is quite common.
|
||
|
||
Remark: if the discriminat is `Systax.missing`, we abort the elaboration of the `match`-expression.
|
||
This can happen due to error recovery. Example
|
||
```
|
||
example : (p ∨ p) → p := fun h => match
|
||
```
|
||
If we don't abort, the elaborator loops because we will keep trying to expand
|
||
```
|
||
match
|
||
```
|
||
into
|
||
```
|
||
let d := <Syntax.missing>; match
|
||
```
|
||
Recall that `Syntax.setArg stx i arg` is a no-op when `i` is out-of-bounds. -/
|
||
def isAtomicDiscr? (discr : Syntax) : TermElabM (Option Expr) := do
|
||
match discr with
|
||
| `($x:ident) => isLocalIdent? x
|
||
| `(@$x:ident) => isLocalIdent? x
|
||
| _ => if discr.isMissing then throwAbortTerm else return none
|
||
|
||
-- See expandNonAtomicDiscrs?
|
||
private def elabAtomicDiscr (discr : Syntax) : TermElabM Expr := do
|
||
let term := discr[1]
|
||
match (← isAtomicDiscr? term) with
|
||
| some e@(Expr.fvar fvarId _) =>
|
||
let localDecl ← getLocalDecl fvarId
|
||
if !isAuxDiscrName localDecl.userName then
|
||
return e -- it is not an auxiliary local created by `expandNonAtomicDiscrs?`
|
||
else
|
||
instantiateMVars localDecl.value
|
||
| _ => throwErrorAt discr "unexpected discriminant"
|
||
|
||
structure ElabMatchTypeAndDiscrsResult where
|
||
discrs : Array Expr
|
||
matchType : Expr
|
||
/- `true` when performing dependent elimination. We use this to decide whether we optimize the "match unit" case.
|
||
See `isMatchUnit?`. -/
|
||
isDep : Bool
|
||
alts : Array MatchAltView
|
||
|
||
private partial def elabMatchTypeAndDiscrs (discrStxs : Array Syntax) (matchOptType : Syntax) (matchAltViews : Array MatchAltView) (expectedType : Expr)
|
||
: TermElabM ElabMatchTypeAndDiscrsResult := do
|
||
let numDiscrs := discrStxs.size
|
||
if matchOptType.isNone then
|
||
elabDiscrs 0 #[]
|
||
else
|
||
let matchTypeStx := matchOptType[0][1]
|
||
let matchType ← elabType matchTypeStx
|
||
let (discrs, isDep) ← elabDiscrsWitMatchType matchType expectedType
|
||
return { discrs := discrs, matchType := matchType, isDep := isDep, alts := matchAltViews }
|
||
where
|
||
/- Easy case: elaborate discriminant when the match-type has been explicitly provided by the user. -/
|
||
elabDiscrsWitMatchType (matchType : Expr) (expectedType : Expr) : TermElabM (Array Expr × Bool) := do
|
||
let mut discrs := #[]
|
||
let mut i := 0
|
||
let mut matchType := matchType
|
||
let mut isDep := false
|
||
for discrStx in discrStxs do
|
||
i := i + 1
|
||
matchType ← whnf matchType
|
||
match matchType with
|
||
| Expr.forallE _ d b _ =>
|
||
let discr ← fullApproxDefEq <| elabTermEnsuringType discrStx[1] d
|
||
trace[Elab.match] "discr #{i} {discr} : {d}"
|
||
if b.hasLooseBVars then
|
||
isDep := true
|
||
matchType ← b.instantiate1 discr
|
||
discrs := discrs.push discr
|
||
| _ =>
|
||
throwError "invalid type provided to match-expression, function type with arity #{discrStxs.size} expected"
|
||
return (discrs, isDep)
|
||
|
||
markIsDep (r : ElabMatchTypeAndDiscrsResult) :=
|
||
{ r with isDep := true }
|
||
|
||
/- Elaborate discriminants inferring the match-type -/
|
||
elabDiscrs (i : Nat) (discrs : Array Expr) : TermElabM ElabMatchTypeAndDiscrsResult := do
|
||
if h : i < discrStxs.size then
|
||
let discrStx := discrStxs.get ⟨i, h⟩
|
||
let discr ← elabAtomicDiscr discrStx
|
||
let discr ← instantiateMVars discr
|
||
let discrType ← inferType discr
|
||
let discrType ← instantiateMVars discrType
|
||
let discrs := discrs.push discr
|
||
let userName ← mkUserNameFor discr
|
||
if discrStx[0].isNone then
|
||
let mut result ← elabDiscrs (i + 1) discrs
|
||
let matchTypeBody ← kabstract result.matchType discr
|
||
if matchTypeBody.hasLooseBVars then
|
||
result := markIsDep result
|
||
return { result with matchType := Lean.mkForall userName BinderInfo.default discrType matchTypeBody }
|
||
else
|
||
let discrs := discrs.push (← mkEqRefl discr)
|
||
let result ← elabDiscrs (i + 1) discrs
|
||
let result := markIsDep result
|
||
let identStx := discrStx[0][0]
|
||
withLocalDeclD userName discrType fun x => do
|
||
let eqType ← mkEq discr x
|
||
withLocalDeclD identStx.getId eqType fun h => do
|
||
let matchTypeBody ← kabstract result.matchType discr
|
||
let matchTypeBody := matchTypeBody.instantiate1 x
|
||
let matchType ← mkForallFVars #[x, h] matchTypeBody
|
||
return { result with
|
||
matchType := matchType
|
||
alts := result.alts.map fun altView => { altView with patterns := altView.patterns.insertAt (i+1) identStx }
|
||
}
|
||
else
|
||
return { discrs, alts := matchAltViews, isDep := false, matchType := expectedType }
|
||
|
||
def expandMacrosInPatterns (matchAlts : Array MatchAltView) : MacroM (Array MatchAltView) := do
|
||
matchAlts.mapM fun matchAlt => do
|
||
let patterns ← matchAlt.patterns.mapM expandMacros
|
||
pure { matchAlt with patterns := patterns }
|
||
|
||
private def getMatchGeneralizing? : Syntax → Option Bool
|
||
| `(match (generalizing := true) $discrs,* $[: $ty?]? with $alts:matchAlt*) => some true
|
||
| `(match (generalizing := false) $discrs,* $[: $ty?]? with $alts:matchAlt*) => some false
|
||
| _ => none
|
||
|
||
/- Given `stx` a match-expression, return its alternatives. -/
|
||
private def getMatchAlts : Syntax → Array MatchAltView
|
||
| `(match $[$gen]? $discrs,* $[: $ty?]? with $alts:matchAlt*) =>
|
||
alts.filterMap fun alt => match alt with
|
||
| `(matchAltExpr| | $patterns,* => $rhs) => some {
|
||
ref := alt,
|
||
patterns := patterns,
|
||
rhs := rhs
|
||
}
|
||
| _ => none
|
||
| _ => #[]
|
||
|
||
builtin_initialize Parser.registerBuiltinNodeKind `MVarWithIdKind
|
||
|
||
open Meta.Match (mkInaccessible inaccessible?)
|
||
|
||
/--
|
||
The elaboration function for `Syntax` created using `mkMVarSyntax`.
|
||
It just converts the metavariable id wrapped by the Syntax into an `Expr`. -/
|
||
@[builtinTermElab MVarWithIdKind] def elabMVarWithIdKind : TermElab := fun stx expectedType? =>
|
||
return mkInaccessible <| mkMVar (getMVarSyntaxMVarId stx)
|
||
|
||
@[builtinTermElab inaccessible] def elabInaccessible : TermElab := fun stx expectedType? => do
|
||
let e ← elabTerm stx[1] expectedType?
|
||
return mkInaccessible e
|
||
|
||
open Lean.Elab.Term.Quotation in
|
||
@[builtinQuotPrecheck Lean.Parser.Term.match] def precheckMatch : Precheck
|
||
| `(match $[$discrs:term],* with $[| $[$patss],* => $rhss]*) => do
|
||
discrs.forM precheck
|
||
for (pats, rhs) in patss.zip rhss do
|
||
let vars ←
|
||
try
|
||
getPatternsVars pats
|
||
catch
|
||
| _ => return -- can happen in case of pattern antiquotations
|
||
Quotation.withNewLocals (getPatternVarNames vars) <| precheck rhs
|
||
| _ => throwUnsupportedSyntax
|
||
|
||
/- We convert the collected `PatternVar`s intro `PatternVarDecl` -/
|
||
inductive PatternVarDecl where
|
||
/- For `anonymousVar`, we create both a metavariable and a free variable. The free variable is used as an assignment for the metavariable
|
||
when it is not assigned during pattern elaboration. -/
|
||
| anonymousVar (mvarId : MVarId) (fvarId : FVarId)
|
||
| localVar (fvarId : FVarId)
|
||
|
||
private partial def withPatternVars {α} (pVars : Array PatternVar) (k : Array PatternVarDecl → TermElabM α) : TermElabM α :=
|
||
let rec loop (i : Nat) (decls : Array PatternVarDecl) := do
|
||
if h : i < pVars.size then
|
||
match pVars.get ⟨i, h⟩ with
|
||
| PatternVar.anonymousVar mvarId =>
|
||
let type ← mkFreshTypeMVar
|
||
let userName ← mkFreshBinderName
|
||
withLocalDecl userName BinderInfo.default type fun x =>
|
||
loop (i+1) (decls.push (PatternVarDecl.anonymousVar mvarId x.fvarId!))
|
||
| PatternVar.localVar userName =>
|
||
let type ← mkFreshTypeMVar
|
||
withLocalDecl userName BinderInfo.default type fun x =>
|
||
loop (i+1) (decls.push (PatternVarDecl.localVar x.fvarId!))
|
||
else
|
||
/- We must create the metavariables for `PatternVar.anonymousVar` AFTER we create the new local decls using `withLocalDecl`.
|
||
Reason: their scope must include the new local decls since some of them are assigned by typing constraints. -/
|
||
decls.forM fun decl => match decl with
|
||
| PatternVarDecl.anonymousVar mvarId fvarId => do
|
||
let type ← inferType (mkFVar fvarId)
|
||
discard <| mkFreshExprMVarWithId mvarId type
|
||
| _ => pure ()
|
||
k decls
|
||
loop 0 #[]
|
||
|
||
/-
|
||
Remark: when performing dependent pattern matching, we often had to write code such as
|
||
|
||
```lean
|
||
def Vec.map' (f : α → β) (xs : Vec α n) : Vec β n :=
|
||
match n, xs with
|
||
| _, nil => nil
|
||
| _, cons a as => cons (f a) (map' f as)
|
||
```
|
||
We had to include `n` and the `_`s because the type of `xs` depends on `n`.
|
||
Moreover, `nil` and `cons a as` have different types.
|
||
This was quite tedious. So, we have implemented an automatic "discriminant refinement procedure".
|
||
The procedure is based on the observation that we get a type error whenenver we forget to include `_`s
|
||
and the indices a discriminant depends on. So, we catch the exception, check whether the type of the discriminant
|
||
is an indexed family, and add their indices as new discriminants.
|
||
|
||
The current implementation, adds indices as they are found, and does not
|
||
try to "sort" the new discriminants.
|
||
|
||
If the refinement process fails, we report the original error message.
|
||
-/
|
||
|
||
/- Auxiliary structure for storing an type mismatch exception when processing the
|
||
pattern #`idx` of some alternative. -/
|
||
structure PatternElabException where
|
||
ex : Exception
|
||
patternIdx : Nat -- Discriminant that sh
|
||
pathToIndex : List Nat -- Path to the problematic inductive type index that produced the type mismatch
|
||
|
||
/--
|
||
This method is part of the "discriminant refinement" procedure. It in invoked when the
|
||
type of the `pattern` does not match the expected type. The expected type is based on the
|
||
motive computed using the `match` discriminants.
|
||
It tries to compute a path to an index of the discriminant type.
|
||
For example, suppose the user has written
|
||
```
|
||
inductive Mem (a : α) : List α → Prop where
|
||
| head {as} : Mem a (a::as)
|
||
| tail {as} : Mem a as → Mem a (a'::as)
|
||
|
||
infix:50 " ∈ " => Mem
|
||
|
||
example (a b : Nat) (h : a ∈ [b]) : b = a :=
|
||
match h with
|
||
| Mem.head => rfl
|
||
```
|
||
The motive for the match is `a ∈ [b] → b = a`, and get a type mismatch between the type
|
||
of `Mem.head` and `a ∈ [b]`. This procedure return the path `[2, 1]` to the index `b`.
|
||
We use it to produce the following refinement
|
||
```
|
||
example (a b : Nat) (h : a ∈ [b]) : b = a :=
|
||
match b, h with
|
||
| _, Mem.head => rfl
|
||
```
|
||
which produces the new motive `(x : Nat) → a ∈ [x] → x = a`
|
||
After this refinement step, the `match` is elaborated successfully.
|
||
|
||
This method relies on the fact that the dependent pattern matcher compiler solves equations
|
||
between indices of indexed inductive families.
|
||
The following kinds of equations are supported by this compiler:
|
||
- `x = t`
|
||
- `t = x`
|
||
- `ctor ... = ctor ...`
|
||
|
||
where `x` is a free variable, `t` is an arbitrary term, and `ctor` is constructor.
|
||
Our procedure ensures that "information" is not lost, and will *not* succeed in an
|
||
example such as
|
||
```
|
||
example (a b : Nat) (f : Nat → Nat) (h : f a ∈ [f b]) : f b = f a :=
|
||
match h with
|
||
| Mem.head => rfl
|
||
```
|
||
and will not add `f b` as a new discriminant. We may add an option in the future to
|
||
enable this more liberal form of refinement.
|
||
-/
|
||
private partial def findDiscrRefinementPath (pattern : Expr) (expected : Expr) : OptionT MetaM (List Nat) := do
|
||
goType (← instantiateMVars (← inferType pattern)) expected
|
||
where
|
||
checkCompatibleApps (t d : Expr) : OptionT MetaM Unit := do
|
||
guard d.isApp
|
||
guard <| t.getAppNumArgs == d.getAppNumArgs
|
||
let tFn := t.getAppFn
|
||
let dFn := d.getAppFn
|
||
guard <| tFn.isConst && dFn.isConst
|
||
guard (← isDefEq tFn dFn)
|
||
|
||
-- Visitor for inductive types
|
||
goType (t d : Expr) : OptionT MetaM (List Nat) := do
|
||
trace[Meta.debug] "type {t} =?= {d}"
|
||
let t ← whnf t
|
||
let d ← whnf d
|
||
checkCompatibleApps t d
|
||
matchConstInduct t.getAppFn (fun _ => failure) fun info _ => do
|
||
let tArgs := t.getAppArgs
|
||
let dArgs := d.getAppArgs
|
||
for i in [:info.numParams] do
|
||
let tArg := tArgs[i]
|
||
let dArg := dArgs[i]
|
||
unless (← isDefEq tArg dArg) do
|
||
return i :: (← goType tArg dArg)
|
||
for i in [info.numParams : tArgs.size] do
|
||
let tArg := tArgs[i]
|
||
let dArg := dArgs[i]
|
||
unless (← isDefEq tArg dArg) do
|
||
return i :: (← goIndex tArg dArg)
|
||
failure
|
||
|
||
-- Visitor for indexed families
|
||
goIndex (t d : Expr) : OptionT MetaM (List Nat) := do
|
||
let t ← whnfD t
|
||
let d ← whnfD d
|
||
if t.isFVar || d.isFVar then
|
||
return [] -- Found refinement path
|
||
else
|
||
trace[Meta.debug] "index {t} =?= {d}"
|
||
checkCompatibleApps t d
|
||
matchConstCtor t.getAppFn (fun _ => failure) fun info _ => do
|
||
let tArgs := t.getAppArgs
|
||
let dArgs := d.getAppArgs
|
||
for i in [:info.numParams] do
|
||
let tArg := tArgs[i]
|
||
let dArg := dArgs[i]
|
||
unless (← isDefEq tArg dArg) do
|
||
failure
|
||
for i in [info.numParams : tArgs.size] do
|
||
let tArg := tArgs[i]
|
||
let dArg := dArgs[i]
|
||
unless (← isDefEq tArg dArg) do
|
||
return i :: (← goIndex tArg dArg)
|
||
failure
|
||
|
||
private partial def eraseIndices (type : Expr) : MetaM Expr := do
|
||
let type' ← whnfD type
|
||
matchConstInduct type'.getAppFn (fun _ => return type) fun info _ => do
|
||
let args := type'.getAppArgs
|
||
let params ← args[:info.numParams].toArray.mapM eraseIndices
|
||
let result := mkAppN type'.getAppFn params
|
||
let resultType ← inferType result
|
||
let (newIndices, _, _) ← forallMetaTelescopeReducing resultType (some (args.size - info.numParams))
|
||
return mkAppN result newIndices
|
||
|
||
private def elabPatterns (patternStxs : Array Syntax) (matchType : Expr) : ExceptT PatternElabException TermElabM (Array Expr × Expr) :=
|
||
withReader (fun ctx => { ctx with implicitLambda := false }) do
|
||
let mut patterns := #[]
|
||
let mut matchType := matchType
|
||
for idx in [:patternStxs.size] do
|
||
let patternStx := patternStxs[idx]
|
||
matchType ← whnf matchType
|
||
match matchType with
|
||
| Expr.forallE _ d b _ =>
|
||
let pattern ← do
|
||
let s ← saveState
|
||
try
|
||
liftM <| withSynthesize <| withoutErrToSorry <| elabTermEnsuringType patternStx d
|
||
catch ex : Exception =>
|
||
restoreState s
|
||
match (← liftM <| commitIfNoErrors? <| withoutErrToSorry do elabTermAndSynthesize patternStx (← eraseIndices d)) with
|
||
| some pattern =>
|
||
match (← findDiscrRefinementPath pattern d |>.run) with
|
||
| some path =>
|
||
trace[Meta.debug] "refinement path: {path}"
|
||
restoreState s
|
||
-- Wrap the type mismatch exception for the "discriminant refinement" feature.
|
||
throwThe PatternElabException { ex := ex, patternIdx := idx, pathToIndex := path }
|
||
| none => restoreState s; throw ex
|
||
| none => throw ex
|
||
matchType := b.instantiate1 pattern
|
||
patterns := patterns.push pattern
|
||
| _ => throwError "unexpected match type"
|
||
return (patterns, matchType)
|
||
|
||
def finalizePatternDecls (patternVarDecls : Array PatternVarDecl) : TermElabM (Array LocalDecl) := do
|
||
let mut decls := #[]
|
||
for pdecl in patternVarDecls do
|
||
match pdecl with
|
||
| PatternVarDecl.localVar fvarId =>
|
||
let decl ← getLocalDecl fvarId
|
||
let decl ← instantiateLocalDeclMVars decl
|
||
decls := decls.push decl
|
||
| PatternVarDecl.anonymousVar mvarId fvarId =>
|
||
let e ← instantiateMVars (mkMVar mvarId);
|
||
trace[Elab.match] "finalizePatternDecls: mvarId: {mvarId} := {e}, fvar: {mkFVar fvarId}"
|
||
match e with
|
||
| Expr.mvar newMVarId _ =>
|
||
/- Metavariable was not assigned, or assigned to another metavariable. So,
|
||
we assign to the auxiliary free variable we created at `withPatternVars` to `newMVarId`. -/
|
||
assignExprMVar newMVarId (mkFVar fvarId)
|
||
trace[Elab.match] "finalizePatternDecls: {mkMVar newMVarId} := {mkFVar fvarId}"
|
||
let decl ← getLocalDecl fvarId
|
||
let decl ← instantiateLocalDeclMVars decl
|
||
decls := decls.push decl
|
||
| _ => pure ()
|
||
/- We perform a topological sort (dependecies) on `decls` because the pattern elaboration process may produce a sequence where a declaration d₁ may occur after d₂ when d₂ depends on d₁. -/
|
||
sortLocalDecls decls
|
||
|
||
open Meta.Match (Pattern Pattern.var Pattern.inaccessible Pattern.ctor Pattern.as Pattern.val Pattern.arrayLit AltLHS MatcherResult)
|
||
|
||
namespace ToDepElimPattern
|
||
|
||
structure State where
|
||
found : NameSet := {}
|
||
localDecls : Array LocalDecl
|
||
newLocals : NameSet := {}
|
||
|
||
abbrev M := StateRefT State TermElabM
|
||
|
||
private def alreadyVisited (fvarId : FVarId) : M Bool := do
|
||
let s ← get
|
||
return s.found.contains fvarId
|
||
|
||
private def markAsVisited (fvarId : FVarId) : M Unit :=
|
||
modify fun s => { s with found := s.found.insert fvarId }
|
||
|
||
private def throwInvalidPattern {α} (e : Expr) : M α :=
|
||
throwError "invalid pattern {indentExpr e}"
|
||
|
||
/- Create a new LocalDecl `x` for the metavariable `mvar`, and return `Pattern.var x` -/
|
||
private def mkLocalDeclFor (mvar : Expr) : M Pattern := do
|
||
let mvarId := mvar.mvarId!
|
||
let s ← get
|
||
match (← getExprMVarAssignment? mvarId) with
|
||
| some val => return Pattern.inaccessible val
|
||
| none =>
|
||
let fvarId ← mkFreshId
|
||
let type ← inferType mvar
|
||
/- HACK: `fvarId` is not in the scope of `mvarId`
|
||
If this generates problems in the future, we should update the metavariable declarations. -/
|
||
assignExprMVar mvarId (mkFVar fvarId)
|
||
let userName ← mkFreshBinderName
|
||
let newDecl := LocalDecl.cdecl arbitrary fvarId userName type BinderInfo.default;
|
||
modify fun s =>
|
||
{ s with
|
||
newLocals := s.newLocals.insert fvarId,
|
||
localDecls :=
|
||
match s.localDecls.findIdx? fun decl => mvar.occurs decl.type with
|
||
| none => s.localDecls.push newDecl -- None of the existing declarations depend on `mvar`
|
||
| some i => s.localDecls.insertAt i newDecl }
|
||
return Pattern.var fvarId
|
||
|
||
partial def main (e : Expr) : M Pattern := do
|
||
let isLocalDecl (fvarId : FVarId) : M Bool := do
|
||
return (← get).localDecls.any fun d => d.fvarId == fvarId
|
||
let mkPatternVar (fvarId : FVarId) (e : Expr) : M Pattern := do
|
||
if (← alreadyVisited fvarId) then
|
||
return Pattern.inaccessible e
|
||
else
|
||
markAsVisited fvarId
|
||
return Pattern.var e.fvarId!
|
||
let mkInaccessible (e : Expr) : M Pattern := do
|
||
match e with
|
||
| Expr.fvar fvarId _ =>
|
||
if (← isLocalDecl fvarId) then
|
||
mkPatternVar fvarId e
|
||
else
|
||
return Pattern.inaccessible e
|
||
| _ =>
|
||
return Pattern.inaccessible e
|
||
match inaccessible? e with
|
||
| some t => mkInaccessible t
|
||
| none =>
|
||
match e.arrayLit? with
|
||
| some (α, lits) =>
|
||
return Pattern.arrayLit α (← lits.mapM main)
|
||
| none =>
|
||
if e.isAppOfArity `namedPattern 3 then
|
||
let p ← main <| e.getArg! 2
|
||
match e.getArg! 1 with
|
||
| Expr.fvar fvarId _ => return Pattern.as fvarId p
|
||
| _ => throwError "unexpected occurrence of auxiliary declaration 'namedPattern'"
|
||
else if e.isNatLit || e.isStringLit || e.isCharLit then
|
||
return Pattern.val e
|
||
else if e.isFVar then
|
||
let fvarId := e.fvarId!
|
||
unless (← isLocalDecl fvarId) do
|
||
throwInvalidPattern e
|
||
mkPatternVar fvarId e
|
||
else if e.isMVar then
|
||
mkLocalDeclFor e
|
||
else
|
||
let newE ← whnf e
|
||
if newE != e then
|
||
main newE
|
||
else matchConstCtor e.getAppFn (fun _ => throwInvalidPattern e) fun v us => do
|
||
let args := e.getAppArgs
|
||
unless args.size == v.numParams + v.numFields do
|
||
throwInvalidPattern e
|
||
let params := args.extract 0 v.numParams
|
||
let fields := args.extract v.numParams args.size
|
||
let fields ← fields.mapM main
|
||
return Pattern.ctor v.name us params.toList fields.toList
|
||
|
||
end ToDepElimPattern
|
||
|
||
def withDepElimPatterns {α} (localDecls : Array LocalDecl) (ps : Array Expr) (k : Array LocalDecl → Array Pattern → TermElabM α) : TermElabM α := do
|
||
let (patterns, s) ← (ps.mapM ToDepElimPattern.main).run { localDecls := localDecls }
|
||
let localDecls ← s.localDecls.mapM fun d => instantiateLocalDeclMVars d
|
||
/- toDepElimPatterns may have added new localDecls. Thus, we must update the local context before we execute `k` -/
|
||
let lctx ← getLCtx
|
||
let lctx := localDecls.foldl (fun (lctx : LocalContext) d => lctx.erase d.fvarId) lctx
|
||
let lctx := localDecls.foldl (fun (lctx : LocalContext) d => lctx.addDecl d) lctx
|
||
withTheReader Meta.Context (fun ctx => { ctx with lctx := lctx }) do
|
||
k localDecls patterns
|
||
|
||
private def withElaboratedLHS {α} (ref : Syntax) (patternVarDecls : Array PatternVarDecl) (patternStxs : Array Syntax) (matchType : Expr)
|
||
(k : AltLHS → Expr → TermElabM α) : ExceptT PatternElabException TermElabM α := do
|
||
let (patterns, matchType) ← withSynthesize <| elabPatterns patternStxs matchType
|
||
id (α := TermElabM α) do
|
||
let localDecls ← finalizePatternDecls patternVarDecls
|
||
let patterns ← patterns.mapM (instantiateMVars ·)
|
||
withDepElimPatterns localDecls patterns fun localDecls patterns =>
|
||
k { ref := ref, fvarDecls := localDecls.toList, patterns := patterns.toList } matchType
|
||
|
||
private def elabMatchAltView (alt : MatchAltView) (matchType : Expr) : ExceptT PatternElabException TermElabM (AltLHS × Expr) := withRef alt.ref do
|
||
let (patternVars, alt) ← collectPatternVars alt
|
||
trace[Elab.match] "patternVars: {patternVars}"
|
||
withPatternVars patternVars fun patternVarDecls => do
|
||
withElaboratedLHS alt.ref patternVarDecls alt.patterns matchType fun altLHS matchType => do
|
||
let rhs ← elabTermEnsuringType alt.rhs matchType
|
||
let xs := altLHS.fvarDecls.toArray.map LocalDecl.toExpr
|
||
let rhs ← if xs.isEmpty then pure <| mkSimpleThunk rhs else mkLambdaFVars xs rhs
|
||
trace[Elab.match] "rhs: {rhs}"
|
||
return (altLHS, rhs)
|
||
|
||
/--
|
||
Collect problematic index for the "discriminant refinement feature". This method is invoked
|
||
when we detect a type mismatch at a pattern #`idx` of some alternative. -/
|
||
private partial def getIndexToInclude? (discr : Expr) (pathToIndex : List Nat) : TermElabM (Option Expr) := do
|
||
go (← inferType discr) pathToIndex |>.run
|
||
where
|
||
go (e : Expr) (path : List Nat) : OptionT MetaM Expr := do
|
||
match path with
|
||
| [] => return e
|
||
| i::path =>
|
||
let e ← whnfD e
|
||
guard <| e.isApp && i < e.getAppNumArgs
|
||
go (e.getArg! i) path
|
||
|
||
/--
|
||
"Generalize" variables that depend on the discriminants.
|
||
|
||
Remarks and limitations:
|
||
- If `matchType` is a proposition, then we generalize even when the user did not provide `(generalizing := true)`.
|
||
Motivation: users should have control about the actual `match`-expressions in their programs.
|
||
- We currently do not generalize let-decls.
|
||
- We abort generalization if the new `matchType` is type incorrect.
|
||
- Only discriminants that are free variables are considered during specialization.
|
||
- We "generalize" by adding new discriminants and pattern variables. We do not "clear" the generalized variables,
|
||
but they become inaccessible since they are shadowed by the patterns variables. We assume this is ok since
|
||
this is the exact behavior users would get if they had written it by hand. Recall there is no `clear` in term mode.
|
||
-/
|
||
private def generalize (discrs : Array Expr) (matchType : Expr) (altViews : Array MatchAltView) (generalizing? : Option Bool) : TermElabM (Array Expr × Expr × Array MatchAltView × Bool) := do
|
||
let gen ←
|
||
match generalizing? with
|
||
| some g => pure g
|
||
| _ => isProp matchType
|
||
if !gen then
|
||
return (discrs, matchType, altViews, false)
|
||
else
|
||
let ysFVarIds ← getFVarsToGeneralize discrs
|
||
/- let-decls are currently being ignored by the generalizer. -/
|
||
let ysFVarIds ← ysFVarIds.filterM fun fvarId => return !(← getLocalDecl fvarId).isLet
|
||
if ysFVarIds.isEmpty then
|
||
return (discrs, matchType, altViews, false)
|
||
else
|
||
let ys := ysFVarIds.map mkFVar
|
||
-- trace[Meta.debug] "ys: {ys}, discrs: {discrs}"
|
||
let matchType' ← forallBoundedTelescope matchType discrs.size fun ds type => do
|
||
let type ← mkForallFVars ys type
|
||
let (discrs', ds') := Array.unzip <| Array.zip discrs ds |>.filter fun (di, d) => di.isFVar
|
||
let type := type.replaceFVars discrs' ds'
|
||
mkForallFVars ds type
|
||
-- trace[Meta.debug] "matchType': {matchType'}"
|
||
if (← isTypeCorrect matchType') then
|
||
let discrs := discrs ++ ys
|
||
let altViews ← altViews.mapM fun altView => do
|
||
let patternVars ← getPatternsVars altView.patterns
|
||
-- We traverse backwards because we want to keep the most recent names.
|
||
-- For example, if `ys` contains `#[h, h]`, we want to make sure `mkFreshUsername is applied to the first `h`,
|
||
-- since it is already shadowed by the second.
|
||
let ysUserNames ← ys.foldrM (init := #[]) fun ys ysUserNames => do
|
||
let yDecl ← getLocalDecl ys.fvarId!
|
||
let mut yUserName := yDecl.userName
|
||
if ysUserNames.contains yUserName then
|
||
yUserName ← mkFreshUserName yUserName
|
||
-- Explicitly provided pattern variables shadow `y`
|
||
else if patternVars.any fun | PatternVar.localVar x => x == yUserName | _ => false then
|
||
yUserName ← mkFreshUserName yUserName
|
||
return ysUserNames.push yUserName
|
||
let ysIds ← ysUserNames.reverse.mapM fun n => return mkIdentFrom (← getRef) n
|
||
return { altView with patterns := altView.patterns ++ ysIds }
|
||
return (discrs, matchType', altViews, true)
|
||
else
|
||
return (discrs, matchType, altViews, true)
|
||
|
||
private partial def elabMatchAltViews (generalizing? : Option Bool) (discrs : Array Expr) (matchType : Expr) (altViews : Array MatchAltView) : TermElabM (Array Expr × Expr × Array (AltLHS × Expr) × Bool) := do
|
||
loop discrs matchType altViews none
|
||
where
|
||
/-
|
||
"Discriminant refinement" main loop.
|
||
`first?` contains the first error message we found before updated the `discrs`. -/
|
||
loop (discrs : Array Expr) (matchType : Expr) (altViews : Array MatchAltView) (first? : Option (SavedState × Exception))
|
||
: TermElabM (Array Expr × Expr × Array (AltLHS × Expr) × Bool) := do
|
||
let s ← saveState
|
||
let (discrs', matchType', altViews', refined) ← generalize discrs matchType altViews generalizing?
|
||
match (← altViews'.mapM (fun altView => elabMatchAltView altView matchType') |>.run) with
|
||
| Except.ok alts => return (discrs', matchType', alts, first?.isSome || refined)
|
||
| Except.error { patternIdx := patternIdx, pathToIndex := pathToIndex, ex := ex } =>
|
||
trace[Meta.debug] "pathToIndex: {toString pathToIndex}"
|
||
let some index ← getIndexToInclude? discrs[patternIdx] pathToIndex
|
||
| throwEx (← updateFirst first? ex)
|
||
trace[Meta.debug] "index: {index}"
|
||
if (← discrs.anyM fun discr => isDefEq discr index) then
|
||
throwEx (← updateFirst first? ex)
|
||
let first ← updateFirst first? ex
|
||
s.restore
|
||
let indices ← collectDeps #[index] discrs
|
||
let matchType ←
|
||
try
|
||
updateMatchType indices matchType
|
||
catch ex =>
|
||
throwEx first
|
||
let altViews ← addWildcardPatterns indices.size altViews
|
||
let discrs := indices ++ discrs
|
||
loop discrs matchType altViews first
|
||
|
||
throwEx {α} (p : SavedState × Exception) : TermElabM α := do
|
||
p.1.restore; throw p.2
|
||
|
||
updateFirst (first? : Option (SavedState × Exception)) (ex : Exception) : TermElabM (SavedState × Exception) := do
|
||
match first? with
|
||
| none => return (← saveState, ex)
|
||
| some first => return first
|
||
|
||
containsFVar (es : Array Expr) (fvarId : FVarId) : Bool :=
|
||
es.any fun e => e.isFVar && e.fvarId! == fvarId
|
||
|
||
/- Update `indices` by including any free variable `x` s.t.
|
||
- Type of some `discr` depends on `x`.
|
||
- Type of `x` depends on some free variable in `indices`.
|
||
|
||
If we don't include these extra variables in indices, then
|
||
`updateMatchType` will generate a type incorrect term.
|
||
For example, suppose `discr` contains `h : @HEq α a α b`, and
|
||
`indices` is `#[α, b]`, and `matchType` is `@HEq α a α b → B`.
|
||
`updateMatchType indices matchType` produces the type
|
||
`(α' : Type) → (b : α') → @HEq α' a α' b → B` which is type incorrect
|
||
because we have `a : α`.
|
||
The method `collectDeps` will include `a` into `indices`.
|
||
|
||
This method does not handle dependencies among non-free variables.
|
||
We rely on the type checking method `check` at `updateMatchType`.
|
||
|
||
Remark: `indices : Array Expr` does not need to be an array anymore.
|
||
We should cleanup this code, and use `index : Expr` instead.
|
||
-/
|
||
collectDeps (indices : Array Expr) (discrs : Array Expr) : TermElabM (Array Expr) := do
|
||
let mut s : CollectFVars.State := {}
|
||
for discr in discrs do
|
||
s := collectFVars s (← instantiateMVars (← inferType discr))
|
||
let (indicesFVar, indicesNonFVar) := indices.split Expr.isFVar
|
||
let indicesFVar := indicesFVar.map Expr.fvarId!
|
||
let mut toAdd := #[]
|
||
for fvarId in s.fvarSet.toList do
|
||
unless containsFVar discrs fvarId || containsFVar indices fvarId do
|
||
let localDecl ← getLocalDecl fvarId
|
||
let mctx ← getMCtx
|
||
for indexFVarId in indicesFVar do
|
||
if mctx.localDeclDependsOn localDecl indexFVarId then
|
||
toAdd := toAdd.push fvarId
|
||
let lctx ← getLCtx
|
||
let indicesFVar := (indicesFVar ++ toAdd).qsort fun fvarId₁ fvarId₂ =>
|
||
(lctx.get! fvarId₁).index < (lctx.get! fvarId₂).index
|
||
return indicesFVar.map mkFVar ++ indicesNonFVar
|
||
|
||
updateMatchType (indices : Array Expr) (matchType : Expr) : TermElabM Expr := do
|
||
let matchType ← indices.foldrM (init := matchType) fun index matchType => do
|
||
let indexType ← inferType index
|
||
let matchTypeBody ← kabstract matchType index
|
||
let userName ← mkUserNameFor index
|
||
return Lean.mkForall userName BinderInfo.default indexType matchTypeBody
|
||
check matchType
|
||
return matchType
|
||
|
||
addWildcardPatterns (num : Nat) (altViews : Array MatchAltView) : TermElabM (Array MatchAltView) := do
|
||
let hole := mkHole (← getRef)
|
||
let wildcards := mkArray num hole
|
||
return altViews.map fun altView => { altView with patterns := wildcards ++ altView.patterns }
|
||
|
||
def mkMatcher (input : Meta.Match.MkMatcherInput) : TermElabM MatcherResult :=
|
||
Meta.Match.mkMatcher input
|
||
|
||
register_builtin_option match.ignoreUnusedAlts : Bool := {
|
||
defValue := false
|
||
descr := "if true, do not generate error if an alternative is not used"
|
||
}
|
||
|
||
def reportMatcherResultErrors (altLHSS : List AltLHS) (result : MatcherResult) : TermElabM Unit := do
|
||
unless result.counterExamples.isEmpty do
|
||
withHeadRefOnly <| logError m!"missing cases:\n{Meta.Match.counterExamplesToMessageData result.counterExamples}"
|
||
unless match.ignoreUnusedAlts.get (← getOptions) || result.unusedAltIdxs.isEmpty do
|
||
let mut i := 0
|
||
for alt in altLHSS do
|
||
if result.unusedAltIdxs.contains i then
|
||
withRef alt.ref do
|
||
logError "redundant alternative"
|
||
i := i + 1
|
||
|
||
/--
|
||
If `altLHSS + rhss` is encoding `| PUnit.unit => rhs[0]`, return `rhs[0]`
|
||
Otherwise, return none.
|
||
-/
|
||
private def isMatchUnit? (altLHSS : List Match.AltLHS) (rhss : Array Expr) : MetaM (Option Expr) := do
|
||
assert! altLHSS.length == rhss.size
|
||
match altLHSS with
|
||
| [ { fvarDecls := [], patterns := [ Pattern.ctor `PUnit.unit .. ], .. } ] =>
|
||
/- Recall that for alternatives of the form `| PUnit.unit => rhs`, `rhss[0]` is of the form `fun _ : Unit => b`. -/
|
||
match rhss[0] with
|
||
| Expr.lam _ _ b _ => return if b.hasLooseBVars then none else b
|
||
| _ => return none
|
||
| _ => return none
|
||
private def elabMatchAux (generalizing? : Option Bool) (discrStxs : Array Syntax) (altViews : Array MatchAltView) (matchOptType : Syntax) (expectedType : Expr)
|
||
: TermElabM Expr := do
|
||
let mut generalizing? := generalizing?
|
||
if !matchOptType.isNone then
|
||
if generalizing? == some true then
|
||
throwError "the '(generalizing := true)' parameter is not supported when the 'match' type is explicitly provided"
|
||
generalizing? := some false
|
||
let (discrs, matchType, altLHSS, isDep, rhss) ← commitIfDidNotPostpone do
|
||
let ⟨discrs, matchType, isDep, altViews⟩ ← elabMatchTypeAndDiscrs discrStxs matchOptType altViews expectedType
|
||
let matchAlts ← liftMacroM <| expandMacrosInPatterns altViews
|
||
trace[Elab.match] "matchType: {matchType}"
|
||
let (discrs, matchType, alts, refined) ← elabMatchAltViews generalizing? discrs matchType matchAlts
|
||
let isDep := isDep || refined
|
||
/-
|
||
We should not use `synthesizeSyntheticMVarsNoPostponing` here. Otherwise, we will not be
|
||
able to elaborate examples such as:
|
||
```
|
||
def f (x : Nat) : Option Nat := none
|
||
|
||
def g (xs : List (Nat × Nat)) : IO Unit :=
|
||
xs.forM fun x =>
|
||
match f x.fst with
|
||
| _ => pure ()
|
||
```
|
||
If `synthesizeSyntheticMVarsNoPostponing`, the example above fails at `x.fst` because
|
||
the type of `x` is only available after we proces the last argument of `List.forM`.
|
||
|
||
We apply pending default types to make sure we can process examples such as
|
||
```
|
||
let (a, b) := (0, 0)
|
||
```
|
||
-/
|
||
synthesizeSyntheticMVarsUsingDefault
|
||
let rhss := alts.map Prod.snd
|
||
let matchType ← instantiateMVars matchType
|
||
let altLHSS ← alts.toList.mapM fun alt => do
|
||
let altLHS ← Match.instantiateAltLHSMVars alt.1
|
||
/- Remark: we try to postpone before throwing an error.
|
||
The combinator `commitIfDidNotPostpone` ensures we backtrack any updates that have been performed.
|
||
The quick-check `waitExpectedTypeAndDiscrs` minimizes the number of scenarios where we have to postpone here.
|
||
Here is an example that passes the `waitExpectedTypeAndDiscrs` test, but postpones here.
|
||
```
|
||
def bad (ps : Array (Nat × Nat)) : Array (Nat × Nat) :=
|
||
(ps.filter fun (p : Prod _ _) =>
|
||
match p with
|
||
| (x, y) => x == 0)
|
||
++
|
||
ps
|
||
```
|
||
When we try to elaborate `fun (p : Prod _ _) => ...` for the first time, we haven't propagated the type of `ps` yet
|
||
because `Array.filter` has type `{α : Type u_1} → (α → Bool) → (as : Array α) → optParam Nat 0 → optParam Nat (Array.size as) → Array α`
|
||
However, the partial type annotation `(p : Prod _ _)` makes sure we succeed at the quick-check `waitExpectedTypeAndDiscrs`.
|
||
-/
|
||
withRef altLHS.ref do
|
||
for d in altLHS.fvarDecls do
|
||
if d.hasExprMVar then
|
||
withExistingLocalDecls altLHS.fvarDecls do
|
||
tryPostpone
|
||
throwMVarError m!"invalid match-expression, type of pattern variable '{d.toExpr}' contains metavariables{indentExpr d.type}"
|
||
for p in altLHS.patterns do
|
||
if p.hasExprMVar then
|
||
withExistingLocalDecls altLHS.fvarDecls do
|
||
tryPostpone
|
||
throwMVarError m!"invalid match-expression, pattern contains metavariables{indentExpr (← p.toExpr)}"
|
||
pure altLHS
|
||
return (discrs, matchType, altLHSS, isDep, rhss)
|
||
if let some r ← if isDep then pure none else isMatchUnit? altLHSS rhss then
|
||
return r
|
||
else
|
||
let numDiscrs := discrs.size
|
||
let matcherName ← mkAuxName `match
|
||
let matcherResult ← mkMatcher { matcherName, matchType, numDiscrs, lhss := altLHSS }
|
||
matcherResult.addMatcher
|
||
let motive ← forallBoundedTelescope matchType numDiscrs fun xs matchType => mkLambdaFVars xs matchType
|
||
reportMatcherResultErrors altLHSS matcherResult
|
||
let r := mkApp matcherResult.matcher motive
|
||
let r := mkAppN r discrs
|
||
let r := mkAppN r rhss
|
||
trace[Elab.match] "result: {r}"
|
||
return r
|
||
|
||
private def getDiscrs (matchStx : Syntax) : Array Syntax :=
|
||
matchStx[2].getSepArgs
|
||
|
||
private def getMatchOptType (matchStx : Syntax) : Syntax :=
|
||
matchStx[3]
|
||
|
||
private def expandNonAtomicDiscrs? (matchStx : Syntax) : TermElabM (Option Syntax) :=
|
||
let matchOptType := getMatchOptType matchStx;
|
||
if matchOptType.isNone then do
|
||
let discrs := getDiscrs matchStx;
|
||
let allLocal ← discrs.allM fun discr => Option.isSome <$> isAtomicDiscr? discr[1]
|
||
if allLocal then
|
||
return none
|
||
else
|
||
-- We use `foundFVars` to make sure the discriminants are distinct variables.
|
||
-- See: code for computing "matchType" at `elabMatchTypeAndDiscrs`
|
||
let rec loop (discrs : List Syntax) (discrsNew : Array Syntax) (foundFVars : NameSet) := do
|
||
match discrs with
|
||
| [] =>
|
||
let discrs := Syntax.mkSep discrsNew (mkAtomFrom matchStx ", ");
|
||
pure (matchStx.setArg 2 discrs)
|
||
| discr :: discrs =>
|
||
-- Recall that
|
||
-- matchDiscr := leading_parser optional (ident >> ":") >> termParser
|
||
let term := discr[1]
|
||
let addAux : TermElabM Syntax := withFreshMacroScope do
|
||
let d ← `(_discr);
|
||
unless isAuxDiscrName d.getId do -- Use assertion?
|
||
throwError "unexpected internal auxiliary discriminant name"
|
||
let discrNew := discr.setArg 1 d;
|
||
let r ← loop discrs (discrsNew.push discrNew) foundFVars
|
||
`(let _discr := $term; $r)
|
||
match (← isAtomicDiscr? term) with
|
||
| some x => if x.isFVar then loop discrs (discrsNew.push discr) (foundFVars.insert x.fvarId!) else addAux
|
||
| none => addAux
|
||
return some (← loop discrs.toList #[] {})
|
||
else
|
||
-- We do not pull non atomic discriminants when match type is provided explicitly by the user
|
||
return none
|
||
|
||
private def waitExpectedType (expectedType? : Option Expr) : TermElabM Expr := do
|
||
tryPostponeIfNoneOrMVar expectedType?
|
||
match expectedType? with
|
||
| some expectedType => pure expectedType
|
||
| none => mkFreshTypeMVar
|
||
|
||
private def tryPostponeIfDiscrTypeIsMVar (matchStx : Syntax) : TermElabM Unit := do
|
||
-- We don't wait for the discriminants types when match type is provided by user
|
||
if getMatchOptType matchStx |>.isNone then
|
||
let discrs := getDiscrs matchStx
|
||
for discr in discrs do
|
||
let term := discr[1]
|
||
match (← isAtomicDiscr? term) with
|
||
| none => throwErrorAt discr "unexpected discriminant" -- see `expandNonAtomicDiscrs?
|
||
| some d =>
|
||
let dType ← inferType d
|
||
trace[Elab.match] "discr {d} : {dType}"
|
||
tryPostponeIfMVar dType
|
||
|
||
/-
|
||
We (try to) elaborate a `match` only when the expected type is available.
|
||
If the `matchType` has not been provided by the user, we also try to postpone elaboration if the type
|
||
of a discriminant is not available. That is, it is of the form `(?m ...)`.
|
||
We use `expandNonAtomicDiscrs?` to make sure all discriminants are local variables.
|
||
This is a standard trick we use in the elaborator, and it is also used to elaborate structure instances.
|
||
Suppose, we are trying to elaborate
|
||
```
|
||
match g x with
|
||
| ... => ...
|
||
```
|
||
`expandNonAtomicDiscrs?` converts it intro
|
||
```
|
||
let _discr := g x
|
||
match _discr with
|
||
| ... => ...
|
||
```
|
||
Thus, at `tryPostponeIfDiscrTypeIsMVar` we only need to check whether the type of `_discr` is not of the form `(?m ...)`.
|
||
Note that, the auxiliary variable `_discr` is expanded at `elabAtomicDiscr`.
|
||
|
||
This elaboration technique is needed to elaborate terms such as:
|
||
```lean
|
||
xs.filter fun (a, b) => a > b
|
||
```
|
||
which are syntax sugar for
|
||
```lean
|
||
List.filter (fun p => match p with | (a, b) => a > b) xs
|
||
```
|
||
When we visit `match p with | (a, b) => a > b`, we don't know the type of `p` yet.
|
||
-/
|
||
private def waitExpectedTypeAndDiscrs (matchStx : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do
|
||
tryPostponeIfNoneOrMVar expectedType?
|
||
tryPostponeIfDiscrTypeIsMVar matchStx
|
||
match expectedType? with
|
||
| some expectedType => return expectedType
|
||
| none => mkFreshTypeMVar
|
||
|
||
/-
|
||
```
|
||
leading_parser:leadPrec "match " >> sepBy1 matchDiscr ", " >> optType >> " with " >> matchAlts
|
||
```
|
||
Remark the `optIdent` must be `none` at `matchDiscr`. They are expanded by `expandMatchDiscr?`.
|
||
-/
|
||
private def elabMatchCore (stx : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do
|
||
let expectedType ← waitExpectedTypeAndDiscrs stx expectedType?
|
||
let discrStxs := (getDiscrs stx).map fun d => d
|
||
let gen? := getMatchGeneralizing? stx
|
||
let altViews := getMatchAlts stx
|
||
let matchOptType := getMatchOptType stx
|
||
elabMatchAux gen? discrStxs altViews matchOptType expectedType
|
||
|
||
private def isPatternVar (stx : Syntax) : TermElabM Bool := do
|
||
match (← resolveId? stx "pattern") with
|
||
| none => isAtomicIdent stx
|
||
| some f => match f with
|
||
| Expr.const fName _ _ =>
|
||
match (← getEnv).find? fName with
|
||
| some (ConstantInfo.ctorInfo _) => return false
|
||
| some _ => return !hasMatchPatternAttribute (← getEnv) fName
|
||
| _ => isAtomicIdent stx
|
||
| _ => isAtomicIdent stx
|
||
where
|
||
isAtomicIdent (stx : Syntax) : Bool :=
|
||
stx.isIdent && stx.getId.eraseMacroScopes.isAtomic
|
||
|
||
-- leading_parser "match " >> sepBy1 termParser ", " >> optType >> " with " >> matchAlts
|
||
@[builtinTermElab «match»] def elabMatch : TermElab := fun stx expectedType? => do
|
||
match stx with
|
||
| `(match $discr:term with | $y:ident => $rhs:term) =>
|
||
if (← isPatternVar y) then expandSimpleMatch stx discr y rhs expectedType? else elabMatchDefault stx expectedType?
|
||
| _ => elabMatchDefault stx expectedType?
|
||
where
|
||
elabMatchDefault (stx : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do
|
||
match (← expandNonAtomicDiscrs? stx) with
|
||
| some stxNew => withMacroExpansion stx stxNew <| elabTerm stxNew expectedType?
|
||
| none =>
|
||
let discrs := getDiscrs stx;
|
||
let matchOptType := getMatchOptType stx;
|
||
if !matchOptType.isNone && discrs.any fun d => !d[0].isNone then
|
||
throwErrorAt matchOptType "match expected type should not be provided when discriminants with equality proofs are used"
|
||
elabMatchCore stx expectedType?
|
||
|
||
builtin_initialize
|
||
registerTraceClass `Elab.match
|
||
|
||
-- leading_parser:leadPrec "nomatch " >> termParser
|
||
@[builtinTermElab «nomatch»] def elabNoMatch : TermElab := fun stx expectedType? => do
|
||
match stx with
|
||
| `(nomatch $discrExpr) =>
|
||
match (← isLocalIdent? discrExpr) with
|
||
| some _ =>
|
||
let expectedType ← waitExpectedType expectedType?
|
||
let discr := Syntax.node ``Lean.Parser.Term.matchDiscr #[mkNullNode, discrExpr]
|
||
elabMatchAux none #[discr] #[] mkNullNode expectedType
|
||
| _ =>
|
||
let stxNew ← `(let _discr := $discrExpr; nomatch _discr)
|
||
withMacroExpansion stx stxNew <| elabTerm stxNew expectedType?
|
||
| _ => throwUnsupportedSyntax
|
||
|
||
end Lean.Elab.Term
|