lean4-htt/src/Lean/Meta/Match/MatchEqs.lean
2022-06-13 17:10:14 -07:00

707 lines
32 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2021 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.Meta.Match.Match
import Lean.Meta.Match.MatchEqsExt
import Lean.Meta.Tactic.Apply
import Lean.Meta.Tactic.Delta
import Lean.Meta.Tactic.SplitIf
import Lean.Meta.Tactic.Injection
import Lean.Meta.Tactic.Contradiction
namespace Lean.Meta
/--
Helper method for `proveCondEqThm`. Given a goal of the form `C.rec ... xMajor = rhs`,
apply `cases xMajor`. -/
partial def casesOnStuckLHS (mvarId : MVarId) : MetaM (Array MVarId) := do
let target ← getMVarType mvarId
if let some (_, lhs, _) ← matchEq? target then
if let some fvarId ← findFVar? lhs then
return (← cases mvarId fvarId).map fun s => s.mvarId
throwError "'casesOnStuckLHS' failed"
where
findFVar? (e : Expr) : MetaM (Option FVarId) := do
match e.getAppFn with
| Expr.proj _ _ e _ => findFVar? e
| f =>
if !f.isConst then
return none
else
let declName := f.constName!
let args := e.getAppArgs
match (← getProjectionFnInfo? declName) with
| some projInfo =>
if projInfo.numParams < args.size then
findFVar? args[projInfo.numParams]
else
return none
| none =>
matchConstRec f (fun _ => return none) fun recVal _ => do
if recVal.getMajorIdx >= args.size then
return none
let major := args[recVal.getMajorIdx]
if major.isFVar then
return some major.fvarId!
else
return none
def casesOnStuckLHS? (mvarId : MVarId) : MetaM (Option (Array MVarId)) := do
try casesOnStuckLHS mvarId catch _ => return none
namespace Match
def unfoldNamedPattern (e : Expr) : MetaM Expr := do
let visit (e : Expr) : MetaM TransformStep := do
if let some e := isNamedPattern? e then
if let some eNew ← unfoldDefinition? e then
return TransformStep.visit eNew
return TransformStep.visit e
Meta.transform e (pre := visit)
/--
Similar to `forallTelescopeReducing`, but
1. Eliminates arguments for named parameters and the associated equation proofs.
2. Equality parameters associated with the `h : discr` notation are replaced with `rfl` proofs.
Recall that this kind of parameter always occurs after the parameters correspoting to pattern variables.
`numNonEqParams` is the size of the prefix.
The continuation `k` takes four arguments `ys args mask type`.
- `ys` are variables for the hypotheses that have not been eliminated.
- `eqs` are variables for equality hypotheses associated with discriminants annotated with `h : discr`.
- `args` are the arguments for the alternative `alt` that has type `altType`. `ys.size <= args.size`
- `mask[i]` is true if the hypotheses has not been eliminated. `mask.size == args.size`.
- `type` is the resulting type for `altType`.
We use the `mask` to build the splitter proof. See `mkSplitterProof`.
-/
partial def forallAltTelescope (altType : Expr) (numNonEqParams : Nat)
(k : (ys : Array Expr) → (eqs : Array Expr) → (args : Array Expr) → (mask : Array Bool) → (type : Expr) → MetaM α)
: MetaM α := do
go #[] #[] #[] #[] 0 altType
where
go (ys : Array Expr) (eqs : Array Expr) (args : Array Expr) (mask : Array Bool) (i : Nat) (type : Expr) : MetaM α := do
let type ← whnfForall type
match type with
| Expr.forallE n d b .. =>
if i < numNonEqParams then
let d ← unfoldNamedPattern d
withLocalDeclD n d fun y => do
let typeNew := b.instantiate1 y
if let some (_, lhs, rhs) ← matchEq? d then
if lhs.isFVar && ys.contains lhs && args.contains lhs && isNamedPatternProof typeNew y then
let some i := ys.getIdx? lhs | unreachable!
let ys := ys.eraseIdx i
let some j := args.getIdx? lhs | unreachable!
let mask := mask.set! j false
let args := args.map fun arg => if arg == lhs then rhs else arg
let args := args.push (← mkEqRefl rhs)
let typeNew := typeNew.replaceFVar lhs rhs
return (← go ys eqs args (mask.push false) (i+1) typeNew)
go (ys.push y) eqs (args.push y) (mask.push true) (i+1) typeNew
else
let arg ← if let some (_, _, rhs) ← matchEq? d then
mkEqRefl rhs
else if let some (_, _, _, rhs) ← matchHEq? d then
mkHEqRefl rhs
else
throwError "unexpected match alternative type{indentExpr altType}"
withLocalDeclD n d fun eq => do
let typeNew := b.instantiate1 eq
go ys (eqs.push eq) (args.push arg) (mask.push false) (i+1) typeNew
| _ =>
let type ← unfoldNamedPattern type
/- Recall that alternatives that do not have variables have a `Unit` parameter to ensure
they are not eagerly evaluated. -/
if ys.size == 1 then
if (← inferType ys[0]).isConstOf ``Unit && !(← dependsOn type ys[0].fvarId!) then
return (← k #[] #[] #[mkConst ``Unit.unit] #[false] type)
k ys eqs args mask type
isNamedPatternProof (type : Expr) (h : Expr) : Bool :=
Option.isSome <| type.find? fun e =>
if let some e := isNamedPattern? e then
e.appArg! == h
else
false
namespace SimpH
/--
State for the equational theorem hypothesis simplifier.
Recall that each equation contains additional hypotheses to ensure the associated case does not taken by previous cases.
We have one hypothesis for each previous case.
Each hypothesis is of the form `forall xs, eqs → False`
We use tactics to minimize code duplication.
-/
structure State where
mvarId : MVarId -- Goal representing the hypothesis
xs : List FVarId -- Pattern variables for a previous case
eqs : List FVarId -- Equations to be processed
eqsNew : List FVarId := [] -- Simplied (already processed) equations
abbrev M := StateRefT State MetaM
/--
Apply the given substitution to `fvarIds`.
This is an auxiliary method for `substRHS`.
-/
private def applySubst (s : FVarSubst) (fvarIds : List FVarId) : List FVarId :=
fvarIds.filterMap fun fvarId => match s.apply (mkFVar fvarId) with
| Expr.fvar fvarId .. => some fvarId
| _ => none
/--
Given an equation of the form `lhs = rhs` where `rhs` is variable in `xs`,
the replace it everywhere with `lhs`.
-/
private def substRHS (eq : FVarId) (rhs : FVarId) : M Unit := do
assert! (← get).xs.contains rhs
let (subst, mvarId) ← substCore (← get).mvarId eq (symm := true)
modify fun s => { s with
mvarId,
xs := applySubst subst (s.xs.erase rhs)
eqs := applySubst subst s.eqs
eqsNew := applySubst subst s.eqsNew
}
private def isDone : M Bool :=
return (← get).eqs.isEmpty
/-- Customized `contradiction` tactic for `simpH?` -/
private def contradiction (mvarId : MVarId) : MetaM Bool :=
contradictionCore mvarId { genDiseq := false, emptyType := false }
/--
Auxiliary tactic that tries to replace as many variables as possible and then apply `contradiction`.
We use it to discard redundant hypotheses.
-/
partial def trySubstVarsAndContradiction (mvarId : MVarId) : MetaM Bool :=
commitWhen do
let mvarId ← substVars mvarId
match (← injections mvarId) with
| none => return true -- closed goal
| some mvarId' =>
if mvarId' == mvarId then
contradiction mvarId
else
trySubstVarsAndContradiction mvarId'
private def processNextEq : M Bool := do
let s ← get
withMVarContext s.mvarId do
-- If the goal is contradictory, the hypothesis is redundant.
if (← contradiction s.mvarId) then
return false
if let eq :: eqs := s.eqs then
modify fun s => { s with eqs }
let eqType ← inferType (mkFVar eq)
-- See `substRHS`. Recall that if `rhs` is a variable then if must be in `s.xs`
if let some (_, lhs, rhs) ← matchEq? eqType then
if (← isDefEq lhs rhs) then
return true
if rhs.isFVar then
substRHS eq rhs.fvarId!
return true
if let some (α, lhs, β, rhs) ← matchHEq? eqType then
-- Try to convert `HEq` into `Eq`
if (← isDefEq α β) then
let (eqNew, mvarId) ← heqToEq s.mvarId eq (tryToClear := true)
modify fun s => { s with mvarId, eqs := eqNew :: s.eqs }
return true
-- If it is not possible, we try to show the hypothesis is redundant by substituting even variables that are not at `s.xs`, and then use contradiction.
else
match lhs.isConstructorApp? (← getEnv), rhs.isConstructorApp? (← getEnv) with
| some lhsCtor, some rhsCtor =>
if lhsCtor.name != rhsCtor.name then
return false -- If the constructors are different, we can discard the hypothesis even if it a heterogeneous equality
else if (← trySubstVarsAndContradiction s.mvarId) then
return false
| _, _ =>
if (← trySubstVarsAndContradiction s.mvarId) then
return false
try
-- Try to simplify equation using `injection` tactic.
match (← injection s.mvarId eq) with
| InjectionResult.solved => return false
| InjectionResult.subgoal mvarId eqNews .. =>
modify fun s => { s with mvarId, eqs := eqNews.toList ++ s.eqs }
catch _ =>
modify fun s => { s with eqsNew := eq :: s.eqsNew }
return true
partial def go : M Bool := do
if (← isDone) then
return true
else if (← processNextEq) then
go
else
return false
end SimpH
/--
Auxiliary method for simplifying equational theorem hypotheses.
Recall that each equation contains additional hypotheses to ensure the associated case was not taken by previous cases.
We have one hypothesis for each previous case.
-/
private partial def simpH? (h : Expr) (numEqs : Nat) : MetaM (Option Expr) := withDefault do
let numVars ← forallTelescope h fun ys _ => pure (ys.size - numEqs)
let mvarId := (← mkFreshExprSyntheticOpaqueMVar h).mvarId!
let (xs, mvarId) ← introN mvarId numVars
let (eqs, mvarId) ← introN mvarId numEqs
let (r, s) ← SimpH.go |>.run { mvarId, xs := xs.toList, eqs := eqs.toList }
if r then
withMVarContext s.mvarId do
let eqs := s.eqsNew.reverse.toArray.map mkFVar
let mut r ← mkForallFVars eqs (mkConst ``False)
/- We only include variables in `xs` if there is a dependency. -/
for x in s.xs.reverse do
if (← dependsOn r x) then
r ← mkForallFVars #[mkFVar x] r
trace[Meta.Match.matchEqs] "simplified hypothesis{indentExpr r}"
check r
return some r
else
return none
private def substSomeVar (mvarId : MVarId) : MetaM (Array MVarId) := withMVarContext mvarId do
for localDecl in (← getLCtx) do
if let some (_, lhs, rhs) ← matchEq? localDecl.type then
if lhs.isFVar then
if !(← dependsOn rhs lhs.fvarId!) then
match (← subst? mvarId lhs.fvarId!) with
| some mvarId => return #[mvarId]
| none => pure ()
throwError "substSomeVar failed"
/--
Helper method for proving a conditional equational theorem associated with an alternative of
the `match`-eliminator `matchDeclName`. `type` contains the type of the theorem. -/
partial def proveCondEqThm (matchDeclName : Name) (type : Expr) : MetaM Expr := withLCtx {} {} do
let type ← instantiateMVars type
forallTelescope type fun ys target => do
let mvar0 ← mkFreshExprSyntheticOpaqueMVar target
trace[Meta.Match.matchEqs] "proveCondEqThm {mvar0.mvarId!}"
let mvarId ← deltaTarget mvar0.mvarId! (· == matchDeclName)
withDefault <| go mvarId 0
mkLambdaFVars ys (← instantiateMVars mvar0)
where
go (mvarId : MVarId) (depth : Nat) : MetaM Unit := withIncRecDepth do
trace[Meta.Match.matchEqs] "proveCondEqThm.go {mvarId}"
let mvarId' ← modifyTargetEqLHS mvarId whnfCore
let mvarId := mvarId'
let subgoals ←
(do applyRefl mvarId; return #[])
<|>
(do contradiction mvarId { genDiseq := true }; return #[])
<|>
(casesOnStuckLHS mvarId)
<|>
(do let mvarId' ← simpIfTarget mvarId (useDecide := true)
if mvarId' == mvarId then throwError "simpIf failed"
return #[mvarId'])
<|>
(do if let some (s₁, s₂) ← splitIfTarget? mvarId then
let mvarId₁ ← trySubst s₁.mvarId s₁.fvarId
return #[mvarId₁, s₂.mvarId]
else
throwError "spliIf failed")
<|>
(substSomeVar mvarId)
<|>
(throwError "failed to generate equality theorems for `match` expression `{matchDeclName}`\n{MessageData.ofGoal mvarId}")
subgoals.forM (go · (depth+1))
/-- Construct new local declarations `xs` with types `altTypes`, and then execute `f xs` -/
private partial def withSplitterAlts (altTypes : Array Expr) (f : Array Expr → MetaM α) : MetaM α := do
let rec go (i : Nat) (xs : Array Expr) : MetaM α := do
if h : i < altTypes.size then
let hName := (`h).appendIndexAfter (i+1)
withLocalDeclD hName (altTypes.get ⟨i, h⟩) fun x =>
go (i+1) (xs.push x)
else
f xs
go 0 #[]
inductive InjectionAnyResult where
| solved
| failed
| subgoal (mvarId : MVarId)
private def injectionAnyCandidate? (type : Expr) : MetaM (Option (Expr × Expr)) := do
if let some (_, lhs, rhs) ← matchEq? type then
return some (lhs, rhs)
else if let some (α, lhs, β, rhs) ← matchHEq? type then
if (← isDefEq α β) then
return some (lhs, rhs)
return none
private def injectionAny (mvarId : MVarId) : MetaM InjectionAnyResult :=
withMVarContext mvarId do
for localDecl in (← getLCtx) do
if let some (lhs, rhs) ← injectionAnyCandidate? localDecl.type then
unless (← isDefEq lhs rhs) do
let lhs ← whnf lhs
let rhs ← whnf rhs
unless lhs.isNatLit && rhs.isNatLit do
try
match (← injection mvarId localDecl.fvarId) with
| InjectionResult.solved => return InjectionAnyResult.solved
| InjectionResult.subgoal mvarId .. => return InjectionAnyResult.subgoal mvarId
catch ex =>
trace[Meta.Match.matchEqs] "injectionAnyFailed at {localDecl.userName}, error\n{ex.toMessageData}"
pure ()
return InjectionAnyResult.failed
private abbrev ConvertM := ReaderT (FVarIdMap (Expr × Nat × Array Bool)) $ StateRefT (Array MVarId) MetaM
/--
Construct a proof for the splitter generated by `mkEquationsfor`.
The proof uses the definition of the `match`-declaration as a template (argument `template`).
- `alts` are free variables corresponding to alternatives of the `match` auxiliary declaration being processed.
- `altNews` are the new free variables which contains aditional hypotheses that ensure they are only used
when the previous overlapping alternatives are not applicable. -/
private partial def mkSplitterProof (matchDeclName : Name) (template : Expr) (alts altsNew : Array Expr)
(altsNewNumParams : Array Nat)
(altArgMasks : Array (Array Bool)) : MetaM Expr := do
trace[Meta.Match.matchEqs] "proof template: {template}"
let map := mkMap
let (proof, mvarIds) ← convertTemplate template |>.run map |>.run #[]
trace[Meta.Match.matchEqs] "splitter proof: {proof}"
for mvarId in mvarIds do
proveSubgoal mvarId
instantiateMVars proof
where
mkMap : FVarIdMap (Expr × Nat × Array Bool) := Id.run do
let mut m := {}
for alt in alts, altNew in altsNew, numParams in altsNewNumParams, argMask in altArgMasks do
m := m.insert alt.fvarId! (altNew, numParams, argMask)
return m
trimFalseTrail (argMask : Array Bool) : Array Bool :=
if argMask.isEmpty then
argMask
else if !argMask.back then
trimFalseTrail argMask.pop
else
argMask
/--
Auxiliary function used at `convertTemplate` to decide whether to use `convertCastEqRec`.
See `convertCastEqRec`. -/
isCastEqRec (e : Expr) : ConvertM Bool := do
-- TODO: we do not handle `Eq.rec` since we never found an example that needed it.
-- If we find one we must extend `convertCastEqRec`.
unless e.isAppOf ``Eq.ndrec do return false
unless e.getAppNumArgs > 6 do return false
for arg in e.getAppArgs[6:] do
if arg.isFVar && (← read).contains arg.fvarId! then
return true
return true
/--
Auxiliary function used at `convertTemplate`. It is needed when the auxiliary `match` declaration had to refine the type of its
minor premises during dependent pattern match. For an example, consider
```
inductive Foo : Nat → Type _
| nil : Foo 0
| cons (t: Foo l): Foo l
def Foo.bar (t₁: Foo l₁): Foo l₂ → Bool
| cons s₁ => t₁.bar s₁
| _ => false
attribute [simp] Foo.bar
```
The auxiliary `Foo.bar.match_1` is of the form
```
def Foo.bar.match_1.{u_1} : {l₂ : Nat} →
(t₂ : Foo l₂) →
(motive : Foo l₂ → Sort u_1) →
(t₂ : Foo l₂) → ((s₁ : Foo l₂) → motive (Foo.cons s₁)) → ((x : Foo l₂) → motive x) → motive t₂ :=
fun {l₂} t₂ motive t₂_1 h_1 h_2 =>
(fun t₂_2 =>
Foo.casesOn (motive := fun a x => l₂ = a → HEq t₂_1 x → motive t₂_1) t₂_2
(fun h =>
Eq.ndrec (motive := fun {l₂} =>
(t₂ t₂ : Foo l₂) →
(motive : Foo l₂ → Sort u_1) →
((s₁ : Foo l₂) → motive (Foo.cons s₁)) → ((x : Foo l₂) → motive x) → HEq t₂ Foo.nil → motive t₂)
(fun t₂ t₂ motive h_1 h_2 h => Eq.symm (eq_of_heq h) ▸ h_2 Foo.nil) (Eq.symm h) t₂ t₂_1 motive h_1 h_2) --- HERE
fun {l} t h =>
Eq.ndrec (motive := fun {l} => (t : Foo l) → HEq t₂_1 (Foo.cons t) → motive t₂_1)
(fun t h => Eq.symm (eq_of_heq h) ▸ h_1 t) h t)
t₂_1 (Eq.refl l₂) (HEq.refl t₂_1)
```
The `HERE` comment marks the place where the type of `Foo.bar.match_1` minor premises `h_1` and `h_2` is being "refined"
using `Eq.ndrec`.
This function will adjust the motive and minor premise of the `Eq.ndrec` to reflect the new minor premises used in the
corresponding splitter theorem.
We may have to extend this function to handle `Eq.rec` too.
This function was added to address issue #1179
-/
convertCastEqRec (e : Expr) : ConvertM Expr := do
assert! (← isCastEqRec e)
e.withApp fun f args => do
let mut argsNew := args
let mut isAlt := #[]
for i in [6:args.size] do
let arg := argsNew[i]
if arg.isFVar then
match (← read).find? arg.fvarId! with
| some (altNew, _, _) =>
argsNew := argsNew.set! i altNew
trace[Meta.Match.matchEqs] "arg: {arg} : {← inferType arg}, altNew: {altNew} : {← inferType altNew}"
isAlt := isAlt.push true
| none =>
argsNew := argsNew.set! i (← convertTemplate arg)
isAlt := isAlt.push false
else
argsNew := argsNew.set! i (← convertTemplate arg)
isAlt := isAlt.push false
assert! isAlt.size == args.size - 6
let rhs := args[4]
let motive := args[2]
-- Construct new motive using the splitter theorem minor premise types.
let motiveNew ← lambdaTelescope motive fun motiveArgs body => do
unless motiveArgs.size == 1 do
throwError "unexpected `Eq.ndrec` motive while creating splitter/eliminator theorem for `{matchDeclName}`, expected lambda with 1 binder{indentExpr motive}"
let x := motiveArgs[0]
forallTelescopeReducing body fun motiveTypeArgs resultType => do
unless motiveTypeArgs.size >= isAlt.size do
throwError "unexpected `Eq.ndrec` motive while creating splitter/eliminator theorem for `{matchDeclName}`, expected arrow with at least #{isAlt.size} binders{indentExpr body}"
let rec go (i : Nat) (motiveTypeArgsNew : Array Expr) : ConvertM Expr := do
assert! motiveTypeArgsNew.size == i
if h : i < motiveTypeArgs.size then
let motiveTypeArg := motiveTypeArgs.get ⟨i, h⟩
if i < isAlt.size && isAlt[i] then
let altNew := argsNew[6+i] -- Recall that `Eq.ndrec` has 6 arguments
let altTypeNew ← inferType altNew
trace[Meta.Match.matchEqs] "altNew: {altNew} : {altTypeNew}"
-- Replace `rhs` with `x` (the lambda binder in the motive)
let mut altTypeNewAbst := (← kabstract altTypeNew rhs).instantiate1 x
-- Replace args[6:6+i] with `motiveTypeArgsNew`
for j in [:i] do
altTypeNewAbst := (← kabstract altTypeNewAbst argsNew[6+j]).instantiate1 motiveTypeArgsNew[j]
let localDecl ← getLocalDecl motiveTypeArg.fvarId!
withLocalDecl localDecl.userName localDecl.binderInfo altTypeNewAbst fun motiveTypeArgNew =>
go (i+1) (motiveTypeArgsNew.push motiveTypeArgNew)
else
go (i+1) (motiveTypeArgsNew.push motiveTypeArg)
else
mkLambdaFVars motiveArgs (← mkForallFVars motiveTypeArgsNew resultType)
go 0 #[]
trace[Meta.Match.matchEqs] "new motive: {motiveNew}"
unless (← isTypeCorrect motiveNew) do
throwError "failed to construct new type correct motive for `Eq.ndrec` while creating splitter/eliminator theorem for `{matchDeclName}`{indentExpr motiveNew}"
argsNew := argsNew.set! 2 motiveNew
-- Construct the new minor premise for the `Eq.ndrec` application.
-- First, we use `eqRecNewPrefix` to infer the new minor premise binders for `Eq.ndrec`
let eqRecNewPrefix := mkAppN f argsNew[:3] -- `Eq.ndrec` minor premise is the fourth argument.
let .forallE _ minorTypeNew .. ← whnf (← inferType eqRecNewPrefix) | unreachable!
trace[Meta.Match.matchEqs] "new minor type: {minorTypeNew}"
let minor := args[3]
let minorNew ← forallBoundedTelescope minorTypeNew isAlt.size fun minorArgsNew _ => do
let mut minorBodyNew := minor
-- We have to extend the mapping to make sure `convertTemplate` can "fix" occurrences of the refined minor premises
let mut m ← read
for i in [:isAlt.size] do
if isAlt[i] then
-- `convertTemplate` will correct occurrences of the alternative
let alt := args[6+i] -- Recall that `Eq.ndrec` has 6 arguments
let some (_, numParams, argMask) := m.find? alt.fvarId! | unreachable!
-- We add a new entry to `m` to make sure `convertTemplate` will correct the occurrences of the alternative
m := m.insert minorArgsNew[i].fvarId! (minorArgsNew[i], numParams, argMask)
unless minorBodyNew.isLambda do
throwError "unexpected `Eq.ndrec` minor premise while creating splitter/eliminator theorem for `{matchDeclName}`, expected lambda with at least #{isAlt.size} binders{indentExpr minor}"
minorBodyNew := minorBodyNew.bindingBody!
minorBodyNew := minorBodyNew.instantiateRev minorArgsNew
trace[Meta.Match.matchEqs] "minor premise new body before convertTemplate:{indentExpr minorBodyNew}"
minorBodyNew ← withReader (fun _ => m) <| convertTemplate minorBodyNew
trace[Meta.Match.matchEqs] "minor premise new body after convertTemplate:{indentExpr minorBodyNew}"
mkLambdaFVars minorArgsNew minorBodyNew
unless (← isTypeCorrect minorNew) do
throwError "failed to construct new type correct minor premise for `Eq.ndrec` while creating splitter/eliminator theorem for `{matchDeclName}`{indentExpr minorNew}"
argsNew := argsNew.set! 3 minorNew
-- trace[Meta.Match.matchEqs] "argsNew: {argsNew}"
trace[Meta.Match.matchEqs] "found cast target {e}"
return mkAppN f argsNew
convertTemplate (e : Expr) : ConvertM Expr :=
transform e fun e => do
if (← isCastEqRec e) then
return .done (← convertCastEqRec e)
else match e.getAppFn with
| Expr.fvar fvarId .. =>
match (← read).find? fvarId with
| some (altNew, numParams, argMask) =>
trace[Meta.Match.matchEqs] ">> argMask: {argMask}, e: {e}, {altNew}"
let mut newArgs := #[]
let argMask := trimFalseTrail argMask
unless e.getAppNumArgs ≥ argMask.size do
throwError "unexpected occurrence of `match`-expression alternative (aka minor premise) while creating splitter/eliminator theorem for `{matchDeclName}`, minor premise is partially applied{indentExpr e}\npossible solution if you are matching on inductive families: add its indices as additional discriminants"
for arg in e.getAppArgs, includeArg in argMask do
if includeArg then
newArgs := newArgs.push arg
let eNew := mkAppN altNew newArgs
/- Recall that `numParams` does not include the equalities associated with discriminants of the form `h : discr`. -/
let (mvars, _, _) ← forallMetaBoundedTelescope (← inferType eNew) (numParams - newArgs.size) (kind := MetavarKind.syntheticOpaque)
modify fun s => s ++ (mvars.map (·.mvarId!))
let eNew := mkAppN eNew mvars
return TransformStep.done eNew
| none => return TransformStep.visit e
| _ => return TransformStep.visit e
proveSubgoalLoop (mvarId : MVarId) : MetaM Unit := do
trace[Meta.Match.matchEqs] "proveSubgoalLoop\n{mvarId}"
match (← injectionAny mvarId) with
| InjectionAnyResult.solved => return ()
| InjectionAnyResult.failed =>
let mvarId' ← substVars mvarId
if mvarId' == mvarId then
if (← contradictionCore mvarId {}) then
return ()
throwError "failed to generate splitter for match auxiliary declaration '{matchDeclName}', unsolved subgoal:\n{MessageData.ofGoal mvarId}"
else
proveSubgoalLoop mvarId'
| InjectionAnyResult.subgoal mvarId => proveSubgoalLoop mvarId
proveSubgoal (mvarId : MVarId) : MetaM Unit := do
trace[Meta.Match.matchEqs] "subgoal {mkMVar mvarId}, {repr (← getMVarDecl mvarId).kind}, {← isExprMVarAssigned mvarId}\n{MessageData.ofGoal mvarId}"
let (_, mvarId) ← intros mvarId
let mvarId ← tryClearMany mvarId (alts.map (·.fvarId!))
proveSubgoalLoop mvarId
/--
Create new alternatives (aka minor premises) by replacing `discrs` with `patterns` at `alts`.
Recall that `alts` depends on `discrs` when `numDiscrEqs > 0`, where `numDiscrEqs` is the number of discriminants
annotated with `h : discr`.
-/
private partial def withNewAlts (numDiscrEqs : Nat) (discrs : Array Expr) (patterns : Array Expr) (alts : Array Expr) (k : Array Expr → MetaM α) : MetaM α :=
if numDiscrEqs == 0 then
k alts
else
go 0 #[]
where
go (i : Nat) (altsNew : Array Expr) : MetaM α := do
if h : i < alts.size then
let alt := alts.get ⟨i, h⟩
let altLocalDecl ← getFVarLocalDecl alt
let typeNew := altLocalDecl.type.replaceFVars discrs patterns
withLocalDecl altLocalDecl.userName altLocalDecl.binderInfo typeNew fun altNew =>
go (i+1) (altsNew.push altNew)
else
k altsNew
/--
Create conditional equations and splitter for the given match auxiliary declaration. -/
private partial def mkEquationsFor (matchDeclName : Name) : MetaM MatchEqns := withLCtx {} {} do
trace[Meta.Match.matchEqs] "mkEquationsFor '{matchDeclName}'"
withConfig (fun c => { c with etaStruct := .none }) do
let baseName := mkPrivateName (← getEnv) matchDeclName
let constInfo ← getConstInfo matchDeclName
let us := constInfo.levelParams.map mkLevelParam
let some matchInfo ← getMatcherInfo? matchDeclName | throwError "'{matchDeclName}' is not a matcher function"
let numDiscrEqs := getNumEqsFromDiscrInfos matchInfo.discrInfos
forallTelescopeReducing constInfo.type fun xs matchResultType => do
let mut eqnNames := #[]
let params := xs[:matchInfo.numParams]
let motive := xs[matchInfo.getMotivePos]
let alts := xs[xs.size - matchInfo.numAlts:]
let firstDiscrIdx := matchInfo.numParams + 1
let discrs := xs[firstDiscrIdx : firstDiscrIdx + matchInfo.numDiscrs]
let mut notAlts := #[]
let mut idx := 1
let mut splitterAltTypes := #[]
let mut splitterAltNumParams := #[]
let mut altArgMasks := #[] -- masks produced by `forallAltTelescope`
for i in [:alts.size] do
let altNumParams := matchInfo.altNumParams[i]
let altNonEqNumParams := altNumParams - numDiscrEqs
let thmName := baseName ++ ((`eq).appendIndexAfter idx)
eqnNames := eqnNames.push thmName
let (notAlt, splitterAltType, splitterAltNumParam, argMask) ← forallAltTelescope (← inferType alts[i]) altNonEqNumParams fun ys eqs rhsArgs argMask altResultType => do
let patterns := altResultType.getAppArgs
let mut hs := #[]
for notAlt in notAlts do
let h ← instantiateForall notAlt patterns
if let some h ← simpH? h patterns.size then
hs := hs.push h
trace[Meta.Match.matchEqs] "hs: {hs}"
let splitterAltType ← mkForallFVars ys (← hs.foldrM (init := (← mkForallFVars eqs altResultType)) mkArrow)
let splitterAltNumParam := hs.size + ys.size
-- Create a proposition for representing terms that do not match `patterns`
let mut notAlt := mkConst ``False
for discr in discrs.toArray.reverse, pattern in patterns.reverse do
notAlt ← mkArrow (← mkEqHEq discr pattern) notAlt
notAlt ← mkForallFVars (discrs ++ ys) notAlt
/- Recall that when we use the `h : discr`, the alternative type depends on the discriminant.
Thus, we need to create new `alts`. -/
withNewAlts numDiscrEqs discrs patterns alts fun alts => do
let alt := alts[i]
let lhs := mkAppN (mkConst constInfo.name us) (params ++ #[motive] ++ patterns ++ alts)
let rhs := mkAppN alt rhsArgs
let thmType ← mkEq lhs rhs
let thmType ← hs.foldrM (init := thmType) mkArrow
let thmType ← mkForallFVars (params ++ #[motive] ++ ys ++ alts) thmType
let thmType ← unfoldNamedPattern thmType
let thmVal ← proveCondEqThm matchDeclName thmType
addDecl <| Declaration.thmDecl {
name := thmName
levelParams := constInfo.levelParams
type := thmType
value := thmVal
}
return (notAlt, splitterAltType, splitterAltNumParam, argMask)
notAlts := notAlts.push notAlt
splitterAltTypes := splitterAltTypes.push splitterAltType
splitterAltNumParams := splitterAltNumParams.push splitterAltNumParam
altArgMasks := altArgMasks.push argMask
trace[Meta.Match.matchEqs] "splitterAltType: {splitterAltType}"
idx := idx + 1
-- Define splitter with conditional/refined alternatives
withSplitterAlts splitterAltTypes fun altsNew => do
let splitterParams := params.toArray ++ #[motive] ++ discrs.toArray ++ altsNew
let splitterType ← mkForallFVars splitterParams matchResultType
trace[Meta.Match.matchEqs] "splitterType: {splitterType}"
let template := mkAppN (mkConst constInfo.name us) (params ++ #[motive] ++ discrs ++ alts)
let template ← deltaExpand template (· == constInfo.name)
let template := template.headBeta
let splitterVal ← mkLambdaFVars splitterParams (← mkSplitterProof matchDeclName template alts altsNew splitterAltNumParams altArgMasks)
let splitterName := baseName ++ `splitter
addAndCompile <| Declaration.defnDecl {
name := splitterName
levelParams := constInfo.levelParams
type := splitterType
value := splitterVal
hints := .abbrev
safety := .safe
}
setInlineAttribute splitterName
let result := { eqnNames, splitterName, splitterAltNumParams }
registerMatchEqns matchDeclName result
return result
/- See header at `MatchEqsExt.lean` -/
@[export lean_get_match_equations_for]
def getEquationsForImpl (matchDeclName : Name) : MetaM MatchEqns := do
match matchEqnsExt.getState (← getEnv) |>.map.find? matchDeclName with
| some matchEqns => return matchEqns
| none => mkEquationsFor matchDeclName
builtin_initialize registerTraceClass `Meta.Match.matchEqs
end Lean.Meta.Match