lean4-htt/old_tests/tests/lean/run/cont.lean
2018-04-10 12:56:55 -07:00

93 lines
3.2 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2018 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Sebastian Ullrich
The continuation monad transformer.
-/
import init.category.alternative init.category.combinators init.category.lift
import system.io
universes u v w
/-- An implementation of [ContT](https://hackage.haskell.org/package/mtl-2.2.2/docs/Control-Monad-Cont.html#t:ContT) -/
structure cont_t (ρ : Type u) (m : Type u → Type v) (α : Type u) : Type (max u v) :=
(run : (α → m ρ) → m ρ)
attribute [pp_using_anonymous_constructor] cont_t
/-- An implementation of [MonadCont done right](https://wiki.haskell.org/MonadCont_done_right) -/
class monad_cont (m : Type u → Type v) :=
/- Call a function with the current continuation (cc) as its argument, which can be called to
exit the function from anywhere inside it. -/
(call_cc {α : Type u} : ((∀ {β}, α → m β) → m α) → m α)
export monad_cont (call_cc)
@[reducible] def cont (ρ α : Type u) : Type u := cont_t ρ id α
namespace cont_t
section
parameters {ρ : Type u} {m : Type u → Type v}
protected def pure {α : Type u} (a : α) : cont_t ρ m α :=
⟨λ cc, cc a⟩
protected def bind {α β : Type u} (ma : cont_t ρ m α) (f : α → cont_t ρ m β) : cont_t ρ m β :=
⟨λ cc, ma.run (λ a, (f a).run cc)⟩
instance : monad (cont_t ρ m) :=
{ pure := @pure, bind := @bind }
protected def call_cc {α : Type u} (f : (∀ {β}, α → cont_t ρ m β) → cont_t ρ m α) : cont_t ρ m α :=
⟨λ cc, (f (λ _ a, ⟨λ _, cc a⟩)).run cc⟩
instance : monad_cont (cont_t ρ m) :=
⟨@call_cc⟩
protected def lift [_root_.monad m] {α : Type u} (x : m α) : cont_t ρ m α :=
⟨λ cc, x >>= cc⟩
instance [_root_.monad m] : has_monad_lift m (cont_t ρ m) :=
⟨@cont_t.lift _⟩
-- there is NO instance of `monad_functor` for `cont_t`
end
end cont_t
namespace cont_t
variable {ρ : Type u}
variable {m : Type u → Type v}
variables {α β : Type u}
variables (x : cont_t ρ m α)
lemma ext {x x' : cont_t ρ m α} (h : ∀ cc, x.run cc = x'.run cc) : x = x' :=
by cases x; cases x'; simp [show x = x', from funext h]
@[simp] lemma run_pure (a : α) (cc : α → m ρ) : (pure a : cont_t ρ m α).run cc = cc a := rfl
@[simp] lemma run_bind (f : α → cont_t ρ m β) (cc : β → m ρ) :
(x >>= f).run cc = x.run (λ a, (f a).run cc) := rfl
@[simp] lemma run_map (f : α → β) (cc : β → m ρ) : (f <$> x).run cc = x.run (cc ∘ f) := rfl
end cont_t
instance (ρ : Type u) (m : Type u → Type v) [monad m] [is_lawful_monad m] : is_lawful_monad (cont_t ρ m) :=
{ id_map := by intros; apply cont_t.ext; simp,
pure_bind := by intros; apply cont_t.ext; simp,
bind_assoc := by intros; apply cont_t.ext; simp }
-- count the even numbers from 0 to 7 in a horrible, imperative way
def cont_example : cont_t unit (state_t io) :=
do call_cc $ λ break,
(list.range 10).mmap' $ λ i,
call_cc $ λ continue, do {
when (i % 2 = 0) $
continue (),
when (i > 7) $
break (),
modify (+1) >> pure ()
},
get
#eval do ((), 4) ← (cont_example.run (λ _, pure ())).run 0,
pure ()