diff --git a/src/library/compiler/csimp.cpp b/src/library/compiler/csimp.cpp index b5312ea42d..2682087dc2 100644 --- a/src/library/compiler/csimp.cpp +++ b/src/library/compiler/csimp.cpp @@ -637,9 +637,11 @@ class csimp_fn { buffer zs; unsigned saved_fvars_size = m_fvars.size(); expr minor_val = get_minor_body(minor, zs); - flet save_expr2ctor(m_expr2ctor, m_expr2ctor); - update_expr2ctor(major, c_fn, c_args, i, zs); - minor_val = visit(minor_val, false); + { + flet save_expr2ctor(m_expr2ctor, m_expr2ctor); + update_expr2ctor(major, c_fn, c_args, i, zs); + minor_val = visit(minor_val, false); + } expr new_minor; if (is_join_point_app(minor_val)) { buffer jp_args; @@ -1497,9 +1499,12 @@ class csimp_fn { unsigned saved_fvars_size = m_fvars.size(); buffer zs; minor = get_minor_body(minor, zs); - flet save_expr2ctor(m_expr2ctor, m_expr2ctor); - update_expr2ctor(major, c, args, cidx, zs); - expr new_minor = visit(minor, false); + expr new_minor; + { + flet save_expr2ctor(m_expr2ctor, m_expr2ctor); + update_expr2ctor(major, c, args, cidx, zs); + new_minor = visit(minor, false); + } new_minor = mk_let(zs, saved_fvars_size, new_minor, false); expr result_minor = mk_minor_lambda(zs, new_minor); if (all_equal_opt) { diff --git a/tests/compiler/float_cases_bug.lean b/tests/compiler/float_cases_bug.lean new file mode 100644 index 0000000000..c9d3591f96 --- /dev/null +++ b/tests/compiler/float_cases_bug.lean @@ -0,0 +1,36 @@ +inductive Term : Type +| const : Nat -> Term +| app : List Term -> Term + +namespace Term +instance : Inhabited Term := ⟨Term.const 0⟩ +partial def hasToString : Term -> String | (const n) := "CONST(" ++ toString n ++ ")" | (app ts) := "APP" +instance : HasToString Term := ⟨hasToString⟩ +end Term + +open Term + +structure MyState : Type := (ts : List Term) +def emit (t : Term) : State MyState Unit := modify (λ ms => ⟨t::ms.ts⟩) + +partial def foo : MyState -> Term -> Term -> List Term +| ms₀ t u := + let stateT : State MyState Unit := do { + + match t with + | const _ => pure () + | app _ => emit (const 1) *> pure () ; + + match t, u with + | app _, app _ => emit (app []) *> pure () + | _, _ => pure () ; + + match t, u with + | app _, app _ => emit (app []) *> pure () + | _, _ => emit (const 2) *> pure () + + } ; + + (stateT.run ⟨[]⟩).2.ts.reverse + +def main : IO Unit := IO.println $ foo ⟨[]⟩ (app []) (app []) diff --git a/tests/compiler/float_cases_bug.lean.expected.out b/tests/compiler/float_cases_bug.lean.expected.out new file mode 100644 index 0000000000..d11d51fe37 --- /dev/null +++ b/tests/compiler/float_cases_bug.lean.expected.out @@ -0,0 +1 @@ +[CONST(1), APP, APP]