lean4-htt/src/Lean/Elab/Tactic/Grind/Basic.lean
Leonardo de Moura cffacf1b10
feat: support local hypotheses in simp [h] for sym => mode (#13042)
This PR extends the `simp` tactic in `sym =>` mode to support local
hypotheses in the extra theorem list.

`simp myVariant [h]` now resolves `h` against the local context first,
falling back to global constants. Local hypotheses are converted to
rewrite rules via `mkTheoremFromExpr`, which applies the `eq_true`/
`eq_false`/`propext` adapter from #13041.

- Add `ExtraTheorem` inductive (`.const` / `.fvar`) for cache keying
- Add `resolveExtraTheorems` that checks the local context before
globals
- Update `addExtraTheorems`, `mkDefaultMethods`, `elabVariant`
signatures

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-22 21:50:31 +00:00

435 lines
16 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Elab.Tactic.Basic
public import Lean.Meta.Tactic.Grind.Main
import Lean.Meta.Tactic.Grind.Intro
public import Lean.Meta.Sym.Apply
public import Lean.Meta.Sym.Util
public import Lean.Meta.Sym.Simp.SimpM
import Init.Omega
public section
namespace Lean.Elab.Tactic.Grind
open Meta
structure Context extends Tactic.Context where
ctx : Meta.Grind.Context
sctx : Meta.Sym.Context
methods : Grind.Methods
params : Grind.Params
sym : Bool := false
open Meta.Grind (Goal)
/-- An extra theorem passed to `simp` in `sym =>` mode. -/
inductive ExtraTheorem where
| const (declName : Name)
| fvar (fvarId : FVarId)
deriving BEq, Hashable
/-- Cache key for `Sym.simp` variant invocations. -/
structure SimpCacheKey where
variant : Name
extras : Array ExtraTheorem
deriving BEq, Hashable
structure Cache where
/-- Cache for `BackwardRule`s created from declaration names (sym mode only). -/
backwardRuleName : PHashMap Name Sym.BackwardRule := {}
/-- Cache for `BackwardRule`s created from elaborated terms, keyed by syntax byte position range (sym mode only). -/
backwardRuleSyntax : PHashMap (Nat × Nat) Sym.BackwardRule := {}
/-- Per-variant persistent `Sym.simp` cache. Keyed by variant name + extra theorem names. -/
simpState : Std.HashMap SimpCacheKey Sym.Simp.State := {}
structure State where
symState : Meta.Sym.State
grindState : Meta.Grind.State
goals : List Goal
cache : Cache := {}
structure SavedState where
term : Term.SavedState
tactic : State
abbrev GrindTacticM := ReaderT Context $ StateRefT State TermElabM
abbrev GrindTactic := Syntax → GrindTacticM Unit
/-- Returns the list of goals. Goals may or may not already be assigned. -/
def getGoals : GrindTacticM (List Goal) :=
return (← get).goals
def setGoals (goals : List Goal) : GrindTacticM Unit :=
modify fun s => { s with goals }
def pruneSolvedGoals : GrindTacticM Unit := do
let gs ← getGoals
let gs := gs.filter fun g => !g.inconsistent
setGoals gs
def getUnsolvedGoals : GrindTacticM (List Goal) := do
pruneSolvedGoals
getGoals
def getUnsolvedGoalMVarIds : GrindTacticM (List MVarId) := do
pruneSolvedGoals
let goals ← getGoals
return goals.map (·.mvarId)
protected def saveState : GrindTacticM SavedState :=
return { term := (← Term.saveState), tactic := (← get) }
def SavedState.restore (b : SavedState) (restoreInfo := false) : GrindTacticM Unit := do
b.term.restore restoreInfo
set b.tactic
@[always_inline]
instance : Monad GrindTacticM :=
let i : Monad GrindTacticM := inferInstance
{ pure := i.pure, bind := i.bind }
instance : Inhabited (GrindTacticM α) where
default := fun _ _ => default
unsafe builtin_initialize grindTacElabAttribute : KeyedDeclsAttribute GrindTactic ←
mkElabAttribute GrindTactic `builtin_grind_tactic `grind_tactic
`Lean.Parser.Tactic.Grind `Lean.Elab.Tactic.Grind.GrindTactic "grind"
def mkTacticInfo (mctxBefore : MetavarContext) (goalsBefore : List MVarId) (stx : Syntax) : GrindTacticM Info :=
return Info.ofTacticInfo {
elaborator := (← read).elaborator
mctxBefore := mctxBefore
goalsBefore := goalsBefore
stx := stx
mctxAfter := (← getMCtx)
goalsAfter := (← getUnsolvedGoalMVarIds)
}
def mkInitialTacticInfo (stx : Syntax) : GrindTacticM (GrindTacticM Info) := do
let mctxBefore ← getMCtx
let goalsBefore ← getUnsolvedGoalMVarIds
return mkTacticInfo mctxBefore goalsBefore stx
@[inline] def withTacticInfoContext (stx : Syntax) (x : GrindTacticM α) : GrindTacticM α := do
withInfoContext x (← mkInitialTacticInfo stx)
structure EvalTacticFailure where
exception : Exception
state : SavedState
partial def evalGrindTactic (stx : Syntax) : GrindTacticM Unit := do
checkSystem "grind tactic execution"
profileitM Exception "grind tactic execution" (decl := stx.getKind) (← getOptions) <|
withRef stx <| withIncRecDepth <| withFreshMacroScope <| match stx with
| .node _ k _ =>
if k == nullKind then
Term.withoutTacticIncrementality true <| withTacticInfoContext stx do
stx.getArgs.forM evalGrindTactic
else withTraceNode `Elab.step (fun _ => return stx) (tag := stx.getKind.toString) do
let evalFns := grindTacElabAttribute.getEntries (← getEnv) stx.getKind
let macros := macroAttribute.getEntries (← getEnv) stx.getKind
if evalFns.isEmpty && macros.isEmpty then
throwErrorAt stx "Grind tactic `{stx.getKind}` has not been implemented"
let s ← Grind.saveState
expandEval s macros evalFns #[]
| .missing => pure ()
| _ => throwError m!"Unexpected grind tactic{indentD stx}"
where
throwExs (failures : Array EvalTacticFailure) : GrindTacticM Unit := do
if h : 0 < failures.size then
-- For macros we want to report the error from the first registered / last tried rule (#3770)
let fail := failures[failures.size - 1]
fail.state.restore (restoreInfo := true)
throw fail.exception
else
throwErrorAt stx "Unexpected syntax{indentD stx}"
@[inline] handleEx (s : SavedState) (failures : Array EvalTacticFailure) (ex : Exception) (k : Array EvalTacticFailure → GrindTacticM Unit) := do
match ex with
| .error .. =>
trace[Elab.tactic.backtrack] ex.toMessageData
let failures := failures.push ⟨ex, ← Grind.saveState⟩
s.restore (restoreInfo := true); k failures
| .internal id _ =>
if id == unsupportedSyntaxExceptionId then
-- We do not store `unsupportedSyntaxExceptionId`, see throwExs
s.restore (restoreInfo := true); k failures
else if id == abortTacticExceptionId then
for msg in (← Core.getMessageLog).toList do
trace[Elab.tactic.backtrack] msg.data
let failures := failures.push ⟨ex, ← Grind.saveState⟩
s.restore (restoreInfo := true); k failures
else
throw ex -- (*)
expandEval (s : SavedState) (macros : List _) (evalFns : List _)
(failures : Array EvalTacticFailure) : GrindTacticM Unit :=
match macros with
| [] => eval s evalFns failures
| m :: ms =>
try
withReader ({ · with elaborator := m.declName }) do
withTacticInfoContext stx do
let stx' ← adaptMacro m.value stx
evalGrindTactic stx'
catch ex => handleEx s failures ex (expandEval s ms evalFns)
eval (s : SavedState) (evalFns : List _) (failures : Array EvalTacticFailure) : GrindTacticM Unit := do
match evalFns with
| [] => throwExs failures
| evalFn::evalFns => do
try
Term.withoutTacticIncrementality true do
withReader ({ · with elaborator := evalFn.declName }) do
withTacticInfoContext stx do
evalFn.value stx
catch ex => handleEx s failures ex (eval s evalFns)
def throwNoGoalsToBeSolved : GrindTacticM α :=
throwError "No goals to be solved"
instance : MonadBacktrack SavedState GrindTacticM where
saveState := Grind.saveState
restoreState b := b.restore
/--
Runs `x` with only the first unsolved goal as the goal.
Fails if there are no goal to be solved.
-/
def focus (k : GrindTacticM α) : GrindTacticM α := do
let goal :: goals ← getUnsolvedGoals
| throwNoGoalsToBeSolved
setGoals [goal]
let a ← k
let goals' ← getUnsolvedGoals
setGoals (goals' ++ goals)
pure a
/--
Non-backtracking `try`/`catch`.
-/
@[inline] protected def tryCatch {α} (x : GrindTacticM α) (h : Exception → GrindTacticM α) : GrindTacticM α := do
try x catch ex => h ex
/--
Backtracking `try`/`catch`. This is used for the `MonadExcept` instance for `GrindTacticM`.
-/
@[inline] protected def tryCatchRestore {α} (x : GrindTacticM α) (h : Exception → GrindTacticM α)
: GrindTacticM α := do
let b ← saveState
try x catch ex => b.restore; h ex
instance : MonadExcept Exception GrindTacticM where
throw := throw
tryCatch := Grind.tryCatchRestore
/-- Execute `x` with error recovery disabled -/
def withoutRecover (x : GrindTacticM α) : GrindTacticM α :=
withReader (fun ctx => { ctx with recover := false }) x
/--
Like `throwErrorAt`, but, if recovery is enabled, logs the error instead.
-/
def throwOrLogErrorAt (ref : Syntax) (msg : MessageData) : GrindTacticM Unit := do
if (← read).recover then
logErrorAt ref msg
else
throwErrorAt ref msg
/--
Like `throwError`, but, if recovery is enabled, logs the error instead.
-/
def throwOrLogError (msg : MessageData) : GrindTacticM Unit := do
throwOrLogErrorAt (← getRef) msg
@[inline] protected def orElse (x : GrindTacticM α) (y : Unit → GrindTacticM α) : GrindTacticM α := do
try withoutRecover x catch _ => y ()
instance : OrElse (GrindTacticM α) where
orElse := Grind.orElse
instance : Alternative GrindTacticM where
failure := fun {_} => throwError "Failed"
orElse := Grind.orElse
/--
Save the current tactic state for a token `stx`.
This method is a no-op if `stx` has no position information.
We use this method to save the tactic state at punctuation such as `;`
-/
def saveTacticInfoForToken (stx : Syntax) : GrindTacticM Unit := do
unless stx.getPos?.isNone do
withTacticInfoContext stx (pure ())
/-- Elaborate `x` with `stx` on the macro stack -/
@[inline]
def withMacroExpansion (beforeStx afterStx : Syntax) (x : GrindTacticM α) : GrindTacticM α :=
withMacroExpansionInfo beforeStx afterStx do
withTheReader Term.Context (fun ctx => { ctx with macroStack := { before := beforeStx, after := afterStx } :: ctx.macroStack }) x
/-- Adapt a syntax transformation to a regular tactic evaluator. -/
def adaptExpander (exp : Syntax → GrindTacticM Syntax) : GrindTactic := fun stx => do
let stx' ← exp stx
withMacroExpansion stx stx' <| evalGrindTactic stx'
/-- Return the first goal. -/
def getMainGoal : GrindTacticM Goal := do
loop (← getGoals)
where
loop : List Goal → GrindTacticM Goal
| [] => throwNoGoalsToBeSolved
| goal :: goals => do
if goal.inconsistent then
loop goals
else
setGoals (goal :: goals)
return goal
/-- Execute `x` using the main goal local context and instances -/
def withMainContext (x : GrindTacticM α) : GrindTacticM α := do
(← getMainGoal).mvarId.withContext x
def tryTactic? (tac : GrindTacticM α) : GrindTacticM (Option α) := do
try
pure (some (← tac))
catch _ =>
pure none
def tryTactic (tac : GrindTacticM α) : GrindTacticM Bool := do
try
discard tac
pure true
catch _ =>
pure false
open Grind
/-
**Note**: Recall that `grind` uses the reducibility specified at `Config.reducible`
-/
def liftGrindM (k : GrindM α) : GrindTacticM α := do
let ctx ← read
let s ← get
let ((a, grindState), symState) ← liftMetaM <| StateRefT'.run (((Grind.withGTransparency k) ctx.methods.toMethodsRef ctx.ctx |>.run s.grindState) ctx.sctx) s.symState
modify fun s => { s with grindState, symState }
return a
def replaceMainGoal (goals : List Goal) : GrindTacticM Unit := do
let goals := goals.filter fun goal => !goal.inconsistent
let (_ :: goals') ← getGoals | throwNoGoalsToBeSolved
modify fun s => { s with goals := goals ++ goals' }
def liftGoalM (k : GoalM α) : GrindTacticM α := do
let goal ← getMainGoal
let (a, goal) ← liftGrindM <| k.run goal
replaceMainGoal [goal]
return a
def liftAction (a : Action) : GrindTacticM Unit := do
let goal ← getMainGoal
let ka := fun _ => throwError "tactic is not applicable"
let kp := fun goal => return .stuck [goal]
match (← liftGrindM <| a goal ka kp) with
| .closed _ => replaceMainGoal []
| .stuck gs => replaceMainGoal gs
def done : GrindTacticM Unit := do
pruneSolvedGoals
let goals ← getGoals
unless goals.isEmpty do
let params := (← read).params
let results ← liftGrindM do goals.mapM fun goal => Grind.mkResult params (some goal)
let msgs ← results.mapM fun result => result.toMessageData
let msg := MessageData.joinSep msgs m!"\n\n"
logError <| MessageData.tagged `Tactic.unsolvedGoals <| m!"unsolved goals\n{msg}"
goals.forM fun goal => admitGoal goal.mvarId
throwAbortTactic
/--
Runs `tactic` with only the first unsolved goal as the goal, and expects it leave no goals.
Fails if there are no goal to be solved.
-/
def focusAndDone (tactic : GrindTacticM α) : GrindTacticM α :=
focus do
let a ← tactic
done
pure a
/-- Close the main goal using the given tactic. If it fails, log the error and `admit` -/
def closeUsingOrAdmit (tac : GrindTacticM Unit) : GrindTacticM Unit := do
/- Important: we must define `closeUsingOrAdmit` before we define
the instance `MonadExcept` for `GrindTacticM` since it backtracks the state including error messages. -/
let goal :: goals ← getUnsolvedGoals | throwNoGoalsToBeSolved
tryCatchRuntimeEx
(focusAndDone tac)
fun ex => do
if (← read).recover then
logException ex
admitGoal goal.mvarId
setGoals goals
else
throw ex
def GrindTacticM.run (x : GrindTacticM α) (ctx : Context) (s : State) : TermElabM (α × State) :=
x ctx |>.run s
def mkEvalTactic' (elaborator : Name) (params : Params) : TermElabM (Goal → TSyntax `grind → GrindM (List Goal)) := do
let termState ← getThe Term.State
let termCtx ← readThe Term.Context
let eval (goal : Goal) (stx : TSyntax `grind) : GrindM (List Goal) := do
let methods ← getMethods
let grindCtx ← readThe Meta.Grind.Context
let symCtx ← readThe Meta.Sym.Context
let grindState ← get
let symState ← getThe Sym.State
-- **Note**: we discard changes to `Term.State`
let (subgoals, grindState', symState') ← Term.TermElabM.run' (ctx := termCtx) (s := termState) do
let (_, s) ← GrindTacticM.run
(ctx := { recover := false, methods, ctx := grindCtx, sctx := symCtx, params, elaborator })
(s := { grindState, symState, goals := [goal] }) do
evalGrindTactic stx.raw
pruneSolvedGoals
return (s.goals, s.grindState, s.symState)
set grindState'
set symState'
return subgoals
return eval
def mkEvalTactic (params : Params) : TacticM (Goal → TSyntax `grind → GrindM (List Goal)) := do
mkEvalTactic' (← read).elaborator params
def GrindTacticM.runAtGoal (mvarId : MVarId) (params : Params) (k : GrindTacticM α) (sym : Bool := false) : TacticM (α × State) := do
let evalTactic ← mkEvalTactic params
/-
**Note**: We don't want to close branches using `sorry` after applying `intros + assertAll`.
Reconsider the option `useSorry`.
-/
let params' := { params with config.useSorry := false }
let (methods, ctx, sctx, state) ← liftMetaM <| GrindM.runAtGoal mvarId params' (evalTactic? := some evalTactic) fun goal => do
let goals ←
if sym then
/- In sym mode, skip eager intros + by-contradiction. The user controls intro/internalize.
Preprocess for maximal term sharing, required by Sym operations (introN, BackwardRule.apply, etc.). -/
let mvarId ← Sym.preprocessMVar goal.mvarId
pure [{ goal with mvarId }]
else
let a : Action := Action.intros 0 >> Action.assertAll
match (← a.run goal) with
| .closed _ => pure []
| .stuck gs => pure gs
let methods ← getMethods
let ctx ← readThe Meta.Grind.Context
/- Restore original config -/
let ctx := { ctx with config := params.config }
let sctx ← readThe Meta.Sym.Context
let grindState ← get
let symState ← getThe Sym.State
return (methods, ctx, sctx, { grindState, symState, goals })
let tctx ← read
k { tctx with methods, ctx, sctx, params, sym } |>.run state
end Lean.Elab.Tactic.Grind