From 9a071c18e7b21183506113ec5750417143f94b92 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 27 Mar 2019 11:09:32 -0700 Subject: [PATCH] feat(library/equations_compiler): add support for `partial` definitions --- library/init/core.lean | 3 + src/frontends/lean/definition_cmds.cpp | 3 +- src/library/constants.cpp | 12 ++ src/library/constants.h | 3 + src/library/constants.txt | 3 + .../equations_compiler/partial_rec.cpp | 110 +++++++++++++++++- src/library/equations_compiler/util.cpp | 7 ++ src/library/equations_compiler/util.h | 1 + tests/compiler/partial.lean | 14 +++ tests/compiler/partial.lean.expected.out | 2 + 10 files changed, 152 insertions(+), 6 deletions(-) create mode 100644 tests/compiler/partial.lean create mode 100644 tests/compiler/partial.lean.expected.out diff --git a/library/init/core.lean b/library/init/core.lean index c077abe9b7..0cc29d6b89 100644 --- a/library/init/core.lean +++ b/library/init/core.lean @@ -451,6 +451,9 @@ instance : HasOne Nat := ⟨Nat.succ (Nat.zero)⟩ instance : HasAdd Nat := ⟨Nat.add⟩ +/- Auxiliary constant used by equation compiler. -/ +constant hugeFuel : Nat := 10000 + def std.priority.default : Nat := 1000 def std.priority.max : Nat := 0xFFFFFFFF diff --git a/src/frontends/lean/definition_cmds.cpp b/src/frontends/lean/definition_cmds.cpp index 5b3dd0e465..1789bfc145 100644 --- a/src/frontends/lean/definition_cmds.cpp +++ b/src/frontends/lean/definition_cmds.cpp @@ -568,7 +568,8 @@ static environment elab_single_def(parser & p, decl_cmd_kind const & kind, cmd_m environment new_env = env_n.first; name c_real_name = env_n.second; new_env = add_local_ref(p, new_env, c_name, c_real_name, lp_names, params); - if (!meta.m_modifiers.m_is_unsafe && (kind == decl_cmd_kind::Definition || kind == decl_cmd_kind::Instance)) { + if (!meta.m_modifiers.m_is_unsafe && !meta.m_modifiers.m_is_partial && + (kind == decl_cmd_kind::Definition || kind == decl_cmd_kind::Instance)) { new_env = mk_smart_unfolding_definition(new_env, p.get_options(), c_real_name); } /* Apply attributes last so that they may access any information on the new decl */ diff --git a/src/library/constants.cpp b/src/library/constants.cpp index b0182c3325..5992633f9e 100644 --- a/src/library/constants.cpp +++ b/src/library/constants.cpp @@ -82,6 +82,7 @@ name const * g_heq_refl = nullptr; name const * g_heq_symm = nullptr; name const * g_heq_trans = nullptr; name const * g_heq_of_eq = nullptr; +name const * g_huge_fuel = nullptr; name const * g_id = nullptr; name const * g_id_rhs = nullptr; name const * g_id_delta = nullptr; @@ -93,6 +94,8 @@ name const * g_iff_refl = nullptr; name const * g_iff_symm = nullptr; name const * g_iff_trans = nullptr; name const * g_iff_true_intro = nullptr; +name const * g_inhabited = nullptr; +name const * g_inhabited_default = nullptr; name const * g_int = nullptr; name const * g_int_nat_abs = nullptr; name const * g_int_lt = nullptr; @@ -261,6 +264,7 @@ void initialize_constants() { g_heq_symm = new name{"Heq", "symm"}; g_heq_trans = new name{"Heq", "trans"}; g_heq_of_eq = new name{"heqOfEq"}; + g_huge_fuel = new name{"hugeFuel"}; g_id = new name{"id"}; g_id_rhs = new name{"idRhs"}; g_id_delta = new name{"idDelta"}; @@ -272,6 +276,8 @@ void initialize_constants() { g_iff_symm = new name{"Iff", "symm"}; g_iff_trans = new name{"Iff", "trans"}; g_iff_true_intro = new name{"iffTrueIntro"}; + g_inhabited = new name{"Inhabited"}; + g_inhabited_default = new name{"Inhabited", "default"}; g_int = new name{"Int"}; g_int_nat_abs = new name{"Int", "natAbs"}; g_int_lt = new name{"Int", "lt"}; @@ -441,6 +447,7 @@ void finalize_constants() { delete g_heq_symm; delete g_heq_trans; delete g_heq_of_eq; + delete g_huge_fuel; delete g_id; delete g_id_rhs; delete g_id_delta; @@ -452,6 +459,8 @@ void finalize_constants() { delete g_iff_symm; delete g_iff_trans; delete g_iff_true_intro; + delete g_inhabited; + delete g_inhabited_default; delete g_int; delete g_int_nat_abs; delete g_int_lt; @@ -620,6 +629,7 @@ name const & get_heq_refl_name() { return *g_heq_refl; } name const & get_heq_symm_name() { return *g_heq_symm; } name const & get_heq_trans_name() { return *g_heq_trans; } name const & get_heq_of_eq_name() { return *g_heq_of_eq; } +name const & get_huge_fuel_name() { return *g_huge_fuel; } name const & get_id_name() { return *g_id; } name const & get_id_rhs_name() { return *g_id_rhs; } name const & get_id_delta_name() { return *g_id_delta; } @@ -631,6 +641,8 @@ name const & get_iff_refl_name() { return *g_iff_refl; } name const & get_iff_symm_name() { return *g_iff_symm; } name const & get_iff_trans_name() { return *g_iff_trans; } name const & get_iff_true_intro_name() { return *g_iff_true_intro; } +name const & get_inhabited_name() { return *g_inhabited; } +name const & get_inhabited_default_name() { return *g_inhabited_default; } name const & get_int_name() { return *g_int; } name const & get_int_nat_abs_name() { return *g_int_nat_abs; } name const & get_int_lt_name() { return *g_int_lt; } diff --git a/src/library/constants.h b/src/library/constants.h index 935727952c..832e7f39db 100644 --- a/src/library/constants.h +++ b/src/library/constants.h @@ -84,6 +84,7 @@ name const & get_heq_refl_name(); name const & get_heq_symm_name(); name const & get_heq_trans_name(); name const & get_heq_of_eq_name(); +name const & get_huge_fuel_name(); name const & get_id_name(); name const & get_id_rhs_name(); name const & get_id_delta_name(); @@ -95,6 +96,8 @@ name const & get_iff_refl_name(); name const & get_iff_symm_name(); name const & get_iff_trans_name(); name const & get_iff_true_intro_name(); +name const & get_inhabited_name(); +name const & get_inhabited_default_name(); name const & get_int_name(); name const & get_int_nat_abs_name(); name const & get_int_lt_name(); diff --git a/src/library/constants.txt b/src/library/constants.txt index c3dd72fa22..72bfbebb87 100644 --- a/src/library/constants.txt +++ b/src/library/constants.txt @@ -77,6 +77,7 @@ Heq.refl Heq.symm Heq.trans heqOfEq +hugeFuel id idRhs idDelta @@ -88,6 +89,8 @@ Iff.refl Iff.symm Iff.trans iffTrueIntro +Inhabited +Inhabited.default Int Int.natAbs Int.lt diff --git a/src/library/equations_compiler/partial_rec.cpp b/src/library/equations_compiler/partial_rec.cpp index 3cfa8a69e0..4a805e305f 100644 --- a/src/library/equations_compiler/partial_rec.cpp +++ b/src/library/equations_compiler/partial_rec.cpp @@ -4,8 +4,14 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ +#include "kernel/instantiate.h" #include "library/trace.h" +#include "library/locals.h" +#include "library/util.h" +#include "library/constants.h" +#include "library/type_context.h" #include "library/equations_compiler/structural_rec.h" +#include "frontends/lean/elaborator.h" namespace lean { struct partial_rec_fn { @@ -22,14 +28,104 @@ struct partial_rec_fn { m_env(env), m_elab(elab), m_mctx(mctx), m_lctx(lctx) { } + options const & get_options() const { + return m_elab.get_options(); + } + + type_context_old mk_type_context(local_context const & lctx) { + return type_context_old(m_env, m_mctx, lctx, m_elab.get_cache(), transparency_mode::Semireducible); + } + + type_context_old mk_type_context() { + return mk_type_context(m_lctx); + } + + expr mk_base_case_eq(type_context_old & ctx, expr const & fn, unsigned arity, expr const & new_fn) { + expr fn_type = ctx.infer(fn); + expr result_type = fn_type; + type_context_old::tmp_locals args(ctx); + for (unsigned i = 0; i < arity; i++) { + result_type = ctx.relaxed_whnf(result_type); + expr arg = args.push_local_from_binding(result_type); + result_type = instantiate(binding_body(result_type), arg); + } + level result_level = get_level(ctx, result_type); + expr inhabited_result = mk_app(mk_constant(get_inhabited_name(), {result_level}), result_type); + optional inhabitant = ctx.mk_class_instance(inhabited_result); + if (!inhabitant) { + throw generic_exception(m_ref, "failed to compile partial definition, failed to synthesize result type inhabitant"); + } + expr rhs = mk_app(mk_constant(get_inhabited_default_name(), {result_level}), result_type, *inhabitant); + expr zero = mk_constant(get_nat_zero_name()); + expr lhs = mk_app(mk_app(new_fn, zero), args.as_buffer()); + expr new_eq = mk_equation(lhs, rhs); + return ctx.mk_lambda(args.as_buffer(), new_eq); + } + + void update_eqs(type_context_old & ctx, unpack_eqns & ues, expr const & fn, expr const & new_fn) { + unsigned arity = ues.get_arity_of(0); + buffer & eqns = ues.get_eqns_of(0); + buffer new_eqns; + /* Add base case */ + new_eqns.push_back(mk_base_case_eq(ctx, fn, arity, new_fn)); + /* Add (succ fuel) pattern to each equation */ + for (expr const & eqn : eqns) { + type_context_old::tmp_locals locals(ctx); + unpack_eqn ue(ctx, eqn); + expr lhs = ue.lhs(); + expr rhs = ue.rhs(); + expr fuel = ue.add_var_front("fuel", mk_nat_type()); + expr succ_fuel = mk_app(mk_constant(get_nat_succ_name()), fuel); + buffer lhs_args; + get_app_args(lhs, lhs_args); + expr new_lhs = mk_app(mk_app(new_fn, succ_fuel), lhs_args); + expr new_fn_fuel = mk_app(new_fn, fuel); + expr new_rhs = replace_local(rhs, fn, new_fn_fuel); + ue.lhs() = new_lhs; + ue.rhs() = new_rhs; + new_eqns.push_back(ue.repack()); + } + eqns = new_eqns; + } + expr add_fuel_param(expr const & eqns) { - // TODO(Leo): - return eqns; + type_context_old ctx = mk_type_context(); + unpack_eqns ues(ctx, eqns); + if (ues.get_num_fns() != 1) { + throw generic_exception(m_ref, "failed to compile partial definition, mutual recursion is not supported yet"); + } + expr fn = ues.get_fn(0); + expr fn_type = ctx.infer(fn); + expr new_fn_type = mk_arrow(mk_nat_type(), fn_type); + expr new_fn = ues.update_fn_type(0, new_fn_type); + update_eqs(ctx, ues, fn, new_fn); + expr new_eqns = ues.repack(); + m_mctx = ctx.mctx(); + /* Add `_fueled` suffix */ + equations_header header = get_equations_header(new_eqns); + equations_header new_header = header; + new_header.m_fn_names = names(name(head(header.m_fn_names), "_fueled")); + new_header.m_fn_actual_names = names(name(head(header.m_fn_actual_names), "_fueled")); + return update_equations(new_eqns, new_header); } list add_some_fuel(list const & fns) { - // TODO(Leo): - return fns; + type_context_old ctx = mk_type_context(); + names fn_names = m_header.m_fn_names; + names fn_actual_names = m_header.m_fn_actual_names; + expr huge_fuel = mk_constant(get_huge_fuel_name()); + return map(fns, [&](expr const & fueled_fn) { + name fn_name = head(fn_names); + name fn_actual_name = head(fn_actual_names); + fn_names = tail(fn_names); + fn_actual_names = tail(fn_actual_names); + expr fn_val = mk_app(fueled_fn, huge_fuel); + expr fn_type = ctx.infer(fn_val); + expr r; + std::tie(m_env, r) = mk_aux_definition(m_env, get_options(), m_mctx, m_lctx, m_header, + fn_name, fn_actual_name, fn_type, fn_val); + return r; + }); } eqn_compiler_result operator()(expr const & eqns) { @@ -49,6 +145,10 @@ struct partial_rec_fn { eqn_compiler_result partial_rec(environment & env, elaborator & elab, metavar_context & mctx, local_context const & lctx, expr const & eqns) { - return partial_rec_fn(env, elab, mctx, lctx)(eqns); + partial_rec_fn proc(env, elab, mctx, lctx); + auto r = proc(eqns); + env = proc.m_env; + mctx = proc.m_mctx; + return r; } } diff --git a/src/library/equations_compiler/util.cpp b/src/library/equations_compiler/util.cpp index 09de7fb61c..cf656701bf 100644 --- a/src/library/equations_compiler/util.cpp +++ b/src/library/equations_compiler/util.cpp @@ -170,6 +170,13 @@ expr unpack_eqn::add_var(name const & n, expr const & type) { return m_vars.back(); } +expr unpack_eqn::add_var_front(name const & n, expr const & type) { + m_modified_vars = true; + expr x = m_locals.push_local(n, type); + m_vars.insert(0, x); + return x; +} + expr unpack_eqn::repack() { if (!m_modified_vars && equation_lhs(m_nested_src) == m_lhs && diff --git a/src/library/equations_compiler/util.h b/src/library/equations_compiler/util.h index 9b349d2e4b..aa8b617001 100644 --- a/src/library/equations_compiler/util.h +++ b/src/library/equations_compiler/util.h @@ -65,6 +65,7 @@ class unpack_eqn { public: unpack_eqn(type_context_old & ctx, expr const & eqn); expr add_var(name const & n, expr const & type); + expr add_var_front(name const & n, expr const & type); buffer & get_vars() { return m_vars; } expr & lhs() { return m_lhs; } expr & rhs() { return m_rhs; } diff --git a/tests/compiler/partial.lean b/tests/compiler/partial.lean new file mode 100644 index 0000000000..9ccaa06017 --- /dev/null +++ b/tests/compiler/partial.lean @@ -0,0 +1,14 @@ +set_option pp.implicit true +set_option pp.binder_types false +-- set_option trace.compiler.boxed true + +partial def contains : String → Char → Nat → Bool +| s c i := + if s.atEnd i then false + else if s.get i == c then true + else contains s c (s.next i) + +def main : IO Unit := +let s1 := "hello" in +IO.println (contains s1 'a' 0) *> +IO.println (contains s1 'o' 0) diff --git a/tests/compiler/partial.lean.expected.out b/tests/compiler/partial.lean.expected.out new file mode 100644 index 0000000000..1d474d5255 --- /dev/null +++ b/tests/compiler/partial.lean.expected.out @@ -0,0 +1,2 @@ +false +true