From 007bd18bcb0eed2cf9b718101f4a988397cd38e2 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 5 Apr 2025 18:01:37 -0700 Subject: [PATCH] feat: extensible `evalAndSuggest` for `try?` (#7831) This PR adds extensibility to the `evalAndSuggest` procedure used to implement `try?`. Users can now implement their own handlers for any tactic. The new test demonstrates how this feature works. --- src/Lean/Elab/Tactic/Try.lean | 88 ++++++++++++++++++++-------- tests/lean/run/grind_try_extend.lean | 31 ++++++++++ 2 files changed, 96 insertions(+), 23 deletions(-) create mode 100644 tests/lean/run/grind_try_extend.lean diff --git a/src/Lean/Elab/Tactic/Try.lean b/src/Lean/Elab/Tactic/Try.lean index 2ddff5b6ff..649cfd1f25 100644 --- a/src/Lean/Elab/Tactic/Try.lean +++ b/src/Lean/Elab/Tactic/Try.lean @@ -216,19 +216,26 @@ structure Ctx where terminal : Bool config : Try.Config -abbrev M := ReaderT Ctx TacticM +abbrev TryTacticM := ReaderT Ctx TacticM +abbrev TryTactic := TSyntax `tactic → TryTacticM (TSyntax `tactic) -instance : MonadBacktrack SavedState M where +instance : MonadBacktrack SavedState TryTacticM where saveState := fun _ => saveState restoreState s := fun _ => restoreState s -abbrev withNonTerminal (x : M α) : M α := +abbrev withNonTerminal (x : TryTacticM α) : TryTacticM α := withReader (fun c => { c with terminal := false}) x --- TODO: polymorphic `Tactic.focus` -abbrev focus (x : M α) : M α := fun ctx => Tactic.focus (x ctx) +builtin_initialize tryTacticElabAttribute : KeyedDeclsAttribute TryTactic ← do + unsafe mkElabAttribute TryTactic `builtin_try_tactic `try_tactic `Lean.Parser.Tactic `Lean.Elab.Tactic.Try.TryTactic "try_tactic" -def observing (x : M α) : M (TacticResult α) := do +private def getEvalFns (kind : SyntaxNodeKind) : CoreM (List (KeyedDeclsAttribute.AttributeEntry TryTactic)) := do + return tryTacticElabAttribute.getEntries (← getEnv) kind + +-- TODO: polymorphic `Tactic.focus` +abbrev focus (x : TryTacticM α) : TryTacticM α := fun ctx => Tactic.focus (x ctx) + +def observing (x : TryTacticM α) : TryTacticM (TacticResult α) := do let s ← saveState try let e ← x @@ -271,7 +278,7 @@ private def merge? (tac1 tac2 : TSyntax `tactic) : Option (TSyntax `tactic) := else none -private def mergeAll? (tacs : Array (TSyntax `tactic)) : M (Option (TSyntax `tactic)) := do +private def mergeAll? (tacs : Array (TSyntax `tactic)) : TryTacticM (Option (TSyntax `tactic)) := do if !(← read).config.merge || tacs.isEmpty then return none let tac0 := tacs[0]! @@ -304,7 +311,7 @@ private def isOnlyAndNonOnly (tacs2 : Array (TSyntax `tactic)) : Bool := Id.run else return false -private def mkChainResult (tac1 : TSyntax `tactic) (tacss2 : Array (TSyntax `tactic)) : M (TSyntax `tactic) := do +private def mkChainResult (tac1 : TSyntax `tactic) (tacss2 : Array (TSyntax `tactic)) : TryTacticM (TSyntax `tactic) := do let tacss2 := tacss2.map getSuggestionsCore if (← isTracingEnabledFor `try.debug) then trace[try.debug] "mkChainResultCore tac1{indentD tac1}" @@ -343,7 +350,7 @@ private def mkChainResult (tac1 : TSyntax `tactic) (tacss2 : Array (TSyntax `tac (_, acc) ← go tacss2 0 [] none |>.run acc mkTrySuggestions acc where - go (tacss2 : Array (Array (TSyntax `tactic))) (i : Nat) (acc : List (TSyntax `tactic)) (kind? : Option SyntaxNodeKind) : StateT (Array (TSyntax `tactic)) M Unit := do + go (tacss2 : Array (Array (TSyntax `tactic))) (i : Nat) (acc : List (TSyntax `tactic)) (kind? : Option SyntaxNodeKind) : StateT (Array (TSyntax `tactic)) TryTacticM Unit := do if (← get).size > (← read).config.max then return () else if h : i < tacss2.size then @@ -371,7 +378,7 @@ where $tacs2*) modify (·.push tac) -private def evalSuggestGrindTrace (tac : TSyntax `tactic) : M (TSyntax `tactic) := do +private def evalSuggestGrindTrace : TryTactic := fun tac => do match tac with | `(tactic| grind? $configStx:optConfig $[only%$only]? $[ [$params:grindParam,*] ]? $[on_failure $fallback?]?) => let config ← elabGrindConfig configStx @@ -386,7 +393,7 @@ private def evalSuggestGrindTrace (tac : TSyntax `tactic) : M (TSyntax `tactic) return tac | _ => throwUnsupportedSyntax -private def evalSuggestSimpTrace (tac : TSyntax `tactic) : M (TSyntax `tactic) := do (← getMainGoal).withContext do +private def evalSuggestSimpTrace : TryTactic := fun tac => do (← getMainGoal).withContext do match tac with | `(tactic| simp? $_:optConfig $[only%$only]? $[[$args,*]]? $(loc)?) => let tac ← simpTraceToSimp tac @@ -401,10 +408,10 @@ private def evalSuggestSimpTrace (tac : TSyntax `tactic) : M (TSyntax `tactic) : | _ => throwUnsupportedSyntax @[extern "lean_eval_suggest_tactic"] -- forward definition to avoid mutual block -opaque evalSuggest (tac : TSyntax `tactic) : M (TSyntax `tactic) +opaque evalSuggest : TryTactic /-- `evalSuggest` for `tac1 <;> tac2` -/ -private def evalSuggestChain (tac1 tac2 : TSyntax `tactic) : M (TSyntax `tactic) := focus do +private def evalSuggestChain (tac1 tac2 : TSyntax `tactic) : TryTacticM (TSyntax `tactic) := focus do unless (← read).terminal do throwError "invalid `<;>` occurrence in non-terminal position for `try?` script{indentD (← read).root}" let tac1 ← withNonTerminal do evalSuggest tac1 @@ -422,7 +429,7 @@ private def evalSuggestChain (tac1 tac2 : TSyntax `tactic) : M (TSyntax `tactic) mkChainResult tac1 tac2s /-- `evalSuggest` for a sequence of tactics. -/ -private def evalSuggestSeq (tacs : Array (TSyntax `tactic)) : M (TSyntax `tactic) := do +private def evalSuggestSeq (tacs : Array (TSyntax `tactic)) : TryTacticM (TSyntax `tactic) := do if (← read).terminal then let mut result := #[] for i in [:tacs.size - 1] do @@ -433,10 +440,10 @@ private def evalSuggestSeq (tacs : Array (TSyntax `tactic)) : M (TSyntax `tactic else mkSeq (← tacs.mapM evalSuggest) (terminal := false) -private def evalSuggestSeqCore (tacs : Array Syntax) : M (TSyntax `tactic) := do +private def evalSuggestSeqCore (tacs : Array Syntax) : TryTacticM (TSyntax `tactic) := do evalSuggestSeq (tacs.map fun tac => ⟨tac⟩) -private def evalSuggestTacticSeq (s : TSyntax ``Parser.Tactic.tacticSeq) : M (TSyntax `tactic) := do +private def evalSuggestTacticSeq (s : TSyntax ``Parser.Tactic.tacticSeq) : TryTacticM (TSyntax `tactic) := do let tacs ← match s with | `(tacticSeq| { $t;* }) => pure t.getElems | `(tacticSeq| $t;*) => pure t.getElems @@ -444,30 +451,30 @@ private def evalSuggestTacticSeq (s : TSyntax ``Parser.Tactic.tacticSeq) : M (TS evalSuggestSeq tacs /-- `evalSuggest` for `first` tactic. -/ -private partial def evalSuggestFirst (tacs : Array (TSyntax ``Parser.Tactic.tacticSeq)) : M (TSyntax `tactic) := do +private partial def evalSuggestFirst (tacs : Array (TSyntax ``Parser.Tactic.tacticSeq)) : TryTacticM (TSyntax `tactic) := do if tacs.size == 0 then throwError "`first` expects at least one argument" go 0 where - go (i : Nat) : M (TSyntax `tactic) := do + go (i : Nat) : TryTacticM (TSyntax `tactic) := do if i = tacs.size - 1 then evalSuggestTacticSeq tacs[i]! else evalSuggestTacticSeq tacs[i]! <|> go (i+1) /-- `evalSuggest` for `try` tactic. -/ -private partial def evalSuggestTry (tac : TSyntax ``Parser.Tactic.tacticSeq) : M (TSyntax `tactic) := do +private partial def evalSuggestTry (tac : TSyntax ``Parser.Tactic.tacticSeq) : TryTacticM (TSyntax `tactic) := do (do evalSuggestTacticSeq tac) <|> `(tactic| skip) /-- `evalSuggest` for `attempt_all` tactic. -/ -private partial def evalSuggestAttemptAll (tacs : Array (TSyntax ``Parser.Tactic.tacticSeq)) : M (TSyntax `tactic) := do +private partial def evalSuggestAttemptAll (tacs : Array (TSyntax ``Parser.Tactic.tacticSeq)) : TryTacticM (TSyntax `tactic) := do unless (← read).terminal do throwError "invalid occurrence of `attempt_all` in non-terminal position for `try?` script{indentD (← read).root}" go 0 none #[] where - go (i : Nat) (saved? : Option SavedState) (acc : Array (TSyntax `tactic)) : M (TSyntax `tactic) := do + go (i : Nat) (saved? : Option SavedState) (acc : Array (TSyntax `tactic)) : TryTacticM (TSyntax `tactic) := do -- Remark: we considered using `acc.size < (← read).config.max` here to truncate the search, -- but it had a negative effect when using `<;>`. We could miss a preferred solution `induction e <;> grind` -- because only a subset of the goals were solved by simpler tactics such as `rfl` and `simp`. @@ -485,10 +492,45 @@ where else throwError "`attempt_all` failed" +private partial def evalSuggestDefault (tac : TSyntax `tactic) : TryTacticM (TSyntax `tactic) := do + let kind := tac.raw.getKind + match (← getEvalFns kind) with + | [] => evalSuggestAtomic tac -- lift regular tactic + | evalFns => eval (← Tactic.saveState) evalFns #[] +where + throwExs (failures : Array EvalTacticFailure) : TryTacticM (TSyntax `tactic) := do + if h : 0 < failures.size then + let fail := failures[failures.size - 1] + fail.state.restore (restoreInfo := true) + throw fail.exception + else + throwErrorAt tac "unexpected syntax {indentD tac}" + + eval (s : SavedState) (evalFns : List _) (failures : Array EvalTacticFailure) : TryTacticM (TSyntax `tactic) := do + match evalFns with + | [] => throwExs failures + | evalFn::evalFns => + try + withTheReader Tactic.Context ({ · with elaborator := evalFn.declName }) do + evalFn.value tac + catch ex => match ex with + | .error .. => + let failures := failures.push ⟨ex, ← Tactic.saveState⟩ + s.restore (restoreInfo := true); eval s evalFns failures + | .internal id _ => + if id == unsupportedSyntaxExceptionId then + s.restore (restoreInfo := true); eval s evalFns failures + else if id == abortTacticExceptionId then + let failures := failures.push ⟨ex, ← Tactic.saveState⟩ + s.restore (restoreInfo := true); eval s evalFns failures + else + throw ex + -- `evalSuggest` implementation @[export lean_eval_suggest_tactic] -private partial def evalSuggestImpl (tac : TSyntax `tactic) : M (TSyntax `tactic) := do +private partial def evalSuggestImpl : TryTactic := fun tac => do trace[try.debug] "{tac}" + -- TODO: Implement builtin cases using `[builtin_try_tactic]` after update-stage0 match tac with | `(tactic| $tac1 <;> $tac2) => evalSuggestChain tac1 tac2 | `(tactic| first $[| $tacs]*) => evalSuggestFirst tacs @@ -507,7 +549,7 @@ private partial def evalSuggestImpl (tac : TSyntax `tactic) : M (TSyntax `tactic else if k == ``Parser.Tactic.exact? then evalSuggestExact else - evalSuggestAtomic tac + evalSuggestDefault tac if (← read).terminal then unless (← getGoals).isEmpty do throwError "unsolved goals" diff --git a/tests/lean/run/grind_try_extend.lean b/tests/lean/run/grind_try_extend.lean new file mode 100644 index 0000000000..2caff6deed --- /dev/null +++ b/tests/lean/run/grind_try_extend.lean @@ -0,0 +1,31 @@ +import Lean + +open Lean Meta Elab Tactic Try + +-- Install a `TryTactic` handler for `assumption` +@[try_tactic assumption] +def evalTryApply : TryTactic := fun tac => do + -- We just use the default implementation, but return a different tactic. + evalAssumption tac + `(tactic| (trace "worked"; assumption)) + +/-- info: Try this: · trace "worked"; assumption -/ +#guard_msgs (info) in +example (h : False) : False := by + try? (max := 1) -- at most one solution + +-- `try?` uses `evalAndSuggest` the attribute `[try_tactic]` is used to extend `evalAndSuggest`. +-- Let's define our own `try?` that uses `evalAndSuggest` +elab stx:"my_try?" : tactic => do + -- Things to try + let toTry ← `(tactic| attempt_all | assumption | apply True | rfl) + evalAndSuggest stx toTry + +/-- +info: Try these: +• · trace "worked"; assumption +• rfl +-/ +#guard_msgs (info) in +example (a : Nat) (h : a = a) : a = a := by + my_try?