diff --git a/src/tests/library/rewriter/rewriter.cpp b/src/tests/library/rewriter/rewriter.cpp index 93f03e4fe0..ce8d16a3b2 100644 --- a/src/tests/library/rewriter/rewriter.cpp +++ b/src/tests/library/rewriter/rewriter.cpp @@ -590,6 +590,53 @@ static void repeat_rewriter2_tst() { env.add_theorem("repeat_thm2", concl, proof); } +static void depth_rewriter1_tst() { + cout << "=== depth_rewriter1_tst() ===" << std::endl; + // Theorem: Pi(x y : N), x + y = y + x := ADD_COMM x y + // Term : f (a + b) + // Result : (f (b + a), ADD_COMM a b) + expr a = Const("a"); // a : Nat + expr b = Const("b"); // b : Nat + expr f1 = Const("f1"); // f : Nat -> Nat + expr f2 = Const("f2"); // f : Nat -> Nat -> Nat + expr f3 = Const("f3"); // f : Nat -> Nat -> Nat -> Nat + expr f4 = Const("f4"); // f : Nat -> Nat -> Nat -> Nat -> Nat + expr zero = nVal(0); // zero : Nat + expr a_plus_b = nAdd(a, b); + expr b_plus_a = nAdd(b, a); + expr add_comm_thm_type = Pi("x", Nat, + Pi("y", Nat, + Eq(nAdd(Const("x"), Const("y")), nAdd(Const("y"), Const("x"))))); + expr add_comm_thm_body = Const("ADD_COMM"); + + environment env = mk_toplevel(); + env.add_var("f1", Nat >> Nat); + env.add_var("f2", Nat >> (Nat >> Nat)); + env.add_var("f3", Nat >> (Nat >> (Nat >> Nat))); + env.add_var("f4", Nat >> (Nat >> (Nat >> (Nat >> Nat)))); + env.add_var("a", Nat); + env.add_var("b", Nat); + env.add_axiom("ADD_COMM", add_comm_thm_type); // ADD_COMM : Pi (x, y: N), x + y = y + z + + // Rewriting + rewriter add_comm_thm_rewriter = mk_theorem_rewriter(add_comm_thm_type, add_comm_thm_body); + rewriter try_rewriter = mk_try_rewriter(add_comm_thm_rewriter); + rewriter depth_rewriter = mk_depth_rewriter(try_rewriter); + context ctx; + + cout << "RW = " << depth_rewriter << std::endl; + + expr v = nAdd(f1(nAdd(a, b)), f3(a, b, nAdd(a, b))); + pair result = depth_rewriter(env, ctx, v); + expr concl = mk_eq(v, result.first); + expr proof = result.second; + cout << "Concl = " << concl << std::endl + << "Proof = " << proof << std::endl; + lean_assert_eq(concl, mk_eq(v, nAdd(f3(a, b, nAdd(b, a)), f1(nAdd(b, a))))); + env.add_theorem("depth_rewriter1", concl, proof); + cout << "====================================================" << std::endl; +} + int main() { theorem_rewriter1_tst(); theorem_rewriter2_tst(); @@ -602,5 +649,6 @@ int main() { app_rewriter1_tst(); repeat_rewriter1_tst(); repeat_rewriter2_tst(); + depth_rewriter1_tst(); return has_violations() ? 1 : 0; }