feat: generate conditional equation theorems for match expressions
This commit is contained in:
parent
37f2f7d472
commit
49520aa2ee
2 changed files with 59 additions and 37 deletions
|
|
@ -38,6 +38,7 @@ partial def mkEquationsFor (matchDeclName : Name) : MetaM Unit := do
|
|||
let firstDiscrIdx := matchInfo.numParams + 1
|
||||
let discrs := xs[firstDiscrIdx : firstDiscrIdx + matchInfo.numDiscrs]
|
||||
let mut notAlts := #[]
|
||||
let mut idx := 1
|
||||
for alt in alts do
|
||||
let altType ← inferType alt
|
||||
trace[Meta.debug] ">> {altType}"
|
||||
|
|
@ -62,8 +63,15 @@ partial def mkEquationsFor (matchDeclName : Name) : MetaM Unit := do
|
|||
let thmType ← mkForallFVars (params ++ #[motive] ++ alts ++ ys) thmType
|
||||
let thmVal ← prove thmType
|
||||
trace[Meta.debug] "thmVal: {thmVal}"
|
||||
-- check thmVal -- TODO remove
|
||||
let thmName := matchDeclName ++ ((`eq).appendIndexAfter idx)
|
||||
addDecl <| Declaration.thmDecl {
|
||||
name := thmName
|
||||
levelParams := constInfo.levelParams
|
||||
type := thmType
|
||||
value := thmVal
|
||||
}
|
||||
return notAlts.push notAlt
|
||||
idx := idx + 1
|
||||
where
|
||||
toFVarsRHSArgs (ys : Array Expr) : MetaM (Array Expr × Array Expr) := do
|
||||
if ys.size == 1 && (← inferType ys[0]).isConstOf ``Unit then
|
||||
|
|
@ -116,29 +124,31 @@ where
|
|||
none
|
||||
|
||||
proveLoop (mvarId : MVarId) (depth : Nat) : MetaM Unit := withIncRecDepth do
|
||||
let mvarId ← modifyTargetEqLHS mvarId whnfCore
|
||||
let mvarId' ← modifyTargetEqLHS mvarId whnfCore
|
||||
let mvarId := mvarId'
|
||||
trace[Meta.debug] "proveLoop [{depth}]\n{MessageData.ofGoal mvarId}"
|
||||
(applyRefl mvarId)
|
||||
<|>
|
||||
(contradiction mvarId { genDiseq := true })
|
||||
<|>
|
||||
(do (← casesOnStuckLHS mvarId).forM (proveLoop . (depth + 1)))
|
||||
<|>
|
||||
(do let mvarId' ← simpIfTarget mvarId (useDecide := true)
|
||||
trace[Meta.debug] "simpIfTarget\n{MessageData.ofGoal mvarId'}"
|
||||
if mvarId' == mvarId then throwError "simpIf failed"
|
||||
proveLoop mvarId' (depth+1))
|
||||
<|>
|
||||
(do if let some (s₁, s₂) ← splitIfTarget? mvarId then
|
||||
let mvarId₁ ← trySubst s₁.mvarId s₁.fvarId
|
||||
proveLoop mvarId₁ (depth+1)
|
||||
proveLoop s₂.mvarId (depth+1)
|
||||
else
|
||||
throwError "spliIf failed")
|
||||
<|>
|
||||
(throwError "failed to generate equality theorems for `match` expression, support for array literals has not been implemented yet{MessageData.ofGoal mvarId}")
|
||||
let subgoals ←
|
||||
(do applyRefl mvarId; return #[])
|
||||
<|>
|
||||
(do contradiction mvarId { genDiseq := true }; return #[])
|
||||
<|>
|
||||
(casesOnStuckLHS mvarId)
|
||||
<|>
|
||||
(do let mvarId' ← simpIfTarget mvarId (useDecide := true)
|
||||
if mvarId' == mvarId then throwError "simpIf failed"
|
||||
return #[mvarId'])
|
||||
<|>
|
||||
(do if let some (s₁, s₂) ← splitIfTarget? mvarId then
|
||||
let mvarId₁ ← trySubst s₁.mvarId s₁.fvarId
|
||||
return #[mvarId₁, s₂.mvarId]
|
||||
else
|
||||
throwError "spliIf failed")
|
||||
<|>
|
||||
(throwError "failed to generate equality theorems for `match` expression, support for array literals has not been implemented yet{MessageData.ofGoal mvarId}")
|
||||
subgoals.forM (proveLoop . (depth+1))
|
||||
|
||||
prove (type : Expr) : MetaM Expr :=
|
||||
prove (type : Expr) : MetaM Expr := do
|
||||
let type ← instantiateMVars type
|
||||
withLCtx {} {} <| forallTelescope type fun ys target => do
|
||||
let mvar0 ← mkFreshExprSyntheticOpaqueMVar target
|
||||
let mvarId ← deltaTarget mvar0.mvarId! (. == matchDeclName)
|
||||
|
|
|
|||
|
|
@ -27,22 +27,34 @@ def h (x y : Nat) : Nat :=
|
|||
| Nat.zero, y+1 => 44
|
||||
| _, _ => 1
|
||||
|
||||
-- theorem ex1 : h 10000 1 = 0 :=
|
||||
-- rfl
|
||||
theorem ex1 : h 10000 1 = 0 := rfl
|
||||
theorem ex2 : h 10002 1 = 3 := rfl
|
||||
|
||||
-- theorem ex2 : h 10002 1 = 3 :=
|
||||
-- rfl
|
||||
|
||||
def g (xs ys : Array Nat) : Nat :=
|
||||
match xs, ys with
|
||||
| #[], #[] => 0
|
||||
| _, #[0, y+1] => 1
|
||||
| _, #[x, y] => 2
|
||||
| _, _ => 3
|
||||
|
||||
set_option trace.Meta.debug true
|
||||
set_option pp.proofs true
|
||||
-- set_option trace.Meta.debug true
|
||||
-- set_option pp.proofs true
|
||||
-- set_option trace.Meta.debug truen
|
||||
test% f.match_1
|
||||
set_option pp.analyze false
|
||||
#check @f.match_1.eq_1
|
||||
#check @f.match_1.eq_2
|
||||
#check @f.match_1.eq_3
|
||||
#check @f.match_1.eq_4
|
||||
|
||||
test% h.match_1
|
||||
-- test% g.match_1
|
||||
#check @h.match_1.eq_1
|
||||
#check @h.match_1.eq_2
|
||||
#check @h.match_1.eq_3
|
||||
#check @h.match_1.eq_4
|
||||
#check @h.match_1.eq_5
|
||||
#check @h.match_1.eq_6
|
||||
|
||||
def g (xs ys : List (Nat × String)) : Nat :=
|
||||
match xs, ys with
|
||||
| _, [(a,b)] => 0
|
||||
| [(c, d)], _ => 1
|
||||
| _, _ => 2
|
||||
|
||||
test% g.match_1
|
||||
#check @g.match_1.eq_1
|
||||
#check @g.match_1.eq_2
|
||||
#check @g.match_1.eq_3
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue