test: add nondet example
This commit is contained in:
parent
4641542647
commit
2a5b18fde4
1 changed files with 238 additions and 0 deletions
238
tests/playground/nondet.lean
Normal file
238
tests/playground/nondet.lean
Normal file
|
|
@ -0,0 +1,238 @@
|
|||
universes u v w
|
||||
|
||||
namespace Nondet
|
||||
|
||||
def AltsCore := PNonScalar.{u}
|
||||
|
||||
inductive Alts (m : Type u → Type u) (α : Type u)
|
||||
| nil : Alts m α
|
||||
| cons (head : α) (tail : m AltsCore) : Alts m α
|
||||
| delayed (t : Thunk (m AltsCore)) : Alts m α
|
||||
|
||||
namespace Alts
|
||||
|
||||
@[inline]
|
||||
unsafe def ofAltsCore {m : Type u → Type u} {α : Type u} (x : m AltsCore.{u}) : m (Alts m α) :=
|
||||
unsafeCast x
|
||||
|
||||
@[inline]
|
||||
unsafe def toAltsCore {m : Type u → Type u} {α : Type u} (x : m (Alts m α)) : m AltsCore.{u} :=
|
||||
unsafeCast x
|
||||
|
||||
mutual
|
||||
unsafe def run' {m α} [Monad m] : Alts m α → m (Option (α × m (Alts m α)))
|
||||
| nil => pure none
|
||||
| cons x xs => pure $ some (x, ofAltsCore xs)
|
||||
| delayed xs => runUnsafe (ofAltsCore xs.get)
|
||||
|
||||
unsafe def runUnsafe {m α} [Monad m] (x : m (Alts m α)) : m (Option (α × m (Alts m α))) := do
|
||||
run' (← x)
|
||||
end
|
||||
|
||||
unsafe def ofList {m α} [Monad m] : List α → m (Alts m α)
|
||||
| [] => pure Alts.nil
|
||||
| a::as => pure $ Alts.cons a (toAltsCore (ofList as))
|
||||
|
||||
mutual
|
||||
unsafe def append' {m α} [Monad m] : Alts m α → m (Alts m α) → m (Alts m α)
|
||||
| nil, bs => bs
|
||||
| cons a as, bs => pure $ delayed $ Thunk.mk fun _ => toAltsCore $ pure $ cons a $ toAltsCore $ append (ofAltsCore as) bs
|
||||
| delayed as, bs => pure $ delayed $ Thunk.mk fun _ => toAltsCore $ append (ofAltsCore as.get) bs
|
||||
|
||||
unsafe def append {m α} [Monad m] (xs ys : m (Alts m α)) : m (Alts m α) := do
|
||||
append' (← xs) ys
|
||||
end
|
||||
|
||||
mutual
|
||||
unsafe def join' {m α} [Monad m] : Alts m (m (Alts m α)) → m (Alts m α)
|
||||
| nil => pure nil
|
||||
| cons x xs => pure $ delayed $ Thunk.mk fun _ => toAltsCore $ append x $ join $ ofAltsCore xs
|
||||
| delayed xs => pure $ delayed $ Thunk.mk fun _ => toAltsCore (α := α) $ join (ofAltsCore xs.get)
|
||||
|
||||
unsafe def join {m α} [Monad m] (xs : m (Alts m (m (Alts m α)))) : m (Alts m α) := do
|
||||
join' (← xs)
|
||||
end
|
||||
|
||||
mutual
|
||||
unsafe def map' {m α β} [Monad m] (f : α → β) : Alts m α → m (Alts m β)
|
||||
| nil => pure nil
|
||||
| cons x xs => pure $ delayed $ Thunk.mk fun _ => toAltsCore $ pure $ cons (f x) $ toAltsCore $ map f (ofAltsCore xs)
|
||||
| delayed xs => pure $ delayed $ Thunk.mk fun _ => toAltsCore $ map f (ofAltsCore xs.get)
|
||||
|
||||
unsafe def map {m α β} [Monad m] (f : α → β) (xs : m (Alts m α)) : m (Alts m β) := do
|
||||
map' f (← xs)
|
||||
end
|
||||
|
||||
@[inline]
|
||||
protected unsafe def pure {m α} [Monad m] (a : α) : m (Alts m α) := do
|
||||
pure $ Alts.cons a $ toAltsCore (α := α) Alts.nil
|
||||
|
||||
end Alts
|
||||
|
||||
def NondetT (m : Type u → Type u) (α : Type u) := m (Alts m α)
|
||||
|
||||
instance {m α} [Pure m] : Inhabited (NondetT m α) :=
|
||||
⟨(pure Alts.nil : m (Alts m α))⟩
|
||||
|
||||
@[inline]
|
||||
unsafe def appendUnsafe {m α} [Monad m] (x y : NondetT m α) : NondetT m α :=
|
||||
Alts.append x y
|
||||
|
||||
@[implementedBy appendUnsafe]
|
||||
protected constant append {m α} [Monad m] (x y : NondetT m α) : NondetT m α
|
||||
|
||||
@[inline]
|
||||
unsafe def joinUnsafe {m α} [Monad m] (x : NondetT m (NondetT m α)) : NondetT m α :=
|
||||
Alts.join x
|
||||
|
||||
@[implementedBy joinUnsafe]
|
||||
protected constant join {m α} [Monad m] (x : NondetT m (NondetT m α)) : NondetT m α
|
||||
|
||||
@[inline]
|
||||
unsafe def mapUnsafe {m α β} [Monad m] (f : α → β) (x : NondetT m α) : NondetT m β :=
|
||||
Alts.map f x
|
||||
|
||||
@[implementedBy mapUnsafe]
|
||||
protected constant map {m α β} [Monad m] (f : α → β) (x : NondetT m α) : NondetT m β
|
||||
|
||||
@[inline]
|
||||
unsafe def pureUnsafe {m α} [Monad m] (a : α) : NondetT m α :=
|
||||
Alts.pure a
|
||||
|
||||
@[implementedBy pureUnsafe]
|
||||
protected constant pure {m α} [Monad m] (a : α) : NondetT m α
|
||||
|
||||
@[inline]
|
||||
protected def bind {m} [Monad m] {α β} (x : NondetT m α) (f : α → NondetT m β) : NondetT m β :=
|
||||
Nondet.join (Nondet.map f x)
|
||||
|
||||
@[inline]
|
||||
protected def failure {m α} [Monad m] : NondetT m α := id (α := m (Alts m α)) $
|
||||
pure Alts.nil
|
||||
|
||||
instance {m} [Monad m] : Monad (NondetT m) := {
|
||||
bind := Nondet.bind,
|
||||
pure := Nondet.pure,
|
||||
map := Nondet.map
|
||||
}
|
||||
|
||||
instance {m} [Monad m] : Alternative (NondetT m) := {
|
||||
failure := Nondet.failure,
|
||||
orelse := Nondet.append
|
||||
}
|
||||
|
||||
@[inline]
|
||||
unsafe def runUnsafe {m α} [Monad m] (x : NondetT m α) : m (Option (α × NondetT m α)) :=
|
||||
Alts.runUnsafe x
|
||||
|
||||
@[implementedBy runUnsafe]
|
||||
protected constant NondetT.run {m α} [Monad m] (x : NondetT m α) : m (Option (α × NondetT m α))
|
||||
|
||||
protected def NondetT.run' {m α} [Monad m] (x : NondetT m α) : m (Option α) := do
|
||||
match (← x.run) with
|
||||
| some (a, _) => pure (some a)
|
||||
| none => pure none
|
||||
|
||||
@[inline]
|
||||
unsafe def chooseUnsafe {m α} [Monad m] (as : List α) : NondetT m α :=
|
||||
Alts.ofList as
|
||||
|
||||
@[implementedBy chooseUnsafe]
|
||||
constant chooseImp {m α} [Monad m] (as : List α) : NondetT m α
|
||||
|
||||
class Choose (m : Type u → Type u) :=
|
||||
(choose : {α : Type u} → List α → m α)
|
||||
|
||||
export Choose (choose)
|
||||
|
||||
instance {m} [Monad m] : Choose (NondetT m) := {
|
||||
choose := chooseImp
|
||||
}
|
||||
|
||||
instance {m n} [Choose m] [MonadLift m n] : Choose n := {
|
||||
choose := fun as => liftM (m := m) $ choose as
|
||||
}
|
||||
|
||||
@[inline]
|
||||
unsafe def liftUnsafe {m α} [Monad m] (x : m α) : NondetT m α := id (α := m (Alts m α)) do
|
||||
let a ← x
|
||||
Alts.pure a
|
||||
|
||||
@[implementedBy liftUnsafe]
|
||||
protected constant lift {m α} [Monad m] (x : m α) : NondetT m α
|
||||
|
||||
instance {m} [Monad m] : MonadLift m (NondetT m) := ⟨Nondet.lift⟩
|
||||
|
||||
end Nondet
|
||||
|
||||
export Nondet (NondetT)
|
||||
|
||||
namespace ex1
|
||||
-- The state is not backtrackable
|
||||
abbrev M := NondetT $ StateT Nat $ IO
|
||||
|
||||
def checkVal (x : Nat) : M Nat := do
|
||||
IO.println s!"x: {x}"
|
||||
guard (x % 2 == 0)
|
||||
pure x
|
||||
|
||||
def f : M Nat := do
|
||||
let x ← Nondet.choose [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
checkVal x
|
||||
|
||||
def h : M Nat := do
|
||||
let a ← f
|
||||
modify (· + 1)
|
||||
guard (a % 3 == 0)
|
||||
pure (a+1)
|
||||
|
||||
def g : IO Unit := do
|
||||
let (some (x, _), s) ← h.run.run 0 | throw $ IO.userError "failed"
|
||||
IO.println s!"x: {x}, s: {s}"
|
||||
|
||||
#eval g
|
||||
end ex1
|
||||
|
||||
namespace ex2
|
||||
-- The state is backtrackable
|
||||
abbrev M := StateT Nat $ NondetT $ IO
|
||||
|
||||
def checkVal (x : Nat) : M Nat := do
|
||||
IO.println s!"x: {x}"
|
||||
guard (x % 2 == 0)
|
||||
pure x
|
||||
|
||||
def f : M Nat := do
|
||||
let x ← Nondet.choose [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
checkVal x
|
||||
|
||||
def h : M Nat := do
|
||||
let a ← f
|
||||
modify (· + 1)
|
||||
guard (a % 3 == 0)
|
||||
pure (a+1)
|
||||
|
||||
def g : IO Unit := do
|
||||
let some ((x, s), _) ← (h.run 0).run | throw $ IO.userError "failed"
|
||||
IO.println s!"x: {x}, s: {s}"
|
||||
|
||||
#eval g
|
||||
end ex2
|
||||
|
||||
namespace ex3
|
||||
abbrev M := NondetT IO
|
||||
|
||||
def a : M Unit := IO.println "a"
|
||||
def b : M Unit := IO.println "b"
|
||||
def c : M Unit := IO.println "c"
|
||||
|
||||
def t1 : M Unit :=
|
||||
((a <|> a) *> b) *> c
|
||||
|
||||
def t2 : M Unit :=
|
||||
(a <|> a) *> (b *> c)
|
||||
|
||||
#eval t1.run'
|
||||
#eval t2.run'
|
||||
|
||||
end ex3
|
||||
Loading…
Add table
Reference in a new issue