feat: case splitting in grind (#6717)

This PR introduces a new feature that allows users to specify which
inductive datatypes the `grind` tactic should perform case splits on.
The configuration option `splitIndPred` is now set to `false` by
default. The attribute `[grind cases]` is used to mark inductive
datatypes and predicates that `grind` may case split on during the
search. Additionally, the attribute `[grind cases eager]` can be used to
mark datatypes and predicates for case splitting both during
pre-processing and the search.

Users can also write `grind [HasType]` or `grind [cases HasType]` to
instruct `grind` to perform case splitting on the inductive predicate
`HasType` in a specific instance. Similarly, `grind [-Or]` can be used
to instruct `grind` not to case split on disjunctions.

Co-authored-by: Leonardo de Moura <leodemoura@amazon.com>
This commit is contained in:
Leonardo de Moura 2025-01-20 14:44:56 -08:00 committed by GitHub
parent c07f64a621
commit 189f5d41fb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 76 additions and 124 deletions

View file

@ -5,11 +5,7 @@ Authors: Leonardo de Moura
-/
prelude
import Init.Core
import Init.Grind.Tactics
attribute [grind_cases] And Prod False Empty True Unit Exists
namespace Lean.Grind.Eager
attribute [scoped grind_cases] Or
end Lean.Grind.Eager
attribute [grind cases eager] And Prod False Empty True Unit Exists
attribute [grind cases] Or

View file

@ -48,8 +48,8 @@ structure Config where
splitIte : Bool := true
/--
If `splitIndPred` is `true`, `grind` performs case-splitting on inductive predicates.
Otherwise, it performs case-splitting only on types marked with `[grind_split]` attribute. -/
splitIndPred : Bool := true
Otherwise, it performs case-splitting only on types marked with `[grind cases]` attribute. -/
splitIndPred : Bool := false
/-- By default, `grind` halts as soon as it encounters a sub-goal where no further progress can be made. -/
failures : Nat := 1
/-- Maximum number of heartbeats (in thousands) the canonicalizer can spend per definitional equality test. -/

View file

@ -8,57 +8,6 @@ import Lean.Meta.Tactic.Grind.EMatchTheorem
import Lean.Meta.Tactic.Grind.Cases
namespace Lean.Meta.Grind
--- TODO: delete
builtin_initialize grindCasesExt : SimpleScopedEnvExtension Name NameSet ←
registerSimpleScopedEnvExtension {
initial := {}
addEntry := fun s declName => s.insert declName
}
/--
Returns `true` if `declName` has been tagged with attribute `[grind_cases]`.
-/
def isGrindCasesTarget (declName : Name) : CoreM Bool :=
return grindCasesExt.getState (← getEnv) |>.contains declName
private def getAlias? (value : Expr) : MetaM (Option Name) :=
lambdaTelescope value fun _ body => do
if let .const declName _ := body.getAppFn' then
return some declName
else
return none
/--
Throws an error if `declName` cannot be annotated with attribute `[grind_cases]`.
We support the following cases:
- `declName` is a non-recursive datatype.
- `declName` is an abbreviation for a non-recursive datatype.
-/
private partial def validateGrindCasesAttr (declName : Name) : CoreM Unit := do
match (← getConstInfo declName) with
| .inductInfo info =>
if info.isRec then
throwError "`{declName}` is a recursive datatype"
| .defnInfo info =>
let failed := throwError "`{declName}` is a definition, but it is not an alias/abbreviation for an inductive datatype"
let some declName ← getAlias? info.value |>.run' {} {}
| failed
try
validateGrindCasesAttr declName
catch _ =>
failed
| _ =>
throwError "`{declName}` is not an inductive datatype or an alias for one"
builtin_initialize
registerBuiltinAttribute {
name := `grind_cases
descr := "`grind` tactic applies `cases` to (non-recursive) inductives during pre-processing step"
add := fun declName _ attrKind => do
validateGrindCasesAttr declName
grindCasesExt.add declName attrKind
}
--- END of TODO: detele
inductive AttrKind where
| ematch (k : TheoremKind)

View file

@ -32,6 +32,12 @@ def CasesTypes.insert (s : CasesTypes) (declName : Name) (eager : Bool) : CasesT
def CasesTypes.find? (s : CasesTypes) (declName : Name) : Option Bool :=
s.casesMap.find? declName
def CasesTypes.isEagerSplit (s : CasesTypes) (declName : Name) : Bool :=
s.casesMap.find? declName |>.getD false
def CasesTypes.isSplit (s : CasesTypes) (declName : Name) : Bool :=
s.casesMap.find? declName |>.isSome
builtin_initialize casesExt : SimpleScopedEnvExtension CasesEntry CasesTypes ←
registerSimpleScopedEnvExtension {
initial := {}

View file

@ -53,7 +53,6 @@ private def addSplitCandidate (e : Expr) : GoalM Unit := do
trace_goal[grind.split.candidate] "{e}"
modify fun s => { s with splitCandidates := e :: s.splitCandidates }
-- TODO: add attribute to make this extensible
private def forbiddenSplitTypes := [``Eq, ``HEq, ``True, ``False]
/-- Returns `true` if `e` is of the form `@Eq Prop a b` -/
@ -63,29 +62,37 @@ def isMorallyIff (e : Expr) : Bool :=
/-- Inserts `e` into the list of case-split candidates if applicable. -/
private def checkAndAddSplitCandidate (e : Expr) : GoalM Unit := do
unless e.isApp do return ()
if (← getConfig).splitIte && (e.isIte || e.isDIte) then
addSplitCandidate e
return ()
if isMorallyIff e then
addSplitCandidate e
return ()
if (← getConfig).splitMatch then
if (← isMatcherApp e) then
if let .reduced _ ← reduceMatcher? e then
-- When instantiating `match`-equations, we add `match`-applications that can be reduced,
-- and consequently don't need to be splitted.
match e with
| .app .. =>
if (← getConfig).splitIte && (e.isIte || e.isDIte) then
addSplitCandidate e
return ()
if isMorallyIff e then
addSplitCandidate e
return ()
if (← getConfig).splitMatch then
if (← isMatcherApp e) then
if let .reduced _ ← reduceMatcher? e then
-- When instantiating `match`-equations, we add `match`-applications that can be reduced,
-- and consequently don't need to be splitted.
return ()
else
addSplitCandidate e
return ()
let .const declName _ := e.getAppFn | return ()
if forbiddenSplitTypes.contains declName then
return ()
else
addSplitCandidate e
unless (← isInductivePredicate declName) do
return ()
let .const declName _ := e.getAppFn | return ()
if forbiddenSplitTypes.contains declName then return ()
-- We should have a mechanism for letting users to select types to case-split.
-- Right now, we just consider inductive predicates that are not in the forbidden list
if (← getConfig).splitIndPred then
if (← isInductivePredicate declName) then
if (← get).casesTypes.isSplit declName then
addSplitCandidate e
else if (← getConfig).splitIndPred then
addSplitCandidate e
| .fvar .. =>
let .const declName _ := (← inferType e).getAppFn | return ()
if (← get).casesTypes.isSplit declName then
addSplitCandidate e
| _ => pure ()
/--
If `e` is a `cast`-like term (e.g., `cast h a`), add `HEq e a` to the to-do list.

View file

@ -74,12 +74,12 @@ private def introNext (goal : Goal) (generation : Nat) : GrindM IntroResult := d
else
return .done
private def isCasesCandidate (type : Expr) : MetaM Bool := do
def isEagerCasesCandidate (goal : Goal) (type : Expr) : Bool := Id.run do
let .const declName _ := type.getAppFn | return false
isGrindCasesTarget declName
return goal.casesTypes.isEagerSplit declName
private def applyCases? (goal : Goal) (fvarId : FVarId) : MetaM (Option (List Goal)) := goal.mvarId.withContext do
if (← isCasesCandidate (← fvarId.getType)) then
if isEagerCasesCandidate goal (← fvarId.getType) then
let mvarIds ← cases goal.mvarId (mkFVar fvarId)
return mvarIds.map fun mvarId => { goal with mvarId }
else
@ -121,7 +121,7 @@ partial def intros (generation : Nat) : GrindTactic' := fun goal => do
/-- Asserts a new fact `prop` with proof `proof` to the given `goal`. -/
def assertAt (proof : Expr) (prop : Expr) (generation : Nat) : GrindTactic' := fun goal => do
if (← isCasesCandidate prop) then
if isEagerCasesCandidate goal prop then
let mvarId ← goal.mvarId.assert (← mkFreshUserName `h) prop proof
let goal := { goal with mvarId }
intros generation goal

View file

@ -67,7 +67,8 @@ private def mkGoal (mvarId : MVarId) (params : Params) : GrindM Goal := do
let falseExpr ← getFalseExpr
let natZeroExpr ← getNatZeroExpr
let thmMap := params.ematch
GoalM.run' { mvarId, thmMap } do
let casesTypes := params.casesTypes
GoalM.run' { mvarId, thmMap, casesTypes } do
mkENodeCore falseExpr (interpreted := true) (ctor := false) (generation := 0)
mkENodeCore trueExpr (interpreted := true) (ctor := false) (generation := 0)
mkENodeCore natZeroExpr (interpreted := true) (ctor := false) (generation := 0)

View file

@ -101,6 +101,13 @@ private def checkCaseSplitStatus (e : Expr) : GoalM CaseSplitStatus := do
if let some info ← isInductivePredicate? declName then
if (← isEqTrue e) then
return .ready info.ctors.length info.isRec
if e.isFVar then
let type ← whnfD (← inferType e)
let report : GoalM Unit := do
reportIssue "cannot perform case-split on {e}, unexpected type{indentExpr type}"
let .const declName _ := type.getAppFn | report; return .resolved
let .inductInfo info ← getConstInfo declName | report; return .resolved
return .ready info.ctors.length info.isRec
return .notReady
private inductive SplitCandidate where

View file

@ -16,6 +16,7 @@ import Lean.Meta.Tactic.Util
import Lean.Meta.Tactic.Ext
import Lean.Meta.Tactic.Grind.ENodeKey
import Lean.Meta.Tactic.Grind.Attr
import Lean.Meta.Tactic.Grind.Cases
import Lean.Meta.Tactic.Grind.Arith.Types
import Lean.Meta.Tactic.Grind.EMatchTheorem
@ -362,6 +363,8 @@ structure Goal where
nextIdx : Nat := 0
/-- State of arithmetic procedures -/
arith : Arith.State := {}
/-- Inductive datatypes marked for case-splitting -/
casesTypes : CasesTypes := {}
/-- Active theorems that we have performed ematching at least once. -/
thms : PArray EMatchTheorem := {}
/-- Active theorems that we have not performed any round of ematching yet. -/

View file

@ -1,35 +1,3 @@
/--
error: `List` is a recursive datatype
-/
#guard_msgs in
attribute [grind_cases] List
/--
error: `Prod.mk` is not an inductive datatype or an alias for one
-/
#guard_msgs in
attribute [grind_cases] Prod.mk
/--
error: `List.append` is a definition, but it is not an alias/abbreviation for an inductive datatype
-/
#guard_msgs in
attribute [grind_cases] List.append
attribute [grind_cases] Prod
def Foo (α : Type u) := Sum α α
attribute [grind_cases] Foo
attribute [grind_cases] And
attribute [grind_cases] False
attribute [grind_cases] Empty
-- TODO: delete everything above
/--
error: invalid `[grind cases eager]`, `List` is not a non-recursive inductive datatype or an alias for one
-/

View file

@ -39,7 +39,9 @@ h : c = true
theorem ex (h : (f a && (b || f (f c))) = true) (h' : p ∧ q) : b && a := by
grind
open Lean.Grind.Eager in
section
attribute [local grind cases eager] Or
/--
error: `grind` failed
case grind.2.1
@ -69,6 +71,8 @@ h : b = false
theorem ex2 (h : (f a && (b || f (f c))) = true) (h' : p ∧ q) : b && a := by
grind
end
def g (i : Nat) (j : Nat) (_ : i > j := by omega) := i + j
/--

View file

@ -54,10 +54,13 @@ inductive HasType : Expr → Ty → Prop
set_option trace.grind true
theorem HasType.det (h₁ : HasType e t₁) (h₂ : HasType e t₂) : t₁ = t₂ := by
grind
grind [HasType]
theorem HasType.det' (h₁ : HasType e t₁) (h₂ : HasType e t₂) : t₁ = t₂ := by
fail_if_success grind (splitIndPred := false)
example (h₁ : HasType e t₁) (h₂ : HasType e t₂) : t₁ = t₂ := by
grind +splitIndPred
example (h₁ : HasType e t₁) (h₂ : HasType e t₂) : t₁ = t₂ := by
fail_if_success grind
sorry
end grind_test_induct_pred

View file

@ -47,4 +47,4 @@ h : HEq ⋯ ⋯
-/
#guard_msgs (error) in
example {c : Nat} (q : X c 0) : False := by
grind
grind [X]

View file

@ -296,7 +296,15 @@ example {α β} (f : α → β) (a : α) : ∃ a', f a' = f a := by
open List in
example : (replicate n a).map f = replicate n (f a) := by
grind only [Option.map_some', Option.map_none', getElem?_map, getElem?_replicate]
grind +splitIndPred only [Option.map_some', Option.map_none', getElem?_map, getElem?_replicate]
open List in
example : (replicate n a).map f = replicate n (f a) := by
grind only [Exists, Option.map_some', Option.map_none', getElem?_map, getElem?_replicate]
open List in
example : (replicate n a).map f = replicate n (f a) := by
grind only [cases Exists, Option.map_some', Option.map_none', getElem?_map, getElem?_replicate]
open List in
example : (replicate n a).map f = replicate n (f a) := by