lean4-htt/src/Lean/Meta/Match/Match.lean
Joachim Breitner e2f5938e74
refactor: some Meta.Match.Match refactorings (#11011)
This PR extracts some refactorings from #10763, including dropping dead
code and not failing in `inaccessibleAsCtor`, which leadas to (slightly)
better error messages, and also on the grounds that the failing
alternative may actually be unreachable.
2025-10-29 23:24:57 +00:00

1086 lines
46 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Meta.Closure
public import Lean.Meta.Tactic.Contradiction
public import Lean.Meta.GeneralizeTelescope
public import Lean.Meta.Match.Basic
public import Lean.Meta.Match.MatcherApp.Basic
public import Lean.Meta.Match.MVarRenaming
public section
namespace Lean.Meta.Match
private def mkIncorrectNumberOfPatternsMsg [ToMessageData α]
(discrepancyKind : String) (expected actual : Nat) (pats : List α) :=
let patternsMsg := MessageData.joinSep (pats.map toMessageData) ", "
m!"{discrepancyKind} patterns in match alternative: Expected {expected}, \
but found {actual}:{indentD patternsMsg}"
/--
Throws an error indicating that the alternative at `ref` contains an unexpected number of patterns.
Remark: we allow `α` to be arbitrary because this error may be thrown before or after elaborating
pattern syntax.
-/
def throwIncorrectNumberOfPatternsAt [ToMessageData α]
(ref : Syntax) (discrepancyKind : String) (expected actual : Nat) (pats : List α)
: MetaM Unit := do
throwErrorAt ref (mkIncorrectNumberOfPatternsMsg discrepancyKind expected actual pats)
/--
Logs an error indicating that the alternative at `ref` contains an unexpected number of patterns.
Remark: we allow `α` to be arbitrary because this error may be thrown before or after elaborating
pattern syntax.
-/
def logIncorrectNumberOfPatternsAt [ToMessageData α]
(ref : Syntax) (discrepancyKind : String) (expected actual : Nat) (pats : List α)
: MetaM Unit :=
logErrorAt ref (mkIncorrectNumberOfPatternsMsg discrepancyKind expected actual pats)
/-- The number of patterns in each AltLHS must be equal to the number of discriminants. -/
private def checkNumPatterns (numDiscrs : Nat) (lhss : List AltLHS) : MetaM Unit := do
for lhs in lhss do
let doThrow (kind : String) := withExistingLocalDecls lhs.fvarDecls do
throwIncorrectNumberOfPatternsAt lhs.ref kind numDiscrs lhs.patterns.length
(lhs.patterns.map Pattern.toMessageData)
if lhs.patterns.length < numDiscrs then
doThrow "Not enough"
else if lhs.patterns.length > numDiscrs then
-- This case should be impossible, as an alternative with too many patterns will cause an
-- error to be thrown in `Lean.Elab.Term.elabPatterns`
doThrow "Too many"
/--
Execute `k hs` where `hs` contains new equalities `h : lhs[i] = rhs[i]` for each `discrInfos[i] = some h`.
Assume `lhs.size == rhs.size == discrInfos.size`
-/
private partial def withEqs (lhs rhs : Array Expr) (discrInfos : Array DiscrInfo) (k : Array Expr → MetaM α) : MetaM α := do
go 0 #[]
where
go (i : Nat) (hs : Array Expr) : MetaM α := do
if i < lhs.size then
if let some hName := discrInfos[i]!.hName? then
withLocalDeclD hName (← mkEqHEq lhs[i]! rhs[i]!) fun h =>
go (i+1) (hs.push h)
else
go (i+1) hs
else
k hs
/-- Given a list of `AltLHS`, create a minor premise for each one, convert them into `Alt`, and then execute `k` -/
private def withAlts {α} (motive : Expr) (discrs : Array Expr) (discrInfos : Array DiscrInfo) (lhss : List AltLHS) (k : List Alt → Array (Expr × Nat) → MetaM α) : MetaM α :=
loop lhss [] #[]
where
mkMinorType (xs : Array Expr) (lhs : AltLHS) : MetaM Expr :=
withExistingLocalDecls lhs.fvarDecls do
let args ← lhs.patterns.toArray.mapM (Pattern.toExpr · (annotate := true))
let minorType := mkAppN motive args
withEqs discrs args discrInfos fun eqs => do
mkForallFVars (xs ++ eqs) minorType
loop (lhss : List AltLHS) (alts : List Alt) (minors : Array (Expr × Nat)) : MetaM α := do
match lhss with
| [] => k alts.reverse minors
| lhs::lhss =>
let xs := lhs.fvarDecls.toArray.map LocalDecl.toExpr
let minorType ← mkMinorType xs lhs
let hasParams := !xs.isEmpty || discrInfos.any fun info => info.hName?.isSome
let (minorType, minorNumParams) := if hasParams then (minorType, xs.size) else (mkSimpleThunkType minorType, 1)
let idx := alts.length
let minorName := (`h).appendIndexAfter (idx+1)
trace[Meta.Match.debug] "minor premise {minorName} : {minorType}"
withLocalDeclD minorName minorType fun minor => do
let rhs := if hasParams then mkAppN minor xs else mkApp minor (mkConst `Unit.unit)
let minors := minors.push (minor, minorNumParams)
let fvarDecls ← lhs.fvarDecls.mapM instantiateLocalDeclMVars
let alts := { ref := lhs.ref, idx := idx, rhs := rhs, fvarDecls := fvarDecls, patterns := lhs.patterns, cnstrs := [] } :: alts
loop lhss alts minors
structure State where
used : Std.HashSet Nat := {} -- used alternatives
counterExamples : List (List Example) := []
/-- Return true if the given (sub-)problem has been solved. -/
private def isDone (p : Problem) : Bool :=
p.vars.isEmpty
/-- Return true if the next element on the `p.vars` list is a variable. -/
private def isNextVar (p : Problem) : Bool :=
match p.vars with
| .fvar _ :: _ => true
| _ => false
private def hasAsPattern (p : Problem) : Bool :=
p.alts.any fun alt => match alt.patterns with
| .as .. :: _ => true
| _ => false
private def hasCtorPattern (p : Problem) : Bool :=
p.alts.any fun alt => match alt.patterns with
| .ctor .. :: _ => true
| _ => false
private def hasValPattern (p : Problem) : Bool :=
p.alts.any fun alt => match alt.patterns with
| .val _ :: _ => true
| _ => false
private def hasNatValPattern (p : Problem) : MetaM Bool :=
p.alts.anyM fun alt => do
match alt.patterns with
| .val v :: _ => return (← getNatValue? v).isSome
| _ => return false
private def hasIntValPattern (p : Problem) : MetaM Bool :=
p.alts.anyM fun alt => do
match alt.patterns with
| .val v :: _ => return (← getIntValue? v).isSome
| _ => return false
private def hasVarPattern (p : Problem) : Bool :=
p.alts.any fun alt => match alt.patterns with
| .var _ :: _ => true
| _ => false
private def hasArrayLitPattern (p : Problem) : Bool :=
p.alts.any fun alt => match alt.patterns with
| .arrayLit .. :: _ => true
| _ => false
private def isVariableTransition (p : Problem) : Bool :=
p.alts.all fun alt => match alt.patterns with
| .inaccessible _ :: _ => true
| .var _ :: _ => true
| _ => false
private def getInductiveVal? (x : Expr) : MetaM (Option InductiveVal) := do
let xType ← inferType x
let xType ← whnfD xType
match xType.getAppFn with
| Expr.const constName _ =>
let cinfo ← getConstInfo constName
match cinfo with
| ConstantInfo.inductInfo val => return some val
| _ => return none
| _ => return none
def isCurrVarInductive (p : Problem) : MetaM Bool := do
match p.vars with
| [] => return false
| x::_ => withGoalOf p do
let val? ← getInductiveVal? x
return val?.isSome
private def isConstructorTransition (p : Problem) : MetaM Bool := do
return (← isCurrVarInductive p)
&& (hasCtorPattern p || p.alts.isEmpty)
&& p.alts.all fun alt => match alt.patterns with
| .ctor .. :: _ => true
| .var _ :: _ => true
| .inaccessible _ :: _ => true
| _ => false
private def isValueTransition (p : Problem) : Bool :=
hasVarPattern p && hasValPattern p
&& p.alts.all fun alt => match alt.patterns with
| .val _ :: _ => true
| .var _ :: _ => true
| _ => false
private def isValueOnlyTransitionCore (p : Problem) (isValue : Expr → MetaM Bool) : MetaM Bool := do
if hasVarPattern p then return false
if !hasValPattern p then return false
p.alts.allM fun alt => do
match alt.patterns with
| .val v :: _ => isValue v
| .ctor .. :: _ => return true
| _ => return false
private def isFinValueTransition (p : Problem) : MetaM Bool :=
isValueOnlyTransitionCore p fun e => return (← getFinValue? e).isSome
private def isBitVecValueTransition (p : Problem) : MetaM Bool :=
isValueOnlyTransitionCore p fun e => return (← getBitVecValue? e).isSome
private def isArrayLitTransition (p : Problem) : Bool :=
hasArrayLitPattern p && hasVarPattern p
&& p.alts.all fun alt => match alt.patterns with
| .arrayLit .. :: _ => true
| .var _ :: _ => true
| _ => false
private def hasCtorOrInaccessible (p : Problem) : Bool :=
!isNextVar p ||
p.alts.any fun alt => match alt.patterns with
| .ctor .. :: _ => true
| .inaccessible _ :: _ => true
| _ => false
private def isNatValueTransition (p : Problem) : MetaM Bool := do
return (← hasNatValPattern p) && hasCtorOrInaccessible p
/--
Predicate for testing whether we need to expand `Int` value patterns into constructors.
There are two cases:
- We have constructor or inaccessible patterns. Example:
```
| 0, ...
| Int.toVal p, ...
...
```
- We don't have the `else`-case (i.e., variable pattern). This can happen
when the non-value cases are unreachable.
-/
private def isIntValueTransition (p : Problem) : MetaM Bool := do
unless (← hasIntValPattern p) do return false
return hasCtorOrInaccessible p || !hasVarPattern p
private def processSkipInaccessible (p : Problem) : Problem := Id.run do
let x :: xs := p.vars | unreachable!
let alts := p.alts.map fun alt => Id.run do
let .inaccessible e :: ps := alt.patterns | unreachable!
{ alt with patterns := ps, cnstrs := (x, e) :: alt.cnstrs }
{ p with alts := alts, vars := xs }
/--
If constraint is of the form `e ≋ x` where `x` is a free variable, reorient it
as `x ≋ e` If
- `x` is an `alt`-local declaration
- `e` is not a free variable.
-/
private def reorientCnstrs (alt : Alt) : Alt :=
let cnstrs := alt.cnstrs.map fun (lhs, rhs) =>
if rhs.isFVar && alt.isLocalDecl rhs.fvarId! then
(rhs, lhs)
else if !lhs.isFVar && rhs.isFVar then
(rhs, lhs)
else
(lhs, rhs)
{ alt with cnstrs }
/--
Remove constraints of the form `lhs ≋ rhs` where `lhs` and `rhs` are definitionally equal,
or `lhs` is a free variable.
Dropping unsolved constraints where `lhs` is a free variable seems unsound, but simply leads to later
errors about the type of the alternative not matching the goal type, which is arguably a bit more
user-friendly than showing possibly match-compilation-internal variable names.
-/
private def filterTrivialCnstrs (alt : Alt) : MetaM Alt := do
let cnstrs ← withExistingLocalDecls alt.fvarDecls do
alt.cnstrs.filterM fun (lhs, rhs) => do
if (← isDefEqGuarded lhs rhs) then
return false
else if lhs.isFVar then
return false
else
return true
return { alt with cnstrs }
/--
Find an alternative constraint of the form `x ≋ e` where `x` is an alternative
local declarations, and `x` and `e` have definitionally equal types.
Then, replace `x` with `e` in the alternative, and return it.
Return `none` if the alternative does not contain a constraint of this form.
-/
private def solveSomeLocalFVarIdCnstr? (alt : Alt) : MetaM (Option Alt) :=
withExistingLocalDecls alt.fvarDecls do
let (some (fvarId, val), cnstrs) ← go alt.cnstrs | return none
trace[Meta.Match.match] "found cnstr to solve {mkFVar fvarId} ↦ {val}"
return some <| { alt with cnstrs }.replaceFVarId fvarId val
where
go (cnstrs : List (Expr × Expr)) := do
match cnstrs with
| [] => return (none, [])
| (lhs, rhs) :: cnstrs =>
if lhs.isFVar && alt.isLocalDecl lhs.fvarId! then
if !(← dependsOn rhs lhs.fvarId!) && (← isDefEqGuarded (← inferType lhs) (← inferType rhs)) then
return (some (lhs.fvarId!, rhs), cnstrs)
let (p, cnstrs) ← go cnstrs
return (p, (lhs, rhs) :: cnstrs)
/--
Solve pending alternative constraints. If all constraints can be solved perform assignment
`mvarId := alt.rhs`, and return true.
-/
private partial def solveCnstrs (mvarId : MVarId) (alt : Alt) : StateRefT State MetaM Unit := do
go (reorientCnstrs alt)
where
go (alt : Alt) : StateRefT State MetaM Unit := do
match (← solveSomeLocalFVarIdCnstr? alt) with
| some alt => go alt
| none =>
let alt ← filterTrivialCnstrs alt
if alt.cnstrs.isEmpty then
let eType ← inferType alt.rhs
let targetType ← mvarId.getType
unless (← isDefEqGuarded targetType eType) do
trace[Meta.Match.match] "assignGoalOf failed {eType} =?= {targetType}"
throwErrorAt alt.ref "Dependent elimination failed: Type mismatch when solving this alternative: it {← mkHasTypeButIsExpectedMsg eType targetType}"
mvarId.assign alt.rhs
modify fun s => { s with used := s.used.insert alt.idx }
else
trace[Meta.Match.match] "alt has unsolved cnstrs:\n{← alt.toMessageData}"
let mut msg := m!"Dependent match elimination failed: Could not solve constraints"
for (lhs, rhs) in alt.cnstrs do
msg := msg ++ m!"\n {lhs} ≋ {rhs}"
throwErrorAt alt.ref msg
/--
Try to solve the problem by using the first alternative whose pending constraints can be resolved.
-/
private def processLeaf (p : Problem) : StateRefT State MetaM Unit :=
p.mvarId.withContext do
trace[Meta.Match.match] "local context at processLeaf:\n{(← mkFreshTypeMVar).mvarId!}"
go p.alts
where
go (alts : List Alt) : StateRefT State MetaM Unit := do
match alts with
| [] =>
/- TODO: allow users to configure which tactic is used to close leaves. -/
unless (← p.mvarId.contradictionCore {}) do
trace[Meta.Match.match] "missing alternative"
p.mvarId.admit
modify fun s => { s with counterExamples := p.examples :: s.counterExamples }
| alt :: _ =>
solveCnstrs p.mvarId alt
private def processAsPattern (p : Problem) : MetaM Problem := withGoalOf p do
let x :: _ := p.vars | unreachable!
let alts ← p.alts.mapM fun alt => do
match alt.patterns with
| .as fvarId p h :: ps =>
/- We used to use eagerly check the types here (using what was called `checkAndReplaceFVarId`),
but `x` and `fvarId` can have different types when dependent types are being used.
Let's consider the repro for issue #471
```
inductive vec : Nat → Type
| nil : vec 0
| cons : Int → vec n → vec n.succ
def vec_len : vec n → Nat
| vec.nil => 0
| x@(vec.cons h t) => vec_len t + 1
```
we reach the state
```
[Meta.Match.match] remaining variables: [x✝:(vec n✝)]
alternatives:
[n:(Nat), x:(vec (Nat.succ n)), h:(Int), t:(vec n)] |- [x@(vec.cons n h t)] => h_1 n x h t
[x✝:(vec n✝)] |- [x✝] => h_2 n✝ x✝
```
The variables `x✝:(vec n✝)` and `x:(vec (Nat.succ n))` have different types, but we perform the substitution anyway,
because we claim the "discrepancy" will be corrected after we process the pattern `(vec.cons n h t)`.
The right-hand-side is temporarily type incorrect, but we claim this is fine because it will be type correct again after
we the pattern `(vec.cons n h t)`. TODO: try to find a cleaner solution.
-/
let r ← mkEqRefl x
return { alt with patterns := p :: ps }.replaceFVarId fvarId x |>.replaceFVarId h r
| _ => return alt
return { p with alts := alts }
private def processVariable (p : Problem) : MetaM Problem := withGoalOf p do
let x :: xs := p.vars | unreachable!
let alts ← p.alts.mapM fun alt => do
match alt.patterns with
| .inaccessible e :: ps => return { alt with patterns := ps, cnstrs := (x, e) :: alt.cnstrs }
| .var fvarId :: ps =>
withExistingLocalDecls alt.fvarDecls do
if (← isDefEqGuarded (← fvarId.getType) (← inferType x)) then
return { alt with patterns := ps }.replaceFVarId fvarId x
else
return { alt with patterns := ps, cnstrs := (mkFVar fvarId, x) :: alt.cnstrs }
| _ => unreachable!
return { p with alts := alts, vars := xs }
/-!
Note that we decided to store pending constraints to address issues exposed by #1279 and #1361.
Here is a simplified version of the example on this issue (see test: `1279_simplified.lean`)
```lean
inductive Arrow : Type → Type → Type 1
| id : Arrow a a
| unit : Arrow Unit Unit
| comp : Arrow β γ → Arrow α β → Arrow α γ
deriving Repr
def Arrow.compose (f : Arrow β γ) (g : Arrow α β) : Arrow α γ :=
match f, g with
| id, g => g
| f, id => f
| f, g => comp f g
```
The initial state for the `match`-expression above is
```lean
[Meta.Match.match] remaining variables: [β✝:(Type), γ✝:(Type), f✝:(Arrow β✝ γ✝), g✝:(Arrow α β✝)]
alternatives:
[β:(Type), g:(Arrow α β)] |- [β, .(β), (Arrow.id .(β)), g] => h_1 β g
[γ:(Type), f:(Arrow α γ)] |- [.(α), γ, f, (Arrow.id .(α))] => h_2 γ f
[β:(Type), γ:(Type), f:(Arrow β γ), g:(Arrow α β)] |- [β, γ, f, g] => h_3 β γ f g
```
The first step is a variable-transition which replaces `β` with `β✝` in the first and third alternatives.
The constraint `β✝ ≋ α` in the second alternative used to be discarded. We now store it at the
alternative `cnstrs` field.
-/
private def inLocalDecls (localDecls : List LocalDecl) (fvarId : FVarId) : Bool :=
localDecls.any fun d => d.fvarId == fvarId
private def expandVarIntoCtor (alt : Alt) (ctorName : Name) : MetaM Alt := do
let .var fvarId :: ps := alt.patterns | unreachable!
let alt := { alt with patterns := ps}
withExistingLocalDecls alt.fvarDecls do
trace[Meta.Match.unify] "expandVarIntoCtor fvarId: {mkFVar fvarId}, ctorName: {ctorName}, alt:\n{← alt.toMessageData}"
let expectedType ← inferType (mkFVar fvarId)
let expectedType ← whnfD expectedType
let (ctorLevels, ctorParams) ← getInductiveUniverseAndParams expectedType
let ctor := mkAppN (mkConst ctorName ctorLevels) ctorParams
let ctorType ← inferType ctor
forallTelescopeReducing ctorType fun ctorFields resultType => do
let ctor := mkAppN ctor ctorFields
let alt := alt.replaceFVarId fvarId ctor
let ctorFieldDecls ← ctorFields.mapM fun ctorField => ctorField.fvarId!.getDecl
let newAltDecls := ctorFieldDecls.toList ++ alt.fvarDecls
let mut cnstrs := alt.cnstrs
unless (← isDefEqGuarded resultType expectedType) do
cnstrs := (resultType, expectedType) :: cnstrs
trace[Meta.Match.unify] "expandVarIntoCtor {mkFVar fvarId} : {expectedType}, ctor: {ctor}"
let ctorFieldPatterns := ctorFieldDecls.toList.map fun decl => Pattern.var decl.fvarId
return { alt with fvarDecls := newAltDecls, patterns := ctorFieldPatterns ++ alt.patterns, cnstrs }
private def expandInaccessibleIntoVar (alt : Alt) : MetaM Alt := do
let .inaccessible e :: ps := alt.patterns | unreachable!
withExistingLocalDecls alt.fvarDecls do
let type ← inferType e
withLocalDeclD `x type fun x => do
trace[Meta.Match.unify] "expandInaccessibleIntoVar {x} : {type} := {e}"
return { alt with
fvarDecls := (← x.fvarId!.getDecl) :: alt.fvarDecls
patterns := .var x.fvarId! :: ps
cnstrs := (x, e) :: alt.cnstrs
}
private def hasRecursiveType (x : Expr) : MetaM Bool := do
match (← getInductiveVal? x) with
| some val => return val.isRec
| _ => return false
/-- Given `alt` s.t. the next pattern is an inaccessible pattern `e`,
try to normalize `e` into a constructor application.
If it is a constructor application of `ctorName`,
update the next patterns with the fields of the constructor.
Otherwise, move it to contraints, so that we fail unless some later step
eliminates this alternative.
-/
def processInaccessibleAsCtor (alt : Alt) (ctorName : Name) : MetaM Alt := do
let .inaccessible e :: ps := alt.patterns | unreachable!
trace[Meta.Match.match] "inaccessible step {e} as ctor {ctorName}"
withExistingLocalDecls alt.fvarDecls do
-- Try to push inaccessible annotations.
let e ← whnfD e
if let some (ctorVal, ctorArgs) ← constructorApp? e then
if ctorVal.name == ctorName then
let fields := ctorArgs.extract ctorVal.numParams ctorArgs.size
let fields := fields.toList.map .inaccessible
return { alt with patterns := fields ++ ps }
let alt' ← expandInaccessibleIntoVar alt
expandVarIntoCtor alt' ctorName
private def hasNonTrivialExample (p : Problem) : Bool :=
p.examples.any fun | Example.underscore => false | _ => true
private def throwCasesException (p : Problem) (ex : Exception) : MetaM α := do
match ex with
| .error ref msg =>
let exampleMsg :=
if hasNonTrivialExample p then m!" after processing{indentD <| examplesToMessageData p.examples}" else ""
throw <| Exception.error ref <| msg.composePreservingKind <| m!"{exampleMsg}\n" ++
"the dependent pattern matcher can solve the following kinds of equations\n" ++
"- <var> = <term> and <term> = <var>\n" ++
"- <term> = <term> where the terms are definitionally equal\n" ++
"- <constructor> = <constructor>, examples: List.cons x xs = List.cons y ys, and List.cons x xs = List.nil"
| _ => throw ex
private def processConstructor (p : Problem) : MetaM (Array Problem) := do
trace[Meta.Match.match] "constructor step"
let x :: xs := p.vars | unreachable!
let subgoals? ← commitWhenSome? do
let subgoals ←
try
p.mvarId.cases x.fvarId!
catch ex =>
if p.alts.isEmpty then
/- If we have no alternatives and dependent pattern matching fails, then a "missing cases" error is better than a "stuck" error message. -/
return none
else
throwCasesException p ex
if subgoals.isEmpty then
/- Easy case: we have solved problem `p` since there are no subgoals -/
return some #[]
else if !p.alts.isEmpty then
return some subgoals
else do
let isRec ← withGoalOf p <| hasRecursiveType x
/- If there are no alternatives and the type of the current variable is recursive, we do NOT consider
a constructor-transition to avoid nontermination.
TODO: implement a more general approach if this is not sufficient in practice -/
if isRec then
return none
else
return some subgoals
let some subgoals := subgoals? | return #[{ p with vars := xs }]
subgoals.mapM fun subgoal => subgoal.mvarId.withContext do
let subst := subgoal.subst
let fields := subgoal.fields.toList
let newVars := fields ++ xs
let newVars := newVars.map fun x => x.applyFVarSubst subst
let subex := Example.ctor subgoal.ctorName <| fields.map fun field => match field with
| .fvar fvarId => Example.var fvarId
| _ => Example.underscore -- This case can happen due to dependent elimination
let examples := p.examples.map <| Example.replaceFVarId x.fvarId! subex
let examples := examples.map <| Example.applyFVarSubst subst
let newAlts := p.alts.filter fun alt => match alt.patterns with
| .ctor n .. :: _ => n == subgoal.ctorName
| .var _ :: _ => true
| .inaccessible _ :: _ => true
| _ => false
let newAlts := newAlts.map fun alt => alt.applyFVarSubst subst
let newAlts ← newAlts.mapM fun alt => do
match alt.patterns with
| .ctor _ _ _ fields :: ps => return { alt with patterns := fields ++ ps }
| .var _ :: _ => expandVarIntoCtor alt subgoal.ctorName
| .inaccessible _ :: _ => processInaccessibleAsCtor alt subgoal.ctorName
| _ => unreachable!
return { mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples }
private def processNonVariable (p : Problem) : MetaM Problem := withGoalOf p do
let x :: xs := p.vars | unreachable!
if let some (ctorVal, xArgs) ← withTransparency .default <| constructorApp'? x then
if hasCtorPattern p then
let alts ← p.alts.filterMapM fun alt => do
match alt.patterns with
| .ctor ctorName _ _ fields :: ps =>
if ctorName != ctorVal.name then
return none
else
return some { alt with patterns := fields ++ ps }
| .inaccessible _ :: _ => processInaccessibleAsCtor alt ctorVal.name
| .var _ :: _ => expandVarIntoCtor alt ctorVal.name
| _ => unreachable!
let xFields := xArgs.extract ctorVal.numParams xArgs.size
return { p with alts := alts, vars := xFields.toList ++ xs }
let alts ← p.alts.mapM fun alt => do
match alt.patterns with
| p :: ps => return { alt with patterns := ps, cnstrs := (x, ← p.toExpr) :: alt.cnstrs }
| _ => unreachable!
return { p with alts := alts, vars := xs }
private def collectValues (p : Problem) : Array Expr :=
p.alts.foldl (init := #[]) fun values alt =>
match alt.patterns with
| .val v :: _ => if values.contains v then values else values.push v
| _ => values
private def isFirstPatternVar (alt : Alt) : Bool :=
match alt.patterns with
| .var _ :: _ => true
| _ => false
private def processValue (p : Problem) : MetaM (Array Problem) := do
trace[Meta.Match.match] "value step"
let x :: xs := p.vars | unreachable!
let values := collectValues p
let subgoals ← caseValues p.mvarId x.fvarId! values (substNewEqs := true)
subgoals.mapIdxM fun i subgoal => do
trace[Meta.Match.match] "processValue subgoal\n{MessageData.ofGoal subgoal.mvarId}"
if h : i < values.size then
let value := values[i]
-- (x = value) branch
let subst := subgoal.subst
trace[Meta.Match.match] "processValue subst: {subst.map.toList.map fun p => mkFVar p.1}, {subst.map.toList.map fun p => p.2}"
let examples := p.examples.map <| Example.replaceFVarId x.fvarId! (Example.val value)
let examples := examples.map <| Example.applyFVarSubst subst
let newAlts := p.alts.filter fun alt => match alt.patterns with
| .val v :: _ => v == value
| .var _ :: _ => true
| _ => false
let newAlts := newAlts.map fun alt => alt.applyFVarSubst subst
let newAlts := newAlts.map fun alt => match alt.patterns with
| .val _ :: ps => { alt with patterns := ps }
| .var fvarId :: ps =>
let alt := { alt with patterns := ps }
alt.replaceFVarId fvarId value
| _ => unreachable!
let newVars := xs.map fun x => x.applyFVarSubst subst
return { mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples }
else
-- else branch for value
let newAlts := p.alts.filter isFirstPatternVar
return { p with mvarId := subgoal.mvarId, alts := newAlts, vars := x::xs }
private def collectArraySizes (p : Problem) : Array Nat :=
p.alts.foldl (init := #[]) fun sizes alt =>
match alt.patterns with
| .arrayLit _ ps :: _ => let sz := ps.length; if sizes.contains sz then sizes else sizes.push sz
| _ => sizes
private def expandVarIntoArrayLit (alt : Alt) (fvarId : FVarId) (arrayElemType : Expr) (arraySize : Nat) : MetaM Alt :=
withExistingLocalDecls alt.fvarDecls do
let fvarDecl ← fvarId.getDecl
let varNamePrefix := fvarDecl.userName
let rec loop (n : Nat) (newVars : Array Expr) := do
match n with
| n+1 =>
withLocalDeclD (varNamePrefix.appendIndexAfter (n+1)) arrayElemType fun x =>
loop n (newVars.push x)
| 0 =>
let arrayLit ← mkArrayLit arrayElemType newVars.toList
let alt := alt.replaceFVarId fvarId arrayLit
let newDecls ← newVars.toList.mapM fun newVar => newVar.fvarId!.getDecl
let newPatterns := newVars.toList.map fun newVar => .var newVar.fvarId!
return { alt with fvarDecls := newDecls ++ alt.fvarDecls, patterns := newPatterns ++ alt.patterns }
loop arraySize #[]
private def processArrayLit (p : Problem) : MetaM (Array Problem) := do
trace[Meta.Match.match] "array literal step"
let x :: xs := p.vars | unreachable!
let sizes := collectArraySizes p
let subgoals ← caseArraySizes p.mvarId x.fvarId! sizes
subgoals.mapIdxM fun i subgoal => do
if h : i < sizes.size then
let size := sizes[i]
let subst := subgoal.subst
let elems := subgoal.elems.toList
let newVars := elems.map mkFVar ++ xs
let newVars := newVars.map fun x => x.applyFVarSubst subst
let subex := Example.arrayLit <| elems.map Example.var
let examples := p.examples.map <| Example.replaceFVarId x.fvarId! subex
let examples := examples.map <| Example.applyFVarSubst subst
let newAlts := p.alts.filter fun alt => match alt.patterns with
| .arrayLit _ ps :: _ => ps.length == size
| .var _ :: _ => true
| _ => false
let newAlts := newAlts.map fun alt => alt.applyFVarSubst subst
let newAlts ← newAlts.mapM fun alt => do
match alt.patterns with
| .arrayLit _ pats :: ps => return { alt with patterns := pats ++ ps }
| .var fvarId :: ps =>
let α ← getArrayArgType <| subst.apply x
expandVarIntoArrayLit { alt with patterns := ps } fvarId α size
| _ => unreachable!
return { mvarId := subgoal.mvarId, vars := newVars, alts := newAlts, examples := examples }
else
-- else branch
let newAlts := p.alts.filter isFirstPatternVar
return { p with mvarId := subgoal.mvarId, alts := newAlts, vars := x::xs }
private def expandNatValuePattern (p : Problem) : MetaM Problem := do
let alts ← p.alts.mapM fun alt => do
match alt.patterns with
| .val n :: ps =>
match (← getNatValue? n) with
| some 0 => return { alt with patterns := .ctor ``Nat.zero [] [] [] :: ps }
| some (n+1) => return { alt with patterns := .ctor ``Nat.succ [] [] [.val (toExpr n)] :: ps }
| _ => return alt
| _ => return alt
return { p with alts := alts }
private def expandIntValuePattern (p : Problem) : MetaM Problem := do
let alts ← p.alts.mapM fun alt => do
match alt.patterns with
| .val n :: ps =>
match (← getIntValue? n) with
| some i =>
if i >= 0 then
return { alt with patterns := .ctor ``Int.ofNat [] [] [.val (toExpr i.toNat)] :: ps }
else
return { alt with patterns := .ctor ``Int.negSucc [] [] [.val (toExpr (-(i + 1)).toNat)] :: ps }
| _ => return alt
| _ => return alt
return { p with alts := alts }
private def expandFinValuePattern (p : Problem) : MetaM Problem := do
let alts ← p.alts.mapM fun alt => do
let .val n :: ps := alt.patterns | return alt
let some ⟨n, v⟩ ← getFinValue? n | return alt
let p ← mkLt (toExpr v.val) (toExpr n)
let h ← mkDecideProof p
return { alt with patterns := .ctor ``Fin.mk [] [toExpr n] [.val (toExpr v.val), .inaccessible h] :: ps }
return { p with alts := alts }
private def expandBitVecValuePattern (p : Problem) : MetaM Problem := do
let alts ← p.alts.mapM fun alt => do
let .val n :: ps := alt.patterns | return alt
let some ⟨_, v⟩ ← getBitVecValue? n | return alt
return { alt with patterns := .ctor ``BitVec.ofFin [] [] [.val (toExpr v.toFin)] :: ps }
return { p with alts := alts }
private def traceStep (msg : String) : StateRefT State MetaM Unit := do
trace[Meta.Match.match] "{msg} step"
private def traceState (p : Problem) : MetaM Unit :=
withGoalOf p (traceM `Meta.Match.match p.toMessageData)
private def throwNonSupported (p : Problem) : MetaM Unit :=
withGoalOf p do
let msg ← p.toMessageData
throwError "Failed to compile pattern matching: Stuck at{indentD msg}"
private def checkNextPatternTypes (p : Problem) : MetaM Unit := do
match p.vars with
| [] => return ()
| x::_ => withGoalOf p do
for alt in p.alts do
withRef alt.ref do
match alt.patterns with
| [] => return ()
| p::_ =>
let e ← p.toExpr
let xType ← inferType x
let eType ← inferType e
unless (← isDefEq xType eType) do
throwError "Type mismatch in pattern: Pattern{indentExpr e}\n{← mkHasTypeButIsExpectedMsg eType xType}"
def isExFalsoTransition (p : Problem) : MetaM Bool := do
if p.alts.isEmpty then
withGoalOf p do
let targetType ← p.mvarId.getType
return !targetType.isFalse
else
return false
def processExFalso (p : Problem) : MetaM Problem := do
let mvarId' ← p.mvarId.exfalso
return { p with mvarId := mvarId' }
private partial def process (p : Problem) : StateRefT State MetaM Unit := do
traceState p
if isDone p then
traceStep ("leaf")
processLeaf p
return
if (← isExFalsoTransition p) then
traceStep ("ex falso")
let p ← processExFalso p
process p
return
if hasAsPattern p then
traceStep ("as-pattern")
let p ← processAsPattern p
process p
return
if (← isNatValueTransition p) then
traceStep ("nat value to constructor")
process (← expandNatValuePattern p)
return
if (← isIntValueTransition p) then
traceStep ("int value to constructor")
process (← expandIntValuePattern p)
return
if (← isFinValueTransition p) then
traceStep ("fin value to constructor")
process (← expandFinValuePattern p)
return
if (← isBitVecValueTransition p) then
traceStep ("bitvec value to constructor")
process (← expandBitVecValuePattern p)
return
if !isNextVar p then
traceStep ("non variable")
let p ← processNonVariable p
process p
return
if (← isConstructorTransition p) then
let ps ← processConstructor p
ps.forM process
return
if isVariableTransition p then
traceStep ("variable")
let p ← processVariable p
process p
return
if isValueTransition p then
let ps ← processValue p
ps.forM process
return
if isArrayLitTransition p then
let ps ← processArrayLit p
ps.forM process
return
if (← hasNatValPattern p) then
-- This branch is reachable when `p`, for example, is just values without an else-alternative.
-- We added it just to get better error messages.
traceStep ("nat value to constructor")
process (← expandNatValuePattern p)
return
checkNextPatternTypes p
throwNonSupported p
private def getUElimPos? (matcherLevels : List Level) (uElim : Level) : MetaM (Option Nat) :=
if uElim == levelZero then
return none
else match matcherLevels.idxOf? uElim with
| none => throwError "Dependent match elimination failed: Universe level not found"
| some pos => return some pos
/- See comment at `mkMatcher` before `mkAuxDefinition` -/
register_builtin_option bootstrap.genMatcherCode : Bool := {
defValue := true
group := "bootstrap"
descr := "disable code generation for auxiliary matcher function"
}
private structure MatcherKey where
value : Expr
compile : Bool
-- When a matcher is created in a private context and thus may contain private references, we must
-- not reuse it in an exported context.
isPrivate : Bool
deriving BEq, Hashable
private builtin_initialize matcherExt : EnvExtension (PHashMap MatcherKey Name) ←
registerEnvExtension (pure {}) (asyncMode := .local) -- mere cache, keep it local
/-- Similar to `mkAuxDefinition`, but uses the cache `matcherExt`.
It also returns an Boolean that indicates whether a new matcher function was added to the environment or not. -/
def mkMatcherAuxDefinition (name : Name) (type : Expr) (value : Expr) : MetaM (Expr × Option (MatcherInfo → MetaM Unit)) := do
trace[Meta.Match.debug] "{name} : {type} := {value}"
let compile := bootstrap.genMatcherCode.get (← getOptions)
let result ← Closure.mkValueTypeClosure type value (zetaDelta := false)
let env ← getEnv
let mkMatcherConst name :=
mkAppN (mkConst name result.levelArgs.toList) result.exprArgs
let key := { value := result.value, compile, isPrivate := env.header.isModule && isPrivateName name }
let mut nameNew? := (matcherExt.getState env).find? key
if nameNew?.isNone && key.isPrivate then
-- private contexts may reuse public matchers
nameNew? := (matcherExt.getState env).find? { key with isPrivate := false }
match nameNew? with
| some nameNew => return (mkMatcherConst nameNew, none)
| none =>
let decl := Declaration.defnDecl (← mkDefinitionValInferringUnsafe name result.levelParams.toList
result.type result.value .abbrev)
trace[Meta.Match.debug] "{name} : {result.type} := {result.value}"
let addMatcher : MatcherInfo → MetaM Unit := fun mi => do
-- matcher bodies should always be exported, if not private anyway
withExporting do
addDecl decl
modifyEnv fun env => matcherExt.modifyState env fun s => s.insert key name
addMatcherInfo name mi
setInlineAttribute name
enableRealizationsForConst name
if compile then
compileDecl decl
return (mkMatcherConst name, some addMatcher)
structure MkMatcherInput where
matcherName : Name
matchType : Expr
discrInfos : Array DiscrInfo
lhss : List AltLHS
def MkMatcherInput.numDiscrs (m : MkMatcherInput) :=
m.discrInfos.size
def MkMatcherInput.collectFVars (m : MkMatcherInput) : StateRefT CollectFVars.State MetaM Unit := do
m.matchType.collectFVars
m.lhss.forM fun alt => alt.collectFVars
def MkMatcherInput.collectDependencies (m : MkMatcherInput) : MetaM FVarIdSet := do
let (_, s) ← m.collectFVars |>.run {}
let s ← s.addDependencies
return s.fvarSet
/--
Auxiliary method used at `mkMatcher`. It executes `k` in a local context that contains only
the local declarations `m` depends on. This is important because otherwise dependent elimination
may "refine" the types of unnecessary declarations and accidentally introduce unnecessary dependencies
in the auto-generated auxiliary declaration. Note that this is not just an optimization because the
unnecessary dependencies may prevent the termination checker from succeeding. For an example,
see issue #1237.
-/
def withCleanLCtxFor (m : MkMatcherInput) (k : MetaM α) : MetaM α := do
let s ← m.collectDependencies
let lctx ← getLCtx
let lctx := lctx.foldr (init := lctx) fun localDecl lctx =>
if s.contains localDecl.fvarId then lctx else lctx.erase localDecl.fvarId
let localInstances := (← getLocalInstances).filter fun localInst => s.contains localInst.fvar.fvarId!
withLCtx lctx localInstances k
/--
Create a dependent matcher for `matchType` where `matchType` is of the form
`(a_1 : A_1) -> (a_2 : A_2[a_1]) -> ... -> (a_n : A_n[a_1, a_2, ... a_{n-1}]) -> B[a_1, ..., a_n]`
where `n = numDiscrs`, and the `lhss` are the left-hand-sides of the `match`-expression alternatives.
Each `AltLHS` has a list of local declarations and a list of patterns.
The number of patterns must be the same in each `AltLHS`.
The generated matcher has the structure described at `MatcherInfo`. The motive argument is of the form
`(motive : (a_1 : A_1) -> (a_2 : A_2[a_1]) -> ... -> (a_n : A_n[a_1, a_2, ... a_{n-1}]) -> Sort v)`
where `v` is a universe parameter or 0 if `B[a_1, ..., a_n]` is a proposition.
-/
def mkMatcher (input : MkMatcherInput) : MetaM MatcherResult := withCleanLCtxFor input do
let ⟨matcherName, matchType, discrInfos, lhss⟩ := input
let numDiscrs := discrInfos.size
let numEqs := getNumEqsFromDiscrInfos discrInfos
checkNumPatterns numDiscrs lhss
forallBoundedTelescope matchType numDiscrs fun discrs matchTypeBody => do
/- We generate an matcher that can eliminate using different motives with different universe levels.
`uElim` is the universe level the caller wants to eliminate to.
If it is not levelZero, we create a matcher that can eliminate in any universe level.
This is useful for implementing `MatcherApp.addArg` because it may have to change the universe level. -/
let uElim ← getLevel matchTypeBody
let uElimGen ← if uElim == levelZero then pure levelZero else mkFreshLevelMVar
let mkMatcher (type val : Expr) (minors : Array (Expr × Nat)) (s : State) : MetaM MatcherResult := do
trace[Meta.Match.debug] "matcher value: {val}\ntype: {type}"
trace[Meta.Match.debug] "minors num params: {minors.map (·.2)}"
/- The option `bootstrap.gen_matcher_code` is a helper hack. It is useful, for example,
for compiling `src/Init/Data/Int`. It is needed because the compiler uses `Int.decLt`
for generating code for `Int.casesOn` applications, but `Int.casesOn` is used to
give the reference implementation for
```
@[extern "lean_int_neg"] def neg (n : @& Int) : Int :=
match n with
| ofNat n => negOfNat n
| negSucc n => succ n
```
which is defined **before** `Int.decLt` -/
let (matcher, addMatcher) ← mkMatcherAuxDefinition matcherName type val
trace[Meta.Match.debug] "matcher levels: {matcher.getAppFn.constLevels!}, uElim: {uElimGen}"
let uElimPos? ← getUElimPos? matcher.getAppFn.constLevels! uElimGen
discard <| isLevelDefEq uElimGen uElim
let addMatcher :=
match addMatcher with
| some addMatcher => addMatcher <|
{ numParams := matcher.getAppNumArgs
altNumParams := minors.map fun minor => minor.2 + numEqs
discrInfos
numDiscrs
uElimPos?
}
| none => pure ()
trace[Meta.Match.debug] "matcher: {matcher}"
let unusedAltIdxs := lhss.length.fold (init := []) fun i _ r =>
if s.used.contains i then r else i::r
return {
matcher,
counterExamples := s.counterExamples,
unusedAltIdxs := unusedAltIdxs.reverse,
addMatcher
}
let motiveType ← mkForallFVars discrs (mkSort uElimGen)
trace[Meta.Match.debug] "motiveType: {motiveType}"
withLocalDeclD `motive motiveType fun motive => do
if discrInfos.any fun info => info.hName?.isSome then
forallBoundedTelescope matchType numDiscrs fun discrs' _ => do
let (mvarType, isEqMask) ← withEqs discrs discrs' discrInfos fun eqs => do
let mvarType ← mkForallFVars eqs (mkAppN motive discrs')
let isEqMask ← eqs.mapM fun eq => return (← inferType eq).isEq
return (mvarType, isEqMask)
trace[Meta.Match.debug] "target: {mvarType}"
withAlts motive discrs discrInfos lhss fun alts minors => do
let mvar ← mkFreshExprMVar mvarType
trace[Meta.Match.debug] "goal\n{mvar.mvarId!}"
let examples := discrs'.toList.map fun discr => Example.var discr.fvarId!
let (_, s) ← (process { mvarId := mvar.mvarId!, vars := discrs'.toList, alts := alts, examples := examples }).run {}
let val ← mkLambdaFVars discrs' mvar
trace[Meta.Match.debug] "matcher\nvalue: {val}\ntype: {← inferType val}"
let mut rfls := #[]
let mut isEqMaskIdx := 0
for discr in discrs, info in discrInfos do
if info.hName?.isSome then
if isEqMask[isEqMaskIdx]! then
rfls := rfls.push (← mkEqRefl discr)
else
rfls := rfls.push (← mkHEqRefl discr)
isEqMaskIdx := isEqMaskIdx + 1
let val := mkAppN (mkAppN val discrs) rfls
let args := #[motive] ++ discrs ++ minors.map Prod.fst
let val ← mkLambdaFVars args val
let type ← mkForallFVars args (mkAppN motive discrs)
mkMatcher type val minors s
else
let mvarType := mkAppN motive discrs
trace[Meta.Match.debug] "target: {mvarType}"
withAlts motive discrs discrInfos lhss fun alts minors => do
let mvar ← mkFreshExprMVar mvarType
let examples := discrs.toList.map fun discr => Example.var discr.fvarId!
let (_, s) ← (process { mvarId := mvar.mvarId!, vars := discrs.toList, alts := alts, examples := examples }).run {}
let args := #[motive] ++ discrs ++ minors.map Prod.fst
let type ← mkForallFVars args mvarType
let val ← mkLambdaFVars args mvar
mkMatcher type val minors s
def getMkMatcherInputInContext (matcherApp : MatcherApp) : MetaM MkMatcherInput := do
let matcherName := matcherApp.matcherName
let some matcherInfo ← getMatcherInfo? matcherName
| throwError "Internal error during match expression elaboration: Could not find a matcher named `{matcherName}`"
let matcherConst ← getConstVal matcherName
let matcherType ← instantiateTypeLevelParams matcherConst matcherApp.matcherLevels.toList
let matcherType ← instantiateForall matcherType <| matcherApp.params ++ #[matcherApp.motive]
let matchType ← do
let u :=
if let some idx := matcherInfo.uElimPos?
then mkLevelParam matcherConst.levelParams.toArray[idx]!
else levelZero
forallBoundedTelescope matcherType (some matcherInfo.numDiscrs) fun discrs _ => do
mkForallFVars discrs (mkConst ``PUnit [u])
let matcherType ← instantiateForall matcherType matcherApp.discrs
let lhss ← forallBoundedTelescope matcherType (some matcherApp.alts.size) fun alts _ =>
alts.mapM fun alt => do
let ty ← inferType alt
forallTelescope ty fun xs body => do
let xs ← xs.filterM fun x => dependsOn body x.fvarId!
body.withApp fun _ args => do
let ctx ← getLCtx
let localDecls := xs.map ctx.getFVar!
let patterns ← args.mapM Match.toPattern
return {
ref := Syntax.missing
fvarDecls := localDecls.toList
patterns := patterns.toList : Match.AltLHS }
return { matcherName, matchType, discrInfos := matcherInfo.discrInfos, lhss := lhss.toList }
/-- This function is only used for testing purposes -/
def withMkMatcherInput (matcherName : Name) (k : MkMatcherInput → MetaM α) : MetaM α := do
let some matcherInfo ← getMatcherInfo? matcherName
| throwError "Internal error during match expression elaboration: Could not find a matcher named `{matcherName}`"
let matcherConst ← getConstInfo matcherName
forallBoundedTelescope matcherConst.type (some matcherInfo.arity) fun xs _ => do
let matcherApp ← mkConstWithLevelParams matcherConst.name
let matcherApp := mkAppN matcherApp xs
let some matcherApp ← matchMatcherApp? matcherApp
| throwError "Internal error during match expression elaboration: Could not find a matcher app named `{matcherApp}`"
let mkMatcherInput ← getMkMatcherInputInContext matcherApp
k mkMatcherInput
end Match
builtin_initialize
registerTraceClass `Meta.Match.match
registerTraceClass `Meta.Match.debug
registerTraceClass `Meta.Match.unify
end Lean.Meta