perf: dec specialization (#13788)

This PR generates specialized code for invoking `dec` on values whose
shape is known. This puts branch prediction pressure off
`lean_dec_ref_cold` as the shape of the constructor should now be
compiled into the executable.
This commit is contained in:
Henrik Böving 2026-05-19 20:56:34 +01:00 committed by GitHub
parent 0a75e1d92f
commit 1f23107ad9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 166 additions and 51 deletions

View file

@ -119,7 +119,7 @@ partial def lowerCode (c : LCNF.Code .impure) : M FnBody := do
| .inc fvarId n check persistent k _ =>
let .var var ← getFVarValue fvarId | unreachable!
return .inc var n check persistent (← lowerCode k)
| .dec fvarId n check persistent k _ =>
| .dec fvarId n check persistent _ k _ =>
let .var var ← getFVarValue fvarId | unreachable!
return .dec var n check persistent (← lowerCode k)
| .del fvarId k _ =>

View file

@ -171,10 +171,11 @@ partial def eqv (code₁ code₂ : Code pu) : EqvM Bool := do
pure (p₁ == p₂) <&&>
eqvFVar fvarId₁ fvarId₂ <&&>
eqv k₁ k₂
| .dec fvarId₁ n₁ c₁ p₁ k₁ _, .dec fvarId₂ n₂ c₂ p₂ k₂ _ =>
| .dec fvarId₁ n₁ c₁ p₁ o₁ k₁ _, .dec fvarId₂ n₂ c₂ p₂ o₂ k₂ _ =>
pure (n₁ == n₂) <&&>
pure (c₁ == c₂) <&&>
pure (p₁ == p₂) <&&>
pure (o₁ == o₂) <&&>
eqvFVar fvarId₁ fvarId₂ <&&>
eqv k₁ k₂
| .del fvarId₁ k₁ _, .del fvarId₂ k₂ _ =>

View file

@ -381,7 +381,7 @@ inductive Code (pu : Purity) where
| sset (fvarId : FVarId) (i : Nat) (offset : Nat) (y : FVarId) (ty : Expr) (k : Code pu) (h : pu = .impure := by purity_tac)
| setTag (fvarId : FVarId) (cidx : Nat) (k : Code pu) (h : pu = .impure := by purity_tac)
| inc (fvarId : FVarId) (n : Nat) (check : Bool) (persistent : Bool) (k : Code pu) (h : pu = .impure := by purity_tac)
| dec (fvarId : FVarId) (n : Nat) (check : Bool) (persistent : Bool) (k : Code pu) (h : pu = .impure := by purity_tac)
| dec (fvarId : FVarId) (n : Nat) (check : Bool) (persistent : Bool) (objs? : Option Nat) (k : Code pu) (h : pu = .impure := by purity_tac)
| del (fvarId : FVarId) (k : Code pu) (h : pu = .impure := by purity_tac)
deriving Inhabited
@ -463,7 +463,7 @@ inductive CodeDecl (pu : Purity) where
| sset (fvarId : FVarId) (i : Nat) (offset : Nat) (y : FVarId) (ty : Expr) (h : pu = .impure := by purity_tac)
| setTag (fvarId : FVarId) (cidx : Nat) (h : pu = .impure := by purity_tac)
| inc (fvarId : FVarId) (n : Nat) (check : Bool) (persistent : Bool) (h : pu = .impure := by purity_tac)
| dec (fvarId : FVarId) (n : Nat) (check : Bool) (persistent : Bool) (h : pu = .impure := by purity_tac)
| dec (fvarId : FVarId) (n : Nat) (check : Bool) (persistent : Bool) (objs? : Option Nat) (h : pu = .impure := by purity_tac)
| del (fvarId : FVarId) (h : pu = .impure := by purity_tac)
deriving Inhabited
@ -481,7 +481,7 @@ def Code.toCodeDecl! : Code pu → CodeDecl pu
| .sset fvarId i offset ty y _ _ => .sset fvarId i offset ty y
| .setTag fvarId cidx _ _ => .setTag fvarId cidx
| .inc fvarId n check persistent _ _ => .inc fvarId n check persistent
| .dec fvarId n check persistent _ _ => .dec fvarId n check persistent
| .dec fvarId n check persistent objs? _ _ => .dec fvarId n check persistent objs?
| .del fvarId _ _ => .del fvarId
| _ => unreachable!
@ -499,7 +499,7 @@ where
| .sset fvarId idx offset y ty _ => go (i-1) (.sset fvarId idx offset y ty code)
| .setTag fvarId cidx _ => go (i-1) (.setTag fvarId cidx code)
| .inc fvarId n check persistent _ => go (i-1) (.inc fvarId n check persistent code)
| .dec fvarId n check persistent _ => go (i-1) (.dec fvarId n check persistent code)
| .dec fvarId n check persistent objs? _ => go (i-1) (.dec fvarId n check persistent objs? code)
| .del fvarId _ => go (i-1) (.del fvarId code)
else
code
@ -526,8 +526,8 @@ mutual
v₁ == v₂ && c₁ == c₂ && eqImp k₁ k₂
| .inc v₁ n₁ c₁ p₁ k₁ _, .inc v₂ n₂ c₂ p₂ k₂ _ =>
v₁ == v₂ && n₁ == n₂ && c₁ == c₂ && p₁ == p₂ && eqImp k₁ k₂
| .dec v₁ n₁ c₁ p₁ k₁ _, .dec v₂ n₂ c₂ p₂ k₂ _ =>
v₁ == v₂ && n₁ == n₂ && c₁ == c₂ && p₁ == p₂ && eqImp k₁ k₂
| .dec v₁ n₁ c₁ p₁ o₁ k₁ _, .dec v₂ n₂ c₂ p₂ o₂ k₂ _ =>
v₁ == v₂ && n₁ == n₂ && c₁ == c₂ && p₁ == p₂ && o₁ == o₂ && eqImp k₁ k₂
| .del v₁ k₁ _, .del v₂ k₂ _ =>
v₁ == v₂ && eqImp k₁ k₂
| _, _ => false
@ -627,7 +627,8 @@ private unsafe def updateAltImp (alt : Alt pu) (ps' : Array (Param pu)) (k' : Co
| .uset fvarId offset y k _ => if ptrEq k k' then c else .uset fvarId offset y k'
| .setTag fvarId cidx k _ => if ptrEq k k' then c else .setTag fvarId cidx k'
| .inc fvarId n check persistent k _ => if ptrEq k k' then c else .inc fvarId n check persistent k'
| .dec fvarId n check persistent k _ => if ptrEq k k' then c else .dec fvarId n check persistent k'
| .dec fvarId n check persistent o k _ =>
if ptrEq k k' then c else .dec fvarId n check persistent o k'
| .del fvarId k _ => if ptrEq k k' then c else .del fvarId k'
| _ => unreachable!
@ -732,21 +733,22 @@ private unsafe def updateAltImp (alt : Alt pu) (ps' : Array (Param pu)) (k' : Co
(check' : Bool) (persistent' : Bool) (k' : Code pu) : Code pu
@[inline] private unsafe def updateDecImp (c : Code pu) (fvarId' : FVarId) (n' : Nat)
(check' : Bool) (persistent' : Bool) (k' : Code pu) : Code pu :=
(check' : Bool) (persistent' : Bool) (objs?' : Option Nat) (k' : Code pu) : Code pu :=
match c with
| .dec fvarId n check persistent k _ =>
| .dec fvarId n check persistent objs? k _ =>
if ptrEq fvarId fvarId'
&& n == n'
&& check == check'
&& persistent == persistent'
&& ptrEq objs? objs?'
&& ptrEq k k' then
c
else
.dec fvarId' n' check' persistent' k'
.dec fvarId' n' check' persistent' objs?' k'
| _ => unreachable!
@[implemented_by updateDecImp] opaque Code.updateDec! (c : Code pu) (fvarId' : FVarId) (n' : Nat)
(check' : Bool) (persistent' : Bool) (k' : Code pu) : Code pu
@[implemented_by updateDecImp] opaque Code.updateDec! (c : Code pu) (fvarId' : FVarId)
(n' : Nat) (check' : Bool) (persistent' : Bool) (objs? : Option Nat) (k' : Code pu) : Code pu
@[inline] private unsafe def updateDelImp (c : Code pu) (fvarId' : FVarId) (k' : Code pu) :
Code pu :=

View file

@ -63,13 +63,13 @@ where
return .inc fvarId s.incTotal[fvarId]! check persistent k
else
return k
| .dec fvarId n check persistent k _ =>
| .dec fvarId n check persistent objs? k _ =>
modify fun s => { s with decTotal := s.decTotal.alter fvarId (fun v? => some ((v?.getD 0) + n)) }
let k ← go k
let s ← get
if !s.decPlaced.contains fvarId then
modify fun s => { s with decPlaced := s.decPlaced.insert fvarId }
return .dec fvarId s.decTotal[fvarId]! check persistent k
return .dec fvarId s.decTotal[fvarId]! check persistent objs? k
else
return k
| .let _ k =>

View file

@ -515,9 +515,9 @@ mutual
| .inc fvarId n check persistent k _ =>
withNormFVarResult (← normFVar fvarId) fun fvarId => do
return code.updateInc! fvarId n check persistent (← normCodeImp k)
| .dec fvarId n check persistent k _ =>
| .dec fvarId n check persistent objs? k _ =>
withNormFVarResult (← normFVar fvarId) fun fvarId => do
return code.updateDec! fvarId n check persistent (← normCodeImp k)
return code.updateDec! fvarId n check persistent objs? (← normCodeImp k)
| .del fvarId k _ =>
withNormFVarResult (← normFVar fvarId) fun fvarId => do
return code.updateDel! fvarId (← normCodeImp k)

View file

@ -47,8 +47,8 @@ partial def hashCode (code : Code pu) : UInt64 :=
mixHash (hash fvarId) (mixHash (hash cidx) (hashCode k))
| .inc fvarId n check persistent k _ =>
mixHash (mixHash (hash fvarId) (hash n)) (mixHash (mixHash (hash persistent) (hash check)) (hashCode k))
| .dec fvarId n check persistent k _ =>
mixHash (mixHash (hash fvarId) (hash n)) (mixHash (mixHash (hash persistent) (hash check)) (hashCode k))
| .dec fvarId n check persistent objs? k _ =>
mixHash (mixHash (hash fvarId) (hash n)) (mixHash (mixHash (hash persistent) (hash check)) (mixHash (hash objs?) (hashCode k)))
| .del fvarId k _ =>
mixHash (hash fvarId) (hashCode k)

View file

@ -789,8 +789,8 @@ partial def emitBasicBlock (code : Code .impure) : EmitM Unit := do
| .inc fvarId n check persistent k =>
unless persistent do emitInc fvarId n check
emitBasicBlock k
| .dec fvarId n check persistent k =>
unless persistent do emitDec fvarId n check
| .dec fvarId n check persistent objs? k =>
unless persistent do emitDec fvarId n check objs?
emitBasicBlock k
| .del fvarId k =>
emitDel fvarId
@ -821,11 +821,15 @@ where
emitCApp2 incFn fvarId n
emitLn ";"
emitDec (fvarId : FVarId) (n : Nat) (check : Bool) : EmitM Unit := do
emitDec (fvarId : FVarId) (n : Nat) (check : Bool) (objs? : Option Nat) : EmitM Unit := do
-- Anything else is unsupported at the moment
assert! n == 1
let decFn := if check then "lean_dec" else "lean_dec_ref"
emitCApp1 decFn fvarId
match objs? with
| some objs =>
emitCApp2 "lean_dec_ref_known" fvarId objs
| none =>
let decFn := if check then "lean_dec" else "lean_dec_ref"
emitCApp1 decFn fvarId
emitLn ";"
emitDel (fvarId : FVarId) : EmitM Unit := do

View file

@ -203,7 +203,7 @@ Expand the matching `reuse`/`dec` for the allocation in `origAllocId` whose `res
partial def processResetCont (resetTokenId : FVarId) (code : Code .impure) (origAllocId : FVarId)
(isSharedId : FVarId) (currentRetType : Expr) : CompilerM (Code .impure) := do
match code with
| .dec y n _ _ k =>
| .dec y n _ _ _ k =>
if resetTokenId == y then
assert! n == 1 -- n must be one since `resetToken := reset ...`
return .del resetTokenId k
@ -344,7 +344,7 @@ where
mkSlowPath (origAllocId : FVarId) (mask : Mask) (resetJpId : FVarId) (isSharedId : FVarId) :
CompilerM (Code .impure) := do
let mut code := .jmp resetJpId #[.erased, .fvar isSharedId]
code := .dec origAllocId 1 true false code
code := .dec origAllocId 1 true false none code
for fvarId? in mask do
let some fvarId := fvarId? | continue
code := .inc fvarId 1 true false code
@ -363,7 +363,7 @@ where
if mask[idx].isSome then
continue
let fieldDecl ← mkLetDecl (← mkFreshBinderName `unused) tobject (.oproj idx origAllocId)
code := .let fieldDecl (.dec fieldDecl.fvarId 1 true false code)
code := .let fieldDecl (.dec fieldDecl.fvarId 1 true false none code)
return code
end

View file

@ -168,6 +168,7 @@ structure VarInfo where
isDefiniteRef : Bool
persistent : Bool
idx : Nat
ctorInfo : Option CtorInfo
deriving Inhabited
abbrev VarMap := FVarIdMap VarInfo
@ -268,18 +269,24 @@ def withParams (ps : Array (Param .impure)) (x : RcM α) : RcM α := do
isDefiniteRef := p.type.isDefiniteRef
persistent := false
idx := ctx.idx,
ctorInfo := none,
}
{ ctx with idx := ctx.idx + 1, varMap }
withReader update x
@[inline]
def withLetDecl (decl : LetDecl .impure) (x : RcM α) : RcM α := do
let ctorInfo :=
match decl.value with
| .ctor ctorInfo .. => some ctorInfo
| _ => none
let update := fun ctx =>
let varInfo := {
isPossibleRef := decl.type.isPossibleRef
isDefiniteRef := decl.type.isDefiniteRef
persistent := decl.value.isPersistent
idx := ctx.idx
idx := ctx.idx,
ctorInfo := ctorInfo
}
{ ctx with varMap := ctx.varMap.insert decl.fvarId varInfo, idx := ctx.idx + 1 }
withReader update do
@ -293,9 +300,10 @@ def withCtorAlt (discr : FVarId) (c : CtorInfo) (x : RcM α) : RcM α := do
varMap :=
match ctx.varMap.get? discr with
| some info =>
let isPossibleRef := c.type.isPossibleRef
let isDefiniteRef := c.type.isDefiniteRef
ctx.varMap.insert discr { info with isPossibleRef, isDefiniteRef, idx := ctx.idx + 1 }
let isPossibleRef := c.isRef
let isDefiniteRef := c.isRef
ctx.varMap.insert discr
{ info with isPossibleRef, isDefiniteRef, idx := ctx.idx + 1, ctorInfo := some c }
| none => ctx.varMap
idx := ctx.idx + 1
}) do x
@ -394,10 +402,16 @@ def addInc (fvarId : FVarId) (k : Code .impure) (n : Nat := 1) : RcM (Code .impu
let info ← getVarInfo fvarId
return .inc fvarId n (!info.isDefiniteRef) info.persistent k
@[inline]
def addDec (fvarId : FVarId) (k : Code .impure) : RcM (Code .impure) := do
let info ← getVarInfo fvarId
return .dec fvarId 1 (!info.isDefiniteRef) info.persistent k
match info.ctorInfo with
| some ctorInfo =>
if ctorInfo.isRef then
return .dec fvarId 1 false info.persistent (some ctorInfo.size) k
else
return k
| none =>
return .dec fvarId 1 (!info.isDefiniteRef) info.persistent none k
/--
Insert the alternative specific prolog for the alternative contained in `k`. `altLiveVars` is the

View file

@ -150,8 +150,8 @@ partial def Code.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m F
return Code.updateSetTag! c (← f fvarId) cidx (← mapFVarM f k)
| .inc fvarId n check persistent k _ =>
return Code.updateInc! c (← f fvarId) n check persistent (← mapFVarM f k)
| .dec fvarId n check persistent k _ =>
return Code.updateDec! c (← f fvarId) n check persistent (← mapFVarM f k)
| .dec fvarId n check persistent objs? k _ =>
return Code.updateDec! c (← f fvarId) n check persistent objs? (← mapFVarM f k)
| .del fvarId k _ =>
return Code.updateDel! c (← f fvarId) (← mapFVarM f k)
@ -226,7 +226,7 @@ instance : TraverseFVar (CodeDecl pu) where
| .sset fvarId i offset y ty _ => return .sset (← f fvarId) i offset (← f y) (← mapFVarM f ty)
| .setTag fvarId cidx _ => return .setTag (← f fvarId) cidx
| .inc fvarId n check persistent _ => return .inc (← f fvarId) n check persistent
| .dec fvarId n check persistent _ => return .dec (← f fvarId) n check persistent
| .dec fvarId n check persistent objs? _ => return .dec (← f fvarId) n check persistent objs?
| .del fvarId _ => return .del (← f fvarId)
forFVarM f decl :=
match decl with

View file

@ -198,9 +198,9 @@ partial def internalizeCode (code : Code pu) : InternalizeM pu (Code pu) := do
| .inc fvarId n check persistent k _ =>
withNormFVarResult (← normFVar fvarId) fun fvarId => do
return .inc fvarId n check persistent (← internalizeCode k)
| .dec fvarId n check persistent k _ =>
| .dec fvarId n check persistent objs? k _ =>
withNormFVarResult (← normFVar fvarId) fun fvarId => do
return .dec fvarId n check persistent (← internalizeCode k)
return .dec fvarId n check persistent objs? (← internalizeCode k)
| .del fvarId k _ =>
withNormFVarResult (← normFVar fvarId) fun fvarId => do
return .del fvarId (← internalizeCode k)
@ -232,9 +232,9 @@ partial def internalizeCodeDecl (decl : CodeDecl pu) : InternalizeM pu (CodeDecl
| .inc fvarId n check offset _ =>
let .fvar fvarId ← normFVar fvarId | unreachable!
return .inc fvarId n check offset
| .dec fvarId n check offset _ =>
| .dec fvarId n check offset objs? _ =>
let .fvar fvarId ← normFVar fvarId | unreachable!
return .dec fvarId n check offset
return .dec fvarId n check offset objs?
| .del fvarId _ =>
let .fvar fvarId ← normFVar fvarId | unreachable!
return .del fvarId

View file

@ -160,8 +160,11 @@ mutual
return f!"inc[{n}]{ann} {← ppFVar fvarId};" ++ .line ++ (← ppCode k)
else
return f!"inc{ann} {← ppFVar fvarId};" ++ .line ++ (← ppCode k)
| .dec fvarId n check persistent k _ =>
let ann := (if persistent then "[persistent]" else "") ++ (if !check then "[ref]" else "")
| .dec fvarId n check persistent objs? k _ =>
let mut ann := ""
if persistent then ann := ann ++ "[persistent]"
if !check then ann := ann ++ "[ref]"
if let some objs := objs? then ann := ann ++ s!"[{objs} objs]"
if n != 1 then
return f!"dec[{n}]{ann} {← ppFVar fvarId};" ++ .line ++ (← ppCode k)
else

View file

@ -141,8 +141,8 @@ where
go k (decls.push (.sset fvarId i offset y ty))
| .inc fvarId n check persistent k _ =>
go k (decls.push (.inc fvarId n check persistent))
| .dec fvarId n check persistent k _ =>
go k (decls.push (.dec fvarId n check persistent))
| .dec fvarId n check persistent objs? k _ =>
go k (decls.push (.dec fvarId n check persistent objs?))
| .del fvarId k _ =>
go k (decls.push (.del fvarId))
| .setTag fvarId cidx k _ =>

View file

@ -132,9 +132,10 @@ partial def Code.toExprM (code : Code pu) : ToExprM Expr := do
let value := mkApp4 (mkConst `inc) (.fvar fvarId) (toExpr n) (toExpr check) (toExpr persistent)
let body ← withFVar fvarId k.toExprM
return .letE `dummy (mkConst ``Unit) value body true
| .dec fvarId n check persistent k _ =>
| .dec fvarId n check persistent o k _ =>
let body ← withFVar fvarId k.toExprM
let value := mkApp4 (mkConst `dec) (.fvar fvarId) (toExpr n) (toExpr check) (toExpr persistent)
let value :=
mkApp5 (mkConst `dec) (.fvar fvarId) (toExpr n) (toExpr check) (toExpr persistent) (toExpr o)
return .letE `dummy (mkConst ``Unit) value body true
| .del fvarId k _ =>
let body ← withFVar fvarId k.toExprM

View file

@ -676,6 +676,18 @@ static inline b_lean_obj_res lean_ctor_get(b_lean_obj_arg o, unsigned i) {
return lean_ctor_obj_cptr(o)[i];
}
static inline void lean_dec_ref_known(lean_object * o, unsigned objs) {
assert(lean_is_ref(o));
if (lean_is_exclusive(o)) {
for(unsigned i = 0; i < objs; i++) {
lean_dec(lean_ctor_get(o, i));
}
lean_del_object(o);
} else {
lean_dec_ref(o);
}
}
static inline void lean_ctor_set(b_lean_obj_arg o, unsigned i, lean_obj_arg v) {
assert(i < lean_ctor_num_objs(o));
lean_ctor_obj_cptr(o)[i] = v;

View file

@ -0,0 +1,78 @@
/-!
This test checks specialization of `dec` when the shape of the object is known.
-/
inductive A where
| ctor1 (x : Nat × Nat)
| ctor2 (y : Nat × Nat) (z : Nat × Nat)
| ctor3
-- Force lookAtA to own the A
@[extern "foo"]
opaque foo (x : A) : Nat
/--
trace: [Compiler.explicitRc] size: 47
def lookAtA x : tobj :=
inc x;
let v1 := foo x;
cases x : tobj
| A.ctor1 =>
let x.1 := oproj[0] x;
inc[ref] x.1;
dec[ref][1 objs] x;
let fst.2 := oproj[0] x.1;
inc fst.2;
let snd.3 := oproj[1] x.1;
inc snd.3;
dec[ref] x.1;
let _x.4 := Nat.add v1 fst.2;
dec fst.2;
dec v1;
let _x.5 := Nat.add _x.4 snd.3;
dec snd.3;
dec _x.4;
return _x.5
| A.ctor2 =>
let y.6 := oproj[0] x;
inc[ref] y.6;
let z.7 := oproj[1] x;
inc[ref] z.7;
dec[ref][2 objs] x;
let fst.8 := oproj[0] y.6;
inc fst.8;
let snd.9 := oproj[1] y.6;
inc snd.9;
dec[ref] y.6;
let fst.10 := oproj[0] z.7;
inc fst.10;
let snd.11 := oproj[1] z.7;
inc snd.11;
dec[ref] z.7;
let _x.12 := Nat.add v1 fst.8;
dec fst.8;
dec v1;
let _x.13 := Nat.add _x.12 snd.9;
dec snd.9;
dec _x.12;
let _x.14 := Nat.add _x.13 fst.10;
dec fst.10;
dec _x.13;
let _x.15 := Nat.add _x.14 snd.11;
dec snd.11;
dec _x.14;
return _x.15
| A.ctor3 =>
return v1
-/
#guard_msgs in
set_option trace.Compiler.explicitRc true in
def lookAtA (x : A) : Nat :=
let v1 := foo x
match x with
| .ctor1 (x1, x2) => v1 + x1 + x2
| .ctor2 (x1, x2) (y1, y2) => v1 + x1 + x2 + y1 + y2
| .ctor3 => v1

View file

@ -157,7 +157,7 @@ trace: [Compiler.pushProj] size: 14
| Option.some =>
let val.11 : tobj := oproj[0] a;
inc val.11;
dec[ref] a;
dec[ref][1 objs] a;
let val.12 : tobj := oproj[0] b;
jp resetjp.13 _x.14 isShared.15 : tobj :=
let _x.16 : tobj := Nat.add val.11 val.12;
@ -251,14 +251,14 @@ trace: [Compiler.pushProj] size: 18
| Option.some =>
cases c : tobj
| Bool.false =>
dec[ref] b;
dec[ref] a;
dec[ref][1 objs] b;
dec[ref][1 objs] a;
let _x.11 : tagged := ctor_0[Option.none];
return _x.11
| Bool.true =>
let val.12 : tobj := oproj[0] a;
inc val.12;
dec[ref] a;
dec[ref][1 objs] a;
let val.13 : tobj := oproj[0] b;
jp resetjp.14 _x.15 isShared.16 : tobj :=
let _x.17 : tobj := Nat.add val.12 val.13;