feat(library/equations_compiler): add support for partial definitions

This commit is contained in:
Leonardo de Moura 2019-03-27 11:09:32 -07:00
parent ef5fac1481
commit 9a071c18e7
10 changed files with 152 additions and 6 deletions

View file

@ -451,6 +451,9 @@ instance : HasOne Nat := ⟨Nat.succ (Nat.zero)⟩
instance : HasAdd Nat := ⟨Nat.add⟩
/- Auxiliary constant used by equation compiler. -/
constant hugeFuel : Nat := 10000
def std.priority.default : Nat := 1000
def std.priority.max : Nat := 0xFFFFFFFF

View file

@ -568,7 +568,8 @@ static environment elab_single_def(parser & p, decl_cmd_kind const & kind, cmd_m
environment new_env = env_n.first;
name c_real_name = env_n.second;
new_env = add_local_ref(p, new_env, c_name, c_real_name, lp_names, params);
if (!meta.m_modifiers.m_is_unsafe && (kind == decl_cmd_kind::Definition || kind == decl_cmd_kind::Instance)) {
if (!meta.m_modifiers.m_is_unsafe && !meta.m_modifiers.m_is_partial &&
(kind == decl_cmd_kind::Definition || kind == decl_cmd_kind::Instance)) {
new_env = mk_smart_unfolding_definition(new_env, p.get_options(), c_real_name);
}
/* Apply attributes last so that they may access any information on the new decl */

View file

@ -82,6 +82,7 @@ name const * g_heq_refl = nullptr;
name const * g_heq_symm = nullptr;
name const * g_heq_trans = nullptr;
name const * g_heq_of_eq = nullptr;
name const * g_huge_fuel = nullptr;
name const * g_id = nullptr;
name const * g_id_rhs = nullptr;
name const * g_id_delta = nullptr;
@ -93,6 +94,8 @@ name const * g_iff_refl = nullptr;
name const * g_iff_symm = nullptr;
name const * g_iff_trans = nullptr;
name const * g_iff_true_intro = nullptr;
name const * g_inhabited = nullptr;
name const * g_inhabited_default = nullptr;
name const * g_int = nullptr;
name const * g_int_nat_abs = nullptr;
name const * g_int_lt = nullptr;
@ -261,6 +264,7 @@ void initialize_constants() {
g_heq_symm = new name{"Heq", "symm"};
g_heq_trans = new name{"Heq", "trans"};
g_heq_of_eq = new name{"heqOfEq"};
g_huge_fuel = new name{"hugeFuel"};
g_id = new name{"id"};
g_id_rhs = new name{"idRhs"};
g_id_delta = new name{"idDelta"};
@ -272,6 +276,8 @@ void initialize_constants() {
g_iff_symm = new name{"Iff", "symm"};
g_iff_trans = new name{"Iff", "trans"};
g_iff_true_intro = new name{"iffTrueIntro"};
g_inhabited = new name{"Inhabited"};
g_inhabited_default = new name{"Inhabited", "default"};
g_int = new name{"Int"};
g_int_nat_abs = new name{"Int", "natAbs"};
g_int_lt = new name{"Int", "lt"};
@ -441,6 +447,7 @@ void finalize_constants() {
delete g_heq_symm;
delete g_heq_trans;
delete g_heq_of_eq;
delete g_huge_fuel;
delete g_id;
delete g_id_rhs;
delete g_id_delta;
@ -452,6 +459,8 @@ void finalize_constants() {
delete g_iff_symm;
delete g_iff_trans;
delete g_iff_true_intro;
delete g_inhabited;
delete g_inhabited_default;
delete g_int;
delete g_int_nat_abs;
delete g_int_lt;
@ -620,6 +629,7 @@ name const & get_heq_refl_name() { return *g_heq_refl; }
name const & get_heq_symm_name() { return *g_heq_symm; }
name const & get_heq_trans_name() { return *g_heq_trans; }
name const & get_heq_of_eq_name() { return *g_heq_of_eq; }
name const & get_huge_fuel_name() { return *g_huge_fuel; }
name const & get_id_name() { return *g_id; }
name const & get_id_rhs_name() { return *g_id_rhs; }
name const & get_id_delta_name() { return *g_id_delta; }
@ -631,6 +641,8 @@ name const & get_iff_refl_name() { return *g_iff_refl; }
name const & get_iff_symm_name() { return *g_iff_symm; }
name const & get_iff_trans_name() { return *g_iff_trans; }
name const & get_iff_true_intro_name() { return *g_iff_true_intro; }
name const & get_inhabited_name() { return *g_inhabited; }
name const & get_inhabited_default_name() { return *g_inhabited_default; }
name const & get_int_name() { return *g_int; }
name const & get_int_nat_abs_name() { return *g_int_nat_abs; }
name const & get_int_lt_name() { return *g_int_lt; }

View file

@ -84,6 +84,7 @@ name const & get_heq_refl_name();
name const & get_heq_symm_name();
name const & get_heq_trans_name();
name const & get_heq_of_eq_name();
name const & get_huge_fuel_name();
name const & get_id_name();
name const & get_id_rhs_name();
name const & get_id_delta_name();
@ -95,6 +96,8 @@ name const & get_iff_refl_name();
name const & get_iff_symm_name();
name const & get_iff_trans_name();
name const & get_iff_true_intro_name();
name const & get_inhabited_name();
name const & get_inhabited_default_name();
name const & get_int_name();
name const & get_int_nat_abs_name();
name const & get_int_lt_name();

View file

@ -77,6 +77,7 @@ Heq.refl
Heq.symm
Heq.trans
heqOfEq
hugeFuel
id
idRhs
idDelta
@ -88,6 +89,8 @@ Iff.refl
Iff.symm
Iff.trans
iffTrueIntro
Inhabited
Inhabited.default
Int
Int.natAbs
Int.lt

View file

@ -4,8 +4,14 @@ Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#include "kernel/instantiate.h"
#include "library/trace.h"
#include "library/locals.h"
#include "library/util.h"
#include "library/constants.h"
#include "library/type_context.h"
#include "library/equations_compiler/structural_rec.h"
#include "frontends/lean/elaborator.h"
namespace lean {
struct partial_rec_fn {
@ -22,14 +28,104 @@ struct partial_rec_fn {
m_env(env), m_elab(elab), m_mctx(mctx), m_lctx(lctx) {
}
options const & get_options() const {
return m_elab.get_options();
}
type_context_old mk_type_context(local_context const & lctx) {
return type_context_old(m_env, m_mctx, lctx, m_elab.get_cache(), transparency_mode::Semireducible);
}
type_context_old mk_type_context() {
return mk_type_context(m_lctx);
}
expr mk_base_case_eq(type_context_old & ctx, expr const & fn, unsigned arity, expr const & new_fn) {
expr fn_type = ctx.infer(fn);
expr result_type = fn_type;
type_context_old::tmp_locals args(ctx);
for (unsigned i = 0; i < arity; i++) {
result_type = ctx.relaxed_whnf(result_type);
expr arg = args.push_local_from_binding(result_type);
result_type = instantiate(binding_body(result_type), arg);
}
level result_level = get_level(ctx, result_type);
expr inhabited_result = mk_app(mk_constant(get_inhabited_name(), {result_level}), result_type);
optional<expr> inhabitant = ctx.mk_class_instance(inhabited_result);
if (!inhabitant) {
throw generic_exception(m_ref, "failed to compile partial definition, failed to synthesize result type inhabitant");
}
expr rhs = mk_app(mk_constant(get_inhabited_default_name(), {result_level}), result_type, *inhabitant);
expr zero = mk_constant(get_nat_zero_name());
expr lhs = mk_app(mk_app(new_fn, zero), args.as_buffer());
expr new_eq = mk_equation(lhs, rhs);
return ctx.mk_lambda(args.as_buffer(), new_eq);
}
void update_eqs(type_context_old & ctx, unpack_eqns & ues, expr const & fn, expr const & new_fn) {
unsigned arity = ues.get_arity_of(0);
buffer<expr> & eqns = ues.get_eqns_of(0);
buffer<expr> new_eqns;
/* Add base case */
new_eqns.push_back(mk_base_case_eq(ctx, fn, arity, new_fn));
/* Add (succ fuel) pattern to each equation */
for (expr const & eqn : eqns) {
type_context_old::tmp_locals locals(ctx);
unpack_eqn ue(ctx, eqn);
expr lhs = ue.lhs();
expr rhs = ue.rhs();
expr fuel = ue.add_var_front("fuel", mk_nat_type());
expr succ_fuel = mk_app(mk_constant(get_nat_succ_name()), fuel);
buffer<expr> lhs_args;
get_app_args(lhs, lhs_args);
expr new_lhs = mk_app(mk_app(new_fn, succ_fuel), lhs_args);
expr new_fn_fuel = mk_app(new_fn, fuel);
expr new_rhs = replace_local(rhs, fn, new_fn_fuel);
ue.lhs() = new_lhs;
ue.rhs() = new_rhs;
new_eqns.push_back(ue.repack());
}
eqns = new_eqns;
}
expr add_fuel_param(expr const & eqns) {
// TODO(Leo):
return eqns;
type_context_old ctx = mk_type_context();
unpack_eqns ues(ctx, eqns);
if (ues.get_num_fns() != 1) {
throw generic_exception(m_ref, "failed to compile partial definition, mutual recursion is not supported yet");
}
expr fn = ues.get_fn(0);
expr fn_type = ctx.infer(fn);
expr new_fn_type = mk_arrow(mk_nat_type(), fn_type);
expr new_fn = ues.update_fn_type(0, new_fn_type);
update_eqs(ctx, ues, fn, new_fn);
expr new_eqns = ues.repack();
m_mctx = ctx.mctx();
/* Add `_fueled` suffix */
equations_header header = get_equations_header(new_eqns);
equations_header new_header = header;
new_header.m_fn_names = names(name(head(header.m_fn_names), "_fueled"));
new_header.m_fn_actual_names = names(name(head(header.m_fn_actual_names), "_fueled"));
return update_equations(new_eqns, new_header);
}
list<expr> add_some_fuel(list<expr> const & fns) {
// TODO(Leo):
return fns;
type_context_old ctx = mk_type_context();
names fn_names = m_header.m_fn_names;
names fn_actual_names = m_header.m_fn_actual_names;
expr huge_fuel = mk_constant(get_huge_fuel_name());
return map(fns, [&](expr const & fueled_fn) {
name fn_name = head(fn_names);
name fn_actual_name = head(fn_actual_names);
fn_names = tail(fn_names);
fn_actual_names = tail(fn_actual_names);
expr fn_val = mk_app(fueled_fn, huge_fuel);
expr fn_type = ctx.infer(fn_val);
expr r;
std::tie(m_env, r) = mk_aux_definition(m_env, get_options(), m_mctx, m_lctx, m_header,
fn_name, fn_actual_name, fn_type, fn_val);
return r;
});
}
eqn_compiler_result operator()(expr const & eqns) {
@ -49,6 +145,10 @@ struct partial_rec_fn {
eqn_compiler_result partial_rec(environment & env, elaborator & elab,
metavar_context & mctx, local_context const & lctx,
expr const & eqns) {
return partial_rec_fn(env, elab, mctx, lctx)(eqns);
partial_rec_fn proc(env, elab, mctx, lctx);
auto r = proc(eqns);
env = proc.m_env;
mctx = proc.m_mctx;
return r;
}
}

View file

@ -170,6 +170,13 @@ expr unpack_eqn::add_var(name const & n, expr const & type) {
return m_vars.back();
}
expr unpack_eqn::add_var_front(name const & n, expr const & type) {
m_modified_vars = true;
expr x = m_locals.push_local(n, type);
m_vars.insert(0, x);
return x;
}
expr unpack_eqn::repack() {
if (!m_modified_vars &&
equation_lhs(m_nested_src) == m_lhs &&

View file

@ -65,6 +65,7 @@ class unpack_eqn {
public:
unpack_eqn(type_context_old & ctx, expr const & eqn);
expr add_var(name const & n, expr const & type);
expr add_var_front(name const & n, expr const & type);
buffer<expr> & get_vars() { return m_vars; }
expr & lhs() { return m_lhs; }
expr & rhs() { return m_rhs; }

View file

@ -0,0 +1,14 @@
set_option pp.implicit true
set_option pp.binder_types false
-- set_option trace.compiler.boxed true
partial def contains : String → Char → Nat → Bool
| s c i :=
if s.atEnd i then false
else if s.get i == c then true
else contains s c (s.next i)
def main : IO Unit :=
let s1 := "hello" in
IO.println (contains s1 'a' 0) *>
IO.println (contains s1 'o' 0)

View file

@ -0,0 +1,2 @@
false
true