diff --git a/src/frontends/lean/elaborator.cpp b/src/frontends/lean/elaborator.cpp index 6f63f52fc2..7140faa614 100644 --- a/src/frontends/lean/elaborator.cpp +++ b/src/frontends/lean/elaborator.cpp @@ -853,22 +853,31 @@ expr elaborator::visit_elim_app(expr const & fn, elim_info const & info, buffer< return r; } -optional elaborator::visit_app_with_expected(expr const & fn, buffer const & args, - expr const & expected_type, expr const & ref) { - snapshot C(*this); +struct elaborator::first_pass_info { buffer args_mvars; buffer args_expected_types; buffer new_args; /* new_args_size[i] contains size of new_args after args_mvars[i] was pushed. We need this information for producing error messages. */ buffer new_args_size; - expr fn_type = infer_type(fn); - expr type_before_whnf = fn_type; - expr type = whnf(fn_type); buffer new_instances; /* new_instances_size[i] contains the size of new_instances before (and after) args_mvars[i] was pushed. */ buffer new_instances_size; +}; + +/* Check if fn args resulting type matches the expected type, and fill + first_pass_info & info with information collected in this first pass. + Return true iff the types match. + + Remark: the arguments \c args are *not* visited in this first pass. + They are only used in this method to provide location information. */ +bool elaborator::first_pass(expr const & fn, buffer const & args, + expr const & expected_type, expr const & ref, + first_pass_info & info) { + expr fn_type = infer_type(fn); + expr type_before_whnf = fn_type; + expr type = whnf(fn_type); unsigned i = 0; /* First pass: compute type for an fn-application, and unify it with expected_type. We don't visit expelicit arguments at this point. */ @@ -882,7 +891,7 @@ optional elaborator::visit_app_with_expected(expr const & fn, buffer // implicit argument new_arg = mk_metavar(d, ref); if (bi.is_inst_implicit()) - new_instances.push_back(new_arg); + info.new_instances.push_back(new_arg); // implicit arguments are tagged as inaccessible in patterns if (m_in_pattern) new_arg = copy_tag(ref, mk_inaccessible(new_arg)); @@ -890,15 +899,15 @@ optional elaborator::visit_app_with_expected(expr const & fn, buffer // explicit argument expr const & arg_ref = args[i]; i++; - args_expected_types.push_back(d); + info.args_expected_types.push_back(d); new_arg = mk_metavar(d, arg_ref); - args_mvars.push_back(new_arg); - new_args_size.push_back(new_args.size()); - new_instances_size.push_back(new_instances.size()); + info.args_mvars.push_back(new_arg); + info.new_args_size.push_back(info.new_args.size()); + info.new_instances_size.push_back(info.new_instances.size()); } else { break; } - new_args.push_back(new_arg); + info.new_args.push_back(new_arg); /* See comment above at visit_base_app_core */ type_before_whnf = instantiate(binding_body(type), new_arg); type = whnf(type_before_whnf); @@ -906,60 +915,72 @@ optional elaborator::visit_app_with_expected(expr const & fn, buffer type = type_before_whnf; if (i != args.size()) { /* failed to consume all explicit arguments, use base elaboration for applications */ - C.restore(*this); - return none_expr(); + return false; } - if (!is_def_eq(expected_type, type)) { - /* failed to unify expected_type and computed type, use base elaboration for applications */ - C.restore(*this); - return none_expr(); - } - lean_assert(args_expected_types.size() == args.size()); - lean_assert(args_expected_types.size() == args_mvars.size()); - lean_assert(args_expected_types.size() == new_args_size.size()); - lean_assert(args_expected_types.size() == new_instances_size.size()); + lean_assert(args.size() == args_expected_types.size()); + lean_assert(args.size() == args_mvars.size()); + lean_assert(args.size() == new_args_size.size()); + lean_assert(args.size() == new_instances_size.size()); + return is_def_eq(expected_type, type); +} + +/* Using the information colllected in the first-pass, visit the arguments args. + And then, create resulting application */ +expr elaborator::second_pass(expr const & fn, buffer const & args, + expr const & ref, first_pass_info & info) { unsigned j = 0; /* for traversing new_instances */ /* Second pass: visit explicit arguments using the information we gained about their expected types */ for (unsigned i = 0; i < args.size(); i++) { /* Process type class instances upto args[i] */ - for (; j < new_instances_size[i]; j++) { - expr const & mvar = new_instances[j]; + for (; j < info.new_instances_size[i]; j++) { + expr const & mvar = info.new_instances[j]; if (!try_synthesize_type_class_instance(mvar)) m_instances = cons(mvar, m_instances); } expr ref_arg = args[i]; - expr new_arg = visit(args[i], some_expr(args_expected_types[i])); + expr new_arg = visit(args[i], some_expr(info.args_expected_types[i])); expr new_arg_type = infer_type(new_arg); - if (optional new_new_arg = ensure_has_type(new_arg, new_arg_type, args_expected_types[i], ref_arg)) { + if (optional new_new_arg = ensure_has_type(new_arg, new_arg_type, info.args_expected_types[i], ref_arg)) { new_arg = *new_new_arg; } else { - new_args.shrink(new_args_size[i]); - new_args.push_back(new_arg); - format msg = mk_app_type_mismatch_error(mk_app(fn, new_args.size(), new_args.data()), - new_arg, new_arg_type, args_expected_types[i]); + info.new_args.shrink(info.new_args_size[i]); + info.new_args.push_back(new_arg); + format msg = mk_app_type_mismatch_error(mk_app(fn, info.new_args.size(), info.new_args.data()), + new_arg, new_arg_type, info.args_expected_types[i]); throw elaborator_exception(ref, msg); } - if (!is_def_eq(args_mvars[i], new_arg)) { + if (!is_def_eq(info.args_mvars[i], new_arg)) { auto pp_fn = mk_pp_ctx(); throw elaborator_exception(ref_arg, format("invalid application, type mismatch") + pp_indent(pp_fn, new_arg) + line() + format("has type") + pp_indent(pp_fn, infer_type(new_arg)) + line() + format("failed to be unified with") + - pp_indent(pp_fn, args_mvars[i]) + + pp_indent(pp_fn, info.args_mvars[i]) + line() + format("has type") + - pp_indent(pp_fn, infer_type(args_mvars[i]))); + pp_indent(pp_fn, infer_type(info.args_mvars[i]))); } else { - new_args[new_args_size[i]] = new_arg; + info.new_args[info.new_args_size[i]] = new_arg; } } - for (; j < new_instances.size(); j++) { - expr const & mvar = new_instances[j]; + for (; j < info.new_instances.size(); j++) { + expr const & mvar = info.new_instances[j]; if (!try_synthesize_type_class_instance(mvar)) m_instances = cons(mvar, m_instances); } - return some_expr(mk_app(fn, new_args.size(), new_args.data())); + return mk_app(fn, info.new_args.size(), info.new_args.data()); +} + +optional elaborator::visit_app_with_expected(expr const & fn, buffer const & args, + expr const & expected_type, expr const & ref) { + snapshot C(*this); + first_pass_info info; + if (!first_pass(fn, args, expected_type, ref, info)) { + C.restore(*this); + return none_expr(); + } + return some_expr(second_pass(fn, args, ref, info)); } bool elaborator::is_with_expected_candidate(expr const & fn) { @@ -1069,7 +1090,7 @@ expr elaborator::visit_base_app_core(expr const & _fn, arg_mask amask, buffer const & args, - optional const & expected_type, expr const & ref) { + optional const & expected_type, expr const & ref) { return visit_base_app_core(fn, amask, args, false, expected_type, ref); } diff --git a/src/frontends/lean/elaborator.h b/src/frontends/lean/elaborator.h index 54dcba4c35..612c8a3df3 100644 --- a/src/frontends/lean/elaborator.h +++ b/src/frontends/lean/elaborator.h @@ -162,6 +162,9 @@ private: expr const & ref); bool is_with_expected_candidate(expr const & fn); + struct first_pass_info; + bool first_pass(expr const & fn, buffer const & args, expr const & expected_type, expr const & ref, first_pass_info & info); + expr second_pass(expr const & fn, buffer const & args, expr const & ref, first_pass_info & info); optional visit_app_with_expected(expr const & fn, buffer const & args, expr const & expected_type, expr const & ref); expr visit_base_app_core(expr const & fn, arg_mask amask, buffer const & args,