From 5440bf724da4696be87defe7148f4890c5da103b Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 20 Dec 2025 12:17:09 -0800 Subject: [PATCH] fix: case-splitting selection in `grind` (#11749) This PR fixes a bug in the function `selectNextSplit?` used in `grind`. It was incorrectly computing the generation of each candidate. Closes #11697 --- src/Lean/Meta/Tactic/Grind/Split.lean | 6 +++++- src/Lean/Meta/Tactic/Grind/Types.lean | 8 ++++++++ tests/lean/run/grind_11697_a.lean | 10 ++++++++++ tests/lean/run/grind_11697_b.lean | 10 ++++++++++ 4 files changed, 33 insertions(+), 1 deletion(-) create mode 100644 tests/lean/run/grind_11697_a.lean create mode 100644 tests/lean/run/grind_11697_b.lean diff --git a/src/Lean/Meta/Tactic/Grind/Split.lean b/src/Lean/Meta/Tactic/Grind/Split.lean index b61a13eaee..6441916af9 100644 --- a/src/Lean/Meta/Tactic/Grind/Split.lean +++ b/src/Lean/Meta/Tactic/Grind/Split.lean @@ -216,7 +216,11 @@ where return false else if numCases == 1 && !isRec && numCases' > 1 then return true - if (← getGeneration c.getExpr) < (← getGeneration c'.getExpr) then + /- + **Note**: We used to use `getGeneration c.getExpr` instead of `c.getGeneration`. + This was incorrect. The expression returned by `c.getExpr` may have not been internalized yet. + -/ + else if (← c.getGeneration) < (← c'.getGeneration) then return true return numCases < numCases' if (← isBetter) then diff --git a/src/Lean/Meta/Tactic/Grind/Types.lean b/src/Lean/Meta/Tactic/Grind/Types.lean index ca38a4d143..41204e656c 100644 --- a/src/Lean/Meta/Tactic/Grind/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/Types.lean @@ -1090,6 +1090,14 @@ def Goal.getGeneration (goal : Goal) (e : Expr) : Nat := else 0 +def SplitInfo.getGenerationCore (goal : Goal) : SplitInfo → Nat + | .default e _ => goal.getGeneration e + | .imp e h _ => goal.getGeneration (e.forallDomain h) + | .arg a b _ _ _ => max (goal.getGeneration a) (goal.getGeneration b) + +def SplitInfo.getGeneration (s : SplitInfo) : GoalM Nat := + return s.getGenerationCore (← get) + /-- Returns the generation of the given term. Is assumes it has been internalized -/ def getGeneration (e : Expr) : GoalM Nat := return (← get).getGeneration e diff --git a/tests/lean/run/grind_11697_a.lean b/tests/lean/run/grind_11697_a.lean new file mode 100644 index 0000000000..23d6adbc0c --- /dev/null +++ b/tests/lean/run/grind_11697_a.lean @@ -0,0 +1,10 @@ +namespace Nat + +@[grind =] +theorem testBit_shiftRight_shiftLeft_add {n j k : Nat} (x : Nat) : (x >>> n <<< (n + k)).testBit j = + (decide (n + k ≤ j) && x.testBit (j - k)) := by + grind + +theorem myTheorem {x : Nat} : x = x := by grind + +end Nat diff --git a/tests/lean/run/grind_11697_b.lean b/tests/lean/run/grind_11697_b.lean new file mode 100644 index 0000000000..e5cd1f74e4 --- /dev/null +++ b/tests/lean/run/grind_11697_b.lean @@ -0,0 +1,10 @@ +namespace Nat + +theorem myTheorem {x : Nat} : x = x := by grind + +@[grind =] +theorem testBit_shiftRight_shiftLeft_add {n j k : Nat} (x : Nat) : (x >>> n <<< (n + k)).testBit j = + (decide (n + k ≤ j) && x.testBit (j - k)) := by + grind + +end Nat