lean4-htt/tests/elab/bhaviksSampler.lean
Sebastian Graf 40e8f4c5fb
chore: turn on new do elaborator in Core (#12656)
This PR turns on the new `do` elaborator in Init, Lean, Std, Lake and
the testsuite.

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-09 12:38:33 +00:00

179 lines
6 KiB
Text
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import Std.Data.TreeMap
import Std.Tactic.Do
set_option backward.do.legacy false
/-!
This test is based on code by Bhavik Mehta.
It demonstrates that much of the boilerplate that he was forced to write for his proof
is now automated by the `mvcgen` tactic.
-/
set_option grind.warning false
set_option mvcgen.warning false
section VendoredFromMathlib
abbrev := Nat
/-- A monad transformer to generate random objects using the generic generator type `g` -/
abbrev RandGT (g : Type) := StateT (ULift g)
/-- A monad to generate random objects using the generator type `g`. -/
abbrev RandG (g : Type) := RandGT g Id
/-- A monad transformer to generate random objects using the generator type `StdGen`.
`RandT m α` should be thought of a random value in `m α`. -/
abbrev RandT := RandGT StdGen
/-- `Random m α` gives us machinery to generate values of type `α` in the monad `m`.
Note that `m` is a parameter as some types may only be sampleable with access to a certain monad. -/
class Random (m) (α : Type u) where
/-- Sample an element of this type from the provided generator. -/
random [RandomGen g] : RandGT g m α
/-- `BoundedRandom m α` gives us machinery to generate values of type `α` between certain bounds in
the monad `m`. -/
class BoundedRandom (m) (α : Type u) [LE α] where
/-- Sample a bounded element of this type from the provided generator. -/
randomR {g : Type} (lo hi : α) (h : lo ≤ hi) [RandomGen g] : RandGT g m {a // lo ≤ a ∧ a ≤ hi}
namespace Rand
/-- Generate a random `Nat`. -/
def next [RandomGen g] [Monad m] : RandGT g m Nat := do
let rng := (← get).down
let (res, new) := RandomGen.next rng
set (ULift.up new)
pure res
/-- Create a new random number generator distinct from the one stored in the state. -/
def split {g : Type} [RandomGen g] [Monad m] : RandGT g m g := do
let rng := (← get).down
let (r1, r2) := RandomGen.split rng
set (ULift.up r1)
pure r2
/-- Get the range of `Nat` that can be generated by the generator `g`. -/
def range {g : Type} [RandomGen g] [Monad m] : RandGT g m (Nat × Nat) := do
let rng := (← get).down
pure <| RandomGen.range rng
end Rand
namespace Random
open Rand
variable [Monad m]
/-- Generate a random value of type `α`. -/
def rand (α : Type u) [Random m α] [RandomGen g] : RandGT g m α := Random.random
/-- Generate a random value of type `α` between `x` and `y` inclusive. -/
def randBound (α : Type u)
[LE α] [BoundedRandom m α] (lo hi : α) (h : lo ≤ hi) [RandomGen g] :
RandGT g m {a // lo ≤ a ∧ a ≤ hi} :=
(BoundedRandom.randomR lo hi h : RandGT g _ _)
/-- Generate a random `Fin`. -/
def randFin {n : Nat} [NeZero n] [RandomGen g] : RandGT g m (Fin n) :=
fun ⟨g⟩ ↦ pure <| randNat g 0 (n - 1) |>.map (Fin.ofNat n) ULift.up
instance {n : Nat} [NeZero n] : Random m (Fin n) where
random := randFin
instance : BoundedRandom m Nat where
randomR lo hi h _ := do
let z ← rand (Fin (hi - lo + 1))
pure ⟨
lo + z.val, Nat.le_add_right _ _,
Nat.add_le_of_le_sub' h (Nat.le_of_lt_add_one z.isLt)
end Random
end VendoredFromMathlib
open Random
/-- Take k samples, without replacement, from [0..n-1] -/
def sampler {m} [Monad m] (n k : ) [NeZero n] (h : k ≤ n) : RandT m (Vector (Fin n) k) := do
let mut x : Vector (Fin n) k := Vector.replicate _ 0
let mut h : Std.TreeMap (Fin n) := Std.TreeMap.empty
for hi : i in [0:k] do
let j ← Subtype.val <$> randBound i (n - 1) (have : i < k := hi.upper; by grind)
x := x.set i (h.getD j ⟨j, sorry⟩)
h := h.insert j (h.getD i ⟨i, sorry⟩)
return x
variable {m : Type → Type u} [Monad m] [LawfulMonad m] {n k : }
abbrev Midway (n k : ) : Type := Prod (Vector (Fin n) k) (Std.TreeMap (Fin n))
def init (n k : ) [NeZero n] : Midway n k :=
⟨Vector.replicate _ 0, Std.TreeMap.empty⟩
variable [NeZero n]
def next (data : Midway n k) (i : ) (hi : i < k) (j : ) : Midway n k :=
let (x, h) := data
⟨x.set i (h.getD j ⟨j, sorry⟩), h.insert j (h.getD i ⟨i, sorry⟩)⟩
structure Midway.valid (data : Midway n k) (i : ) : Prop where
nodup_take : (data.1.toList.take i).Nodup
-- disjoint : ∀ j, i ≤ j → j ≤ n - 1 → data.1.getD j j ∉ data.2.toList.take i
-- injOn : Set.InjOn (fun j ↦ data.1.getD j j) {j | i ≤ j ∧ j ≤ n - 1}
theorem valid_init : Midway.valid (init n k) 0 :=
sorry -- domain-specific
theorem Midway.valid_next (data : Midway n k) (i : ) (hi : i < k)
(j : ) (hij : i ≤ j) (hjn : j ≤ n - 1)
(h : Midway.valid data i) : Midway.valid (next data i hi j) (i + 1) :=
sorry -- domain-specific
open Std.Do
@[spec]
theorem randFin_total {m : Type → Type u} [Monad m] [WPMonad m ps] {n : } [NeZero n] :
⦃fun _ => P⦄ -- it's unfortunate that we have to "guess" the frame `fun _ => P` ourselves. TODO: autogeneralize based on "parametricity" in `m`?
randFin (n:=n) (m:=m) (g:=StdGen)
⦃⇓ _ _ => P⦄ := by
unfold randFin
mintro hs ∀s
simp [wp, StateT.run]
@[spec]
theorem randBound_spec {m : Type → Type u} [Monad m] [WPMonad m ps] (h : lo ≤ hi) :
⦃fun _ => P⦄
randBound (m:=m) (g:=StdGen) lo hi h
⦃⇓ _ _ => P⦄ := by
mvcgen [randBound, BoundedRandom.randomR, rand, random]
theorem sampler_correct {m : Type → Type u} {k h} [Monad m] [WPMonad m ps] :
⦃⌜True⌝⦄
sampler (m:=m) n k h
⦃⇓ xs => ⌜xs.toList.Nodup⌝⦄ := by
mvcgen -leave [sampler]
case inv1 => exact (⇓ (xs, midway) => ⌜Midway.valid midway xs.pos⌝)
case vc1 pref cur _ _ _ _ _ _ r _ _ _ =>
dsimp
mframe
rename_i hinv
mpure_intro
simp only [List.length_append, List.length_cons, List.length_nil, Nat.zero_add]
have : cur = pref.length := by grind
subst this
apply Midway.valid_next _ pref.length _ r.val r.property.1 r.property.2 hinv
case vc2 =>
mpure_intro
exact valid_init
case vc3 =>
dsimp
mrename_i h
mpure h
mpure_intro
have h := h.nodup_take
simp at h
-- prove List.take k r.snd.toList = r.snd.toList for r.snd : Vector (Fin n) k
sorry
case vc4 => simp