lean4-htt/tests/playground/uf1_new.lean
Sebastian Ullrich 8a02dfec4f feat: subsume variables under variable
/cc @leodemoura
2021-01-22 14:36:05 +01:00

126 lines
4.3 KiB
Text
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

def StateT (m : Type → Type) (σ : Type) (α : Type) := (Unit × σ) → m (α × σ)
namespace StateT
variable {m : Type → Type} [Monad m] {σ : Type} {α β : Type}
@[inline] protected def pure (a : α) : StateT m σ α := λ ⟨_, s⟩, pure (a, s)
@[inline] protected def bind (x : StateT m σ α) (f : α → StateT m σ β) : StateT m σ β := λ p, do (a, s') ← x p, f a ((), s')
@[inline] def read : StateT m σ σ := λ ⟨_, s⟩, pure (s, s)
@[inline] def write (s' : σ) : StateT m σ Unit := λ ⟨_, s⟩, pure ((), s')
@[inline] def updt (f : σσ) : StateT m σ Unit := λ ⟨_, s⟩, pure ((), f s)
instance : Monad (StateT m σ) :=
{pure := @StateT.pure _ _ _, bind := @StateT.bind _ _ _}
end StateT
def ExceptT (m : Type → Type) (ε : Type) (α : Type) := m (Except ε α)
namespace ExceptT
variable {m : Type → Type} [Monad m] {ε : Type} {α β : Type}
@[inline] protected def pure (a : α) : ExceptT m ε α := (pure (Except.ok a) : m (Except ε α))
@[inline] protected def bind (x : ExceptT m ε α) (f : α → ExceptT m ε β) : ExceptT m ε β :=
(do { v ← x, match v with
| Except.error e := pure (Except.error e)
| Except.ok a := f a } : m (Except ε β))
@[inline] def error (e : ε) : ExceptT m ε α := (pure (Except.error e) : m (Except ε α))
@[inline] def lift (x : m α) : ExceptT m ε α := (do {a ← x, pure (Except.ok a) } : m (Except ε α))
instance : Monad (ExceptT m ε) :=
{pure := @ExceptT.pure _ _ _, bind := @ExceptT.bind _ _ _}
end ExceptT
abbreviation Node := Nat
structure nodeData :=
(find : Node) (rank : Nat := 0)
abbreviation ufData := Array nodeData
abbreviation M (α : Type) := ExceptT (StateT id ufData) String α
@[inline] def read : M ufData := ExceptT.lift StateT.read
@[inline] def write (s : ufData) : M Unit := ExceptT.lift (StateT.write s)
@[inline] def updt (f : ufData → ufData) : M Unit := ExceptT.lift (StateT.updt f)
def error {α : Type} (e : String) : M α := ExceptT.error e
def run {α : Type} (x : M α) (s : ufData := Array.nil) : Except String α × ufData :=
x ((), s)
def capacity : M Nat :=
do d ← read, pure d.size
def findEntryAux : Nat → Node → M nodeData
| 0 n := error "out of fuel"
| (i+1) n :=
do s ← read,
if h : n < s.size then
do { let e := s.read ⟨n, h⟩,
if e.find = n then pure e
else do e₁ ← findEntryAux i e.find,
updt (λ s, s.write' n e₁),
pure e₁ }
else error "invalid Node"
def findEntry (n : Node) : M nodeData :=
do c ← capacity,
findEntryAux c n
def find (n : Node) : M Node :=
do e ← findEntry n, pure e.find
def mk : M Node :=
do n ← capacity,
updt $ λ s, s.push {find := n, rank := 1},
pure n
def union (n₁ n₂ : Node) : M Unit :=
do r₁ ← findEntry n₁,
r₂ ← findEntry n₂,
if r₁.find = r₂.find then pure ()
else updt $ λ s,
if r₁.rank < r₂.rank then s.write' r₁.find { find := r₂.find }
else if r₁.rank = r₂.rank then
let s₁ := s.write' r₁.find { find := r₂.find } in
s₁.write' r₂.find { rank := r₂.rank + 1, .. r₂}
else s.write' r₂.find { find := r₁.find }
def mkNodes : Nat → M Unit
| 0 := pure ()
| (n+1) := mk *> mkNodes n
def checkEq (n₁ n₂ : Node) : M Unit :=
do r₁ ← find n₁, r₂ ← find n₂,
unless (r₁ = r₂) $ error "nodes are not equal"
def mergePackAux : Nat → Nat → Nat → M Unit
| 0 _ _ := pure ()
| (i+1) n d :=
do c ← capacity,
if (n+d) < c
then union n (n+d) *> mergePackAux i (n+1) d
else pure ()
def mergePack (d : Nat) : M Unit :=
do c ← capacity, mergePackAux c 0 d
def numEqsAux : Nat → Node → Nat → M Nat
| 0 _ r := pure r
| (i+1) n r :=
do c ← capacity,
if n < c
then do { n₁ ← find n, numEqsAux i (n+1) (if n = n₁ then r else r+1) }
else pure r
def numEqs : M Nat :=
do c ← capacity,
numEqsAux c 0 0
def test (n : Nat) : M Nat :=
if n < 2 then error "input must be greater than 1"
else do
mkNodes n,
mergePack 50000,
mergePack 10000,
mergePack 5000,
mergePack 1000,
numEqs
def main (xs : List String) : IO UInt32 :=
let n := xs.head.toNat in
match run (test n) with
| (Except.ok v, s) := IO.println ("ok " ++ toString v) *> pure 0
| (Except.error e, s) := IO.println ("Error : " ++ e) *> pure 1