fix: splitMatch tactic

Improve how we compute the motive for match-splitter eliminator.

closes #986
This commit is contained in:
Leonardo de Moura 2022-02-02 15:02:55 -08:00
parent d7f085976f
commit 188f0eb70f
6 changed files with 26 additions and 23 deletions

View file

@ -79,29 +79,31 @@ private def generalizeMatchDiscrs (mvarId : MVarId) (discrs : Array Expr) : Meta
return (result, mvarId)
def applyMatchSplitter (mvarId : MVarId) (matcherDeclName : Name) (us : Array Level) (params : Array Expr) (discrs : Array Expr) : MetaM (List MVarId) := do
let (discrFVarIds, mvarId) ← generalizeMatchDiscrs mvarId discrs
let (reverted, mvarId) ← revert mvarId discrFVarIds (preserveOrder := true)
let (discrFVarIds, mvarId) ← introNP mvarId discrFVarIds.size
let numExtra := reverted.size - discrFVarIds.size
let discrs := discrFVarIds.map mkFVar
let some info ← getMatcherInfo? matcherDeclName | throwError "'applyMatchSplitter' failed, '{matcherDeclName}' is not a 'match' auxiliary declaration."
let matchEqns ← Match.getEquationsFor matcherDeclName
let mut us := us
if let some uElimPos := info.uElimPos? then
-- Set universe elimination level to zero (Prop).
us := us.set! uElimPos levelZero
let splitter := mkAppN (mkConst matchEqns.splitterName us.toList) params
let motiveType := (← whnfForall (← inferType splitter)).bindingDomain!
let (discrFVarIds, mvarId) ← generalizeMatchDiscrs mvarId discrs
let mvarId ← generalizeTargetsEq mvarId motiveType (discrFVarIds.map mkFVar)
let numEqs := discrs.size
let (discrFVarIdsNew, mvarId) ← introN mvarId discrs.size
let discrsNew := discrFVarIdsNew.map mkFVar
withMVarContext mvarId do
let motive ← mkLambdaFVars discrs (← getMVarType mvarId)
-- Fix universe
let mut us := us
if let some uElimPos := info.uElimPos? then
-- Set universe elimination level to zero (Prop).
us := us.set! uElimPos levelZero
let splitter := mkAppN (mkConst matchEqns.splitterName us.toList) params
let splitter := mkAppN (mkApp splitter motive) discrs
check splitter -- TODO
let motive ← mkLambdaFVars discrsNew (← getMVarType mvarId)
let splitter := mkAppN (mkApp splitter motive) discrsNew
check splitter
let mvarIds ← apply mvarId splitter
let (_, mvarIds) ← mvarIds.foldlM (init := (0, [])) fun (i, mvarIds) mvarId => do
let numParams := matchEqns.splitterAltNumParams[i]
let (_, mvarId) ← introN mvarId numParams
let (_, mvarId) ← introNP mvarId numExtra
return (i+1, mvarId::mvarIds)
match (← Cases.unifyEqs numEqs mvarId {}) with
| none => return (i+1, mvarIds) -- case was solved
| some (mvarId, _) =>
return (i+1, mvarId::mvarIds)
return mvarIds.reverse
def splitMatch (mvarId : MVarId) (e : Expr) : MetaM (List MVarId) := do

View file

@ -1,4 +1,4 @@
f._eq_1 : f 0 = 1
f._eq_2 : f 100 = 2
f._eq_3 : f 1000 = 3
f._eq_4 : ∀ (x_1 : Nat), (x_1 = 99 → False) → (x_1 = 999 → False) → f (Nat.succ x_1) = f x_1
f._eq_4 : ∀ (x_2 : Nat), (x_2 = 99 → False) → (x_2 = 999 → False) → f (Nat.succ x_2) = f x_2

View file

@ -1,7 +1,7 @@
iota._eq_1 : iota 0 = []
iota._eq_2 : ∀ (n : Nat), iota (Nat.succ n) = Nat.succ n :: iota n
f._eq_1 : ∀ (x y : Nat), f [x, y] = ([x, y], [y])
f._eq_2 : ∀ (x y : Nat) (zs : List Nat), (zs = [] → False) → f (x :: y :: zs) = f zs
f._eq_1 : ∀ (x_1 y : Nat), f [x_1, y] = ([x_1, y], [y])
f._eq_2 : ∀ (x_1 y : Nat) (zs : List Nat), (zs = [] → False) → f (x_1 :: y :: zs) = f zs
f._eq_3 : ∀ (x : List Nat),
(∀ (x_1 y : Nat), x = [x_1, y] → False) →
(∀ (x_2 y : Nat) (zs : List Nat), x = x_2 :: y :: zs → False) → f x = ([], [])
(∀ (x_1 y : Nat) (zs : List Nat), x = x_1 :: y :: zs → False) → f x = ([], [])

1
tests/lean/run/986.lean Normal file
View file

@ -0,0 +1 @@
attribute [simp] Array.insertionSort.swapLoop

View file

@ -37,6 +37,6 @@ def g (xs ys : List Nat) : Nat :=
example (xs ys : List Nat) : g xs ys > 0 := by
simp [g]
split
next a b _ => show Nat.succ (a + b) > 0; apply Nat.zero_lt_succ
next a b => show Nat.succ (a + b) > 0; apply Nat.zero_lt_succ
next xs b c _ => show Nat.succ b > 0; apply Nat.zero_lt_succ
next => decide

View file

@ -8,10 +8,10 @@ example (a b : Bool) (x y z : Nat) (xs : List Nat) (h1 : (if a then x else y) =
simp [g]
repeat any_goals (split at *)
any_goals (first | decide | contradiction | injections)
next b c _ _ =>
next b c _ =>
show Nat.succ b = 1
simp [List.head!] at h2; simp [h2]
next b c _ _ =>
next b c _ =>
show Nat.succ b = 1
simp [List.head!] at h2; simp [h2]