feat: grind? infrastructure (#6785)
This PR adds infrastructure for the `grind?` tactic. It also adds the new modifier `usr` which allows users to write `grind only [usr thmName]` to instruct `grind` to only use theorem `thmName`, but using the patterns specified with the command `grind_pattern`.
This commit is contained in:
parent
98bd162ad4
commit
69a73a18fb
5 changed files with 186 additions and 52 deletions
|
|
@ -14,10 +14,11 @@ syntax grindEqRhs := atomic("=" "_")
|
|||
syntax grindEqBwd := atomic("←" "=")
|
||||
syntax grindBwd := "←"
|
||||
syntax grindFwd := "→"
|
||||
syntax grindUsr := &"usr"
|
||||
syntax grindCases := &"cases"
|
||||
syntax grindCasesEager := atomic(&"cases" &"eager")
|
||||
|
||||
syntax grindMod := grindEqBoth <|> grindEqRhs <|> grindEq <|> grindEqBwd <|> grindBwd <|> grindFwd <|> grindCasesEager <|> grindCases
|
||||
syntax grindMod := grindEqBoth <|> grindEqRhs <|> grindEq <|> grindEqBwd <|> grindBwd <|> grindFwd <|> grindUsr <|> grindCasesEager <|> grindCases
|
||||
|
||||
syntax (name := grind) "grind" (grindMod)? : attr
|
||||
|
||||
|
|
@ -75,4 +76,10 @@ syntax (name := grind)
|
|||
(" [" withoutPosition(grindParam,*) "]")?
|
||||
("on_failure " term)? : tactic
|
||||
|
||||
|
||||
syntax (name := grindTrace)
|
||||
"grind?" optConfig (&" only")?
|
||||
(" [" withoutPosition(grindParam,*) "]")?
|
||||
("on_failure " term)? : tactic
|
||||
|
||||
end Lean.Parser.Tactic
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ def elabGrindPattern : CommandElab := fun stx => do
|
|||
let pattern ← instantiateMVars pattern
|
||||
let pattern ← Grind.preprocessPattern pattern
|
||||
return pattern.abstract xs
|
||||
Grind.addEMatchTheorem declName xs.size patterns.toList
|
||||
Grind.addEMatchTheorem declName xs.size patterns.toList .user
|
||||
| _ => throwUnsupportedSyntax
|
||||
|
||||
open Command Term in
|
||||
|
|
@ -45,7 +45,7 @@ def elabInitGrindNorm : CommandElab := fun stx =>
|
|||
Grind.registerNormTheorems pre post
|
||||
| _ => throwUnsupportedSyntax
|
||||
|
||||
def elabGrindParams (params : Grind.Params) (ps : TSyntaxArray ``Parser.Tactic.grindParam) : MetaM Grind.Params := do
|
||||
def elabGrindParams (params : Grind.Params) (ps : TSyntaxArray ``Parser.Tactic.grindParam) (only : Bool) : MetaM Grind.Params := do
|
||||
let mut params := params
|
||||
for p in ps do
|
||||
match p with
|
||||
|
|
@ -59,6 +59,16 @@ def elabGrindParams (params : Grind.Params) (ps : TSyntaxArray ``Parser.Tactic.
|
|||
let declName ← realizeGlobalConstNoOverloadWithInfo id
|
||||
let kind ← if let some mod := mod? then Grind.getAttrKindCore mod else pure .infer
|
||||
match kind with
|
||||
| .ematch .user =>
|
||||
unless only do
|
||||
withRef p <| Grind.throwInvalidUsrModifier
|
||||
let s ← Grind.getEMatchTheorems
|
||||
let thms := s.find (.decl declName)
|
||||
let thms := thms.filter fun thm => thm.kind == .user
|
||||
if thms.isEmpty then
|
||||
throwErrorAt p "invalid use of `usr` modifier, `{declName}` does not have patterns specified with the command `grind_pattern`"
|
||||
for thm in thms do
|
||||
params := { params with extra := params.extra.push thm }
|
||||
| .ematch kind =>
|
||||
params ← withRef p <| addEMatchTheorem params declName kind
|
||||
| .cases eager =>
|
||||
|
|
@ -97,7 +107,7 @@ def mkGrindParams (config : Grind.Config) (only : Bool) (ps : TSyntaxArray ``Pa
|
|||
let ematch ← if only then pure {} else Grind.getEMatchTheorems
|
||||
let casesTypes ← if only then pure {} else Grind.getCasesTypes
|
||||
let params := { params with ematch, casesTypes }
|
||||
elabGrindParams params ps
|
||||
elabGrindParams params ps only
|
||||
|
||||
def grind
|
||||
(mvarId : MVarId) (config : Grind.Config)
|
||||
|
|
@ -126,16 +136,32 @@ private def elabFallback (fallback? : Option Term) : TermElabM (Grind.GoalM Unit
|
|||
pure auxDeclName
|
||||
unsafe evalConst (Grind.GoalM Unit) auxDeclName
|
||||
|
||||
private def evalGrindCore
|
||||
(ref : Syntax)
|
||||
(config : TSyntax `Lean.Parser.Tactic.optConfig)
|
||||
(only : Option Syntax)
|
||||
(params : Option (Syntax.TSepArray `Lean.Parser.Tactic.grindParam ","))
|
||||
(fallback? : Option Term)
|
||||
(_trace : Bool) -- TODO
|
||||
: TacticM Unit := do
|
||||
let fallback ← elabFallback fallback?
|
||||
let only := only.isSome
|
||||
let params := if let some params := params then params.getElems else #[]
|
||||
logWarningAt ref "The `grind` tactic is experimental and still under development. Avoid using it in production projects"
|
||||
let declName := (← Term.getDeclName?).getD `_grind
|
||||
let config ← elabGrindConfig config
|
||||
withMainContext do liftMetaFinishingTactic (grind · config only params declName fallback)
|
||||
|
||||
@[builtin_tactic Lean.Parser.Tactic.grind] def evalGrind : Tactic := fun stx => do
|
||||
match stx with
|
||||
| `(tactic| grind $config:optConfig $[only%$only]? $[ [$params:grindParam,*] ]? $[on_failure $fallback?]?) =>
|
||||
let fallback ← elabFallback fallback?
|
||||
let only := only.isSome
|
||||
let params := if let some params := params then params.getElems else #[]
|
||||
logWarningAt stx "The `grind` tactic is experimental and still under development. Avoid using it in production projects"
|
||||
let declName := (← Term.getDeclName?).getD `_grind
|
||||
let config ← elabGrindConfig config
|
||||
withMainContext do liftMetaFinishingTactic (grind · config only params declName fallback)
|
||||
evalGrindCore stx config only params fallback? false
|
||||
| _ => throwUnsupportedSyntax
|
||||
|
||||
@[builtin_tactic Lean.Parser.Tactic.grindTrace] def evalGrindTrace : Tactic := fun stx => do
|
||||
match stx with
|
||||
| `(tactic| grind? $config:optConfig $[only%$only]? $[ [$params:grindParam,*] ]? $[on_failure $fallback?]?) =>
|
||||
evalGrindCore stx config only params fallback? true
|
||||
| _ => throwUnsupportedSyntax
|
||||
|
||||
end Lean.Elab.Tactic
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ def getAttrKindCore (stx : Syntax) : CoreM AttrKind := do
|
|||
| `(Parser.Attr.grindMod| =_) => return .ematch .eqRhs
|
||||
| `(Parser.Attr.grindMod| _=_) => return .ematch .eqBoth
|
||||
| `(Parser.Attr.grindMod| ←=) => return .ematch .eqBwd
|
||||
| `(Parser.Attr.grindMod| usr) => return .ematch .user
|
||||
| `(Parser.Attr.grindMod| cases) => return .cases false
|
||||
| `(Parser.Attr.grindMod| cases eager) => return .cases true
|
||||
| _ => throwError "unexpected `grind` theorem kind: `{stx}`"
|
||||
|
|
@ -34,6 +35,9 @@ def getAttrKindFromOpt (stx : Syntax) : CoreM AttrKind := do
|
|||
else
|
||||
getAttrKindCore stx[1][0]
|
||||
|
||||
def throwInvalidUsrModifier : CoreM α :=
|
||||
throwError "the modifier `usr` is only relevant in parameters for `grind only`"
|
||||
|
||||
builtin_initialize
|
||||
registerBuiltinAttribute {
|
||||
name := `grind
|
||||
|
|
@ -57,6 +61,7 @@ builtin_initialize
|
|||
applicationTime := .afterCompilation
|
||||
add := fun declName stx attrKind => MetaM.run' do
|
||||
match (← getAttrKindFromOpt stx) with
|
||||
| .ematch .user => throwInvalidUsrModifier
|
||||
| .ematch k => addEMatchAttr declName attrKind k
|
||||
| .cases eager => addCasesAttr declName eager attrKind
|
||||
| .infer =>
|
||||
|
|
|
|||
|
|
@ -92,6 +92,30 @@ instance : BEq Origin where
|
|||
instance : Hashable Origin where
|
||||
hash a := hash a.key
|
||||
|
||||
inductive TheoremKind where
|
||||
| eqLhs | eqRhs | eqBoth | eqBwd | fwd | bwd | default | user /- pattern specified using `grind_pattern` command -/
|
||||
deriving Inhabited, BEq, Repr
|
||||
|
||||
private def TheoremKind.toAttribute : TheoremKind → String
|
||||
| .eqLhs => "[grind =]"
|
||||
| .eqRhs => "[grind =_]"
|
||||
| .eqBoth => "[grind _=_]"
|
||||
| .eqBwd => "[grind ←=]"
|
||||
| .fwd => "[grind →]"
|
||||
| .bwd => "[grind ←]"
|
||||
| .default => "[grind]"
|
||||
| .user => "[grind]"
|
||||
|
||||
private def TheoremKind.explainFailure : TheoremKind → String
|
||||
| .eqLhs => "failed to find pattern in the left-hand side of the theorem's conclusion"
|
||||
| .eqRhs => "failed to find pattern in the right-hand side of the theorem's conclusion"
|
||||
| .eqBoth => unreachable! -- eqBoth is a macro
|
||||
| .eqBwd => "failed to use theorem's conclusion as a pattern"
|
||||
| .fwd => "failed to find patterns in the antecedents of the theorem"
|
||||
| .bwd => "failed to find patterns in the theorem's conclusion"
|
||||
| .default => "failed to find patterns"
|
||||
| .user => unreachable!
|
||||
|
||||
/-- A theorem for heuristic instantiation based on E-matching. -/
|
||||
structure EMatchTheorem where
|
||||
/--
|
||||
|
|
@ -106,16 +130,20 @@ structure EMatchTheorem where
|
|||
/-- Contains all symbols used in `pattterns`. -/
|
||||
symbols : List HeadIndex
|
||||
origin : Origin
|
||||
/-- The `kind` is used for generating the `patterns`. We save it here to implement `grind?`. -/
|
||||
kind : TheoremKind
|
||||
deriving Inhabited
|
||||
|
||||
/-- Set of E-matching theorems. -/
|
||||
structure EMatchTheorems where
|
||||
/-- The key is a symbol from `EMatchTheorem.symbols`. -/
|
||||
private map : PHashMap Name (List EMatchTheorem) := {}
|
||||
private smap : PHashMap Name (List EMatchTheorem) := {}
|
||||
/-- Set of theorem ids that have been inserted using `insert`. -/
|
||||
private origins : PHashSet Origin := {}
|
||||
/-- Theorems that have been marked as erased -/
|
||||
private erased : PHashSet Origin := {}
|
||||
/-- Mapping from origin to E-matching theorems associated with this origin. -/
|
||||
private omap : PHashMap Origin (List EMatchTheorem) := {}
|
||||
deriving Inhabited
|
||||
|
||||
/--
|
||||
|
|
@ -130,13 +158,19 @@ def EMatchTheorems.insert (s : EMatchTheorems) (thm : EMatchTheorem) : EMatchThe
|
|||
let .const declName :: syms := thm.symbols
|
||||
| unreachable!
|
||||
let thm := { thm with symbols := syms }
|
||||
let { map, origins, erased } := s
|
||||
let origins := origins.insert thm.origin
|
||||
let erased := erased.erase thm.origin
|
||||
if let some thms := map.find? declName then
|
||||
return { map := map.insert declName (thm::thms), origins, erased }
|
||||
let { smap, origins, erased, omap } := s
|
||||
let origin := thm.origin
|
||||
let origins := origins.insert origin
|
||||
let erased := erased.erase origin
|
||||
let smap := if let some thms := smap.find? declName then
|
||||
smap.insert declName (thm::thms)
|
||||
else
|
||||
return { map := map.insert declName [thm], origins, erased }
|
||||
smap.insert declName [thm]
|
||||
let omap := if let some thms := omap.find? origin then
|
||||
omap.insert origin (thm::thms)
|
||||
else
|
||||
omap.insert origin [thm]
|
||||
return { smap, origins, erased, omap }
|
||||
|
||||
/-- Returns `true` if `s` contains a theorem with the given origin. -/
|
||||
def EMatchTheorems.contains (s : EMatchTheorems) (origin : Origin) : Bool :=
|
||||
|
|
@ -156,11 +190,20 @@ The theorems are removed from `s`.
|
|||
-/
|
||||
@[inline]
|
||||
def EMatchTheorems.retrieve? (s : EMatchTheorems) (sym : Name) : Option (List EMatchTheorem × EMatchTheorems) :=
|
||||
if let some thms := s.map.find? sym then
|
||||
some (thms, { s with map := s.map.erase sym })
|
||||
if let some thms := s.smap.find? sym then
|
||||
some (thms, { s with smap := s.smap.erase sym })
|
||||
else
|
||||
none
|
||||
|
||||
/--
|
||||
Returns theorems associated with the given origin.
|
||||
-/
|
||||
def EMatchTheorems.find (s : EMatchTheorems) (origin : Origin) : List EMatchTheorem :=
|
||||
if let some thms := s.omap.find? origin then
|
||||
thms
|
||||
else
|
||||
[]
|
||||
|
||||
def EMatchTheorem.getProofWithFreshMVarLevels (thm : EMatchTheorem) : MetaM Expr := do
|
||||
if thm.proof.isConst && thm.levelParams.isEmpty then
|
||||
let declName := thm.proof.constName!
|
||||
|
|
@ -491,7 +534,7 @@ private def ppParamsAt (proof : Expr) (numParams : Nat) (paramPos : List Nat) :
|
|||
Creates an E-matching theorem for a theorem with proof `proof`, `numParams` parameters, and the given set of patterns.
|
||||
Pattern variables are represented using de Bruijn indices.
|
||||
-/
|
||||
def mkEMatchTheoremCore (origin : Origin) (levelParams : Array Name) (numParams : Nat) (proof : Expr) (patterns : List Expr) : MetaM EMatchTheorem := do
|
||||
def mkEMatchTheoremCore (origin : Origin) (levelParams : Array Name) (numParams : Nat) (proof : Expr) (patterns : List Expr) (kind : TheoremKind): MetaM EMatchTheorem := do
|
||||
let (patterns, symbols, bvarFound) ← NormalizePattern.main patterns
|
||||
if symbols.isEmpty then
|
||||
throwError "invalid pattern for `{← origin.pp}`{indentD (patterns.map ppPattern)}\nthe pattern does not contain constant symbols for indexing"
|
||||
|
|
@ -501,7 +544,7 @@ def mkEMatchTheoremCore (origin : Origin) (levelParams : Array Name) (numParams
|
|||
throwError "invalid pattern(s) for `{← origin.pp}`{indentD pats}\nthe following theorem parameters cannot be instantiated:{indentD (← ppParamsAt proof numParams pos)}"
|
||||
return {
|
||||
proof, patterns, numParams, symbols
|
||||
levelParams, origin
|
||||
levelParams, origin, kind
|
||||
}
|
||||
|
||||
private def getProofFor (declName : Name) : CoreM Expr := do
|
||||
|
|
@ -514,8 +557,8 @@ private def getProofFor (declName : Name) : CoreM Expr := do
|
|||
Creates an E-matching theorem for `declName` with `numParams` parameters, and the given set of patterns.
|
||||
Pattern variables are represented using de Bruijn indices.
|
||||
-/
|
||||
def mkEMatchTheorem (declName : Name) (numParams : Nat) (patterns : List Expr) : MetaM EMatchTheorem := do
|
||||
mkEMatchTheoremCore (.decl declName) #[] numParams (← getProofFor declName) patterns
|
||||
def mkEMatchTheorem (declName : Name) (numParams : Nat) (patterns : List Expr) (kind : TheoremKind) : MetaM EMatchTheorem := do
|
||||
mkEMatchTheoremCore (.decl declName) #[] numParams (← getProofFor declName) patterns kind
|
||||
|
||||
/--
|
||||
Given a theorem with proof `proof` and type of the form `∀ (a_1 ... a_n), lhs = rhs`,
|
||||
|
|
@ -535,15 +578,15 @@ def mkEMatchEqTheoremCore (origin : Origin) (levelParams : Array Name) (proof :
|
|||
trace[grind.debug.ematch.pattern] "mkEMatchEqTheoremCore: after preprocessing: {pat}, {← normalize pat}"
|
||||
let pats := splitWhileForbidden (pat.abstract xs)
|
||||
return (xs.size, pats)
|
||||
mkEMatchTheoremCore origin levelParams numParams proof patterns
|
||||
mkEMatchTheoremCore origin levelParams numParams proof patterns (if useLhs then .eqLhs else .eqRhs)
|
||||
|
||||
def mkEMatchEqBwdTheoremCore (origin : Origin) (levelParams : Array Name) (proof : Expr) : MetaM EMatchTheorem := do
|
||||
let (numParams, patterns) ← forallTelescopeReducing (← inferType proof) fun xs type => do
|
||||
let_expr f@Eq α lhs rhs := type
|
||||
| throwError "invalid E-matching `≠` theorem, conclusion must be an equality{indentExpr type}"
|
||||
| throwError "invalid E-matching `←=` theorem, conclusion must be an equality{indentExpr type}"
|
||||
let pat ← preprocessPattern (mkEqBwdPattern f.constLevels! α lhs rhs)
|
||||
return (xs.size, [pat.abstract xs])
|
||||
mkEMatchTheoremCore origin levelParams numParams proof patterns
|
||||
mkEMatchTheoremCore origin levelParams numParams proof patterns .eqBwd
|
||||
|
||||
/--
|
||||
Given theorem with name `declName` and type of the form `∀ (a_1 ... a_n), lhs = rhs`,
|
||||
|
|
@ -559,8 +602,8 @@ def mkEMatchEqTheorem (declName : Name) (normalizePattern := true) (useLhs : Boo
|
|||
Adds an E-matching theorem to the environment.
|
||||
See `mkEMatchTheorem`.
|
||||
-/
|
||||
def addEMatchTheorem (declName : Name) (numParams : Nat) (patterns : List Expr) : MetaM Unit := do
|
||||
ematchTheoremsExt.add (← mkEMatchTheorem declName numParams patterns)
|
||||
def addEMatchTheorem (declName : Name) (numParams : Nat) (patterns : List Expr) (kind : TheoremKind) : MetaM Unit := do
|
||||
ematchTheoremsExt.add (← mkEMatchTheorem declName numParams patterns kind)
|
||||
|
||||
/--
|
||||
Adds an E-matching equality theorem to the environment.
|
||||
|
|
@ -573,28 +616,6 @@ def addEMatchEqTheorem (declName : Name) : MetaM Unit := do
|
|||
def getEMatchTheorems : CoreM EMatchTheorems :=
|
||||
return ematchTheoremsExt.getState (← getEnv)
|
||||
|
||||
inductive TheoremKind where
|
||||
| eqLhs | eqRhs | eqBoth | eqBwd | fwd | bwd | default
|
||||
deriving Inhabited, BEq, Repr
|
||||
|
||||
private def TheoremKind.toAttribute : TheoremKind → String
|
||||
| .eqLhs => "[grind =]"
|
||||
| .eqRhs => "[grind =_]"
|
||||
| .eqBoth => "[grind _=_]"
|
||||
| .eqBwd => "[grind ←=]"
|
||||
| .fwd => "[grind →]"
|
||||
| .bwd => "[grind ←]"
|
||||
| .default => "[grind]"
|
||||
|
||||
private def TheoremKind.explainFailure : TheoremKind → String
|
||||
| .eqLhs => "failed to find pattern in the left-hand side of the theorem's conclusion"
|
||||
| .eqRhs => "failed to find pattern in the right-hand side of the theorem's conclusion"
|
||||
| .eqBoth => unreachable! -- eqBoth is a macro
|
||||
| .eqBwd => "failed to use theorem's conclusion as a pattern"
|
||||
| .fwd => "failed to find patterns in the antecedents of the theorem"
|
||||
| .bwd => "failed to find patterns in the theorem's conclusion"
|
||||
| .default => "failed to find patterns"
|
||||
|
||||
/-- Returns the types of `xs` that are propositions. -/
|
||||
private def getPropTypes (xs : Array Expr) : MetaM (Array Expr) :=
|
||||
xs.filterMapM fun x => do
|
||||
|
|
@ -702,7 +723,7 @@ where
|
|||
trace[grind.ematch.pattern] "{← origin.pp}: {patterns.map ppPattern}"
|
||||
return some {
|
||||
proof, patterns, numParams, symbols
|
||||
levelParams, origin
|
||||
levelParams, origin, kind
|
||||
}
|
||||
|
||||
def mkEMatchTheoremForDecl (declName : Name) (thmKind : TheoremKind) : MetaM EMatchTheorem := do
|
||||
|
|
|
|||
75
tests/lean/run/grind_usr.lean
Normal file
75
tests/lean/run/grind_usr.lean
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
opaque f : Nat → Nat
|
||||
|
||||
/--
|
||||
error: the modifier `usr` is only relevant in parameters for `grind only`
|
||||
-/
|
||||
#guard_msgs (error) in
|
||||
@[grind usr]
|
||||
theorem fthm : f (f x) = f x := sorry
|
||||
|
||||
/--
|
||||
info: [grind.ematch.pattern] fthm: [f (f #0)]
|
||||
-/
|
||||
#guard_msgs (info) in
|
||||
set_option trace.grind.ematch.pattern true in
|
||||
example : f (f (f x)) = f x := by
|
||||
grind only [fthm]
|
||||
|
||||
/--
|
||||
info: [grind.ematch.instance] fthm: f (f (f x)) = f (f x)
|
||||
[grind.ematch.instance] fthm: f (f x) = f x
|
||||
-/
|
||||
#guard_msgs (info) in
|
||||
set_option trace.grind.ematch.instance true in
|
||||
example : f (f (f x)) = f x := by
|
||||
grind only [fthm]
|
||||
|
||||
#guard_msgs (info) in
|
||||
-- should not instantiate anything using pattern `f (f #0)`
|
||||
set_option trace.grind.ematch.instance true in
|
||||
example : f x = x := by
|
||||
fail_if_success grind only [fthm]
|
||||
sorry
|
||||
|
||||
/--
|
||||
error: the modifier `usr` is only relevant in parameters for `grind only`
|
||||
-/
|
||||
#guard_msgs (error) in
|
||||
example : f (f (f x)) = f x := by
|
||||
grind [usr fthm]
|
||||
|
||||
/--
|
||||
error: invalid use of `usr` modifier, `fthm` does not have patterns specified with the command `grind_pattern`
|
||||
-/
|
||||
#guard_msgs (error) in
|
||||
example : f (f (f x)) = f x := by
|
||||
grind only [usr fthm]
|
||||
|
||||
grind_pattern fthm => f x
|
||||
|
||||
example : f (f (f x)) = f x := by
|
||||
grind only [usr fthm]
|
||||
|
||||
#guard_msgs (info) in
|
||||
-- should not instantiate anything using pattern `f (f #0)`
|
||||
set_option trace.grind.ematch.instance true in
|
||||
example : f x = x := by
|
||||
fail_if_success grind only [fthm]
|
||||
sorry
|
||||
|
||||
/--
|
||||
info: [grind.ematch.instance] fthm: f (f x) = f x
|
||||
[grind.ematch.instance] fthm: f (f (f x)) = f (f x)
|
||||
-/
|
||||
#guard_msgs (info) in
|
||||
set_option trace.grind.ematch.instance true in
|
||||
example : f x = x := by
|
||||
fail_if_success grind only [usr fthm]
|
||||
sorry
|
||||
|
||||
/--
|
||||
error: the modifier `usr` is only relevant in parameters for `grind only`
|
||||
-/
|
||||
#guard_msgs (error) in
|
||||
example : f (f (f x)) = f x := by
|
||||
grind [usr fthm]
|
||||
Loading…
Add table
Reference in a new issue