From 2a70da50c1478bb146bac7bb6986c09f24d326ec Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 16 Oct 2025 20:07:13 -0700 Subject: [PATCH] feat: proper case-split anchor generation in `splitNext` for `grind?` and `finish?` (#10811) This PR implements proper case-split anchor generation in the `splitNext` action, which will be used to implement `grind?` and `finish?`. --- src/Lean/Elab/Tactic/Grind/Show.lean | 44 +------------------------- src/Lean/Meta/Tactic/Grind/Anchor.lean | 20 ++++++++++++ src/Lean/Meta/Tactic/Grind/Split.lean | 41 +++++++++++++++++++++--- 3 files changed, 57 insertions(+), 48 deletions(-) diff --git a/src/Lean/Elab/Tactic/Grind/Show.lean b/src/Lean/Elab/Tactic/Grind/Show.lean index 7a4bcd116a..17abebba2f 100644 --- a/src/Lean/Elab/Tactic/Grind/Show.lean +++ b/src/Lean/Elab/Tactic/Grind/Show.lean @@ -165,53 +165,11 @@ def pushIfSome (msgs : Array MessageData) (msg? : Option MessageData) : Array Me logInfo <| MessageData.trace { cls := `grind, collapsed := false } "Grind state" msgs | _ => throwUnsupportedSyntax -def truncateAnchors (es : Array (UInt64 × α)) : Array (UInt64 × α) × Nat := - go 4 -where - go (numDigits : Nat) : Array (UInt64 × α) × Nat := Id.run do - if 4*numDigits < 64 then - let shift := 64 - 4*numDigits - let mut found : Std.HashSet UInt64 := {} - let mut result := #[] - for (a, e) in es do - let a' := a >>> shift.toUInt64 - if found.contains a' then - return (← go (numDigits+1)) - else - found := found.insert a' - result := result.push (a', e) - return (result, numDigits) - else - return (es, numDigits) - termination_by 64 - 4*numDigits - -def anchorToString (numDigits : Nat) (anchor : UInt64) : String := - let cs := Nat.toDigits 16 anchor.toNat - let n := cs.length - let zs := List.replicate (numDigits - n) '0' - let cs := zs ++ cs - cs.asString - @[builtin_grind_tactic showSplits] def evalShowSplits : GrindTactic := fun stx => withMainContext do match stx with | `(grind| show_splits $[$filter?]?) => let filter ← elabFilter filter? - let goal ← getMainGoal - let candidates := goal.split.candidates - let candidates ← liftGoalM <| candidates.toArray.mapM fun c => do - let e := c.getExpr - let anchor ← getAnchor e - let status ← checkSplitStatus c - return (e, status, anchor) - let candidates ← liftGoalM <| candidates.filterM fun (e, status, _) => do - -- **Note**: we ignore case-splits that are not ready or have already been resolved. - -- We may consider adding an option for including "not-ready" splits in the future. - if status matches .resolved | .notReady then return false - filter.eval e - -- **TODO**: Add an option for including propositions that are only considered when using `+splitImp` - -- **TODO**: Add an option for including terms whose type is an inductive predicate or type - let candidates := candidates.map fun (e, _, anchor) => (anchor, e) - let (candidates, numDigits) := truncateAnchors candidates + let { candidates, numDigits } ← liftGoalM <| getSplitCandidateAnchors filter.eval if candidates.isEmpty then throwError "no case splits" let msgs := candidates.map fun (a, e) => diff --git a/src/Lean/Meta/Tactic/Grind/Anchor.lean b/src/Lean/Meta/Tactic/Grind/Anchor.lean index 8ba97448d4..387255da79 100644 --- a/src/Lean/Meta/Tactic/Grind/Anchor.lean +++ b/src/Lean/Meta/Tactic/Grind/Anchor.lean @@ -87,4 +87,24 @@ public def isAnchorPrefix (numHexDigits : Nat) (anchorPrefix : UInt64) (anchor : let shift := 64 - numHexDigits.toUInt64*4 anchorPrefix == anchor >>> shift +public def truncateAnchors (es : Array (UInt64 × α)) : Array (UInt64 × α) × Nat := + go 4 +where + go (numDigits : Nat) : Array (UInt64 × α) × Nat := Id.run do + if 4*numDigits < 64 then + let shift := 64 - 4*numDigits + let mut found : Std.HashSet UInt64 := {} + let mut result := #[] + for (a, e) in es do + let a' := a >>> shift.toUInt64 + if found.contains a' then + return (← go (numDigits+1)) + else + found := found.insert a' + result := result.push (a', e) + return (result, numDigits) + else + return (es, numDigits) + termination_by 64 - 4*numDigits + end Lean.Meta.Grind diff --git a/src/Lean/Meta/Tactic/Grind/Split.lean b/src/Lean/Meta/Tactic/Grind/Split.lean index c32fdb5e19..c6b0611c19 100644 --- a/src/Lean/Meta/Tactic/Grind/Split.lean +++ b/src/Lean/Meta/Tactic/Grind/Split.lean @@ -242,6 +242,34 @@ private def casesWithTrace (mvarId : MVarId) (major : Expr) : GoalM (List MVarId saveCases declName false cases mvarId major +structure SplitCandidateAnchors where + /-- Pairs `(anchor, split)` -/ + candidates : Array (UInt64 × Expr) + /-- Number of digits (≥ 4) sufficient for distinguishing anchors. We usually display only the first `numDigits`. -/ + numDigits : Nat + +/-- +Returns case-split candidates. Case-splits that are tagged as `.resolved` or `.notReady` are skipped. +Applies additional `filter` if provided. +-/ +def getSplitCandidateAnchors (filter : Expr → GoalM Bool := fun _ => return true) : GoalM SplitCandidateAnchors := do + let candidates := (← get).split.candidates + let candidates ← candidates.toArray.mapM fun c => do + let e := c.getExpr + let anchor ← getAnchor e + let status ← checkSplitStatus c + return (e, status, anchor) + let candidates ← candidates.filterM fun (e, status, _) => do + -- **Note**: we ignore case-splits that are not ready or have already been resolved. + -- We may consider adding an option for including "not-ready" splits in the future. + if status matches .resolved | .notReady then return false + filter e + -- **TODO**: Add an option for including propositions that are only considered when using `+splitImp` + -- **TODO**: Add an option for including terms whose type is an inductive predicate or type + let candidates := candidates.map fun (e, _, anchor) => (anchor, e) + let (candidates, numDigits) := truncateAnchors candidates + return { candidates, numDigits } + namespace Action /-- @@ -325,9 +353,10 @@ Remark: `numCases` and `isRec` are computed using `checkSplitStatus`. private def splitCore (c : SplitInfo) (numCases : Nat) (isRec : Bool) (stopAtFirstFailure : Bool) (compress : Bool) : Action := fun goal _ kp => do + let traceEnabled := (← getConfig).trace let mvarId ← goal.mkAuxMVar let cExpr := c.getExpr - let (mvarIds, goal) ← GoalM.run goal do + let ((mvarIds, numDigits), goal) ← GoalM.run goal do let gen ← getGeneration cExpr let genNew := if numCases > 1 || isRec then gen+1 else gen saveSplitDiagInfo cExpr genNew numCases c.source @@ -339,8 +368,12 @@ private def splitCore (c : SplitInfo) (numCases : Nat) (isRec : Bool) casesMatch mvarId cExpr else casesWithTrace mvarId (← mkCasesMajor cExpr) + let numDigits ← if traceEnabled then + pure (← getSplitCandidateAnchors).numDigits + else + pure 0 + return (mvarIds, numDigits) let subgoals := mvarIds.map fun mvarId => { goal with mvarId } - let traceEnabled := (← getConfig).trace let mut seqNew : Array (List (TSyntax `grind)) := #[] let mut stuckNew : Array Goal := #[] for subgoal in subgoals do @@ -369,9 +402,7 @@ private def splitCore (c : SplitInfo) (numCases : Nat) (isRec : Bool) if stuckNew.isEmpty then if traceEnabled then let anchor ← goal.withContext <| getAnchor cExpr - -- **TODO**: compute the exact number of digits - let numDigits := 4 - let anchorPrefix := anchor >>> (64 - 16) + let anchorPrefix := anchor >>> (64 - 4*numDigits.toUInt64) let hexnum := mkNode `hexnum #[mkAtom (anchorToString numDigits anchorPrefix)] let cases ← `(grind| cases #$hexnum) return .closed (← mkCasesResultSeq cases seqNew compress)