From b85d95b7b66d8ae343dee35e7402019d99431d2b Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 28 Sep 2021 17:46:19 -0700 Subject: [PATCH] fix: `panic` in monadic polymorphic code fixes #695 --- src/Init/Prelude.lean | 14 +++++++++++++- tests/compiler/overflow1.lean | 15 +++++++++++++++ tests/compiler/overflow1.lean.expected.out | 1 + tests/compiler/overflow2.lean | 18 ++++++++++++++++++ tests/compiler/overflow2.lean.expected.out | 1 + tests/compiler/overflow3.lean | 20 ++++++++++++++++++++ tests/compiler/overflow3.lean.expected.out | 1 + 7 files changed, 69 insertions(+), 1 deletion(-) create mode 100644 tests/compiler/overflow1.lean create mode 100644 tests/compiler/overflow1.lean.expected.out create mode 100644 tests/compiler/overflow2.lean create mode 100644 tests/compiler/overflow2.lean.expected.out create mode 100644 tests/compiler/overflow3.lean create mode 100644 tests/compiler/overflow3.lean.expected.out diff --git a/src/Init/Prelude.lean b/src/Init/Prelude.lean index 438b61e9d2..f53b6093a6 100644 --- a/src/Init/Prelude.lean +++ b/src/Init/Prelude.lean @@ -1141,7 +1141,19 @@ unsafe def unsafeCast {α : Type u} {β : Type v} (a : α) : β := cast lcProof (PUnit.{v}) @[neverExtract, extern "lean_panic_fn"] -constant panic {α : Type u} [Inhabited α] (msg : String) : α +constant panicCore {α : Type u} [Inhabited α] (msg : String) : α + +/-- + This is workaround for `panic` occurring in monadic code. See issue #695. + The `panicCore` definition cannot be specialized since it is an extern. + When `panic` occurs in monadic code, the `Inhabited α` parameter depends on a `[inst : Monad m]` instance. + The `inst` parameter will not be eliminated during specialization if it occurs inside of a binder (to avoid work duplication), and + will prevent the the actual monad from being "copied" to the code being specialized. When we reimplement the specializer, we + may consider copying `inst` if it also occurs outside binders or if it is an instance. +-/ +@[noinline, neverExtract] +def panic {α : Type u} [Inhabited α] (msg : String) : α := + panicCore msg /- The Compiler has special support for arrays. diff --git a/tests/compiler/overflow1.lean b/tests/compiler/overflow1.lean new file mode 100644 index 0000000000..22c2af12f3 --- /dev/null +++ b/tests/compiler/overflow1.lean @@ -0,0 +1,15 @@ +def longArray (n : Nat := 50000) (xs : Array Char := #[]) : Array Char := +match n with +| 0 => xs +| n+1 => longArray n (xs.push 'a') + +def OverflowIte + {m : Type -> Type} + [inst1: Monad m] + (xs: Array Char) : + StateT Nat m Nat := + xs.foldlM (fun (len : Nat) (s : Char) => if s = 'z' then panic "z" else return len + 1) 0 + +def main : IO Unit := + let x := (StateT.run (@OverflowIte Id _ longArray) 0).fst + IO.println x diff --git a/tests/compiler/overflow1.lean.expected.out b/tests/compiler/overflow1.lean.expected.out new file mode 100644 index 0000000000..ccfc37a15d --- /dev/null +++ b/tests/compiler/overflow1.lean.expected.out @@ -0,0 +1 @@ +50000 diff --git a/tests/compiler/overflow2.lean b/tests/compiler/overflow2.lean new file mode 100644 index 0000000000..b2cfb8046a --- /dev/null +++ b/tests/compiler/overflow2.lean @@ -0,0 +1,18 @@ +def longArray (n : Nat := 50000) (xs : Array Char := #[]) : Array Char := +match n with +| 0 => xs +| n+1 => longArray n (xs.push 'a') + +def OverflowFold + {m : Type -> Type} + [inst1: Monad m] + (xs: Array Char) : + StateT Nat m Nat := + xs.foldlM (fun (len : Nat) (s : Char) => + match s with + | 'z' => panic "z" + | _ => return len + 1) 0 + +def main : IO Unit := + let x := (StateT.run (@OverflowFold Id _ longArray) 0).fst + IO.println x diff --git a/tests/compiler/overflow2.lean.expected.out b/tests/compiler/overflow2.lean.expected.out new file mode 100644 index 0000000000..ccfc37a15d --- /dev/null +++ b/tests/compiler/overflow2.lean.expected.out @@ -0,0 +1 @@ +50000 diff --git a/tests/compiler/overflow3.lean b/tests/compiler/overflow3.lean new file mode 100644 index 0000000000..47cb10e42e --- /dev/null +++ b/tests/compiler/overflow3.lean @@ -0,0 +1,20 @@ +def longArray (n : Nat := 50000) (xs : Array Char := #[]) : Array Char := +match n with +| 0 => xs +| n+1 => longArray n (xs.push 'a') + +def OverflowLoop + {m : Type -> Type} + [inst1: Monad m] + (xs: Array Char) : + StateT Nat m Nat := do + let mut out := 0 + for c in xs do + match c with + | 'z' => panic "z" + | _ => out := out + 1 + return out + +def main : IO Unit := + let x := (StateT.run (@OverflowLoop Id _ longArray) 0).fst + IO.println x diff --git a/tests/compiler/overflow3.lean.expected.out b/tests/compiler/overflow3.lean.expected.out new file mode 100644 index 0000000000..ccfc37a15d --- /dev/null +++ b/tests/compiler/overflow3.lean.expected.out @@ -0,0 +1 @@ +50000