feat(library/equations_compiler): add support for partial definitions
This commit is contained in:
parent
ef5fac1481
commit
9a071c18e7
10 changed files with 152 additions and 6 deletions
|
|
@ -451,6 +451,9 @@ instance : HasOne Nat := ⟨Nat.succ (Nat.zero)⟩
|
|||
|
||||
instance : HasAdd Nat := ⟨Nat.add⟩
|
||||
|
||||
/- Auxiliary constant used by equation compiler. -/
|
||||
constant hugeFuel : Nat := 10000
|
||||
|
||||
def std.priority.default : Nat := 1000
|
||||
def std.priority.max : Nat := 0xFFFFFFFF
|
||||
|
||||
|
|
|
|||
|
|
@ -568,7 +568,8 @@ static environment elab_single_def(parser & p, decl_cmd_kind const & kind, cmd_m
|
|||
environment new_env = env_n.first;
|
||||
name c_real_name = env_n.second;
|
||||
new_env = add_local_ref(p, new_env, c_name, c_real_name, lp_names, params);
|
||||
if (!meta.m_modifiers.m_is_unsafe && (kind == decl_cmd_kind::Definition || kind == decl_cmd_kind::Instance)) {
|
||||
if (!meta.m_modifiers.m_is_unsafe && !meta.m_modifiers.m_is_partial &&
|
||||
(kind == decl_cmd_kind::Definition || kind == decl_cmd_kind::Instance)) {
|
||||
new_env = mk_smart_unfolding_definition(new_env, p.get_options(), c_real_name);
|
||||
}
|
||||
/* Apply attributes last so that they may access any information on the new decl */
|
||||
|
|
|
|||
|
|
@ -82,6 +82,7 @@ name const * g_heq_refl = nullptr;
|
|||
name const * g_heq_symm = nullptr;
|
||||
name const * g_heq_trans = nullptr;
|
||||
name const * g_heq_of_eq = nullptr;
|
||||
name const * g_huge_fuel = nullptr;
|
||||
name const * g_id = nullptr;
|
||||
name const * g_id_rhs = nullptr;
|
||||
name const * g_id_delta = nullptr;
|
||||
|
|
@ -93,6 +94,8 @@ name const * g_iff_refl = nullptr;
|
|||
name const * g_iff_symm = nullptr;
|
||||
name const * g_iff_trans = nullptr;
|
||||
name const * g_iff_true_intro = nullptr;
|
||||
name const * g_inhabited = nullptr;
|
||||
name const * g_inhabited_default = nullptr;
|
||||
name const * g_int = nullptr;
|
||||
name const * g_int_nat_abs = nullptr;
|
||||
name const * g_int_lt = nullptr;
|
||||
|
|
@ -261,6 +264,7 @@ void initialize_constants() {
|
|||
g_heq_symm = new name{"Heq", "symm"};
|
||||
g_heq_trans = new name{"Heq", "trans"};
|
||||
g_heq_of_eq = new name{"heqOfEq"};
|
||||
g_huge_fuel = new name{"hugeFuel"};
|
||||
g_id = new name{"id"};
|
||||
g_id_rhs = new name{"idRhs"};
|
||||
g_id_delta = new name{"idDelta"};
|
||||
|
|
@ -272,6 +276,8 @@ void initialize_constants() {
|
|||
g_iff_symm = new name{"Iff", "symm"};
|
||||
g_iff_trans = new name{"Iff", "trans"};
|
||||
g_iff_true_intro = new name{"iffTrueIntro"};
|
||||
g_inhabited = new name{"Inhabited"};
|
||||
g_inhabited_default = new name{"Inhabited", "default"};
|
||||
g_int = new name{"Int"};
|
||||
g_int_nat_abs = new name{"Int", "natAbs"};
|
||||
g_int_lt = new name{"Int", "lt"};
|
||||
|
|
@ -441,6 +447,7 @@ void finalize_constants() {
|
|||
delete g_heq_symm;
|
||||
delete g_heq_trans;
|
||||
delete g_heq_of_eq;
|
||||
delete g_huge_fuel;
|
||||
delete g_id;
|
||||
delete g_id_rhs;
|
||||
delete g_id_delta;
|
||||
|
|
@ -452,6 +459,8 @@ void finalize_constants() {
|
|||
delete g_iff_symm;
|
||||
delete g_iff_trans;
|
||||
delete g_iff_true_intro;
|
||||
delete g_inhabited;
|
||||
delete g_inhabited_default;
|
||||
delete g_int;
|
||||
delete g_int_nat_abs;
|
||||
delete g_int_lt;
|
||||
|
|
@ -620,6 +629,7 @@ name const & get_heq_refl_name() { return *g_heq_refl; }
|
|||
name const & get_heq_symm_name() { return *g_heq_symm; }
|
||||
name const & get_heq_trans_name() { return *g_heq_trans; }
|
||||
name const & get_heq_of_eq_name() { return *g_heq_of_eq; }
|
||||
name const & get_huge_fuel_name() { return *g_huge_fuel; }
|
||||
name const & get_id_name() { return *g_id; }
|
||||
name const & get_id_rhs_name() { return *g_id_rhs; }
|
||||
name const & get_id_delta_name() { return *g_id_delta; }
|
||||
|
|
@ -631,6 +641,8 @@ name const & get_iff_refl_name() { return *g_iff_refl; }
|
|||
name const & get_iff_symm_name() { return *g_iff_symm; }
|
||||
name const & get_iff_trans_name() { return *g_iff_trans; }
|
||||
name const & get_iff_true_intro_name() { return *g_iff_true_intro; }
|
||||
name const & get_inhabited_name() { return *g_inhabited; }
|
||||
name const & get_inhabited_default_name() { return *g_inhabited_default; }
|
||||
name const & get_int_name() { return *g_int; }
|
||||
name const & get_int_nat_abs_name() { return *g_int_nat_abs; }
|
||||
name const & get_int_lt_name() { return *g_int_lt; }
|
||||
|
|
|
|||
|
|
@ -84,6 +84,7 @@ name const & get_heq_refl_name();
|
|||
name const & get_heq_symm_name();
|
||||
name const & get_heq_trans_name();
|
||||
name const & get_heq_of_eq_name();
|
||||
name const & get_huge_fuel_name();
|
||||
name const & get_id_name();
|
||||
name const & get_id_rhs_name();
|
||||
name const & get_id_delta_name();
|
||||
|
|
@ -95,6 +96,8 @@ name const & get_iff_refl_name();
|
|||
name const & get_iff_symm_name();
|
||||
name const & get_iff_trans_name();
|
||||
name const & get_iff_true_intro_name();
|
||||
name const & get_inhabited_name();
|
||||
name const & get_inhabited_default_name();
|
||||
name const & get_int_name();
|
||||
name const & get_int_nat_abs_name();
|
||||
name const & get_int_lt_name();
|
||||
|
|
|
|||
|
|
@ -77,6 +77,7 @@ Heq.refl
|
|||
Heq.symm
|
||||
Heq.trans
|
||||
heqOfEq
|
||||
hugeFuel
|
||||
id
|
||||
idRhs
|
||||
idDelta
|
||||
|
|
@ -88,6 +89,8 @@ Iff.refl
|
|||
Iff.symm
|
||||
Iff.trans
|
||||
iffTrueIntro
|
||||
Inhabited
|
||||
Inhabited.default
|
||||
Int
|
||||
Int.natAbs
|
||||
Int.lt
|
||||
|
|
|
|||
|
|
@ -4,8 +4,14 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
|||
|
||||
Author: Leonardo de Moura
|
||||
*/
|
||||
#include "kernel/instantiate.h"
|
||||
#include "library/trace.h"
|
||||
#include "library/locals.h"
|
||||
#include "library/util.h"
|
||||
#include "library/constants.h"
|
||||
#include "library/type_context.h"
|
||||
#include "library/equations_compiler/structural_rec.h"
|
||||
#include "frontends/lean/elaborator.h"
|
||||
|
||||
namespace lean {
|
||||
struct partial_rec_fn {
|
||||
|
|
@ -22,14 +28,104 @@ struct partial_rec_fn {
|
|||
m_env(env), m_elab(elab), m_mctx(mctx), m_lctx(lctx) {
|
||||
}
|
||||
|
||||
options const & get_options() const {
|
||||
return m_elab.get_options();
|
||||
}
|
||||
|
||||
type_context_old mk_type_context(local_context const & lctx) {
|
||||
return type_context_old(m_env, m_mctx, lctx, m_elab.get_cache(), transparency_mode::Semireducible);
|
||||
}
|
||||
|
||||
type_context_old mk_type_context() {
|
||||
return mk_type_context(m_lctx);
|
||||
}
|
||||
|
||||
expr mk_base_case_eq(type_context_old & ctx, expr const & fn, unsigned arity, expr const & new_fn) {
|
||||
expr fn_type = ctx.infer(fn);
|
||||
expr result_type = fn_type;
|
||||
type_context_old::tmp_locals args(ctx);
|
||||
for (unsigned i = 0; i < arity; i++) {
|
||||
result_type = ctx.relaxed_whnf(result_type);
|
||||
expr arg = args.push_local_from_binding(result_type);
|
||||
result_type = instantiate(binding_body(result_type), arg);
|
||||
}
|
||||
level result_level = get_level(ctx, result_type);
|
||||
expr inhabited_result = mk_app(mk_constant(get_inhabited_name(), {result_level}), result_type);
|
||||
optional<expr> inhabitant = ctx.mk_class_instance(inhabited_result);
|
||||
if (!inhabitant) {
|
||||
throw generic_exception(m_ref, "failed to compile partial definition, failed to synthesize result type inhabitant");
|
||||
}
|
||||
expr rhs = mk_app(mk_constant(get_inhabited_default_name(), {result_level}), result_type, *inhabitant);
|
||||
expr zero = mk_constant(get_nat_zero_name());
|
||||
expr lhs = mk_app(mk_app(new_fn, zero), args.as_buffer());
|
||||
expr new_eq = mk_equation(lhs, rhs);
|
||||
return ctx.mk_lambda(args.as_buffer(), new_eq);
|
||||
}
|
||||
|
||||
void update_eqs(type_context_old & ctx, unpack_eqns & ues, expr const & fn, expr const & new_fn) {
|
||||
unsigned arity = ues.get_arity_of(0);
|
||||
buffer<expr> & eqns = ues.get_eqns_of(0);
|
||||
buffer<expr> new_eqns;
|
||||
/* Add base case */
|
||||
new_eqns.push_back(mk_base_case_eq(ctx, fn, arity, new_fn));
|
||||
/* Add (succ fuel) pattern to each equation */
|
||||
for (expr const & eqn : eqns) {
|
||||
type_context_old::tmp_locals locals(ctx);
|
||||
unpack_eqn ue(ctx, eqn);
|
||||
expr lhs = ue.lhs();
|
||||
expr rhs = ue.rhs();
|
||||
expr fuel = ue.add_var_front("fuel", mk_nat_type());
|
||||
expr succ_fuel = mk_app(mk_constant(get_nat_succ_name()), fuel);
|
||||
buffer<expr> lhs_args;
|
||||
get_app_args(lhs, lhs_args);
|
||||
expr new_lhs = mk_app(mk_app(new_fn, succ_fuel), lhs_args);
|
||||
expr new_fn_fuel = mk_app(new_fn, fuel);
|
||||
expr new_rhs = replace_local(rhs, fn, new_fn_fuel);
|
||||
ue.lhs() = new_lhs;
|
||||
ue.rhs() = new_rhs;
|
||||
new_eqns.push_back(ue.repack());
|
||||
}
|
||||
eqns = new_eqns;
|
||||
}
|
||||
|
||||
expr add_fuel_param(expr const & eqns) {
|
||||
// TODO(Leo):
|
||||
return eqns;
|
||||
type_context_old ctx = mk_type_context();
|
||||
unpack_eqns ues(ctx, eqns);
|
||||
if (ues.get_num_fns() != 1) {
|
||||
throw generic_exception(m_ref, "failed to compile partial definition, mutual recursion is not supported yet");
|
||||
}
|
||||
expr fn = ues.get_fn(0);
|
||||
expr fn_type = ctx.infer(fn);
|
||||
expr new_fn_type = mk_arrow(mk_nat_type(), fn_type);
|
||||
expr new_fn = ues.update_fn_type(0, new_fn_type);
|
||||
update_eqs(ctx, ues, fn, new_fn);
|
||||
expr new_eqns = ues.repack();
|
||||
m_mctx = ctx.mctx();
|
||||
/* Add `_fueled` suffix */
|
||||
equations_header header = get_equations_header(new_eqns);
|
||||
equations_header new_header = header;
|
||||
new_header.m_fn_names = names(name(head(header.m_fn_names), "_fueled"));
|
||||
new_header.m_fn_actual_names = names(name(head(header.m_fn_actual_names), "_fueled"));
|
||||
return update_equations(new_eqns, new_header);
|
||||
}
|
||||
|
||||
list<expr> add_some_fuel(list<expr> const & fns) {
|
||||
// TODO(Leo):
|
||||
return fns;
|
||||
type_context_old ctx = mk_type_context();
|
||||
names fn_names = m_header.m_fn_names;
|
||||
names fn_actual_names = m_header.m_fn_actual_names;
|
||||
expr huge_fuel = mk_constant(get_huge_fuel_name());
|
||||
return map(fns, [&](expr const & fueled_fn) {
|
||||
name fn_name = head(fn_names);
|
||||
name fn_actual_name = head(fn_actual_names);
|
||||
fn_names = tail(fn_names);
|
||||
fn_actual_names = tail(fn_actual_names);
|
||||
expr fn_val = mk_app(fueled_fn, huge_fuel);
|
||||
expr fn_type = ctx.infer(fn_val);
|
||||
expr r;
|
||||
std::tie(m_env, r) = mk_aux_definition(m_env, get_options(), m_mctx, m_lctx, m_header,
|
||||
fn_name, fn_actual_name, fn_type, fn_val);
|
||||
return r;
|
||||
});
|
||||
}
|
||||
|
||||
eqn_compiler_result operator()(expr const & eqns) {
|
||||
|
|
@ -49,6 +145,10 @@ struct partial_rec_fn {
|
|||
eqn_compiler_result partial_rec(environment & env, elaborator & elab,
|
||||
metavar_context & mctx, local_context const & lctx,
|
||||
expr const & eqns) {
|
||||
return partial_rec_fn(env, elab, mctx, lctx)(eqns);
|
||||
partial_rec_fn proc(env, elab, mctx, lctx);
|
||||
auto r = proc(eqns);
|
||||
env = proc.m_env;
|
||||
mctx = proc.m_mctx;
|
||||
return r;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -170,6 +170,13 @@ expr unpack_eqn::add_var(name const & n, expr const & type) {
|
|||
return m_vars.back();
|
||||
}
|
||||
|
||||
expr unpack_eqn::add_var_front(name const & n, expr const & type) {
|
||||
m_modified_vars = true;
|
||||
expr x = m_locals.push_local(n, type);
|
||||
m_vars.insert(0, x);
|
||||
return x;
|
||||
}
|
||||
|
||||
expr unpack_eqn::repack() {
|
||||
if (!m_modified_vars &&
|
||||
equation_lhs(m_nested_src) == m_lhs &&
|
||||
|
|
|
|||
|
|
@ -65,6 +65,7 @@ class unpack_eqn {
|
|||
public:
|
||||
unpack_eqn(type_context_old & ctx, expr const & eqn);
|
||||
expr add_var(name const & n, expr const & type);
|
||||
expr add_var_front(name const & n, expr const & type);
|
||||
buffer<expr> & get_vars() { return m_vars; }
|
||||
expr & lhs() { return m_lhs; }
|
||||
expr & rhs() { return m_rhs; }
|
||||
|
|
|
|||
14
tests/compiler/partial.lean
Normal file
14
tests/compiler/partial.lean
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
set_option pp.implicit true
|
||||
set_option pp.binder_types false
|
||||
-- set_option trace.compiler.boxed true
|
||||
|
||||
partial def contains : String → Char → Nat → Bool
|
||||
| s c i :=
|
||||
if s.atEnd i then false
|
||||
else if s.get i == c then true
|
||||
else contains s c (s.next i)
|
||||
|
||||
def main : IO Unit :=
|
||||
let s1 := "hello" in
|
||||
IO.println (contains s1 'a' 0) *>
|
||||
IO.println (contains s1 'o' 0)
|
||||
2
tests/compiler/partial.lean.expected.out
Normal file
2
tests/compiler/partial.lean.expected.out
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
false
|
||||
true
|
||||
Loading…
Add table
Reference in a new issue