diff --git a/src/Init/Grind/Tactics.lean b/src/Init/Grind/Tactics.lean index f53b5d81bc..bd339f4a98 100644 --- a/src/Init/Grind/Tactics.lean +++ b/src/Init/Grind/Tactics.lean @@ -14,10 +14,11 @@ syntax grindEqRhs := atomic("=" "_") syntax grindEqBwd := atomic("←" "=") syntax grindBwd := "←" syntax grindFwd := "→" +syntax grindUsr := &"usr" syntax grindCases := &"cases" syntax grindCasesEager := atomic(&"cases" &"eager") -syntax grindMod := grindEqBoth <|> grindEqRhs <|> grindEq <|> grindEqBwd <|> grindBwd <|> grindFwd <|> grindCasesEager <|> grindCases +syntax grindMod := grindEqBoth <|> grindEqRhs <|> grindEq <|> grindEqBwd <|> grindBwd <|> grindFwd <|> grindUsr <|> grindCasesEager <|> grindCases syntax (name := grind) "grind" (grindMod)? : attr @@ -75,4 +76,10 @@ syntax (name := grind) (" [" withoutPosition(grindParam,*) "]")? ("on_failure " term)? : tactic + +syntax (name := grindTrace) + "grind?" optConfig (&" only")? + (" [" withoutPosition(grindParam,*) "]")? + ("on_failure " term)? : tactic + end Lean.Parser.Tactic diff --git a/src/Lean/Elab/Tactic/Grind.lean b/src/Lean/Elab/Tactic/Grind.lean index db5483ac3f..fa1ce1d6de 100644 --- a/src/Lean/Elab/Tactic/Grind.lean +++ b/src/Lean/Elab/Tactic/Grind.lean @@ -31,7 +31,7 @@ def elabGrindPattern : CommandElab := fun stx => do let pattern ← instantiateMVars pattern let pattern ← Grind.preprocessPattern pattern return pattern.abstract xs - Grind.addEMatchTheorem declName xs.size patterns.toList + Grind.addEMatchTheorem declName xs.size patterns.toList .user | _ => throwUnsupportedSyntax open Command Term in @@ -45,7 +45,7 @@ def elabInitGrindNorm : CommandElab := fun stx => Grind.registerNormTheorems pre post | _ => throwUnsupportedSyntax -def elabGrindParams (params : Grind.Params) (ps : TSyntaxArray ``Parser.Tactic.grindParam) : MetaM Grind.Params := do +def elabGrindParams (params : Grind.Params) (ps : TSyntaxArray ``Parser.Tactic.grindParam) (only : Bool) : MetaM Grind.Params := do let mut params := params for p in ps do match p with @@ -59,6 +59,16 @@ def elabGrindParams (params : Grind.Params) (ps : TSyntaxArray ``Parser.Tactic. let declName ← realizeGlobalConstNoOverloadWithInfo id let kind ← if let some mod := mod? then Grind.getAttrKindCore mod else pure .infer match kind with + | .ematch .user => + unless only do + withRef p <| Grind.throwInvalidUsrModifier + let s ← Grind.getEMatchTheorems + let thms := s.find (.decl declName) + let thms := thms.filter fun thm => thm.kind == .user + if thms.isEmpty then + throwErrorAt p "invalid use of `usr` modifier, `{declName}` does not have patterns specified with the command `grind_pattern`" + for thm in thms do + params := { params with extra := params.extra.push thm } | .ematch kind => params ← withRef p <| addEMatchTheorem params declName kind | .cases eager => @@ -97,7 +107,7 @@ def mkGrindParams (config : Grind.Config) (only : Bool) (ps : TSyntaxArray ``Pa let ematch ← if only then pure {} else Grind.getEMatchTheorems let casesTypes ← if only then pure {} else Grind.getCasesTypes let params := { params with ematch, casesTypes } - elabGrindParams params ps + elabGrindParams params ps only def grind (mvarId : MVarId) (config : Grind.Config) @@ -126,16 +136,32 @@ private def elabFallback (fallback? : Option Term) : TermElabM (Grind.GoalM Unit pure auxDeclName unsafe evalConst (Grind.GoalM Unit) auxDeclName +private def evalGrindCore + (ref : Syntax) + (config : TSyntax `Lean.Parser.Tactic.optConfig) + (only : Option Syntax) + (params : Option (Syntax.TSepArray `Lean.Parser.Tactic.grindParam ",")) + (fallback? : Option Term) + (_trace : Bool) -- TODO + : TacticM Unit := do + let fallback ← elabFallback fallback? + let only := only.isSome + let params := if let some params := params then params.getElems else #[] + logWarningAt ref "The `grind` tactic is experimental and still under development. Avoid using it in production projects" + let declName := (← Term.getDeclName?).getD `_grind + let config ← elabGrindConfig config + withMainContext do liftMetaFinishingTactic (grind · config only params declName fallback) + @[builtin_tactic Lean.Parser.Tactic.grind] def evalGrind : Tactic := fun stx => do match stx with | `(tactic| grind $config:optConfig $[only%$only]? $[ [$params:grindParam,*] ]? $[on_failure $fallback?]?) => - let fallback ← elabFallback fallback? - let only := only.isSome - let params := if let some params := params then params.getElems else #[] - logWarningAt stx "The `grind` tactic is experimental and still under development. Avoid using it in production projects" - let declName := (← Term.getDeclName?).getD `_grind - let config ← elabGrindConfig config - withMainContext do liftMetaFinishingTactic (grind · config only params declName fallback) + evalGrindCore stx config only params fallback? false + | _ => throwUnsupportedSyntax + +@[builtin_tactic Lean.Parser.Tactic.grindTrace] def evalGrindTrace : Tactic := fun stx => do + match stx with + | `(tactic| grind? $config:optConfig $[only%$only]? $[ [$params:grindParam,*] ]? $[on_failure $fallback?]?) => + evalGrindCore stx config only params fallback? true | _ => throwUnsupportedSyntax end Lean.Elab.Tactic diff --git a/src/Lean/Meta/Tactic/Grind/Attr.lean b/src/Lean/Meta/Tactic/Grind/Attr.lean index 1da60408df..ee9f06501b 100644 --- a/src/Lean/Meta/Tactic/Grind/Attr.lean +++ b/src/Lean/Meta/Tactic/Grind/Attr.lean @@ -23,6 +23,7 @@ def getAttrKindCore (stx : Syntax) : CoreM AttrKind := do | `(Parser.Attr.grindMod| =_) => return .ematch .eqRhs | `(Parser.Attr.grindMod| _=_) => return .ematch .eqBoth | `(Parser.Attr.grindMod| ←=) => return .ematch .eqBwd + | `(Parser.Attr.grindMod| usr) => return .ematch .user | `(Parser.Attr.grindMod| cases) => return .cases false | `(Parser.Attr.grindMod| cases eager) => return .cases true | _ => throwError "unexpected `grind` theorem kind: `{stx}`" @@ -34,6 +35,9 @@ def getAttrKindFromOpt (stx : Syntax) : CoreM AttrKind := do else getAttrKindCore stx[1][0] +def throwInvalidUsrModifier : CoreM α := + throwError "the modifier `usr` is only relevant in parameters for `grind only`" + builtin_initialize registerBuiltinAttribute { name := `grind @@ -57,6 +61,7 @@ builtin_initialize applicationTime := .afterCompilation add := fun declName stx attrKind => MetaM.run' do match (← getAttrKindFromOpt stx) with + | .ematch .user => throwInvalidUsrModifier | .ematch k => addEMatchAttr declName attrKind k | .cases eager => addCasesAttr declName eager attrKind | .infer => diff --git a/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean b/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean index 51611f66b6..de7334bd99 100644 --- a/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean +++ b/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean @@ -92,6 +92,30 @@ instance : BEq Origin where instance : Hashable Origin where hash a := hash a.key +inductive TheoremKind where + | eqLhs | eqRhs | eqBoth | eqBwd | fwd | bwd | default | user /- pattern specified using `grind_pattern` command -/ + deriving Inhabited, BEq, Repr + +private def TheoremKind.toAttribute : TheoremKind → String + | .eqLhs => "[grind =]" + | .eqRhs => "[grind =_]" + | .eqBoth => "[grind _=_]" + | .eqBwd => "[grind ←=]" + | .fwd => "[grind →]" + | .bwd => "[grind ←]" + | .default => "[grind]" + | .user => "[grind]" + +private def TheoremKind.explainFailure : TheoremKind → String + | .eqLhs => "failed to find pattern in the left-hand side of the theorem's conclusion" + | .eqRhs => "failed to find pattern in the right-hand side of the theorem's conclusion" + | .eqBoth => unreachable! -- eqBoth is a macro + | .eqBwd => "failed to use theorem's conclusion as a pattern" + | .fwd => "failed to find patterns in the antecedents of the theorem" + | .bwd => "failed to find patterns in the theorem's conclusion" + | .default => "failed to find patterns" + | .user => unreachable! + /-- A theorem for heuristic instantiation based on E-matching. -/ structure EMatchTheorem where /-- @@ -106,16 +130,20 @@ structure EMatchTheorem where /-- Contains all symbols used in `pattterns`. -/ symbols : List HeadIndex origin : Origin + /-- The `kind` is used for generating the `patterns`. We save it here to implement `grind?`. -/ + kind : TheoremKind deriving Inhabited /-- Set of E-matching theorems. -/ structure EMatchTheorems where /-- The key is a symbol from `EMatchTheorem.symbols`. -/ - private map : PHashMap Name (List EMatchTheorem) := {} + private smap : PHashMap Name (List EMatchTheorem) := {} /-- Set of theorem ids that have been inserted using `insert`. -/ private origins : PHashSet Origin := {} /-- Theorems that have been marked as erased -/ private erased : PHashSet Origin := {} + /-- Mapping from origin to E-matching theorems associated with this origin. -/ + private omap : PHashMap Origin (List EMatchTheorem) := {} deriving Inhabited /-- @@ -130,13 +158,19 @@ def EMatchTheorems.insert (s : EMatchTheorems) (thm : EMatchTheorem) : EMatchThe let .const declName :: syms := thm.symbols | unreachable! let thm := { thm with symbols := syms } - let { map, origins, erased } := s - let origins := origins.insert thm.origin - let erased := erased.erase thm.origin - if let some thms := map.find? declName then - return { map := map.insert declName (thm::thms), origins, erased } + let { smap, origins, erased, omap } := s + let origin := thm.origin + let origins := origins.insert origin + let erased := erased.erase origin + let smap := if let some thms := smap.find? declName then + smap.insert declName (thm::thms) else - return { map := map.insert declName [thm], origins, erased } + smap.insert declName [thm] + let omap := if let some thms := omap.find? origin then + omap.insert origin (thm::thms) + else + omap.insert origin [thm] + return { smap, origins, erased, omap } /-- Returns `true` if `s` contains a theorem with the given origin. -/ def EMatchTheorems.contains (s : EMatchTheorems) (origin : Origin) : Bool := @@ -156,11 +190,20 @@ The theorems are removed from `s`. -/ @[inline] def EMatchTheorems.retrieve? (s : EMatchTheorems) (sym : Name) : Option (List EMatchTheorem × EMatchTheorems) := - if let some thms := s.map.find? sym then - some (thms, { s with map := s.map.erase sym }) + if let some thms := s.smap.find? sym then + some (thms, { s with smap := s.smap.erase sym }) else none +/-- +Returns theorems associated with the given origin. +-/ +def EMatchTheorems.find (s : EMatchTheorems) (origin : Origin) : List EMatchTheorem := + if let some thms := s.omap.find? origin then + thms + else + [] + def EMatchTheorem.getProofWithFreshMVarLevels (thm : EMatchTheorem) : MetaM Expr := do if thm.proof.isConst && thm.levelParams.isEmpty then let declName := thm.proof.constName! @@ -491,7 +534,7 @@ private def ppParamsAt (proof : Expr) (numParams : Nat) (paramPos : List Nat) : Creates an E-matching theorem for a theorem with proof `proof`, `numParams` parameters, and the given set of patterns. Pattern variables are represented using de Bruijn indices. -/ -def mkEMatchTheoremCore (origin : Origin) (levelParams : Array Name) (numParams : Nat) (proof : Expr) (patterns : List Expr) : MetaM EMatchTheorem := do +def mkEMatchTheoremCore (origin : Origin) (levelParams : Array Name) (numParams : Nat) (proof : Expr) (patterns : List Expr) (kind : TheoremKind): MetaM EMatchTheorem := do let (patterns, symbols, bvarFound) ← NormalizePattern.main patterns if symbols.isEmpty then throwError "invalid pattern for `{← origin.pp}`{indentD (patterns.map ppPattern)}\nthe pattern does not contain constant symbols for indexing" @@ -501,7 +544,7 @@ def mkEMatchTheoremCore (origin : Origin) (levelParams : Array Name) (numParams throwError "invalid pattern(s) for `{← origin.pp}`{indentD pats}\nthe following theorem parameters cannot be instantiated:{indentD (← ppParamsAt proof numParams pos)}" return { proof, patterns, numParams, symbols - levelParams, origin + levelParams, origin, kind } private def getProofFor (declName : Name) : CoreM Expr := do @@ -514,8 +557,8 @@ private def getProofFor (declName : Name) : CoreM Expr := do Creates an E-matching theorem for `declName` with `numParams` parameters, and the given set of patterns. Pattern variables are represented using de Bruijn indices. -/ -def mkEMatchTheorem (declName : Name) (numParams : Nat) (patterns : List Expr) : MetaM EMatchTheorem := do - mkEMatchTheoremCore (.decl declName) #[] numParams (← getProofFor declName) patterns +def mkEMatchTheorem (declName : Name) (numParams : Nat) (patterns : List Expr) (kind : TheoremKind) : MetaM EMatchTheorem := do + mkEMatchTheoremCore (.decl declName) #[] numParams (← getProofFor declName) patterns kind /-- Given a theorem with proof `proof` and type of the form `∀ (a_1 ... a_n), lhs = rhs`, @@ -535,15 +578,15 @@ def mkEMatchEqTheoremCore (origin : Origin) (levelParams : Array Name) (proof : trace[grind.debug.ematch.pattern] "mkEMatchEqTheoremCore: after preprocessing: {pat}, {← normalize pat}" let pats := splitWhileForbidden (pat.abstract xs) return (xs.size, pats) - mkEMatchTheoremCore origin levelParams numParams proof patterns + mkEMatchTheoremCore origin levelParams numParams proof patterns (if useLhs then .eqLhs else .eqRhs) def mkEMatchEqBwdTheoremCore (origin : Origin) (levelParams : Array Name) (proof : Expr) : MetaM EMatchTheorem := do let (numParams, patterns) ← forallTelescopeReducing (← inferType proof) fun xs type => do let_expr f@Eq α lhs rhs := type - | throwError "invalid E-matching `≠` theorem, conclusion must be an equality{indentExpr type}" + | throwError "invalid E-matching `←=` theorem, conclusion must be an equality{indentExpr type}" let pat ← preprocessPattern (mkEqBwdPattern f.constLevels! α lhs rhs) return (xs.size, [pat.abstract xs]) - mkEMatchTheoremCore origin levelParams numParams proof patterns + mkEMatchTheoremCore origin levelParams numParams proof patterns .eqBwd /-- Given theorem with name `declName` and type of the form `∀ (a_1 ... a_n), lhs = rhs`, @@ -559,8 +602,8 @@ def mkEMatchEqTheorem (declName : Name) (normalizePattern := true) (useLhs : Boo Adds an E-matching theorem to the environment. See `mkEMatchTheorem`. -/ -def addEMatchTheorem (declName : Name) (numParams : Nat) (patterns : List Expr) : MetaM Unit := do - ematchTheoremsExt.add (← mkEMatchTheorem declName numParams patterns) +def addEMatchTheorem (declName : Name) (numParams : Nat) (patterns : List Expr) (kind : TheoremKind) : MetaM Unit := do + ematchTheoremsExt.add (← mkEMatchTheorem declName numParams patterns kind) /-- Adds an E-matching equality theorem to the environment. @@ -573,28 +616,6 @@ def addEMatchEqTheorem (declName : Name) : MetaM Unit := do def getEMatchTheorems : CoreM EMatchTheorems := return ematchTheoremsExt.getState (← getEnv) -inductive TheoremKind where - | eqLhs | eqRhs | eqBoth | eqBwd | fwd | bwd | default - deriving Inhabited, BEq, Repr - -private def TheoremKind.toAttribute : TheoremKind → String - | .eqLhs => "[grind =]" - | .eqRhs => "[grind =_]" - | .eqBoth => "[grind _=_]" - | .eqBwd => "[grind ←=]" - | .fwd => "[grind →]" - | .bwd => "[grind ←]" - | .default => "[grind]" - -private def TheoremKind.explainFailure : TheoremKind → String - | .eqLhs => "failed to find pattern in the left-hand side of the theorem's conclusion" - | .eqRhs => "failed to find pattern in the right-hand side of the theorem's conclusion" - | .eqBoth => unreachable! -- eqBoth is a macro - | .eqBwd => "failed to use theorem's conclusion as a pattern" - | .fwd => "failed to find patterns in the antecedents of the theorem" - | .bwd => "failed to find patterns in the theorem's conclusion" - | .default => "failed to find patterns" - /-- Returns the types of `xs` that are propositions. -/ private def getPropTypes (xs : Array Expr) : MetaM (Array Expr) := xs.filterMapM fun x => do @@ -702,7 +723,7 @@ where trace[grind.ematch.pattern] "{← origin.pp}: {patterns.map ppPattern}" return some { proof, patterns, numParams, symbols - levelParams, origin + levelParams, origin, kind } def mkEMatchTheoremForDecl (declName : Name) (thmKind : TheoremKind) : MetaM EMatchTheorem := do diff --git a/tests/lean/run/grind_usr.lean b/tests/lean/run/grind_usr.lean new file mode 100644 index 0000000000..4586c74002 --- /dev/null +++ b/tests/lean/run/grind_usr.lean @@ -0,0 +1,75 @@ +opaque f : Nat → Nat + +/-- +error: the modifier `usr` is only relevant in parameters for `grind only` +-/ +#guard_msgs (error) in +@[grind usr] +theorem fthm : f (f x) = f x := sorry + +/-- +info: [grind.ematch.pattern] fthm: [f (f #0)] +-/ +#guard_msgs (info) in +set_option trace.grind.ematch.pattern true in +example : f (f (f x)) = f x := by + grind only [fthm] + +/-- +info: [grind.ematch.instance] fthm: f (f (f x)) = f (f x) +[grind.ematch.instance] fthm: f (f x) = f x +-/ +#guard_msgs (info) in +set_option trace.grind.ematch.instance true in +example : f (f (f x)) = f x := by + grind only [fthm] + +#guard_msgs (info) in +-- should not instantiate anything using pattern `f (f #0)` +set_option trace.grind.ematch.instance true in +example : f x = x := by + fail_if_success grind only [fthm] + sorry + +/-- +error: the modifier `usr` is only relevant in parameters for `grind only` +-/ +#guard_msgs (error) in +example : f (f (f x)) = f x := by + grind [usr fthm] + +/-- +error: invalid use of `usr` modifier, `fthm` does not have patterns specified with the command `grind_pattern` +-/ +#guard_msgs (error) in +example : f (f (f x)) = f x := by + grind only [usr fthm] + +grind_pattern fthm => f x + +example : f (f (f x)) = f x := by + grind only [usr fthm] + +#guard_msgs (info) in +-- should not instantiate anything using pattern `f (f #0)` +set_option trace.grind.ematch.instance true in +example : f x = x := by + fail_if_success grind only [fthm] + sorry + +/-- +info: [grind.ematch.instance] fthm: f (f x) = f x +[grind.ematch.instance] fthm: f (f (f x)) = f (f x) +-/ +#guard_msgs (info) in +set_option trace.grind.ematch.instance true in +example : f x = x := by + fail_if_success grind only [usr fthm] + sorry + +/-- +error: the modifier `usr` is only relevant in parameters for `grind only` +-/ +#guard_msgs (error) in +example : f (f (f x)) = f x := by + grind [usr fthm]