From f1eaebba312ed64ebecd9bb3ccc22d4a18081c31 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 5 Aug 2019 13:19:24 -0700 Subject: [PATCH] fix(library/compiler/csimp): bug at `float_cases_on` The scope of the expr2ctor cache updates was incorrect. This bug affects code of the form ``` let x := C.cases_on y ...; K[x] ``` when we try to float the `cases_on` application, and the continuation `K[x]` contains another `cases_on` application with major `y`. The new test exposes the bug. This commit also fixes the case where the continuation `K[x]` projects `y`. Fixes #26 --- src/library/compiler/csimp.cpp | 17 +++++---- tests/compiler/float_cases_bug.lean | 36 +++++++++++++++++++ .../float_cases_bug.lean.expected.out | 1 + 3 files changed, 48 insertions(+), 6 deletions(-) create mode 100644 tests/compiler/float_cases_bug.lean create mode 100644 tests/compiler/float_cases_bug.lean.expected.out diff --git a/src/library/compiler/csimp.cpp b/src/library/compiler/csimp.cpp index b5312ea42d..2682087dc2 100644 --- a/src/library/compiler/csimp.cpp +++ b/src/library/compiler/csimp.cpp @@ -637,9 +637,11 @@ class csimp_fn { buffer zs; unsigned saved_fvars_size = m_fvars.size(); expr minor_val = get_minor_body(minor, zs); - flet save_expr2ctor(m_expr2ctor, m_expr2ctor); - update_expr2ctor(major, c_fn, c_args, i, zs); - minor_val = visit(minor_val, false); + { + flet save_expr2ctor(m_expr2ctor, m_expr2ctor); + update_expr2ctor(major, c_fn, c_args, i, zs); + minor_val = visit(minor_val, false); + } expr new_minor; if (is_join_point_app(minor_val)) { buffer jp_args; @@ -1497,9 +1499,12 @@ class csimp_fn { unsigned saved_fvars_size = m_fvars.size(); buffer zs; minor = get_minor_body(minor, zs); - flet save_expr2ctor(m_expr2ctor, m_expr2ctor); - update_expr2ctor(major, c, args, cidx, zs); - expr new_minor = visit(minor, false); + expr new_minor; + { + flet save_expr2ctor(m_expr2ctor, m_expr2ctor); + update_expr2ctor(major, c, args, cidx, zs); + new_minor = visit(minor, false); + } new_minor = mk_let(zs, saved_fvars_size, new_minor, false); expr result_minor = mk_minor_lambda(zs, new_minor); if (all_equal_opt) { diff --git a/tests/compiler/float_cases_bug.lean b/tests/compiler/float_cases_bug.lean new file mode 100644 index 0000000000..c9d3591f96 --- /dev/null +++ b/tests/compiler/float_cases_bug.lean @@ -0,0 +1,36 @@ +inductive Term : Type +| const : Nat -> Term +| app : List Term -> Term + +namespace Term +instance : Inhabited Term := ⟨Term.const 0⟩ +partial def hasToString : Term -> String | (const n) := "CONST(" ++ toString n ++ ")" | (app ts) := "APP" +instance : HasToString Term := ⟨hasToString⟩ +end Term + +open Term + +structure MyState : Type := (ts : List Term) +def emit (t : Term) : State MyState Unit := modify (λ ms => ⟨t::ms.ts⟩) + +partial def foo : MyState -> Term -> Term -> List Term +| ms₀ t u := + let stateT : State MyState Unit := do { + + match t with + | const _ => pure () + | app _ => emit (const 1) *> pure () ; + + match t, u with + | app _, app _ => emit (app []) *> pure () + | _, _ => pure () ; + + match t, u with + | app _, app _ => emit (app []) *> pure () + | _, _ => emit (const 2) *> pure () + + } ; + + (stateT.run ⟨[]⟩).2.ts.reverse + +def main : IO Unit := IO.println $ foo ⟨[]⟩ (app []) (app []) diff --git a/tests/compiler/float_cases_bug.lean.expected.out b/tests/compiler/float_cases_bug.lean.expected.out new file mode 100644 index 0000000000..d11d51fe37 --- /dev/null +++ b/tests/compiler/float_cases_bug.lean.expected.out @@ -0,0 +1 @@ +[CONST(1), APP, APP]