diff --git a/src/Lean/Meta/Tactic/Grind/Split.lean b/src/Lean/Meta/Tactic/Grind/Split.lean index 8b66155b68..03540d3102 100644 --- a/src/Lean/Meta/Tactic/Grind/Split.lean +++ b/src/Lean/Meta/Tactic/Grind/Split.lean @@ -185,13 +185,10 @@ where match cs with | [] => modify fun s => { s with split.candidates := cs'.reverse } - if let .some _ numCases isRec _ := c? then - let numSplits := (← get).split.num - -- We only increase the number of splits if there is more than one case or it is recursive. - let numSplits := if numCases > 1 || isRec then numSplits + 1 else numSplits + if let .some .. := c? then -- Remark: we reset `numEmatch` after each case split. -- We should consider other strategies in the future. - modify fun s => { s with split.num := numSplits, ematch.num := 0 } + modify fun s => { s with ematch.num := 0 } return c? | c::cs => if !(← checkAnchorRefs c) then @@ -422,10 +419,24 @@ def splitCore (c : SplitInfo) (numCases : Nat) (isRec : Bool) pure 0 return (mvarIds, numDigits) let numSubgoals := mvarIds.length - let subgoals := mvarIds.mapIdx fun i mvarId => { goal with - mvarId - split.trace := { expr := cExpr, i, num := numSubgoals, source := c.source } :: goal.split.trace - } + /- + **Split counter heuristic**: We do not increment `numSplits` for the first case (`i = 0`) + of a non-recursive split. This leverages non-chronological backtracking: if the first case + is solved using a proof that doesn't depend on the case hypothesis, we backtrack and close + the original goal directly. In this scenario, the case-split was "free", it didn't contribute + to the proof. By not counting it, we allow deeper exploration when case-splits turn out to be + irrelevant. + + For recursive types or subsequent cases (`i > 0`), we always increment the counter since + these represent genuine branches in the proof search. + -/ + let subgoals := mvarIds.mapIdx fun i mvarId => + let numSplits := goal.split.num + let numSplits := if i > 0 || isRec then numSplits + 1 else numSplits + { goal with + mvarId + split.num := numSplits + split.trace := { expr := cExpr, i, num := numSubgoals, source := c.source } :: goal.split.trace } let mut seqNew : Array (List (TSyntax `grind)) := #[] let mut stuckNew : Array Goal := #[] for subgoal in subgoals do diff --git a/tests/lean/run/grind_11081.lean b/tests/lean/run/grind_11081.lean index 97a07dceea..397485a9ab 100644 --- a/tests/lean/run/grind_11081.lean +++ b/tests/lean/run/grind_11081.lean @@ -25,7 +25,7 @@ open List /-- error: `grind` failed -case grind.1.1.1.1.1.1.1.1.1 +case grind.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1.1 α : Type inst : DecidableEq α l₁ l₂ : List α @@ -66,6 +66,44 @@ left_8 : l₁ ~ l₁.diff l₂ right_8 : ∀ (a : α), count a l₁ = count a (l₁.diff l₂) left_9 : l₁ ~ l₂ right_9 : ∀ (a : α), count a l₁ = count a l₂ +left_10 : filter p l₁ ~ filter p (l₁.diff l₂ ++ l₂) +right_10 : ∀ (a : α), count a (filter p l₁) = count a (filter p (l₁.diff l₂ ++ l₂)) +left_11 : filter p (l₁.diff l₂ ++ l₂) ~ filter p l₁ +right_11 : ∀ (a : α), count a (filter p (l₁.diff l₂ ++ l₂)) = count a (filter p l₁) +left_12 : l₁.diff l₂ ++ l₂ ~ l₂ ++ (l₁.diff l₂ ++ l₂) +right_12 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂) = count a (l₂ ++ (l₁.diff l₂ ++ l₂)) +left_13 : l₁.diff l₂ ++ l₂ ~ l₁.diff l₂ ++ l₂ ++ l₁ +right_13 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂) = count a (l₁.diff l₂ ++ l₂ ++ l₁) +left_14 : l₁.diff l₂ ++ l₂ ~ l₁.diff l₂ ++ l₂ ++ (l₁.diff l₂ ++ l₂) +right_14 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂) = count a (l₁.diff l₂ ++ l₂ ++ (l₁.diff l₂ ++ l₂)) +left_15 : l₁.diff l₂ ++ l₂ ~ l₁.diff l₂ ++ l₂ ++ l₂ +right_15 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂) = count a (l₁.diff l₂ ++ l₂ ++ l₂) +left_16 : l₁.diff l₂ ++ l₂ ~ l₁.diff l₂ ++ l₂ ++ l₁.diff l₂ +right_16 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂) = count a (l₁.diff l₂ ++ l₂ ++ l₁.diff l₂) +left_17 : l₁.diff l₂ ++ l₂ ~ l₁ ++ (l₁.diff l₂ ++ l₂) +right_17 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂) = count a (l₁ ++ (l₁.diff l₂ ++ l₂)) +left_18 : filter p (l₁.diff l₂ ++ l₂) ~ filter p (l₁.diff l₂) +right_18 : ∀ (a : α), count a (filter p (l₁.diff l₂ ++ l₂)) = count a (filter p (l₁.diff l₂)) +left_19 : filter p (l₁.diff l₂) ~ filter p (l₁.diff l₂ ++ l₂) +right_19 : ∀ (a : α), count a (filter p (l₁.diff l₂)) = count a (filter p (l₁.diff l₂ ++ l₂)) +left_20 : (filter p (l₁.diff l₂ ++ l₂)).Subperm (filter p l₁) +right_20 : (filter p (l₁.diff l₂ ++ l₂)).Subperm (filter p (l₁.diff l₂ ++ l₂)) +left_21 : l₁.diff l₂ ++ l₂ ++ l₁.diff l₂ ~ l₁.diff l₂ ++ l₂ +right_21 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂ ++ l₁.diff l₂) = count a (l₁.diff l₂ ++ l₂) +left_22 : l₁.diff l₂ ++ l₂ ++ l₁ ~ l₁.diff l₂ ++ l₂ +right_22 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂ ++ l₁) = count a (l₁.diff l₂ ++ l₂) +left_23 : l₁.diff l₂ ++ l₂ ++ l₂ ~ l₁.diff l₂ ++ l₂ +right_23 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂ ++ l₂) = count a (l₁.diff l₂ ++ l₂) +left_24 : l₁.diff l₂ ++ l₂ ++ (l₁.diff l₂ ++ l₂) ~ l₁.diff l₂ ++ l₂ +right_24 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂ ++ (l₁.diff l₂ ++ l₂)) = count a (l₁.diff l₂ ++ l₂) +left_25 : l₁ ++ (l₁.diff l₂ ++ l₂) ~ l₁.diff l₂ ++ l₂ +right_25 : ∀ (a : α), count a (l₁ ++ (l₁.diff l₂ ++ l₂)) = count a (l₁.diff l₂ ++ l₂) +left_26 : l₂ ++ (l₁.diff l₂ ++ l₂) ~ l₁.diff l₂ ++ l₂ +right_26 : ∀ (a : α), count a (l₂ ++ (l₁.diff l₂ ++ l₂)) = count a (l₁.diff l₂ ++ l₂) +left_27 : l₁.diff l₂ ++ (l₁.diff l₂ ++ l₂) ~ l₁.diff l₂ ++ l₂ +right_27 : ∀ (a : α), count a (l₁.diff l₂ ++ (l₁.diff l₂ ++ l₂)) = count a (l₁.diff l₂ ++ l₂) +left_28 : l₁.diff l₂ ++ l₂ ~ l₁.diff l₂ ++ (l₁.diff l₂ ++ l₂) +right_28 : ∀ (a : α), count a (l₁.diff l₂ ++ l₂) = count a (l₁.diff l₂ ++ (l₁.diff l₂ ++ l₂)) ⊢ False -/ #guard_msgs in diff --git a/tests/lean/run/grind_11539_2.lean b/tests/lean/run/grind_11539_2.lean new file mode 100644 index 0000000000..f013176b33 --- /dev/null +++ b/tests/lean/run/grind_11539_2.lean @@ -0,0 +1,5 @@ +example (a b : Nat) (f g : Nat → Nat) + (hf : (∀ i ≤ a, f i ≤ f (i + 1)) ∧ f 0 = 0) + (hg : (∀ i ≤ b, g i ≤ g (i + 1)) ∧ g 0 = 0 ∧ g b = 0) : + g (a + b - a) = 0 := by + grind