From 9e1d97c261249af61ca60b42ead8d19633466321 Mon Sep 17 00:00:00 2001 From: Sebastian Graf Date: Fri, 15 Aug 2025 14:25:01 +0200 Subject: [PATCH] feat: extended `using invariants` and `with` syntax for `mvcgen` (#9927) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR implements extended `induction`-inspired syntax for `mvcgen`, allowing optional `using invariants` and `with` sections. ```lean mvcgen using invariants | 1 => Invariant.withEarlyReturn (onReturn := fun ret seen => ⌜ret = false ∧ ¬l.Nodup⌝) (onContinue := fun traversalState seen => ⌜(∀ x, x ∈ seen ↔ x ∈ traversalState.prefix) ∧ traversalState.prefix.Nodup⌝) with mleave -- mleave is a no-op here, but we are just testing the grammar | vc1 => grind | vc2 => grind | vc3 => grind | vc4 => grind | vc5 => grind ``` --- src/Lean/Elab/Tactic/Do/VCGen.lean | 102 ++++++++++++--- src/Std/Tactic/Do/Syntax.lean | 41 +++++-- tests/lean/run/mvcgenUsingWith.lean | 184 ++++++++++++++++++++++++++++ 3 files changed, 301 insertions(+), 26 deletions(-) create mode 100644 tests/lean/run/mvcgenUsingWith.lean diff --git a/src/Lean/Elab/Tactic/Do/VCGen.lean b/src/Lean/Elab/Tactic/Do/VCGen.lean index 961d8837ac..ec5338d29f 100644 --- a/src/Lean/Elab/Tactic/Do/VCGen.lean +++ b/src/Lean/Elab/Tactic/Do/VCGen.lean @@ -6,21 +6,23 @@ Authors: Sebastian Graf module prelude -public import Std.Do.WP -public import Std.Do.Triple -public import Lean.Elab.Tactic.Simp -public import Lean.Elab.Tactic.Do.ProofMode.Basic -public import Lean.Elab.Tactic.Do.ProofMode.Intro -public import Lean.Elab.Tactic.Do.ProofMode.Revert -public import Lean.Elab.Tactic.Do.ProofMode.Cases -public import Lean.Elab.Tactic.Do.ProofMode.Specialize -public import Lean.Elab.Tactic.Do.ProofMode.Pure -public import Lean.Elab.Tactic.Do.LetElim -public import Lean.Elab.Tactic.Do.Spec -public import Lean.Elab.Tactic.Do.Attr -public import Lean.Elab.Tactic.Do.Syntax +import Std.Do.WP +import Std.Do.Triple +import Lean.Elab.Tactic.Do.VCGen.Split +import Lean.Elab.Tactic.Simp +import Lean.Elab.Tactic.Do.ProofMode.Basic +import Lean.Elab.Tactic.Do.ProofMode.Intro +import Lean.Elab.Tactic.Do.ProofMode.Revert +import Lean.Elab.Tactic.Do.ProofMode.Cases +import Lean.Elab.Tactic.Do.ProofMode.Specialize +import Lean.Elab.Tactic.Do.ProofMode.Pure +import Lean.Elab.Tactic.Do.LetElim +import Lean.Elab.Tactic.Do.Spec +import Lean.Elab.Tactic.Do.Attr +import Lean.Elab.Tactic.Do.Syntax +import Lean.Elab.Tactic.Induction + public import Lean.Elab.Tactic.Do.VCGen.Basic -public import Lean.Elab.Tactic.Do.VCGen.Split public section @@ -40,7 +42,11 @@ private def ProofMode.MGoal.withNewProg (goal : MGoal) (e : Expr) : MGoal := namespace VCGen -partial def genVCs (goal : MVarId) (ctx : Context) (fuel : Fuel) : MetaM (Array MVarId) := do +structure Result where + invariants : Array MVarId + vcs : Array MVarId + +partial def genVCs (goal : MVarId) (ctx : Context) (fuel : Fuel) : MetaM Result := do let (mvar, goal) ← mStartMVar goal mvar.withContext <| withReducible do let (prf, state) ← StateRefT'.run (ReaderT.run (onGoal goal (← mvar.getTag)) ctx) { fuel } @@ -51,7 +57,7 @@ partial def genVCs (goal : MVarId) (ctx : Context) (fuel : Fuel) : MetaM (Array for h : idx in [:state.vcs.size] do let mv := state.vcs[idx] mv.setTag (Name.mkSimple ("vc" ++ toString (idx + 1)) ++ (← mv.getTag)) - return state.invariants ++ state.vcs + return { invariants := state.invariants, vcs := state.vcs } where onFail (goal : MGoal) (name : Name) : VCGenM Expr := do -- trace[Elab.Tactic.Do.vcgen] "fail {goal.toExpr}" @@ -356,6 +362,62 @@ where end VCGen +def elabInvariants (stx : Syntax) (invariants : Array MVarId) : TermElabM Unit := do + let some stx := stx.getOptional? | return () + let stx : TSyntax ``invariantAlts := ⟨stx⟩ + match stx with + | `(invariantAlts| using invariants $alts*) => + for alt in alts do + match alt with + | `(invariantAlt| | $ns,* => $rhs) => + for ref in ns.getElems do + let n := ref.getNat + if n = 0 then + logErrorAt ref "Invariant index 0 is invalid. Invariant indices start at 1 just as the case labels `inv`." + continue + let some mv := invariants[n-1]? | do + logErrorAt ref m!"Invariant index {n} is out of bounds. Invariant indices start at 1 just as the case labels `inv`. There were {invariants.size} invariants." + continue + if ← mv.isAssigned then + logErrorAt ref m!"Invariant {n} is already assigned" + continue + mv.assign (← mv.withContext <| Term.elabTerm rhs (← mv.getType)) + | _ => logErrorAt alt "Expected invariantAlt, got {alt}" + | _ => logErrorAt stx "Expected invariantAlts, got {stx}" + +private def patchVCAltIntoCaseTactic (alt : TSyntax ``vcAlt) : TSyntax ``case := + -- syntax vcAlt := sepBy1(caseArg, " | ") " => " tacticSeq + -- syntax case := "case " sepBy1(caseArg, " | ") " => " tacticSeq + ⟨alt.raw |>.setKind ``case |>.setArg 0 (mkAtom "case")⟩ + +partial def elabVCs (stx : Syntax) (vcs : Array MVarId) : TacticM (List MVarId) := do + let some stx := stx.getOptional? | return vcs.toList + match (⟨stx⟩ : TSyntax ``vcAlts) with + | `(vcAlts| with $(tactic)? $alts*) => + let vcs ← applyPreTac vcs tactic + evalAlts vcs alts + | _ => + logErrorAt stx "Expected inductionAlts, got {stx}" + return vcs.toList +where + applyPreTac (vcs : Array MVarId) (tactic : Option Syntax) : TacticM (Array MVarId) := do + let some tactic := tactic | return vcs + let mut newVCs := #[] + for vc in vcs do + let vcs ← try evalTacticAt tactic vc catch _ => pure [vc] + newVCs := newVCs ++ vcs + return newVCs + + evalAlts (vcs : Array MVarId) (alts : TSyntaxArray ``vcAlt) : TacticM (List MVarId) := do + let oldGoals ← getGoals + try + setGoals vcs.toList + for alt in alts do withRef alt <| evalTactic <| patchVCAltIntoCaseTactic alt + pruneSolvedGoals + getGoals + finally + setGoals oldGoals + @[builtin_tactic Lean.Parser.Tactic.mvcgen] def elabMVCGen : Tactic := fun stx => withMainContext do if mvcgen.warning.get (← getOptions) then @@ -366,10 +428,13 @@ def elabMVCGen : Tactic := fun stx => withMainContext do | none => .unlimited let goal ← getMainGoal let goal ← if ctx.config.elimLets then elimLets goal else pure goal - let vcs ← VCGen.genVCs goal ctx fuel + let { invariants, vcs } ← VCGen.genVCs goal ctx fuel let runOnVCs (tac : TSyntax `tactic) (vcs : Array MVarId) : TermElabM (Array MVarId) := vcs.flatMapM fun vc => List.toArray <$> Term.withSynthesize do Tactic.run vc (Tactic.evalTactic tac *> Tactic.pruneSolvedGoals) + let invariants ← Term.TermElabM.run' do + let invariants ← if ctx.config.leave then runOnVCs (← `(tactic| try mleave)) invariants else pure invariants + elabInvariants stx[3] invariants let vcs ← Term.TermElabM.run' do let vcs ← if ctx.config.trivial then runOnVCs (← `(tactic| try mvcgen_trivial)) vcs else pure vcs let vcs ← if ctx.config.leave then runOnVCs (← `(tactic| try mleave)) vcs else pure vcs @@ -377,4 +442,5 @@ def elabMVCGen : Tactic := fun stx => withMainContext do -- Eliminating lets here causes some metavariables in `mkFreshPair_triple` to become nonassignable -- so we don't do it. Presumably some weird delayed assignment thing is going on. -- let vcs ← if ctx.config.elimLets then liftMetaM <| vcs.mapM elimLets else pure vcs - replaceMainGoal vcs.toList + let vcs ← elabVCs stx[4] vcs + replaceMainGoal (invariants ++ vcs).toList diff --git a/src/Std/Tactic/Do/Syntax.lean b/src/Std/Tactic/Do/Syntax.lean index c0eb4c0e2a..9559cfbe36 100644 --- a/src/Std/Tactic/Do/Syntax.lean +++ b/src/Std/Tactic/Do/Syntax.lean @@ -301,6 +301,14 @@ all_goals macro (name := mspecNoSimp) "mspec_no_simp" spec:(ppSpace colGt term)? : tactic => `(tactic| ((try with_reducible mspec_no_bind $(mkIdent ``Std.Do.Spec.bind)) <;> mspec_no_bind $[$spec]?)) +@[inherit_doc Lean.Parser.Tactic.mspecMacro] +macro (name := mspec) "mspec" spec:(ppSpace colGt term)? : tactic => + `(tactic| (mspec_no_simp $[$spec]? + all_goals ((try simp only [ + $(mkIdent ``Std.Do.SPred.true_intro_simp):term, + $(mkIdent ``Std.Do.SPred.apply_pure):term]) + (try mpure_intro; trivial)))) + syntax "mvcgen_trivial_extensible" : tactic /-- @@ -316,17 +324,34 @@ macro "mvcgen_trivial" : tactic => | try mvcgen_trivial_extensible ) -@[inherit_doc Lean.Parser.Tactic.mspecMacro] -macro (name := mspec) "mspec" spec:(ppSpace colGt term)? : tactic => - `(tactic| (mspec_no_simp $[$spec]? - all_goals ((try simp only [ - $(mkIdent ``Std.Do.SPred.true_intro_simp):term, - $(mkIdent ``Std.Do.SPred.apply_pure):term]) - (try mpure_intro; trivial)))) +/-- +An invariant alternative of the form `| , ..., => term`, where `nᵢ` are natural numbers +referring to numbered invariant goals. +-/ +syntax invariantAlt := ppDedent(ppLine) withPosition("| " num,+) " => " term + +/-- +After `using`, there can be an optional ` invariants ` followed by a list of alternatives +`| 1 => term | ... | => term`. +-/ +syntax invariantAlts := " using" (&" invariants " withPosition((colGe invariantAlt)*))? + +/-- +In induction alternative, which can have 1 or more cases on the left +and `_`, `?_`, or a tactic sequence after the `=>`. +-/ +syntax vcAlt := "| " sepBy1(caseArg, " | ") " => " tacticSeq -- `case` tactic has "case " instead of "| " + +/-- +After `with`, there is an optional tactic that runs on all branches, and +then a list of alternatives. +-/ +syntax vcAlts := " with" (ppSpace colGt tactic)? withPosition((colGe vcAlt)*) @[inherit_doc Lean.Parser.Tactic.mvcgenMacro] syntax (name := mvcgen) "mvcgen" optConfig - (" [" withoutPosition((simpStar <|> simpErase <|> simpLemma),*,?) "]")? : tactic + (" [" withoutPosition((simpStar <|> simpErase <|> simpLemma),*,?) "]")? + (invariantAlts)? (vcAlts)? : tactic /-- Like `mvcgen`, but does not attempt to prove trivial VCs via `mpure_intro; trivial`. diff --git a/tests/lean/run/mvcgenUsingWith.lean b/tests/lean/run/mvcgenUsingWith.lean new file mode 100644 index 0000000000..4a5337f9e6 --- /dev/null +++ b/tests/lean/run/mvcgenUsingWith.lean @@ -0,0 +1,184 @@ +import Std.Tactic.Do +import Std + +open Std Do + +set_option grind.warning false +set_option mvcgen.warning false + +def nodup (l : List Int) : Bool := Id.run do + let mut seen : HashSet Int := ∅ + for x in l do + if x ∈ seen then + return false + seen := seen.insert x + return true + +theorem nodup_correct_vanilla (l : List Int) : nodup l ↔ l.Nodup := by + generalize h : nodup l = r + apply Id.of_wp_run_eq h + mvcgen + case inv1 => + exact Invariant.withEarlyReturn + (onReturn := fun ret seen => ⌜ret = false ∧ ¬l.Nodup⌝) + (onContinue := fun traversalState seen => + ⌜(∀ x, x ∈ seen ↔ x ∈ traversalState.prefix) ∧ traversalState.prefix.Nodup⌝) + all_goals mleave; grind + +theorem nodup_correct_using (l : List Int) : nodup l ↔ l.Nodup := by + generalize h : nodup l = r + apply Id.of_wp_run_eq h + mvcgen using invariants + | 1 => Invariant.withEarlyReturn + (onReturn := fun ret seen => ⌜ret = false ∧ ¬l.Nodup⌝) + (onContinue := fun traversalState seen => + ⌜(∀ x, x ∈ seen ↔ x ∈ traversalState.prefix) ∧ traversalState.prefix.Nodup⌝) + all_goals grind + +theorem nodup_correct_using_with_pretac (l : List Int) : nodup l ↔ l.Nodup := by + generalize h : nodup l = r + apply Id.of_wp_run_eq h + mvcgen using invariants + | 1 => Invariant.withEarlyReturn + (onReturn := fun ret seen => ⌜ret = false ∧ ¬l.Nodup⌝) + (onContinue := fun traversalState seen => + ⌜(∀ x, x ∈ seen ↔ x ∈ traversalState.prefix) ∧ traversalState.prefix.Nodup⌝) + with grind + +theorem nodup_correct_using_with_cases (l : List Int) : nodup l ↔ l.Nodup := by + generalize h : nodup l = r + apply Id.of_wp_run_eq h + mvcgen + using invariants + | 1 => Invariant.withEarlyReturn + (onReturn := fun ret seen => ⌜ret = false ∧ ¬l.Nodup⌝) + (onContinue := fun traversalState seen => + ⌜(∀ x, x ∈ seen ↔ x ∈ traversalState.prefix) ∧ traversalState.prefix.Nodup⌝) + with + | vc1 => grind + | vc2 => grind + | vc3 => grind + | vc4 => grind + | vc5 => grind + +theorem nodup_correct_using_with_pretac_cases (l : List Int) : nodup l ↔ l.Nodup := by + generalize h : nodup l = r + apply Id.of_wp_run_eq h + mvcgen + using invariants + | 1 => Invariant.withEarlyReturn + (onReturn := fun ret seen => ⌜ret = false ∧ ¬l.Nodup⌝) + (onContinue := fun traversalState seen => + ⌜(∀ x, x ∈ seen ↔ x ∈ traversalState.prefix) ∧ traversalState.prefix.Nodup⌝) + with mleave -- mleave is a no-op here, but we are just testing the grammar + | vc1 => grind + | vc2 | vc3 | vc4 => grind + | vc5 => grind + +/-- +error: Case tag `vc3` not found. + +Hint: The only available case tag is `vc5.a.post.success.h_2._@.Std.Do.WP.Basic._hyg.1626`. + vc3̵5̲.̲a̲.̲p̲o̲s̲t̲.̲s̲u̲c̲c̲e̲s̲s̲.̲h̲_̲2̲.̲_̲@̲.̲S̲t̲d̲.̲D̲o̲.̲W̲P̲.̲B̲a̲s̲i̲c̲.̲_̲h̲y̲g̲.̲1̲6̲2̲6̲ +-/ +#guard_msgs in +theorem nodup_correct_using_with_cases_error (l : List Int) : nodup l ↔ l.Nodup := by + generalize h : nodup l = r + apply Id.of_wp_run_eq h + mvcgen + using invariants + | 1 => Invariant.withEarlyReturn + (onReturn := fun ret seen => ⌜ret = false ∧ ¬l.Nodup⌝) + (onContinue := fun traversalState seen => + ⌜(∀ x, x ∈ seen ↔ x ∈ traversalState.prefix) ∧ traversalState.prefix.Nodup⌝) + with mleave -- mleave is a no-op here, but we are just testing the grammar + | vc1 => grind + | vc2 | vc3 | vc4 => grind + | vc3 => grind + | vc5 => grind + +theorem test_with_pretac {m : Option Nat} (h : m = some 4) : + ⦃⌜True⌝⦄ + (match m with + | some n => (set n : StateM Nat PUnit) + | none => set 0) + ⦃⇓ r s => ⌜s = 4⌝⦄ := by + mvcgen with simp_all + +theorem test_with_cases {m : Option Nat} (h : m = some 4) : + ⦃⌜True⌝⦄ + (match m with + | some n => (set n : StateM Nat PUnit) + | none => set 0) + ⦃⇓ r s => ⌜s = 4⌝⦄ := by + mvcgen + with + | vc1 => grind + | vc2 => grind + +theorem test_with_pretac_cases {m : Option Nat} (h : m = some 4) : + ⦃⌜True⌝⦄ + (match m with + | some n => (set n : StateM Nat PUnit) + | none => set 0) + ⦃⇓ r s => ⌜s = 4⌝⦄ := by + mvcgen + with simp -- `simp` is a no-op on some goals, but it should not fail + | vc1 => grind + | vc2 => grind + +def nodup_twice (l : List Int) : Bool := Id.run do + let mut seen : HashSet Int := ∅ + for x in l do + if x ∈ seen then + return false + seen := seen.insert x + let mut seen2 : HashSet Int := ∅ + for x in l do + if x ∈ seen2 then + return false + seen2 := seen2.insert x + return true + +theorem nodup_twice_correct_using_with (l : List Int) : nodup_twice l ↔ l.Nodup := by + generalize h : nodup_twice l = r + apply Id.of_wp_run_eq h + mvcgen + using invariants + | 1 => Invariant.withEarlyReturn + (onReturn := fun ret seen => ⌜ret = false ∧ ¬l.Nodup⌝) + (onContinue := fun traversalState seen => + ⌜(∀ x, x ∈ seen ↔ x ∈ traversalState.prefix) ∧ traversalState.prefix.Nodup⌝) + | 2 => Invariant.withEarlyReturn + (onReturn := fun ret seen => ⌜ret = false ∧ ¬l.Nodup⌝) + (onContinue := fun traversalState seen => + ⌜(∀ x, x ∈ seen ↔ x ∈ traversalState.prefix) ∧ traversalState.prefix.Nodup⌝) + with grind + +theorem nodup_twice_correct_using_multiple_with (l : List Int) : nodup_twice l ↔ l.Nodup := by + generalize h : nodup_twice l = r + apply Id.of_wp_run_eq h + mvcgen + using invariants + | 1, 2 => Invariant.withEarlyReturn + (onReturn := fun ret seen => ⌜ret = false ∧ ¬l.Nodup⌝) + (onContinue := fun traversalState seen => + ⌜(∀ x, x ∈ seen ↔ x ∈ traversalState.prefix) ∧ traversalState.prefix.Nodup⌝) + with grind + +/-- error: Invariant 2 is already assigned -/ +#guard_msgs in +theorem nodup_twice_correct_using_multiple_error (l : List Int) : nodup_twice l ↔ l.Nodup := by + generalize h : nodup_twice l = r + apply Id.of_wp_run_eq h + mvcgen + using invariants + | 1, 2 => Invariant.withEarlyReturn + (onReturn := fun ret seen => ⌜ret = false ∧ ¬l.Nodup⌝) + (onContinue := fun traversalState seen => + ⌜(∀ x, x ∈ seen ↔ x ∈ traversalState.prefix) ∧ traversalState.prefix.Nodup⌝) + | 2 => Invariant.withEarlyReturn + (onReturn := fun ret seen => ⌜ret = false ∧ ¬l.Nodup⌝) + (onContinue := fun traversalState seen => + ⌜(∀ x, x ∈ seen ↔ x ∈ traversalState.prefix) ∧ traversalState.prefix.Nodup⌝) + with grind