feat: case splitting in grind (#6717)
This PR introduces a new feature that allows users to specify which inductive datatypes the `grind` tactic should perform case splits on. The configuration option `splitIndPred` is now set to `false` by default. The attribute `[grind cases]` is used to mark inductive datatypes and predicates that `grind` may case split on during the search. Additionally, the attribute `[grind cases eager]` can be used to mark datatypes and predicates for case splitting both during pre-processing and the search. Users can also write `grind [HasType]` or `grind [cases HasType]` to instruct `grind` to perform case splitting on the inductive predicate `HasType` in a specific instance. Similarly, `grind [-Or]` can be used to instruct `grind` not to case split on disjunctions. Co-authored-by: Leonardo de Moura <leodemoura@amazon.com>
This commit is contained in:
parent
c07f64a621
commit
189f5d41fb
14 changed files with 76 additions and 124 deletions
|
|
@ -5,11 +5,7 @@ Authors: Leonardo de Moura
|
|||
-/
|
||||
prelude
|
||||
import Init.Core
|
||||
import Init.Grind.Tactics
|
||||
|
||||
attribute [grind_cases] And Prod False Empty True Unit Exists
|
||||
|
||||
namespace Lean.Grind.Eager
|
||||
|
||||
attribute [scoped grind_cases] Or
|
||||
|
||||
end Lean.Grind.Eager
|
||||
attribute [grind cases eager] And Prod False Empty True Unit Exists
|
||||
attribute [grind cases] Or
|
||||
|
|
|
|||
|
|
@ -48,8 +48,8 @@ structure Config where
|
|||
splitIte : Bool := true
|
||||
/--
|
||||
If `splitIndPred` is `true`, `grind` performs case-splitting on inductive predicates.
|
||||
Otherwise, it performs case-splitting only on types marked with `[grind_split]` attribute. -/
|
||||
splitIndPred : Bool := true
|
||||
Otherwise, it performs case-splitting only on types marked with `[grind cases]` attribute. -/
|
||||
splitIndPred : Bool := false
|
||||
/-- By default, `grind` halts as soon as it encounters a sub-goal where no further progress can be made. -/
|
||||
failures : Nat := 1
|
||||
/-- Maximum number of heartbeats (in thousands) the canonicalizer can spend per definitional equality test. -/
|
||||
|
|
|
|||
|
|
@ -8,57 +8,6 @@ import Lean.Meta.Tactic.Grind.EMatchTheorem
|
|||
import Lean.Meta.Tactic.Grind.Cases
|
||||
|
||||
namespace Lean.Meta.Grind
|
||||
--- TODO: delete
|
||||
builtin_initialize grindCasesExt : SimpleScopedEnvExtension Name NameSet ←
|
||||
registerSimpleScopedEnvExtension {
|
||||
initial := {}
|
||||
addEntry := fun s declName => s.insert declName
|
||||
}
|
||||
|
||||
/--
|
||||
Returns `true` if `declName` has been tagged with attribute `[grind_cases]`.
|
||||
-/
|
||||
def isGrindCasesTarget (declName : Name) : CoreM Bool :=
|
||||
return grindCasesExt.getState (← getEnv) |>.contains declName
|
||||
|
||||
private def getAlias? (value : Expr) : MetaM (Option Name) :=
|
||||
lambdaTelescope value fun _ body => do
|
||||
if let .const declName _ := body.getAppFn' then
|
||||
return some declName
|
||||
else
|
||||
return none
|
||||
|
||||
/--
|
||||
Throws an error if `declName` cannot be annotated with attribute `[grind_cases]`.
|
||||
We support the following cases:
|
||||
- `declName` is a non-recursive datatype.
|
||||
- `declName` is an abbreviation for a non-recursive datatype.
|
||||
-/
|
||||
private partial def validateGrindCasesAttr (declName : Name) : CoreM Unit := do
|
||||
match (← getConstInfo declName) with
|
||||
| .inductInfo info =>
|
||||
if info.isRec then
|
||||
throwError "`{declName}` is a recursive datatype"
|
||||
| .defnInfo info =>
|
||||
let failed := throwError "`{declName}` is a definition, but it is not an alias/abbreviation for an inductive datatype"
|
||||
let some declName ← getAlias? info.value |>.run' {} {}
|
||||
| failed
|
||||
try
|
||||
validateGrindCasesAttr declName
|
||||
catch _ =>
|
||||
failed
|
||||
| _ =>
|
||||
throwError "`{declName}` is not an inductive datatype or an alias for one"
|
||||
|
||||
builtin_initialize
|
||||
registerBuiltinAttribute {
|
||||
name := `grind_cases
|
||||
descr := "`grind` tactic applies `cases` to (non-recursive) inductives during pre-processing step"
|
||||
add := fun declName _ attrKind => do
|
||||
validateGrindCasesAttr declName
|
||||
grindCasesExt.add declName attrKind
|
||||
}
|
||||
--- END of TODO: detele
|
||||
|
||||
inductive AttrKind where
|
||||
| ematch (k : TheoremKind)
|
||||
|
|
|
|||
|
|
@ -32,6 +32,12 @@ def CasesTypes.insert (s : CasesTypes) (declName : Name) (eager : Bool) : CasesT
|
|||
def CasesTypes.find? (s : CasesTypes) (declName : Name) : Option Bool :=
|
||||
s.casesMap.find? declName
|
||||
|
||||
def CasesTypes.isEagerSplit (s : CasesTypes) (declName : Name) : Bool :=
|
||||
s.casesMap.find? declName |>.getD false
|
||||
|
||||
def CasesTypes.isSplit (s : CasesTypes) (declName : Name) : Bool :=
|
||||
s.casesMap.find? declName |>.isSome
|
||||
|
||||
builtin_initialize casesExt : SimpleScopedEnvExtension CasesEntry CasesTypes ←
|
||||
registerSimpleScopedEnvExtension {
|
||||
initial := {}
|
||||
|
|
|
|||
|
|
@ -53,7 +53,6 @@ private def addSplitCandidate (e : Expr) : GoalM Unit := do
|
|||
trace_goal[grind.split.candidate] "{e}"
|
||||
modify fun s => { s with splitCandidates := e :: s.splitCandidates }
|
||||
|
||||
-- TODO: add attribute to make this extensible
|
||||
private def forbiddenSplitTypes := [``Eq, ``HEq, ``True, ``False]
|
||||
|
||||
/-- Returns `true` if `e` is of the form `@Eq Prop a b` -/
|
||||
|
|
@ -63,29 +62,37 @@ def isMorallyIff (e : Expr) : Bool :=
|
|||
|
||||
/-- Inserts `e` into the list of case-split candidates if applicable. -/
|
||||
private def checkAndAddSplitCandidate (e : Expr) : GoalM Unit := do
|
||||
unless e.isApp do return ()
|
||||
if (← getConfig).splitIte && (e.isIte || e.isDIte) then
|
||||
addSplitCandidate e
|
||||
return ()
|
||||
if isMorallyIff e then
|
||||
addSplitCandidate e
|
||||
return ()
|
||||
if (← getConfig).splitMatch then
|
||||
if (← isMatcherApp e) then
|
||||
if let .reduced _ ← reduceMatcher? e then
|
||||
-- When instantiating `match`-equations, we add `match`-applications that can be reduced,
|
||||
-- and consequently don't need to be splitted.
|
||||
match e with
|
||||
| .app .. =>
|
||||
if (← getConfig).splitIte && (e.isIte || e.isDIte) then
|
||||
addSplitCandidate e
|
||||
return ()
|
||||
if isMorallyIff e then
|
||||
addSplitCandidate e
|
||||
return ()
|
||||
if (← getConfig).splitMatch then
|
||||
if (← isMatcherApp e) then
|
||||
if let .reduced _ ← reduceMatcher? e then
|
||||
-- When instantiating `match`-equations, we add `match`-applications that can be reduced,
|
||||
-- and consequently don't need to be splitted.
|
||||
return ()
|
||||
else
|
||||
addSplitCandidate e
|
||||
return ()
|
||||
let .const declName _ := e.getAppFn | return ()
|
||||
if forbiddenSplitTypes.contains declName then
|
||||
return ()
|
||||
else
|
||||
addSplitCandidate e
|
||||
unless (← isInductivePredicate declName) do
|
||||
return ()
|
||||
let .const declName _ := e.getAppFn | return ()
|
||||
if forbiddenSplitTypes.contains declName then return ()
|
||||
-- We should have a mechanism for letting users to select types to case-split.
|
||||
-- Right now, we just consider inductive predicates that are not in the forbidden list
|
||||
if (← getConfig).splitIndPred then
|
||||
if (← isInductivePredicate declName) then
|
||||
if (← get).casesTypes.isSplit declName then
|
||||
addSplitCandidate e
|
||||
else if (← getConfig).splitIndPred then
|
||||
addSplitCandidate e
|
||||
| .fvar .. =>
|
||||
let .const declName _ := (← inferType e).getAppFn | return ()
|
||||
if (← get).casesTypes.isSplit declName then
|
||||
addSplitCandidate e
|
||||
| _ => pure ()
|
||||
|
||||
/--
|
||||
If `e` is a `cast`-like term (e.g., `cast h a`), add `HEq e a` to the to-do list.
|
||||
|
|
|
|||
|
|
@ -74,12 +74,12 @@ private def introNext (goal : Goal) (generation : Nat) : GrindM IntroResult := d
|
|||
else
|
||||
return .done
|
||||
|
||||
private def isCasesCandidate (type : Expr) : MetaM Bool := do
|
||||
def isEagerCasesCandidate (goal : Goal) (type : Expr) : Bool := Id.run do
|
||||
let .const declName _ := type.getAppFn | return false
|
||||
isGrindCasesTarget declName
|
||||
return goal.casesTypes.isEagerSplit declName
|
||||
|
||||
private def applyCases? (goal : Goal) (fvarId : FVarId) : MetaM (Option (List Goal)) := goal.mvarId.withContext do
|
||||
if (← isCasesCandidate (← fvarId.getType)) then
|
||||
if isEagerCasesCandidate goal (← fvarId.getType) then
|
||||
let mvarIds ← cases goal.mvarId (mkFVar fvarId)
|
||||
return mvarIds.map fun mvarId => { goal with mvarId }
|
||||
else
|
||||
|
|
@ -121,7 +121,7 @@ partial def intros (generation : Nat) : GrindTactic' := fun goal => do
|
|||
|
||||
/-- Asserts a new fact `prop` with proof `proof` to the given `goal`. -/
|
||||
def assertAt (proof : Expr) (prop : Expr) (generation : Nat) : GrindTactic' := fun goal => do
|
||||
if (← isCasesCandidate prop) then
|
||||
if isEagerCasesCandidate goal prop then
|
||||
let mvarId ← goal.mvarId.assert (← mkFreshUserName `h) prop proof
|
||||
let goal := { goal with mvarId }
|
||||
intros generation goal
|
||||
|
|
|
|||
|
|
@ -67,7 +67,8 @@ private def mkGoal (mvarId : MVarId) (params : Params) : GrindM Goal := do
|
|||
let falseExpr ← getFalseExpr
|
||||
let natZeroExpr ← getNatZeroExpr
|
||||
let thmMap := params.ematch
|
||||
GoalM.run' { mvarId, thmMap } do
|
||||
let casesTypes := params.casesTypes
|
||||
GoalM.run' { mvarId, thmMap, casesTypes } do
|
||||
mkENodeCore falseExpr (interpreted := true) (ctor := false) (generation := 0)
|
||||
mkENodeCore trueExpr (interpreted := true) (ctor := false) (generation := 0)
|
||||
mkENodeCore natZeroExpr (interpreted := true) (ctor := false) (generation := 0)
|
||||
|
|
|
|||
|
|
@ -101,6 +101,13 @@ private def checkCaseSplitStatus (e : Expr) : GoalM CaseSplitStatus := do
|
|||
if let some info ← isInductivePredicate? declName then
|
||||
if (← isEqTrue e) then
|
||||
return .ready info.ctors.length info.isRec
|
||||
if e.isFVar then
|
||||
let type ← whnfD (← inferType e)
|
||||
let report : GoalM Unit := do
|
||||
reportIssue "cannot perform case-split on {e}, unexpected type{indentExpr type}"
|
||||
let .const declName _ := type.getAppFn | report; return .resolved
|
||||
let .inductInfo info ← getConstInfo declName | report; return .resolved
|
||||
return .ready info.ctors.length info.isRec
|
||||
return .notReady
|
||||
|
||||
private inductive SplitCandidate where
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ import Lean.Meta.Tactic.Util
|
|||
import Lean.Meta.Tactic.Ext
|
||||
import Lean.Meta.Tactic.Grind.ENodeKey
|
||||
import Lean.Meta.Tactic.Grind.Attr
|
||||
import Lean.Meta.Tactic.Grind.Cases
|
||||
import Lean.Meta.Tactic.Grind.Arith.Types
|
||||
import Lean.Meta.Tactic.Grind.EMatchTheorem
|
||||
|
||||
|
|
@ -362,6 +363,8 @@ structure Goal where
|
|||
nextIdx : Nat := 0
|
||||
/-- State of arithmetic procedures -/
|
||||
arith : Arith.State := {}
|
||||
/-- Inductive datatypes marked for case-splitting -/
|
||||
casesTypes : CasesTypes := {}
|
||||
/-- Active theorems that we have performed ematching at least once. -/
|
||||
thms : PArray EMatchTheorem := {}
|
||||
/-- Active theorems that we have not performed any round of ematching yet. -/
|
||||
|
|
|
|||
|
|
@ -1,35 +1,3 @@
|
|||
/--
|
||||
error: `List` is a recursive datatype
|
||||
-/
|
||||
#guard_msgs in
|
||||
attribute [grind_cases] List
|
||||
|
||||
/--
|
||||
error: `Prod.mk` is not an inductive datatype or an alias for one
|
||||
-/
|
||||
#guard_msgs in
|
||||
attribute [grind_cases] Prod.mk
|
||||
|
||||
/--
|
||||
error: `List.append` is a definition, but it is not an alias/abbreviation for an inductive datatype
|
||||
-/
|
||||
#guard_msgs in
|
||||
attribute [grind_cases] List.append
|
||||
|
||||
attribute [grind_cases] Prod
|
||||
|
||||
def Foo (α : Type u) := Sum α α
|
||||
|
||||
attribute [grind_cases] Foo
|
||||
|
||||
attribute [grind_cases] And
|
||||
|
||||
attribute [grind_cases] False
|
||||
|
||||
attribute [grind_cases] Empty
|
||||
|
||||
-- TODO: delete everything above
|
||||
|
||||
/--
|
||||
error: invalid `[grind cases eager]`, `List` is not a non-recursive inductive datatype or an alias for one
|
||||
-/
|
||||
|
|
|
|||
|
|
@ -39,7 +39,9 @@ h : c = true
|
|||
theorem ex (h : (f a && (b || f (f c))) = true) (h' : p ∧ q) : b && a := by
|
||||
grind
|
||||
|
||||
open Lean.Grind.Eager in
|
||||
section
|
||||
attribute [local grind cases eager] Or
|
||||
|
||||
/--
|
||||
error: `grind` failed
|
||||
case grind.2.1
|
||||
|
|
@ -69,6 +71,8 @@ h : b = false
|
|||
theorem ex2 (h : (f a && (b || f (f c))) = true) (h' : p ∧ q) : b && a := by
|
||||
grind
|
||||
|
||||
end
|
||||
|
||||
def g (i : Nat) (j : Nat) (_ : i > j := by omega) := i + j
|
||||
|
||||
/--
|
||||
|
|
|
|||
|
|
@ -54,10 +54,13 @@ inductive HasType : Expr → Ty → Prop
|
|||
|
||||
set_option trace.grind true
|
||||
theorem HasType.det (h₁ : HasType e t₁) (h₂ : HasType e t₂) : t₁ = t₂ := by
|
||||
grind
|
||||
grind [HasType]
|
||||
|
||||
theorem HasType.det' (h₁ : HasType e t₁) (h₂ : HasType e t₂) : t₁ = t₂ := by
|
||||
fail_if_success grind (splitIndPred := false)
|
||||
example (h₁ : HasType e t₁) (h₂ : HasType e t₂) : t₁ = t₂ := by
|
||||
grind +splitIndPred
|
||||
|
||||
example (h₁ : HasType e t₁) (h₂ : HasType e t₂) : t₁ = t₂ := by
|
||||
fail_if_success grind
|
||||
sorry
|
||||
|
||||
end grind_test_induct_pred
|
||||
|
|
|
|||
|
|
@ -47,4 +47,4 @@ h : HEq ⋯ ⋯
|
|||
-/
|
||||
#guard_msgs (error) in
|
||||
example {c : Nat} (q : X c 0) : False := by
|
||||
grind
|
||||
grind [X]
|
||||
|
|
|
|||
|
|
@ -296,7 +296,15 @@ example {α β} (f : α → β) (a : α) : ∃ a', f a' = f a := by
|
|||
|
||||
open List in
|
||||
example : (replicate n a).map f = replicate n (f a) := by
|
||||
grind only [Option.map_some', Option.map_none', getElem?_map, getElem?_replicate]
|
||||
grind +splitIndPred only [Option.map_some', Option.map_none', getElem?_map, getElem?_replicate]
|
||||
|
||||
open List in
|
||||
example : (replicate n a).map f = replicate n (f a) := by
|
||||
grind only [Exists, Option.map_some', Option.map_none', getElem?_map, getElem?_replicate]
|
||||
|
||||
open List in
|
||||
example : (replicate n a).map f = replicate n (f a) := by
|
||||
grind only [cases Exists, Option.map_some', Option.map_none', getElem?_map, getElem?_replicate]
|
||||
|
||||
open List in
|
||||
example : (replicate n a).map f = replicate n (f a) := by
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue