feat: add simp-like parameters to grind (#6675)
This PR adds `simp`-like parameters to `grind`, and `grind only` similar to `simp only`.
This commit is contained in:
parent
60142c967c
commit
35a4da28ac
5 changed files with 202 additions and 47 deletions
|
|
@ -14,7 +14,9 @@ syntax grindEqRhs := atomic("=" "_")
|
|||
syntax grindBwd := "←"
|
||||
syntax grindFwd := "→"
|
||||
|
||||
syntax (name := grind) "grind" (grindEqBoth <|> grindEqRhs <|> grindEq <|> grindBwd <|> grindFwd)? : attr
|
||||
syntax grindThmMod := grindEqBoth <|> grindEqRhs <|> grindEq <|> grindBwd <|> grindFwd
|
||||
|
||||
syntax (name := grind) "grind" (grindThmMod)? : attr
|
||||
|
||||
end Lean.Parser.Attr
|
||||
|
||||
|
|
@ -59,7 +61,13 @@ namespace Lean.Parser.Tactic
|
|||
`grind` tactic and related tactics.
|
||||
-/
|
||||
|
||||
-- TODO: parameters
|
||||
syntax (name := grind) "grind" optConfig ("on_failure " term)? : tactic
|
||||
syntax grindErase := "-" ident
|
||||
syntax grindLemma := (Attr.grindThmMod)? ident
|
||||
syntax grindParam := grindErase <|> grindLemma
|
||||
|
||||
syntax (name := grind)
|
||||
"grind" optConfig (&" only")?
|
||||
(" [" withoutPosition(grindParam,*) "]")?
|
||||
("on_failure " term)? : tactic
|
||||
|
||||
end Lean.Parser.Tactic
|
||||
|
|
|
|||
|
|
@ -34,8 +34,44 @@ def elabGrindPattern : CommandElab := fun stx => do
|
|||
Grind.addEMatchTheorem declName xs.size patterns.toList
|
||||
| _ => throwUnsupportedSyntax
|
||||
|
||||
def grind (mvarId : MVarId) (config : Grind.Config) (mainDeclName : Name) (fallback : Grind.Fallback) : MetaM Unit := do
|
||||
let goals ← Grind.main mvarId config mainDeclName fallback
|
||||
def elabGrindParams (params : Grind.Params) (ps : TSyntaxArray ``Parser.Tactic.grindParam) : MetaM Grind.Params := do
|
||||
let mut params := params
|
||||
for p in ps do
|
||||
match p with
|
||||
| `(Parser.Tactic.grindParam| - $id:ident) =>
|
||||
let declName ← realizeGlobalConstNoOverloadWithInfo id
|
||||
if (← isInductivePredicate declName) then
|
||||
throwErrorAt p "NIY"
|
||||
else
|
||||
params := { params with ematch := (← params.ematch.eraseDecl declName) }
|
||||
| `(Parser.Tactic.grindParam| $[$mod?:grindThmMod]? $id:ident) =>
|
||||
let declName ← realizeGlobalConstNoOverloadWithInfo id
|
||||
let kind ← if let some mod := mod? then Grind.getTheoremKindCore mod else pure .default
|
||||
if (← isInductivePredicate declName) then
|
||||
throwErrorAt p "NIY"
|
||||
else if (← getConstInfo declName).isTheorem then
|
||||
params := { params with extra := params.extra.push (← Grind.mkEMatchTheoremForDecl declName kind) }
|
||||
else if let some eqns ← getEqnsFor? declName then
|
||||
for eqn in eqns do
|
||||
params := { params with extra := params.extra.push (← Grind.mkEMatchTheoremForDecl eqn kind) }
|
||||
else
|
||||
throwError "invalid `grind` parameter, `{declName}` is not a theorem, definition, or inductive type"
|
||||
| _ => throwError "unexpected `grind` parameter{indentD p}"
|
||||
return params
|
||||
|
||||
def mkGrindParams (config : Grind.Config) (only : Bool) (ps : TSyntaxArray ``Parser.Tactic.grindParam) : MetaM Grind.Params := do
|
||||
let params ← Grind.mkParams config
|
||||
let ematch ← if only then pure {} else Grind.getEMatchTheorems
|
||||
let params := { params with ematch }
|
||||
elabGrindParams params ps
|
||||
|
||||
def grind
|
||||
(mvarId : MVarId) (config : Grind.Config)
|
||||
(only : Bool)
|
||||
(ps : TSyntaxArray ``Parser.Tactic.grindParam)
|
||||
(mainDeclName : Name) (fallback : Grind.Fallback) : MetaM Unit := do
|
||||
let params ← mkGrindParams config only ps
|
||||
let goals ← Grind.main mvarId params mainDeclName fallback
|
||||
unless goals.isEmpty do
|
||||
throwError "`grind` failed\n{← Grind.goalsToMessageData goals config}"
|
||||
|
||||
|
|
@ -58,12 +94,14 @@ private def elabFallback (fallback? : Option Term) : TermElabM (Grind.GoalM Unit
|
|||
|
||||
@[builtin_tactic Lean.Parser.Tactic.grind] def evalApplyRfl : Tactic := fun stx => do
|
||||
match stx with
|
||||
| `(tactic| grind $config:optConfig $[on_failure $fallback?]?) =>
|
||||
| `(tactic| grind $config:optConfig $[only%$only]? $[ [$params:grindParam,*] ]? $[on_failure $fallback?]?) =>
|
||||
let fallback ← elabFallback fallback?
|
||||
let only := only.isSome
|
||||
let params := if let some params := params then params.getElems else #[]
|
||||
logWarningAt stx "The `grind` tactic is experimental and still under development. Avoid using it in production projects"
|
||||
let declName := (← Term.getDeclName?).getD `_grind
|
||||
let config ← elabGrindConfig config
|
||||
withMainContext do liftMetaFinishingTactic (grind · config declName fallback)
|
||||
withMainContext do liftMetaFinishingTactic (grind · config only params declName fallback)
|
||||
| _ => throwUnsupportedSyntax
|
||||
|
||||
end Lean.Elab.Tactic
|
||||
|
|
|
|||
|
|
@ -551,7 +551,7 @@ def getEMatchTheorems : CoreM EMatchTheorems :=
|
|||
|
||||
inductive TheoremKind where
|
||||
| eqLhs | eqRhs | eqBoth | fwd | bwd | default
|
||||
deriving Inhabited, BEq
|
||||
deriving Inhabited, BEq, Repr
|
||||
|
||||
private def TheoremKind.toAttribute : TheoremKind → String
|
||||
| .eqLhs => "[grind =]"
|
||||
|
|
@ -677,19 +677,22 @@ where
|
|||
levelParams, origin
|
||||
}
|
||||
|
||||
private def getKind (stx : Syntax) : TheoremKind :=
|
||||
/-- Return theorem kind for `stx` of the form `Attr.grindThmMod` -/
|
||||
def getTheoremKindCore (stx : Syntax) : CoreM TheoremKind := do
|
||||
match stx with
|
||||
| `(Parser.Attr.grindThmMod| =) => return .eqLhs
|
||||
| `(Parser.Attr.grindThmMod| →) => return .fwd
|
||||
| `(Parser.Attr.grindThmMod| ←) => return .bwd
|
||||
| `(Parser.Attr.grindThmMod| =_) => return .eqRhs
|
||||
| `(Parser.Attr.grindThmMod| _=_) => return .eqBoth
|
||||
| _ => throwError "unexpected `grind` theorem kind: `{stx}`"
|
||||
|
||||
/-- Return theorem kind for `stx` of the form `(Attr.grindThmMod)?` -/
|
||||
def getTheoremKindFromOpt (stx : Syntax) : CoreM TheoremKind := do
|
||||
if stx[1].isNone then
|
||||
.default
|
||||
else if stx[1][0].getKind == ``Parser.Attr.grindEq then
|
||||
.eqLhs
|
||||
else if stx[1][0].getKind == ``Parser.Attr.grindFwd then
|
||||
.fwd
|
||||
else if stx[1][0].getKind == ``Parser.Attr.grindEqRhs then
|
||||
.eqRhs
|
||||
else if stx[1][0].getKind == ``Parser.Attr.grindEqBoth then
|
||||
.eqBoth
|
||||
return .default
|
||||
else
|
||||
.bwd
|
||||
getTheoremKindCore stx[1][0]
|
||||
|
||||
private def addGrindEqAttr (declName : Name) (attrKind : AttributeKind) (thmKind : TheoremKind) (useLhs := true) : MetaM Unit := do
|
||||
if (← getConstInfo declName).isTheorem then
|
||||
|
|
@ -702,6 +705,11 @@ private def addGrindEqAttr (declName : Name) (attrKind : AttributeKind) (thmKind
|
|||
else
|
||||
throwError s!"`{thmKind.toAttribute}` attribute can only be applied to equational theorems or function definitions"
|
||||
|
||||
def mkEMatchTheoremForDecl (declName : Name) (thmKind : TheoremKind) : MetaM EMatchTheorem := do
|
||||
let some thm ← mkEMatchTheoremWithKind? (.decl declName) #[] (← getProofFor declName) thmKind
|
||||
| throwError "`@{thmKind.toAttribute} theorem {declName}` {thmKind.explainFailure}, consider using different options or the `grind_pattern` command"
|
||||
return thm
|
||||
|
||||
private def addGrindAttr (declName : Name) (attrKind : AttributeKind) (thmKind : TheoremKind) : MetaM Unit := do
|
||||
if thmKind == .eqLhs then
|
||||
addGrindEqAttr declName attrKind thmKind (useLhs := true)
|
||||
|
|
@ -713,10 +721,26 @@ private def addGrindAttr (declName : Name) (attrKind : AttributeKind) (thmKind :
|
|||
else if !(← getConstInfo declName).isTheorem then
|
||||
addGrindEqAttr declName attrKind thmKind
|
||||
else
|
||||
let some thm ← mkEMatchTheoremWithKind? (.decl declName) #[] (← getProofFor declName) thmKind
|
||||
| throwError "`@{thmKind.toAttribute} theorem {declName}` {thmKind.explainFailure}, consider using different options or the `grind_pattern` command"
|
||||
let thm ← mkEMatchTheoremForDecl declName thmKind
|
||||
ematchTheoremsExt.add thm attrKind
|
||||
|
||||
def EMatchTheorems.eraseDecl (s : EMatchTheorems) (declName : Name) : MetaM EMatchTheorems := do
|
||||
let throwErr {α} : MetaM α :=
|
||||
throwError "`{declName}` is not marked with the `[grind]` attribute"
|
||||
let info ← getConstInfo declName
|
||||
if !info.isTheorem then
|
||||
if let some eqns ← getEqnsFor? declName then
|
||||
let s := ematchTheoremsExt.getState (← getEnv)
|
||||
unless eqns.all fun eqn => s.contains (.decl eqn) do
|
||||
throwErr
|
||||
return eqns.foldl (init := s) fun s eqn => s.erase (.decl eqn)
|
||||
else
|
||||
throwErr
|
||||
else
|
||||
unless ematchTheoremsExt.getState (← getEnv) |>.contains (.decl declName) do
|
||||
throwErr
|
||||
return s.erase <| .decl declName
|
||||
|
||||
builtin_initialize
|
||||
registerBuiltinAttribute {
|
||||
name := `grind
|
||||
|
|
@ -739,7 +763,7 @@ builtin_initialize
|
|||
`grind` will add an instance of this theorem to the local context whenever it encounters the pattern `foo (foo x)`."
|
||||
applicationTime := .afterCompilation
|
||||
add := fun declName stx attrKind => do
|
||||
addGrindAttr declName attrKind (getKind stx) |>.run' {}
|
||||
addGrindAttr declName attrKind (← getTheoremKindFromOpt stx) |>.run' {}
|
||||
erase := fun declName => MetaM.run' do
|
||||
/-
|
||||
Remark: consider the following example
|
||||
|
|
@ -755,21 +779,9 @@ builtin_initialize
|
|||
attribute [-grind] foo -- ok
|
||||
```
|
||||
-/
|
||||
let throwErr := throwError "`{declName}` is not marked with the `[grind]` attribute"
|
||||
let info ← getConstInfo declName
|
||||
if !info.isTheorem then
|
||||
if let some eqns ← getEqnsFor? declName then
|
||||
let s := ematchTheoremsExt.getState (← getEnv)
|
||||
unless eqns.all fun eqn => s.contains (.decl eqn) do
|
||||
throwErr
|
||||
modifyEnv fun env => ematchTheoremsExt.modifyState env fun s =>
|
||||
eqns.foldl (init := s) fun s eqn => s.erase (.decl eqn)
|
||||
else
|
||||
throwErr
|
||||
else
|
||||
unless ematchTheoremsExt.getState (← getEnv) |>.contains (.decl declName) do
|
||||
throwErr
|
||||
modifyEnv fun env => ematchTheoremsExt.modifyState env fun s => s.erase (.decl declName)
|
||||
let s := ematchTheoremsExt.getState (← getEnv)
|
||||
let s ← s.eraseDecl declName
|
||||
modifyEnv fun env => ematchTheoremsExt.modifyState env fun _ => s
|
||||
}
|
||||
|
||||
end Lean.Meta.Grind
|
||||
|
|
|
|||
|
|
@ -20,6 +20,19 @@ import Lean.Meta.Tactic.Grind.SimpUtil
|
|||
|
||||
namespace Lean.Meta.Grind
|
||||
|
||||
structure Params where
|
||||
config : Grind.Config
|
||||
ematch : EMatchTheorems := {}
|
||||
extra : PArray EMatchTheorem := {}
|
||||
norm : Simp.Context
|
||||
normProcs : Array Simprocs
|
||||
-- TODO: inductives to split
|
||||
|
||||
def mkParams (config : Grind.Config) : MetaM Params := do
|
||||
let norm ← Grind.getSimpContext
|
||||
let normProcs ← Grind.getSimprocs
|
||||
return { config, norm, normProcs }
|
||||
|
||||
def mkMethods (fallback : Fallback) : CoreM Methods := do
|
||||
let builtinPropagators ← builtinPropagatorsRef.get
|
||||
return {
|
||||
|
|
@ -37,26 +50,29 @@ def mkMethods (fallback : Fallback) : CoreM Methods := do
|
|||
prop e
|
||||
}
|
||||
|
||||
def GrindM.run (x : GrindM α) (mainDeclName : Name) (config : Grind.Config) (fallback : Fallback) : MetaM α := do
|
||||
def GrindM.run (x : GrindM α) (mainDeclName : Name) (params : Params) (fallback : Fallback) : MetaM α := do
|
||||
let scState := ShareCommon.State.mk _
|
||||
let (falseExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``False)
|
||||
let (trueExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``True)
|
||||
let (natZExpr, scState) := ShareCommon.State.shareCommon scState (mkNatLit 0)
|
||||
let simprocs ← Grind.getSimprocs
|
||||
let simp ← Grind.getSimpContext
|
||||
let simprocs := params.normProcs
|
||||
let simp := params.norm
|
||||
let config := params.config
|
||||
x (← mkMethods fallback).toMethodsRef { mainDeclName, config, simprocs, simp } |>.run' { scState, trueExpr, falseExpr, natZExpr }
|
||||
|
||||
private def mkGoal (mvarId : MVarId) : GrindM Goal := do
|
||||
private def mkGoal (mvarId : MVarId) (params : Params) : GrindM Goal := do
|
||||
let trueExpr ← getTrueExpr
|
||||
let falseExpr ← getFalseExpr
|
||||
let natZeroExpr ← getNatZeroExpr
|
||||
let thmMap ← getEMatchTheorems
|
||||
let thmMap := params.ematch
|
||||
GoalM.run' { mvarId, thmMap } do
|
||||
mkENodeCore falseExpr (interpreted := true) (ctor := false) (generation := 0)
|
||||
mkENodeCore trueExpr (interpreted := true) (ctor := false) (generation := 0)
|
||||
mkENodeCore natZeroExpr (interpreted := true) (ctor := false) (generation := 0)
|
||||
for thm in params.extra do
|
||||
activateTheorem thm 0
|
||||
|
||||
private def initCore (mvarId : MVarId) : GrindM (List Goal) := do
|
||||
private def initCore (mvarId : MVarId) (params : Params) : GrindM (List Goal) := do
|
||||
mvarId.ensureProp
|
||||
-- TODO: abstract metavars
|
||||
mvarId.ensureNoMVar
|
||||
|
|
@ -65,13 +81,13 @@ private def initCore (mvarId : MVarId) : GrindM (List Goal) := do
|
|||
let mvarId ← mvarId.unfoldReducible
|
||||
let mvarId ← mvarId.betaReduce
|
||||
appendTagSuffix mvarId `grind
|
||||
let goals ← intros (← mkGoal mvarId) (generation := 0)
|
||||
let goals ← intros (← mkGoal mvarId params) (generation := 0)
|
||||
goals.forM (·.checkInvariants (expensive := true))
|
||||
return goals.filter fun goal => !goal.inconsistent
|
||||
|
||||
def main (mvarId : MVarId) (config : Grind.Config) (mainDeclName : Name) (fallback : Fallback) : MetaM (List Goal) := do
|
||||
def main (mvarId : MVarId) (params : Params) (mainDeclName : Name) (fallback : Fallback) : MetaM (List Goal) := do
|
||||
let go : GrindM (List Goal) := do
|
||||
let goals ← initCore mvarId
|
||||
let goals ← initCore mvarId params
|
||||
let goals ← solve goals
|
||||
let goals ← goals.filterMapM fun goal => do
|
||||
if goal.inconsistent then return none
|
||||
|
|
@ -81,6 +97,6 @@ def main (mvarId : MVarId) (config : Grind.Config) (mainDeclName : Name) (fallba
|
|||
return some goal
|
||||
trace[grind.debug.final] "{← ppGoals goals}"
|
||||
return goals
|
||||
go.run mainDeclName config fallback
|
||||
go.run mainDeclName params fallback
|
||||
|
||||
end Lean.Meta.Grind
|
||||
|
|
|
|||
81
tests/lean/run/grind_params.lean
Normal file
81
tests/lean/run/grind_params.lean
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
def foo (x : Nat) := x + 2
|
||||
|
||||
example (f : Nat → Nat) : f (foo a) = b → f (c + 1) = d → c = a + 1 → b = d := by
|
||||
grind [foo]
|
||||
|
||||
opaque bla : Nat → Nat
|
||||
theorem blathm : bla (bla x) = bla x := sorry
|
||||
|
||||
example : bla (foo a) = b → bla b = bla (a + 2) := by
|
||||
grind [foo, blathm]
|
||||
|
||||
example : bla (foo a) = b → bla b = bla (a + 2) := by
|
||||
grind [foo, = blathm]
|
||||
|
||||
/--
|
||||
error: invalid `grind` forward theorem, theorem `blathm` does not have propositional hypotheses
|
||||
-/
|
||||
#guard_msgs (error) in
|
||||
example : bla (foo a) = b → bla b = bla (a + 2) := by
|
||||
grind [foo, → blathm]
|
||||
|
||||
opaque P : Nat → Prop
|
||||
opaque Q : Nat → Prop
|
||||
opaque R : Nat → Prop
|
||||
|
||||
theorem pq : P x → Q x := sorry
|
||||
theorem qr : Q x → R x := sorry
|
||||
|
||||
example : P x → R x := by
|
||||
grind [→ pq, → qr]
|
||||
|
||||
/--
|
||||
error: `grind` failed
|
||||
case grind
|
||||
x : Nat
|
||||
a✝ : P x
|
||||
x✝ : ¬R x
|
||||
⊢ False
|
||||
[grind] Diagnostics
|
||||
[facts] Asserted facts
|
||||
[prop] P x
|
||||
[prop] ¬R x
|
||||
[eqc] True propositions
|
||||
[prop] P x
|
||||
[eqc] False propositions
|
||||
[prop] R x
|
||||
[ematch] E-matching
|
||||
[thm] pq:
|
||||
∀ {x : Nat}, P x → Q x
|
||||
patterns: [Q #1]
|
||||
[thm] qr: ∀ {x : Nat}, Q x → R x patterns: [Q #1]
|
||||
-/
|
||||
#guard_msgs (error) in
|
||||
example : P x → R x := by
|
||||
grind [← pq, → qr]
|
||||
|
||||
example : P x → R x := by
|
||||
grind [← pq, ← qr]
|
||||
|
||||
attribute [grind] blathm
|
||||
|
||||
example : bla (bla (bla (bla x))) = bla x := by
|
||||
grind
|
||||
|
||||
example : bla (bla (bla (bla x))) = bla x := by
|
||||
fail_if_success grind [-blathm]
|
||||
sorry
|
||||
|
||||
example : bla (bla (bla (bla x))) = bla x := by
|
||||
grind only [blathm]
|
||||
|
||||
example : bla (bla (bla (bla x))) = bla x := by
|
||||
fail_if_success grind only
|
||||
sorry
|
||||
|
||||
/--
|
||||
error: `pq` is not marked with the `[grind]` attribute
|
||||
-/
|
||||
#guard_msgs (error) in
|
||||
example : P x → R x := by
|
||||
grind [-pq]
|
||||
Loading…
Add table
Reference in a new issue