feat: add simp-like parameters to grind (#6675)

This PR adds `simp`-like parameters to `grind`, and `grind only` similar
to `simp only`.
This commit is contained in:
Leonardo de Moura 2025-01-16 17:08:45 -08:00 committed by GitHub
parent 60142c967c
commit 35a4da28ac
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 202 additions and 47 deletions

View file

@ -14,7 +14,9 @@ syntax grindEqRhs := atomic("=" "_")
syntax grindBwd := "←"
syntax grindFwd := "→"
syntax (name := grind) "grind" (grindEqBoth <|> grindEqRhs <|> grindEq <|> grindBwd <|> grindFwd)? : attr
syntax grindThmMod := grindEqBoth <|> grindEqRhs <|> grindEq <|> grindBwd <|> grindFwd
syntax (name := grind) "grind" (grindThmMod)? : attr
end Lean.Parser.Attr
@ -59,7 +61,13 @@ namespace Lean.Parser.Tactic
`grind` tactic and related tactics.
-/
-- TODO: parameters
syntax (name := grind) "grind" optConfig ("on_failure " term)? : tactic
syntax grindErase := "-" ident
syntax grindLemma := (Attr.grindThmMod)? ident
syntax grindParam := grindErase <|> grindLemma
syntax (name := grind)
"grind" optConfig (&" only")?
(" [" withoutPosition(grindParam,*) "]")?
("on_failure " term)? : tactic
end Lean.Parser.Tactic

View file

@ -34,8 +34,44 @@ def elabGrindPattern : CommandElab := fun stx => do
Grind.addEMatchTheorem declName xs.size patterns.toList
| _ => throwUnsupportedSyntax
def grind (mvarId : MVarId) (config : Grind.Config) (mainDeclName : Name) (fallback : Grind.Fallback) : MetaM Unit := do
let goals ← Grind.main mvarId config mainDeclName fallback
def elabGrindParams (params : Grind.Params) (ps : TSyntaxArray ``Parser.Tactic.grindParam) : MetaM Grind.Params := do
let mut params := params
for p in ps do
match p with
| `(Parser.Tactic.grindParam| - $id:ident) =>
let declName ← realizeGlobalConstNoOverloadWithInfo id
if (← isInductivePredicate declName) then
throwErrorAt p "NIY"
else
params := { params with ematch := (← params.ematch.eraseDecl declName) }
| `(Parser.Tactic.grindParam| $[$mod?:grindThmMod]? $id:ident) =>
let declName ← realizeGlobalConstNoOverloadWithInfo id
let kind ← if let some mod := mod? then Grind.getTheoremKindCore mod else pure .default
if (← isInductivePredicate declName) then
throwErrorAt p "NIY"
else if (← getConstInfo declName).isTheorem then
params := { params with extra := params.extra.push (← Grind.mkEMatchTheoremForDecl declName kind) }
else if let some eqns ← getEqnsFor? declName then
for eqn in eqns do
params := { params with extra := params.extra.push (← Grind.mkEMatchTheoremForDecl eqn kind) }
else
throwError "invalid `grind` parameter, `{declName}` is not a theorem, definition, or inductive type"
| _ => throwError "unexpected `grind` parameter{indentD p}"
return params
def mkGrindParams (config : Grind.Config) (only : Bool) (ps : TSyntaxArray ``Parser.Tactic.grindParam) : MetaM Grind.Params := do
let params ← Grind.mkParams config
let ematch ← if only then pure {} else Grind.getEMatchTheorems
let params := { params with ematch }
elabGrindParams params ps
def grind
(mvarId : MVarId) (config : Grind.Config)
(only : Bool)
(ps : TSyntaxArray ``Parser.Tactic.grindParam)
(mainDeclName : Name) (fallback : Grind.Fallback) : MetaM Unit := do
let params ← mkGrindParams config only ps
let goals ← Grind.main mvarId params mainDeclName fallback
unless goals.isEmpty do
throwError "`grind` failed\n{← Grind.goalsToMessageData goals config}"
@ -58,12 +94,14 @@ private def elabFallback (fallback? : Option Term) : TermElabM (Grind.GoalM Unit
@[builtin_tactic Lean.Parser.Tactic.grind] def evalApplyRfl : Tactic := fun stx => do
match stx with
| `(tactic| grind $config:optConfig $[on_failure $fallback?]?) =>
| `(tactic| grind $config:optConfig $[only%$only]? $[ [$params:grindParam,*] ]? $[on_failure $fallback?]?) =>
let fallback ← elabFallback fallback?
let only := only.isSome
let params := if let some params := params then params.getElems else #[]
logWarningAt stx "The `grind` tactic is experimental and still under development. Avoid using it in production projects"
let declName := (← Term.getDeclName?).getD `_grind
let config ← elabGrindConfig config
withMainContext do liftMetaFinishingTactic (grind · config declName fallback)
withMainContext do liftMetaFinishingTactic (grind · config only params declName fallback)
| _ => throwUnsupportedSyntax
end Lean.Elab.Tactic

View file

@ -551,7 +551,7 @@ def getEMatchTheorems : CoreM EMatchTheorems :=
inductive TheoremKind where
| eqLhs | eqRhs | eqBoth | fwd | bwd | default
deriving Inhabited, BEq
deriving Inhabited, BEq, Repr
private def TheoremKind.toAttribute : TheoremKind → String
| .eqLhs => "[grind =]"
@ -677,19 +677,22 @@ where
levelParams, origin
}
private def getKind (stx : Syntax) : TheoremKind :=
/-- Return theorem kind for `stx` of the form `Attr.grindThmMod` -/
def getTheoremKindCore (stx : Syntax) : CoreM TheoremKind := do
match stx with
| `(Parser.Attr.grindThmMod| =) => return .eqLhs
| `(Parser.Attr.grindThmMod| →) => return .fwd
| `(Parser.Attr.grindThmMod| ←) => return .bwd
| `(Parser.Attr.grindThmMod| =_) => return .eqRhs
| `(Parser.Attr.grindThmMod| _=_) => return .eqBoth
| _ => throwError "unexpected `grind` theorem kind: `{stx}`"
/-- Return theorem kind for `stx` of the form `(Attr.grindThmMod)?` -/
def getTheoremKindFromOpt (stx : Syntax) : CoreM TheoremKind := do
if stx[1].isNone then
.default
else if stx[1][0].getKind == ``Parser.Attr.grindEq then
.eqLhs
else if stx[1][0].getKind == ``Parser.Attr.grindFwd then
.fwd
else if stx[1][0].getKind == ``Parser.Attr.grindEqRhs then
.eqRhs
else if stx[1][0].getKind == ``Parser.Attr.grindEqBoth then
.eqBoth
return .default
else
.bwd
getTheoremKindCore stx[1][0]
private def addGrindEqAttr (declName : Name) (attrKind : AttributeKind) (thmKind : TheoremKind) (useLhs := true) : MetaM Unit := do
if (← getConstInfo declName).isTheorem then
@ -702,6 +705,11 @@ private def addGrindEqAttr (declName : Name) (attrKind : AttributeKind) (thmKind
else
throwError s!"`{thmKind.toAttribute}` attribute can only be applied to equational theorems or function definitions"
def mkEMatchTheoremForDecl (declName : Name) (thmKind : TheoremKind) : MetaM EMatchTheorem := do
let some thm ← mkEMatchTheoremWithKind? (.decl declName) #[] (← getProofFor declName) thmKind
| throwError "`@{thmKind.toAttribute} theorem {declName}` {thmKind.explainFailure}, consider using different options or the `grind_pattern` command"
return thm
private def addGrindAttr (declName : Name) (attrKind : AttributeKind) (thmKind : TheoremKind) : MetaM Unit := do
if thmKind == .eqLhs then
addGrindEqAttr declName attrKind thmKind (useLhs := true)
@ -713,10 +721,26 @@ private def addGrindAttr (declName : Name) (attrKind : AttributeKind) (thmKind :
else if !(← getConstInfo declName).isTheorem then
addGrindEqAttr declName attrKind thmKind
else
let some thm ← mkEMatchTheoremWithKind? (.decl declName) #[] (← getProofFor declName) thmKind
| throwError "`@{thmKind.toAttribute} theorem {declName}` {thmKind.explainFailure}, consider using different options or the `grind_pattern` command"
let thm ← mkEMatchTheoremForDecl declName thmKind
ematchTheoremsExt.add thm attrKind
def EMatchTheorems.eraseDecl (s : EMatchTheorems) (declName : Name) : MetaM EMatchTheorems := do
let throwErr {α} : MetaM α :=
throwError "`{declName}` is not marked with the `[grind]` attribute"
let info ← getConstInfo declName
if !info.isTheorem then
if let some eqns ← getEqnsFor? declName then
let s := ematchTheoremsExt.getState (← getEnv)
unless eqns.all fun eqn => s.contains (.decl eqn) do
throwErr
return eqns.foldl (init := s) fun s eqn => s.erase (.decl eqn)
else
throwErr
else
unless ematchTheoremsExt.getState (← getEnv) |>.contains (.decl declName) do
throwErr
return s.erase <| .decl declName
builtin_initialize
registerBuiltinAttribute {
name := `grind
@ -739,7 +763,7 @@ builtin_initialize
`grind` will add an instance of this theorem to the local context whenever it encounters the pattern `foo (foo x)`."
applicationTime := .afterCompilation
add := fun declName stx attrKind => do
addGrindAttr declName attrKind (getKind stx) |>.run' {}
addGrindAttr declName attrKind (← getTheoremKindFromOpt stx) |>.run' {}
erase := fun declName => MetaM.run' do
/-
Remark: consider the following example
@ -755,21 +779,9 @@ builtin_initialize
attribute [-grind] foo -- ok
```
-/
let throwErr := throwError "`{declName}` is not marked with the `[grind]` attribute"
let info ← getConstInfo declName
if !info.isTheorem then
if let some eqns ← getEqnsFor? declName then
let s := ematchTheoremsExt.getState (← getEnv)
unless eqns.all fun eqn => s.contains (.decl eqn) do
throwErr
modifyEnv fun env => ematchTheoremsExt.modifyState env fun s =>
eqns.foldl (init := s) fun s eqn => s.erase (.decl eqn)
else
throwErr
else
unless ematchTheoremsExt.getState (← getEnv) |>.contains (.decl declName) do
throwErr
modifyEnv fun env => ematchTheoremsExt.modifyState env fun s => s.erase (.decl declName)
let s := ematchTheoremsExt.getState (← getEnv)
let s ← s.eraseDecl declName
modifyEnv fun env => ematchTheoremsExt.modifyState env fun _ => s
}
end Lean.Meta.Grind

View file

@ -20,6 +20,19 @@ import Lean.Meta.Tactic.Grind.SimpUtil
namespace Lean.Meta.Grind
structure Params where
config : Grind.Config
ematch : EMatchTheorems := {}
extra : PArray EMatchTheorem := {}
norm : Simp.Context
normProcs : Array Simprocs
-- TODO: inductives to split
def mkParams (config : Grind.Config) : MetaM Params := do
let norm ← Grind.getSimpContext
let normProcs ← Grind.getSimprocs
return { config, norm, normProcs }
def mkMethods (fallback : Fallback) : CoreM Methods := do
let builtinPropagators ← builtinPropagatorsRef.get
return {
@ -37,26 +50,29 @@ def mkMethods (fallback : Fallback) : CoreM Methods := do
prop e
}
def GrindM.run (x : GrindM α) (mainDeclName : Name) (config : Grind.Config) (fallback : Fallback) : MetaM α := do
def GrindM.run (x : GrindM α) (mainDeclName : Name) (params : Params) (fallback : Fallback) : MetaM α := do
let scState := ShareCommon.State.mk _
let (falseExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``False)
let (trueExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``True)
let (natZExpr, scState) := ShareCommon.State.shareCommon scState (mkNatLit 0)
let simprocs ← Grind.getSimprocs
let simp ← Grind.getSimpContext
let simprocs := params.normProcs
let simp := params.norm
let config := params.config
x (← mkMethods fallback).toMethodsRef { mainDeclName, config, simprocs, simp } |>.run' { scState, trueExpr, falseExpr, natZExpr }
private def mkGoal (mvarId : MVarId) : GrindM Goal := do
private def mkGoal (mvarId : MVarId) (params : Params) : GrindM Goal := do
let trueExpr ← getTrueExpr
let falseExpr ← getFalseExpr
let natZeroExpr ← getNatZeroExpr
let thmMap ← getEMatchTheorems
let thmMap := params.ematch
GoalM.run' { mvarId, thmMap } do
mkENodeCore falseExpr (interpreted := true) (ctor := false) (generation := 0)
mkENodeCore trueExpr (interpreted := true) (ctor := false) (generation := 0)
mkENodeCore natZeroExpr (interpreted := true) (ctor := false) (generation := 0)
for thm in params.extra do
activateTheorem thm 0
private def initCore (mvarId : MVarId) : GrindM (List Goal) := do
private def initCore (mvarId : MVarId) (params : Params) : GrindM (List Goal) := do
mvarId.ensureProp
-- TODO: abstract metavars
mvarId.ensureNoMVar
@ -65,13 +81,13 @@ private def initCore (mvarId : MVarId) : GrindM (List Goal) := do
let mvarId ← mvarId.unfoldReducible
let mvarId ← mvarId.betaReduce
appendTagSuffix mvarId `grind
let goals ← intros (← mkGoal mvarId) (generation := 0)
let goals ← intros (← mkGoal mvarId params) (generation := 0)
goals.forM (·.checkInvariants (expensive := true))
return goals.filter fun goal => !goal.inconsistent
def main (mvarId : MVarId) (config : Grind.Config) (mainDeclName : Name) (fallback : Fallback) : MetaM (List Goal) := do
def main (mvarId : MVarId) (params : Params) (mainDeclName : Name) (fallback : Fallback) : MetaM (List Goal) := do
let go : GrindM (List Goal) := do
let goals ← initCore mvarId
let goals ← initCore mvarId params
let goals ← solve goals
let goals ← goals.filterMapM fun goal => do
if goal.inconsistent then return none
@ -81,6 +97,6 @@ def main (mvarId : MVarId) (config : Grind.Config) (mainDeclName : Name) (fallba
return some goal
trace[grind.debug.final] "{← ppGoals goals}"
return goals
go.run mainDeclName config fallback
go.run mainDeclName params fallback
end Lean.Meta.Grind

View file

@ -0,0 +1,81 @@
def foo (x : Nat) := x + 2
example (f : Nat → Nat) : f (foo a) = b → f (c + 1) = d → c = a + 1 → b = d := by
grind [foo]
opaque bla : Nat → Nat
theorem blathm : bla (bla x) = bla x := sorry
example : bla (foo a) = b → bla b = bla (a + 2) := by
grind [foo, blathm]
example : bla (foo a) = b → bla b = bla (a + 2) := by
grind [foo, = blathm]
/--
error: invalid `grind` forward theorem, theorem `blathm` does not have propositional hypotheses
-/
#guard_msgs (error) in
example : bla (foo a) = b → bla b = bla (a + 2) := by
grind [foo, → blathm]
opaque P : Nat → Prop
opaque Q : Nat → Prop
opaque R : Nat → Prop
theorem pq : P x → Q x := sorry
theorem qr : Q x → R x := sorry
example : P x → R x := by
grind [→ pq, → qr]
/--
error: `grind` failed
case grind
x : Nat
a✝ : P x
x✝ : ¬R x
⊢ False
[grind] Diagnostics
[facts] Asserted facts
[prop] P x
[prop] ¬R x
[eqc] True propositions
[prop] P x
[eqc] False propositions
[prop] R x
[ematch] E-matching
[thm] pq:
∀ {x : Nat}, P x → Q x
patterns: [Q #1]
[thm] qr: ∀ {x : Nat}, Q x → R x patterns: [Q #1]
-/
#guard_msgs (error) in
example : P x → R x := by
grind [← pq, → qr]
example : P x → R x := by
grind [← pq, ← qr]
attribute [grind] blathm
example : bla (bla (bla (bla x))) = bla x := by
grind
example : bla (bla (bla (bla x))) = bla x := by
fail_if_success grind [-blathm]
sorry
example : bla (bla (bla (bla x))) = bla x := by
grind only [blathm]
example : bla (bla (bla (bla x))) = bla x := by
fail_if_success grind only
sorry
/--
error: `pq` is not marked with the `[grind]` attribute
-/
#guard_msgs (error) in
example : P x → R x := by
grind [-pq]