lean4-htt/src/Lean/Elab/Match.lean

1291 lines
58 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.Util.CollectFVars
import Lean.Parser.Term
import Lean.Meta.Match.MatchPatternAttr
import Lean.Meta.Match.Match
import Lean.Meta.GeneralizeVars
import Lean.Meta.ForEachExpr
import Lean.Elab.SyntheticMVars
import Lean.Elab.Arg
import Lean.Elab.PatternVar
import Lean.Elab.AuxDiscr
import Lean.Elab.BindersUtil
namespace Lean.Elab.Term
open Meta
open Lean.Parser.Term
private def expandSimpleMatch (stx discr lhsVar rhs : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do
let newStx ← `(let $lhsVar := $discr; $rhs)
withMacroExpansion stx newStx <| elabTerm newStx expectedType?
private def mkUserNameFor (e : Expr) : TermElabM Name := do
match e with
/- Remark: we use `mkFreshUserName` to make sure we don't add a variable to the local context that can be resolved to `e`. -/
| Expr.fvar fvarId _ => mkFreshUserName ((← getLocalDecl fvarId).userName)
| _ => mkFreshBinderName
/--
We treat `@x` as atomic to avoid unnecessary extra local declarations from being
inserted into the local context. Recall that `expandMatchAltsIntoMatch` uses `@` modifier.
Thus this is kind of discriminant is quite common.
Remark: if the discriminat is `Systax.missing`, we abort the elaboration of the `match`-expression.
This can happen due to error recovery. Example
```
example : (p p) → p := fun h => match
```
If we don't abort, the elaborator loops because we will keep trying to expand
```
match
```
into
```
let d := <Syntax.missing>; match
```
Recall that `Syntax.setArg stx i arg` is a no-op when `i` is out-of-bounds. -/
def isAtomicDiscr? (discr : Syntax) : TermElabM (Option Expr) := do
match discr with
| `($x:ident) => isLocalIdent? x
| `(@$x:ident) => isLocalIdent? x
| _ => if discr.isMissing then throwAbortTerm else return none
-- See expandNonAtomicDiscrs?
private def elabAtomicDiscr (discr : Syntax) : TermElabM Expr := do
let term := discr[1]
match (← isAtomicDiscr? term) with
| some e@(Expr.fvar fvarId _) =>
let localDecl ← getLocalDecl fvarId
if !isAuxDiscrName localDecl.userName then
addTermInfo discr e -- it is not an auxiliary local created by `expandNonAtomicDiscrs?`
else
instantiateMVars localDecl.value
| _ => throwErrorAt discr "unexpected discriminant"
structure Discr where
expr : Expr
/-- `some h` if discriminant is annotated with the `h : ` notation. -/
h? : Option Syntax := none
deriving Inhabited
structure ElabMatchTypeAndDiscrsResult where
discrs : Array Discr
matchType : Expr
/- `true` when performing dependent elimination. We use this to decide whether we optimize the "match unit" case.
See `isMatchUnit?`. -/
isDep : Bool
alts : Array MatchAltView
private partial def elabMatchTypeAndDiscrs (discrStxs : Array Syntax) (matchOptMotive : Syntax) (matchAltViews : Array MatchAltView) (expectedType : Expr)
: TermElabM ElabMatchTypeAndDiscrsResult := do
let numDiscrs := discrStxs.size
if matchOptMotive.isNone then
elabDiscrs 0 #[]
else
-- motive := leading_parser atomic ("(" >> nonReservedSymbol "motive" >> " := ") >> termParser >> ")"
let matchTypeStx := matchOptMotive[0][3]
let matchType ← elabType matchTypeStx
let (discrs, isDep) ← elabDiscrsWitMatchType matchType expectedType
return { discrs := discrs, matchType := matchType, isDep := isDep, alts := matchAltViews }
where
/- Easy case: elaborate discriminant when the match-type has been explicitly provided by the user. -/
elabDiscrsWitMatchType (matchType : Expr) (expectedType : Expr) : TermElabM (Array Discr × Bool) := do
let mut discrs := #[]
let mut i := 0
let mut matchType := matchType
let mut isDep := false
for discrStx in discrStxs do
i := i + 1
matchType ← whnf matchType
match matchType with
| Expr.forallE _ d b _ =>
let discr ← fullApproxDefEq <| elabTermEnsuringType discrStx[1] d
trace[Elab.match] "discr #{i} {discr} : {d}"
if b.hasLooseBVars then
isDep := true
matchType := b.instantiate1 discr
discrs := discrs.push { expr := discr }
| _ =>
throwError "invalid motive provided to match-expression, function type with arity #{discrStxs.size} expected"
return (discrs, isDep)
markIsDep (r : ElabMatchTypeAndDiscrsResult) :=
{ r with isDep := true }
/- Elaborate discriminants inferring the match-type -/
elabDiscrs (i : Nat) (discrs : Array Discr) : TermElabM ElabMatchTypeAndDiscrsResult := do
if h : i < discrStxs.size then
let discrStx := discrStxs.get ⟨i, h⟩
let discr ← elabAtomicDiscr discrStx
let discr ← instantiateMVars discr
let discrType ← inferType discr
let discrType ← instantiateMVars discrType
let userName ← mkUserNameFor discr
let h? := if discrStx[0].isNone then none else some discrStx[0][0]
let discrs := discrs.push { expr := discr, h? }
let mut result ← elabDiscrs (i + 1) discrs
let matchTypeBody ← kabstract result.matchType discr
if matchTypeBody.hasLooseBVars then
result := markIsDep result
return { result with matchType := Lean.mkForall userName BinderInfo.default discrType matchTypeBody }
else
return { discrs, alts := matchAltViews, isDep := false, matchType := expectedType }
def expandMacrosInPatterns (matchAlts : Array MatchAltView) : MacroM (Array MatchAltView) := do
matchAlts.mapM fun matchAlt => do
let patterns ← matchAlt.patterns.mapM expandMacros
pure { matchAlt with patterns := patterns }
private def getMatchGeneralizing? : Syntax → Option Bool
| `(match (generalizing := true) $[$motive]? $discrs,* with $alts:matchAlt*) => some true
| `(match (generalizing := false) $[$motive]? $discrs,* with $alts:matchAlt*) => some false
| _ => none
/- Given `stx` a match-expression, return its alternatives. -/
private def getMatchAlts : Syntax → Array MatchAltView
| `(match $[$gen]? $[$motive]? $discrs,* with $alts:matchAlt*) =>
alts.filterMap fun alt => match alt with
| `(matchAltExpr| | $patterns,* => $rhs) => some {
ref := alt,
patterns := patterns,
rhs := rhs
}
| _ => none
| _ => #[]
@[builtinTermElab inaccessible] def elabInaccessible : TermElab := fun stx expectedType? => do
let e ← elabTerm stx[1] expectedType?
return mkInaccessible e
open Lean.Elab.Term.Quotation in
@[builtinQuotPrecheck Lean.Parser.Term.match] def precheckMatch : Precheck
| `(match $[$discrs:term],* with $[| $[$patss],* => $rhss]*) => do
discrs.forM precheck
for (pats, rhs) in patss.zip rhss do
let vars ←
try
getPatternsVars pats
catch
| _ => return -- can happen in case of pattern antiquotations
Quotation.withNewLocals (getPatternVarNames vars) <| precheck rhs
| _ => throwUnsupportedSyntax
/- We convert the collected `PatternVar`s intro `PatternVarDecl` -/
structure PatternVarDecl where
fvarId : FVarId
private partial def withPatternVars {α} (pVars : Array PatternVar) (k : Array PatternVarDecl → TermElabM α) : TermElabM α :=
let rec loop (i : Nat) (decls : Array PatternVarDecl) (userNames : Array Name) := do
if h : i < pVars.size then
let var := pVars.get ⟨i, h⟩
let type ← mkFreshTypeMVar
withLocalDecl var.getId BinderInfo.default type fun x =>
loop (i+1) (decls.push { fvarId := x.fvarId! }) (userNames.push Name.anonymous)
else
k decls
loop 0 #[] #[]
/-
Remark: when performing dependent pattern matching, we often had to write code such as
```lean
def Vec.map' (f : α → β) (xs : Vec α n) : Vec β n :=
match n, xs with
| _, nil => nil
| _, cons a as => cons (f a) (map' f as)
```
We had to include `n` and the `_`s because the type of `xs` depends on `n`.
Moreover, `nil` and `cons a as` have different types.
This was quite tedious. So, we have implemented an automatic "discriminant refinement procedure".
The procedure is based on the observation that we get a type error whenenver we forget to include `_`s
and the indices a discriminant depends on. So, we catch the exception, check whether the type of the discriminant
is an indexed family, and add their indices as new discriminants.
The current implementation, adds indices as they are found, and does not
try to "sort" the new discriminants.
If the refinement process fails, we report the original error message.
-/
/- Auxiliary structure for storing an type mismatch exception when processing the
pattern #`idx` of some alternative. -/
structure PatternElabException where
ex : Exception
patternIdx : Nat -- Discriminant that sh
pathToIndex : List Nat -- Path to the problematic inductive type index that produced the type mismatch
/--
This method is part of the "discriminant refinement" procedure. It in invoked when the
type of the `pattern` does not match the expected type. The expected type is based on the
motive computed using the `match` discriminants.
It tries to compute a path to an index of the discriminant type.
For example, suppose the user has written
```
inductive Mem (a : α) : List α → Prop where
| head {as} : Mem a (a::as)
| tail {as} : Mem a as → Mem a (a'::as)
infix:50 " ∈ " => Mem
example (a b : Nat) (h : a ∈ [b]) : b = a :=
match h with
| Mem.head => rfl
```
The motive for the match is `a ∈ [b] → b = a`, and get a type mismatch between the type
of `Mem.head` and `a ∈ [b]`. This procedure return the path `[2, 1]` to the index `b`.
We use it to produce the following refinement
```
example (a b : Nat) (h : a ∈ [b]) : b = a :=
match b, h with
| _, Mem.head => rfl
```
which produces the new motive `(x : Nat) → a ∈ [x] → x = a`
After this refinement step, the `match` is elaborated successfully.
This method relies on the fact that the dependent pattern matcher compiler solves equations
between indices of indexed inductive families.
The following kinds of equations are supported by this compiler:
- `x = t`
- `t = x`
- `ctor ... = ctor ...`
where `x` is a free variable, `t` is an arbitrary term, and `ctor` is constructor.
Our procedure ensures that "information" is not lost, and will *not* succeed in an
example such as
```
example (a b : Nat) (f : Nat → Nat) (h : f a ∈ [f b]) : f b = f a :=
match h with
| Mem.head => rfl
```
and will not add `f b` as a new discriminant. We may add an option in the future to
enable this more liberal form of refinement.
-/
private partial def findDiscrRefinementPath (pattern : Expr) (expected : Expr) : OptionT MetaM (List Nat) := do
goType (← instantiateMVars (← inferType pattern)) expected
where
checkCompatibleApps (t d : Expr) : OptionT MetaM Unit := do
guard d.isApp
guard <| t.getAppNumArgs == d.getAppNumArgs
let tFn := t.getAppFn
let dFn := d.getAppFn
guard <| tFn.isConst && dFn.isConst
guard (← isDefEq tFn dFn)
-- Visitor for inductive types
goType (t d : Expr) : OptionT MetaM (List Nat) := do
let t ← whnf t
let d ← whnf d
checkCompatibleApps t d
matchConstInduct t.getAppFn (fun _ => failure) fun info _ => do
let tArgs := t.getAppArgs
let dArgs := d.getAppArgs
for i in [:info.numParams] do
let tArg := tArgs[i]
let dArg := dArgs[i]
unless (← isDefEq tArg dArg) do
return i :: (← goType tArg dArg)
for i in [info.numParams : tArgs.size] do
let tArg := tArgs[i]
let dArg := dArgs[i]
unless (← isDefEq tArg dArg) do
return i :: (← goIndex tArg dArg)
failure
-- Visitor for indexed families
goIndex (t d : Expr) : OptionT MetaM (List Nat) := do
let t ← whnfD t
let d ← whnfD d
if t.isFVar || d.isFVar then
return [] -- Found refinement path
else
checkCompatibleApps t d
matchConstCtor t.getAppFn (fun _ => failure) fun info _ => do
let tArgs := t.getAppArgs
let dArgs := d.getAppArgs
for i in [:info.numParams] do
let tArg := tArgs[i]
let dArg := dArgs[i]
unless (← isDefEq tArg dArg) do
failure
for i in [info.numParams : tArgs.size] do
let tArg := tArgs[i]
let dArg := dArgs[i]
unless (← isDefEq tArg dArg) do
return i :: (← goIndex tArg dArg)
failure
private partial def eraseIndices (type : Expr) : MetaM Expr := do
let type' ← whnfD type
matchConstInduct type'.getAppFn (fun _ => return type) fun info _ => do
let args := type'.getAppArgs
let params ← args[:info.numParams].toArray.mapM eraseIndices
let result := mkAppN type'.getAppFn params
let resultType ← inferType result
let (newIndices, _, _) ← forallMetaTelescopeReducing resultType (some (args.size - info.numParams))
return mkAppN result newIndices
private def withPatternElabConfig (x : TermElabM α) : TermElabM α :=
withoutErrToSorry <| withReader (fun ctx => { ctx with inPattern := true }) <| x
private def elabPatterns (patternStxs : Array Syntax) (matchType : Expr) : ExceptT PatternElabException TermElabM (Array Expr × Expr) :=
withReader (fun ctx => { ctx with implicitLambda := false }) do
let mut patterns := #[]
let mut matchType := matchType
for idx in [:patternStxs.size] do
let patternStx := patternStxs[idx]
matchType ← whnf matchType
match matchType with
| Expr.forallE _ d b _ =>
let pattern ← do
let s ← saveState
try
liftM <| withSynthesize <| withPatternElabConfig <| elabTermEnsuringType patternStx d
catch ex : Exception =>
restoreState s
match (← liftM <| commitIfNoErrors? <| withPatternElabConfig do elabTermAndSynthesize patternStx (← eraseIndices d)) with
| some pattern =>
match (← findDiscrRefinementPath pattern d |>.run) with
| some path =>
restoreState s
-- Wrap the type mismatch exception for the "discriminant refinement" feature.
throwThe PatternElabException { ex := ex, patternIdx := idx, pathToIndex := path }
| none => restoreState s; throw ex
| none => throw ex
matchType := b.instantiate1 pattern
patterns := patterns.push pattern
| _ => throwError "unexpected match type"
return (patterns, matchType)
open Meta.Match (Pattern Pattern.var Pattern.inaccessible Pattern.ctor Pattern.as Pattern.val Pattern.arrayLit AltLHS MatcherResult)
namespace ToDepElimPattern
private def throwInvalidPattern (e : Expr) : MetaM α :=
throwError "invalid pattern {indentExpr e}"
structure State where
patternVars : Array Expr := #[]
structure Context where
/--
When visiting an assigned metavariable, if it has an user-name. We save it here.
We want to preserve these user-names when generating new pattern variables. -/
userName : Name := Name.anonymous
/--
Pattern variables that were explicitly provided by the user.
Recall that implicit parameters and `_` are elaborated as metavariables, and then converted into pattern variables
by the `normalize` procedure.
-/
explicitPatternVars : Array FVarId := #[]
abbrev M := ReaderT Context $ StateRefT State TermElabM
/-- Return true iff `e` is an explicit pattern variable provided by the user. -/
def isExplicitPatternVar (e : Expr) : M Bool := do
if e.isFVar then
return (← read).explicitPatternVars.any (· == e.fvarId!)
else
return false
/--
Helper function for "saving" the user name associated with `mvarId` (if it is not "anonymous") before visiting `x`
The auto generalization feature will uses synthetic holes to preserve the name of the free variable included during generalization.
For example, if we are generalizing a free variable `bla`, we add the synthetic hole `?bla` for the pattern. We use synthetic hole
because we don't know whether `?bla` will become an inaccessible pattern or not.
The `withMVar` method makes sure we don't "lose" this name when `isDefEq` perform assignments of the form `?bla := ?m` where `?m` has no user name.
This can happen, for example, when the user provides a `_` pattern, or for implicit fields.
-/
private def withMVar (mvarId : MVarId) (x : M α) : M α := do
let localDecl ← getMVarDecl mvarId
if !localDecl.userName.isAnonymous && (← read).userName.isAnonymous then
withReader (fun ctx => { ctx with userName := localDecl.userName }) x
else
x
/--
Creating a mapping containing `b ↦ e'` where `patternWithRef e' = some (stx, b)`,
and `e'` is a subterm of `e`.
This is a helper function for `whnfPreservingPatternRef`. -/
private def mkPatternRefMap (e : Expr) : ExprMap Expr :=
runST go
where
go (σ) : ST σ (ExprMap Expr) := do
let map : ST.Ref σ (ExprMap Expr) ← ST.mkRef {}
e.forEach fun e => do
match patternWithRef? e with
| some (ref, b) => map.modify (·.insert b e)
| none => return ()
map.get
/--
Try to restore `Syntax` ref information stored in `map` after
applying `whnf` at `whnfPreservingPatternRef`.
It assumes `map` has been constructed using `mkPatternRefMap`.
-/
private def applyRefMap (e : Expr) (map : ExprMap Expr) : Expr :=
e.replace fun e =>
match patternWithRef? e with
| some _ => some e -- stop `e` already has annotation
| none => match map.find? e with
| some eWithRef => some eWithRef -- stop `e` found annotation
| none => none -- continue
/--
Applies `whnf` but tries to preserve `PatternWithRef` information.
This is a bit hackish, but it is necessary for providing proper
jump-to-definition information in examples such as
```
def f (x : Nat) : Nat :=
match x with
| 0 => 1
| y + 1 => y
```
Without this trick, the `PatternWithRef` is lost for the `y` at the pattern `y+1`.
-/
private def whnfPreservingPatternRef (e : Expr) : MetaM Expr := do
let eNew ← whnf e
if eNew.isConstructorApp (← getEnv) then
return eNew
else
return applyRefMap eNew (mkPatternRefMap e)
/--
Normalize the pattern and collect all patterns variables (explicit and implicit).
This method is the one that decides where the inaccessible annotations must be inserted.
The pattern variables are both free variables (for explicit pattern variables) and metavariables (for implicit ones).
Recall that `mkLambdaFVars` now allows us to abstract both free variables and metavariables.
-/
partial def normalize (e : Expr) : M Expr := do
match inaccessible? e with
| some e => processInaccessible e
| none =>
match patternWithRef? e with
| some (ref, e) => return mkPatternWithRef (← normalize e) ref
| none =>
match e.arrayLit? with
| some (α, lits) => mkArrayLit α (← lits.mapM normalize)
| none =>
if let some e := Match.isNamedPattern? e then
let x := e.getArg! 1
let p := e.getArg! 2
let h := e.getArg! 3
unless x.consumeMData.isFVar && h.consumeMData.isFVar do
throwError "unexpected occurrence of auxiliary declaration 'namedPattern'"
addVar x
let p ← normalize p
addVar h
return mkApp4 e.getAppFn (e.getArg! 0) x p h
else if isMatchValue e then
return e
else if e.isFVar then
if (← isExplicitPatternVar e) then
processVar e
else
return mkInaccessible e
else if e.getAppFn.isMVar then
let eNew ← instantiateMVars e
if eNew != e then
withMVar e.getAppFn.mvarId! <| normalize eNew
else if e.isMVar then
withMVar e.mvarId! <| processVar e
else
throwInvalidPattern e
else
let eNew ← whnfPreservingPatternRef e
if eNew != e then
normalize eNew
else
matchConstCtor e.getAppFn
(fun _ => return mkInaccessible (← eraseInaccessibleAnnotations (← instantiateMVars e)))
(fun v us => do
let args := e.getAppArgs
unless args.size == v.numParams + v.numFields do
throwInvalidPattern e
let params := args.extract 0 v.numParams
let params ← params.mapM fun p => instantiateMVars p
let fields := args.extract v.numParams args.size
let fields ← fields.mapM normalize
return mkAppN e.getAppFn (params ++ fields))
where
addVar (e : Expr) : M Unit := do
let e ← erasePatternRefAnnotations e
unless (← get).patternVars.contains e do
modify fun s => { s with patternVars := s.patternVars.push e }
processVar (e : Expr) : M Expr := do
let e' ← erasePatternRefAnnotations e
if (← get).patternVars.contains e' then
return mkInaccessible (← eraseInaccessibleAnnotations e)
else
if e'.isMVar then
setMVarTag e'.mvarId! (← read).userName
modify fun s => { s with patternVars := s.patternVars.push e' }
return e
processInaccessible (e : Expr) : M Expr := do
let e' ← erasePatternRefAnnotations e
match e' with
| Expr.fvar fvarId _ =>
if (← isExplicitPatternVar e') then
processVar e
else
return mkInaccessible e
| _ =>
if e'.getAppFn.isMVar then
let eNew ← instantiateMVars e'
if eNew != e' then
withMVar e'.getAppFn.mvarId! <| processInaccessible eNew
else if e'.isMVar then
withMVar e'.mvarId! <| processVar e'
else
throwInvalidPattern e
else
return mkInaccessible (← eraseInaccessibleAnnotations (← instantiateMVars e))
/--
Auxiliary function for combining the `matchType` and all patterns into a single expression.
We use it before we abstract all patterns variables. -/
private partial def packMatchTypePatterns (matchType : Expr) (ps : Array Expr) : MetaM Expr :=
ps.foldlM (init := matchType) fun result p => mkAppM ``PProd.mk #[result, p]
/-- The inverse of `packMatchTypePatterns`. -/
private partial def unpackMatchTypePatterns (p : Expr) : Expr × Array Expr :=
if p.isAppOf ``PProd.mk then
let (matchType, ps) := unpackMatchTypePatterns (p.getArg! 2)
(matchType, ps.push (p.getArg! 3))
else
(p, #[])
/--
Convert a (normalized) pattern encoded as an `Expr` into a `Pattern`.
This method assumes that `e` has been normalized and the explicit and implicit (i.e., metavariables) pattern variables have
already been abstracted and converted back into new free variables.
-/
private partial def toPattern (e : Expr) : MetaM Pattern := do
match inaccessible? e with
| some e => return Pattern.inaccessible e
| none =>
match e.arrayLit? with
| some (α, lits) => return Pattern.arrayLit α (← lits.mapM toPattern)
| none =>
if let some e := Match.isNamedPattern? e then
let p ← toPattern <| e.getArg! 2
match e.getArg! 1, e.getArg! 3 with
| Expr.fvar x _, Expr.fvar h _ => return Pattern.as x p h
| _, _ => throwError "unexpected occurrence of auxiliary declaration 'namedPattern'"
else if isMatchValue e then
return Pattern.val e
else if e.isFVar then
return Pattern.var e.fvarId!
else
matchConstCtor e.getAppFn (fun _ => unreachable!) fun v us => do
let args := e.getAppArgs
let params := args.extract 0 v.numParams
let params ← params.mapM fun p => instantiateMVars p
let fields := args.extract v.numParams args.size
let fields ← fields.mapM toPattern
return Pattern.ctor v.name us params.toList fields.toList
structure TopSort.State where
visitedFVars : FVarIdSet := {}
visitedMVars : MVarIdSet := {}
result : Array Expr := #[]
abbrev TopSortM := StateRefT TopSort.State TermElabM
/--
Topological sort. We need it because inaccessible patterns may contain pattern variables that are declared later.
That is, processing patterns from left to right to do not guarantee that the pattern variables are collected in the
"right" order. "Right" here means pattern `x` must occur befor pattern `y` if `y`s type depends on `x`.
-/
private partial def topSort (patternVars : Array Expr) : TermElabM (Array Expr) := do
let (_, s) ← patternVars.mapM visit |>.run {}
return s.result
where
visit (e : Expr) : TopSortM Unit := do
match e with
| Expr.proj _ _ e _ => visit e
| Expr.forallE _ d b _ => visit d; visit b
| Expr.lam _ d b _ => visit d; visit b
| Expr.letE _ t v b _ => visit t; visit v; visit b
| Expr.app f a _ => visit f; visit a
| Expr.mdata _ b _ => visit b
| Expr.mvar mvarId _ =>
let v ← instantiateMVars e
if !v.isMVar then
visit v
else if patternVars.contains e then
unless (← get).visitedMVars.contains mvarId do
modify fun s => { s with visitedMVars := s.visitedMVars.insert mvarId }
let mvarDecl ← getMVarDecl mvarId
visit mvarDecl.type
modify fun s => { s with result := s.result.push e }
| Expr.fvar fvarId _ =>
if patternVars.contains e then
unless (← get).visitedFVars.contains fvarId do
modify fun s => { s with visitedFVars := s.visitedFVars.insert fvarId }
let localDecl ← getLocalDecl fvarId
visit localDecl.type
modify fun s => { s with result := s.result.push e }
| _ => return ()
/--
Save pattern information in the info tree, and remove `patternWithRef?` annotations.
-/
partial def savePatternInfo (p : Expr) : TermElabM Expr :=
go p |>.run false
where
/- The `Bool` context is true iff we are inside of an "inaccessible" pattern. -/
go (p : Expr) : ReaderT Bool TermElabM Expr := do
match p with
| .forallE n d b bi => withLocalDecl n b.binderInfo (← go d) fun x => do mkForallFVars #[x] (← go (b.instantiate1 x))
| .lam n d b bi => withLocalDecl n b.binderInfo (← go d) fun x => do mkLambdaFVars #[x] (← go (b.instantiate1 x))
| .letE n t v b .. => withLetDecl n (← go t) (← go v) fun x => do mkLetFVars #[x] (← go (b.instantiate1 x))
| .app f a _ => return mkApp (← go f) (← go a)
| .proj _ _ b _ => return p.updateProj! (← go b)
| .mdata k b _ =>
if inaccessible? p |>.isSome then
return mkMData k (← withReader (fun _ => false) (go b))
else if let some (stx, p) := patternWithRef? p then
let p ← go p
if p.isFVar && !(← read) then
/- If `p` is a free variable and we are not inside of an "inaccessible" pattern, this `p` is a binder. -/
addTermInfo stx p (isBinder := true)
else
addTermInfo stx p
else
return mkMData k (← go b)
| _ => return p
/--
Main method for `withDepElimPatterns`.
- `PatternVarDecls`: are the explicit pattern variables provided by the user.
- `ps`: are the patterns provided by the user.
- `matchType`: the expected typ for this branch. It depends on the explicit pattern variables and the implicit ones that are still represented as metavariables,
and are found by this function.
- `k` is the continuation that is executed in an updated local context with the all pattern variables (explicit and implicit). Note that, `patternVarDecls` are all
replaced since they may depend on implicit pattern variables (i.e., metavariables) that are converted into new free variables by this method.
-/
partial def main (patternVarDecls : Array PatternVarDecl) (ps : Array Expr) (matchType : Expr) (k : Array LocalDecl → Array Pattern → Expr → TermElabM α) : TermElabM α := do
let explicitPatternVars := patternVarDecls.map fun decl => decl.fvarId
let (ps, s) ← ps.mapM normalize |>.run { explicitPatternVars } |>.run {}
let patternVars ← topSort s.patternVars
trace[Elab.match] "patternVars after topSort: {patternVars}"
for explicit in explicitPatternVars do
unless patternVars.any (· == mkFVar explicit) do
withInPattern do
throwError "invalid patterns, `{mkFVar explicit}` is an explicit pattern variable, but it only occurs in positions that are inaccessible to pattern matching{indentD (MessageData.joinSep (ps.toList.map (MessageData.ofExpr .)) m!"\n\n")}"
let packed ← pack patternVars ps matchType
trace[Elab.match] "packed: {packed}"
let lctx := explicitPatternVars.foldl (init := (← getLCtx)) fun lctx d => lctx.erase d
withTheReader Meta.Context (fun ctx => { ctx with lctx := lctx }) do
check packed
unpack packed fun patternVars patterns matchType => do
let localDecls ← patternVars.mapM fun x => getLocalDecl x.fvarId!
trace[Elab.match] "patternVars: {patternVars}, matchType: {matchType}"
k localDecls (← patterns.mapM fun p => toPattern p) matchType
where
pack (patternVars : Array Expr) (ps : Array Expr) (matchType : Expr) : MetaM Expr := do
/-
Recall that some of the `patternVars` are metavariables without a user facing name.
Thus, this method tries to infer names for them using `ps` before performing the `mkLambdaFVars` abstraction.
Let `?m` be a metavariable in `patternVars` without a user facing name.
The heuristic uses the patterns `ps`. We traverse the patterns from right to left searching for applications
`f ... ?m`. The name for the corresponding `f`-parameter is used to name `?m`.
We search from right to left to make sure we visit a pattern before visiting its indices. Example:
```
#[@List.cons α i ?m, @HList.cons α β i ?m a as, @Member.head α i ?m]
```
-/
let setMVarsAt (e : Expr) : StateRefT (Array MVarId) MetaM Unit := do
let mvarIds ← setMVarUserNamesAt (← erasePatternRefAnnotations e) patternVars
modify (· ++ mvarIds)
let go : StateRefT (Array MVarId) MetaM Expr := do
try
for p in ps.reverse do
setMVarsAt p
mkLambdaFVars patternVars (← packMatchTypePatterns matchType ps) (binderInfoForMVars := BinderInfo.default)
finally
resetMVarUserNames (← get)
go |>.run' #[]
unpack (packed : Expr) (k : (patternVars : Array Expr) → (patterns : Array Expr) → (matchType : Expr) → TermElabM α) : TermElabM α :=
let rec go (packed : Expr) (patternVars : Array Expr) : TermElabM α := do
match packed with
| .lam n d b _ =>
withLocalDeclD n (← erasePatternRefAnnotations (← eraseInaccessibleAnnotations d)) fun patternVar =>
go (b.instantiate1 patternVar) (patternVars.push patternVar)
| _ =>
let (matchType, patterns) := unpackMatchTypePatterns packed
let matchType ← erasePatternRefAnnotations (← eraseInaccessibleAnnotations matchType)
let patterns ← patterns.mapM (savePatternInfo ·)
k patternVars patterns matchType
go packed #[]
end ToDepElimPattern
def withDepElimPatterns (patternVarDecls : Array PatternVarDecl) (ps : Array Expr) (matchType : Expr) (k : Array LocalDecl → Array Pattern → Expr → TermElabM α) : TermElabM α := do
ToDepElimPattern.main patternVarDecls ps matchType k
private def withElaboratedLHS {α} (ref : Syntax) (patternVarDecls : Array PatternVarDecl) (patternStxs : Array Syntax) (matchType : Expr)
(k : AltLHS → Expr → TermElabM α) : ExceptT PatternElabException TermElabM α := do
let (patterns, matchType) ← withSynthesize <| elabPatterns patternStxs matchType
id (α := TermElabM α) do
trace[Elab.match] "patterns: {patterns}"
withDepElimPatterns patternVarDecls patterns matchType fun localDecls patterns matchType => do
k { ref := ref, fvarDecls := localDecls.toList, patterns := patterns.toList } matchType
/--
Try to clear the free variables in `toClear` and auxiliary discriminants, and then execute `k` in the updated local context.
If `type` or another local variables depends on a free variable in `toClear`, then it is not cleared.
-/
private def withToClear (toClear : Array FVarId) (type : Expr) (k : TermElabM α) : TermElabM α := do
if toClear.isEmpty then
k
else
let toClear ← sortFVarIds toClear
trace[Elab.match] ">> toClear {toClear.map mkFVar}"
let mut lctx ← getLCtx
let mut localInsts ← getLocalInstances
for fvarId in toClear.reverse do
if !(← dependsOn type fvarId) then
if !(← lctx.anyM fun localDecl => pure (localDecl.fvarId != fvarId) <&&> localDeclDependsOn localDecl fvarId) then
lctx := lctx.erase fvarId
localInsts := localInsts.filter fun localInst => localInst.fvar.fvarId! != fvarId
withLCtx lctx localInsts k
private def withoutAuxDiscrs (matchType : Expr) (k : TermElabM α) : TermElabM α := do
let mut toClear := #[]
for localDecl in (← getLCtx) do
if isAuxDiscrName localDecl.userName || isAuxFunDiscrName localDecl.userName then
toClear := toClear.push localDecl.fvarId
withToClear toClear matchType k
/--
Generate equalities `h : discr = pattern` for discriminants annotated with `h :`.
We use these equalities to elaborate the right-hand-side of a `match` alternative.
-/
private def withEqs (discrs : Array Discr) (patterns : List Pattern) (k : Array Expr → TermElabM α) : TermElabM α := do
go 0 patterns #[]
where
go (i : Nat) (ps : List Pattern) (eqs : Array Expr) : TermElabM α := do
match ps with
| [] => k eqs
| p::ps =>
if h : i < discrs.size then
let discr := discrs.get ⟨i, h⟩
if let some h := discr.h? then
withLocalDeclD h.getId (← mkEqHEq discr.expr (← p.toExpr)) fun eq => do
addTermInfo' h eq (isBinder := true)
go (i+1) ps (eqs.push eq)
else
go (i+1) ps eqs
else
k eqs
/--
Elaborate the `match` alternative `alt` using the given `matchType`.
The array `toClear` contains variables that must be cleared before elaborating the `rhs` because
they have been generalized/refined.
-/
private def elabMatchAltView (discrs : Array Discr) (alt : MatchAltView) (matchType : Expr) (toClear : Array FVarId) : ExceptT PatternElabException TermElabM (AltLHS × Expr) := withRef alt.ref do
withoutAuxDiscrs matchType do
let (patternVars, alt) ← collectPatternVars alt
trace[Elab.match] "patternVars: {patternVars}"
withPatternVars patternVars fun patternVarDecls => do
withElaboratedLHS alt.ref patternVarDecls alt.patterns matchType fun altLHS matchType =>
withEqs discrs altLHS.patterns fun eqs =>
withLocalInstances altLHS.fvarDecls do
trace[Elab.match] "elabMatchAltView: {matchType}"
let matchType ← instantiateMVars matchType
-- If `matchType` is of the form `@m ...`, we create a new metavariable with the current scope.
-- This improves the effectiveness of the `isDefEq` default approximations
let matchType' ← if matchType.getAppFn.isMVar then mkFreshTypeMVar else pure matchType
withToClear toClear matchType' do
let rhs ← elabTermEnsuringType alt.rhs matchType'
-- We use all approximations to ensure the auxiliary type is defeq to the original one.
unless (← fullApproxDefEq <| isDefEq matchType' matchType) do
throwError "type mistmatch, alternative {← mkHasTypeButIsExpectedMsg matchType' matchType}"
let xs := altLHS.fvarDecls.toArray.map LocalDecl.toExpr ++ eqs
let rhs ← if xs.isEmpty then pure <| mkSimpleThunk rhs else mkLambdaFVars xs rhs
trace[Elab.match] "rhs: {rhs}"
return (altLHS, rhs)
/--
Collect problematic index for the "discriminant refinement feature". This method is invoked
when we detect a type mismatch at a pattern #`idx` of some alternative. -/
private partial def getIndexToInclude? (discr : Expr) (pathToIndex : List Nat) : TermElabM (Option Expr) := do
go (← inferType discr) pathToIndex |>.run
where
go (e : Expr) (path : List Nat) : OptionT MetaM Expr := do
match path with
| [] => return e
| i::path =>
let e ← whnfD e
guard <| e.isApp && i < e.getAppNumArgs
go (e.getArg! i) path
structure GeneralizeResult where
discrs : Array Discr
/-- `FVarId`s of the variables that have been generalized. We store them to clear after in each branch. -/
toClear : Array FVarId := #[]
matchType : Expr
altViews : Array MatchAltView
refined : Bool := false
/--
"Generalize" variables that depend on the discriminants.
Remarks and limitations:
- We currently do not generalize let-decls.
- We abort generalization if the new `matchType` is type incorrect.
- Only discriminants that are free variables are considered during specialization.
- We "generalize" by adding new discriminants and pattern variables. We do not "clear" the generalized variables,
but they become inaccessible since they are shadowed by the patterns variables. We assume this is ok since
this is the exact behavior users would get if they had written it by hand. Recall there is no `clear` in term mode.
-/
private def generalize (discrs : Array Discr) (matchType : Expr) (altViews : Array MatchAltView) (generalizing? : Option Bool) : TermElabM GeneralizeResult := do
let gen := if let some g := generalizing? then g else true
if !gen then
return { discrs, matchType, altViews }
else
let discrExprs := discrs.map (·.expr)
/- let-decls are currently being ignored by the generalizer. -/
let ysFVarIds ← getFVarsToGeneralize discrExprs (ignoreLetDecls := true)
if ysFVarIds.isEmpty then
return { discrs, matchType, altViews }
else
let ys := ysFVarIds.map mkFVar
let matchType' ← forallBoundedTelescope matchType discrs.size fun ds type => do
let type ← mkForallFVars ys type
let (discrs', ds') := Array.unzip <| Array.zip discrExprs ds |>.filter fun (di, d) => di.isFVar
let type := type.replaceFVars discrs' ds'
mkForallFVars ds type
if (← isTypeCorrect matchType') then
let discrs := discrs ++ ys.map fun y => { expr := y : Discr }
let altViews ← altViews.mapM fun altView => do
let patternVars ← getPatternsVars altView.patterns
-- We traverse backwards because we want to keep the most recent names.
-- For example, if `ys` contains `#[h, h]`, we want to make sure `mkFreshUsername is applied to the first `h`,
-- since it is already shadowed by the second.
let ysUserNames ← ys.foldrM (init := #[]) fun ys ysUserNames => do
let yDecl ← getLocalDecl ys.fvarId!
let mut yUserName := yDecl.userName
if ysUserNames.contains yUserName then
yUserName ← mkFreshUserName yUserName
-- Explicitly provided pattern variables shadow `y`
else if patternVars.any fun x => x.getId == yUserName then
yUserName ← mkFreshUserName yUserName
return ysUserNames.push yUserName
let ysIds ← ysUserNames.reverse.mapM fun n => return mkIdentFrom (← getRef) n
return { altView with patterns := altView.patterns ++ ysIds }
return { discrs, toClear := ysFVarIds, matchType := matchType', altViews, refined := true }
else
return { discrs, matchType, altViews }
private partial def elabMatchAltViews (generalizing? : Option Bool) (discrs : Array Discr) (matchType : Expr) (altViews : Array MatchAltView) : TermElabM (Array Discr × Expr × Array (AltLHS × Expr) × Bool) := do
loop discrs #[] matchType altViews none
where
/-
"Discriminant refinement" main loop.
`first?` contains the first error message we found before updated the `discrs`. -/
loop (discrs : Array Discr) (toClear : Array FVarId) (matchType : Expr) (altViews : Array MatchAltView) (first? : Option (SavedState × Exception))
: TermElabM (Array Discr × Expr × Array (AltLHS × Expr) × Bool) := do
let s ← saveState
let { discrs := discrs', toClear := toClear', matchType := matchType', altViews := altViews', refined } ← generalize discrs matchType altViews generalizing?
match (← altViews'.mapM (fun altView => elabMatchAltView discrs' altView matchType' (toClear ++ toClear')) |>.run) with
| Except.ok alts => return (discrs', matchType', alts, first?.isSome || refined)
| Except.error { patternIdx := patternIdx, pathToIndex := pathToIndex, ex := ex } =>
let discr := discrs[patternIdx]
let some index ← getIndexToInclude? discr.expr pathToIndex
| throwEx (← updateFirst first? ex)
trace[Elab.match] "index to include: {index}"
if (← discrs.anyM fun discr => isDefEq discr.expr index) then
throwEx (← updateFirst first? ex)
let first ← updateFirst first? ex
s.restore (restoreInfo := true)
let indices ← collectDeps #[index] (discrs.map (·.expr))
let matchType ←
try
updateMatchType indices matchType
catch ex =>
throwEx first
let ref ← getRef
trace[Elab.match] "new indices to add as discriminants: {indices}"
let wildcards ← indices.mapM fun index => do
if index.isFVar then
let localDecl ← getLocalDecl index.fvarId!
if localDecl.userName.hasMacroScopes then
return mkHole ref
else
let id := mkIdentFrom ref localDecl.userName
`(?$id:ident)
else
return mkHole ref
let altViews := altViews.map fun altView => { altView with patterns := wildcards ++ altView.patterns }
let indDiscrs ← indices.mapM fun i => do
match discr.h? with
| none => return { expr := i : Discr }
| some h =>
-- If the discriminant that introduced this index is annotated with `h : discr`, then we should annotate the new discriminant too.
let h := mkIdentFrom h (← mkFreshUserName `h)
return { expr := i, h? := h : Discr }
let discrs := indDiscrs ++ discrs
let indexFVarIds := indices.filterMap fun | .fvar fvarId .. => some fvarId | _ => none
loop discrs (toClear ++ indexFVarIds) matchType altViews first
throwEx {α} (p : SavedState × Exception) : TermElabM α := do
p.1.restore (restoreInfo := true); throw p.2
updateFirst (first? : Option (SavedState × Exception)) (ex : Exception) : TermElabM (SavedState × Exception) := do
match first? with
| none => return (← saveState, ex)
| some first => return first
containsFVar (es : Array Expr) (fvarId : FVarId) : Bool :=
es.any fun e => e.isFVar && e.fvarId! == fvarId
/- Update `indices` by including any free variable `x` s.t.
- Type of some `discr` depends on `x`.
- Type of `x` depends on some free variable in `indices`.
If we don't include these extra variables in indices, then
`updateMatchType` will generate a type incorrect term.
For example, suppose `discr` contains `h : @HEq α a α b`, and
`indices` is `#[α, b]`, and `matchType` is `@HEq α a α b → B`.
`updateMatchType indices matchType` produces the type
`(α' : Type) → (b : α') → @HEq α' a α' b → B` which is type incorrect
because we have `a : α`.
The method `collectDeps` will include `a` into `indices`.
This method does not handle dependencies among non-free variables.
We rely on the type checking method `check` at `updateMatchType`.
Remark: `indices : Array Expr` does not need to be an array anymore.
We should cleanup this code, and use `index : Expr` instead.
-/
collectDeps (indices : Array Expr) (discrs : Array Expr) : TermElabM (Array Expr) := do
let mut s : CollectFVars.State := {}
for discr in discrs do
s := collectFVars s (← instantiateMVars (← inferType discr))
let (indicesFVar, indicesNonFVar) := indices.split Expr.isFVar
let indicesFVar := indicesFVar.map Expr.fvarId!
let mut toAdd := #[]
for fvarId in s.fvarSet.toList do
unless containsFVar discrs fvarId || containsFVar indices fvarId do
let localDecl ← getLocalDecl fvarId
let mctx ← getMCtx
for indexFVarId in indicesFVar do
if mctx.localDeclDependsOn localDecl indexFVarId then
toAdd := toAdd.push fvarId
let indicesFVar ← sortFVarIds (indicesFVar ++ toAdd)
return indicesFVar.map mkFVar ++ indicesNonFVar
updateMatchType (indices : Array Expr) (matchType : Expr) : TermElabM Expr := do
let matchType ← indices.foldrM (init := matchType) fun index matchType => do
let indexType ← inferType index
let matchTypeBody ← kabstract matchType index
let userName ← mkUserNameFor index
return Lean.mkForall userName BinderInfo.default indexType matchTypeBody
check matchType
return matchType
def mkMatcher (input : Meta.Match.MkMatcherInput) : TermElabM MatcherResult :=
Meta.Match.mkMatcher input
register_builtin_option match.ignoreUnusedAlts : Bool := {
defValue := false
descr := "if true, do not generate error if an alternative is not used"
}
def reportMatcherResultErrors (altLHSS : List AltLHS) (result : MatcherResult) : TermElabM Unit := do
unless result.counterExamples.isEmpty do
withHeadRefOnly <| logError m!"missing cases:\n{Meta.Match.counterExamplesToMessageData result.counterExamples}"
unless match.ignoreUnusedAlts.get (← getOptions) || result.unusedAltIdxs.isEmpty do
let mut i := 0
for alt in altLHSS do
if result.unusedAltIdxs.contains i then
withRef alt.ref do
logError "redundant alternative"
i := i + 1
/--
If `altLHSS + rhss` is encoding `| PUnit.unit => rhs[0]`, return `rhs[0]`
Otherwise, return none.
-/
private def isMatchUnit? (altLHSS : List Match.AltLHS) (rhss : Array Expr) : MetaM (Option Expr) := do
assert! altLHSS.length == rhss.size
match altLHSS with
| [ { fvarDecls := [], patterns := [ Pattern.ctor `PUnit.unit .. ], .. } ] =>
/- Recall that for alternatives of the form `| PUnit.unit => rhs`, `rhss[0]` is of the form `fun _ : Unit => b`. -/
match rhss[0] with
| Expr.lam _ _ b _ => return if b.hasLooseBVars then none else b
| _ => return none
| _ => return none
private def elabMatchAux (generalizing? : Option Bool) (discrStxs : Array Syntax) (altViews : Array MatchAltView) (matchOptMotive : Syntax) (expectedType : Expr)
: TermElabM Expr := do
let mut generalizing? := generalizing?
if !matchOptMotive.isNone then
if generalizing? == some true then
throwError "the '(generalizing := true)' parameter is not supported when the 'match' motive is explicitly provided"
generalizing? := some false
let (discrs, matchType, altLHSS, isDep, rhss) ← commitIfDidNotPostpone do
let ⟨discrs, matchType, isDep, altViews⟩ ← elabMatchTypeAndDiscrs discrStxs matchOptMotive altViews expectedType
let matchAlts ← liftMacroM <| expandMacrosInPatterns altViews
trace[Elab.match] "matchType: {matchType}"
let (discrs, matchType, alts, refined) ← elabMatchAltViews generalizing? discrs matchType matchAlts
let isDep := isDep || refined
/-
We should not use `synthesizeSyntheticMVarsNoPostponing` here. Otherwise, we will not be
able to elaborate examples such as:
```
def f (x : Nat) : Option Nat := none
def g (xs : List (Nat × Nat)) : IO Unit :=
xs.forM fun x =>
match f x.fst with
| _ => pure ()
```
If `synthesizeSyntheticMVarsNoPostponing`, the example above fails at `x.fst` because
the type of `x` is only available after we proces the last argument of `List.forM`.
We apply pending default types to make sure we can process examples such as
```
let (a, b) := (0, 0)
```
-/
synthesizeSyntheticMVarsUsingDefault
let rhss := alts.map Prod.snd
let matchType ← instantiateMVars matchType
let altLHSS ← alts.toList.mapM fun alt => do
let altLHS ← Match.instantiateAltLHSMVars alt.1
/- Remark: we try to postpone before throwing an error.
The combinator `commitIfDidNotPostpone` ensures we backtrack any updates that have been performed.
The quick-check `waitExpectedTypeAndDiscrs` minimizes the number of scenarios where we have to postpone here.
Here is an example that passes the `waitExpectedTypeAndDiscrs` test, but postpones here.
```
def bad (ps : Array (Nat × Nat)) : Array (Nat × Nat) :=
(ps.filter fun (p : Prod _ _) =>
match p with
| (x, y) => x == 0)
++
ps
```
When we try to elaborate `fun (p : Prod _ _) => ...` for the first time, we haven't propagated the type of `ps` yet
because `Array.filter` has type `{α : Type u_1} → (α → Bool) → (as : Array α) → optParam Nat 0 → optParam Nat (Array.size as) → Array α`
However, the partial type annotation `(p : Prod _ _)` makes sure we succeed at the quick-check `waitExpectedTypeAndDiscrs`.
-/
withRef altLHS.ref do
for d in altLHS.fvarDecls do
if d.hasExprMVar then
withExistingLocalDecls altLHS.fvarDecls do
tryPostpone
throwMVarError m!"invalid match-expression, type of pattern variable '{d.toExpr}' contains metavariables{indentExpr d.type}"
for p in altLHS.patterns do
if p.hasExprMVar then
withExistingLocalDecls altLHS.fvarDecls do
tryPostpone
throwMVarError m!"invalid match-expression, pattern contains metavariables{indentExpr (← p.toExpr)}"
pure altLHS
return (discrs, matchType, altLHSS, isDep, rhss)
if let some r ← if isDep then pure none else isMatchUnit? altLHSS rhss then
return r
else
let numDiscrs := discrs.size
let matcherName ← mkAuxName `match
let matcherResult ← mkMatcher { matcherName, matchType, discrInfos := discrs.map fun discr => { hName? := discr.h?.map (·.getId) }, lhss := altLHSS }
matcherResult.addMatcher
let motive ← forallBoundedTelescope matchType numDiscrs fun xs matchType => mkLambdaFVars xs matchType
reportMatcherResultErrors altLHSS matcherResult
let r := mkApp matcherResult.matcher motive
let r := mkAppN r (discrs.map (·.expr))
let r := mkAppN r rhss
trace[Elab.match] "result: {r}"
return r
-- leading_parser "match " >> optional generalizingParam >> optional motive >> sepBy1 matchDiscr ", " >> " with " >> ppDedent matchAlts
private def getDiscrs (matchStx : Syntax) : Array Syntax :=
matchStx[3].getSepArgs
private def getMatchOptMotive (matchStx : Syntax) : Syntax :=
matchStx[2]
private def expandNonAtomicDiscrs? (matchStx : Syntax) : TermElabM (Option Syntax) :=
let matchOptMotive := getMatchOptMotive matchStx
if matchOptMotive.isNone then do
let discrs := getDiscrs matchStx
let allLocal ← discrs.allM fun discr => Option.isSome <$> isAtomicDiscr? discr[1]
if allLocal then
return none
else
-- We use `foundFVars` to make sure the discriminants are distinct variables.
-- See: code for computing "matchType" at `elabMatchTypeAndDiscrs`
let rec loop (discrs : List Syntax) (discrsNew : Array Syntax) (foundFVars : FVarIdSet) := do
match discrs with
| [] =>
let discrs := Syntax.mkSep discrsNew (mkAtomFrom matchStx ", ")
pure (matchStx.setArg 3 discrs)
| discr :: discrs =>
-- Recall that
-- matchDiscr := leading_parser optional (ident >> ":") >> termParser
let term := discr[1]
let addAux : TermElabM Syntax := withFreshMacroScope do
let d ← mkAuxDiscr
unless isAuxDiscrName d.getId do -- Use assertion?
throwError "unexpected internal auxiliary discriminant name"
let discrNew := discr.setArg 1 d
let r ← loop discrs (discrsNew.push discrNew) foundFVars
`(let $d:ident := $term; $r)
match (← isAtomicDiscr? term) with
| some x => if x.isFVar then loop discrs (discrsNew.push discr) (foundFVars.insert x.fvarId!) else addAux
| none => addAux
return some (← loop discrs.toList #[] {})
else
-- We do not pull non atomic discriminants when match type is provided explicitly by the user
return none
private def waitExpectedType (expectedType? : Option Expr) : TermElabM Expr := do
tryPostponeIfNoneOrMVar expectedType?
match expectedType? with
| some expectedType => pure expectedType
| none => mkFreshTypeMVar
private def tryPostponeIfDiscrTypeIsMVar (matchStx : Syntax) : TermElabM Unit := do
-- We don't wait for the discriminants types when match type is provided by user
if getMatchOptMotive matchStx |>.isNone then
let discrs := getDiscrs matchStx
for discr in discrs do
let term := discr[1]
match (← isAtomicDiscr? term) with
| none => throwErrorAt discr "unexpected discriminant" -- see `expandNonAtomicDiscrs?
| some d =>
let dType ← inferType d
trace[Elab.match] "discr {d} : {dType}"
tryPostponeIfMVar dType
/-
We (try to) elaborate a `match` only when the expected type is available.
If the `matchType` has not been provided by the user, we also try to postpone elaboration if the type
of a discriminant is not available. That is, it is of the form `(?m ...)`.
We use `expandNonAtomicDiscrs?` to make sure all discriminants are local variables.
This is a standard trick we use in the elaborator, and it is also used to elaborate structure instances.
Suppose, we are trying to elaborate
```
match g x with
| ... => ...
```
`expandNonAtomicDiscrs?` converts it intro
```
let _discr := g x
match _discr with
| ... => ...
```
Thus, at `tryPostponeIfDiscrTypeIsMVar` we only need to check whether the type of `_discr` is not of the form `(?m ...)`.
Note that, the auxiliary variable `_discr` is expanded at `elabAtomicDiscr`.
This elaboration technique is needed to elaborate terms such as:
```lean
xs.filter fun (a, b) => a > b
```
which are syntax sugar for
```lean
List.filter (fun p => match p with | (a, b) => a > b) xs
```
When we visit `match p with | (a, b) => a > b`, we don't know the type of `p` yet.
-/
private def waitExpectedTypeAndDiscrs (matchStx : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do
tryPostponeIfNoneOrMVar expectedType?
tryPostponeIfDiscrTypeIsMVar matchStx
match expectedType? with
| some expectedType => return expectedType
| none => mkFreshTypeMVar
/-
```
leading_parser "match " >> optional generalizingParam >> optional motive >> sepBy1 matchDiscr ", " >> " with " >> ppDedent matchAlts
```
Remark the `optIdent` must be `none` at `matchDiscr`. They are expanded by `expandMatchDiscr?`.
-/
private def elabMatchCore (stx : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do
let expectedType ← waitExpectedTypeAndDiscrs stx expectedType?
let discrStxs := (getDiscrs stx).map fun d => d
let gen? := getMatchGeneralizing? stx
let altViews := getMatchAlts stx
let matchOptMotive := getMatchOptMotive stx
elabMatchAux gen? discrStxs altViews matchOptMotive expectedType
private def isPatternVar (stx : Syntax) : TermElabM Bool := do
match (← resolveId? stx "pattern") with
| none => return isAtomicIdent stx
| some f => match f with
| Expr.const fName _ _ =>
match (← getEnv).find? fName with
| some (ConstantInfo.ctorInfo _) => return false
| some _ => return !hasMatchPatternAttribute (← getEnv) fName
| _ => return isAtomicIdent stx
| _ => return isAtomicIdent stx
where
isAtomicIdent (stx : Syntax) : Bool :=
stx.isIdent && stx.getId.eraseMacroScopes.isAtomic
/--
Pattern matching. `match e, ... with | p, ... => f | ...` matches each given
term `e` against each pattern `p` of a match alternative. When all patterns
of an alternative match, the `match` term evaluates to the value of the
corresponding right-hand side `f` with the pattern variables bound to the
respective matched values.
When not constructing a proof, `match` does not automatically substitute variables
matched on in dependent variables' types. Use `match (generalizing := true) ...` to
enforce this. -/
@[builtinTermElab «match»] def elabMatch : TermElab := fun stx expectedType? => do
match stx with
| `(match $discr:term with | $y:ident => $rhs:term) =>
if (← isPatternVar y) then expandSimpleMatch stx discr y rhs expectedType? else elabMatchDefault stx expectedType?
| _ => elabMatchDefault stx expectedType?
where
elabMatchDefault (stx : Syntax) (expectedType? : Option Expr) : TermElabM Expr := do
match (← liftMacroM <| expandMatchAlts? stx) with
| some stxNew => withMacroExpansion stx stxNew <| elabTerm stxNew expectedType?
| none =>
match (← expandNonAtomicDiscrs? stx) with
| some stxNew => withMacroExpansion stx stxNew <| elabTerm stxNew expectedType?
| none =>
let discrs := getDiscrs stx
let matchOptMotive := getMatchOptMotive stx
if !matchOptMotive.isNone && discrs.any fun d => !d[0].isNone then
throwErrorAt matchOptMotive "match motive should not be provided when discriminants with equality proofs are used"
elabMatchCore stx expectedType?
builtin_initialize
registerTraceClass `Elab.match
-- leading_parser:leadPrec "nomatch " >> termParser
/-- Empty match/ex falso. `nomatch e` is of arbitrary type `α : Sort u` if
Lean can show that an empty set of patterns is exhaustive given `e`'s type,
e.g. because it has no constructors. -/
@[builtinTermElab «nomatch»] def elabNoMatch : TermElab := fun stx expectedType? => do
match stx with
| `(nomatch $discrExpr) =>
match (← isLocalIdent? discrExpr) with
| some _ =>
let expectedType ← waitExpectedType expectedType?
let discr := mkNode ``Lean.Parser.Term.matchDiscr #[mkNullNode, discrExpr]
elabMatchAux none #[discr] #[] mkNullNode expectedType
| _ =>
let d ← mkAuxDiscr
let stxNew ← `(let $d:ident := $discrExpr; nomatch $d:ident)
withMacroExpansion stx stxNew <| elabTerm stxNew expectedType?
| _ => throwUnsupportedSyntax
end Lean.Elab.Term