From eda3eae18e6e5dd7dcb0cee6fa32d73cfa42e6aa Mon Sep 17 00:00:00 2001 From: Gabriel Ebner Date: Sun, 10 Jul 2022 16:11:01 +0200 Subject: [PATCH] perf: implement Expr.update* in Lean --- src/Init/Util.lean | 9 +- src/Lean/Expr.lean | 178 ++++++++++--------- src/Lean/Util/HasConstCache.lean | 4 +- src/kernel/expr.cpp | 91 ---------- tests/lean/updateExprIssue.lean.expected.out | 21 ++- 5 files changed, 125 insertions(+), 178 deletions(-) diff --git a/src/Init/Util.lean b/src/Init/Util.lean index 2902ce9586..b092ac9afb 100644 --- a/src/Init/Util.lean +++ b/src/Init/Util.lean @@ -42,8 +42,15 @@ unsafe def ptrAddrUnsafe {α : Type u} (a : @& α) : USize := 0 @[inline] unsafe def withPtrAddrUnsafe {α : Type u} {β : Type v} (a : α) (k : USize → β) (h : ∀ u₁ u₂, k u₁ = k u₂) : β := k (ptrAddrUnsafe a) +@[inline] unsafe def ptrEq (a b : α) : Bool := ptrAddrUnsafe a == ptrAddrUnsafe b + +unsafe def ptrEqList : (as bs : List α) → Bool + | [], [] => true + | a::as, b::bs => if ptrEq a b then ptrEqList as bs else false + | _, _ => false + @[inline] unsafe def withPtrEqUnsafe {α : Type u} (a b : α) (k : Unit → Bool) (h : a = b → k () = true) : Bool := - if ptrAddrUnsafe a == ptrAddrUnsafe b then true else k () + if ptrEq a b then true else k () @[implementedBy withPtrEqUnsafe] def withPtrEq {α : Type u} (a b : α) (k : Unit → Bool) (h : a = b → k () = true) : Bool := k () diff --git a/src/Lean/Expr.lean b/src/Lean/Expr.lean index 2947ebbdc0..82f09a7c32 100644 --- a/src/Lean/Expr.lean +++ b/src/Lean/Expr.lean @@ -732,9 +732,6 @@ opaque eqv (a : @& Expr) (b : @& Expr) : Bool instance : BEq Expr where beq := Expr.eqv -protected unsafe def ptrEq (a b : Expr) : Bool := - ptrAddrUnsafe a == ptrAddrUnsafe b - /-- Return true iff `a` and `b` are equal. Binder names and annotations are taking into account. @@ -1393,101 +1390,120 @@ def containsFVar (e : Expr) (fvarId : FVarId) : Bool := e.hasAnyFVar (· == fvarId) /-! -The update functions here are defined using C code. They will try to avoid -allocating new values using pointer equality. -The hypotheses `(h : e.is...)` are used to ensure Lean will not crash -at runtime. -The `update*!` functions are inlined and provide a convenient way of using the -update proofs without providing proofs. -Note that if they are used under a match-expression, the compiler will eliminate -the double-match. +The update functions try to avoid allocating new values using pointer equality. +Note that if the `update*!` functions are used under a match-expression, +the compiler will eliminate the double-match. -/ -@[extern "lean_expr_update_app"] -def updateApp (e : Expr) (newFn : Expr) (newArg : Expr) (h : e.isApp) : Expr := - mkApp newFn newArg - -@[inline] def updateApp! (e : Expr) (newFn : Expr) (newArg : Expr) : Expr := - match h : e with - | app .. => updateApp e newFn newArg (h ▸ rfl) - | _ => panic! "application expected" - -@[extern "lean_expr_update_const"] -def updateConst (e : Expr) (newLevels : List Level) (h : e.isConst) : Expr := - mkConst e.constName! newLevels - -@[inline] def updateConst! (e : Expr) (newLevels : List Level) : Expr := - match h : e with - | const .. => updateConst e newLevels (h ▸ rfl) - | _ => panic! "constant expected" - -@[extern "lean_expr_update_sort"] -def updateSort (e : Expr) (newLevel : Level) (h : e.isSort) : Expr := - mkSort newLevel - -@[inline] def updateSort! (e : Expr) (newLevel : Level) : Expr := - match h : e with - | sort .. => updateSort e newLevel (h ▸ rfl) - | _ => panic! "level expected" - -@[extern "lean_expr_update_proj"] -def updateProj (e : Expr) (newExpr : Expr) (h : e.isProj) : Expr := +@[inline] private unsafe def updateApp!Impl (e : Expr) (newFn : Expr) (newArg : Expr) : Expr := match e with - | proj s i .. => mkProj s i newExpr - | _ => e -- unreachable because of `h` + | app fn arg => if ptrEq fn newFn && ptrEq arg newArg then e else mkApp newFn newArg + | _ => panic! "application expected" -@[extern "lean_expr_update_mdata"] -def updateMData (e : Expr) (newExpr : Expr) (h : e.isMData) : Expr := +@[implementedBy updateApp!Impl] +def updateApp! (e : Expr) (newFn : Expr) (newArg : Expr) : Expr := match e with - | mdata d .. => mkMData d newExpr - | _ => e -- unreachable because of `h` + | app _ _ => mkApp newFn newArg + | _ => panic! "application expected" -@[inline] def updateMData! (e : Expr) (newExpr : Expr) : Expr := - match h : e with - | mdata .. => updateMData e newExpr (h ▸ rfl) - | _ => panic! "mdata expected" +@[inline] private unsafe def updateConst!Impl (e : Expr) (newLevels : List Level) : Expr := + match e with + | const n ls => if ptrEqList ls newLevels then e else mkConst n newLevels + | _ => panic! "constant expected" -@[inline] def updateProj! (e : Expr) (newExpr : Expr) : Expr := - match h : e with - | proj .. => updateProj e newExpr (h ▸ rfl) - | _ => panic! "proj expected" +@[implementedBy updateConst!Impl] +def updateConst! (e : Expr) (newLevels : List Level) : Expr := + match e with + | const n _ => mkConst n newLevels + | _ => panic! "constant expected" -@[extern "lean_expr_update_forall"] -def updateForall (e : Expr) (newBinfo : BinderInfo) (newDomain : Expr) (newBody : Expr) (h : e.isForall) : Expr := - mkForall e.bindingName! newBinfo newDomain newBody +@[inline] private unsafe def updateSort!Impl (e : Expr) (u' : Level) : Expr := + match e with + | sort u => if ptrEq u u' then e else mkSort u' + | _ => panic! "level expected" -@[inline] def updateForall! (e : Expr) (newBinfo : BinderInfo) (newDomain : Expr) (newBody : Expr) : Expr := - match h : e with - | forallE .. => updateForall e newBinfo newDomain newBody (h ▸ rfl) - | _ => panic! "forall expected" +@[implementedBy updateSort!Impl] +def updateSort! (e : Expr) (newLevel : Level) : Expr := + match e with + | sort _ => mkSort newLevel + | _ => panic! "level expected" -@[inline] def updateForallE! (e : Expr) (newDomain : Expr) (newBody : Expr) : Expr := - match h : e with - | forallE _ _ _ c => updateForall e c newDomain newBody (h ▸ rfl) +@[inline] private unsafe def updateMData!Impl (e : Expr) (newExpr : Expr) : Expr := + match e with + | mdata d a => if ptrEq a newExpr then e else mkMData d newExpr + | _ => panic! "mdata expected" + +@[implementedBy updateMData!Impl] +def updateMData! (e : Expr) (newExpr : Expr) : Expr := + match e with + | mdata d _ => mkMData d newExpr + | _ => panic! "mdata expected" + +@[inline] private unsafe def updateProj!Impl (e : Expr) (newExpr : Expr) : Expr := + match e with + | proj s i a => if ptrEq a newExpr then e else mkProj s i newExpr + | _ => panic! "proj expected" + +@[implementedBy updateProj!Impl] +def updateProj! (e : Expr) (newExpr : Expr) : Expr := + match e with + | proj s i _ => mkProj s i newExpr + | _ => panic! "proj expected" + +@[inline] private unsafe def updateForall!Impl (e : Expr) (newBinfo : BinderInfo) (newDomain : Expr) (newBody : Expr) : Expr := + match e with + | forallE n d b bi => + if ptrEq d newDomain && ptrEq b newBody && bi == newBinfo then + e + else + mkForall n newBinfo newDomain newBody | _ => panic! "forall expected" -@[extern "lean_expr_update_lambda"] -def updateLambda (e : Expr) (newBinfo : BinderInfo) (newDomain : Expr) (newBody : Expr) (h : e.isLambda) : Expr := - mkLambda e.bindingName! newBinfo newDomain newBody +@[implementedBy updateForall!Impl] +def updateForall! (e : Expr) (newBinfo : BinderInfo) (newDomain : Expr) (newBody : Expr) : Expr := + match e with + | forallE n _ _ _ => mkForall n newBinfo newDomain newBody + | _ => panic! "forall expected" -@[inline] def updateLambda! (e : Expr) (newBinfo : BinderInfo) (newDomain : Expr) (newBody : Expr) : Expr := - match h : e with - | lam .. => updateLambda e newBinfo newDomain newBody (h ▸ rfl) - | _ => panic! "lambda expected" +@[inline] def updateForallE! (e : Expr) (newDomain : Expr) (newBody : Expr) : Expr := + match e with + | forallE n d b bi => updateForall! (forallE n d b bi) bi newDomain newBody + | _ => panic! "forall expected" -@[inline] def updateLambdaE! (e : Expr) (newDomain : Expr) (newBody : Expr) : Expr := - match h : e with - | lam _ _ _ c => updateLambda e c newDomain newBody (h ▸ rfl) +@[inline] private unsafe def updateLambda!Impl (e : Expr) (newBinfo : BinderInfo) (newDomain : Expr) (newBody : Expr) : Expr := + match e with + | lam n d b bi => + if ptrEq d newDomain && ptrEq b newBody && bi == newBinfo then + e + else + mkLambda n newBinfo newDomain newBody | _ => panic! "lambda expected" -@[extern "lean_expr_update_let"] -def updateLet (e : Expr) (newType : Expr) (newVal : Expr) (newBody : Expr) (h : e.isLet) : Expr := - mkLet e.letName! newType newVal newBody +@[implementedBy updateLambda!Impl] +def updateLambda! (e : Expr) (newBinfo : BinderInfo) (newDomain : Expr) (newBody : Expr) : Expr := + match e with + | lam n _ _ _ => mkLambda n newBinfo newDomain newBody + | _ => panic! "lambda expected" -@[inline] def updateLet! (e : Expr) (newType : Expr) (newVal : Expr) (newBody : Expr) : Expr := - match h : e with - | letE .. => updateLet e newType newVal newBody (h ▸ rfl) - | _ => panic! "let expression expected" +@[inline] def updateLambdaE! (e : Expr) (newDomain : Expr) (newBody : Expr) : Expr := + match e with + | lam n d b bi => updateLambda! (lam n d b bi) bi newDomain newBody + | _ => panic! "lambda expected" + +@[inline] private unsafe def updateLet!Impl (e : Expr) (newType : Expr) (newVal : Expr) (newBody : Expr) : Expr := + match e with + | letE n t v b nonDep => + if ptrEq t newType && ptrEq v newVal && ptrEq b newBody then + e + else + letE n newType newVal newBody nonDep + | _ => panic! "let expression expected" + +@[implementedBy updateLet!Impl] +def updateLet! (e : Expr) (newType : Expr) (newVal : Expr) (newBody : Expr) : Expr := + match e with + | letE n _ _ _ c => letE n newType newVal newBody c + | _ => panic! "let expression expected" def updateFn : Expr → Expr → Expr | e@(app f a), g => e.updateApp! (updateFn f g) a diff --git a/src/Lean/Util/HasConstCache.lean b/src/Lean/Util/HasConstCache.lean index a131d71728..bad4046f9b 100644 --- a/src/Lean/Util/HasConstCache.lean +++ b/src/Lean/Util/HasConstCache.lean @@ -11,7 +11,7 @@ structure HasConstCache (declName : Name) where cache : Std.HashMapImp Expr Bool := Std.mkHashMapImp unsafe def HasConstCache.containsUnsafe (e : Expr) : StateM (HasConstCache declName) Bool := do - if let some r := (← get).cache.find? (beq := ⟨Expr.ptrEq⟩) e then + if let some r := (← get).cache.find? (beq := ⟨ptrEq⟩) e then return r else match e with @@ -25,7 +25,7 @@ unsafe def HasConstCache.containsUnsafe (e : Expr) : StateM (HasConstCache declN | _ => return false where cache (e : Expr) (r : Bool) : StateM (HasConstCache declName) Bool := do - modify fun ⟨cache⟩ => ⟨cache.insert (beq := ⟨Expr.ptrEq⟩) e r |>.1⟩ + modify fun ⟨cache⟩ => ⟨cache.insert (beq := ⟨ptrEq⟩) e r |>.1⟩ return r /-- diff --git a/src/kernel/expr.cpp b/src/kernel/expr.cpp index 4904d97f24..ebce32cbf1 100644 --- a/src/kernel/expr.cpp +++ b/src/kernel/expr.cpp @@ -330,97 +330,6 @@ expr update_let(expr const & e, expr const & new_type, expr const & new_value, e return e; } -extern "C" LEAN_EXPORT object * lean_expr_update_mdata(obj_arg e, obj_arg new_expr) { - if (mdata_expr(TO_REF(expr, e)).raw() != new_expr) { - object * r = lean_expr_mk_mdata(mdata_data(TO_REF(expr, e)).to_obj_arg(), new_expr); - lean_dec_ref(e); - return r; - } else { - lean_dec_ref(new_expr); - return e; - } -} - -extern "C" LEAN_EXPORT object * lean_expr_update_const(obj_arg e, obj_arg new_levels) { - if (const_levels(TO_REF(expr, e)).raw() != new_levels) { - object * r = lean_expr_mk_const(const_name(TO_REF(expr, e)).to_obj_arg(), new_levels); - lean_dec_ref(e); - return r; - } else { - lean_dec(new_levels); - return e; - } -} - -extern "C" LEAN_EXPORT object * lean_expr_update_sort(obj_arg e, obj_arg new_level) { - if (sort_level(TO_REF(expr, e)).raw() != new_level) { - object * r = lean_expr_mk_sort(new_level); - lean_dec_ref(e); - return r; - } else { - lean_dec(new_level); - return e; - } -} - -extern "C" LEAN_EXPORT object * lean_expr_update_proj(obj_arg e, obj_arg new_expr) { - if (proj_expr(TO_REF(expr, e)).raw() != new_expr) { - object * r = lean_expr_mk_proj(proj_sname(TO_REF(expr, e)).to_obj_arg(), proj_idx(TO_REF(expr, e)).to_obj_arg(), new_expr); - lean_dec_ref(e); - return r; - } else { - lean_dec_ref(new_expr); - return e; - } -} - -extern "C" LEAN_EXPORT object * lean_expr_update_app(obj_arg e, obj_arg new_fn, obj_arg new_arg) { - if (app_fn(TO_REF(expr, e)).raw() != new_fn || app_arg(TO_REF(expr, e)).raw() != new_arg) { - object * r = lean_expr_mk_app(new_fn, new_arg); - lean_dec_ref(e); - return r; - } else { - lean_dec_ref(new_fn); lean_dec_ref(new_arg); - return e; - } -} - -extern "C" LEAN_EXPORT object * lean_expr_update_forall(obj_arg e, uint8 new_binfo, obj_arg new_domain, obj_arg new_body) { - if (binding_domain(TO_REF(expr, e)).raw() != new_domain || binding_body(TO_REF(expr, e)).raw() != new_body || - binding_info(TO_REF(expr, e)) != static_cast(new_binfo)) { - object * r = lean_expr_mk_forall(binding_name(TO_REF(expr, e)).to_obj_arg(), new_domain, new_body, new_binfo); - lean_dec_ref(e); - return r; - } else { - lean_dec_ref(new_domain); lean_dec_ref(new_body); - return e; - } -} - -extern "C" LEAN_EXPORT object * lean_expr_update_lambda(obj_arg e, uint8 new_binfo, obj_arg new_domain, obj_arg new_body) { - if (binding_domain(TO_REF(expr, e)).raw() != new_domain || binding_body(TO_REF(expr, e)).raw() != new_body || - binding_info(TO_REF(expr, e)) != static_cast(new_binfo)) { - object * r = lean_expr_mk_lambda(binding_name(TO_REF(expr, e)).to_obj_arg(), new_domain, new_body, new_binfo); - lean_dec_ref(e); - return r; - } else { - lean_dec_ref(new_domain); lean_dec_ref(new_body); - return e; - } -} - -extern "C" LEAN_EXPORT object * lean_expr_update_let(obj_arg e, obj_arg new_type, obj_arg new_val, obj_arg new_body) { - if (let_type(TO_REF(expr, e)).raw() != new_type || let_value(TO_REF(expr, e)).raw() != new_val || - let_body(TO_REF(expr, e)).raw() != new_body) { - object * r = lean_expr_mk_let(let_name(TO_REF(expr, e)).to_obj_arg(), new_type, new_val, new_body); - lean_dec_ref(e); - return r; - } else { - lean_dec_ref(new_type); lean_dec_ref(new_val); lean_dec_ref(new_body); - return e; - } -} - extern "C" object * lean_expr_consume_type_annotations(obj_arg e); expr consume_type_annotations(expr const & e) { return expr(lean_expr_consume_type_annotations(e.to_obj_arg())); } diff --git a/tests/lean/updateExprIssue.lean.expected.out b/tests/lean/updateExprIssue.lean.expected.out index c15dbe0275..10537d1e32 100644 --- a/tests/lean/updateExprIssue.lean.expected.out +++ b/tests/lean/updateExprIssue.lean.expected.out @@ -13,9 +13,24 @@ def sefFn (x_1 : obj) (x_2 : obj) : obj := Lean.Expr.const._impl → ret x_1 Lean.Expr.app._impl → - let x_3 : obj := proj[1] x_1; - let x_4 : obj := Lean.Expr.updateApp x_1 x_2 x_3 ◾; - ret x_4 + let x_3 : obj := proj[0] x_1; + let x_4 : obj := proj[1] x_1; + let x_5 : usize := ptrAddrUnsafe ◾ x_3; + let x_6 : usize := ptrAddrUnsafe ◾ x_2; + let x_7 : u8 := USize.decEq x_5 x_6; + case x_7 : obj of + Bool.false → + let x_8 : obj := Lean.Expr.app._override x_2 x_4; + ret x_8 + Bool.true → + let x_9 : usize := ptrAddrUnsafe ◾ x_4; + let x_10 : u8 := USize.decEq x_9 x_9; + case x_10 : obj of + Bool.false → + let x_11 : obj := Lean.Expr.app._override x_2 x_4; + ret x_11 + Bool.true → + ret x_1 Lean.Expr.lam._impl → ret x_1 Lean.Expr.forallE._impl →