fix: panic in monadic polymorphic code

fixes #695
This commit is contained in:
Leonardo de Moura 2021-09-28 17:46:19 -07:00
parent d0462153a0
commit b85d95b7b6
7 changed files with 69 additions and 1 deletions

View file

@ -1141,7 +1141,19 @@ unsafe def unsafeCast {α : Type u} {β : Type v} (a : α) : β :=
cast lcProof (PUnit.{v})
@[neverExtract, extern "lean_panic_fn"]
constant panic {α : Type u} [Inhabited α] (msg : String) : α
constant panicCore {α : Type u} [Inhabited α] (msg : String) : α
/--
This is workaround for `panic` occurring in monadic code. See issue #695.
The `panicCore` definition cannot be specialized since it is an extern.
When `panic` occurs in monadic code, the `Inhabited α` parameter depends on a `[inst : Monad m]` instance.
The `inst` parameter will not be eliminated during specialization if it occurs inside of a binder (to avoid work duplication), and
will prevent the the actual monad from being "copied" to the code being specialized. When we reimplement the specializer, we
may consider copying `inst` if it also occurs outside binders or if it is an instance.
-/
@[noinline, neverExtract]
def panic {α : Type u} [Inhabited α] (msg : String) : α :=
panicCore msg
/-
The Compiler has special support for arrays.

View file

@ -0,0 +1,15 @@
def longArray (n : Nat := 50000) (xs : Array Char := #[]) : Array Char :=
match n with
| 0 => xs
| n+1 => longArray n (xs.push 'a')
def OverflowIte
{m : Type -> Type}
[inst1: Monad m]
(xs: Array Char) :
StateT Nat m Nat :=
xs.foldlM (fun (len : Nat) (s : Char) => if s = 'z' then panic "z" else return len + 1) 0
def main : IO Unit :=
let x := (StateT.run (@OverflowIte Id _ longArray) 0).fst
IO.println x

View file

@ -0,0 +1 @@
50000

View file

@ -0,0 +1,18 @@
def longArray (n : Nat := 50000) (xs : Array Char := #[]) : Array Char :=
match n with
| 0 => xs
| n+1 => longArray n (xs.push 'a')
def OverflowFold
{m : Type -> Type}
[inst1: Monad m]
(xs: Array Char) :
StateT Nat m Nat :=
xs.foldlM (fun (len : Nat) (s : Char) =>
match s with
| 'z' => panic "z"
| _ => return len + 1) 0
def main : IO Unit :=
let x := (StateT.run (@OverflowFold Id _ longArray) 0).fst
IO.println x

View file

@ -0,0 +1 @@
50000

View file

@ -0,0 +1,20 @@
def longArray (n : Nat := 50000) (xs : Array Char := #[]) : Array Char :=
match n with
| 0 => xs
| n+1 => longArray n (xs.push 'a')
def OverflowLoop
{m : Type -> Type}
[inst1: Monad m]
(xs: Array Char) :
StateT Nat m Nat := do
let mut out := 0
for c in xs do
match c with
| 'z' => panic "z"
| _ => out := out + 1
return out
def main : IO Unit :=
let x := (StateT.run (@OverflowLoop Id _ longArray) 0).fst
IO.println x

View file

@ -0,0 +1 @@
50000