From 42fbe3c18c6fd86b7d6098d8d73ad573aa1ececa Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 27 Mar 2019 17:13:53 -0700 Subject: [PATCH] chore(library/init,runtime,library/compiler): add `fix` primitive back The new `partial def`s allow us to define `fix` in Lean, but the Lean implementation is not as efficient as the native one. The native one in C++ use weak pointers to prevent a closure allocation at every recursive invocation. This commit also fixes the `fixCore` helper functions that were broken after we switched to camelCase. We have updated the test `fix1.lean` to demonstrate the native implementation is faster. Here are the numbers on my desktop. ``` ./run.sh fix1.lean 24 721420279 Time for 'native fix': 816ms 721420279 Time for 'fix in lean': 1.34s ``` --- library/init/default.lean | 1 + library/init/fix.lean | 80 +++++++++++++++++++++++ library/init/lean/parser/rec.lean | 9 +-- src/library/compiler/csimp.cpp | 42 +++++++++++++ src/library/compiler/util.cpp | 10 +-- src/runtime/object.cpp | 101 ++++++++++++++++++++++++++++++ src/runtime/object.h | 10 +++ tests/playground/fix1.lean | 21 +++++++ 8 files changed, 264 insertions(+), 10 deletions(-) create mode 100644 library/init/fix.lean create mode 100644 tests/playground/fix1.lean diff --git a/library/init/default.lean b/library/init/default.lean index dc245dda4d..502cdfb017 100644 --- a/library/init/default.lean +++ b/library/init/default.lean @@ -6,3 +6,4 @@ Authors: Leonardo de Moura prelude import init.core init.control init.data.basic import init.coe init.wf init.data init.io init.util +import init.fix diff --git a/library/init/fix.lean b/library/init/fix.lean new file mode 100644 index 0000000000..0599738219 --- /dev/null +++ b/library/init/fix.lean @@ -0,0 +1,80 @@ +/- +Copyright (c) 2019 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import init.data.uint +universe u + +def bfix1 {α β : Type u} (base : α → β) (rec : (α → β) → α → β) : Nat → α → β +| 0 a := base a +| (n+1) a := rec (bfix1 n) a + +@[extern cpp inline "lean::fixpoint(#4, #5)"] +def fixCore1 {α β : Type u} (base : @& (α → β)) (rec : (α → β) → α → β) : α → β := +bfix1 base rec usizeSz + +@[inline] def fixCore {α β : Type u} (base : @& (α → β)) (rec : (α → β) → α → β) : α → β := +fixCore1 base rec + +@[inline] def fix1 {α β : Type u} [Inhabited β] (rec : (α → β) → α → β) : α → β := +fixCore1 (λ _, default β) rec + +@[inline] def fix {α β : Type u} [Inhabited β] (rec : (α → β) → α → β) : α → β := +fixCore1 (λ _, default β) rec + +def bfix2 {α₁ α₂ β : Type u} (base : α₁ → α₂ → β) (rec : (α₁ → α₂ → β) → α₁ → α₂ → β) : Nat → α₁ → α₂ → β +| 0 a₁ a₂ := base a₁ a₂ +| (n+1) a₁ a₂ := rec (bfix2 n) a₁ a₂ + +@[extern cpp inline "lean::fixpoint2(#5, #6, #7)"] +def fixCore2 {α₁ α₂ β : Type u} (base : α₁ → α₂ → β) (rec : (α₁ → α₂ → β) → α₁ → α₂ → β) : α₁ → α₂ → β := +bfix2 base rec usizeSz + +@[inline] def fix2 {α₁ α₂ β : Type u} [Inhabited β] (rec : (α₁ → α₂ → β) → α₁ → α₂ → β) : α₁ → α₂ → β := +fixCore2 (λ _ _, default β) rec + +def bfix3 {α₁ α₂ α₃ β : Type u} (base : α₁ → α₂ → α₃ → β) (rec : (α₁ → α₂ → α₃ → β) → α₁ → α₂ → α₃ → β) : Nat → α₁ → α₂ → α₃ → β +| 0 a₁ a₂ a₃ := base a₁ a₂ a₃ +| (n+1) a₁ a₂ a₃ := rec (bfix3 n) a₁ a₂ a₃ + +@[extern cpp inline "lean::fixpoint3(#6, #7, #8, #9)"] +def fixCore3 {α₁ α₂ α₃ β : Type u} (base : α₁ → α₂ → α₃ → β) (rec : (α₁ → α₂ → α₃ → β) → α₁ → α₂ → α₃ → β) : α₁ → α₂ → α₃ → β := +bfix3 base rec usizeSz + +@[inline] def fix3 {α₁ α₂ α₃ β : Type u} [Inhabited β] (rec : (α₁ → α₂ → α₃ → β) → α₁ → α₂ → α₃ → β) : α₁ → α₂ → α₃ → β := +fixCore3 (λ _ _ _, default β) rec + +def bfix4 {α₁ α₂ α₃ α₄ β : Type u} (base : α₁ → α₂ → α₃ → α₄ → β) (rec : (α₁ → α₂ → α₃ → α₄ → β) → α₁ → α₂ → α₃ → α₄ → β) : Nat → α₁ → α₂ → α₃ → α₄ → β +| 0 a₁ a₂ a₃ a₄ := base a₁ a₂ a₃ a₄ +| (n+1) a₁ a₂ a₃ a₄ := rec (bfix4 n) a₁ a₂ a₃ a₄ + +@[extern cpp inline "lean::fixpoint4(#7, #8, #9, #10, #11)"] +def fixCore4 {α₁ α₂ α₃ α₄ β : Type u} (base : α₁ → α₂ → α₃ → α₄ → β) (rec : (α₁ → α₂ → α₃ → α₄ → β) → α₁ → α₂ → α₃ → α₄ → β) : α₁ → α₂ → α₃ → α₄ → β := +bfix4 base rec usizeSz + +@[inline] def fix4 {α₁ α₂ α₃ α₄ β : Type u} [Inhabited β] (rec : (α₁ → α₂ → α₃ → α₄ → β) → α₁ → α₂ → α₃ → α₄ → β) : α₁ → α₂ → α₃ → α₄ → β := +fixCore4 (λ _ _ _ _, default β) rec + +def bfix5 {α₁ α₂ α₃ α₄ α₅ β : Type u} (base : α₁ → α₂ → α₃ → α₄ → α₅ → β) (rec : (α₁ → α₂ → α₃ → α₄ → α₅ → β) → α₁ → α₂ → α₃ → α₄ → α₅ → β) : Nat → α₁ → α₂ → α₃ → α₄ → α₅ → β +| 0 a₁ a₂ a₃ a₄ a₅ := base a₁ a₂ a₃ a₄ a₅ +| (n+1) a₁ a₂ a₃ a₄ a₅ := rec (bfix5 n) a₁ a₂ a₃ a₄ a₅ + +@[extern cpp inline "lean::fixpoint5(#8, #9, #10, #11, #12, #13)"] +def fixCore5 {α₁ α₂ α₃ α₄ α₅ β : Type u} (base : α₁ → α₂ → α₃ → α₄ → α₅ → β) (rec : (α₁ → α₂ → α₃ → α₄ → α₅ → β) → α₁ → α₂ → α₃ → α₄ → α₅ → β) : α₁ → α₂ → α₃ → α₄ → α₅ → β := +bfix5 base rec usizeSz + +@[inline] def fix5 {α₁ α₂ α₃ α₄ α₅ β : Type u} [Inhabited β] (rec : (α₁ → α₂ → α₃ → α₄ → α₅ → β) → α₁ → α₂ → α₃ → α₄ → α₅ → β) : α₁ → α₂ → α₃ → α₄ → α₅ → β := +fixCore5 (λ _ _ _ _ _, default β) rec + +def bfix6 {α₁ α₂ α₃ α₄ α₅ α₆ β : Type u} (base : α₁ → α₂ → α₃ → α₄ → α₅ → α₆ → β) (rec : (α₁ → α₂ → α₃ → α₄ → α₅ → α₆ → β) → α₁ → α₂ → α₃ → α₄ → α₅ → α₆ → β) : Nat → α₁ → α₂ → α₃ → α₄ → α₅ → α₆ → β +| 0 a₁ a₂ a₃ a₄ a₅ a₆ := base a₁ a₂ a₃ a₄ a₅ a₆ +| (n+1) a₁ a₂ a₃ a₄ a₅ a₆ := rec (bfix6 n) a₁ a₂ a₃ a₄ a₅ a₆ + +@[extern cpp inline "lean::fixpoint6(#9, #10, #11, #12, #13, #14, #15)"] +def fixCore6 {α₁ α₂ α₃ α₄ α₅ α₆ β : Type u} (base : α₁ → α₂ → α₃ → α₄ → α₅ → α₆ → β) (rec : (α₁ → α₂ → α₃ → α₄ → α₅ → α₆ → β) → α₁ → α₂ → α₃ → α₄ → α₅ → α₆ → β) : α₁ → α₂ → α₃ → α₄ → α₅ → α₆ → β := +bfix6 base rec usizeSz + +@[inline] def fix6 {α₁ α₂ α₃ α₄ α₅ α₆ β : Type u} [Inhabited β] (rec : (α₁ → α₂ → α₃ → α₄ → α₅ → α₆ → β) → α₁ → α₂ → α₃ → α₄ → α₅ → α₆ → β) : α₁ → α₂ → α₃ → α₄ → α₅ → α₆ → β := +fixCore6 (λ _ _ _ _ _ _, default β) rec diff --git a/library/init/lean/parser/rec.lean b/library/init/lean/parser/rec.lean index a3ae749bba..69f4e50d98 100644 --- a/library/init/lean/parser/rec.lean +++ b/library/init/lean/parser/rec.lean @@ -6,7 +6,7 @@ Author: Sebastian Ullrich Recursion monad transformer -/ prelude -import init.control.reader init.lean.parser.parsec +import init.control.reader init.lean.parser.parsec init.fix namespace Lean.Parser @@ -23,13 +23,10 @@ local attribute [reducible] RecT @[inline] def recurse (a : α) : RecT α δ m δ := λ f, f a -@[specialize] private partial def runAux : m δ → (α → RecT α δ m δ) → α → m δ -| b rec a := rec a (runAux b rec) - /-- Execute `x`, executing `rec a` whenever `recurse a` is called. After `maxRec` recursion steps, `base` is executed instead. -/ -@[inline] protected def run (x : RecT α δ m β) (base : Unit → m δ) (rec : α → RecT α δ m δ) : m β := -x.run (runAux (base ()) rec) +@[inline] protected def run (x : RecT α δ m β) (base : α → m δ) (rec : α → RecT α δ m δ) : m β := +x (fixCore base (λ a f, rec f a)) @[inline] protected def runParsec {γ : Type} [MonadParsec γ m] (x : RecT α δ m β) (rec : α → RecT α δ m δ) : m β := RecT.run x (λ _, MonadParsec.error "RecT.runParsec: no progress") rec diff --git a/src/library/compiler/csimp.cpp b/src/library/compiler/csimp.cpp index 967023d7b2..c74d7722ad 100644 --- a/src/library/compiler/csimp.cpp +++ b/src/library/compiler/csimp.cpp @@ -1377,6 +1377,46 @@ class csimp_fn { return mk_app(mk_constant(get_nat_add_name()), arg, mk_lit(literal(nat(1)))); } + /* + Replace `fixCore f a_1 ... a_m` + with `fixCore f a_1 ... a_m` whenever `n < m`. + This optimization is for writing reusable/generic code. For + example, we cannot write an efficient `rec_t` monad transformer + without it because we don't know the arity of `m A` when we write `rec_t`. + Remark: the runtime provides a small set of `fixCore` implementations (`i in [1, 6]`). + This methods does nothing if `m > 6`. */ + expr visit_fix_core(expr const & e, unsigned n) { + if (m_before_erasure) return visit_app_default(e); + buffer args; + expr fn = get_app_args(e, args); + lean_assert(is_constant(fn) && is_fix_core(const_name(fn))); + unsigned arity = + n + /* α_1 ... α_n Type arguments */ + 1 + /* β : Type */ + 1 + /* (base : α_1 → ... → α_n → β) */ + 1 + /* (rec : (α_1 → ... → α_n → β) → α_1 → ... → α_n → β) */ + n; /* α_1 → ... → α_n */ + if (args.size() <= arity) return visit_app_default(e); + /* This `fixCore` application is an overapplication. + The `fixCore` is implemented by the runtime, and the result + is a closure. This is bad for performance. We should + replace it with `fixCore` (if the runtime contains one) */ + unsigned num_extra = args.size() - arity; + unsigned m = n + num_extra; + optional fix_core_m = mk_enf_fix_core(m); + if (!fix_core_m) return visit_app_default(e); + buffer new_args; + /* Add α_1 ... α_n and β */ + for (unsigned i = 0; i < m+1; i++) { + new_args.push_back(mk_enf_neutral()); + } + /* `(base : α_1 → ... → α_n → β)` is not used in the runtime primitive. + So, we replace it with a neutral value :) */ + new_args.push_back(mk_enf_neutral()); + new_args.append(args.size() - n - 2, args.data() + n + 2); + return mk_app(*fix_core_m, new_args); + } + expr visit_app(expr const & e, bool is_let_val) { if (is_cases_on_app(env(), e)) { return visit_cases(e, is_let_val); @@ -1417,6 +1457,8 @@ class csimp_fn { return mk_lit(literal(nat(0))); } else if (optional r = try_inline(fn, e, is_let_val)) { return *r; + } else if (optional i = is_fix_core(n)) { + return visit_fix_core(e, *i); } else { return visit_app_default(e); } diff --git a/src/library/compiler/util.cpp b/src/library/compiler/util.cpp index 5641251a34..a04a8fa0cf 100644 --- a/src/library/compiler/util.cpp +++ b/src/library/compiler/util.cpp @@ -516,15 +516,17 @@ optional get_num_lit_ext(expr const & e) { optional is_fix_core(name const & n) { if (!n.is_atomic() || !n.is_string()) return optional(); string_ref const & r = n.get_string(); - if (r.length() != 10) return optional(); + if (r.length() != 8) return optional(); char const * s = r.data(); - if (std::strncmp(s, "fix_core_", 9) != 0 || !std::isdigit(s[9])) return optional(); - return optional(s[9] - '0'); + if (std::strncmp(s, "fixCore", 7) != 0 || !std::isdigit(s[7])) return optional(); + return optional(s[7] - '0'); } optional mk_enf_fix_core(unsigned n) { if (n == 0 || n > 6) return none_expr(); - return some_expr(mk_constant(name("fix_core").append_after(n))); + std::ostringstream s; + s << "fixCore" << n; + return some_expr(mk_constant(name(s.str()))); } void initialize_compiler_util() { diff --git a/src/runtime/object.cpp b/src/runtime/object.cpp index 979ceddccb..7d4343a61b 100644 --- a/src/runtime/object.cpp +++ b/src/runtime/object.cpp @@ -1756,6 +1756,107 @@ object * array_push(obj_arg a, obj_arg v) { return r; } +// ======================================= +// fixpoint + +static inline object * ptr_to_weak_ptr(object * p) { + return reinterpret_cast(reinterpret_cast(p) | 1); +} + +static inline object * weak_ptr_to_ptr(object * w) { + return reinterpret_cast((reinterpret_cast(w) >> 1) << 1); +} + +obj_res fixpoint_aux(obj_arg rec, obj_arg weak_k, obj_arg a) { + object * k = weak_ptr_to_ptr(weak_k); + inc(k); + return apply_2(rec, k, a); +} + +obj_res fixpoint(obj_arg rec, obj_arg a) { + object * k = alloc_closure(fixpoint_aux, 2); + inc(rec); + closure_set(k, 0, rec); + closure_set(k, 1, ptr_to_weak_ptr(k)); + object * r = apply_2(rec, k, a); + return r; +} + +obj_res fixpoint_aux2(obj_arg rec, obj_arg weak_k, obj_arg a1, obj_arg a2) { + object * k = weak_ptr_to_ptr(weak_k); + inc(k); + return apply_3(rec, k, a1, a2); +} + +obj_res fixpoint2(obj_arg rec, obj_arg a1, obj_arg a2) { + object * k = alloc_closure(fixpoint_aux2, 2); + inc(rec); + closure_set(k, 0, rec); + closure_set(k, 1, ptr_to_weak_ptr(k)); + object * r = apply_3(rec, k, a1, a2); + return r; +} + +obj_res fixpoint_aux3(obj_arg rec, obj_arg weak_k, obj_arg a1, obj_arg a2, obj_arg a3) { + object * k = weak_ptr_to_ptr(weak_k); + inc(k); + return apply_4(rec, k, a1, a2, a3); +} + +obj_res fixpoint3(obj_arg rec, obj_arg a1, obj_arg a2, obj_arg a3) { + object * k = alloc_closure(fixpoint_aux3, 2); + inc(rec); + closure_set(k, 0, rec); + closure_set(k, 1, ptr_to_weak_ptr(k)); + object * r = apply_4(rec, k, a1, a2, a3); + return r; +} + +obj_res fixpoint_aux4(obj_arg rec, obj_arg weak_k, obj_arg a1, obj_arg a2, obj_arg a3, obj_arg a4) { + object * k = weak_ptr_to_ptr(weak_k); + inc(k); + return apply_5(rec, k, a1, a2, a3, a4); +} + +obj_res fixpoint4(obj_arg rec, obj_arg a1, obj_arg a2, obj_arg a3, obj_arg a4) { + object * k = alloc_closure(fixpoint_aux4, 2); + inc(rec); + closure_set(k, 0, rec); + closure_set(k, 1, ptr_to_weak_ptr(k)); + object * r = apply_5(rec, k, a1, a2, a3, a4); + return r; +} + +obj_res fixpoint_aux5(obj_arg rec, obj_arg weak_k, obj_arg a1, obj_arg a2, obj_arg a3, obj_arg a4, obj_arg a5) { + object * k = weak_ptr_to_ptr(weak_k); + inc(k); + return apply_6(rec, k, a1, a2, a3, a4, a5); +} + +obj_res fixpoint5(obj_arg rec, obj_arg a1, obj_arg a2, obj_arg a3, obj_arg a4, obj_arg a5) { + object * k = alloc_closure(fixpoint_aux5, 2); + inc(rec); + closure_set(k, 0, rec); + closure_set(k, 1, ptr_to_weak_ptr(k)); + object * r = apply_6(rec, k, a1, a2, a3, a4, a5); + return r; +} + +obj_res fixpoint_aux6(obj_arg rec, obj_arg weak_k, obj_arg a1, obj_arg a2, obj_arg a3, obj_arg a4, obj_arg a5, obj_arg a6) { + object * k = weak_ptr_to_ptr(weak_k); + inc(k); + return apply_7(rec, k, a1, a2, a3, a4, a5, a6); +} + +obj_res fixpoint6(obj_arg rec, obj_arg a1, obj_arg a2, obj_arg a3, obj_arg a4, obj_arg a5, obj_arg a6) { + object * k = alloc_closure(fixpoint_aux6, 2); + inc(rec); + closure_set(k, 0, rec); + closure_set(k, 1, ptr_to_weak_ptr(k)); + object * r = apply_7(rec, k, a1, a2, a3, a4, a5, a6); + return r; +} + // ======================================= // Debugging helper functions diff --git a/src/runtime/object.h b/src/runtime/object.h index 43e45c6143..004c6787d9 100644 --- a/src/runtime/object.h +++ b/src/runtime/object.h @@ -652,6 +652,16 @@ inline obj_res alloc_closure(object*(*fun)(object *, object *, object *, object return alloc_closure(reinterpret_cast(fun), 8, num_fixed); } +// ======================================= +// Fixpoint + +obj_res fixpoint(obj_arg rec, obj_arg a); +obj_res fixpoint2(obj_arg rec, obj_arg a1, obj_arg a2); +obj_res fixpoint3(obj_arg rec, obj_arg a1, obj_arg a2, obj_arg a3); +obj_res fixpoint4(obj_arg rec, obj_arg a1, obj_arg a2, obj_arg a3, obj_arg a4); +obj_res fixpoint5(obj_arg rec, obj_arg a1, obj_arg a2, obj_arg a3, obj_arg a4, obj_arg a5); +obj_res fixpoint6(obj_arg rec, obj_arg a1, obj_arg a2, obj_arg a3, obj_arg a4, obj_arg a5, obj_arg a6); + // ======================================= // Array of objects diff --git a/tests/playground/fix1.lean b/tests/playground/fix1.lean new file mode 100644 index 0000000000..2079545e46 --- /dev/null +++ b/tests/playground/fix1.lean @@ -0,0 +1,21 @@ +def foo (rec : Nat → Nat → Nat) : Nat → Nat → Nat +| 0 a := a +| (n+1) a := rec n a + a + rec n (a+1) + +partial def fix' (f: (Nat → Nat → Nat) → (Nat → Nat → Nat)) : Nat → Nat → Nat +| a b := f fix' a b + +def prof {α : Type} (msg : String) (p : IO α) : IO α := +let msg := "Time for '" ++ msg ++ "':" in +timeit msg p + +def fix_test (n : Nat) : IO Unit := +IO.println (fix foo n 10) + +def fix'_test (n : Nat) : IO Unit := +IO.println (fix' foo n 10) + +def main (xs : List String) : IO Unit := +prof "native fix" (fix_test xs.head.toNat) *> +prof "fix in lean" (fix'_test xs.head.toNat) *> +pure ()