diff --git a/src/tests/library/rewriter/rewriter.cpp b/src/tests/library/rewriter/rewriter.cpp index 071efb0711..f5d362bfdf 100644 --- a/src/tests/library/rewriter/rewriter.cpp +++ b/src/tests/library/rewriter/rewriter.cpp @@ -638,11 +638,11 @@ static void depth_rewriter1_tst() { cout << "====================================================" << std::endl; } -static void lambda_rewriter1_tst() { - cout << "=== lambda_rewriter1_tst() ===" << std::endl; +static void lambda_body_rewriter_tst() { + cout << "=== lambda_body_rewriter_tst() ===" << std::endl; // Theorem: Pi(x y : N), x + y = y + x := ADD_COMM x y - // Term : f (a + b) - // Result : (f (b + a), ADD_COMM a b) + // Term : fun (x : Nat), (a + b) + // Result : fun (x : Nat), (b + a) expr a = Const("a"); // a : Nat expr b = Const("b"); // b : Nat expr f1 = Const("f1"); // f : Nat -> Nat @@ -670,58 +670,25 @@ static void lambda_rewriter1_tst() { rewriter add_comm_thm_rewriter = mk_theorem_rewriter(add_comm_thm_type, add_comm_thm_body); rewriter lambda_rewriter = mk_lambda_body_rewriter(add_comm_thm_rewriter); context ctx; - cout << "RW = " << lambda_rewriter << std::endl; expr v = mk_lambda("x", Nat, nAdd(b, a)); pair result = lambda_rewriter(env, ctx, v); expr concl = mk_eq(v, result.first); expr proof = result.second; + cout << "v = " << v << std::endl; cout << "Concl = " << concl << std::endl << "Proof = " << proof << std::endl; lean_assert_eq(concl, mk_eq(v, mk_lambda("x", Nat, nAdd(a, b)))); env->add_theorem("lambda_rewriter1", concl, proof); - cout << "====================================================" << std::endl; -} -static void lambda_rewriter2_tst() { - cout << "=== lambda_rewriter2_tst() ===" << std::endl; // Theorem: Pi(x y : N), x + y = y + x := ADD_COMM x y - // Term : f (a + b) - // Result : (f (b + a), ADD_COMM a b) - expr a = Const("a"); // a : Nat - expr b = Const("b"); // b : Nat - expr f1 = Const("f1"); // f : Nat -> Nat - expr f2 = Const("f2"); // f : Nat -> Nat -> Nat - expr f3 = Const("f3"); // f : Nat -> Nat -> Nat -> Nat - expr f4 = Const("f4"); // f : Nat -> Nat -> Nat -> Nat -> Nat - expr zero = nVal(0); // zero : Nat - expr a_plus_b = nAdd(a, b); - expr b_plus_a = nAdd(b, a); - expr add_comm_thm_type = Pi("x", Nat, - Pi("y", Nat, - Eq(nAdd(Const("x"), Const("y")), nAdd(Const("y"), Const("x"))))); - expr add_comm_thm_body = Const("ADD_COMM"); - - environment env = mk_toplevel(); - env->add_var("f1", Nat >> Nat); - env->add_var("f2", Nat >> (Nat >> Nat)); - env->add_var("f3", Nat >> (Nat >> (Nat >> Nat))); - env->add_var("f4", Nat >> (Nat >> (Nat >> (Nat >> Nat)))); - env->add_var("a", Nat); - env->add_var("b", Nat); - env->add_axiom("ADD_COMM", add_comm_thm_type); // ADD_COMM : Pi (x, y: N), x + y = y + z - - // Rewriting - rewriter add_comm_thm_rewriter = mk_theorem_rewriter(add_comm_thm_type, add_comm_thm_body); - rewriter lambda_rewriter = mk_lambda_body_rewriter(add_comm_thm_rewriter); - context ctx; - - cout << "RW = " << lambda_rewriter << std::endl; - - expr v = mk_lambda("x", Nat, nAdd(Var(0), a)); - pair result = lambda_rewriter(env, ctx, v); - expr concl = mk_eq(v, result.first); - expr proof = result.second; + // Term : fun (x : Nat), (x + a) + // Result : fun (x : Nat), (a + x) + v = mk_lambda("x", Nat, nAdd(Var(0), a)); + result = lambda_rewriter(env, ctx, v); + concl = mk_eq(v, result.first); + proof = result.second; + cout << "v = " << v << std::endl; cout << "Concl = " << concl << std::endl << "Proof = " << proof << std::endl; lean_assert_eq(concl, mk_eq(v, mk_lambda("x", Nat, nAdd(a, Var(0))))); @@ -729,6 +696,38 @@ static void lambda_rewriter2_tst() { cout << "====================================================" << std::endl; } +static void lambda_type_rewriter_tst() { + // Theorem: Pi(x y : N), x + y = y + x := ADD_COMM x y + // Term : fun (x : vec(Nat, a + b)), x + // Result : fun (x : vec(Nat, b + a)), x + cout << "=== lambda_type_rewriter_tst() ===" << std::endl; + context ctx; + environment env = mk_toplevel(); + expr a = Const("a"); // a : Nat + env->add_var("a", Nat); + expr b = Const("b"); // b : Nat + env->add_var("b", Nat); + expr vec = Const("vec"); + env->add_var("vec", Type() >> (Nat >> Type())); // vec : Type -> Nat -> Type + expr add_comm_thm_type = Pi("x", Nat, Pi("y", Nat, Eq(nAdd(Const("x"), Const("y")), nAdd(Const("y"), Const("x"))))); + expr add_comm_thm_body = Const("ADD_COMM"); + env->add_axiom("ADD_COMM", add_comm_thm_type); // ADD_COMM : Pi (x, y: N), x + y = y + z + rewriter add_comm_thm_rewriter = mk_theorem_rewriter(add_comm_thm_type, add_comm_thm_body); + rewriter try_rewriter = mk_try_rewriter(add_comm_thm_rewriter); + rewriter depth_rewriter = mk_depth_rewriter(try_rewriter); + rewriter lambda_rewriter = mk_lambda_type_rewriter(depth_rewriter); + + expr v = mk_lambda("x", vec(Nat, nAdd(a, b)), Var(0)); + pair result = lambda_rewriter(env, ctx, v); + expr concl = mk_eq(v, result.first); + expr proof = result.second; + cout << "v = " << v << std::endl; + cout << "Concl = " << concl << std::endl + << "Proof = " << proof << std::endl; + lean_assert_eq(concl, mk_eq(v, mk_lambda("x", vec(Nat, nAdd(b, a)), Var(0)))); + env->add_theorem("lambda_type_rewriter", concl, proof); + cout << "====================================================" << std::endl; +} int main() { save_stack_info(); @@ -744,7 +743,7 @@ int main() { repeat_rewriter1_tst(); repeat_rewriter2_tst(); depth_rewriter1_tst(); - lambda_rewriter1_tst(); - lambda_rewriter2_tst(); + lambda_body_rewriter_tst(); + lambda_type_rewriter_tst(); return has_violations() ? 1 : 0; }