From eab09084a3380665fff6004672ffdba6d2d139b3 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 6 Feb 2025 13:56:14 -0800 Subject: [PATCH] feat: `try?` composite suggestions (#6979) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds support for more complex suggestions in `try?`. Example: ```lean example (as : List α) (a : α) : concat as a = as ++ [a] := by try? ``` suggestion ``` Try this: · induction as, a using concat.induct · rfl · simp_all ``` --- src/Init/Try.lean | 5 +- src/Lean/Elab/Tactic/Try.lean | 68 ++++++++++++++++++--------- src/Lean/Meta/Tactic/Try.lean | 1 + src/Lean/Meta/Tactic/Try/Collect.lean | 26 ++++++++-- tests/lean/run/grind_constProp.lean | 25 +++++++++- tests/lean/run/try_trace1.lean | 17 +++++++ 6 files changed, 113 insertions(+), 29 deletions(-) diff --git a/src/Init/Try.lean b/src/Init/Try.lean index a7cf348367..8f84a3c25b 100644 --- a/src/Init/Try.lean +++ b/src/Init/Try.lean @@ -27,7 +27,10 @@ namespace Lean.Parser.Tactic syntax (name := tryTrace) "try?" optConfig : tactic -/-- Helper tactic for implementing the tactic `try?`. -/ +/-- Helper internal tactic for implementing the tactic `try?`. -/ syntax (name := attemptAll) "attempt_all " withPosition((ppDedent(ppLine) colGe "| " tacticSeq)+) : tactic +/-- Helper internal tactic used to implement `evalSuggest` in `try?` -/ +syntax (name := tryResult) "try_suggestions " tactic* : tactic + end Lean.Parser.Tactic diff --git a/src/Lean/Elab/Tactic/Try.lean b/src/Lean/Elab/Tactic/Try.lean index 35a806948c..9af3584378 100644 --- a/src/Lean/Elab/Tactic/Try.lean +++ b/src/Lean/Elab/Tactic/Try.lean @@ -13,11 +13,6 @@ import Lean.Elab.Tactic.Config import Lean.Elab.Tactic.SimpTrace import Lean.Elab.Tactic.Grind -namespace Lean.Parser.Tactic -/-- Internal tactic used to implement `evalSuggest` -/ -syntax (name := tryResult) "try_suggestions " tactic* : tactic -end Lean.Parser.Tactic - namespace Lean.Elab.Tactic open Meta /-! @@ -52,7 +47,7 @@ private def appendSeqResult (suggestionSeqs : Array (Array (TSyntax `tactic))) ( /-- Returns a tactic representing all given suggestions `tacs`. -/ private def mkTrySuggestions (tacs : Array (TSyntax `tactic)) : TacticM (TSyntax `tactic) := do if tacs.isEmpty then - throwError "`mkSuggestions` failed" + throwError "`mkTrySuggestions` failed" else if tacs.size == 1 then return tacs[0]! else @@ -130,16 +125,38 @@ private def getKindsSolvedAll (tacss : Array (Array (TSyntax `tactic))) : Array r := r.push k return r -private def mkChainResultCore (tac1 : TSyntax `tactic) (tacs2 : Array (TSyntax `tactic)) : TacticM (Array (TSyntax `tactic)) := do - let tacs2 := tacs2.map getSuggestionsCore +private def peekOne (tac1 : TSyntax `tactic) (tacss2 : Array (Array (TSyntax `tactic))) : TacticM (TSyntax `tactic) := do + let mut tacs2 := #[] + for s in tacss2 do + if s.isEmpty then + tacs2 := tacs2.push (← `(tactic| · sorry)) + else + tacs2 := tacs2.push (← `(tactic| · $(s[0]!):tactic)) + `(tactic| · $tac1:tactic + $tacs2*) + +private def mkChainResultCore (tac1 : TSyntax `tactic) (tacss2 : Array (TSyntax `tactic)) : TacticM (Array (TSyntax `tactic)) := do + let tacss2 := tacss2.map getSuggestionsCore + if (← isTracingEnabledFor `try.debug) then + trace[try.debug] "mkChainResultCore tac1{indentD tac1}" + let mut i : Nat := 0 + for tacs2 in tacss2 do + i := i + 1 + trace[try.debug] "goal #{i} tactics" + for tac2 in tacs2 do + trace[try.debug] " {tac2}" + trace[try.debug] "mkChainResult -----" let mut acc := #[] - let solvedAll := getTacsSolvedAll tacs2 + let solvedAll := getTacsSolvedAll tacss2 for tac2 in solvedAll do acc := acc.push (← `(tactic| $tac1 <;> $tac2)) - let tacs2 := eraseTacs tacs2 solvedAll + let tacss2 := eraseTacs tacss2 solvedAll -- TODO: mixed cases - trace[Meta.debug] "CHAIN tacs2: {tacs2}" - trace[Meta.debug] "CHAIN kinds: {getKindsSolvedAll tacs2}" + trace[try.debug] "kinds: {getKindsSolvedAll tacss2}" + if (!acc.isEmpty && tacss2.all fun s => !s.isEmpty) + -- We only include partial solutions if there are no other solutions. + || (acc.isEmpty && tacss2.any fun s => !s.isEmpty) then + acc := acc.push <| (← peekOne tac1 tacss2) return acc private def mkChainResult (tac1 : TSyntax `tactic) (tacs2 : Array (TSyntax `tactic)) : TacticM (TSyntax `tactic) := do @@ -178,6 +195,7 @@ private def evalSuggestGrindTrace (tac : TSyntax `tactic) : TacticM (TSyntax `ta let trace ← evalGrindCore tac config only params fallback? let tac ← grindTraceToGrind tac let tac' ← mkGrindOnly configStx fallback? trace + trace[try.debug] "`grind` succeeded" mkTrySuggestions #[tac, tac'] | _ => throwUnsupportedSyntax @@ -188,6 +206,7 @@ private def evalSuggestSimpTrace (tac : TSyntax `tactic) : TacticM (TSyntax `tac let { ctx, simprocs, .. } ← mkSimpContext tac (eraseLocal := false) let stats ← simpLocation ctx (simprocs := simprocs) none <| (loc.map expandLocation).getD (.targets #[] true) let tac' ← mkSimpCallStx tac stats.usedTheorems + trace[try.debug] "`simp` succeeded" mkTrySuggestions #[tac, tac'] | _ => throwUnsupportedSyntax @@ -215,11 +234,14 @@ private def evalSuggestChain (tac1 tac2 : TSyntax `tactic) : TacticM (TSyntax `t let goals ← getGoals setGoals [] let mut tac2s := #[] + let mut i : Nat := 0 for goal in goals do setGoals [goal] - let tac2' ← (evalSuggest tac2) <|> `(tactic| sorry) + let tac2' : TSyntax `tactic ← (evalSuggest tac2) <|> `(tactic| sorry) + i := i + 1 + trace[try.debug] "`<;>` goal #{i}, tactic{indentD tac2'}" unless (← getGoals).isEmpty do - throwError "unsolved goals, `<;>` in `try?` requires all goals to be solved" + throwError "unsolved goals, `<;>` in `try?` requires all goals to be solved{indentD tac2}\n{goalsToMessageData (← getGoals)}" tac2s := tac2s.push tac2' if tac2s.all isSorry then throwError "`<;>` failed" @@ -269,8 +291,11 @@ where go (i : Nat) (saved? : Option SavedState) (acc : Array (TSyntax `tactic)) : TacticM (TSyntax `tactic) := do if i < tacs.size then match (← observing (evalSuggestTacticSeq tacs[i]!)) with - | .ok tac s => go (i+1) (saved? <|> some s) (appendSuggestion acc tac) - | _ => go (i+1) saved? acc + | .ok tac s => + trace[try.debug] "`attempt_all` argument succeeded{indentD tac}" + go (i+1) (saved? <|> some s) (appendSuggestion acc tac) + | _ => + go (i+1) saved? acc else if let some saved := saved? then saved.restore @@ -281,6 +306,7 @@ where -- `evalSuggest` implementation @[export lean_eval_suggest_tactic] private partial def evalSuggestImpl (tac : TSyntax `tactic) : TacticM (TSyntax `tactic) := do + trace[try.debug] "{tac}" match tac with | `(tactic| $tac1 <;> $tac2) => evalSuggestChain tac1 tac2 | `(tactic| first $[| $tacs]*) => evalSuggestFirst tacs @@ -343,17 +369,17 @@ private def setGrindParams (tac : TSyntax `tactic) (params : Array (TSyntax ``Pa ⟨tac.raw.setArg 3 (mkNullNode paramsStx)⟩ /-- Given a set of declaration names, returns `grind` parameters of the form `= ` -/ -private def mkGrindEqnParams (declNames : Std.HashSet Name) : MetaM (Array (TSyntax ``Parser.Tactic.grindParam)) := do - declNames.toArray.mapM fun declName => do +private def mkGrindEqnParams (declNames : Array Name) : MetaM (Array (TSyntax ``Parser.Tactic.grindParam)) := do + declNames.mapM fun declName => do `(Parser.Tactic.grindParam| = $(← toIdent declName)) private def mkGrindStx (info : Try.Info) : MetaM (TSyntax `tactic) := do let grind ← `(tactic| grind?) let mut tacs := #[grind] unless info.eqnCandidates.isEmpty do - tacs := tacs.push (setGrindParams grind (← mkGrindEqnParams info.eqnCandidates)) + tacs := tacs.push (setGrindParams grind (← mkGrindEqnParams info.eqnCandidates.elems)) unless info.unfoldCandidates.isEmpty do - tacs := tacs.push (setGrindParams grind (← mkGrindEqnParams info.unfoldCandidates)) + tacs := tacs.push (setGrindParams grind (← mkGrindEqnParams info.unfoldCandidates.elems)) mkFirstStx tacs /-! Other generators -/ @@ -400,7 +426,7 @@ where `(tactic| induction $terms,* using $indFn <;> $cont) private def mkAllFunIndStx (info : Try.Info) (cont : TSyntax `tactic) : MetaM (TSyntax `tactic) := do - let tacs ← info.funIndCandidates.toArray.mapM (mkFunIndStx · cont) + let tacs ← info.funIndCandidates.elems.mapM (mkFunIndStx · cont) mkFirstStx tacs /-! Main code -/ diff --git a/src/Lean/Meta/Tactic/Try.lean b/src/Lean/Meta/Tactic/Try.lean index 18b0fbb463..fb34f3b20d 100644 --- a/src/Lean/Meta/Tactic/Try.lean +++ b/src/Lean/Meta/Tactic/Try.lean @@ -12,6 +12,7 @@ builtin_initialize registerTraceClass `try builtin_initialize registerTraceClass `try.collect builtin_initialize registerTraceClass `try.collect.funInd +builtin_initialize registerTraceClass `try.debug builtin_initialize registerTraceClass `try.debug.funInd end Lean diff --git a/src/Lean/Meta/Tactic/Try/Collect.lean b/src/Lean/Meta/Tactic/Try/Collect.lean index 2750721d86..38f5a496dd 100644 --- a/src/Lean/Meta/Tactic/Try/Collect.lean +++ b/src/Lean/Meta/Tactic/Try/Collect.lean @@ -21,19 +21,35 @@ structure FunIndCandidate where majors : Array FVarId deriving Hashable, BEq +/-- `Set` with insertion order preserved. -/ +structure OrdSet (α : Type) [Hashable α] [BEq α] where + elems : Array α := #[] + set : Std.HashSet α := {} + deriving Inhabited + +def OrdSet.insert {_ : Hashable α} {_ : BEq α} (s : OrdSet α) (a : α) : OrdSet α := + if s.set.contains a then + s + else + let { elems, set } := s + { elems := elems.push a, set := set.insert a } + +def OrdSet.isEmpty {_ : Hashable α} {_ : BEq α} (s : OrdSet α) : Bool := + s.elems.isEmpty + structure Result where /-- All constant symbols occurring in the gal. -/ - allConsts : Std.HashSet Name := {} + allConsts : OrdSet Name := {} /-- Unfolding candiates. -/ - unfoldCandidates : Std.HashSet Name := {} + unfoldCandidates : OrdSet Name := {} /-- Equation function candiates. -/ - eqnCandidates : Std.HashSet Name := {} + eqnCandidates : OrdSet Name := {} /-- Function induction candidates. -/ - funIndCandidates : Std.HashSet FunIndCandidate := {} + funIndCandidates : OrdSet FunIndCandidate := {} /-- Induction candidates. -/ indCandidates : Array InductionCandidate := #[] /-- Relevant declarations by `libSearch` -/ - libSearchResults : Std.HashSet (Name × Grind.EMatchTheoremKind) := {} + libSearchResults : OrdSet (Name × Grind.EMatchTheoremKind) := {} structure Context where config : Try.Config diff --git a/tests/lean/run/grind_constProp.lean b/tests/lean/run/grind_constProp.lean index 5c888fa359..8d6cb4abfb 100644 --- a/tests/lean/run/grind_constProp.lean +++ b/tests/lean/run/grind_constProp.lean @@ -205,7 +205,15 @@ def evalExpr (e : Expr) : EvalM Val := do @[grind] theorem UnaryOp.simplify_eval (op : UnaryOp) : (op.simplify a).eval σ = (Expr.una op a).eval σ := by grind [UnaryOp.simplify.eq_def] -/-- info: Try this: (induction e using Expr.simplify.induct) <;> grind -/ +/-- +info: Try these: +• (induction e using Expr.simplify.induct) <;> grind +• · + induction e using Expr.simplify.induct + · grind only [Expr.simplify, BinOp.simplify, Expr.eval, BinaryOp.simplify_eval] + · grind only [UnaryOp.simplify_eval, UnaryOp.simplify, Expr.simplify, Expr.eval] + · simp +-/ #guard_msgs (info) in example (e : Expr) : e.simplify.eval σ = e.eval σ := by try? @@ -304,7 +312,20 @@ theorem State.cons_le_of_eq (h₁ : σ' ≼ σ) (h₂ : σ.find? x = some v) : ( @[grind] theorem State.join_le_left_of (h : σ₁ ≼ σ₂) (σ₃ : State) : σ₁.join σ₃ ≼ σ₂ := by grind -/-- info: Try this: (induction σ₁, σ₂ using State.join.induct) <;> grind -/ +/-- +info: Try these: +• (induction σ₁, σ₂ using State.join.induct) <;> grind +• · + induction σ₁, σ₂ using State.join.induct + · + grind only [State.join_le_left, State.find?, State.join, State.join_le_left_of, State.le, = State.find?_nil, + State.bot_le, State.le_refl] + · + grind only [State.join, State.join_le_left, State.length_erase_le, State.find?, State.join_le_left_of, State.le, = + State.find?_erase_eq, State.erase_le, State.le_refl, cases Or] + · grind only [State.join, State.join_le_left, State.length_erase_le, State.join_le_left_of, State.le, State.erase_le] + · grind only [State.join, State.join_le_left, State.length_erase_le, State.join_le_left_of, State.le, State.erase_le] +-/ #guard_msgs (info) in example (σ₁ σ₂ : State) : σ₁.join σ₂ ≼ σ₂ := by try? diff --git a/tests/lean/run/try_trace1.lean b/tests/lean/run/try_trace1.lean index 09d84f2253..fba4395a4a 100644 --- a/tests/lean/run/try_trace1.lean +++ b/tests/lean/run/try_trace1.lean @@ -1,4 +1,5 @@ set_option grind.warning false +%reset_grind_attrs /-- info: Try these: @@ -97,3 +98,19 @@ example : app (app as bs) cs = app as (app bs cs) := by intro _ _ _ -- `as`, `bs`, and `cs` now have inaccessible names. try? + +def concat : List α → α → List α + | .nil, b => .cons b .nil + | .cons a as, b => .cons a (concat as b) + +attribute [simp] concat + +/-- +info: Try this: · + induction as, a using concat.induct + · rfl + · simp_all +-/ +#guard_msgs (info) in +example (as : List α) (a : α) : concat as a = as ++ [a] := by + try?