feat: extensible evalAndSuggest for try? (#7831)
This PR adds extensibility to the `evalAndSuggest` procedure used to implement `try?`. Users can now implement their own handlers for any tactic. The new test demonstrates how this feature works.
This commit is contained in:
parent
c851cdb21e
commit
007bd18bcb
2 changed files with 96 additions and 23 deletions
|
|
@ -216,19 +216,26 @@ structure Ctx where
|
|||
terminal : Bool
|
||||
config : Try.Config
|
||||
|
||||
abbrev M := ReaderT Ctx TacticM
|
||||
abbrev TryTacticM := ReaderT Ctx TacticM
|
||||
abbrev TryTactic := TSyntax `tactic → TryTacticM (TSyntax `tactic)
|
||||
|
||||
instance : MonadBacktrack SavedState M where
|
||||
instance : MonadBacktrack SavedState TryTacticM where
|
||||
saveState := fun _ => saveState
|
||||
restoreState s := fun _ => restoreState s
|
||||
|
||||
abbrev withNonTerminal (x : M α) : M α :=
|
||||
abbrev withNonTerminal (x : TryTacticM α) : TryTacticM α :=
|
||||
withReader (fun c => { c with terminal := false}) x
|
||||
|
||||
-- TODO: polymorphic `Tactic.focus`
|
||||
abbrev focus (x : M α) : M α := fun ctx => Tactic.focus (x ctx)
|
||||
builtin_initialize tryTacticElabAttribute : KeyedDeclsAttribute TryTactic ← do
|
||||
unsafe mkElabAttribute TryTactic `builtin_try_tactic `try_tactic `Lean.Parser.Tactic `Lean.Elab.Tactic.Try.TryTactic "try_tactic"
|
||||
|
||||
def observing (x : M α) : M (TacticResult α) := do
|
||||
private def getEvalFns (kind : SyntaxNodeKind) : CoreM (List (KeyedDeclsAttribute.AttributeEntry TryTactic)) := do
|
||||
return tryTacticElabAttribute.getEntries (← getEnv) kind
|
||||
|
||||
-- TODO: polymorphic `Tactic.focus`
|
||||
abbrev focus (x : TryTacticM α) : TryTacticM α := fun ctx => Tactic.focus (x ctx)
|
||||
|
||||
def observing (x : TryTacticM α) : TryTacticM (TacticResult α) := do
|
||||
let s ← saveState
|
||||
try
|
||||
let e ← x
|
||||
|
|
@ -271,7 +278,7 @@ private def merge? (tac1 tac2 : TSyntax `tactic) : Option (TSyntax `tactic) :=
|
|||
else
|
||||
none
|
||||
|
||||
private def mergeAll? (tacs : Array (TSyntax `tactic)) : M (Option (TSyntax `tactic)) := do
|
||||
private def mergeAll? (tacs : Array (TSyntax `tactic)) : TryTacticM (Option (TSyntax `tactic)) := do
|
||||
if !(← read).config.merge || tacs.isEmpty then
|
||||
return none
|
||||
let tac0 := tacs[0]!
|
||||
|
|
@ -304,7 +311,7 @@ private def isOnlyAndNonOnly (tacs2 : Array (TSyntax `tactic)) : Bool := Id.run
|
|||
else
|
||||
return false
|
||||
|
||||
private def mkChainResult (tac1 : TSyntax `tactic) (tacss2 : Array (TSyntax `tactic)) : M (TSyntax `tactic) := do
|
||||
private def mkChainResult (tac1 : TSyntax `tactic) (tacss2 : Array (TSyntax `tactic)) : TryTacticM (TSyntax `tactic) := do
|
||||
let tacss2 := tacss2.map getSuggestionsCore
|
||||
if (← isTracingEnabledFor `try.debug) then
|
||||
trace[try.debug] "mkChainResultCore tac1{indentD tac1}"
|
||||
|
|
@ -343,7 +350,7 @@ private def mkChainResult (tac1 : TSyntax `tactic) (tacss2 : Array (TSyntax `tac
|
|||
(_, acc) ← go tacss2 0 [] none |>.run acc
|
||||
mkTrySuggestions acc
|
||||
where
|
||||
go (tacss2 : Array (Array (TSyntax `tactic))) (i : Nat) (acc : List (TSyntax `tactic)) (kind? : Option SyntaxNodeKind) : StateT (Array (TSyntax `tactic)) M Unit := do
|
||||
go (tacss2 : Array (Array (TSyntax `tactic))) (i : Nat) (acc : List (TSyntax `tactic)) (kind? : Option SyntaxNodeKind) : StateT (Array (TSyntax `tactic)) TryTacticM Unit := do
|
||||
if (← get).size > (← read).config.max then
|
||||
return ()
|
||||
else if h : i < tacss2.size then
|
||||
|
|
@ -371,7 +378,7 @@ where
|
|||
$tacs2*)
|
||||
modify (·.push tac)
|
||||
|
||||
private def evalSuggestGrindTrace (tac : TSyntax `tactic) : M (TSyntax `tactic) := do
|
||||
private def evalSuggestGrindTrace : TryTactic := fun tac => do
|
||||
match tac with
|
||||
| `(tactic| grind? $configStx:optConfig $[only%$only]? $[ [$params:grindParam,*] ]? $[on_failure $fallback?]?) =>
|
||||
let config ← elabGrindConfig configStx
|
||||
|
|
@ -386,7 +393,7 @@ private def evalSuggestGrindTrace (tac : TSyntax `tactic) : M (TSyntax `tactic)
|
|||
return tac
|
||||
| _ => throwUnsupportedSyntax
|
||||
|
||||
private def evalSuggestSimpTrace (tac : TSyntax `tactic) : M (TSyntax `tactic) := do (← getMainGoal).withContext do
|
||||
private def evalSuggestSimpTrace : TryTactic := fun tac => do (← getMainGoal).withContext do
|
||||
match tac with
|
||||
| `(tactic| simp? $_:optConfig $[only%$only]? $[[$args,*]]? $(loc)?) =>
|
||||
let tac ← simpTraceToSimp tac
|
||||
|
|
@ -401,10 +408,10 @@ private def evalSuggestSimpTrace (tac : TSyntax `tactic) : M (TSyntax `tactic) :
|
|||
| _ => throwUnsupportedSyntax
|
||||
|
||||
@[extern "lean_eval_suggest_tactic"] -- forward definition to avoid mutual block
|
||||
opaque evalSuggest (tac : TSyntax `tactic) : M (TSyntax `tactic)
|
||||
opaque evalSuggest : TryTactic
|
||||
|
||||
/-- `evalSuggest` for `tac1 <;> tac2` -/
|
||||
private def evalSuggestChain (tac1 tac2 : TSyntax `tactic) : M (TSyntax `tactic) := focus do
|
||||
private def evalSuggestChain (tac1 tac2 : TSyntax `tactic) : TryTacticM (TSyntax `tactic) := focus do
|
||||
unless (← read).terminal do
|
||||
throwError "invalid `<;>` occurrence in non-terminal position for `try?` script{indentD (← read).root}"
|
||||
let tac1 ← withNonTerminal do evalSuggest tac1
|
||||
|
|
@ -422,7 +429,7 @@ private def evalSuggestChain (tac1 tac2 : TSyntax `tactic) : M (TSyntax `tactic)
|
|||
mkChainResult tac1 tac2s
|
||||
|
||||
/-- `evalSuggest` for a sequence of tactics. -/
|
||||
private def evalSuggestSeq (tacs : Array (TSyntax `tactic)) : M (TSyntax `tactic) := do
|
||||
private def evalSuggestSeq (tacs : Array (TSyntax `tactic)) : TryTacticM (TSyntax `tactic) := do
|
||||
if (← read).terminal then
|
||||
let mut result := #[]
|
||||
for i in [:tacs.size - 1] do
|
||||
|
|
@ -433,10 +440,10 @@ private def evalSuggestSeq (tacs : Array (TSyntax `tactic)) : M (TSyntax `tactic
|
|||
else
|
||||
mkSeq (← tacs.mapM evalSuggest) (terminal := false)
|
||||
|
||||
private def evalSuggestSeqCore (tacs : Array Syntax) : M (TSyntax `tactic) := do
|
||||
private def evalSuggestSeqCore (tacs : Array Syntax) : TryTacticM (TSyntax `tactic) := do
|
||||
evalSuggestSeq (tacs.map fun tac => ⟨tac⟩)
|
||||
|
||||
private def evalSuggestTacticSeq (s : TSyntax ``Parser.Tactic.tacticSeq) : M (TSyntax `tactic) := do
|
||||
private def evalSuggestTacticSeq (s : TSyntax ``Parser.Tactic.tacticSeq) : TryTacticM (TSyntax `tactic) := do
|
||||
let tacs ← match s with
|
||||
| `(tacticSeq| { $t;* }) => pure t.getElems
|
||||
| `(tacticSeq| $t;*) => pure t.getElems
|
||||
|
|
@ -444,30 +451,30 @@ private def evalSuggestTacticSeq (s : TSyntax ``Parser.Tactic.tacticSeq) : M (TS
|
|||
evalSuggestSeq tacs
|
||||
|
||||
/-- `evalSuggest` for `first` tactic. -/
|
||||
private partial def evalSuggestFirst (tacs : Array (TSyntax ``Parser.Tactic.tacticSeq)) : M (TSyntax `tactic) := do
|
||||
private partial def evalSuggestFirst (tacs : Array (TSyntax ``Parser.Tactic.tacticSeq)) : TryTacticM (TSyntax `tactic) := do
|
||||
if tacs.size == 0 then
|
||||
throwError "`first` expects at least one argument"
|
||||
go 0
|
||||
where
|
||||
go (i : Nat) : M (TSyntax `tactic) := do
|
||||
go (i : Nat) : TryTacticM (TSyntax `tactic) := do
|
||||
if i = tacs.size - 1 then
|
||||
evalSuggestTacticSeq tacs[i]!
|
||||
else
|
||||
evalSuggestTacticSeq tacs[i]! <|> go (i+1)
|
||||
|
||||
/-- `evalSuggest` for `try` tactic. -/
|
||||
private partial def evalSuggestTry (tac : TSyntax ``Parser.Tactic.tacticSeq) : M (TSyntax `tactic) := do
|
||||
private partial def evalSuggestTry (tac : TSyntax ``Parser.Tactic.tacticSeq) : TryTacticM (TSyntax `tactic) := do
|
||||
(do evalSuggestTacticSeq tac)
|
||||
<|>
|
||||
`(tactic| skip)
|
||||
|
||||
/-- `evalSuggest` for `attempt_all` tactic. -/
|
||||
private partial def evalSuggestAttemptAll (tacs : Array (TSyntax ``Parser.Tactic.tacticSeq)) : M (TSyntax `tactic) := do
|
||||
private partial def evalSuggestAttemptAll (tacs : Array (TSyntax ``Parser.Tactic.tacticSeq)) : TryTacticM (TSyntax `tactic) := do
|
||||
unless (← read).terminal do
|
||||
throwError "invalid occurrence of `attempt_all` in non-terminal position for `try?` script{indentD (← read).root}"
|
||||
go 0 none #[]
|
||||
where
|
||||
go (i : Nat) (saved? : Option SavedState) (acc : Array (TSyntax `tactic)) : M (TSyntax `tactic) := do
|
||||
go (i : Nat) (saved? : Option SavedState) (acc : Array (TSyntax `tactic)) : TryTacticM (TSyntax `tactic) := do
|
||||
-- Remark: we considered using `acc.size < (← read).config.max` here to truncate the search,
|
||||
-- but it had a negative effect when using `<;>`. We could miss a preferred solution `induction e <;> grind`
|
||||
-- because only a subset of the goals were solved by simpler tactics such as `rfl` and `simp`.
|
||||
|
|
@ -485,10 +492,45 @@ where
|
|||
else
|
||||
throwError "`attempt_all` failed"
|
||||
|
||||
private partial def evalSuggestDefault (tac : TSyntax `tactic) : TryTacticM (TSyntax `tactic) := do
|
||||
let kind := tac.raw.getKind
|
||||
match (← getEvalFns kind) with
|
||||
| [] => evalSuggestAtomic tac -- lift regular tactic
|
||||
| evalFns => eval (← Tactic.saveState) evalFns #[]
|
||||
where
|
||||
throwExs (failures : Array EvalTacticFailure) : TryTacticM (TSyntax `tactic) := do
|
||||
if h : 0 < failures.size then
|
||||
let fail := failures[failures.size - 1]
|
||||
fail.state.restore (restoreInfo := true)
|
||||
throw fail.exception
|
||||
else
|
||||
throwErrorAt tac "unexpected syntax {indentD tac}"
|
||||
|
||||
eval (s : SavedState) (evalFns : List _) (failures : Array EvalTacticFailure) : TryTacticM (TSyntax `tactic) := do
|
||||
match evalFns with
|
||||
| [] => throwExs failures
|
||||
| evalFn::evalFns =>
|
||||
try
|
||||
withTheReader Tactic.Context ({ · with elaborator := evalFn.declName }) do
|
||||
evalFn.value tac
|
||||
catch ex => match ex with
|
||||
| .error .. =>
|
||||
let failures := failures.push ⟨ex, ← Tactic.saveState⟩
|
||||
s.restore (restoreInfo := true); eval s evalFns failures
|
||||
| .internal id _ =>
|
||||
if id == unsupportedSyntaxExceptionId then
|
||||
s.restore (restoreInfo := true); eval s evalFns failures
|
||||
else if id == abortTacticExceptionId then
|
||||
let failures := failures.push ⟨ex, ← Tactic.saveState⟩
|
||||
s.restore (restoreInfo := true); eval s evalFns failures
|
||||
else
|
||||
throw ex
|
||||
|
||||
-- `evalSuggest` implementation
|
||||
@[export lean_eval_suggest_tactic]
|
||||
private partial def evalSuggestImpl (tac : TSyntax `tactic) : M (TSyntax `tactic) := do
|
||||
private partial def evalSuggestImpl : TryTactic := fun tac => do
|
||||
trace[try.debug] "{tac}"
|
||||
-- TODO: Implement builtin cases using `[builtin_try_tactic]` after update-stage0
|
||||
match tac with
|
||||
| `(tactic| $tac1 <;> $tac2) => evalSuggestChain tac1 tac2
|
||||
| `(tactic| first $[| $tacs]*) => evalSuggestFirst tacs
|
||||
|
|
@ -507,7 +549,7 @@ private partial def evalSuggestImpl (tac : TSyntax `tactic) : M (TSyntax `tactic
|
|||
else if k == ``Parser.Tactic.exact? then
|
||||
evalSuggestExact
|
||||
else
|
||||
evalSuggestAtomic tac
|
||||
evalSuggestDefault tac
|
||||
if (← read).terminal then
|
||||
unless (← getGoals).isEmpty do
|
||||
throwError "unsolved goals"
|
||||
|
|
|
|||
31
tests/lean/run/grind_try_extend.lean
Normal file
31
tests/lean/run/grind_try_extend.lean
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
import Lean
|
||||
|
||||
open Lean Meta Elab Tactic Try
|
||||
|
||||
-- Install a `TryTactic` handler for `assumption`
|
||||
@[try_tactic assumption]
|
||||
def evalTryApply : TryTactic := fun tac => do
|
||||
-- We just use the default implementation, but return a different tactic.
|
||||
evalAssumption tac
|
||||
`(tactic| (trace "worked"; assumption))
|
||||
|
||||
/-- info: Try this: · trace "worked"; assumption -/
|
||||
#guard_msgs (info) in
|
||||
example (h : False) : False := by
|
||||
try? (max := 1) -- at most one solution
|
||||
|
||||
-- `try?` uses `evalAndSuggest` the attribute `[try_tactic]` is used to extend `evalAndSuggest`.
|
||||
-- Let's define our own `try?` that uses `evalAndSuggest`
|
||||
elab stx:"my_try?" : tactic => do
|
||||
-- Things to try
|
||||
let toTry ← `(tactic| attempt_all | assumption | apply True | rfl)
|
||||
evalAndSuggest stx toTry
|
||||
|
||||
/--
|
||||
info: Try these:
|
||||
• · trace "worked"; assumption
|
||||
• rfl
|
||||
-/
|
||||
#guard_msgs (info) in
|
||||
example (a : Nat) (h : a = a) : a = a := by
|
||||
my_try?
|
||||
Loading…
Add table
Reference in a new issue