diff --git a/src/library/equations_compiler/wf_rec.cpp b/src/library/equations_compiler/wf_rec.cpp index a5e724c6a3..6a9d6e9d9d 100644 --- a/src/library/equations_compiler/wf_rec.cpp +++ b/src/library/equations_compiler/wf_rec.cpp @@ -370,26 +370,22 @@ struct wf_rec_fn { expr arg = app_arg(e); unsigned num_fns = ues.get_num_fns(); expr result_fn; - unsigned fn_idx = 0; + unsigned fn_idx = 0; + /* Recall that if we have 4 mutually recursive functions, we encode them as + + f_1 a = _mutual (inl a) + f_2 b = _mutual (inr (inl b)) + f_3 c = _mutual (inr (inr (inl c))) + f_4 d = _mutual (inr (inr (inr c))) + */ if (num_fns > 1) { - if (is_app_of(arg, get_psum_inr_name())) { - for (unsigned i = 0; i < num_fns - 1; i++) { - lean_assert(is_app_of(arg, get_psum_inr_name())); - arg = app_arg(arg); - } - result_fn = result_fns.back(); - fn_idx = num_fns - 1; - } else { - lean_assert(is_app_of(arg, get_psum_inl_name())); + while (is_app_of(arg, get_psum_inr_name())) { + fn_idx++; + arg = app_arg(arg); + } + if (is_app_of(arg, get_psum_inl_name())) { arg = app_arg(arg); - while (is_app_of(arg, get_psum_inr_name())) { - fn_idx++; - arg = app_arg(arg); - } - lean_assert(fn_idx < num_fns); } - } else { - fn_idx = 0; } result_fn = result_fns[fn_idx]; unsigned arity = ues.get_arity_of(fn_idx); diff --git a/tests/lean/run/1782.lean b/tests/lean/run/1782.lean new file mode 100644 index 0000000000..b633a00d12 --- /dev/null +++ b/tests/lean/run/1782.lean @@ -0,0 +1,37 @@ +mutual inductive a, b, c +with a : Type +| foo : a +with b : Type +| bar : b +with c : Type +| baz : c + +mutual def f, g, h +with f : a → nat +| a.foo := 0 +with g : b → nat +| b.bar := 1 +with h : c → nat +| c.baz := 2 + +example : f a.foo = 0 := rfl +example : g b.bar = 1 := rfl +example : h c.baz = 2 := rfl + + +mutual def f_1, f_2, f_3, f_4 +with f_1 : a → nat +| a.foo := 0 +with f_2 : b → nat +| b.bar := 1 +with f_3 : c → nat +| c.baz := 2 +with f_4 : nat → nat +| 0 := 3 +| _ := 4 + +example : f_1 a.foo = 0 := rfl +example : f_2 b.bar = 1 := rfl +example : f_3 c.baz = 2 := rfl +example : f_4 0 = 3 := rfl +example : f_4 1 = 4 := rfl