This PR changes the signature of `Array.get` to take a Nat and a proof, rather than a `Fin`, for consistency with the rest of the (planned) Array API. Note that because of bootstrapping issues we can't provide `get_elem_tactic` as an autoparameter for the proof. As users will mostly use the `xs[i]` notation provided by `GetElem`, this hopefully isn't a problem. We may restore `Fin` based versions, either here or downstream, as needed, but they won't be the "main" functions. --------- Co-authored-by: David Thrane Christiansen <david@davidchristiansen.dk>
126 lines
4.3 KiB
Text
Executable file
126 lines
4.3 KiB
Text
Executable file
def StateT' (m : Type → Type) (σ : Type) (α : Type) := σ → m (α × σ)
|
||
namespace StateT'
|
||
variable {m : Type → Type} [Monad m] {σ : Type} {α β : Type}
|
||
@[inline] protected def pure (a : α) : StateT' m σ α := fun s => pure (a, s)
|
||
@[inline] protected def bind (x : StateT' m σ α) (f : α → StateT' m σ β) : StateT' m σ β := fun s => do let (a, s') ← x s; f a s'
|
||
@[inline] def read : StateT' m σ σ := fun s => pure (s, s)
|
||
@[inline] def write (s' : σ) : StateT' m σ Unit := fun s => pure ((), s')
|
||
@[inline] def updt (f : σ → σ) : StateT' m σ Unit := fun s => pure ((), f s)
|
||
instance : Monad (StateT' m σ) :=
|
||
{pure := @StateT'.pure _ _ _, bind := @StateT'.bind _ _ _}
|
||
end StateT'
|
||
|
||
def ExceptT' (m : Type → Type) (ε : Type) (α : Type) := m (Except ε α)
|
||
namespace ExceptT'
|
||
variable {m : Type → Type} [Monad m] {ε : Type} {α β : Type}
|
||
@[inline] protected def pure (a : α) : ExceptT' m ε α := (pure (Except.ok a) : m (Except ε α))
|
||
@[inline] protected def bind (x : ExceptT' m ε α) (f : α → ExceptT' m ε β) : ExceptT' m ε β :=
|
||
(do { let v ← x; match v with
|
||
| Except.error e => pure (Except.error e)
|
||
| Except.ok a => f a } : m (Except ε β))
|
||
@[inline] def error (e : ε) : ExceptT' m ε α := (pure (Except.error e) : m (Except ε α))
|
||
@[inline] def lift (x : m α) : ExceptT' m ε α := (do {let a ← x; pure (Except.ok a) } : m (Except ε α))
|
||
instance : Monad (ExceptT' m ε) :=
|
||
{pure := @ExceptT'.pure _ _ _, bind := @ExceptT'.bind _ _ _}
|
||
end ExceptT'
|
||
|
||
abbrev Node := Nat
|
||
|
||
structure nodeData :=
|
||
(find : Node) (rank : Nat := 0)
|
||
|
||
abbrev ufData := Array nodeData
|
||
|
||
abbrev M (α : Type) := ExceptT' (StateT' Id ufData) String α
|
||
@[inline] def read : M ufData := ExceptT'.lift StateT'.read
|
||
@[inline] def write (s : ufData) : M Unit := ExceptT'.lift (StateT'.write s)
|
||
@[inline] def updt (f : ufData → ufData) : M Unit := ExceptT'.lift (StateT'.updt f)
|
||
@[inline] def error {α : Type} (e : String) : M α := ExceptT'.error e
|
||
def run {α : Type} (x : M α) (s : ufData := ∅) : Except String α × ufData :=
|
||
x s
|
||
|
||
def capacity : M Nat :=
|
||
do let d ← read; pure d.size
|
||
|
||
def findEntryAux : Nat → Node → M nodeData
|
||
| 0, n => error "out of fuel"
|
||
| i+1, n =>
|
||
do let s ← read;
|
||
if h : n < s.size then
|
||
do { let e := s[n];
|
||
if e.find = n then pure e
|
||
else do let e₁ ← findEntryAux i e.find;
|
||
updt (fun s => s.set! n e₁);
|
||
pure e₁ }
|
||
else error "invalid Node"
|
||
|
||
def findEntry (n : Node) : M nodeData :=
|
||
do let c ← capacity;
|
||
findEntryAux c n
|
||
|
||
def find (n : Node) : M Node :=
|
||
do let e ← findEntry n; pure e.find
|
||
|
||
def mk : M Node :=
|
||
do let n ← capacity;
|
||
updt $ fun s => s.push {find := n, rank := 1};
|
||
pure n
|
||
|
||
def union (n₁ n₂ : Node) : M Unit :=
|
||
do let r₁ ← findEntry n₁;
|
||
let r₂ ← findEntry n₂;
|
||
if r₁.find = r₂.find then pure ()
|
||
else updt $ fun s =>
|
||
if r₁.rank < r₂.rank then s.set! r₁.find { find := r₂.find }
|
||
else if r₁.rank = r₂.rank then
|
||
let s₁ := s.set! r₁.find { find := r₂.find };
|
||
s₁.set! r₂.find { r₂ with rank := r₂.rank + 1 }
|
||
else s.set! r₂.find { find := r₁.find }
|
||
|
||
|
||
def mkNodes : Nat → M Unit
|
||
| 0 => pure ()
|
||
| n+1 => do _ ← mk; mkNodes n
|
||
|
||
def checkEq (n₁ n₂ : Node) : M Unit :=
|
||
do let r₁ ← find n₁; let r₂ ← find n₂;
|
||
unless (r₁ = r₂) do error "nodes are not equal"
|
||
|
||
def mergePackAux : Nat → Nat → Nat → M Unit
|
||
| 0, _, _ => pure ()
|
||
| i+1, n, d => do
|
||
let c ← capacity;
|
||
if (n+d) < c
|
||
then union n (n+d) *> mergePackAux i (n+1) d
|
||
else pure ()
|
||
|
||
def mergePack (d : Nat) : M Unit :=
|
||
do let c ← capacity; mergePackAux c 0 d
|
||
|
||
def numEqsAux : Nat → Node → Nat → M Nat
|
||
| 0, _, r => pure r
|
||
| i+1, n, r =>
|
||
do let c ← capacity;
|
||
if n < c
|
||
then do { let n₁ ← find n; numEqsAux i (n+1) (if n = n₁ then r else r+1) }
|
||
else pure r
|
||
|
||
def numEqs : M Nat :=
|
||
do let c ← capacity;
|
||
numEqsAux c 0 0
|
||
|
||
def test (n : Nat) : M Nat :=
|
||
if n < 2 then error "input must be greater than 1"
|
||
else do
|
||
mkNodes n;
|
||
mergePack 50000;
|
||
mergePack 10000;
|
||
mergePack 5000;
|
||
mergePack 1000;
|
||
numEqs
|
||
|
||
def main (xs : List String) : IO UInt32 :=
|
||
let n := xs.head!.toNat!;
|
||
match run (test n) with
|
||
| (Except.ok v, s) => IO.println ("ok " ++ toString v) *> pure 0
|
||
| (Except.error e, s) => IO.println ("Error : " ++ e) *> pure 1
|