perf: implement Expr.update* in Lean

This commit is contained in:
Gabriel Ebner 2022-07-10 16:11:01 +02:00 committed by Leonardo de Moura
parent 1611cf63c3
commit eda3eae18e
5 changed files with 125 additions and 178 deletions

View file

@ -42,8 +42,15 @@ unsafe def ptrAddrUnsafe {α : Type u} (a : @& α) : USize := 0
@[inline] unsafe def withPtrAddrUnsafe {α : Type u} {β : Type v} (a : α) (k : USize → β) (h : ∀ u₁ u₂, k u₁ = k u₂) : β :=
k (ptrAddrUnsafe a)
@[inline] unsafe def ptrEq (a b : α) : Bool := ptrAddrUnsafe a == ptrAddrUnsafe b
unsafe def ptrEqList : (as bs : List α) → Bool
| [], [] => true
| a::as, b::bs => if ptrEq a b then ptrEqList as bs else false
| _, _ => false
@[inline] unsafe def withPtrEqUnsafe {α : Type u} (a b : α) (k : Unit → Bool) (h : a = b → k () = true) : Bool :=
if ptrAddrUnsafe a == ptrAddrUnsafe b then true else k ()
if ptrEq a b then true else k ()
@[implementedBy withPtrEqUnsafe]
def withPtrEq {α : Type u} (a b : α) (k : Unit → Bool) (h : a = b → k () = true) : Bool := k ()

View file

@ -732,9 +732,6 @@ opaque eqv (a : @& Expr) (b : @& Expr) : Bool
instance : BEq Expr where
beq := Expr.eqv
protected unsafe def ptrEq (a b : Expr) : Bool :=
ptrAddrUnsafe a == ptrAddrUnsafe b
/--
Return true iff `a` and `b` are equal.
Binder names and annotations are taking into account.
@ -1393,101 +1390,120 @@ def containsFVar (e : Expr) (fvarId : FVarId) : Bool :=
e.hasAnyFVar (· == fvarId)
/-!
The update functions here are defined using C code. They will try to avoid
allocating new values using pointer equality.
The hypotheses `(h : e.is...)` are used to ensure Lean will not crash
at runtime.
The `update*!` functions are inlined and provide a convenient way of using the
update proofs without providing proofs.
Note that if they are used under a match-expression, the compiler will eliminate
the double-match.
The update functions try to avoid allocating new values using pointer equality.
Note that if the `update*!` functions are used under a match-expression,
the compiler will eliminate the double-match.
-/
@[extern "lean_expr_update_app"]
def updateApp (e : Expr) (newFn : Expr) (newArg : Expr) (h : e.isApp) : Expr :=
mkApp newFn newArg
@[inline] def updateApp! (e : Expr) (newFn : Expr) (newArg : Expr) : Expr :=
match h : e with
| app .. => updateApp e newFn newArg (h ▸ rfl)
| _ => panic! "application expected"
@[extern "lean_expr_update_const"]
def updateConst (e : Expr) (newLevels : List Level) (h : e.isConst) : Expr :=
mkConst e.constName! newLevels
@[inline] def updateConst! (e : Expr) (newLevels : List Level) : Expr :=
match h : e with
| const .. => updateConst e newLevels (h ▸ rfl)
| _ => panic! "constant expected"
@[extern "lean_expr_update_sort"]
def updateSort (e : Expr) (newLevel : Level) (h : e.isSort) : Expr :=
mkSort newLevel
@[inline] def updateSort! (e : Expr) (newLevel : Level) : Expr :=
match h : e with
| sort .. => updateSort e newLevel (h ▸ rfl)
| _ => panic! "level expected"
@[extern "lean_expr_update_proj"]
def updateProj (e : Expr) (newExpr : Expr) (h : e.isProj) : Expr :=
@[inline] private unsafe def updateApp!Impl (e : Expr) (newFn : Expr) (newArg : Expr) : Expr :=
match e with
| proj s i .. => mkProj s i newExpr
| _ => e -- unreachable because of `h`
| app fn arg => if ptrEq fn newFn && ptrEq arg newArg then e else mkApp newFn newArg
| _ => panic! "application expected"
@[extern "lean_expr_update_mdata"]
def updateMData (e : Expr) (newExpr : Expr) (h : e.isMData) : Expr :=
@[implementedBy updateApp!Impl]
def updateApp! (e : Expr) (newFn : Expr) (newArg : Expr) : Expr :=
match e with
| mdata d .. => mkMData d newExpr
| _ => e -- unreachable because of `h`
| app _ _ => mkApp newFn newArg
| _ => panic! "application expected"
@[inline] def updateMData! (e : Expr) (newExpr : Expr) : Expr :=
match h : e with
| mdata .. => updateMData e newExpr (h ▸ rfl)
| _ => panic! "mdata expected"
@[inline] private unsafe def updateConst!Impl (e : Expr) (newLevels : List Level) : Expr :=
match e with
| const n ls => if ptrEqList ls newLevels then e else mkConst n newLevels
| _ => panic! "constant expected"
@[inline] def updateProj! (e : Expr) (newExpr : Expr) : Expr :=
match h : e with
| proj .. => updateProj e newExpr (h ▸ rfl)
| _ => panic! "proj expected"
@[implementedBy updateConst!Impl]
def updateConst! (e : Expr) (newLevels : List Level) : Expr :=
match e with
| const n _ => mkConst n newLevels
| _ => panic! "constant expected"
@[extern "lean_expr_update_forall"]
def updateForall (e : Expr) (newBinfo : BinderInfo) (newDomain : Expr) (newBody : Expr) (h : e.isForall) : Expr :=
mkForall e.bindingName! newBinfo newDomain newBody
@[inline] private unsafe def updateSort!Impl (e : Expr) (u' : Level) : Expr :=
match e with
| sort u => if ptrEq u u' then e else mkSort u'
| _ => panic! "level expected"
@[inline] def updateForall! (e : Expr) (newBinfo : BinderInfo) (newDomain : Expr) (newBody : Expr) : Expr :=
match h : e with
| forallE .. => updateForall e newBinfo newDomain newBody (h ▸ rfl)
| _ => panic! "forall expected"
@[implementedBy updateSort!Impl]
def updateSort! (e : Expr) (newLevel : Level) : Expr :=
match e with
| sort _ => mkSort newLevel
| _ => panic! "level expected"
@[inline] def updateForallE! (e : Expr) (newDomain : Expr) (newBody : Expr) : Expr :=
match h : e with
| forallE _ _ _ c => updateForall e c newDomain newBody (h ▸ rfl)
@[inline] private unsafe def updateMData!Impl (e : Expr) (newExpr : Expr) : Expr :=
match e with
| mdata d a => if ptrEq a newExpr then e else mkMData d newExpr
| _ => panic! "mdata expected"
@[implementedBy updateMData!Impl]
def updateMData! (e : Expr) (newExpr : Expr) : Expr :=
match e with
| mdata d _ => mkMData d newExpr
| _ => panic! "mdata expected"
@[inline] private unsafe def updateProj!Impl (e : Expr) (newExpr : Expr) : Expr :=
match e with
| proj s i a => if ptrEq a newExpr then e else mkProj s i newExpr
| _ => panic! "proj expected"
@[implementedBy updateProj!Impl]
def updateProj! (e : Expr) (newExpr : Expr) : Expr :=
match e with
| proj s i _ => mkProj s i newExpr
| _ => panic! "proj expected"
@[inline] private unsafe def updateForall!Impl (e : Expr) (newBinfo : BinderInfo) (newDomain : Expr) (newBody : Expr) : Expr :=
match e with
| forallE n d b bi =>
if ptrEq d newDomain && ptrEq b newBody && bi == newBinfo then
e
else
mkForall n newBinfo newDomain newBody
| _ => panic! "forall expected"
@[extern "lean_expr_update_lambda"]
def updateLambda (e : Expr) (newBinfo : BinderInfo) (newDomain : Expr) (newBody : Expr) (h : e.isLambda) : Expr :=
mkLambda e.bindingName! newBinfo newDomain newBody
@[implementedBy updateForall!Impl]
def updateForall! (e : Expr) (newBinfo : BinderInfo) (newDomain : Expr) (newBody : Expr) : Expr :=
match e with
| forallE n _ _ _ => mkForall n newBinfo newDomain newBody
| _ => panic! "forall expected"
@[inline] def updateLambda! (e : Expr) (newBinfo : BinderInfo) (newDomain : Expr) (newBody : Expr) : Expr :=
match h : e with
| lam .. => updateLambda e newBinfo newDomain newBody (h ▸ rfl)
| _ => panic! "lambda expected"
@[inline] def updateForallE! (e : Expr) (newDomain : Expr) (newBody : Expr) : Expr :=
match e with
| forallE n d b bi => updateForall! (forallE n d b bi) bi newDomain newBody
| _ => panic! "forall expected"
@[inline] def updateLambdaE! (e : Expr) (newDomain : Expr) (newBody : Expr) : Expr :=
match h : e with
| lam _ _ _ c => updateLambda e c newDomain newBody (h ▸ rfl)
@[inline] private unsafe def updateLambda!Impl (e : Expr) (newBinfo : BinderInfo) (newDomain : Expr) (newBody : Expr) : Expr :=
match e with
| lam n d b bi =>
if ptrEq d newDomain && ptrEq b newBody && bi == newBinfo then
e
else
mkLambda n newBinfo newDomain newBody
| _ => panic! "lambda expected"
@[extern "lean_expr_update_let"]
def updateLet (e : Expr) (newType : Expr) (newVal : Expr) (newBody : Expr) (h : e.isLet) : Expr :=
mkLet e.letName! newType newVal newBody
@[implementedBy updateLambda!Impl]
def updateLambda! (e : Expr) (newBinfo : BinderInfo) (newDomain : Expr) (newBody : Expr) : Expr :=
match e with
| lam n _ _ _ => mkLambda n newBinfo newDomain newBody
| _ => panic! "lambda expected"
@[inline] def updateLet! (e : Expr) (newType : Expr) (newVal : Expr) (newBody : Expr) : Expr :=
match h : e with
| letE .. => updateLet e newType newVal newBody (h ▸ rfl)
| _ => panic! "let expression expected"
@[inline] def updateLambdaE! (e : Expr) (newDomain : Expr) (newBody : Expr) : Expr :=
match e with
| lam n d b bi => updateLambda! (lam n d b bi) bi newDomain newBody
| _ => panic! "lambda expected"
@[inline] private unsafe def updateLet!Impl (e : Expr) (newType : Expr) (newVal : Expr) (newBody : Expr) : Expr :=
match e with
| letE n t v b nonDep =>
if ptrEq t newType && ptrEq v newVal && ptrEq b newBody then
e
else
letE n newType newVal newBody nonDep
| _ => panic! "let expression expected"
@[implementedBy updateLet!Impl]
def updateLet! (e : Expr) (newType : Expr) (newVal : Expr) (newBody : Expr) : Expr :=
match e with
| letE n _ _ _ c => letE n newType newVal newBody c
| _ => panic! "let expression expected"
def updateFn : Expr → Expr → Expr
| e@(app f a), g => e.updateApp! (updateFn f g) a

View file

@ -11,7 +11,7 @@ structure HasConstCache (declName : Name) where
cache : Std.HashMapImp Expr Bool := Std.mkHashMapImp
unsafe def HasConstCache.containsUnsafe (e : Expr) : StateM (HasConstCache declName) Bool := do
if let some r := (← get).cache.find? (beq := ⟨Expr.ptrEq⟩) e then
if let some r := (← get).cache.find? (beq := ⟨ptrEq⟩) e then
return r
else
match e with
@ -25,7 +25,7 @@ unsafe def HasConstCache.containsUnsafe (e : Expr) : StateM (HasConstCache declN
| _ => return false
where
cache (e : Expr) (r : Bool) : StateM (HasConstCache declName) Bool := do
modify fun ⟨cache⟩ => ⟨cache.insert (beq := ⟨Expr.ptrEq⟩) e r |>.1⟩
modify fun ⟨cache⟩ => ⟨cache.insert (beq := ⟨ptrEq⟩) e r |>.1⟩
return r
/--

View file

@ -330,97 +330,6 @@ expr update_let(expr const & e, expr const & new_type, expr const & new_value, e
return e;
}
extern "C" LEAN_EXPORT object * lean_expr_update_mdata(obj_arg e, obj_arg new_expr) {
if (mdata_expr(TO_REF(expr, e)).raw() != new_expr) {
object * r = lean_expr_mk_mdata(mdata_data(TO_REF(expr, e)).to_obj_arg(), new_expr);
lean_dec_ref(e);
return r;
} else {
lean_dec_ref(new_expr);
return e;
}
}
extern "C" LEAN_EXPORT object * lean_expr_update_const(obj_arg e, obj_arg new_levels) {
if (const_levels(TO_REF(expr, e)).raw() != new_levels) {
object * r = lean_expr_mk_const(const_name(TO_REF(expr, e)).to_obj_arg(), new_levels);
lean_dec_ref(e);
return r;
} else {
lean_dec(new_levels);
return e;
}
}
extern "C" LEAN_EXPORT object * lean_expr_update_sort(obj_arg e, obj_arg new_level) {
if (sort_level(TO_REF(expr, e)).raw() != new_level) {
object * r = lean_expr_mk_sort(new_level);
lean_dec_ref(e);
return r;
} else {
lean_dec(new_level);
return e;
}
}
extern "C" LEAN_EXPORT object * lean_expr_update_proj(obj_arg e, obj_arg new_expr) {
if (proj_expr(TO_REF(expr, e)).raw() != new_expr) {
object * r = lean_expr_mk_proj(proj_sname(TO_REF(expr, e)).to_obj_arg(), proj_idx(TO_REF(expr, e)).to_obj_arg(), new_expr);
lean_dec_ref(e);
return r;
} else {
lean_dec_ref(new_expr);
return e;
}
}
extern "C" LEAN_EXPORT object * lean_expr_update_app(obj_arg e, obj_arg new_fn, obj_arg new_arg) {
if (app_fn(TO_REF(expr, e)).raw() != new_fn || app_arg(TO_REF(expr, e)).raw() != new_arg) {
object * r = lean_expr_mk_app(new_fn, new_arg);
lean_dec_ref(e);
return r;
} else {
lean_dec_ref(new_fn); lean_dec_ref(new_arg);
return e;
}
}
extern "C" LEAN_EXPORT object * lean_expr_update_forall(obj_arg e, uint8 new_binfo, obj_arg new_domain, obj_arg new_body) {
if (binding_domain(TO_REF(expr, e)).raw() != new_domain || binding_body(TO_REF(expr, e)).raw() != new_body ||
binding_info(TO_REF(expr, e)) != static_cast<binder_info>(new_binfo)) {
object * r = lean_expr_mk_forall(binding_name(TO_REF(expr, e)).to_obj_arg(), new_domain, new_body, new_binfo);
lean_dec_ref(e);
return r;
} else {
lean_dec_ref(new_domain); lean_dec_ref(new_body);
return e;
}
}
extern "C" LEAN_EXPORT object * lean_expr_update_lambda(obj_arg e, uint8 new_binfo, obj_arg new_domain, obj_arg new_body) {
if (binding_domain(TO_REF(expr, e)).raw() != new_domain || binding_body(TO_REF(expr, e)).raw() != new_body ||
binding_info(TO_REF(expr, e)) != static_cast<binder_info>(new_binfo)) {
object * r = lean_expr_mk_lambda(binding_name(TO_REF(expr, e)).to_obj_arg(), new_domain, new_body, new_binfo);
lean_dec_ref(e);
return r;
} else {
lean_dec_ref(new_domain); lean_dec_ref(new_body);
return e;
}
}
extern "C" LEAN_EXPORT object * lean_expr_update_let(obj_arg e, obj_arg new_type, obj_arg new_val, obj_arg new_body) {
if (let_type(TO_REF(expr, e)).raw() != new_type || let_value(TO_REF(expr, e)).raw() != new_val ||
let_body(TO_REF(expr, e)).raw() != new_body) {
object * r = lean_expr_mk_let(let_name(TO_REF(expr, e)).to_obj_arg(), new_type, new_val, new_body);
lean_dec_ref(e);
return r;
} else {
lean_dec_ref(new_type); lean_dec_ref(new_val); lean_dec_ref(new_body);
return e;
}
}
extern "C" object * lean_expr_consume_type_annotations(obj_arg e);
expr consume_type_annotations(expr const & e) { return expr(lean_expr_consume_type_annotations(e.to_obj_arg())); }

View file

@ -13,9 +13,24 @@ def sefFn (x_1 : obj) (x_2 : obj) : obj :=
Lean.Expr.const._impl →
ret x_1
Lean.Expr.app._impl →
let x_3 : obj := proj[1] x_1;
let x_4 : obj := Lean.Expr.updateApp x_1 x_2 x_3 ◾;
ret x_4
let x_3 : obj := proj[0] x_1;
let x_4 : obj := proj[1] x_1;
let x_5 : usize := ptrAddrUnsafe ◾ x_3;
let x_6 : usize := ptrAddrUnsafe ◾ x_2;
let x_7 : u8 := USize.decEq x_5 x_6;
case x_7 : obj of
Bool.false →
let x_8 : obj := Lean.Expr.app._override x_2 x_4;
ret x_8
Bool.true →
let x_9 : usize := ptrAddrUnsafe ◾ x_4;
let x_10 : u8 := USize.decEq x_9 x_9;
case x_10 : obj of
Bool.false →
let x_11 : obj := Lean.Expr.app._override x_2 x_4;
ret x_11
Bool.true →
ret x_1
Lean.Expr.lam._impl →
ret x_1
Lean.Expr.forallE._impl →