feat: improve case-split heuristic in grind (#11609)
This PR improves the case-split heuristics in `grind`. In this PR, we do not increment the number of case splits in the first case. The idea is to leverage non-chronological backtracking: if the first case is solved using a proof that doesn't depend on the case hypothesis, we backtrack and close the original goal directly. In this scenario, the case-split was "free", it didn't contribute to the proof. By not counting it, we allow deeper exploration when case-splits turn out to be irrelevant. The new heuristic addresses the second example in #11545
This commit is contained in:
parent
b8c53b1d29
commit
ea10bdf154
3 changed files with 64 additions and 10 deletions
|
|
@ -185,13 +185,10 @@ where
|
|||
match cs with
|
||||
| [] =>
|
||||
modify fun s => { s with split.candidates := cs'.reverse }
|
||||
if let .some _ numCases isRec _ := c? then
|
||||
let numSplits := (← get).split.num
|
||||
-- We only increase the number of splits if there is more than one case or it is recursive.
|
||||
let numSplits := if numCases > 1 || isRec then numSplits + 1 else numSplits
|
||||
if let .some .. := c? then
|
||||
-- Remark: we reset `numEmatch` after each case split.
|
||||
-- We should consider other strategies in the future.
|
||||
modify fun s => { s with split.num := numSplits, ematch.num := 0 }
|
||||
modify fun s => { s with ematch.num := 0 }
|
||||
return c?
|
||||
| c::cs =>
|
||||
if !(← checkAnchorRefs c) then
|
||||
|
|
@ -422,10 +419,24 @@ def splitCore (c : SplitInfo) (numCases : Nat) (isRec : Bool)
|
|||
pure 0
|
||||
return (mvarIds, numDigits)
|
||||
let numSubgoals := mvarIds.length
|
||||
let subgoals := mvarIds.mapIdx fun i mvarId => { goal with
|
||||
mvarId
|
||||
split.trace := { expr := cExpr, i, num := numSubgoals, source := c.source } :: goal.split.trace
|
||||
}
|
||||
/-
|
||||
**Split counter heuristic**: We do not increment `numSplits` for the first case (`i = 0`)
|
||||
of a non-recursive split. This leverages non-chronological backtracking: if the first case
|
||||
is solved using a proof that doesn't depend on the case hypothesis, we backtrack and close
|
||||
the original goal directly. In this scenario, the case-split was "free", it didn't contribute
|
||||
to the proof. By not counting it, we allow deeper exploration when case-splits turn out to be
|
||||
irrelevant.
|
||||
|
||||
For recursive types or subsequent cases (`i > 0`), we always increment the counter since
|
||||
these represent genuine branches in the proof search.
|
||||
-/
|
||||
let subgoals := mvarIds.mapIdx fun i mvarId =>
|
||||
let numSplits := goal.split.num
|
||||
let numSplits := if i > 0 || isRec then numSplits + 1 else numSplits
|
||||
{ goal with
|
||||
mvarId
|
||||
split.num := numSplits
|
||||
split.trace := { expr := cExpr, i, num := numSubgoals, source := c.source } :: goal.split.trace }
|
||||
let mut seqNew : Array (List (TSyntax `grind)) := #[]
|
||||
let mut stuckNew : Array Goal := #[]
|
||||
for subgoal in subgoals do
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ open List
|
|||
|
||||
/--
|
||||
error: `grind` failed
|
||||
case grind.1.1.1.1.1.1.1.1.1
|
||||
case grind.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1
|
||||
α : Type
|
||||
inst : DecidableEq α
|
||||
l₁ l₂ : List α
|
||||
|
|
@ -66,6 +66,44 @@ left_8 : l₁ ~ l₁.diff l₂
|
|||
right_8 : ∀ (a : α), count a l₁ = count a (l₁.diff l₂)
|
||||
left_9 : l₁ ~ l₂
|
||||
right_9 : ∀ (a : α), count a l₁ = count a l₂
|
||||
left_10 : filter p l₁ ~ filter p (l₁.diff l₂ ++ l₂)
|
||||
right_10 : ∀ (a : α), count a (filter p l₁) = count a (filter p (l₁.diff l₂ ++ l₂))
|
||||
left_11 : filter p (l₁.diff l₂ ++ l₂) ~ filter p l₁
|
||||
right_11 : ∀ (a : α), count a (filter p (l₁.diff l₂ ++ l₂)) = count a (filter p l₁)
|
||||
left_12 : l₁.diff l₂ ++ l₂ ~ l₂ ++ (l₁.diff l₂ ++ l₂)
|
||||
right_12 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂) = count a (l₂ ++ (l₁.diff l₂ ++ l₂))
|
||||
left_13 : l₁.diff l₂ ++ l₂ ~ l₁.diff l₂ ++ l₂ ++ l₁
|
||||
right_13 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂) = count a (l₁.diff l₂ ++ l₂ ++ l₁)
|
||||
left_14 : l₁.diff l₂ ++ l₂ ~ l₁.diff l₂ ++ l₂ ++ (l₁.diff l₂ ++ l₂)
|
||||
right_14 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂) = count a (l₁.diff l₂ ++ l₂ ++ (l₁.diff l₂ ++ l₂))
|
||||
left_15 : l₁.diff l₂ ++ l₂ ~ l₁.diff l₂ ++ l₂ ++ l₂
|
||||
right_15 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂) = count a (l₁.diff l₂ ++ l₂ ++ l₂)
|
||||
left_16 : l₁.diff l₂ ++ l₂ ~ l₁.diff l₂ ++ l₂ ++ l₁.diff l₂
|
||||
right_16 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂) = count a (l₁.diff l₂ ++ l₂ ++ l₁.diff l₂)
|
||||
left_17 : l₁.diff l₂ ++ l₂ ~ l₁ ++ (l₁.diff l₂ ++ l₂)
|
||||
right_17 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂) = count a (l₁ ++ (l₁.diff l₂ ++ l₂))
|
||||
left_18 : filter p (l₁.diff l₂ ++ l₂) ~ filter p (l₁.diff l₂)
|
||||
right_18 : ∀ (a : α), count a (filter p (l₁.diff l₂ ++ l₂)) = count a (filter p (l₁.diff l₂))
|
||||
left_19 : filter p (l₁.diff l₂) ~ filter p (l₁.diff l₂ ++ l₂)
|
||||
right_19 : ∀ (a : α), count a (filter p (l₁.diff l₂)) = count a (filter p (l₁.diff l₂ ++ l₂))
|
||||
left_20 : (filter p (l₁.diff l₂ ++ l₂)).Subperm (filter p l₁)
|
||||
right_20 : (filter p (l₁.diff l₂ ++ l₂)).Subperm (filter p (l₁.diff l₂ ++ l₂))
|
||||
left_21 : l₁.diff l₂ ++ l₂ ++ l₁.diff l₂ ~ l₁.diff l₂ ++ l₂
|
||||
right_21 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂ ++ l₁.diff l₂) = count a (l₁.diff l₂ ++ l₂)
|
||||
left_22 : l₁.diff l₂ ++ l₂ ++ l₁ ~ l₁.diff l₂ ++ l₂
|
||||
right_22 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂ ++ l₁) = count a (l₁.diff l₂ ++ l₂)
|
||||
left_23 : l₁.diff l₂ ++ l₂ ++ l₂ ~ l₁.diff l₂ ++ l₂
|
||||
right_23 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂ ++ l₂) = count a (l₁.diff l₂ ++ l₂)
|
||||
left_24 : l₁.diff l₂ ++ l₂ ++ (l₁.diff l₂ ++ l₂) ~ l₁.diff l₂ ++ l₂
|
||||
right_24 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂ ++ (l₁.diff l₂ ++ l₂)) = count a (l₁.diff l₂ ++ l₂)
|
||||
left_25 : l₁ ++ (l₁.diff l₂ ++ l₂) ~ l₁.diff l₂ ++ l₂
|
||||
right_25 : ∀ (a : α), count a (l₁ ++ (l₁.diff l₂ ++ l₂)) = count a (l₁.diff l₂ ++ l₂)
|
||||
left_26 : l₂ ++ (l₁.diff l₂ ++ l₂) ~ l₁.diff l₂ ++ l₂
|
||||
right_26 : ∀ (a : α), count a (l₂ ++ (l₁.diff l₂ ++ l₂)) = count a (l₁.diff l₂ ++ l₂)
|
||||
left_27 : l₁.diff l₂ ++ (l₁.diff l₂ ++ l₂) ~ l₁.diff l₂ ++ l₂
|
||||
right_27 : ∀ (a : α), count a (l₁.diff l₂ ++ (l₁.diff l₂ ++ l₂)) = count a (l₁.diff l₂ ++ l₂)
|
||||
left_28 : l₁.diff l₂ ++ l₂ ~ l₁.diff l₂ ++ (l₁.diff l₂ ++ l₂)
|
||||
right_28 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂) = count a (l₁.diff l₂ ++ (l₁.diff l₂ ++ l₂))
|
||||
⊢ False
|
||||
-/
|
||||
#guard_msgs in
|
||||
|
|
|
|||
5
tests/lean/run/grind_11539_2.lean
Normal file
5
tests/lean/run/grind_11539_2.lean
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
example (a b : Nat) (f g : Nat → Nat)
|
||||
(hf : (∀ i ≤ a, f i ≤ f (i + 1)) ∧ f 0 = 0)
|
||||
(hg : (∀ i ≤ b, g i ≤ g (i + 1)) ∧ g 0 = 0 ∧ g b = 0) :
|
||||
g (a + b - a) = 0 := by
|
||||
grind
|
||||
Loading…
Add table
Reference in a new issue