feat: extended using invariants and with syntax for mvcgen (#9927)

This PR implements extended `induction`-inspired syntax for `mvcgen`,
allowing optional `using invariants` and `with` sections.

```lean
  mvcgen
  using invariants
  | 1 => Invariant.withEarlyReturn
      (onReturn := fun ret seen => ⌜ret = false ∧ ¬l.Nodup⌝)
      (onContinue := fun traversalState seen =>
        ⌜(∀ x, x ∈ seen ↔ x ∈ traversalState.prefix) ∧ traversalState.prefix.Nodup⌝)
  with mleave -- mleave is a no-op here, but we are just testing the grammar
  | vc1 => grind
  | vc2 => grind
  | vc3 => grind
  | vc4 => grind
  | vc5 => grind
```
This commit is contained in:
Sebastian Graf 2025-08-15 14:25:01 +02:00 committed by GitHub
parent 4c562fc1a3
commit 9e1d97c261
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 301 additions and 26 deletions

View file

@ -6,21 +6,23 @@ Authors: Sebastian Graf
module
prelude
public import Std.Do.WP
public import Std.Do.Triple
public import Lean.Elab.Tactic.Simp
public import Lean.Elab.Tactic.Do.ProofMode.Basic
public import Lean.Elab.Tactic.Do.ProofMode.Intro
public import Lean.Elab.Tactic.Do.ProofMode.Revert
public import Lean.Elab.Tactic.Do.ProofMode.Cases
public import Lean.Elab.Tactic.Do.ProofMode.Specialize
public import Lean.Elab.Tactic.Do.ProofMode.Pure
public import Lean.Elab.Tactic.Do.LetElim
public import Lean.Elab.Tactic.Do.Spec
public import Lean.Elab.Tactic.Do.Attr
public import Lean.Elab.Tactic.Do.Syntax
import Std.Do.WP
import Std.Do.Triple
import Lean.Elab.Tactic.Do.VCGen.Split
import Lean.Elab.Tactic.Simp
import Lean.Elab.Tactic.Do.ProofMode.Basic
import Lean.Elab.Tactic.Do.ProofMode.Intro
import Lean.Elab.Tactic.Do.ProofMode.Revert
import Lean.Elab.Tactic.Do.ProofMode.Cases
import Lean.Elab.Tactic.Do.ProofMode.Specialize
import Lean.Elab.Tactic.Do.ProofMode.Pure
import Lean.Elab.Tactic.Do.LetElim
import Lean.Elab.Tactic.Do.Spec
import Lean.Elab.Tactic.Do.Attr
import Lean.Elab.Tactic.Do.Syntax
import Lean.Elab.Tactic.Induction
public import Lean.Elab.Tactic.Do.VCGen.Basic
public import Lean.Elab.Tactic.Do.VCGen.Split
public section
@ -40,7 +42,11 @@ private def ProofMode.MGoal.withNewProg (goal : MGoal) (e : Expr) : MGoal :=
namespace VCGen
partial def genVCs (goal : MVarId) (ctx : Context) (fuel : Fuel) : MetaM (Array MVarId) := do
structure Result where
invariants : Array MVarId
vcs : Array MVarId
partial def genVCs (goal : MVarId) (ctx : Context) (fuel : Fuel) : MetaM Result := do
let (mvar, goal) ← mStartMVar goal
mvar.withContext <| withReducible do
let (prf, state) ← StateRefT'.run (ReaderT.run (onGoal goal (← mvar.getTag)) ctx) { fuel }
@ -51,7 +57,7 @@ partial def genVCs (goal : MVarId) (ctx : Context) (fuel : Fuel) : MetaM (Array
for h : idx in [:state.vcs.size] do
let mv := state.vcs[idx]
mv.setTag (Name.mkSimple ("vc" ++ toString (idx + 1)) ++ (← mv.getTag))
return state.invariants ++ state.vcs
return { invariants := state.invariants, vcs := state.vcs }
where
onFail (goal : MGoal) (name : Name) : VCGenM Expr := do
-- trace[Elab.Tactic.Do.vcgen] "fail {goal.toExpr}"
@ -356,6 +362,62 @@ where
end VCGen
def elabInvariants (stx : Syntax) (invariants : Array MVarId) : TermElabM Unit := do
let some stx := stx.getOptional? | return ()
let stx : TSyntax ``invariantAlts := ⟨stx⟩
match stx with
| `(invariantAlts| using invariants $alts*) =>
for alt in alts do
match alt with
| `(invariantAlt| | $ns,* => $rhs) =>
for ref in ns.getElems do
let n := ref.getNat
if n = 0 then
logErrorAt ref "Invariant index 0 is invalid. Invariant indices start at 1 just as the case labels `inv<n>`."
continue
let some mv := invariants[n-1]? | do
logErrorAt ref m!"Invariant index {n} is out of bounds. Invariant indices start at 1 just as the case labels `inv<n>`. There were {invariants.size} invariants."
continue
if ← mv.isAssigned then
logErrorAt ref m!"Invariant {n} is already assigned"
continue
mv.assign (← mv.withContext <| Term.elabTerm rhs (← mv.getType))
| _ => logErrorAt alt "Expected invariantAlt, got {alt}"
| _ => logErrorAt stx "Expected invariantAlts, got {stx}"
private def patchVCAltIntoCaseTactic (alt : TSyntax ``vcAlt) : TSyntax ``case :=
-- syntax vcAlt := sepBy1(caseArg, " | ") " => " tacticSeq
-- syntax case := "case " sepBy1(caseArg, " | ") " => " tacticSeq
⟨alt.raw |>.setKind ``case |>.setArg 0 (mkAtom "case")⟩
partial def elabVCs (stx : Syntax) (vcs : Array MVarId) : TacticM (List MVarId) := do
let some stx := stx.getOptional? | return vcs.toList
match (⟨stx⟩ : TSyntax ``vcAlts) with
| `(vcAlts| with $(tactic)? $alts*) =>
let vcs ← applyPreTac vcs tactic
evalAlts vcs alts
| _ =>
logErrorAt stx "Expected inductionAlts, got {stx}"
return vcs.toList
where
applyPreTac (vcs : Array MVarId) (tactic : Option Syntax) : TacticM (Array MVarId) := do
let some tactic := tactic | return vcs
let mut newVCs := #[]
for vc in vcs do
let vcs ← try evalTacticAt tactic vc catch _ => pure [vc]
newVCs := newVCs ++ vcs
return newVCs
evalAlts (vcs : Array MVarId) (alts : TSyntaxArray ``vcAlt) : TacticM (List MVarId) := do
let oldGoals ← getGoals
try
setGoals vcs.toList
for alt in alts do withRef alt <| evalTactic <| patchVCAltIntoCaseTactic alt
pruneSolvedGoals
getGoals
finally
setGoals oldGoals
@[builtin_tactic Lean.Parser.Tactic.mvcgen]
def elabMVCGen : Tactic := fun stx => withMainContext do
if mvcgen.warning.get (← getOptions) then
@ -366,10 +428,13 @@ def elabMVCGen : Tactic := fun stx => withMainContext do
| none => .unlimited
let goal ← getMainGoal
let goal ← if ctx.config.elimLets then elimLets goal else pure goal
let vcs ← VCGen.genVCs goal ctx fuel
let { invariants, vcs } ← VCGen.genVCs goal ctx fuel
let runOnVCs (tac : TSyntax `tactic) (vcs : Array MVarId) : TermElabM (Array MVarId) :=
vcs.flatMapM fun vc => List.toArray <$> Term.withSynthesize do
Tactic.run vc (Tactic.evalTactic tac *> Tactic.pruneSolvedGoals)
let invariants ← Term.TermElabM.run' do
let invariants ← if ctx.config.leave then runOnVCs (← `(tactic| try mleave)) invariants else pure invariants
elabInvariants stx[3] invariants
let vcs ← Term.TermElabM.run' do
let vcs ← if ctx.config.trivial then runOnVCs (← `(tactic| try mvcgen_trivial)) vcs else pure vcs
let vcs ← if ctx.config.leave then runOnVCs (← `(tactic| try mleave)) vcs else pure vcs
@ -377,4 +442,5 @@ def elabMVCGen : Tactic := fun stx => withMainContext do
-- Eliminating lets here causes some metavariables in `mkFreshPair_triple` to become nonassignable
-- so we don't do it. Presumably some weird delayed assignment thing is going on.
-- let vcs ← if ctx.config.elimLets then liftMetaM <| vcs.mapM elimLets else pure vcs
replaceMainGoal vcs.toList
let vcs ← elabVCs stx[4] vcs
replaceMainGoal (invariants ++ vcs).toList

View file

@ -301,6 +301,14 @@ all_goals
macro (name := mspecNoSimp) "mspec_no_simp" spec:(ppSpace colGt term)? : tactic =>
`(tactic| ((try with_reducible mspec_no_bind $(mkIdent ``Std.Do.Spec.bind)) <;> mspec_no_bind $[$spec]?))
@[inherit_doc Lean.Parser.Tactic.mspecMacro]
macro (name := mspec) "mspec" spec:(ppSpace colGt term)? : tactic =>
`(tactic| (mspec_no_simp $[$spec]?
all_goals ((try simp only [
$(mkIdent ``Std.Do.SPred.true_intro_simp):term,
$(mkIdent ``Std.Do.SPred.apply_pure):term])
(try mpure_intro; trivial))))
syntax "mvcgen_trivial_extensible" : tactic
/--
@ -316,17 +324,34 @@ macro "mvcgen_trivial" : tactic =>
| try mvcgen_trivial_extensible
)
@[inherit_doc Lean.Parser.Tactic.mspecMacro]
macro (name := mspec) "mspec" spec:(ppSpace colGt term)? : tactic =>
`(tactic| (mspec_no_simp $[$spec]?
all_goals ((try simp only [
$(mkIdent ``Std.Do.SPred.true_intro_simp):term,
$(mkIdent ``Std.Do.SPred.apply_pure):term])
(try mpure_intro; trivial))))
/--
An invariant alternative of the form `| <n₁>, ..., <nₖ> => term`, where `nᵢ` are natural numbers
referring to numbered invariant goals.
-/
syntax invariantAlt := ppDedent(ppLine) withPosition("| " num,+) " => " term
/--
After `using`, there can be an optional ` invariants ` followed by a list of alternatives
`| 1 => term | ... | <n> => term`.
-/
syntax invariantAlts := " using" (&" invariants " withPosition((colGe invariantAlt)*))?
/--
In induction alternative, which can have 1 or more cases on the left
and `_`, `?_`, or a tactic sequence after the `=>`.
-/
syntax vcAlt := "| " sepBy1(caseArg, " | ") " => " tacticSeq -- `case` tactic has "case " instead of "| "
/--
After `with`, there is an optional tactic that runs on all branches, and
then a list of alternatives.
-/
syntax vcAlts := " with" (ppSpace colGt tactic)? withPosition((colGe vcAlt)*)
@[inherit_doc Lean.Parser.Tactic.mvcgenMacro]
syntax (name := mvcgen) "mvcgen" optConfig
(" [" withoutPosition((simpStar <|> simpErase <|> simpLemma),*,?) "]")? : tactic
(" [" withoutPosition((simpStar <|> simpErase <|> simpLemma),*,?) "]")?
(invariantAlts)? (vcAlts)? : tactic
/--
Like `mvcgen`, but does not attempt to prove trivial VCs via `mpure_intro; trivial`.

View file

@ -0,0 +1,184 @@
import Std.Tactic.Do
import Std
open Std Do
set_option grind.warning false
set_option mvcgen.warning false
def nodup (l : List Int) : Bool := Id.run do
let mut seen : HashSet Int := ∅
for x in l do
if x ∈ seen then
return false
seen := seen.insert x
return true
theorem nodup_correct_vanilla (l : List Int) : nodup l ↔ l.Nodup := by
generalize h : nodup l = r
apply Id.of_wp_run_eq h
mvcgen
case inv1 =>
exact Invariant.withEarlyReturn
(onReturn := fun ret seen => ⌜ret = false ∧ ¬l.Nodup⌝)
(onContinue := fun traversalState seen =>
⌜(∀ x, x ∈ seen ↔ x ∈ traversalState.prefix) ∧ traversalState.prefix.Nodup⌝)
all_goals mleave; grind
theorem nodup_correct_using (l : List Int) : nodup l ↔ l.Nodup := by
generalize h : nodup l = r
apply Id.of_wp_run_eq h
mvcgen using invariants
| 1 => Invariant.withEarlyReturn
(onReturn := fun ret seen => ⌜ret = false ∧ ¬l.Nodup⌝)
(onContinue := fun traversalState seen =>
⌜(∀ x, x ∈ seen ↔ x ∈ traversalState.prefix) ∧ traversalState.prefix.Nodup⌝)
all_goals grind
theorem nodup_correct_using_with_pretac (l : List Int) : nodup l ↔ l.Nodup := by
generalize h : nodup l = r
apply Id.of_wp_run_eq h
mvcgen using invariants
| 1 => Invariant.withEarlyReturn
(onReturn := fun ret seen => ⌜ret = false ∧ ¬l.Nodup⌝)
(onContinue := fun traversalState seen =>
⌜(∀ x, x ∈ seen ↔ x ∈ traversalState.prefix) ∧ traversalState.prefix.Nodup⌝)
with grind
theorem nodup_correct_using_with_cases (l : List Int) : nodup l ↔ l.Nodup := by
generalize h : nodup l = r
apply Id.of_wp_run_eq h
mvcgen
using invariants
| 1 => Invariant.withEarlyReturn
(onReturn := fun ret seen => ⌜ret = false ∧ ¬l.Nodup⌝)
(onContinue := fun traversalState seen =>
⌜(∀ x, x ∈ seen ↔ x ∈ traversalState.prefix) ∧ traversalState.prefix.Nodup⌝)
with
| vc1 => grind
| vc2 => grind
| vc3 => grind
| vc4 => grind
| vc5 => grind
theorem nodup_correct_using_with_pretac_cases (l : List Int) : nodup l ↔ l.Nodup := by
generalize h : nodup l = r
apply Id.of_wp_run_eq h
mvcgen
using invariants
| 1 => Invariant.withEarlyReturn
(onReturn := fun ret seen => ⌜ret = false ∧ ¬l.Nodup⌝)
(onContinue := fun traversalState seen =>
⌜(∀ x, x ∈ seen ↔ x ∈ traversalState.prefix) ∧ traversalState.prefix.Nodup⌝)
with mleave -- mleave is a no-op here, but we are just testing the grammar
| vc1 => grind
| vc2 | vc3 | vc4 => grind
| vc5 => grind
/--
error: Case tag `vc3` not found.
Hint: The only available case tag is `vc5.a.post.success.h_2._@.Std.Do.WP.Basic._hyg.1626`.
vc3̵5̲.̲a̲.̲p̲o̲s̲t̲.̲s̲u̲c̲c̲e̲s̲s̲.̲h̲_̲2̲.̲_̲@̲.̲S̲t̲d̲.̲D̲o̲.̲W̲P̲.̲B̲a̲s̲i̲c̲.̲_̲h̲y̲g̲.̲1̲6̲2̲6̲
-/
#guard_msgs in
theorem nodup_correct_using_with_cases_error (l : List Int) : nodup l ↔ l.Nodup := by
generalize h : nodup l = r
apply Id.of_wp_run_eq h
mvcgen
using invariants
| 1 => Invariant.withEarlyReturn
(onReturn := fun ret seen => ⌜ret = false ∧ ¬l.Nodup⌝)
(onContinue := fun traversalState seen =>
⌜(∀ x, x ∈ seen ↔ x ∈ traversalState.prefix) ∧ traversalState.prefix.Nodup⌝)
with mleave -- mleave is a no-op here, but we are just testing the grammar
| vc1 => grind
| vc2 | vc3 | vc4 => grind
| vc3 => grind
| vc5 => grind
theorem test_with_pretac {m : Option Nat} (h : m = some 4) :
⦃⌜True⌝⦄
(match m with
| some n => (set n : StateM Nat PUnit)
| none => set 0)
⦃⇓ r s => ⌜s = 4⌝⦄ := by
mvcgen with simp_all
theorem test_with_cases {m : Option Nat} (h : m = some 4) :
⦃⌜True⌝⦄
(match m with
| some n => (set n : StateM Nat PUnit)
| none => set 0)
⦃⇓ r s => ⌜s = 4⌝⦄ := by
mvcgen
with
| vc1 => grind
| vc2 => grind
theorem test_with_pretac_cases {m : Option Nat} (h : m = some 4) :
⦃⌜True⌝⦄
(match m with
| some n => (set n : StateM Nat PUnit)
| none => set 0)
⦃⇓ r s => ⌜s = 4⌝⦄ := by
mvcgen
with simp -- `simp` is a no-op on some goals, but it should not fail
| vc1 => grind
| vc2 => grind
def nodup_twice (l : List Int) : Bool := Id.run do
let mut seen : HashSet Int := ∅
for x in l do
if x ∈ seen then
return false
seen := seen.insert x
let mut seen2 : HashSet Int := ∅
for x in l do
if x ∈ seen2 then
return false
seen2 := seen2.insert x
return true
theorem nodup_twice_correct_using_with (l : List Int) : nodup_twice l ↔ l.Nodup := by
generalize h : nodup_twice l = r
apply Id.of_wp_run_eq h
mvcgen
using invariants
| 1 => Invariant.withEarlyReturn
(onReturn := fun ret seen => ⌜ret = false ∧ ¬l.Nodup⌝)
(onContinue := fun traversalState seen =>
⌜(∀ x, x ∈ seen ↔ x ∈ traversalState.prefix) ∧ traversalState.prefix.Nodup⌝)
| 2 => Invariant.withEarlyReturn
(onReturn := fun ret seen => ⌜ret = false ∧ ¬l.Nodup⌝)
(onContinue := fun traversalState seen =>
⌜(∀ x, x ∈ seen ↔ x ∈ traversalState.prefix) ∧ traversalState.prefix.Nodup⌝)
with grind
theorem nodup_twice_correct_using_multiple_with (l : List Int) : nodup_twice l ↔ l.Nodup := by
generalize h : nodup_twice l = r
apply Id.of_wp_run_eq h
mvcgen
using invariants
| 1, 2 => Invariant.withEarlyReturn
(onReturn := fun ret seen => ⌜ret = false ∧ ¬l.Nodup⌝)
(onContinue := fun traversalState seen =>
⌜(∀ x, x ∈ seen ↔ x ∈ traversalState.prefix) ∧ traversalState.prefix.Nodup⌝)
with grind
/-- error: Invariant 2 is already assigned -/
#guard_msgs in
theorem nodup_twice_correct_using_multiple_error (l : List Int) : nodup_twice l ↔ l.Nodup := by
generalize h : nodup_twice l = r
apply Id.of_wp_run_eq h
mvcgen
using invariants
| 1, 2 => Invariant.withEarlyReturn
(onReturn := fun ret seen => ⌜ret = false ∧ ¬l.Nodup⌝)
(onContinue := fun traversalState seen =>
⌜(∀ x, x ∈ seen ↔ x ∈ traversalState.prefix) ∧ traversalState.prefix.Nodup⌝)
| 2 => Invariant.withEarlyReturn
(onReturn := fun ret seen => ⌜ret = false ∧ ¬l.Nodup⌝)
(onContinue := fun traversalState seen =>
⌜(∀ x, x ∈ seen ↔ x ∈ traversalState.prefix) ∧ traversalState.prefix.Nodup⌝)
with grind