perf: implement Expr.update* in Lean
This commit is contained in:
parent
1611cf63c3
commit
eda3eae18e
5 changed files with 125 additions and 178 deletions
|
|
@ -42,8 +42,15 @@ unsafe def ptrAddrUnsafe {α : Type u} (a : @& α) : USize := 0
|
|||
@[inline] unsafe def withPtrAddrUnsafe {α : Type u} {β : Type v} (a : α) (k : USize → β) (h : ∀ u₁ u₂, k u₁ = k u₂) : β :=
|
||||
k (ptrAddrUnsafe a)
|
||||
|
||||
@[inline] unsafe def ptrEq (a b : α) : Bool := ptrAddrUnsafe a == ptrAddrUnsafe b
|
||||
|
||||
unsafe def ptrEqList : (as bs : List α) → Bool
|
||||
| [], [] => true
|
||||
| a::as, b::bs => if ptrEq a b then ptrEqList as bs else false
|
||||
| _, _ => false
|
||||
|
||||
@[inline] unsafe def withPtrEqUnsafe {α : Type u} (a b : α) (k : Unit → Bool) (h : a = b → k () = true) : Bool :=
|
||||
if ptrAddrUnsafe a == ptrAddrUnsafe b then true else k ()
|
||||
if ptrEq a b then true else k ()
|
||||
|
||||
@[implementedBy withPtrEqUnsafe]
|
||||
def withPtrEq {α : Type u} (a b : α) (k : Unit → Bool) (h : a = b → k () = true) : Bool := k ()
|
||||
|
|
|
|||
|
|
@ -732,9 +732,6 @@ opaque eqv (a : @& Expr) (b : @& Expr) : Bool
|
|||
instance : BEq Expr where
|
||||
beq := Expr.eqv
|
||||
|
||||
protected unsafe def ptrEq (a b : Expr) : Bool :=
|
||||
ptrAddrUnsafe a == ptrAddrUnsafe b
|
||||
|
||||
/--
|
||||
Return true iff `a` and `b` are equal.
|
||||
Binder names and annotations are taking into account.
|
||||
|
|
@ -1393,101 +1390,120 @@ def containsFVar (e : Expr) (fvarId : FVarId) : Bool :=
|
|||
e.hasAnyFVar (· == fvarId)
|
||||
|
||||
/-!
|
||||
The update functions here are defined using C code. They will try to avoid
|
||||
allocating new values using pointer equality.
|
||||
The hypotheses `(h : e.is...)` are used to ensure Lean will not crash
|
||||
at runtime.
|
||||
The `update*!` functions are inlined and provide a convenient way of using the
|
||||
update proofs without providing proofs.
|
||||
Note that if they are used under a match-expression, the compiler will eliminate
|
||||
the double-match.
|
||||
The update functions try to avoid allocating new values using pointer equality.
|
||||
Note that if the `update*!` functions are used under a match-expression,
|
||||
the compiler will eliminate the double-match.
|
||||
-/
|
||||
|
||||
@[extern "lean_expr_update_app"]
|
||||
def updateApp (e : Expr) (newFn : Expr) (newArg : Expr) (h : e.isApp) : Expr :=
|
||||
mkApp newFn newArg
|
||||
|
||||
@[inline] def updateApp! (e : Expr) (newFn : Expr) (newArg : Expr) : Expr :=
|
||||
match h : e with
|
||||
| app .. => updateApp e newFn newArg (h ▸ rfl)
|
||||
| _ => panic! "application expected"
|
||||
|
||||
@[extern "lean_expr_update_const"]
|
||||
def updateConst (e : Expr) (newLevels : List Level) (h : e.isConst) : Expr :=
|
||||
mkConst e.constName! newLevels
|
||||
|
||||
@[inline] def updateConst! (e : Expr) (newLevels : List Level) : Expr :=
|
||||
match h : e with
|
||||
| const .. => updateConst e newLevels (h ▸ rfl)
|
||||
| _ => panic! "constant expected"
|
||||
|
||||
@[extern "lean_expr_update_sort"]
|
||||
def updateSort (e : Expr) (newLevel : Level) (h : e.isSort) : Expr :=
|
||||
mkSort newLevel
|
||||
|
||||
@[inline] def updateSort! (e : Expr) (newLevel : Level) : Expr :=
|
||||
match h : e with
|
||||
| sort .. => updateSort e newLevel (h ▸ rfl)
|
||||
| _ => panic! "level expected"
|
||||
|
||||
@[extern "lean_expr_update_proj"]
|
||||
def updateProj (e : Expr) (newExpr : Expr) (h : e.isProj) : Expr :=
|
||||
@[inline] private unsafe def updateApp!Impl (e : Expr) (newFn : Expr) (newArg : Expr) : Expr :=
|
||||
match e with
|
||||
| proj s i .. => mkProj s i newExpr
|
||||
| _ => e -- unreachable because of `h`
|
||||
| app fn arg => if ptrEq fn newFn && ptrEq arg newArg then e else mkApp newFn newArg
|
||||
| _ => panic! "application expected"
|
||||
|
||||
@[extern "lean_expr_update_mdata"]
|
||||
def updateMData (e : Expr) (newExpr : Expr) (h : e.isMData) : Expr :=
|
||||
@[implementedBy updateApp!Impl]
|
||||
def updateApp! (e : Expr) (newFn : Expr) (newArg : Expr) : Expr :=
|
||||
match e with
|
||||
| mdata d .. => mkMData d newExpr
|
||||
| _ => e -- unreachable because of `h`
|
||||
| app _ _ => mkApp newFn newArg
|
||||
| _ => panic! "application expected"
|
||||
|
||||
@[inline] def updateMData! (e : Expr) (newExpr : Expr) : Expr :=
|
||||
match h : e with
|
||||
| mdata .. => updateMData e newExpr (h ▸ rfl)
|
||||
| _ => panic! "mdata expected"
|
||||
@[inline] private unsafe def updateConst!Impl (e : Expr) (newLevels : List Level) : Expr :=
|
||||
match e with
|
||||
| const n ls => if ptrEqList ls newLevels then e else mkConst n newLevels
|
||||
| _ => panic! "constant expected"
|
||||
|
||||
@[inline] def updateProj! (e : Expr) (newExpr : Expr) : Expr :=
|
||||
match h : e with
|
||||
| proj .. => updateProj e newExpr (h ▸ rfl)
|
||||
| _ => panic! "proj expected"
|
||||
@[implementedBy updateConst!Impl]
|
||||
def updateConst! (e : Expr) (newLevels : List Level) : Expr :=
|
||||
match e with
|
||||
| const n _ => mkConst n newLevels
|
||||
| _ => panic! "constant expected"
|
||||
|
||||
@[extern "lean_expr_update_forall"]
|
||||
def updateForall (e : Expr) (newBinfo : BinderInfo) (newDomain : Expr) (newBody : Expr) (h : e.isForall) : Expr :=
|
||||
mkForall e.bindingName! newBinfo newDomain newBody
|
||||
@[inline] private unsafe def updateSort!Impl (e : Expr) (u' : Level) : Expr :=
|
||||
match e with
|
||||
| sort u => if ptrEq u u' then e else mkSort u'
|
||||
| _ => panic! "level expected"
|
||||
|
||||
@[inline] def updateForall! (e : Expr) (newBinfo : BinderInfo) (newDomain : Expr) (newBody : Expr) : Expr :=
|
||||
match h : e with
|
||||
| forallE .. => updateForall e newBinfo newDomain newBody (h ▸ rfl)
|
||||
| _ => panic! "forall expected"
|
||||
@[implementedBy updateSort!Impl]
|
||||
def updateSort! (e : Expr) (newLevel : Level) : Expr :=
|
||||
match e with
|
||||
| sort _ => mkSort newLevel
|
||||
| _ => panic! "level expected"
|
||||
|
||||
@[inline] def updateForallE! (e : Expr) (newDomain : Expr) (newBody : Expr) : Expr :=
|
||||
match h : e with
|
||||
| forallE _ _ _ c => updateForall e c newDomain newBody (h ▸ rfl)
|
||||
@[inline] private unsafe def updateMData!Impl (e : Expr) (newExpr : Expr) : Expr :=
|
||||
match e with
|
||||
| mdata d a => if ptrEq a newExpr then e else mkMData d newExpr
|
||||
| _ => panic! "mdata expected"
|
||||
|
||||
@[implementedBy updateMData!Impl]
|
||||
def updateMData! (e : Expr) (newExpr : Expr) : Expr :=
|
||||
match e with
|
||||
| mdata d _ => mkMData d newExpr
|
||||
| _ => panic! "mdata expected"
|
||||
|
||||
@[inline] private unsafe def updateProj!Impl (e : Expr) (newExpr : Expr) : Expr :=
|
||||
match e with
|
||||
| proj s i a => if ptrEq a newExpr then e else mkProj s i newExpr
|
||||
| _ => panic! "proj expected"
|
||||
|
||||
@[implementedBy updateProj!Impl]
|
||||
def updateProj! (e : Expr) (newExpr : Expr) : Expr :=
|
||||
match e with
|
||||
| proj s i _ => mkProj s i newExpr
|
||||
| _ => panic! "proj expected"
|
||||
|
||||
@[inline] private unsafe def updateForall!Impl (e : Expr) (newBinfo : BinderInfo) (newDomain : Expr) (newBody : Expr) : Expr :=
|
||||
match e with
|
||||
| forallE n d b bi =>
|
||||
if ptrEq d newDomain && ptrEq b newBody && bi == newBinfo then
|
||||
e
|
||||
else
|
||||
mkForall n newBinfo newDomain newBody
|
||||
| _ => panic! "forall expected"
|
||||
|
||||
@[extern "lean_expr_update_lambda"]
|
||||
def updateLambda (e : Expr) (newBinfo : BinderInfo) (newDomain : Expr) (newBody : Expr) (h : e.isLambda) : Expr :=
|
||||
mkLambda e.bindingName! newBinfo newDomain newBody
|
||||
@[implementedBy updateForall!Impl]
|
||||
def updateForall! (e : Expr) (newBinfo : BinderInfo) (newDomain : Expr) (newBody : Expr) : Expr :=
|
||||
match e with
|
||||
| forallE n _ _ _ => mkForall n newBinfo newDomain newBody
|
||||
| _ => panic! "forall expected"
|
||||
|
||||
@[inline] def updateLambda! (e : Expr) (newBinfo : BinderInfo) (newDomain : Expr) (newBody : Expr) : Expr :=
|
||||
match h : e with
|
||||
| lam .. => updateLambda e newBinfo newDomain newBody (h ▸ rfl)
|
||||
| _ => panic! "lambda expected"
|
||||
@[inline] def updateForallE! (e : Expr) (newDomain : Expr) (newBody : Expr) : Expr :=
|
||||
match e with
|
||||
| forallE n d b bi => updateForall! (forallE n d b bi) bi newDomain newBody
|
||||
| _ => panic! "forall expected"
|
||||
|
||||
@[inline] def updateLambdaE! (e : Expr) (newDomain : Expr) (newBody : Expr) : Expr :=
|
||||
match h : e with
|
||||
| lam _ _ _ c => updateLambda e c newDomain newBody (h ▸ rfl)
|
||||
@[inline] private unsafe def updateLambda!Impl (e : Expr) (newBinfo : BinderInfo) (newDomain : Expr) (newBody : Expr) : Expr :=
|
||||
match e with
|
||||
| lam n d b bi =>
|
||||
if ptrEq d newDomain && ptrEq b newBody && bi == newBinfo then
|
||||
e
|
||||
else
|
||||
mkLambda n newBinfo newDomain newBody
|
||||
| _ => panic! "lambda expected"
|
||||
|
||||
@[extern "lean_expr_update_let"]
|
||||
def updateLet (e : Expr) (newType : Expr) (newVal : Expr) (newBody : Expr) (h : e.isLet) : Expr :=
|
||||
mkLet e.letName! newType newVal newBody
|
||||
@[implementedBy updateLambda!Impl]
|
||||
def updateLambda! (e : Expr) (newBinfo : BinderInfo) (newDomain : Expr) (newBody : Expr) : Expr :=
|
||||
match e with
|
||||
| lam n _ _ _ => mkLambda n newBinfo newDomain newBody
|
||||
| _ => panic! "lambda expected"
|
||||
|
||||
@[inline] def updateLet! (e : Expr) (newType : Expr) (newVal : Expr) (newBody : Expr) : Expr :=
|
||||
match h : e with
|
||||
| letE .. => updateLet e newType newVal newBody (h ▸ rfl)
|
||||
| _ => panic! "let expression expected"
|
||||
@[inline] def updateLambdaE! (e : Expr) (newDomain : Expr) (newBody : Expr) : Expr :=
|
||||
match e with
|
||||
| lam n d b bi => updateLambda! (lam n d b bi) bi newDomain newBody
|
||||
| _ => panic! "lambda expected"
|
||||
|
||||
@[inline] private unsafe def updateLet!Impl (e : Expr) (newType : Expr) (newVal : Expr) (newBody : Expr) : Expr :=
|
||||
match e with
|
||||
| letE n t v b nonDep =>
|
||||
if ptrEq t newType && ptrEq v newVal && ptrEq b newBody then
|
||||
e
|
||||
else
|
||||
letE n newType newVal newBody nonDep
|
||||
| _ => panic! "let expression expected"
|
||||
|
||||
@[implementedBy updateLet!Impl]
|
||||
def updateLet! (e : Expr) (newType : Expr) (newVal : Expr) (newBody : Expr) : Expr :=
|
||||
match e with
|
||||
| letE n _ _ _ c => letE n newType newVal newBody c
|
||||
| _ => panic! "let expression expected"
|
||||
|
||||
def updateFn : Expr → Expr → Expr
|
||||
| e@(app f a), g => e.updateApp! (updateFn f g) a
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ structure HasConstCache (declName : Name) where
|
|||
cache : Std.HashMapImp Expr Bool := Std.mkHashMapImp
|
||||
|
||||
unsafe def HasConstCache.containsUnsafe (e : Expr) : StateM (HasConstCache declName) Bool := do
|
||||
if let some r := (← get).cache.find? (beq := ⟨Expr.ptrEq⟩) e then
|
||||
if let some r := (← get).cache.find? (beq := ⟨ptrEq⟩) e then
|
||||
return r
|
||||
else
|
||||
match e with
|
||||
|
|
@ -25,7 +25,7 @@ unsafe def HasConstCache.containsUnsafe (e : Expr) : StateM (HasConstCache declN
|
|||
| _ => return false
|
||||
where
|
||||
cache (e : Expr) (r : Bool) : StateM (HasConstCache declName) Bool := do
|
||||
modify fun ⟨cache⟩ => ⟨cache.insert (beq := ⟨Expr.ptrEq⟩) e r |>.1⟩
|
||||
modify fun ⟨cache⟩ => ⟨cache.insert (beq := ⟨ptrEq⟩) e r |>.1⟩
|
||||
return r
|
||||
|
||||
/--
|
||||
|
|
|
|||
|
|
@ -330,97 +330,6 @@ expr update_let(expr const & e, expr const & new_type, expr const & new_value, e
|
|||
return e;
|
||||
}
|
||||
|
||||
extern "C" LEAN_EXPORT object * lean_expr_update_mdata(obj_arg e, obj_arg new_expr) {
|
||||
if (mdata_expr(TO_REF(expr, e)).raw() != new_expr) {
|
||||
object * r = lean_expr_mk_mdata(mdata_data(TO_REF(expr, e)).to_obj_arg(), new_expr);
|
||||
lean_dec_ref(e);
|
||||
return r;
|
||||
} else {
|
||||
lean_dec_ref(new_expr);
|
||||
return e;
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" LEAN_EXPORT object * lean_expr_update_const(obj_arg e, obj_arg new_levels) {
|
||||
if (const_levels(TO_REF(expr, e)).raw() != new_levels) {
|
||||
object * r = lean_expr_mk_const(const_name(TO_REF(expr, e)).to_obj_arg(), new_levels);
|
||||
lean_dec_ref(e);
|
||||
return r;
|
||||
} else {
|
||||
lean_dec(new_levels);
|
||||
return e;
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" LEAN_EXPORT object * lean_expr_update_sort(obj_arg e, obj_arg new_level) {
|
||||
if (sort_level(TO_REF(expr, e)).raw() != new_level) {
|
||||
object * r = lean_expr_mk_sort(new_level);
|
||||
lean_dec_ref(e);
|
||||
return r;
|
||||
} else {
|
||||
lean_dec(new_level);
|
||||
return e;
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" LEAN_EXPORT object * lean_expr_update_proj(obj_arg e, obj_arg new_expr) {
|
||||
if (proj_expr(TO_REF(expr, e)).raw() != new_expr) {
|
||||
object * r = lean_expr_mk_proj(proj_sname(TO_REF(expr, e)).to_obj_arg(), proj_idx(TO_REF(expr, e)).to_obj_arg(), new_expr);
|
||||
lean_dec_ref(e);
|
||||
return r;
|
||||
} else {
|
||||
lean_dec_ref(new_expr);
|
||||
return e;
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" LEAN_EXPORT object * lean_expr_update_app(obj_arg e, obj_arg new_fn, obj_arg new_arg) {
|
||||
if (app_fn(TO_REF(expr, e)).raw() != new_fn || app_arg(TO_REF(expr, e)).raw() != new_arg) {
|
||||
object * r = lean_expr_mk_app(new_fn, new_arg);
|
||||
lean_dec_ref(e);
|
||||
return r;
|
||||
} else {
|
||||
lean_dec_ref(new_fn); lean_dec_ref(new_arg);
|
||||
return e;
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" LEAN_EXPORT object * lean_expr_update_forall(obj_arg e, uint8 new_binfo, obj_arg new_domain, obj_arg new_body) {
|
||||
if (binding_domain(TO_REF(expr, e)).raw() != new_domain || binding_body(TO_REF(expr, e)).raw() != new_body ||
|
||||
binding_info(TO_REF(expr, e)) != static_cast<binder_info>(new_binfo)) {
|
||||
object * r = lean_expr_mk_forall(binding_name(TO_REF(expr, e)).to_obj_arg(), new_domain, new_body, new_binfo);
|
||||
lean_dec_ref(e);
|
||||
return r;
|
||||
} else {
|
||||
lean_dec_ref(new_domain); lean_dec_ref(new_body);
|
||||
return e;
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" LEAN_EXPORT object * lean_expr_update_lambda(obj_arg e, uint8 new_binfo, obj_arg new_domain, obj_arg new_body) {
|
||||
if (binding_domain(TO_REF(expr, e)).raw() != new_domain || binding_body(TO_REF(expr, e)).raw() != new_body ||
|
||||
binding_info(TO_REF(expr, e)) != static_cast<binder_info>(new_binfo)) {
|
||||
object * r = lean_expr_mk_lambda(binding_name(TO_REF(expr, e)).to_obj_arg(), new_domain, new_body, new_binfo);
|
||||
lean_dec_ref(e);
|
||||
return r;
|
||||
} else {
|
||||
lean_dec_ref(new_domain); lean_dec_ref(new_body);
|
||||
return e;
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" LEAN_EXPORT object * lean_expr_update_let(obj_arg e, obj_arg new_type, obj_arg new_val, obj_arg new_body) {
|
||||
if (let_type(TO_REF(expr, e)).raw() != new_type || let_value(TO_REF(expr, e)).raw() != new_val ||
|
||||
let_body(TO_REF(expr, e)).raw() != new_body) {
|
||||
object * r = lean_expr_mk_let(let_name(TO_REF(expr, e)).to_obj_arg(), new_type, new_val, new_body);
|
||||
lean_dec_ref(e);
|
||||
return r;
|
||||
} else {
|
||||
lean_dec_ref(new_type); lean_dec_ref(new_val); lean_dec_ref(new_body);
|
||||
return e;
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" object * lean_expr_consume_type_annotations(obj_arg e);
|
||||
|
||||
expr consume_type_annotations(expr const & e) { return expr(lean_expr_consume_type_annotations(e.to_obj_arg())); }
|
||||
|
|
|
|||
|
|
@ -13,9 +13,24 @@ def sefFn (x_1 : obj) (x_2 : obj) : obj :=
|
|||
Lean.Expr.const._impl →
|
||||
ret x_1
|
||||
Lean.Expr.app._impl →
|
||||
let x_3 : obj := proj[1] x_1;
|
||||
let x_4 : obj := Lean.Expr.updateApp x_1 x_2 x_3 ◾;
|
||||
ret x_4
|
||||
let x_3 : obj := proj[0] x_1;
|
||||
let x_4 : obj := proj[1] x_1;
|
||||
let x_5 : usize := ptrAddrUnsafe ◾ x_3;
|
||||
let x_6 : usize := ptrAddrUnsafe ◾ x_2;
|
||||
let x_7 : u8 := USize.decEq x_5 x_6;
|
||||
case x_7 : obj of
|
||||
Bool.false →
|
||||
let x_8 : obj := Lean.Expr.app._override x_2 x_4;
|
||||
ret x_8
|
||||
Bool.true →
|
||||
let x_9 : usize := ptrAddrUnsafe ◾ x_4;
|
||||
let x_10 : u8 := USize.decEq x_9 x_9;
|
||||
case x_10 : obj of
|
||||
Bool.false →
|
||||
let x_11 : obj := Lean.Expr.app._override x_2 x_4;
|
||||
ret x_11
|
||||
Bool.true →
|
||||
ret x_1
|
||||
Lean.Expr.lam._impl →
|
||||
ret x_1
|
||||
Lean.Expr.forallE._impl →
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue