fix: case-splitting selection in grind (#11749)
This PR fixes a bug in the function `selectNextSplit?` used in `grind`. It was incorrectly computing the generation of each candidate. Closes #11697
This commit is contained in:
parent
c88ec35c0d
commit
5440bf724d
4 changed files with 33 additions and 1 deletions
|
|
@ -216,7 +216,11 @@ where
|
|||
return false
|
||||
else if numCases == 1 && !isRec && numCases' > 1 then
|
||||
return true
|
||||
if (← getGeneration c.getExpr) < (← getGeneration c'.getExpr) then
|
||||
/-
|
||||
**Note**: We used to use `getGeneration c.getExpr` instead of `c.getGeneration`.
|
||||
This was incorrect. The expression returned by `c.getExpr` may have not been internalized yet.
|
||||
-/
|
||||
else if (← c.getGeneration) < (← c'.getGeneration) then
|
||||
return true
|
||||
return numCases < numCases'
|
||||
if (← isBetter) then
|
||||
|
|
|
|||
|
|
@ -1090,6 +1090,14 @@ def Goal.getGeneration (goal : Goal) (e : Expr) : Nat :=
|
|||
else
|
||||
0
|
||||
|
||||
def SplitInfo.getGenerationCore (goal : Goal) : SplitInfo → Nat
|
||||
| .default e _ => goal.getGeneration e
|
||||
| .imp e h _ => goal.getGeneration (e.forallDomain h)
|
||||
| .arg a b _ _ _ => max (goal.getGeneration a) (goal.getGeneration b)
|
||||
|
||||
def SplitInfo.getGeneration (s : SplitInfo) : GoalM Nat :=
|
||||
return s.getGenerationCore (← get)
|
||||
|
||||
/-- Returns the generation of the given term. Is assumes it has been internalized -/
|
||||
def getGeneration (e : Expr) : GoalM Nat :=
|
||||
return (← get).getGeneration e
|
||||
|
|
|
|||
10
tests/lean/run/grind_11697_a.lean
Normal file
10
tests/lean/run/grind_11697_a.lean
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
namespace Nat
|
||||
|
||||
@[grind =]
|
||||
theorem testBit_shiftRight_shiftLeft_add {n j k : Nat} (x : Nat) : (x >>> n <<< (n + k)).testBit j =
|
||||
(decide (n + k ≤ j) && x.testBit (j - k)) := by
|
||||
grind
|
||||
|
||||
theorem myTheorem {x : Nat} : x = x := by grind
|
||||
|
||||
end Nat
|
||||
10
tests/lean/run/grind_11697_b.lean
Normal file
10
tests/lean/run/grind_11697_b.lean
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
namespace Nat
|
||||
|
||||
theorem myTheorem {x : Nat} : x = x := by grind
|
||||
|
||||
@[grind =]
|
||||
theorem testBit_shiftRight_shiftLeft_add {n j k : Nat} (x : Nat) : (x >>> n <<< (n + k)).testBit j =
|
||||
(decide (n + k ≤ j) && x.testBit (j - k)) := by
|
||||
grind
|
||||
|
||||
end Nat
|
||||
Loading…
Add table
Reference in a new issue