feat: try? composite suggestions (#6979)
This PR adds support for more complex suggestions in `try?`. Example: ```lean example (as : List α) (a : α) : concat as a = as ++ [a] := by try? ``` suggestion ``` Try this: · induction as, a using concat.induct · rfl · simp_all ```
This commit is contained in:
parent
45d39422bc
commit
eab09084a3
6 changed files with 113 additions and 29 deletions
|
|
@ -27,7 +27,10 @@ namespace Lean.Parser.Tactic
|
|||
|
||||
syntax (name := tryTrace) "try?" optConfig : tactic
|
||||
|
||||
/-- Helper tactic for implementing the tactic `try?`. -/
|
||||
/-- Helper internal tactic for implementing the tactic `try?`. -/
|
||||
syntax (name := attemptAll) "attempt_all " withPosition((ppDedent(ppLine) colGe "| " tacticSeq)+) : tactic
|
||||
|
||||
/-- Helper internal tactic used to implement `evalSuggest` in `try?` -/
|
||||
syntax (name := tryResult) "try_suggestions " tactic* : tactic
|
||||
|
||||
end Lean.Parser.Tactic
|
||||
|
|
|
|||
|
|
@ -13,11 +13,6 @@ import Lean.Elab.Tactic.Config
|
|||
import Lean.Elab.Tactic.SimpTrace
|
||||
import Lean.Elab.Tactic.Grind
|
||||
|
||||
namespace Lean.Parser.Tactic
|
||||
/-- Internal tactic used to implement `evalSuggest` -/
|
||||
syntax (name := tryResult) "try_suggestions " tactic* : tactic
|
||||
end Lean.Parser.Tactic
|
||||
|
||||
namespace Lean.Elab.Tactic
|
||||
open Meta
|
||||
/-!
|
||||
|
|
@ -52,7 +47,7 @@ private def appendSeqResult (suggestionSeqs : Array (Array (TSyntax `tactic))) (
|
|||
/-- Returns a tactic representing all given suggestions `tacs`. -/
|
||||
private def mkTrySuggestions (tacs : Array (TSyntax `tactic)) : TacticM (TSyntax `tactic) := do
|
||||
if tacs.isEmpty then
|
||||
throwError "`mkSuggestions` failed"
|
||||
throwError "`mkTrySuggestions` failed"
|
||||
else if tacs.size == 1 then
|
||||
return tacs[0]!
|
||||
else
|
||||
|
|
@ -130,16 +125,38 @@ private def getKindsSolvedAll (tacss : Array (Array (TSyntax `tactic))) : Array
|
|||
r := r.push k
|
||||
return r
|
||||
|
||||
private def mkChainResultCore (tac1 : TSyntax `tactic) (tacs2 : Array (TSyntax `tactic)) : TacticM (Array (TSyntax `tactic)) := do
|
||||
let tacs2 := tacs2.map getSuggestionsCore
|
||||
private def peekOne (tac1 : TSyntax `tactic) (tacss2 : Array (Array (TSyntax `tactic))) : TacticM (TSyntax `tactic) := do
|
||||
let mut tacs2 := #[]
|
||||
for s in tacss2 do
|
||||
if s.isEmpty then
|
||||
tacs2 := tacs2.push (← `(tactic| · sorry))
|
||||
else
|
||||
tacs2 := tacs2.push (← `(tactic| · $(s[0]!):tactic))
|
||||
`(tactic| · $tac1:tactic
|
||||
$tacs2*)
|
||||
|
||||
private def mkChainResultCore (tac1 : TSyntax `tactic) (tacss2 : Array (TSyntax `tactic)) : TacticM (Array (TSyntax `tactic)) := do
|
||||
let tacss2 := tacss2.map getSuggestionsCore
|
||||
if (← isTracingEnabledFor `try.debug) then
|
||||
trace[try.debug] "mkChainResultCore tac1{indentD tac1}"
|
||||
let mut i : Nat := 0
|
||||
for tacs2 in tacss2 do
|
||||
i := i + 1
|
||||
trace[try.debug] "goal #{i} tactics"
|
||||
for tac2 in tacs2 do
|
||||
trace[try.debug] " {tac2}"
|
||||
trace[try.debug] "mkChainResult -----"
|
||||
let mut acc := #[]
|
||||
let solvedAll := getTacsSolvedAll tacs2
|
||||
let solvedAll := getTacsSolvedAll tacss2
|
||||
for tac2 in solvedAll do
|
||||
acc := acc.push (← `(tactic| $tac1 <;> $tac2))
|
||||
let tacs2 := eraseTacs tacs2 solvedAll
|
||||
let tacss2 := eraseTacs tacss2 solvedAll
|
||||
-- TODO: mixed cases
|
||||
trace[Meta.debug] "CHAIN tacs2: {tacs2}"
|
||||
trace[Meta.debug] "CHAIN kinds: {getKindsSolvedAll tacs2}"
|
||||
trace[try.debug] "kinds: {getKindsSolvedAll tacss2}"
|
||||
if (!acc.isEmpty && tacss2.all fun s => !s.isEmpty)
|
||||
-- We only include partial solutions if there are no other solutions.
|
||||
|| (acc.isEmpty && tacss2.any fun s => !s.isEmpty) then
|
||||
acc := acc.push <| (← peekOne tac1 tacss2)
|
||||
return acc
|
||||
|
||||
private def mkChainResult (tac1 : TSyntax `tactic) (tacs2 : Array (TSyntax `tactic)) : TacticM (TSyntax `tactic) := do
|
||||
|
|
@ -178,6 +195,7 @@ private def evalSuggestGrindTrace (tac : TSyntax `tactic) : TacticM (TSyntax `ta
|
|||
let trace ← evalGrindCore tac config only params fallback?
|
||||
let tac ← grindTraceToGrind tac
|
||||
let tac' ← mkGrindOnly configStx fallback? trace
|
||||
trace[try.debug] "`grind` succeeded"
|
||||
mkTrySuggestions #[tac, tac']
|
||||
| _ => throwUnsupportedSyntax
|
||||
|
||||
|
|
@ -188,6 +206,7 @@ private def evalSuggestSimpTrace (tac : TSyntax `tactic) : TacticM (TSyntax `tac
|
|||
let { ctx, simprocs, .. } ← mkSimpContext tac (eraseLocal := false)
|
||||
let stats ← simpLocation ctx (simprocs := simprocs) none <| (loc.map expandLocation).getD (.targets #[] true)
|
||||
let tac' ← mkSimpCallStx tac stats.usedTheorems
|
||||
trace[try.debug] "`simp` succeeded"
|
||||
mkTrySuggestions #[tac, tac']
|
||||
| _ => throwUnsupportedSyntax
|
||||
|
||||
|
|
@ -215,11 +234,14 @@ private def evalSuggestChain (tac1 tac2 : TSyntax `tactic) : TacticM (TSyntax `t
|
|||
let goals ← getGoals
|
||||
setGoals []
|
||||
let mut tac2s := #[]
|
||||
let mut i : Nat := 0
|
||||
for goal in goals do
|
||||
setGoals [goal]
|
||||
let tac2' ← (evalSuggest tac2) <|> `(tactic| sorry)
|
||||
let tac2' : TSyntax `tactic ← (evalSuggest tac2) <|> `(tactic| sorry)
|
||||
i := i + 1
|
||||
trace[try.debug] "`<;>` goal #{i}, tactic{indentD tac2'}"
|
||||
unless (← getGoals).isEmpty do
|
||||
throwError "unsolved goals, `<;>` in `try?` requires all goals to be solved"
|
||||
throwError "unsolved goals, `<;>` in `try?` requires all goals to be solved{indentD tac2}\n{goalsToMessageData (← getGoals)}"
|
||||
tac2s := tac2s.push tac2'
|
||||
if tac2s.all isSorry then
|
||||
throwError "`<;>` failed"
|
||||
|
|
@ -269,8 +291,11 @@ where
|
|||
go (i : Nat) (saved? : Option SavedState) (acc : Array (TSyntax `tactic)) : TacticM (TSyntax `tactic) := do
|
||||
if i < tacs.size then
|
||||
match (← observing (evalSuggestTacticSeq tacs[i]!)) with
|
||||
| .ok tac s => go (i+1) (saved? <|> some s) (appendSuggestion acc tac)
|
||||
| _ => go (i+1) saved? acc
|
||||
| .ok tac s =>
|
||||
trace[try.debug] "`attempt_all` argument succeeded{indentD tac}"
|
||||
go (i+1) (saved? <|> some s) (appendSuggestion acc tac)
|
||||
| _ =>
|
||||
go (i+1) saved? acc
|
||||
else
|
||||
if let some saved := saved? then
|
||||
saved.restore
|
||||
|
|
@ -281,6 +306,7 @@ where
|
|||
-- `evalSuggest` implementation
|
||||
@[export lean_eval_suggest_tactic]
|
||||
private partial def evalSuggestImpl (tac : TSyntax `tactic) : TacticM (TSyntax `tactic) := do
|
||||
trace[try.debug] "{tac}"
|
||||
match tac with
|
||||
| `(tactic| $tac1 <;> $tac2) => evalSuggestChain tac1 tac2
|
||||
| `(tactic| first $[| $tacs]*) => evalSuggestFirst tacs
|
||||
|
|
@ -343,17 +369,17 @@ private def setGrindParams (tac : TSyntax `tactic) (params : Array (TSyntax ``Pa
|
|||
⟨tac.raw.setArg 3 (mkNullNode paramsStx)⟩
|
||||
|
||||
/-- Given a set of declaration names, returns `grind` parameters of the form `= <declName>` -/
|
||||
private def mkGrindEqnParams (declNames : Std.HashSet Name) : MetaM (Array (TSyntax ``Parser.Tactic.grindParam)) := do
|
||||
declNames.toArray.mapM fun declName => do
|
||||
private def mkGrindEqnParams (declNames : Array Name) : MetaM (Array (TSyntax ``Parser.Tactic.grindParam)) := do
|
||||
declNames.mapM fun declName => do
|
||||
`(Parser.Tactic.grindParam| = $(← toIdent declName))
|
||||
|
||||
private def mkGrindStx (info : Try.Info) : MetaM (TSyntax `tactic) := do
|
||||
let grind ← `(tactic| grind?)
|
||||
let mut tacs := #[grind]
|
||||
unless info.eqnCandidates.isEmpty do
|
||||
tacs := tacs.push (setGrindParams grind (← mkGrindEqnParams info.eqnCandidates))
|
||||
tacs := tacs.push (setGrindParams grind (← mkGrindEqnParams info.eqnCandidates.elems))
|
||||
unless info.unfoldCandidates.isEmpty do
|
||||
tacs := tacs.push (setGrindParams grind (← mkGrindEqnParams info.unfoldCandidates))
|
||||
tacs := tacs.push (setGrindParams grind (← mkGrindEqnParams info.unfoldCandidates.elems))
|
||||
mkFirstStx tacs
|
||||
|
||||
/-! Other generators -/
|
||||
|
|
@ -400,7 +426,7 @@ where
|
|||
`(tactic| induction $terms,* using $indFn <;> $cont)
|
||||
|
||||
private def mkAllFunIndStx (info : Try.Info) (cont : TSyntax `tactic) : MetaM (TSyntax `tactic) := do
|
||||
let tacs ← info.funIndCandidates.toArray.mapM (mkFunIndStx · cont)
|
||||
let tacs ← info.funIndCandidates.elems.mapM (mkFunIndStx · cont)
|
||||
mkFirstStx tacs
|
||||
|
||||
/-! Main code -/
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ builtin_initialize registerTraceClass `try
|
|||
builtin_initialize registerTraceClass `try.collect
|
||||
builtin_initialize registerTraceClass `try.collect.funInd
|
||||
|
||||
builtin_initialize registerTraceClass `try.debug
|
||||
builtin_initialize registerTraceClass `try.debug.funInd
|
||||
|
||||
end Lean
|
||||
|
|
|
|||
|
|
@ -21,19 +21,35 @@ structure FunIndCandidate where
|
|||
majors : Array FVarId
|
||||
deriving Hashable, BEq
|
||||
|
||||
/-- `Set` with insertion order preserved. -/
|
||||
structure OrdSet (α : Type) [Hashable α] [BEq α] where
|
||||
elems : Array α := #[]
|
||||
set : Std.HashSet α := {}
|
||||
deriving Inhabited
|
||||
|
||||
def OrdSet.insert {_ : Hashable α} {_ : BEq α} (s : OrdSet α) (a : α) : OrdSet α :=
|
||||
if s.set.contains a then
|
||||
s
|
||||
else
|
||||
let { elems, set } := s
|
||||
{ elems := elems.push a, set := set.insert a }
|
||||
|
||||
def OrdSet.isEmpty {_ : Hashable α} {_ : BEq α} (s : OrdSet α) : Bool :=
|
||||
s.elems.isEmpty
|
||||
|
||||
structure Result where
|
||||
/-- All constant symbols occurring in the gal. -/
|
||||
allConsts : Std.HashSet Name := {}
|
||||
allConsts : OrdSet Name := {}
|
||||
/-- Unfolding candiates. -/
|
||||
unfoldCandidates : Std.HashSet Name := {}
|
||||
unfoldCandidates : OrdSet Name := {}
|
||||
/-- Equation function candiates. -/
|
||||
eqnCandidates : Std.HashSet Name := {}
|
||||
eqnCandidates : OrdSet Name := {}
|
||||
/-- Function induction candidates. -/
|
||||
funIndCandidates : Std.HashSet FunIndCandidate := {}
|
||||
funIndCandidates : OrdSet FunIndCandidate := {}
|
||||
/-- Induction candidates. -/
|
||||
indCandidates : Array InductionCandidate := #[]
|
||||
/-- Relevant declarations by `libSearch` -/
|
||||
libSearchResults : Std.HashSet (Name × Grind.EMatchTheoremKind) := {}
|
||||
libSearchResults : OrdSet (Name × Grind.EMatchTheoremKind) := {}
|
||||
|
||||
structure Context where
|
||||
config : Try.Config
|
||||
|
|
|
|||
|
|
@ -205,7 +205,15 @@ def evalExpr (e : Expr) : EvalM Val := do
|
|||
@[grind] theorem UnaryOp.simplify_eval (op : UnaryOp) : (op.simplify a).eval σ = (Expr.una op a).eval σ := by
|
||||
grind [UnaryOp.simplify.eq_def]
|
||||
|
||||
/-- info: Try this: (induction e using Expr.simplify.induct) <;> grind -/
|
||||
/--
|
||||
info: Try these:
|
||||
• (induction e using Expr.simplify.induct) <;> grind
|
||||
• ·
|
||||
induction e using Expr.simplify.induct
|
||||
· grind only [Expr.simplify, BinOp.simplify, Expr.eval, BinaryOp.simplify_eval]
|
||||
· grind only [UnaryOp.simplify_eval, UnaryOp.simplify, Expr.simplify, Expr.eval]
|
||||
· simp
|
||||
-/
|
||||
#guard_msgs (info) in
|
||||
example (e : Expr) : e.simplify.eval σ = e.eval σ := by
|
||||
try?
|
||||
|
|
@ -304,7 +312,20 @@ theorem State.cons_le_of_eq (h₁ : σ' ≼ σ) (h₂ : σ.find? x = some v) : (
|
|||
@[grind] theorem State.join_le_left_of (h : σ₁ ≼ σ₂) (σ₃ : State) : σ₁.join σ₃ ≼ σ₂ := by
|
||||
grind
|
||||
|
||||
/-- info: Try this: (induction σ₁, σ₂ using State.join.induct) <;> grind -/
|
||||
/--
|
||||
info: Try these:
|
||||
• (induction σ₁, σ₂ using State.join.induct) <;> grind
|
||||
• ·
|
||||
induction σ₁, σ₂ using State.join.induct
|
||||
·
|
||||
grind only [State.join_le_left, State.find?, State.join, State.join_le_left_of, State.le, = State.find?_nil,
|
||||
State.bot_le, State.le_refl]
|
||||
·
|
||||
grind only [State.join, State.join_le_left, State.length_erase_le, State.find?, State.join_le_left_of, State.le, =
|
||||
State.find?_erase_eq, State.erase_le, State.le_refl, cases Or]
|
||||
· grind only [State.join, State.join_le_left, State.length_erase_le, State.join_le_left_of, State.le, State.erase_le]
|
||||
· grind only [State.join, State.join_le_left, State.length_erase_le, State.join_le_left_of, State.le, State.erase_le]
|
||||
-/
|
||||
#guard_msgs (info) in
|
||||
example (σ₁ σ₂ : State) : σ₁.join σ₂ ≼ σ₂ := by
|
||||
try?
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
set_option grind.warning false
|
||||
%reset_grind_attrs
|
||||
|
||||
/--
|
||||
info: Try these:
|
||||
|
|
@ -97,3 +98,19 @@ example : app (app as bs) cs = app as (app bs cs) := by
|
|||
intro _ _ _
|
||||
-- `as`, `bs`, and `cs` now have inaccessible names.
|
||||
try?
|
||||
|
||||
def concat : List α → α → List α
|
||||
| .nil, b => .cons b .nil
|
||||
| .cons a as, b => .cons a (concat as b)
|
||||
|
||||
attribute [simp] concat
|
||||
|
||||
/--
|
||||
info: Try this: ·
|
||||
induction as, a using concat.induct
|
||||
· rfl
|
||||
· simp_all
|
||||
-/
|
||||
#guard_msgs (info) in
|
||||
example (as : List α) (a : α) : concat as a = as ++ [a] := by
|
||||
try?
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue