diff --git a/src/kernel/metavar.cpp b/src/kernel/metavar.cpp index 163b6fbf05..251f95b30b 100644 --- a/src/kernel/metavar.cpp +++ b/src/kernel/metavar.cpp @@ -135,8 +135,12 @@ expr instantiate_metavars(expr const & e, metavar_env const & env) { return replace_fn(f)(e); } +meta_ctx add_lift(meta_ctx const & ctx, unsigned s, unsigned n) { + return cons(mk_lift(s, n), ctx); +} + expr add_lift(expr const & m, unsigned s, unsigned n) { - return mk_metavar(metavar_idx(m), cons(mk_lift(s, n), metavar_ctx(m))); + return mk_metavar(metavar_idx(m), add_lift(metavar_ctx(m), s, n)); } meta_ctx add_lower(meta_ctx const & ctx, unsigned s2, unsigned n2) { diff --git a/src/kernel/metavar.h b/src/kernel/metavar.h index 16043937bb..584609b3e3 100644 --- a/src/kernel/metavar.h +++ b/src/kernel/metavar.h @@ -150,6 +150,11 @@ public: */ expr instantiate_metavars(expr const & e, metavar_env const & env); +/** + \brief Extend the context \c ctx with the entry lift:s:n +*/ +meta_ctx add_lift(meta_ctx const & ctx, unsigned s, unsigned n); + /** \brief Add a lift:s:n operation to the context of the given metavariable. @@ -165,6 +170,11 @@ expr add_lift(expr const & m, unsigned s, unsigned n); */ expr add_lower(expr const & m, unsigned s, unsigned n); +/** + \brief Extend the context \c ctx with the entry lower:s:n +*/ +meta_ctx add_lower(meta_ctx const & ctx, unsigned s, unsigned n); + /** \brief Add a subst:s:v operation to the context of the given metavariable. @@ -173,6 +183,12 @@ expr add_lower(expr const & m, unsigned s, unsigned n); */ expr add_subst(expr const & m, unsigned s, expr const & v); + +/** + \brief Extend the context \c ctx with the entry subst:s v +*/ +meta_ctx add_subst(meta_ctx const & ctx, unsigned s, expr const & v); + /** \brief Return true iff the given metavariable has a non-empty context associated with it. diff --git a/src/kernel/normalizer.cpp b/src/kernel/normalizer.cpp index ea1afada46..349db4f50d 100644 --- a/src/kernel/normalizer.cpp +++ b/src/kernel/normalizer.cpp @@ -141,6 +141,46 @@ class normalizer::imp { return expr(); } + + bool is_identity_stack(value_stack const & s, unsigned k) { + if (length(s) != k) + return false; + unsigned i = 0; + for (auto e : s) { + if (e.kind() != svalue_kind::BoundedVar || k - to_bvar(e) - 1 != i) + return false; + ++i; + } + return true; + } + + /** + \brief Update the metavariable context for \c m based on the + value_stack \c s and the number of binders \c k. + + \pre is_metavar(m) + */ + expr updt_metavar(expr const & m, value_stack const & s, unsigned k) { + lean_assert(is_metavar(m)); + if (is_identity_stack(s, k)) + return m; + meta_ctx ctx = metavar_ctx(m); + unsigned midx = metavar_idx(m); + unsigned s_len = length(s); + unsigned i = 0; + ctx = add_lift(ctx, s_len, s_len); + for (auto e : s) { + ctx = add_subst(ctx, i, lift_free_vars(reify(e, k), s_len)); + ++i; + } + ctx = add_lower(ctx, s_len, s_len); + unsigned m_ctx_len = m_ctx.size(); + lean_assert(s_len + m_ctx_len >= k); + if (s_len + m_ctx_len > k) + ctx = add_lower(ctx, s_len, s_len + m_ctx_len - k); + return mk_metavar(midx, ctx); + } + /** \brief Normalize the expression \c a in a context composed of stack \c s and \c k binders. */ svalue normalize(expr const & a, value_stack const & s, unsigned k) { flet l(m_depth, m_depth+1); @@ -161,8 +201,7 @@ class normalizer::imp { if (m_menv && m_menv->contains(a) && m_menv->is_assigned(a)) { r = normalize(m_menv->get_subst(a), s, k); } else { - // TODO(Leo) We must store in the metavariable the implicit substitution stored in value_stack. - r = svalue(a); + r = svalue(updt_metavar(a, s, k)); } break; case expr_kind::Var: diff --git a/src/tests/kernel/metavar.cpp b/src/tests/kernel/metavar.cpp index ae9ddbf3bb..1ecf75247f 100644 --- a/src/tests/kernel/metavar.cpp +++ b/src/tests/kernel/metavar.cpp @@ -249,6 +249,14 @@ static void tst11() { } static void tst12() { + metavar_env menv; + expr m = menv.mk_metavar(); + expr f = Const("f"); + std::cout << instantiate(f(m), {Var(0), Var(1)}) << "\n"; + std::cout << instantiate(f(m), {Var(1), Var(0)}) << "\n"; +} + +static void tst13() { environment env; metavar_env menv; expr m = menv.mk_metavar(); @@ -262,6 +270,119 @@ static void tst12() { expr F = Fun({x, N}, f(m))(a); normalizer norm(env); std::cout << norm(F) << "\n"; + menv.assign(0, Var(0)); + std::cout << norm(instantiate_metavars(F, menv)) << "\n"; + lean_assert(norm(instantiate_metavars(F, menv)) == + instantiate_metavars(norm(F), menv)); +} + +static void tst14() { + environment env; + metavar_env menv; + expr m1 = menv.mk_metavar(); + expr m2 = menv.mk_metavar(); + expr N = Const("N"); + expr f = Const("f"); + expr h = Const("h"); + expr a = Const("a"); + expr b = Const("b"); + expr x = Const("x"); + expr y = Const("y"); + env.add_var("h", Pi({N, Type()}, N >> (N >> N))); + expr F1 = Fun({{N, Type()}, {a, N}, {f, N >> N}}, + (Fun({{x, N}, {y, N}}, Eq(f(m1), y)))(a)); + metavar_env menv2 = menv; + menv2.assign(0, h(Var(4), Var(1), Var(3))); + normalizer norm(env); + env.add_var("M", Type()); + expr M = Const("M"); + std::cout << norm(F1) << "\n"; + std::cout << instantiate_metavars(norm(F1), menv2) << "\n"; + std::cout << instantiate_metavars(F1, menv2) << "\n"; + std::cout << norm(instantiate_metavars(F1, menv2)) << "\n"; + lean_assert(instantiate_metavars(norm(F1), menv2) == + norm(instantiate_metavars(F1, menv2))); + expr F2 = (Fun({{N, Type()}, {f, N >> N}, {a, N}, {b, N}}, + (Fun({{x, N}, {y, N}}, Eq(f(m1), y)))(a, m2)))(M); + std::cout << norm(F2) << "\n"; + expr F3 = (Fun({{N, Type()}, {f, N >> N}, {a, N}, {b, N}}, + (Fun({{x, N}, {y, N}}, Eq(f(m1), y)))(b, m2)))(M); + std::cout << norm(F3) << "\n"; +} + +static void tst15() { + environment env; + metavar_env menv; + normalizer norm(env); + expr m1 = menv.mk_metavar(); + expr f = Const("f"); + expr x = Const("x"); + expr y = Const("y"); + expr z = Const("z"); + expr N = Const("N"); + env.add_var("N", Type()); + env.add_var("f", Type() >> Type()); + expr F = Fun({z, Type()}, Fun({{x, Type()}, {y, Type()}}, f(m1))(N, N)); + menv.assign(0, Var(2)); + std::cout << norm(F) << "\n"; + std::cout << instantiate_metavars(norm(F), menv) << "\n"; + std::cout << norm(instantiate_metavars(F, menv)) << "\n"; + lean_assert(instantiate_metavars(norm(F), menv) == + norm(instantiate_metavars(F, menv))); +} + +static void tst16() { + environment env; + metavar_env menv; + normalizer norm(env); + context ctx; + ctx = extend(ctx, "w", Type()); + expr m1 = menv.mk_metavar(); + expr f = Const("f"); + expr x = Const("x"); + expr y = Const("y"); + expr z = Const("z"); + expr N = Const("N"); + env.add_var("N", Type()); + expr F = Fun({z, Type()}, Fun({{x, Type()}, {y, Type()}}, m1)(N, N)); + menv.assign(0, Var(3)); + std::cout << norm(F, ctx) << "\n"; + std::cout << instantiate_metavars(norm(F, ctx), menv) << "\n"; + std::cout << norm(instantiate_metavars(F, menv), ctx) << "\n"; +} + +static void tst17() { + environment env; + metavar_env menv; + normalizer norm(env); + context ctx; + ctx = extend(ctx, "w1", Type()); + ctx = extend(ctx, "w2", Type()); + ctx = extend(ctx, "w3", Type()); + ctx = extend(ctx, "w4", Type()); + expr m1 = menv.mk_metavar(); + expr f = Const("f"); + expr x = Const("x"); + expr y = Const("y"); + expr z = Const("z"); + expr N = Const("N"); + env.add_var("N", Type()); + expr F = Fun({z, Type()}, Fun({{x, Type()}, {y, Type()}}, m1)(N, N)); + metavar_env menv2 = menv; + menv.assign(0, Var(3)); + std::cout << norm(F, ctx) << "\n"; + std::cout << instantiate_metavars(norm(F, ctx), menv) << "\n"; + std::cout << norm(instantiate_metavars(F, menv), ctx) << "\n"; + F = Fun({z, Type()}, Fun({{x, Type()}, {y, Type()}, {x, Type()}, {y, Type()}, {x, Type()}}, m1)(N, N, N, N, N)); + lean_assert(instantiate_metavars(norm(F, ctx), menv) == + norm(instantiate_metavars(F, menv), ctx)); + std::cout << "----------------------\n"; + menv2.assign(0, Var(8)); + std::cout << norm(F, ctx) << "\n"; + std::cout << instantiate_metavars(norm(F, ctx), menv2) << "\n"; + std::cout << norm(instantiate_metavars(F, menv2), ctx) << "\n"; + lean_assert(instantiate_metavars(norm(F, ctx), menv2) == + norm(instantiate_metavars(F, menv2), ctx)); } int main() { @@ -277,5 +398,11 @@ int main() { tst10(); tst11(); tst12(); + tst13(); + tst14(); + tst15(); + tst16(); + tst17(); return has_violations() ? 1 : 0; } +