feat: match equation theorem generation for new h : discr notation encoding

TODO: splitter theorem generation still needs to be fixed.
This commit is contained in:
Leonardo de Moura 2022-04-29 11:14:22 -07:00
parent 24417ed466
commit 89441aac2a
3 changed files with 88 additions and 37 deletions

View file

@ -779,7 +779,7 @@ where `v` is a universe parameter or 0 if `B[a_1, ..., a_n]` is a proposition. -
def mkMatcher (input : MkMatcherInput) : MetaM MatcherResult := do
let ⟨matcherName, matchType, discrInfos, lhss⟩ := input
let numDiscrs := discrInfos.size
let numEqs := (discrInfos.filter fun info => info.hName?.isSome).size
let numEqs := getNumEqsFromDiscrInfos discrInfos
checkNumPatterns numDiscrs lhss
forallBoundedTelescope matchType numDiscrs fun discrs matchTypeBody => do
/- We generate an matcher that can eliminate using different motives with different universe levels.

View file

@ -62,8 +62,15 @@ def unfoldNamedPattern (e : Expr) : MetaM Expr := do
Meta.transform e (pre := visit)
/--
Similar to `forallTelescopeReducing`, but eliminates arguments for named parameters and the associated
equation proofs. The continuation `k` takes four arguments `ys args mask type`.
Similar to `forallTelescopeReducing`, but
1. Eliminates arguments for named parameters and the associated equation proofs.
2. Equality parameters associated with the `h : discr` notation are replaced with `rfl` proofs.
Recall that this kind of parameter always occurs after the parameters correspoting to pattern variables.
`numNonEqParams` is the size of the prefix.
The continuation `k` takes four arguments `ys args mask type`.
- `ys` are variables for the hypotheses that have not been eliminated.
- `args` are the arguments for the alternative `alt` that has type `altType`. `ys.size <= args.size`
- `mask[i]` is true if the hypotheses has not been eliminated. `mask.size == args.size`.
@ -71,26 +78,31 @@ def unfoldNamedPattern (e : Expr) : MetaM Expr := do
We use the `mask` to build the splitter proof. See `mkSplitterProof`.
-/
partial def forallAltTelescope (altType : Expr) (k : Array Expr → Array Expr → Array Bool → Expr → MetaM α) : MetaM α := do
go #[] #[] #[] altType
partial def forallAltTelescope (altType : Expr) (numNonEqParams : Nat) (k : Array Expr → Array Expr → Array Bool → Expr → MetaM α) : MetaM α := do
go #[] #[] #[] 0 altType
where
go (ys : Array Expr) (args : Array Expr) (mask : Array Bool) (type : Expr) : MetaM α := do
go (ys : Array Expr) (args : Array Expr) (mask : Array Bool) (i : Nat) (type : Expr) : MetaM α := do
let type ← whnfForall type
match type with
| Expr.forallE n d b .. =>
let d ← unfoldNamedPattern d
withLocalDeclD n d fun y => do
let typeNew := b.instantiate1 y
if let some (_, lhs, rhs) ← matchEq? d then
if lhs.isFVar && ys.contains lhs && args.contains lhs && isNamedPatternProof typeNew y then
let some i := ys.getIdx? lhs | unreachable!
let ys := ys.eraseIdx i
let mask := mask.set! i false
let args := args.map fun arg => if arg == lhs then rhs else arg
let args := args.push (← mkEqRefl rhs)
let typeNew := typeNew.replaceFVar lhs rhs
return (← go ys args (mask.push false) typeNew)
go (ys.push y) (args.push y) (mask.push true) typeNew
if i < numNonEqParams then
let d ← unfoldNamedPattern d
withLocalDeclD n d fun y => do
let typeNew := b.instantiate1 y
if let some (_, lhs, rhs) ← matchEq? d then
if lhs.isFVar && ys.contains lhs && args.contains lhs && isNamedPatternProof typeNew y then
let some i := ys.getIdx? lhs | unreachable!
let ys := ys.eraseIdx i
let mask := mask.set! i false
let args := args.map fun arg => if arg == lhs then rhs else arg
let args := args.push (← mkEqRefl rhs)
let typeNew := typeNew.replaceFVar lhs rhs
return (← go ys args (mask.push false) (i+1) typeNew)
go (ys.push y) (args.push y) (mask.push true) (i+1) typeNew
else
let some (_, _, rhs) ← matchEq? d | throwError "unexpected match alternative type{indentExpr altType}"
let arg ← mkEqRefl rhs
go ys (args.push arg) (mask.push false) (i+1) (b.instantiate1 arg)
| _ =>
let type ← unfoldNamedPattern type
/- Recall that alternatives that do not have variables have a `Unit` parameter to ensure
@ -258,16 +270,17 @@ private def substSomeVar (mvarId : MVarId) : MetaM (Array MVarId) := withMVarCon
/--
Helper method for proving a conditional equational theorem associated with an alternative of
the `match`-eliminator `matchDeclName`. `type` contains the type of the theorem. -/
partial def proveCondEqThm (matchDeclName : Name) (type : Expr) : MetaM Expr := do
partial def proveCondEqThm (matchDeclName : Name) (type : Expr) : MetaM Expr := withLCtx {} {} do
let type ← instantiateMVars type
forallTelescope type fun ys target => do
let mvar0 ← mkFreshExprSyntheticOpaqueMVar target
trace[Meta.Match.matchEqs] "proveCondEqThm {mvar0.mvarId!}"
let mvarId ← deltaTarget mvar0.mvarId! (· == matchDeclName)
trace[Meta.Match.matchEqs] "{MessageData.ofGoal mvarId}"
withDefault <| go mvarId 0
mkLambdaFVars ys (← instantiateMVars mvar0)
where
go (mvarId : MVarId) (depth : Nat) : MetaM Unit := withIncRecDepth do
trace[Meta.Match.matchEqs] "proveCondEqThm.go {mvarId}"
let mvarId' ← modifyTargetEqLHS mvarId whnfCore
let mvarId := mvarId'
let subgoals ←
@ -395,6 +408,27 @@ where
let mvarId ← tryClearMany mvarId (alts.map (·.fvarId!))
proveSubgoalLoop mvarId
/--
Create new alternatives (aka minor premises) by replacing `discrs` with `patterns` at `alts`.
Recall that `alts` depends on `discrs` when `numDiscrEqs > 0`, where `numDiscrEqs` is the number of discriminants
annotated with `h : discr`.
-/
private partial def withNewAlts (numDiscrEqs : Nat) (discrs : Array Expr) (patterns : Array Expr) (alts : Array Expr) (k : Array Expr → MetaM α) : MetaM α :=
if numDiscrEqs == 0 then
k alts
else
go 0 #[]
where
go (i : Nat) (altsNew : Array Expr) : MetaM α := do
if h : i < alts.size then
let alt := alts.get ⟨i, h⟩
let altLocalDecl ← getFVarLocalDecl alt
let typeNew := altLocalDecl.type.replaceFVars discrs patterns
withLocalDecl altLocalDecl.userName altLocalDecl.binderInfo typeNew fun altNew =>
go (i+1) (altsNew.push altNew)
else
k altsNew
/--
Create conditional equations and splitter for the given match auxiliary declaration. -/
private partial def mkEquationsFor (matchDeclName : Name) : MetaM MatchEqns := withLCtx {} {} do
@ -404,6 +438,7 @@ private partial def mkEquationsFor (matchDeclName : Name) : MetaM MatchEqns :=
let constInfo ← getConstInfo matchDeclName
let us := constInfo.levelParams.map mkLevelParam
let some matchInfo ← getMatcherInfo? matchDeclName | throwError "'{matchDeclName}' is not a matcher function"
let numDiscrEqs := getNumEqsFromDiscrInfos matchInfo.discrInfos
forallTelescopeReducing constInfo.type fun xs matchResultType => do
let mut eqnNames := #[]
let params := xs[:matchInfo.numParams]
@ -416,10 +451,12 @@ private partial def mkEquationsFor (matchDeclName : Name) : MetaM MatchEqns :=
let mut splitterAltTypes := #[]
let mut splitterAltNumParams := #[]
let mut altArgMasks := #[] -- masks produced by `forallAltTelescope`
for alt in alts do
for i in [:alts.size] do
let altNumParams := matchInfo.altNumParams[i]
let altNonEqNumParams := altNumParams - numDiscrEqs
let thmName := baseName ++ ((`eq).appendIndexAfter idx)
eqnNames := eqnNames.push thmName
let (notAlt, splitterAltType, splitterAltNumParam, argMask) ← forallAltTelescope (← inferType alt) fun ys rhsArgs argMask altResultType => do
let (notAlt, splitterAltType, splitterAltNumParam, argMask) ← forallAltTelescope (← inferType alts[i]) altNonEqNumParams fun ys rhsArgs argMask altResultType => do
let patterns := altResultType.getAppArgs
let mut hs := #[]
for notAlt in notAlts do
@ -437,20 +474,24 @@ private partial def mkEquationsFor (matchDeclName : Name) : MetaM MatchEqns :=
else
notAlt ← mkArrow (← mkHEq discr pattern) notAlt
notAlt ← mkForallFVars (discrs ++ ys) notAlt
let lhs := mkAppN (mkConst constInfo.name us) (params ++ #[motive] ++ patterns ++ alts)
let rhs := mkAppN alt rhsArgs
let thmType ← mkEq lhs rhs
let thmType ← hs.foldrM (init := thmType) mkArrow
let thmType ← mkForallFVars (params ++ #[motive] ++ alts ++ ys) thmType
let thmType ← unfoldNamedPattern thmType
let thmVal ← proveCondEqThm matchDeclName thmType
addDecl <| Declaration.thmDecl {
name := thmName
levelParams := constInfo.levelParams
type := thmType
value := thmVal
}
return (notAlt, splitterAltType, splitterAltNumParam, argMask)
/- Recall that when we use the `h : discr`, the alternative type depends on the discriminant.
Thus, we need to create new `alts`. -/
withNewAlts numDiscrEqs discrs patterns alts fun alts => do
let alt := alts[i]
let lhs := mkAppN (mkConst constInfo.name us) (params ++ #[motive] ++ patterns ++ alts)
let rhs := mkAppN alt rhsArgs
let thmType ← mkEq lhs rhs
let thmType ← hs.foldrM (init := thmType) mkArrow
let thmType ← mkForallFVars (params ++ #[motive] ++ ys ++ alts) thmType
let thmType ← unfoldNamedPattern thmType
let thmVal ← proveCondEqThm matchDeclName thmType
addDecl <| Declaration.thmDecl {
name := thmName
levelParams := constInfo.levelParams
type := thmType
value := thmVal
}
return (notAlt, splitterAltType, splitterAltNumParam, argMask)
notAlts := notAlts.push notAlt
splitterAltTypes := splitterAltTypes.push splitterAltType
splitterAltNumParams := splitterAltNumParams.push splitterAltNumParam

View file

@ -47,6 +47,16 @@ def MatcherInfo.getFirstAltPos (info : MatcherInfo) : Nat :=
def MatcherInfo.getMotivePos (info : MatcherInfo) : Nat :=
info.numParams
def getNumEqsFromDiscrInfos (infos : Array DiscrInfo) : Nat := Id.run do
let mut r := 0
for info in infos do
if info.hName?.isSome then
r := r + 1
return r
def MatcherInfo.getNumDiscrEqs (info : MatcherInfo) : Nat :=
getNumEqsFromDiscrInfos info.discrInfos
namespace Extension
structure Entry where