diff --git a/src/library/definitional/equations.cpp b/src/library/definitional/equations.cpp index eb24677c14..69299fbb2c 100644 --- a/src/library/definitional/equations.cpp +++ b/src/library/definitional/equations.cpp @@ -351,6 +351,7 @@ class equation_compiler_fn { } [[ noreturn ]] static void throw_error(char const * msg, expr const & src) { throw_generic_exception(msg, src); } + [[ noreturn ]] static void throw_error(sstream const & ss, expr const & src) { throw_generic_exception(ss, src); } [[ noreturn ]] static void throw_error(expr const & src, pp_fn const & fn) { throw_generic_exception(src, fn); } [[ noreturn ]] void throw_error(sstream const & ss) const { throw_generic_exception(ss, m_meta); } [[ noreturn ]] void throw_error(expr const & src, sstream const & ss) const { throw_generic_exception(ss, src); } @@ -493,19 +494,66 @@ class equation_compiler_fn { } } + // Store in \c arities the number of arguments of each function being defined. + // This procedure also makes sure that two different equations for the same function + // contain the same number of arguments in the left-hand-side. + // Remark: after executing this procedure the arity of m_fns[i] is stored in arities[i] + // if there is at least one equation for m_fns[i]. + void initialize_arities(expr const & eqns, buffer> & arities) { + lean_assert(arities.empty()); + buffer eqs; + to_equations(eqns, eqs); + lean_assert(!eqs.empty()); + arities.resize(m_fns.size()); + for (expr eq : eqs) { + if (is_lambda_equation(eq)) { + for (expr const & fn : m_fns) + eq = instantiate(binding_body(eq), fn); + while (is_lambda(eq)) + eq = binding_body(eq); + lean_assert(is_equation(eq)); + expr const & lhs = equation_lhs(eq); + buffer lhs_args; + expr const & lhs_fn = get_app_args(lhs, lhs_args); + if (!is_local(lhs_fn)) + throw_error(sstream() << "invalid recursive equation, " + << "left-hand-side is not one of the functions being defined", eq); + unsigned i = 0; + for (; i < m_fns.size(); i++) { + if (lhs_fn == m_fns[i]) { + if (arities[i] && *arities[i] != lhs_args.size()) + throw_error(sstream() << "invalid recursive equation for '" << lhs_fn << "' " + << "left-hand-side of different equations have different number of arguments", eq); + arities[i] = lhs_args.size(); + } + } + } + } + } + // Initialize the variable stack for each function that needs // to be compiled. // This method assumes m_fns has been already initialized. // This method also initialized the buffer prg, but the eqns // field of each program is not initialized by it. - void initialize_var_stack(buffer & prgs) { + // + // See initialize_arities for an explanation for \c arities. + void initialize_var_stack(buffer & prgs, buffer> const & arities) { lean_assert(!m_fns.empty()); lean_assert(prgs.empty()); - for (expr const & fn : m_fns) { + for (unsigned i = 0; i < m_fns.size(); i++) { + expr const & fn = m_fns[i]; buffer args; - expr r_type = to_telescope(mlocal_type(fn), args); + expr r_type = to_telescope(mlocal_type(fn), args); for (expr & arg : args) arg = update_mlocal(arg, whnf(mlocal_type(arg))); + if (arities[i]) { + unsigned arity = *arities[i]; + if (args.size() > arity) { + r_type = Pi(args.size() - arity, args.data() + arity, r_type); + args.shrink(arity); + } + } list ctx = to_list(args); list> vstack = map2>(ctx, [](expr const & e) { return optional(mlocal_name(e)); @@ -579,8 +627,10 @@ class equation_compiler_fn { // Create initial program state for each function being defined. void initialize(expr const & eqns, buffer & prg) { lean_assert(is_equations(eqns)); + buffer> arities; initialize_fns(eqns); - initialize_var_stack(prg); + initialize_arities(eqns, arities); + initialize_var_stack(prg, arities); buffer eqs; to_equations(eqns, eqs); buffer> res_eqns; diff --git a/tests/lean/run/match_fun.lean b/tests/lean/run/match_fun.lean new file mode 100644 index 0000000000..5e9b900630 --- /dev/null +++ b/tests/lean/run/match_fun.lean @@ -0,0 +1,22 @@ +open bool nat + +definition foo (b : bool) : nat → nat := +match b with +| tt := λ x : nat, zero +| ff := λ y : nat, (succ zero) +end + +example : foo tt 1 = zero := rfl +example : foo ff 1 = 1 := rfl + + +definition zero_fn := λ x : nat, zero + +definition foo2 : bool → nat → nat +| foo2 tt := succ +| foo2 ff := zero_fn + +example : foo2 tt 1 = 2 := rfl +example : foo2 tt 2 = 3 := rfl +example : foo2 ff 1 = 0 := rfl +example : foo2 ff 2 = 0 := rfl