fix: bug at unsafe umapMAux implementation

closes #125
This commit is contained in:
Leonardo de Moura 2020-03-14 13:27:20 -07:00
parent 2bb9755e1c
commit 2d7ec0b49c
3 changed files with 42 additions and 9 deletions

View file

@ -436,11 +436,14 @@ def std.prec.maxPlus : Nat := std.prec.max + 10
infixr `×` := Prod
-- notation for n-ary tuples
/- Some type that is not a scalar value in our runtime.
TODO: mark opaque -/
/- Some type that is not a scalar value in our runtime. -/
structure NonScalar :=
(val : Nat)
/- Some type that is not a scalar value in our runtime and is universe polymorphic. -/
inductive PNonScalar : Type u
| mk (v : Nat) : PNonScalar
/- For numeric literals notation -/
class HasOfNat (α : Type u) :=
(ofNat : Nat → α)
@ -1098,6 +1101,8 @@ instance : Inhabited Nat := ⟨0⟩
instance : Inhabited NonScalar := ⟨⟨arbitrary _⟩⟩
instance : Inhabited PNonScalar.{u} := ⟨⟨arbitrary _⟩⟩
instance : Inhabited PointedType := ⟨{type := PUnit, val := ⟨⟩}⟩
class inductive Nonempty (α : Sort u) : Prop

View file

@ -445,21 +445,21 @@ section
variables {m : Type u → Type w} [Monad m]
variable {β : Type u}
@[specialize] unsafe partial def umapMAux (f : Nat → α → m β) : Nat → Array α → m (Array β)
@[specialize] unsafe partial def umapMAux (f : Nat → α → m β) : Nat → Array (PNonScalar.{u}) → m (Array (PNonScalar.{u}))
| i, a =>
if h : i < a.size then
let idx : Fin a.size := ⟨i, h⟩;
let v : α := a.get idx;
let a := a.set idx (unsafeCast ());
do newV ← f i v; umapMAux (i+1) (a.set idx (unsafeCast newV))
let v : PNonScalar := a.get idx;
let a := a.set idx (arbitrary _);
do newV ← f i (unsafeCast v); umapMAux (i+1) (a.set idx (unsafeCast newV))
else
pure (unsafeCast a)
pure a
@[inline] unsafe partial def umapM (f : α → m β) (as : Array α) : m (Array β) :=
umapMAux (fun i a => f a) 0 as
@unsafeCast (m (Array PNonScalar.{u})) (m (Array β)) $ umapMAux (fun i a => f a) 0 (unsafeCast as)
@[inline] unsafe partial def umapIdxM (f : Nat → α → m β) (as : Array α) : m (Array β) :=
umapMAux f 0 as
@unsafeCast (m (Array PNonScalar.{u})) (m (Array β)) $ umapMAux f 0 (unsafeCast as)
@[implementedBy Array.umapM] def mapM (f : α → m β) (as : Array α) : m (Array β) :=
as.foldlM (fun bs a => do b ← f a; pure (bs.push b)) (mkEmpty as.size)

28
tests/lean/run/125.lean Normal file
View file

@ -0,0 +1,28 @@
class HasElems (α : Type) : Type := (elems : Array α)
def elems (α : Type) [HasElems α] : Array α := HasElems.elems α
inductive Foo : Type
| mk1 : Bool → Foo
| mk2 : Bool → Foo
open Foo
instance BoolElems : HasElems Bool := ⟨#[false, true]⟩
instance FooElems : HasElems Foo := ⟨(elems Bool).map mk1 ++ (elems Bool).map mk2⟩
def fooRepr (foo : Foo) :=
match foo with
| mk1 b => "OH " ++ toString b
| mk2 b => "DR " ++ toString b
instance : HasRepr Foo := ⟨fooRepr⟩
#eval elems Foo
#eval #[false, true].map Foo.mk1
def Foo.toBool : Foo → Bool
| Foo.mk1 b => b
| Foo.mk2 b => b
#eval #[Foo.mk1 false, Foo.mk2 true].map Foo.toBool