def StateT' (m : Type → Type) (σ : Type) (α : Type) := σ → m (α × σ) namespace StateT' variables {m : Type → Type} [Monad m] {σ : Type} {α β : Type} @[inline] protected def pure (a : α) : StateT' m σ α := λ s, pure (a, s) @[inline] protected def bind (x : StateT' m σ α) (f : α → StateT' m σ β) : StateT' m σ β := λ s, do (a, s') ← x s, f a s' @[inline] def read : StateT' m σ σ := λ s, pure (s, s) @[inline] def write (s' : σ) : StateT' m σ Unit := λ s, pure ((), s') @[inline] def updt (f : σ → σ) : StateT' m σ Unit := λ s, pure ((), f s) instance : Monad (StateT' m σ) := {pure := @StateT'.pure _ _ _, bind := @StateT'.bind _ _ _} end StateT' def ExceptT' (m : Type → Type) (ε : Type) (α : Type) := m (Except ε α) namespace ExceptT' variables {m : Type → Type} [Monad m] {ε : Type} {α β : Type} @[inline] protected def pure (a : α) : ExceptT' m ε α := (pure (Except.ok a) : m (Except ε α)) @[inline] protected def bind (x : ExceptT' m ε α) (f : α → ExceptT' m ε β) : ExceptT' m ε β := (do { v ← x, match v with | Except.error e := pure (Except.error e) | Except.ok a := f a } : m (Except ε β)) @[inline] def error (e : ε) : ExceptT' m ε α := (pure (Except.error e) : m (Except ε α)) @[inline] def lift (x : m α) : ExceptT' m ε α := (do {a ← x, pure (Except.ok a) } : m (Except ε α)) instance : Monad (ExceptT' m ε) := {pure := @ExceptT'.pure _ _ _, bind := @ExceptT'.bind _ _ _} end ExceptT' abbreviation Node := Nat structure nodeData := (find : Node) (rank : Nat := 0) abbreviation ufData := Array nodeData abbreviation M (α : Type) := ExceptT' (StateT' Id ufData) String α @[inline] def read : M ufData := ExceptT'.lift StateT'.read @[inline] def write (s : ufData) : M Unit := ExceptT'.lift (StateT'.write s) @[inline] def updt (f : ufData → ufData) : M Unit := ExceptT'.lift (StateT'.updt f) @[inline] def error {α : Type} (e : String) : M α := ExceptT'.error e def run {α : Type} (x : M α) (s : ufData := ∅) : Except String α × ufData := x s def capacity : M Nat := do d ← read, pure d.size def findEntryAux : Nat → Node → M nodeData | 0 n := error "out of fuel" | (i+1) n := do s ← read, if h : n < s.size then do { let e := s.fget ⟨n, h⟩, if e.find = n then pure e else do e₁ ← findEntryAux i e.find, updt (λ s, s.set n e₁), pure e₁ } else error "invalid Node" def findEntry (n : Node) : M nodeData := do c ← capacity, findEntryAux c n def find (n : Node) : M Node := do e ← findEntry n, pure e.find def mk : M Node := do n ← capacity, updt $ λ s, s.push {find := n, rank := 1}, pure n def union (n₁ n₂ : Node) : M Unit := do r₁ ← findEntry n₁, r₂ ← findEntry n₂, if r₁.find = r₂.find then pure () else updt $ λ s, if r₁.rank < r₂.rank then s.set r₁.find { find := r₂.find } else if r₁.rank = r₂.rank then let s₁ := s.set r₁.find { find := r₂.find } in s₁.set r₂.find { rank := r₂.rank + 1, .. r₂} else s.set r₂.find { find := r₁.find } def mkNodes : Nat → M Unit | 0 := pure () | (n+1) := mk *> mkNodes n def checkEq (n₁ n₂ : Node) : M Unit := do r₁ ← find n₁, r₂ ← find n₂, unless (r₁ = r₂) $ error "nodes are not equal" def mergePackAux : Nat → Nat → Nat → M Unit | 0 _ _ := pure () | (i+1) n d := do c ← capacity, if (n+d) < c then union n (n+d) *> mergePackAux i (n+1) d else pure () def mergePack (d : Nat) : M Unit := do c ← capacity, mergePackAux c 0 d def numEqsAux : Nat → Node → Nat → M Nat | 0 _ r := pure r | (i+1) n r := do c ← capacity, if n < c then do { n₁ ← find n, numEqsAux i (n+1) (if n = n₁ then r else r+1) } else pure r def numEqs : M Nat := do c ← capacity, numEqsAux c 0 0 def test (n : Nat) : M Nat := if n < 2 then error "input must be greater than 1" else do mkNodes n, mergePack 50000, mergePack 10000, mergePack 5000, mergePack 1000, numEqs def main (xs : List String) : IO UInt32 := let n := xs.head.toNat in match run (test n) with | (Except.ok v, s) := IO.println ("ok " ++ toString v) *> pure 0 | (Except.error e, s) := IO.println ("Error : " ++ e) *> pure 1