feat: proper case-split anchor generation in splitNext for grind? and finish? (#10811)
This PR implements proper case-split anchor generation in the `splitNext` action, which will be used to implement `grind?` and `finish?`.
This commit is contained in:
parent
effde06296
commit
2a70da50c1
3 changed files with 57 additions and 48 deletions
|
|
@ -165,53 +165,11 @@ def pushIfSome (msgs : Array MessageData) (msg? : Option MessageData) : Array Me
|
|||
logInfo <| MessageData.trace { cls := `grind, collapsed := false } "Grind state" msgs
|
||||
| _ => throwUnsupportedSyntax
|
||||
|
||||
def truncateAnchors (es : Array (UInt64 × α)) : Array (UInt64 × α) × Nat :=
|
||||
go 4
|
||||
where
|
||||
go (numDigits : Nat) : Array (UInt64 × α) × Nat := Id.run do
|
||||
if 4*numDigits < 64 then
|
||||
let shift := 64 - 4*numDigits
|
||||
let mut found : Std.HashSet UInt64 := {}
|
||||
let mut result := #[]
|
||||
for (a, e) in es do
|
||||
let a' := a >>> shift.toUInt64
|
||||
if found.contains a' then
|
||||
return (← go (numDigits+1))
|
||||
else
|
||||
found := found.insert a'
|
||||
result := result.push (a', e)
|
||||
return (result, numDigits)
|
||||
else
|
||||
return (es, numDigits)
|
||||
termination_by 64 - 4*numDigits
|
||||
|
||||
def anchorToString (numDigits : Nat) (anchor : UInt64) : String :=
|
||||
let cs := Nat.toDigits 16 anchor.toNat
|
||||
let n := cs.length
|
||||
let zs := List.replicate (numDigits - n) '0'
|
||||
let cs := zs ++ cs
|
||||
cs.asString
|
||||
|
||||
@[builtin_grind_tactic showSplits] def evalShowSplits : GrindTactic := fun stx => withMainContext do
|
||||
match stx with
|
||||
| `(grind| show_splits $[$filter?]?) =>
|
||||
let filter ← elabFilter filter?
|
||||
let goal ← getMainGoal
|
||||
let candidates := goal.split.candidates
|
||||
let candidates ← liftGoalM <| candidates.toArray.mapM fun c => do
|
||||
let e := c.getExpr
|
||||
let anchor ← getAnchor e
|
||||
let status ← checkSplitStatus c
|
||||
return (e, status, anchor)
|
||||
let candidates ← liftGoalM <| candidates.filterM fun (e, status, _) => do
|
||||
-- **Note**: we ignore case-splits that are not ready or have already been resolved.
|
||||
-- We may consider adding an option for including "not-ready" splits in the future.
|
||||
if status matches .resolved | .notReady then return false
|
||||
filter.eval e
|
||||
-- **TODO**: Add an option for including propositions that are only considered when using `+splitImp`
|
||||
-- **TODO**: Add an option for including terms whose type is an inductive predicate or type
|
||||
let candidates := candidates.map fun (e, _, anchor) => (anchor, e)
|
||||
let (candidates, numDigits) := truncateAnchors candidates
|
||||
let { candidates, numDigits } ← liftGoalM <| getSplitCandidateAnchors filter.eval
|
||||
if candidates.isEmpty then
|
||||
throwError "no case splits"
|
||||
let msgs := candidates.map fun (a, e) =>
|
||||
|
|
|
|||
|
|
@ -87,4 +87,24 @@ public def isAnchorPrefix (numHexDigits : Nat) (anchorPrefix : UInt64) (anchor :
|
|||
let shift := 64 - numHexDigits.toUInt64*4
|
||||
anchorPrefix == anchor >>> shift
|
||||
|
||||
public def truncateAnchors (es : Array (UInt64 × α)) : Array (UInt64 × α) × Nat :=
|
||||
go 4
|
||||
where
|
||||
go (numDigits : Nat) : Array (UInt64 × α) × Nat := Id.run do
|
||||
if 4*numDigits < 64 then
|
||||
let shift := 64 - 4*numDigits
|
||||
let mut found : Std.HashSet UInt64 := {}
|
||||
let mut result := #[]
|
||||
for (a, e) in es do
|
||||
let a' := a >>> shift.toUInt64
|
||||
if found.contains a' then
|
||||
return (← go (numDigits+1))
|
||||
else
|
||||
found := found.insert a'
|
||||
result := result.push (a', e)
|
||||
return (result, numDigits)
|
||||
else
|
||||
return (es, numDigits)
|
||||
termination_by 64 - 4*numDigits
|
||||
|
||||
end Lean.Meta.Grind
|
||||
|
|
|
|||
|
|
@ -242,6 +242,34 @@ private def casesWithTrace (mvarId : MVarId) (major : Expr) : GoalM (List MVarId
|
|||
saveCases declName false
|
||||
cases mvarId major
|
||||
|
||||
structure SplitCandidateAnchors where
|
||||
/-- Pairs `(anchor, split)` -/
|
||||
candidates : Array (UInt64 × Expr)
|
||||
/-- Number of digits (≥ 4) sufficient for distinguishing anchors. We usually display only the first `numDigits`. -/
|
||||
numDigits : Nat
|
||||
|
||||
/--
|
||||
Returns case-split candidates. Case-splits that are tagged as `.resolved` or `.notReady` are skipped.
|
||||
Applies additional `filter` if provided.
|
||||
-/
|
||||
def getSplitCandidateAnchors (filter : Expr → GoalM Bool := fun _ => return true) : GoalM SplitCandidateAnchors := do
|
||||
let candidates := (← get).split.candidates
|
||||
let candidates ← candidates.toArray.mapM fun c => do
|
||||
let e := c.getExpr
|
||||
let anchor ← getAnchor e
|
||||
let status ← checkSplitStatus c
|
||||
return (e, status, anchor)
|
||||
let candidates ← candidates.filterM fun (e, status, _) => do
|
||||
-- **Note**: we ignore case-splits that are not ready or have already been resolved.
|
||||
-- We may consider adding an option for including "not-ready" splits in the future.
|
||||
if status matches .resolved | .notReady then return false
|
||||
filter e
|
||||
-- **TODO**: Add an option for including propositions that are only considered when using `+splitImp`
|
||||
-- **TODO**: Add an option for including terms whose type is an inductive predicate or type
|
||||
let candidates := candidates.map fun (e, _, anchor) => (anchor, e)
|
||||
let (candidates, numDigits) := truncateAnchors candidates
|
||||
return { candidates, numDigits }
|
||||
|
||||
namespace Action
|
||||
|
||||
/--
|
||||
|
|
@ -325,9 +353,10 @@ Remark: `numCases` and `isRec` are computed using `checkSplitStatus`.
|
|||
private def splitCore (c : SplitInfo) (numCases : Nat) (isRec : Bool)
|
||||
(stopAtFirstFailure : Bool)
|
||||
(compress : Bool) : Action := fun goal _ kp => do
|
||||
let traceEnabled := (← getConfig).trace
|
||||
let mvarId ← goal.mkAuxMVar
|
||||
let cExpr := c.getExpr
|
||||
let (mvarIds, goal) ← GoalM.run goal do
|
||||
let ((mvarIds, numDigits), goal) ← GoalM.run goal do
|
||||
let gen ← getGeneration cExpr
|
||||
let genNew := if numCases > 1 || isRec then gen+1 else gen
|
||||
saveSplitDiagInfo cExpr genNew numCases c.source
|
||||
|
|
@ -339,8 +368,12 @@ private def splitCore (c : SplitInfo) (numCases : Nat) (isRec : Bool)
|
|||
casesMatch mvarId cExpr
|
||||
else
|
||||
casesWithTrace mvarId (← mkCasesMajor cExpr)
|
||||
let numDigits ← if traceEnabled then
|
||||
pure (← getSplitCandidateAnchors).numDigits
|
||||
else
|
||||
pure 0
|
||||
return (mvarIds, numDigits)
|
||||
let subgoals := mvarIds.map fun mvarId => { goal with mvarId }
|
||||
let traceEnabled := (← getConfig).trace
|
||||
let mut seqNew : Array (List (TSyntax `grind)) := #[]
|
||||
let mut stuckNew : Array Goal := #[]
|
||||
for subgoal in subgoals do
|
||||
|
|
@ -369,9 +402,7 @@ private def splitCore (c : SplitInfo) (numCases : Nat) (isRec : Bool)
|
|||
if stuckNew.isEmpty then
|
||||
if traceEnabled then
|
||||
let anchor ← goal.withContext <| getAnchor cExpr
|
||||
-- **TODO**: compute the exact number of digits
|
||||
let numDigits := 4
|
||||
let anchorPrefix := anchor >>> (64 - 16)
|
||||
let anchorPrefix := anchor >>> (64 - 4*numDigits.toUInt64)
|
||||
let hexnum := mkNode `hexnum #[mkAtom (anchorToString numDigits anchorPrefix)]
|
||||
let cases ← `(grind| cases #$hexnum)
|
||||
return .closed (← mkCasesResultSeq cases seqNew compress)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue