feat: support for incrementally processing hypotheses in grind (#11787)

This PR adds support for incrementally processing local declarations in
`grind`. Instead of processing all hypotheses at once during goal
initialization, `grind` now tracks which local declarations have been
processed via `Goal.nextDeclIdx` and provides APIs to process new
hypotheses incrementally.
This feature will be used by the new `SymM` monad for efficient symbolic
simulation.
This commit is contained in:
Leonardo de Moura 2025-12-23 18:50:22 -08:00 committed by GitHub
parent c34e4cf0f7
commit ce56e2139e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 85 additions and 36 deletions

View file

@ -239,7 +239,7 @@ def grind
return ()
mvarId.withContext do
let params ← mkGrindParams config only ps mvarId
Grind.withProtectedMCtx config.abstractProof mvarId fun mvarId' => do
Grind.withProtectedMCtx config mvarId fun mvarId' => do
let finalize (result : Grind.Result) : TacticM Unit := do
if result.hasFailed then
throwError "`grind` failed\n{← result.toMessageData}"
@ -340,7 +340,7 @@ def evalGrindTraceCore (stx : Syntax) (trace := true) (verbose := true) (useSorr
| _ => true
let mvarId ← getMainGoal
let params ← mkGrindParams config only paramStxs mvarId
Grind.withProtectedMCtx config.abstractProof mvarId fun mvarId' => do
Grind.withProtectedMCtx config mvarId fun mvarId' => do
let (tacs, _) ← Grind.GrindTacticM.runAtGoal mvarId' params do
let finish ← Grind.Action.mkFinish
let goal :: _ ← Grind.getGoals

View file

@ -491,8 +491,8 @@ def getAt? (lctx : LocalContext) (i : Nat) : Option LocalDecl :=
| none => pure b
| some decl => f decl b
@[specialize] def forM [Monad m] (lctx : LocalContext) (f : LocalDecl → m PUnit) : m PUnit :=
lctx.decls.forM fun decl => match decl with
@[specialize] def forM [Monad m] (lctx : LocalContext) (f : LocalDecl → m PUnit) (start := 0) : m PUnit :=
lctx.decls.forM (start := start) fun decl => match decl with
| none => pure PUnit.unit
| some decl => f decl

View file

@ -244,19 +244,34 @@ def Result.toMessageData (result : Result) : MetaM MessageData := do
/--
When `Config.revert := false`, we preprocess the hypotheses, and add them to the `grind` state.
It starts at `goal.nextDeclIdx`. If `num?` is `some num`, then at most `num` local declarations are processed.
Otherwise, all remaining local declarations are processed.
Remark: this function assumes the local context does not contains holes with `none` in `decls`.
-/
private def addHypotheses (goal : Goal) : GrindM Goal := GoalM.run' goal do
let mvarDecl ← goal.mvarId.getDecl
for localDecl in mvarDecl.lctx do
if (← isInconsistent) then return ()
let type := localDecl.type
if (← isProp type) then
let r ← preprocessHypothesis type
match r.proof? with
| none => add r.expr localDecl.toExpr
| some h => add r.expr <| mkApp4 (mkConst ``Eq.mp [0]) type r.expr h localDecl.toExpr
else
internalizeLocalDecl localDecl
private def addHypotheses (goal : Goal) (num? : Option Nat := none) : GrindM Goal := GoalM.run' goal do
discard <| go.run
where
go : ExceptT Unit GoalM Unit := do
let mvarDecl ← goal.mvarId.getDecl
mvarDecl.lctx.forM (start := goal.nextDeclIdx) fun localDecl => do
if (← isInconsistent) then
setNextDeclToEnd
throwThe Unit () -- interrupt
if let some num := num? then
if localDecl.index >= goal.nextDeclIdx + num then
modify fun goal => { goal with nextDeclIdx := localDecl.index }
throwThe Unit () -- interrupt
unless localDecl.isImplementationDetail do
let type := localDecl.type
if (← isProp type) then
let r ← preprocessHypothesis type
match r.proof? with
| none => add r.expr localDecl.toExpr
| some h => add r.expr <| mkApp4 (mkConst ``Eq.mp [0]) type r.expr h localDecl.toExpr
else
internalizeLocalDecl localDecl
setNextDeclToEnd -- Processed all local decls
private def initCore (mvarId : MVarId) (params : Params) : GrindM Goal := do
/-
@ -310,14 +325,24 @@ See issue #11806 for a motivating example.
-/
def withProtectedMCtx [Monad m] [MonadControlT MetaM m] [MonadLiftT MetaM m]
[MonadExcept Exception m] [MonadRuntimeException m]
(abstractProof : Bool) (mvarId : MVarId) (k : MVarId → m α) : m α := do
(config : Grind.Config) (mvarId : MVarId) (k : MVarId → m α) : m α := do
/-
**Note**: `instantiateGoalMVars` here also instantiates mvars occurring in hypotheses.
This is particularly important when using `grind -revert`.
-/
let mvarId ← mvarId.instantiateGoalMVars
let mvarId ← mvarId.abstractMVars
let mvarId ← mvarId.clearImplDetails
let mut mvarId ← mvarId.instantiateGoalMVars
/-
**TODO**: It would be nice to remove the following step, but
some tests break with unknown metavariable error when this
step is removed. The main issue is the `withNewMCtxDepth` step at
`main`.
-/
mvarId ← mvarId.abstractMVars
if config.revert then
/-
**Note**: We now skip implementation details at `addHypotheses`
-/
mvarId ← mvarId.clearImplDetails
tryCatchRuntimeEx (main mvarId) fun ex => do
mvarId.admit
throw ex
@ -327,22 +352,23 @@ where
let (a, val) ← withNewMCtxDepth do
let mvar' ← mkFreshExprSyntheticOpaqueMVar type
let a ← k mvar'.mvarId!
let val ← finalize mvar'
let val ← instantiateMVarsProfiling mvar'
return (a, val)
let val ← finalize val
(mvarId.assign val : MetaM _)
return a
finalize (mvar' : Expr) : MetaM Expr := do
trace[grind.debug.proof] "{← instantiateMVars mvar'}"
let type ← inferType mvar'
finalize (val : Expr) : MetaM Expr := do
trace[grind.debug.proof] "{val}"
let type ← inferType val
-- `grind` proofs are often big, if `abstractProof` is true, we create an auxiliary theorem.
let val ← if !abstractProof then
instantiateMVarsProfiling mvar'
let val ← if !config.abstractProof then
pure val
else if (← isProp type) then
mkAuxTheorem type (← instantiateMVarsProfiling mvar') (zetaDelta := true)
mkAuxTheorem type val (zetaDelta := true)
else
let auxName ← mkAuxDeclName `grind
mkAuxDefinition auxName type (← instantiateMVarsProfiling mvar') (zetaDelta := true)
mkAuxDefinition auxName type val (zetaDelta := true)
return val
end Lean.Meta.Grind

View file

@ -945,6 +945,8 @@ structure DelayedTheoremInstance where
/-- The `grind` goal. -/
structure Goal where
mvarId : MVarId
/-- Next local declaration index to process. -/
nextDeclIdx : Nat := 0
canon : Canon.State := {}
enodeMap : ENodeMap := default
exprs : PArray Expr := {}
@ -1010,6 +1012,30 @@ abbrev GoalM := StateRefT Goal GrindM
@[inline] def GoalM.run' (goal : Goal) (x : GoalM Unit) : GrindM Goal :=
goal.mvarId.withContext do StateRefT'.run' (x *> get) goal
/--
Sets `nextDeclIdx` to point past the last local declaration in the local context.
This marks all existing local declarations as already processed by `grind`. Use this when
initializing a goal whose hypotheses should not be processed or after internalizing all of them.
-/
def Goal.setNextDeclToEnd (g : Goal) : MetaM Goal := do
let mvarDecl ← g.mvarId.getDecl
return { g with nextDeclIdx := mvarDecl.lctx.decls.size }
def setNextDeclToEnd : GoalM Unit := do
let mvarDecl ← (← get).mvarId.getDecl
modify fun g => { g with nextDeclIdx := mvarDecl.lctx.decls.size }
/--
Returns `true` if the goal has local declarations that have not yet been processed by `grind`.
A local declaration is "pending" if its index is greater than or equal to `nextDeclIdx`.
This is used to determine whether `grind` needs to internalize new hypotheses.
-/
def Goal.hasPendingLocalDecls (g : Goal) : MetaM Bool := do
let mvarDecl ← g.mvarId.getDecl
return g.nextDeclIdx < mvarDecl.lctx.decls.size
def updateLastTag : GoalM Unit := do
if (← isTracingEnabledFor `grind) then
let currTag ← (← get).mvarId.getTag
@ -1744,6 +1770,7 @@ inductive ActionResult where
`gs` are subgoals that could not be closed. They are used for producing error messages.
-/
stuck (gs : List Goal)
deriving Inhabited
abbrev ActionCont : Type :=
Goal → GrindM ActionResult

View file

@ -2,15 +2,11 @@ module
reset_grind_attrs%
attribute [grind] List.not_mem_nil
/--
error: Tactic `grind` failed: the goal mentions the declaration `incList`, which is being defined. To avoid circular reasoning, try rewriting the goal to eliminate `incList` before using `grind`.
as✝ : List Nat
a : Nat
as : List Nat
⊢ ∀ (a : Nat), a ∈ (incList as).val → a > 0
/-!
Note: the following definition used to fail because the goal mentions the
declaration `incList` being defined.
-/
#guard_msgs (error) in
def incList (as : List Nat) : { as : List Nat // ∀ a, a ∈ as → a > 0 } :=
match as with
| [] => ⟨[], by grind⟩