feat: match equation theorem generation for new h : discr notation encoding
TODO: splitter theorem generation still needs to be fixed.
This commit is contained in:
parent
24417ed466
commit
89441aac2a
3 changed files with 88 additions and 37 deletions
|
|
@ -779,7 +779,7 @@ where `v` is a universe parameter or 0 if `B[a_1, ..., a_n]` is a proposition. -
|
|||
def mkMatcher (input : MkMatcherInput) : MetaM MatcherResult := do
|
||||
let ⟨matcherName, matchType, discrInfos, lhss⟩ := input
|
||||
let numDiscrs := discrInfos.size
|
||||
let numEqs := (discrInfos.filter fun info => info.hName?.isSome).size
|
||||
let numEqs := getNumEqsFromDiscrInfos discrInfos
|
||||
checkNumPatterns numDiscrs lhss
|
||||
forallBoundedTelescope matchType numDiscrs fun discrs matchTypeBody => do
|
||||
/- We generate an matcher that can eliminate using different motives with different universe levels.
|
||||
|
|
|
|||
|
|
@ -62,8 +62,15 @@ def unfoldNamedPattern (e : Expr) : MetaM Expr := do
|
|||
Meta.transform e (pre := visit)
|
||||
|
||||
/--
|
||||
Similar to `forallTelescopeReducing`, but eliminates arguments for named parameters and the associated
|
||||
equation proofs. The continuation `k` takes four arguments `ys args mask type`.
|
||||
Similar to `forallTelescopeReducing`, but
|
||||
|
||||
1. Eliminates arguments for named parameters and the associated equation proofs.
|
||||
|
||||
2. Equality parameters associated with the `h : discr` notation are replaced with `rfl` proofs.
|
||||
Recall that this kind of parameter always occurs after the parameters correspoting to pattern variables.
|
||||
`numNonEqParams` is the size of the prefix.
|
||||
|
||||
The continuation `k` takes four arguments `ys args mask type`.
|
||||
- `ys` are variables for the hypotheses that have not been eliminated.
|
||||
- `args` are the arguments for the alternative `alt` that has type `altType`. `ys.size <= args.size`
|
||||
- `mask[i]` is true if the hypotheses has not been eliminated. `mask.size == args.size`.
|
||||
|
|
@ -71,26 +78,31 @@ def unfoldNamedPattern (e : Expr) : MetaM Expr := do
|
|||
|
||||
We use the `mask` to build the splitter proof. See `mkSplitterProof`.
|
||||
-/
|
||||
partial def forallAltTelescope (altType : Expr) (k : Array Expr → Array Expr → Array Bool → Expr → MetaM α) : MetaM α := do
|
||||
go #[] #[] #[] altType
|
||||
partial def forallAltTelescope (altType : Expr) (numNonEqParams : Nat) (k : Array Expr → Array Expr → Array Bool → Expr → MetaM α) : MetaM α := do
|
||||
go #[] #[] #[] 0 altType
|
||||
where
|
||||
go (ys : Array Expr) (args : Array Expr) (mask : Array Bool) (type : Expr) : MetaM α := do
|
||||
go (ys : Array Expr) (args : Array Expr) (mask : Array Bool) (i : Nat) (type : Expr) : MetaM α := do
|
||||
let type ← whnfForall type
|
||||
match type with
|
||||
| Expr.forallE n d b .. =>
|
||||
let d ← unfoldNamedPattern d
|
||||
withLocalDeclD n d fun y => do
|
||||
let typeNew := b.instantiate1 y
|
||||
if let some (_, lhs, rhs) ← matchEq? d then
|
||||
if lhs.isFVar && ys.contains lhs && args.contains lhs && isNamedPatternProof typeNew y then
|
||||
let some i := ys.getIdx? lhs | unreachable!
|
||||
let ys := ys.eraseIdx i
|
||||
let mask := mask.set! i false
|
||||
let args := args.map fun arg => if arg == lhs then rhs else arg
|
||||
let args := args.push (← mkEqRefl rhs)
|
||||
let typeNew := typeNew.replaceFVar lhs rhs
|
||||
return (← go ys args (mask.push false) typeNew)
|
||||
go (ys.push y) (args.push y) (mask.push true) typeNew
|
||||
if i < numNonEqParams then
|
||||
let d ← unfoldNamedPattern d
|
||||
withLocalDeclD n d fun y => do
|
||||
let typeNew := b.instantiate1 y
|
||||
if let some (_, lhs, rhs) ← matchEq? d then
|
||||
if lhs.isFVar && ys.contains lhs && args.contains lhs && isNamedPatternProof typeNew y then
|
||||
let some i := ys.getIdx? lhs | unreachable!
|
||||
let ys := ys.eraseIdx i
|
||||
let mask := mask.set! i false
|
||||
let args := args.map fun arg => if arg == lhs then rhs else arg
|
||||
let args := args.push (← mkEqRefl rhs)
|
||||
let typeNew := typeNew.replaceFVar lhs rhs
|
||||
return (← go ys args (mask.push false) (i+1) typeNew)
|
||||
go (ys.push y) (args.push y) (mask.push true) (i+1) typeNew
|
||||
else
|
||||
let some (_, _, rhs) ← matchEq? d | throwError "unexpected match alternative type{indentExpr altType}"
|
||||
let arg ← mkEqRefl rhs
|
||||
go ys (args.push arg) (mask.push false) (i+1) (b.instantiate1 arg)
|
||||
| _ =>
|
||||
let type ← unfoldNamedPattern type
|
||||
/- Recall that alternatives that do not have variables have a `Unit` parameter to ensure
|
||||
|
|
@ -258,16 +270,17 @@ private def substSomeVar (mvarId : MVarId) : MetaM (Array MVarId) := withMVarCon
|
|||
/--
|
||||
Helper method for proving a conditional equational theorem associated with an alternative of
|
||||
the `match`-eliminator `matchDeclName`. `type` contains the type of the theorem. -/
|
||||
partial def proveCondEqThm (matchDeclName : Name) (type : Expr) : MetaM Expr := do
|
||||
partial def proveCondEqThm (matchDeclName : Name) (type : Expr) : MetaM Expr := withLCtx {} {} do
|
||||
let type ← instantiateMVars type
|
||||
forallTelescope type fun ys target => do
|
||||
let mvar0 ← mkFreshExprSyntheticOpaqueMVar target
|
||||
trace[Meta.Match.matchEqs] "proveCondEqThm {mvar0.mvarId!}"
|
||||
let mvarId ← deltaTarget mvar0.mvarId! (· == matchDeclName)
|
||||
trace[Meta.Match.matchEqs] "{MessageData.ofGoal mvarId}"
|
||||
withDefault <| go mvarId 0
|
||||
mkLambdaFVars ys (← instantiateMVars mvar0)
|
||||
where
|
||||
go (mvarId : MVarId) (depth : Nat) : MetaM Unit := withIncRecDepth do
|
||||
trace[Meta.Match.matchEqs] "proveCondEqThm.go {mvarId}"
|
||||
let mvarId' ← modifyTargetEqLHS mvarId whnfCore
|
||||
let mvarId := mvarId'
|
||||
let subgoals ←
|
||||
|
|
@ -395,6 +408,27 @@ where
|
|||
let mvarId ← tryClearMany mvarId (alts.map (·.fvarId!))
|
||||
proveSubgoalLoop mvarId
|
||||
|
||||
/--
|
||||
Create new alternatives (aka minor premises) by replacing `discrs` with `patterns` at `alts`.
|
||||
Recall that `alts` depends on `discrs` when `numDiscrEqs > 0`, where `numDiscrEqs` is the number of discriminants
|
||||
annotated with `h : discr`.
|
||||
-/
|
||||
private partial def withNewAlts (numDiscrEqs : Nat) (discrs : Array Expr) (patterns : Array Expr) (alts : Array Expr) (k : Array Expr → MetaM α) : MetaM α :=
|
||||
if numDiscrEqs == 0 then
|
||||
k alts
|
||||
else
|
||||
go 0 #[]
|
||||
where
|
||||
go (i : Nat) (altsNew : Array Expr) : MetaM α := do
|
||||
if h : i < alts.size then
|
||||
let alt := alts.get ⟨i, h⟩
|
||||
let altLocalDecl ← getFVarLocalDecl alt
|
||||
let typeNew := altLocalDecl.type.replaceFVars discrs patterns
|
||||
withLocalDecl altLocalDecl.userName altLocalDecl.binderInfo typeNew fun altNew =>
|
||||
go (i+1) (altsNew.push altNew)
|
||||
else
|
||||
k altsNew
|
||||
|
||||
/--
|
||||
Create conditional equations and splitter for the given match auxiliary declaration. -/
|
||||
private partial def mkEquationsFor (matchDeclName : Name) : MetaM MatchEqns := withLCtx {} {} do
|
||||
|
|
@ -404,6 +438,7 @@ private partial def mkEquationsFor (matchDeclName : Name) : MetaM MatchEqns :=
|
|||
let constInfo ← getConstInfo matchDeclName
|
||||
let us := constInfo.levelParams.map mkLevelParam
|
||||
let some matchInfo ← getMatcherInfo? matchDeclName | throwError "'{matchDeclName}' is not a matcher function"
|
||||
let numDiscrEqs := getNumEqsFromDiscrInfos matchInfo.discrInfos
|
||||
forallTelescopeReducing constInfo.type fun xs matchResultType => do
|
||||
let mut eqnNames := #[]
|
||||
let params := xs[:matchInfo.numParams]
|
||||
|
|
@ -416,10 +451,12 @@ private partial def mkEquationsFor (matchDeclName : Name) : MetaM MatchEqns :=
|
|||
let mut splitterAltTypes := #[]
|
||||
let mut splitterAltNumParams := #[]
|
||||
let mut altArgMasks := #[] -- masks produced by `forallAltTelescope`
|
||||
for alt in alts do
|
||||
for i in [:alts.size] do
|
||||
let altNumParams := matchInfo.altNumParams[i]
|
||||
let altNonEqNumParams := altNumParams - numDiscrEqs
|
||||
let thmName := baseName ++ ((`eq).appendIndexAfter idx)
|
||||
eqnNames := eqnNames.push thmName
|
||||
let (notAlt, splitterAltType, splitterAltNumParam, argMask) ← forallAltTelescope (← inferType alt) fun ys rhsArgs argMask altResultType => do
|
||||
let (notAlt, splitterAltType, splitterAltNumParam, argMask) ← forallAltTelescope (← inferType alts[i]) altNonEqNumParams fun ys rhsArgs argMask altResultType => do
|
||||
let patterns := altResultType.getAppArgs
|
||||
let mut hs := #[]
|
||||
for notAlt in notAlts do
|
||||
|
|
@ -437,20 +474,24 @@ private partial def mkEquationsFor (matchDeclName : Name) : MetaM MatchEqns :=
|
|||
else
|
||||
notAlt ← mkArrow (← mkHEq discr pattern) notAlt
|
||||
notAlt ← mkForallFVars (discrs ++ ys) notAlt
|
||||
let lhs := mkAppN (mkConst constInfo.name us) (params ++ #[motive] ++ patterns ++ alts)
|
||||
let rhs := mkAppN alt rhsArgs
|
||||
let thmType ← mkEq lhs rhs
|
||||
let thmType ← hs.foldrM (init := thmType) mkArrow
|
||||
let thmType ← mkForallFVars (params ++ #[motive] ++ alts ++ ys) thmType
|
||||
let thmType ← unfoldNamedPattern thmType
|
||||
let thmVal ← proveCondEqThm matchDeclName thmType
|
||||
addDecl <| Declaration.thmDecl {
|
||||
name := thmName
|
||||
levelParams := constInfo.levelParams
|
||||
type := thmType
|
||||
value := thmVal
|
||||
}
|
||||
return (notAlt, splitterAltType, splitterAltNumParam, argMask)
|
||||
/- Recall that when we use the `h : discr`, the alternative type depends on the discriminant.
|
||||
Thus, we need to create new `alts`. -/
|
||||
withNewAlts numDiscrEqs discrs patterns alts fun alts => do
|
||||
let alt := alts[i]
|
||||
let lhs := mkAppN (mkConst constInfo.name us) (params ++ #[motive] ++ patterns ++ alts)
|
||||
let rhs := mkAppN alt rhsArgs
|
||||
let thmType ← mkEq lhs rhs
|
||||
let thmType ← hs.foldrM (init := thmType) mkArrow
|
||||
let thmType ← mkForallFVars (params ++ #[motive] ++ ys ++ alts) thmType
|
||||
let thmType ← unfoldNamedPattern thmType
|
||||
let thmVal ← proveCondEqThm matchDeclName thmType
|
||||
addDecl <| Declaration.thmDecl {
|
||||
name := thmName
|
||||
levelParams := constInfo.levelParams
|
||||
type := thmType
|
||||
value := thmVal
|
||||
}
|
||||
return (notAlt, splitterAltType, splitterAltNumParam, argMask)
|
||||
notAlts := notAlts.push notAlt
|
||||
splitterAltTypes := splitterAltTypes.push splitterAltType
|
||||
splitterAltNumParams := splitterAltNumParams.push splitterAltNumParam
|
||||
|
|
|
|||
|
|
@ -47,6 +47,16 @@ def MatcherInfo.getFirstAltPos (info : MatcherInfo) : Nat :=
|
|||
def MatcherInfo.getMotivePos (info : MatcherInfo) : Nat :=
|
||||
info.numParams
|
||||
|
||||
def getNumEqsFromDiscrInfos (infos : Array DiscrInfo) : Nat := Id.run do
|
||||
let mut r := 0
|
||||
for info in infos do
|
||||
if info.hName?.isSome then
|
||||
r := r + 1
|
||||
return r
|
||||
|
||||
def MatcherInfo.getNumDiscrEqs (info : MatcherInfo) : Nat :=
|
||||
getNumEqsFromDiscrInfos info.discrInfos
|
||||
|
||||
namespace Extension
|
||||
|
||||
structure Entry where
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue