From e2b5747f4bbca5803ade20939a4bd039f4c3ac2d Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 18 Oct 2025 10:02:36 -0700 Subject: [PATCH] feat: `evalTactic` in `GrindM` (#10833) This PR implements infrastructure for evaluating `grind` tactics in the `GrindM` monad. We are going to use it to check whether auto-generated tactics can effectively close the original goal. --- src/Lean/Elab/Tactic/Grind/Basic.lean | 28 ++++++++++++++++++++++++- src/Lean/Meta/Tactic/Grind/Action.lean | 20 +++++++++++++++++- src/Lean/Meta/Tactic/Grind/Main.lean | 12 ++++++----- src/Lean/Meta/Tactic/Grind/Types.lean | 29 ++++++++++++++++++++++++++ 4 files changed, 82 insertions(+), 7 deletions(-) diff --git a/src/Lean/Elab/Tactic/Grind/Basic.lean b/src/Lean/Elab/Tactic/Grind/Basic.lean index 536a7f48fe..d1825c59c8 100644 --- a/src/Lean/Elab/Tactic/Grind/Basic.lean +++ b/src/Lean/Elab/Tactic/Grind/Basic.lean @@ -341,8 +341,34 @@ def liftSearchM (k : SearchM α) : GrindTacticM α := do replaceMainGoal [state.goal] return a +def GrindTacticM.run (x : GrindTacticM α) (ctx : Context) (s : State) : TermElabM (α × State) := + x ctx |>.run s + +def mkEvalTactic' (elaborator : Name) (params : Params) : TermElabM (Goal → TSyntax `grind → GrindM (List Goal)) := do + let termState ← getThe Term.State + let termCtx ← readThe Term.Context + let eval (goal : Goal) (stx : TSyntax `grind) : GrindM (List Goal) := do + let methods ← getMethods + let grindCtx ← readThe Meta.Grind.Context + let grindState ← get + -- **Note**: we discard changes to `Term.State` + let (subgoals, grindState') ← Term.TermElabM.run' (ctx := termCtx) (s := termState) do + let (_, s) ← GrindTacticM.run + (ctx := { methods, ctx := grindCtx, params, elaborator }) + (s := { state := grindState, goals := [goal] }) do + evalGrindTactic stx.raw + pruneSolvedGoals + return (s.goals, s.state) + set grindState' + return subgoals + return eval + +def mkEvalTactic (params : Params) : TacticM (Goal → TSyntax `grind → GrindM (List Goal)) := do + mkEvalTactic' (← read).elaborator params + def GrindTacticM.runAtGoal (mvarId : MVarId) (params : Params) (k : GrindTacticM α) : TacticM (α × State) := do - let (methods, ctx, state) ← liftMetaM <| GrindM.runAtGoal mvarId params fun goal => do + let evalTactic ← mkEvalTactic params + let (methods, ctx, state) ← liftMetaM <| GrindM.runAtGoal mvarId params (evalTactic? := some evalTactic) fun goal => do let methods ← getMethods -- **Note**: We use `withCheapCasesOnly` to ensure multiple goals are not created. -- We will add support for this case in the future. diff --git a/src/Lean/Meta/Tactic/Grind/Action.lean b/src/Lean/Meta/Tactic/Grind/Action.lean index 3dc321b65d..b988ebacab 100644 --- a/src/Lean/Meta/Tactic/Grind/Action.lean +++ b/src/Lean/Meta/Tactic/Grind/Action.lean @@ -247,7 +247,7 @@ A terminal action which closes the goal or not. This kind of action may make progress, but we only include `mkTac` into the resulting tactic sequence if it closed the goal. -/ -public def terminalAction (check : GoalM Bool) (mkTac : GrindM (TSyntax `grind)) : Action := fun goal kna kp => do +def terminalAction (check : GoalM Bool) (mkTac : GrindM (TSyntax `grind)) : Action := fun goal kna kp => do let (progress, goal') ← GoalM.run goal check if progress then if goal'.inconsistent then @@ -257,6 +257,24 @@ public def terminalAction (check : GoalM Bool) (mkTac : GrindM (TSyntax `grind)) else kna goal' +/-- +Helper action that checks whether the resulting tactic script produced by its continuation +can close the original goal. +-/ +def checkTactic : Action := fun goal _ kp => do + let s ← saveState + let r ← kp goal + match r with + | .closed seq => + let tac ← mkGrindNext seq + Lean.withoutModifyingState do + s.restore + let subgoals ← evalTactic goal tac + unless subgoals.isEmpty do + throwError "generated tactic cannot close the goal{indentD tac}\nInitial goal\n{goal.mvarId}\nPending subgoals\n{subgoals.map (·.mvarId)}" + return r + | _ => return r + section /-! Some sanity check properties. diff --git a/src/Lean/Meta/Tactic/Grind/Main.lean b/src/Lean/Meta/Tactic/Grind/Main.lean index 5edb0fbbe5..69edf84167 100644 --- a/src/Lean/Meta/Tactic/Grind/Main.lean +++ b/src/Lean/Meta/Tactic/Grind/Main.lean @@ -45,9 +45,11 @@ def mkParams (config : Grind.Config) : MetaM Params := do let symPrios ← getGlobalSymbolPriorities return { config, norm, normProcs, symPrios } -def mkMethods : CoreM Methods := do +def mkMethods (evalTactic? : Option EvalTactic := none) : CoreM Methods := do let builtinPropagators ← builtinPropagatorsRef.get + let evalTactic : EvalTactic := evalTactic?.getD EvalTactic.skip return { + evalTactic propagateUp := fun e => do propagateForallPropUp e propagateReflCmp e @@ -75,7 +77,7 @@ private def discharge? (e : Expr) : SimpM (Option Expr) := do else return none -def GrindM.run (x : GrindM α) (params : Params) : MetaM α := do +def GrindM.run (x : GrindM α) (params : Params) (evalTactic? : Option EvalTactic := none) : MetaM α := do let (falseExpr, scState) := shareCommonAlpha (mkConst ``False) {} let (trueExpr, scState) := shareCommonAlpha (mkConst ``True) scState let (bfalseExpr, scState) := shareCommonAlpha (mkConst ``Bool.false) scState @@ -88,7 +90,7 @@ def GrindM.run (x : GrindM α) (params : Params) : MetaM α := do let simp := params.norm let config := params.config let symPrios := params.symPrios - x (← mkMethods).toMethodsRef { config, simpMethods, simp, trueExpr, falseExpr, natZExpr, btrueExpr, bfalseExpr, ordEqExpr, intExpr, symPrios } + x (← mkMethods evalTactic?).toMethodsRef { config, simpMethods, simp, trueExpr, falseExpr, natZExpr, btrueExpr, bfalseExpr, ordEqExpr, intExpr, symPrios } |>.run' { scState } private def mkCleanState (mvarId : MVarId) (params : Params) : MetaM Clean.State := mvarId.withContext do @@ -217,11 +219,11 @@ def mkResult (params : Params) (failure? : Option Goal) : GrindM Result := do logInfo msg return { failure?, issues, config := params.config, trace, counters, simp, splitDiags } -def GrindM.runAtGoal (mvarId : MVarId) (params : Params) (k : Goal → GrindM α) : MetaM α := do +def GrindM.runAtGoal (mvarId : MVarId) (params : Params) (k : Goal → GrindM α) (evalTactic? : Option EvalTactic := none) : MetaM α := do let go : GrindM α := withReducible do let goal ← initCore mvarId params k goal - go.run params + go.run params (evalTactic? := evalTactic?) def main (mvarId : MVarId) (params : Params) : MetaM Result := do profileitM Exception "grind" (← getOptions) do GrindM.runAtGoal mvarId params fun goal => do diff --git a/src/Lean/Meta/Tactic/Grind/Types.lean b/src/Lean/Meta/Tactic/Grind/Types.lean index 3095c781db..db5fbfd77c 100644 --- a/src/Lean/Meta/Tactic/Grind/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/Types.lean @@ -219,6 +219,9 @@ structure State where -/ anchors : PHashMap ExprPtr UInt64 := {} +instance : Nonempty State := + .intro {} + private opaque MethodsRefPointed : NonemptyType.{0} def MethodsRef : Type := MethodsRefPointed.type instance : Nonempty MethodsRef := by exact MethodsRefPointed.property @@ -228,6 +231,26 @@ abbrev GrindM := ReaderT MethodsRef $ ReaderT Context $ StateRefT State MetaM @[inline] def mapGrindM [MonadControlT GrindM m] [Monad m] (f : {α : Type} → GrindM α → GrindM α) {α} (x : m α) : m α := controlAt GrindM fun runInBase => f <| runInBase x +/-- +Backtrackable state for the `GrindM` monad. +-/ +structure SavedState where + «meta» : Meta.SavedState + grind : State + deriving Nonempty + +protected def saveState : GrindM SavedState := + return { «meta» := (← Meta.saveState), grind := (← get) } + +/-- Restore backtrackable parts of the state. -/ +def SavedState.restore (b : SavedState) : GrindM Unit := do + b.meta.restore + set b.grind + +instance : MonadBacktrack SavedState GrindM where + saveState := Grind.saveState + restoreState s := s.restore + /-- `withoutReportingMVarIssues x` executes `x` without reporting metavariables found during internalization. See comment at `Grind.Context.reportMVarIssue` for additional details. @@ -1325,10 +1348,13 @@ def forEachEqcRoot (f : ENode → GoalM Unit) : GoalM Unit := do f n abbrev Propagator := Expr → GoalM Unit +abbrev EvalTactic := Goal → TSyntax `grind → GrindM (List Goal) +def EvalTactic.skip : EvalTactic := fun goal _ => return [goal] structure Methods where propagateUp : Propagator := fun _ => return () propagateDown : Propagator := fun _ => return () + evalTactic : EvalTactic := EvalTactic.skip deriving Inhabited def Methods.toMethodsRef (m : Methods) : MethodsRef := @@ -1346,6 +1372,9 @@ def propagateUp (e : Expr) : GoalM Unit := do def propagateDown (e : Expr) : GoalM Unit := do (← getMethods).propagateDown e +def evalTactic (goal : Goal) (stx : TSyntax `grind) : GrindM (List Goal) := do + (← getMethods).evalTactic goal stx + /-- Returns expressions in the given expression equivalence class. -/ partial def Goal.getEqc (goal : Goal) (e : Expr) (sort := false) : List Expr := let eqc := go e e #[]