feat: support for del, isShared, oset and setTag (#12687)

This PR implements the LCNF instructions required for the expand reset
reuse pass.
This commit is contained in:
Henrik Böving 2026-02-25 11:43:15 +01:00 committed by GitHub
parent 532310313f
commit e96d969d59
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 288 additions and 60 deletions

View file

@ -101,6 +101,10 @@ partial def lowerCode (c : LCNF.Code .impure) : M FnBody := do
let ret ← getFVarValue fvarId
return .ret ret
| .unreach .. => return .unreachable
| .oset fvarId i y k _ =>
let y ← lowerArg y
let .var fvarId ← getFVarValue fvarId | unreachable!
return .set fvarId i y (← lowerCode k)
| .sset fvarId i offset y type k _ =>
let .var y ← getFVarValue y | unreachable!
let .var fvarId ← getFVarValue fvarId | unreachable!
@ -109,12 +113,18 @@ partial def lowerCode (c : LCNF.Code .impure) : M FnBody := do
let .var y ← getFVarValue y | unreachable!
let .var fvarId ← getFVarValue fvarId | unreachable!
return .uset fvarId i y (← lowerCode k)
| .setTag fvarId cidx k _ =>
let .var var ← getFVarValue fvarId | unreachable!
return .setTag var cidx (← lowerCode k)
| .inc fvarId n check persistent k _ =>
let .var var ← getFVarValue fvarId | unreachable!
return .inc var n check persistent (← lowerCode k)
| .dec fvarId n check persistent k _ =>
let .var var ← getFVarValue fvarId | unreachable!
return .dec var n check persistent (← lowerCode k)
| .del fvarId k _ =>
let .var var ← getFVarValue fvarId | unreachable!
return .del var (← lowerCode k)
| .fun .. => panic! "all local functions should be λ-lifted"
partial def lowerLet (decl : LCNF.LetDecl .impure) (k : LCNF.Code .impure) : M FnBody := do
@ -155,6 +165,9 @@ partial def lowerLet (decl : LCNF.LetDecl .impure) (k : LCNF.Code .impure) : M F
| .unbox var =>
withGetFVarValue var fun var => do
continueLet (.unbox var)
| .isShared var =>
withGetFVarValue var fun var => do
continueLet (.isShared var)
| .erased => mkErased ()
where
mkErased (_ : Unit) : M FnBody := do

View file

@ -75,6 +75,7 @@ def eqvLetValue (e₁ e₂ : LetValue pu) : EqvM Bool := do
pure (i₁ == i₂ && u₁ == u₂) <&&> eqvFVar v₁ v₂ <&&> eqvArgs as₁ as₂
| .box ty₁ v₁ _, .box ty₂ v₂ _ => eqvType ty₁ ty₂ <&&> eqvFVar v₁ v₂
| .unbox v₁ _, .unbox v₂ _ => eqvFVar v₁ v₂
| .isShared v₁ _, .isShared v₂ _ => eqvFVar v₁ v₂
| _, _ => return false
@[inline] def withFVar (fvarId₁ fvarId₂ : FVarId) (x : EqvM α) : EqvM α :=
@ -143,6 +144,11 @@ partial def eqv (code₁ code₂ : Code pu) : EqvM Bool := do
eqvFVar c₁.discr c₂.discr <&&>
eqvType c₁.resultType c₂.resultType <&&>
eqvAlts c₁.alts c₂.alts
| .oset fvarId₁ i₁ y₁ k₁ _, .oset fvarId₂ i₂ y₂ k₂ _ =>
pure (i₁ == i₂) <&&>
eqvFVar fvarId₁ fvarId₂ <&&>
eqvArg y₁ y₂ <&&>
eqv k₁ k₂
| .sset fvarId₁ i₁ offset₁ y₁ ty₁ k₁ _, .sset fvarId₂ i₂ offset₂ y₂ ty₂ k₂ _ =>
pure (i₁ == i₂) <&&>
pure (offset₁ == offset₂) <&&>
@ -155,6 +161,10 @@ partial def eqv (code₁ code₂ : Code pu) : EqvM Bool := do
eqvFVar fvarId₁ fvarId₂ <&&>
eqvFVar y₁ y₂ <&&>
eqv k₁ k₂
| .setTag fvarId₁ c₁ k₁ _, .setTag fvarId₂ c₂ k₂ _ =>
pure (c₁ == c₂) <&&>
eqvFVar fvarId₁ fvarId₂ <&&>
eqv k₁ k₂
| .inc fvarId₁ n₁ c₁ p₁ k₁ _, .inc fvarId₂ n₂ c₂ p₂ k₂ _ =>
pure (n₁ == n₂) <&&>
pure (c₁ == c₂) <&&>
@ -167,6 +177,9 @@ partial def eqv (code₁ code₂ : Code pu) : EqvM Bool := do
pure (p₁ == p₂) <&&>
eqvFVar fvarId₁ fvarId₂ <&&>
eqv k₁ k₂
| .del fvarId₁ k₁ _, .del fvarId₂ k₂ _ =>
eqvFVar fvarId₁ fvarId₂ <&&>
eqv k₁ k₂
| _, _ => return false
end

View file

@ -219,6 +219,10 @@ inductive LetValue (pu : Purity) where
| box (ty : Expr) (fvarId : FVarId) (h : pu = .impure := by purity_tac)
/-- Given `fvarId : [t]object`, obtain the underlying scalar value. -/
| unbox (fvarId : FVarId) (h : pu = .impure := by purity_tac)
/--
Return whether the object stored behind `fvarId` is shared or not. The return type is a `UInt8`.
-/
| isShared (fvarId : FVarId) (h : pu = .impure := by purity_tac)
deriving Inhabited, BEq, Hashable
def Arg.toLetValue (arg : Arg pu) : LetValue pu :=
@ -298,7 +302,12 @@ private unsafe def LetValue.updateUnboxImp (e : LetValue pu) (fvarId' : FVarId)
@[implemented_by LetValue.updateUnboxImp] opaque LetValue.updateUnbox! (e : LetValue pu) (fvarId' : FVarId) : LetValue pu
private unsafe def LetValue.updateIsSharedImp (e : LetValue pu) (fvarId' : FVarId) : LetValue pu :=
match e with
| .isShared fvarId _ => if fvarId == fvarId' then e else .isShared fvarId'
| _ => unreachable!
@[implemented_by LetValue.updateIsSharedImp] opaque LetValue.updateIsShared! (e : LetValue pu) (fvarId' : FVarId) : LetValue pu
private unsafe def LetValue.updateArgsImp (e : LetValue pu) (args' : Array (Arg pu)) : LetValue pu :=
match e with
@ -331,6 +340,7 @@ def LetValue.toExpr (e : LetValue pu) : Expr :=
#[.fvar var, .const i.name [], ToExpr.toExpr updateHeader] ++ (args.map Arg.toExpr)
| .box ty var _ => mkApp2 (.const `box []) ty (.fvar var)
| .unbox var _ => mkApp (.const `unbox []) (.fvar var)
| .isShared fvarId _ => mkApp (.const `isShared []) (.fvar fvarId)
structure LetDecl (pu : Purity) where
fvarId : FVarId
@ -361,10 +371,13 @@ inductive Code (pu : Purity) where
| cases (cases : Cases pu)
| return (fvarId : FVarId)
| unreach (type : Expr)
| oset (fvarId : FVarId) (i : Nat) (y : Arg pu) (k : Code pu) (h : pu = .impure := by purity_tac)
| uset (fvarId : FVarId) (i : Nat) (y : FVarId) (k : Code pu) (h : pu = .impure := by purity_tac)
| sset (fvarId : FVarId) (i : Nat) (offset : Nat) (y : FVarId) (ty : Expr) (k : Code pu) (h : pu = .impure := by purity_tac)
| setTag (fvarId : FVarId) (cidx : Nat) (k : Code pu) (h : pu = .impure := by purity_tac)
| inc (fvarId : FVarId) (n : Nat) (check : Bool) (persistent : Bool) (k : Code pu) (h : pu = .impure := by purity_tac)
| dec (fvarId : FVarId) (n : Nat) (check : Bool) (persistent : Bool) (k : Code pu) (h : pu = .impure := by purity_tac)
| del (fvarId : FVarId) (k : Code pu) (h : pu = .impure := by purity_tac)
deriving Inhabited
end
@ -440,25 +453,32 @@ inductive CodeDecl (pu : Purity) where
| let (decl : LetDecl pu)
| fun (decl : FunDecl pu) (h : pu = .pure := by purity_tac)
| jp (decl : FunDecl pu)
| oset (fvarId : FVarId) (i : Nat) (y : Arg pu) (h : pu = .impure := by purity_tac)
| uset (fvarId : FVarId) (i : Nat) (y : FVarId) (h : pu = .impure := by purity_tac)
| sset (fvarId : FVarId) (i : Nat) (offset : Nat) (y : FVarId) (ty : Expr) (h : pu = .impure := by purity_tac)
| setTag (fvarId : FVarId) (cidx : Nat) (h : pu = .impure := by purity_tac)
| inc (fvarId : FVarId) (n : Nat) (check : Bool) (persistent : Bool) (h : pu = .impure := by purity_tac)
| dec (fvarId : FVarId) (n : Nat) (check : Bool) (persistent : Bool) (h : pu = .impure := by purity_tac)
| del (fvarId : FVarId) (h : pu = .impure := by purity_tac)
deriving Inhabited
def CodeDecl.fvarId : CodeDecl pu → FVarId
| .let decl | .fun decl _ | .jp decl => decl.fvarId
| .uset fvarId .. | .sset fvarId .. | .inc fvarId .. | .dec fvarId .. => fvarId
| .uset fvarId .. | .sset fvarId .. | .inc fvarId .. | .dec fvarId .. | .del fvarId ..
| .oset fvarId .. | .setTag fvarId .. => fvarId
def Code.toCodeDecl! : Code pu → CodeDecl pu
| .let decl _ => .let decl
| .fun decl _ _ => .fun decl
| .jp decl _ => .jp decl
| .uset fvarId i y _ _ => .uset fvarId i y
| .sset fvarId i offset ty y _ _ => .sset fvarId i offset ty y
| .inc fvarId n check persistent _ _ => .inc fvarId n check persistent
| .dec fvarId n check persistent _ _ => .dec fvarId n check persistent
| _ => unreachable!
| .let decl _ => .let decl
| .fun decl _ _ => .fun decl
| .jp decl _ => .jp decl
| .oset fvarId i y _ _ => .oset fvarId i y
| .uset fvarId i y _ _ => .uset fvarId i y
| .sset fvarId i offset ty y _ _ => .sset fvarId i offset ty y
| .setTag fvarId cidx _ _ => .setTag fvarId cidx
| .inc fvarId n check persistent _ _ => .inc fvarId n check persistent
| .dec fvarId n check persistent _ _ => .dec fvarId n check persistent
| .del fvarId _ _ => .del fvarId
| _ => unreachable!
def attachCodeDecls (decls : Array (CodeDecl pu)) (code : Code pu) : Code pu :=
go decls.size code
@ -469,10 +489,13 @@ where
| .let decl => go (i-1) (.let decl code)
| .fun decl _ => go (i-1) (.fun decl code)
| .jp decl => go (i-1) (.jp decl code)
| .oset fvarId idx y _ => go (i-1) (.oset fvarId idx y code)
| .uset fvarId idx y _ => go (i-1) (.uset fvarId idx y code)
| .sset fvarId idx offset y ty _ => go (i-1) (.sset fvarId idx offset y ty code)
| .setTag fvarId cidx _ => go (i-1) (.setTag fvarId cidx code)
| .inc fvarId n check persistent _ => go (i-1) (.inc fvarId n check persistent code)
| .dec fvarId n check persistent _ => go (i-1) (.dec fvarId n check persistent code)
| .del fvarId _ => go (i-1) (.del fvarId code)
else
code
@ -488,14 +511,20 @@ mutual
| .jmp j₁ as₁, .jmp j₂ as₂ => j₁ == j₂ && as₁ == as₂
| .return r₁, .return r₂ => r₁ == r₂
| .unreach t₁, .unreach t₂ => t₁ == t₂
| .oset v₁ i₁ y₁ k₁ _, .oset v₂ i₂ y₂ k₂ _ =>
v₁ == v₂ && i₁ == i₂ && y₁ == y₂ && eqImp k₁ k₂
| .uset v₁ i₁ y₁ k₁ _, .uset v₂ i₂ y₂ k₂ _ =>
v₁ == v₂ && i₁ == i₂ && y₁ == y₂ && eqImp k₁ k₂
| .sset v₁ i₁ o₁ y₁ ty₁ k₁ _, .sset v₂ i₂ o₂ y₂ ty₂ k₂ _ =>
v₁ == v₂ && i₁ == i₂ && o₁ == o₂ && y₁ == y₂ && ty₁ == ty₂ && eqImp k₁ k₂
| .setTag v₁ c₁ k₁ _, .setTag v₂ c₂ k₂ _ =>
v₁ == v₂ && c₁ == c₂ && eqImp k₁ k₂
| .inc v₁ n₁ c₁ p₁ k₁ _, .inc v₂ n₂ c₂ p₂ k₂ _ =>
v₁ == v₂ && n₁ == n₂ && c₁ == c₂ && p₁ == p₂ && eqImp k₁ k₂
| .dec v₁ n₁ c₁ p₁ k₁ _, .dec v₂ n₂ c₂ p₂ k₂ _ =>
v₁ == v₂ && n₁ == n₂ && c₁ == c₂ && p₁ == p₂ && eqImp k₁ k₂
| .del v₁ k₁ _, .del v₂ k₂ _ =>
v₁ == v₂ && eqImp k₁ k₂
| _, _ => false
private unsafe def eqFunDecl (d₁ d₂ : FunDecl pu) : Bool :=
@ -588,10 +617,13 @@ private unsafe def updateAltImp (alt : Alt pu) (ps' : Array (Param pu)) (k' : Co
| .let decl k => if ptrEq k k' then c else .let decl k'
| .fun decl k _ => if ptrEq k k' then c else .fun decl k'
| .jp decl k => if ptrEq k k' then c else .jp decl k'
| .oset fvarId offset y k _ => if ptrEq k k' then c else .oset fvarId offset y k'
| .sset fvarId i offset y ty k _ => if ptrEq k k' then c else .sset fvarId i offset y ty k'
| .uset fvarId offset y k _ => if ptrEq k k' then c else .uset fvarId offset y k'
| .setTag fvarId cidx k _ => if ptrEq k k' then c else .setTag fvarId cidx k'
| .inc fvarId n check persistent k _ => if ptrEq k k' then c else .inc fvarId n check persistent k'
| .dec fvarId n check persistent k _ => if ptrEq k k' then c else .dec fvarId n check persistent k'
| .del fvarId k _ => if ptrEq k k' then c else .del fvarId k'
| _ => unreachable!
@[implemented_by updateContImp] opaque Code.updateCont! (c : Code pu) (k' : Code pu) : Code pu
@ -635,6 +667,19 @@ private unsafe def updateAltImp (alt : Alt pu) (ps' : Array (Param pu)) (k' : Co
.sset fvarId' i' offset' y' ty' k'
| _ => unreachable!
@[inline] private unsafe def updateOsetImp (c : Code pu) (fvarId' : FVarId)
(i' : Nat) (y' : Arg pu) (k' : Code pu) : Code pu :=
match c with
| .oset fvarId i y k _ =>
if ptrEq fvarId fvarId' && i == i' && ptrEq y y' && ptrEq k k' then
c
else
.oset fvarId' i' y' k'
| _ => unreachable!
@[implemented_by updateOsetImp] opaque Code.updateOset! (c : Code pu) (fvarId' : FVarId)
(i' : Nat) (y' : Arg pu) (k' : Code pu) : Code pu
@[implemented_by updateSsetImp] opaque Code.updateSset! (c : Code pu) (fvarId' : FVarId) (i' : Nat)
(offset' : Nat) (y' : FVarId) (ty' : Expr) (k' : Code pu) : Code pu
@ -651,6 +696,19 @@ private unsafe def updateAltImp (alt : Alt pu) (ps' : Array (Param pu)) (k' : Co
@[implemented_by updateUsetImp] opaque Code.updateUset! (c : Code pu) (fvarId' : FVarId)
(i' : Nat) (y' : FVarId) (k' : Code pu) : Code pu
@[inline] private unsafe def updateSetTagImp (c : Code pu) (fvarId' : FVarId) (cidx' : Nat)
(k' : Code pu) : Code pu :=
match c with
| .setTag fvarId cidx k _ =>
if ptrEq fvarId fvarId' && cidx == cidx' && ptrEq k k' then
c
else
.setTag fvarId' cidx' k'
| _ => unreachable!
@[implemented_by updateSetTagImp] opaque Code.updateSetTag! (c : Code pu) (fvarId' : FVarId)
(cidx' : Nat) (k' : Code pu) : Code pu
@[inline] private unsafe def updateIncImp (c : Code pu) (fvarId' : FVarId) (n' : Nat)
(check' : Bool) (persistent' : Bool) (k' : Code pu) : Code pu :=
match c with
@ -685,6 +743,19 @@ private unsafe def updateAltImp (alt : Alt pu) (ps' : Array (Param pu)) (k' : Co
@[implemented_by updateDecImp] opaque Code.updateDec! (c : Code pu) (fvarId' : FVarId) (n' : Nat)
(check' : Bool) (persistent' : Bool) (k' : Code pu) : Code pu
@[inline] private unsafe def updateDelImp (c : Code pu) (fvarId' : FVarId) (k' : Code pu) :
Code pu :=
match c with
| .del fvarId k _ =>
if ptrEq fvarId fvarId' && ptrEq k k' then
c
else
.del fvarId' k'
| _ => unreachable!
@[implemented_by updateDelImp] opaque Code.updateDel! (c : Code pu) (fvarId' : FVarId)
(k' : Code pu) : Code pu
private unsafe def updateParamCoreImp (p : Param pu) (type : Expr) : Param pu :=
if ptrEq type p.type then
p
@ -753,8 +824,8 @@ partial def Code.size (c : Code pu) : Nat :=
where
go (c : Code pu) (n : Nat) : Nat :=
match c with
| .let (k := k) .. | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) ..
| .dec (k := k) .. => go k (n + 1)
| .let (k := k) .. | .oset (k := k) .. | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) ..
| .dec (k := k) .. | .setTag (k := k) .. | .del (k := k) .. => go k (n + 1)
| .jp decl k | .fun decl k _ => go k <| go decl.value n
| .cases c => c.alts.foldl (init := n+1) fun n alt => go alt.getCode (n+1)
| .jmp .. => n+1
@ -772,8 +843,8 @@ where
go (c : Code pu) : EStateM Unit Nat Unit := do
match c with
| .let (k := k) .. | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) ..
| .dec (k := k) .. => inc; go k
| .let (k := k) .. | .oset (k := k) .. | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) ..
| .dec (k := k) .. | .setTag (k := k) .. | .del (k := k) .. => inc; go k
| .jp decl k | .fun decl k _ => inc; go decl.value; go k
| .cases c => inc; c.alts.forM fun alt => go alt.getCode
| .jmp .. => inc
@ -785,8 +856,8 @@ where
go (c : Code pu) : m Unit := do
f c
match c with
| .let (k := k) .. | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) ..
| .dec (k := k) .. => go k
| .let (k := k) .. | .oset (k := k) .. | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) ..
| .dec (k := k) .. | .setTag (k := k) .. | .del (k := k) .. => go k
| .fun decl k _ | .jp decl k => go decl.value; go k
| .cases c => c.alts.forM fun alt => go alt.getCode
| .unreach .. | .return .. | .jmp .. => return ()
@ -1053,7 +1124,7 @@ private def collectLetValue (e : LetValue pu) (s : FVarIdHashSet) : FVarIdHashSe
| .fvar fvarId args => collectArgs args <| s.insert fvarId
| .const _ _ args _ | .pap _ args _ | .fap _ args _ | .ctor _ args _ => collectArgs args s
| .proj _ _ fvarId _ | .sproj _ _ fvarId _ | .uproj _ fvarId _ | .oproj _ fvarId _
| .reset _ fvarId _ | .box _ fvarId _ | .unbox fvarId _ => s.insert fvarId
| .reset _ fvarId _ | .box _ fvarId _ | .unbox fvarId _ | .isShared fvarId _ => s.insert fvarId
| .lit .. | .erased => s
| .reuse fvarId _ _ args _ => collectArgs args <| s.insert fvarId
@ -1082,7 +1153,12 @@ partial def Code.collectUsed (code : Code pu) (s : FVarIdHashSet := {}) : FVarId
let s := s.insert fvarId
let s := s.insert y
k.collectUsed s
| .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) .. =>
| .oset fvarId _ y k _ =>
let s := s.insert fvarId
let s := if let .fvar y := y then s.insert y else s
k.collectUsed s
| .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) ..
| .del (fvarId := fvarId) (k := k) .. | .setTag (fvarId := fvarId) (k := k) .. =>
k.collectUsed <| s.insert fvarId
end
@ -1095,7 +1171,11 @@ def CodeDecl.collectUsed (codeDecl : CodeDecl pu) (s : FVarIdHashSet := ∅) : F
| .jp decl | .fun decl _ => decl.collectUsed s
| .sset (fvarId := fvarId) (y := y) .. | .uset (fvarId := fvarId) (y := y) .. =>
s.insert fvarId |>.insert y
| .inc (fvarId := fvarId) .. | .dec (fvarId := fvarId) .. => s.insert fvarId
| .oset (fvarId := fvarId) (y := y) .. =>
let s := s.insert fvarId
if let .fvar y := y then s.insert y else s
| .inc (fvarId := fvarId) .. | .dec (fvarId := fvarId) .. | .setTag (fvarId := fvarId) ..
| .del (fvarId := fvarId) .. => s.insert fvarId
/--
Traverse the given block of potentially mutually recursive functions
@ -1125,7 +1205,8 @@ where
modify fun s => s.insert declName
| _ => pure ()
visit k
| .uset (k := k) .. | .sset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => visit k
| .oset (k := k) .. | .uset (k := k) .. | .sset (k := k) .. | .inc (k := k) ..
| .dec (k := k) .. | .del (k := k) .. | .setTag (k := k) .. => visit k
go : StateM NameSet Unit :=
decls.forM (·.value.forCodeM visit)

View file

@ -68,7 +68,8 @@ where
eraseCode k
eraseParam auxParam
return .unreach typeNew
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. =>
| .oset (k := k) ..| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) ..
| .del (k := k) .. | .setTag (k := k) .. =>
return c.updateCont! (← go k)
instance : MonadCodeBind CompilerM where

View file

@ -149,7 +149,7 @@ def eraseCodeDecl (decl : CodeDecl pu) : CompilerM Unit := do
match decl with
| .let decl => eraseLetDecl decl
| .jp decl | .fun decl _ => eraseFunDecl decl
| .sset .. | .uset .. | .inc .. | .dec .. => return ()
| .sset .. | .uset .. | .inc .. | .dec .. | .setTag .. | .oset .. | .del .. => return ()
/--
Erase all free variables occurring in `decls` from the local context.
@ -300,6 +300,10 @@ private partial def normLetValueImp (s : FVarSubst pu) (e : LetValue pu) (transl
match normFVarImp s fvarId translator with
| .fvar fvarId' => e.updateUnbox! fvarId'
| .erased => .erased
| .isShared fvarId _ =>
match normFVarImp s fvarId translator with
| .fvar fvarId' => e.updateIsShared! fvarId'
| .erased => .erased
/--
Interface for monads that have a free substitutions.
@ -497,16 +501,26 @@ mutual
withNormFVarResult (← normFVar fvarId) fun fvarId => do
withNormFVarResult (← normFVar y) fun y => do
return code.updateSset! fvarId i offset y (← normExpr ty) (← normCodeImp k)
| .oset fvarId offset y k _ =>
withNormFVarResult (← normFVar fvarId) fun fvarId => do
let y ← normArg y
return code.updateOset! fvarId offset y (← normCodeImp k)
| .uset fvarId offset y k _ =>
withNormFVarResult (← normFVar fvarId) fun fvarId => do
withNormFVarResult (← normFVar y) fun y => do
return code.updateUset! fvarId offset y (← normCodeImp k)
| .setTag fvarId cidx k _ =>
withNormFVarResult (← normFVar fvarId) fun fvarId => do
return code.updateSetTag! fvarId cidx (← normCodeImp k)
| .inc fvarId n check persistent k _ =>
withNormFVarResult (← normFVar fvarId) fun fvarId => do
return code.updateInc! fvarId n check persistent (← normCodeImp k)
| .dec fvarId n check persistent k _ =>
withNormFVarResult (← normFVar fvarId) fun fvarId => do
return code.updateDec! fvarId n check persistent (← normCodeImp k)
| .del fvarId k _ =>
withNormFVarResult (← normFVar fvarId) fun fvarId => do
return code.updateDel! fvarId (← normCodeImp k)
end
@[inline] def normFunDecl [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m pu t] (decl : FunDecl pu) : m (FunDecl pu) := do

View file

@ -39,12 +39,18 @@ partial def hashCode (code : Code pu) : UInt64 :=
| .cases c => mixHash (mixHash (hash c.discr) (hash c.resultType)) (hashAlts c.alts)
| .sset fvarId i offset y ty k _ =>
mixHash (mixHash (hash fvarId) (hash i)) (mixHash (mixHash (hash offset) (hash y)) (mixHash (hash ty) (hashCode k)))
| .oset fvarId offset y k _ =>
mixHash (mixHash (hash fvarId) (hash offset)) (mixHash (hash y) (hashCode k))
| .uset fvarId offset y k _ =>
mixHash (mixHash (hash fvarId) (hash offset)) (mixHash (hash y) (hashCode k))
| .setTag fvarId cidx k _ =>
mixHash (hash fvarId) (mixHash (hash cidx) (hashCode k))
| .inc fvarId n check persistent k _ =>
mixHash (mixHash (hash fvarId) (hash n)) (mixHash (mixHash (hash persistent) (hash check)) (hashCode k))
| .dec fvarId n check persistent k _ =>
mixHash (mixHash (hash fvarId) (hash n)) (mixHash (mixHash (hash persistent) (hash check)) (hashCode k))
| .del fvarId k _ =>
mixHash (hash fvarId) (hashCode k)
end

View file

@ -31,7 +31,7 @@ private def letValueDepOn (e : LetValue pu) : M Bool :=
match e with
| .erased | .lit .. => return false
| .proj _ _ fvarId _ | .oproj _ fvarId _ | .uproj _ fvarId _ | .sproj _ _ fvarId _
| .reset _ fvarId _ | .box _ fvarId _ | .unbox fvarId _ => fvarDepOn fvarId
| .reset _ fvarId _ | .box _ fvarId _ | .unbox fvarId _ | .isShared fvarId _ => fvarDepOn fvarId
| .fvar fvarId args | .reuse fvarId _ _ args _ => fvarDepOn fvarId <||> args.anyM argDepOn
| .const _ _ args _ | .ctor _ args _ | .fap _ args _ | .pap _ args _ => args.anyM argDepOn
@ -46,8 +46,12 @@ private partial def depOn (c : Code pu) : M Bool :=
| .jmp fvarId args => fvarDepOn fvarId <||> args.anyM argDepOn
| .return fvarId => fvarDepOn fvarId
| .unreach _ => return false
| .sset fv1 _ _ fv2 _ k _ | .uset fv1 _ fv2 k _ => fvarDepOn fv1 <||> fvarDepOn fv2 <||> depOn k
| .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) .. =>
| .oset fv1 _ fv2 k _ =>
fvarDepOn fv1 <||> argDepOn fv2 <||> depOn k
| .sset fv1 _ _ fv2 _ k _ | .uset fv1 _ fv2 k _ =>
fvarDepOn fv1 <||> fvarDepOn fv2 <||> depOn k
| .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) ..
| .del (fvarId := fvarId) (k := k) .. | .setTag (fvarId := fvarId) (k := k) .. =>
fvarDepOn fvarId <||> depOn k
@[inline] def Arg.dependsOn (arg : Arg pu) (s : FVarIdSet) : Bool :=
@ -66,9 +70,14 @@ def CodeDecl.dependsOn (decl : CodeDecl pu) (s : FVarIdSet) : Bool :=
match decl with
| .let decl => decl.dependsOn s
| .jp decl | .fun decl _ => decl.dependsOn s
| .uset (fvarId := fvarId) (y := y) .. | .sset (fvarId := fvarId) (y := y) .. =>
| .oset (fvarId := fvarId) (y := y) .. =>
s.contains fvarId || y.dependsOn s
| .uset (fvarId := fvarId) (y := y) ..
| .sset (fvarId := fvarId) (y := y) .. =>
s.contains fvarId || s.contains y
| .inc (fvarId := fvarId) .. | .dec (fvarId := fvarId) .. => s.contains fvarId
| .inc (fvarId := fvarId) .. | .dec (fvarId := fvarId) .. | .del (fvarId := fvarId) ..
| .setTag (fvarId := fvarId) .. =>
s.contains fvarId
/--
Return `true` is `c` depends on a free variable in `s`.

View file

@ -35,7 +35,7 @@ def collectLocalDeclsLetValue (s : UsedLocalDecls) (e : LetValue pu) : UsedLocal
match e with
| .erased | .lit .. => s
| .proj _ _ fvarId _ | .reset _ fvarId _ | .sproj _ _ fvarId _ | .uproj _ fvarId _
| .oproj _ fvarId _ | .box _ fvarId _ | .unbox fvarId _ => s.insert fvarId
| .oproj _ fvarId _ | .box _ fvarId _ | .unbox fvarId _ | .isShared fvarId _ => s.insert fvarId
| .const _ _ args _ => collectLocalDeclsArgs s args
| .fvar fvarId args | .reuse fvarId _ _ args _ => collectLocalDeclsArgs (s.insert fvarId) args
| .fap _ args _ | .pap _ args _ | .ctor _ args _ => collectLocalDeclsArgs s args
@ -56,9 +56,8 @@ def LetValue.safeToElim (val : LetValue pu) : Bool :=
| .pure => true
| .impure =>
match val with
-- TODO | .isShared ..
| .ctor .. | .reset .. | .reuse .. | .oproj .. | .uproj .. | .sproj .. | .lit .. | .pap ..
| .box .. | .unbox .. | .erased .. => true
| .box .. | .unbox .. | .erased .. | .isShared .. => true
-- 0-ary full applications are considered constants
| .fap _ args => args.isEmpty
| .fvar .. => false
@ -95,6 +94,13 @@ partial def Code.elimDead (code : Code pu) : M (Code pu) := do
| .return fvarId => collectFVarM fvarId; return code
| .jmp fvarId args => collectFVarM fvarId; args.forM collectArgM; return code
| .unreach .. => return code
| .oset fvarId _ y k _ =>
let k ← k.elimDead
if (← get).contains fvarId then
collectArgM y
return code.updateCont! k
else
return k
| .uset fvarId _ y k _ | .sset fvarId _ _ y _ k _ =>
let k ← k.elimDead
if (← get).contains fvarId then
@ -102,7 +108,8 @@ partial def Code.elimDead (code : Code pu) : M (Code pu) := do
return code.updateCont! k
else
return k
| .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) .. =>
| .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) ..
| .setTag (fvarId := fvarId) (k := k) .. | .del (fvarId := fvarId) (k := k) .. =>
let k ← k.elimDead
collectFVarM fvarId
return code.updateCont! k

View file

@ -284,7 +284,7 @@ partial def Code.explicitBoxing (code : Code .impure) : BoxM (Code .impure) := d
let some jpDecl ← findFunDecl? fvarId | unreachable!
castArgsIfNeeded args jpDecl.params fun args => return code.updateJmp! fvarId args
| .unreach .. => return code.updateUnreach! (← getResultType)
| .inc .. | .dec .. => unreachable!
| .inc .. | .dec .. | .oset .. | .setTag .. | .del .. => unreachable!
where
/--
Up to this point the type system of IR is quite loose so we can for example encounter situations
@ -313,7 +313,7 @@ where
| .ctor i _ => return i.type
| .fvar .. | .lit .. | .sproj .. | .oproj .. | .reset .. | .reuse .. =>
return currentType
| .box .. | .unbox .. => unreachable!
| .box .. | .unbox .. | .isShared .. => unreachable!
visitLet (code : Code .impure) (decl : LetDecl .impure) (k : Code .impure) : BoxM (Code .impure) := do
let type ← tryCorrectLetDeclType decl.type decl.value
@ -350,7 +350,7 @@ where
| .erased | .reset .. | .sproj .. | .uproj .. | .oproj .. | .lit .. =>
let decl ← decl.update type decl.value
return code.updateLet! decl k
| .box .. | .unbox .. => unreachable!
| .box .. | .unbox .. | .isShared .. => unreachable!
def run (decls : Array (Decl .impure)) : CompilerM (Array (Decl .impure)) := do
let decls ← decls.foldlM (init := #[]) fun newDecls decl => do

View file

@ -117,7 +117,7 @@ partial def collectCode (code : Code .impure) : M Unit := do
| .cases cases => cases.alts.forM (·.forCodeM collectCode)
| .sset (k := k) .. | .uset (k := k) .. => collectCode k
| .return .. | .jmp .. | .unreach .. => return ()
| .inc .. | .dec .. => unreachable!
| .inc .. | .dec .. | .setTag .. | .oset .. | .del .. => unreachable!
/--
Collect the derived value tree as well as the set of parameters that take objects and are borrowed.
@ -334,6 +334,7 @@ def useLetValue (value : LetValue .impure) : RcM Unit := do
useVar fvarId
useArgs args
| .lit .. | .erased => return ()
| .isShared .. => unreachable!
@[inline]
def bindVar (fvarId : FVarId) : RcM Unit :=
@ -547,6 +548,7 @@ def LetDecl.explicitRc (code : Code .impure) (decl : LetDecl .impure) (k : Code
addIncBeforeConsumeAll allArgs (code.updateLet! decl k)
| .lit .. | .box .. | .reset .. | .erased .. =>
pure <| code.updateLet! decl k
| .isShared .. => unreachable!
useLetValue decl.value
bindVar decl.fvarId
return k
@ -622,7 +624,7 @@ partial def Code.explicitRc (code : Code .impure) : RcM (Code .impure) := do
| .unreach .. =>
setRetLiveVars
return code
| .inc .. | .dec .. => unreachable!
| .inc .. | .dec .. | .setTag .. | .del .. | .oset .. => unreachable!
def Decl.explicitRc (decl : Decl .impure) :
CompilerM (Decl .impure) := do

View file

@ -83,12 +83,13 @@ def LetValue.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarI
return e.updateReuse! (← f fvarId) i updateHeader (← args.mapM (TraverseFVar.mapFVarM f))
| .box ty fvarId _ => return e.updateBox! ty (← f fvarId)
| .unbox fvarId _ => return e.updateUnbox! (← f fvarId)
| .isShared fvarId _ => return e.updateIsShared! (← f fvarId)
def LetValue.forFVarM [Monad m] (f : FVarId → m Unit) (e : LetValue pu) : m Unit := do
match e with
| .lit .. | .erased => return ()
| .proj _ _ fvarId _ | .oproj _ fvarId _ | .sproj _ _ fvarId _ | .uproj _ fvarId _
| .reset _ fvarId _ | .box _ fvarId _ | .unbox fvarId _ => f fvarId
| .reset _ fvarId _ | .box _ fvarId _ | .unbox fvarId _ | .isShared fvarId _ => f fvarId
| .const _ _ args _ | .pap _ args _ | .fap _ args _ | .ctor _ args _ =>
args.forM (TraverseFVar.forFVarM f)
| .fvar fvarId args | .reuse fvarId _ _ args _ => f fvarId; args.forM (TraverseFVar.forFVarM f)
@ -139,14 +140,20 @@ partial def Code.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m F
return Code.updateReturn! c (← f var)
| .unreach typ =>
return Code.updateUnreach! c (← Expr.mapFVarM f typ)
| .oset fvarId offset y k _ =>
return Code.updateOset! c (← f fvarId) offset (← y.mapFVarM f) (← mapFVarM f k)
| .sset fvarId i offset y ty k _ =>
return Code.updateSset! c (← f fvarId) i offset (← f y) (← Expr.mapFVarM f ty) (← mapFVarM f k)
| .uset fvarId offset y k _ =>
return Code.updateUset! c (← f fvarId) offset (← f y) (← mapFVarM f k)
| .setTag fvarId cidx k _ =>
return Code.updateSetTag! c (← f fvarId) cidx (← mapFVarM f k)
| .inc fvarId n check persistent k _ =>
return Code.updateInc! c (← f fvarId) n check persistent (← mapFVarM f k)
| .dec fvarId n check persistent k _ =>
return Code.updateDec! c (← f fvarId) n check persistent (← mapFVarM f k)
| .del fvarId k _ =>
return Code.updateDel! c (← f fvarId) (← mapFVarM f k)
partial def Code.forFVarM [Monad m] (f : FVarId → m Unit) (c : Code pu) : m Unit := do
match c with
@ -182,7 +189,12 @@ partial def Code.forFVarM [Monad m] (f : FVarId → m Unit) (c : Code pu) : m Un
f fvarId
f y
forFVarM f k
| .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) .. =>
| .oset fvarId _ y k _ =>
f fvarId
y.forFVarM f
forFVarM f k
| .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) ..
| .del (fvarId := fvarId) (k := k) .. | .setTag (fvarId := fvarId) (k := k) .. =>
f fvarId
forFVarM f k
@ -210,17 +222,22 @@ instance : TraverseFVar (CodeDecl pu) where
| .jp decl => return .jp (← mapFVarM f decl)
| .let decl => return .let (← mapFVarM f decl)
| .uset fvarId i y _ => return .uset (← f fvarId) i (← f y)
| .oset fvarId i y _ => return .oset (← f fvarId) i (← y.mapFVarM f)
| .sset fvarId i offset y ty _ => return .sset (← f fvarId) i offset (← f y) (← mapFVarM f ty)
| .setTag fvarId cidx _ => return .setTag (← f fvarId) cidx
| .inc fvarId n check persistent _ => return .inc (← f fvarId) n check persistent
| .dec fvarId n check persistent _ => return .dec (← f fvarId) n check persistent
| .del fvarId _ => return .del (← f fvarId)
forFVarM f decl :=
match decl with
| .fun decl _ => forFVarM f decl
| .jp decl => forFVarM f decl
| .let decl => forFVarM f decl
| .uset fvarId i y _ => do f fvarId; f y
| .oset fvarId i y _ => do f fvarId; y.forFVarM f
| .sset fvarId i offset y ty _ => do f fvarId; f y; forFVarM f ty
| .inc (fvarId := fvarId) .. | .dec (fvarId := fvarId) .. => f fvarId
| .inc (fvarId := fvarId) .. | .dec (fvarId := fvarId) .. | .del (fvarId := fvarId) ..
| .setTag (fvarId := fvarId) .. => f fvarId
instance : TraverseFVar (Alt pu) where
mapFVarM f alt := do

View file

@ -91,7 +91,7 @@ where
| .cases cs => cs.alts.forM (·.forCodeM (goCode declName))
| .let _ k | .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => goCode declName k
| .return .. | .jmp .. | .unreach .. => return ()
| .inc .. | .dec .. => unreachable!
| .inc .. | .dec .. | .setTag .. | .del .. | .oset .. => unreachable!
/--
Apply the inferred borrow annotations from `map` to a SCC.
@ -121,7 +121,7 @@ where
| .cases cs => return code.updateAlts! <| ← cs.alts.mapMonoM (·.mapCodeM (go declName))
| .let _ k | .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => return code.updateCont! (← go declName k)
| .return .. | .jmp .. | .unreach .. => return code
| .inc .. | .dec .. => unreachable!
| .inc .. | .dec .. | .setTag .. | .oset .. | .del .. => unreachable!
structure Ctx where
/--
@ -300,7 +300,7 @@ where
| .cases cs => cs.alts.forM (·.forCodeM collectCode)
| .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => collectCode k
| .return .. | .unreach .. => return ()
| .inc .. | .dec .. => unreachable!
| .inc .. | .dec .. | .setTag .. | .oset .. | .del .. => unreachable!
public def inferBorrow : Pass where

View file

@ -120,6 +120,10 @@ private partial def internalizeLetValue (e : LetValue pu) : InternalizeM pu (Let
match (← normFVar fvarId) with
| .fvar fvarId' => return e.updateBox! ty fvarId'
| .erased => return .erased
| .isShared fvarId _ =>
match (← normFVar fvarId) with
| .fvar fvarId' => return e.updateIsShared! fvarId'
| .erased => return .erased
def internalizeLetDecl (decl : LetDecl pu) : InternalizeM pu (LetDecl pu) := do
let binderName ← refreshBinderName decl.binderName
@ -166,12 +170,22 @@ partial def internalizeCode (code : Code pu) : InternalizeM pu (Code pu) := do
withNormFVarResult (← normFVar fvarId) fun fvarId => do
withNormFVarResult (← normFVar y) fun y => do
return .uset fvarId offset y (← internalizeCode k)
| .oset fvarId offset y k _ =>
withNormFVarResult (← normFVar fvarId) fun fvarId => do
let y ← normArg y
return .oset fvarId offset y (← internalizeCode k)
| .setTag fvarId cidx k _ =>
withNormFVarResult (← normFVar fvarId) fun fvarId => do
return .setTag fvarId cidx (← internalizeCode k)
| .inc fvarId n check persistent k _ =>
withNormFVarResult (← normFVar fvarId) fun fvarId => do
return .inc fvarId n check persistent (← internalizeCode k)
| .dec fvarId n check persistent k _ =>
withNormFVarResult (← normFVar fvarId) fun fvarId => do
return .dec fvarId n check persistent (← internalizeCode k)
| .del fvarId k _ =>
withNormFVarResult (← normFVar fvarId) fun fvarId => do
return .del fvarId (← internalizeCode k)
end
@ -180,8 +194,12 @@ partial def internalizeCodeDecl (decl : CodeDecl pu) : InternalizeM pu (CodeDecl
| .let decl => return .let (← internalizeLetDecl decl)
| .fun decl _ => return .fun (← internalizeFunDecl decl)
| .jp decl => return .jp (← internalizeFunDecl decl)
| .uset fvarId i y _ =>
| .oset fvarId i y _ =>
-- Something weird should be happening if these become erased...
let .fvar fvarId ← normFVar fvarId | unreachable!
let y ← normArg y
return .oset fvarId i y
| .uset fvarId i y _ =>
let .fvar fvarId ← normFVar fvarId | unreachable!
let .fvar y ← normFVar y | unreachable!
return .uset fvarId i y
@ -190,12 +208,18 @@ partial def internalizeCodeDecl (decl : CodeDecl pu) : InternalizeM pu (CodeDecl
let .fvar y ← normFVar y | unreachable!
let ty ← normExpr ty
return .sset fvarId i offset y ty
| .setTag fvarId cidx _ =>
let .fvar fvarId ← normFVar fvarId | unreachable!
return .setTag fvarId cidx
| .inc fvarId n check offset _ =>
let .fvar fvarId ← normFVar fvarId | unreachable!
return .inc fvarId n check offset
| .dec fvarId n check offset _ =>
let .fvar fvarId ← normFVar fvarId | unreachable!
return .dec fvarId n check offset
| .del fvarId _ =>
let .fvar fvarId ← normFVar fvarId | unreachable!
return .del fvarId
end Internalize

View file

@ -77,7 +77,8 @@ mutual
| .let decl k => eraseCode k <| lctx.eraseLetDecl decl
| .jp decl k | .fun decl k _ => eraseCode k <| eraseFunDecl lctx decl
| .cases c => eraseAlts c.alts lctx
| .uset (k := k) .. | .sset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. =>
| .oset (k := k) .. | .uset (k := k) .. | .sset (k := k) .. | .inc (k := k) ..
| .dec (k := k) .. | .del (k := k) .. | .setTag (k := k) .. =>
eraseCode k lctx
| .return .. | .jmp .. | .unreach .. => lctx
end

View file

@ -65,6 +65,8 @@ where
| .jp decl k => go decl.value <||> (do markJpVisited decl.fvarId; go k)
| .uset fvarId _ y k _ | .sset fvarId _ _ y _ k _ =>
visitVar fvarId <||> visitVar y <||> go k
| .oset fvarId _ y k _ =>
visitVar fvarId <||> pure (y.dependsOn (← read).targetSet) <||> go k
| .cases c => visitVar c.discr <||> c.alts.anyM (go ·.getCode)
| .jmp fvarId args =>
(pure <| args.any (·.dependsOn (← read).targetSet)) <||> do
@ -76,7 +78,8 @@ where
go decl.value
| .return var => visitVar var
| .unreach .. => return false
| .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) .. =>
| .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) ..
| .setTag (fvarId := fvarId) (k := k) .. | .del (fvarId := fvarId) (k := k) =>
visitVar fvarId <||> go k
@[inline]

View file

@ -94,6 +94,7 @@ def ppLetValue (e : LetValue pu) : M Format := do
return f!"reuse" ++ (if updateHeader then f!"!" else f!"") ++ f!" {← ppFVar fvarId} in {info}{← ppArgs args}"
| .box _ fvarId _ => return f!"box {← ppFVar fvarId}"
| .unbox fvarId _ => return f!"unbox {← ppFVar fvarId}"
| .isShared fvarId _ => return f!"isShared {← ppFVar fvarId}"
def ppParam (param : Param pu) : M Format := do
let borrow := if param.borrow then "@&" else ""
@ -149,6 +150,10 @@ mutual
return f!"sset {← ppFVar fvarId}[{i}, {offset}] := {← ppFVar y};" ++ .line ++ (← ppCode k)
| .uset fvarId i y k _ =>
return f!"uset {← ppFVar fvarId}[{i}] := {← ppFVar y};" ++ .line ++ (← ppCode k)
| .oset fvarId i y k _ =>
return f!"oset {← ppFVar fvarId} [{i}] := {← ppArg y};" ++ .line ++ (← ppCode k)
| .setTag fvarId cidx k _ =>
return f!"setTag {← ppFVar fvarId} := {cidx};" ++ .line ++ (← ppCode k)
| .inc fvarId n _ _ k _ =>
if n != 1 then
return f!"inc[{n}] {← ppFVar fvarId};" ++ .line ++ (← ppCode k)
@ -159,6 +164,8 @@ mutual
return f!"dec[{n}] {← ppFVar fvarId};" ++ .line ++ (← ppCode k)
else
return f!"dec {← ppFVar fvarId};" ++ .line ++ (← ppCode k)
| .del fvarId k _ =>
return f!"del {← ppFVar fvarId};" ++ .line ++ (← ppCode k)
partial def ppDeclValue (b : DeclValue pu) : M Format := do

View file

@ -58,7 +58,8 @@ where
go k
| .cases cs => cs.alts.forM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return ()
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. | .del (k := k) ..
| .setTag (k := k) .. | .oset (k := k) .. => go k
start (decls : Array (Decl pu)) : StateRefT (Array (LetValue pu)) CompilerM Unit :=
decls.forM (·.value.forCodeM go)
@ -73,7 +74,8 @@ where
| .jp decl k => modify (·.push decl); go decl.value; go k
| .cases cs => cs.alts.forM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return ()
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. | .del (k := k) ..
| .setTag (k := k) .. | .oset (k := k) .. => go k
start (decls : Array (Decl pu)) : StateRefT (Array (FunDecl pu)) CompilerM Unit :=
decls.forM (·.value.forCodeM go)
@ -86,7 +88,8 @@ where
| .fun decl k _ | .jp decl k => go decl.value <||> go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return false
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. | .del (k := k) ..
| .setTag (k := k) .. | .oset (k := k) .. => go k
partial def filterByFun (pu : Purity) (f : FunDecl pu → CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
filter (·.value.isCodeAndM go)
@ -96,7 +99,8 @@ where
| .fun decl k _ => do if (← f decl) then return true else go decl.value <||> go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return false
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. | .del (k := k) ..
| .setTag (k := k) .. | .oset (k := k) .. => go k
partial def filterByJp (pu : Purity) (f : FunDecl pu → CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
filter (·.value.isCodeAndM go)
@ -107,7 +111,8 @@ where
| .jp decl k => do if (← f decl) then return true else go decl.value <||> go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return false
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. | .del (k := k) ..
| .setTag (k := k) .. | .oset (k := k) .. => go k
partial def filterByFunDecl (pu : Purity) (f : FunDecl pu → CompilerM Bool) :
Probe (Decl pu) (Decl pu):=
@ -118,7 +123,8 @@ where
| .fun decl k _ | .jp decl k => do if (← f decl) then return true else go decl.value <||> go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return false
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. | .del (k := k) ..
| .setTag (k := k) .. | .oset (k := k) .. => go k
partial def filterByCases (pu : Purity) (f : Cases pu → CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
filter (·.value.isCodeAndM go)
@ -128,7 +134,8 @@ where
| .fun decl k _ | .jp decl k => go decl.value <||> go k
| .cases cs => do if (← f cs) then return true else cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return false
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. | .del (k := k) ..
| .setTag (k := k) .. | .oset (k := k) .. => go k
partial def filterByJmp (pu : Purity) (f : FVarId → Array (Arg pu) → CompilerM Bool) :
Probe (Decl pu) (Decl pu) :=
@ -140,7 +147,8 @@ where
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp fn var => f fn var
| .return .. | .unreach .. => return false
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. | .del (k := k) ..
| .setTag (k := k) .. | .oset (k := k) .. => go k
partial def filterByReturn (pu : Purity) (f : FVarId → CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
filter (·.value.isCodeAndM go)
@ -151,7 +159,8 @@ where
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .unreach .. => return false
| .return var => f var
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. | .del (k := k) ..
| .setTag (k := k) .. | .oset (k := k) .. => go k
partial def filterByUnreach (pu : Purity) (f : Expr → CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
filter (·.value.isCodeAndM go)
@ -162,7 +171,8 @@ where
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. => return false
| .unreach typ => f typ
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. | .del (k := k) ..
| .setTag (k := k) .. | .oset (k := k) .. => go k
@[inline]
def declNames (pu : Purity) : Probe (Decl pu) Name :=

View file

@ -133,6 +133,8 @@ where
| .jp decl k =>
let decl ← decl.updateValue (← decl.value.pushProj)
go k (decls.push (.jp decl))
| .oset fvarId i y k _ =>
go k (decls.push (.oset fvarId i y))
| .uset fvarId i y k _ =>
go k (decls.push (.uset fvarId i y))
| .sset fvarId i offset y ty k _ =>
@ -141,6 +143,10 @@ where
go k (decls.push (.inc fvarId n check persistent))
| .dec fvarId n check persistent k _ =>
go k (decls.push (.dec fvarId n check persistent))
| .del fvarId k _ =>
go k (decls.push (.del fvarId))
| .setTag fvarId cidx k _ =>
go k (decls.push (.setTag fvarId cidx))
| .cases c => c.pushProjs decls
| .jmp .. | .return .. | .unreach .. =>
return attachCodeDecls decls c

View file

@ -53,7 +53,8 @@ partial def Code.applyRenaming (code : Code pu) (r : Renaming) : CompilerM (Code
| .ctorAlt _ k _ => return alt.updateCode (← k.applyRenaming r)
return code.updateAlts! alts
| .jmp .. | .unreach .. | .return .. => return code
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. =>
| .oset (k := k) .. | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) ..
| .del (k := k) .. | .setTag (k := k) .. =>
return code.updateCont! (← k.applyRenaming r)
end

View file

@ -120,7 +120,7 @@ where
| .return .. | .jmp .. | .unreach .. => return (c, false)
| .sset _ _ _ _ _ k _ | .uset _ _ _ k _ | .let _ k =>
goK k
| .inc .. | .dec .. => unreachable!
| .inc .. | .dec .. | .setTag .. | .oset .. | .del .. => unreachable!
def isCtorUsing (instr : CodeDecl .impure) (x : FVarId) : Bool :=
match instr with
@ -242,7 +242,7 @@ where
return (c.updateCont! k, false)
| .return .. | .jmp .. | .unreach .. =>
return (c, ← c.isFVarLiveIn x)
| .inc .. | .dec .. => unreachable!
| .inc .. | .dec .. | .setTag .. | .oset .. | .del .. => unreachable!
end
@ -275,7 +275,7 @@ partial def Code.insertResetReuse (c : Code .impure) : ReuseM (Code .impure) :=
| .let _ k | .uset _ _ _ k _ | .sset _ _ _ _ _ k _ =>
return c.updateCont! (← k.insertResetReuse)
| .return .. | .jmp .. | .unreach .. => return c
| .inc .. | .dec .. => unreachable!
| .inc .. | .dec .. | .setTag .. | .oset .. | .del .. => unreachable!
partial def Decl.insertResetReuseCore (decl : Decl .impure) : ReuseM (Decl .impure) := do
let value ← decl.value.mapCodeM fun code => do
@ -298,7 +298,7 @@ where
| .jp decl k => collectResets decl.value; collectResets k
| .cases c => c.alts.forM (collectResets ·.getCode)
| .unreach .. | .return .. | .jmp .. => return ()
| .inc .. | .dec .. => unreachable!
| .inc .. | .dec .. | .setTag .. | .oset .. | .del .. => unreachable!
def Decl.insertResetReuse (decl : Decl .impure) : CompilerM (Decl .impure) := do

View file

@ -107,7 +107,8 @@ partial def Code.simpCase (code : Code .impure) : CompilerM (Code .impure) := do
let decl ← decl.updateValue (← decl.value.simpCase)
return code.updateFun! decl (← k.simpCase)
| .return .. | .jmp .. | .unreach .. => return code
| .let _ k | .uset (k := k) .. | .sset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. =>
| .let _ k | .uset (k := k) .. | .sset (k := k) .. | .inc (k := k) .. | .dec (k := k) ..
| .setTag (k := k) .. | .del (k := k) .. | .oset (k := k) .. =>
return code.updateCont! (← k.simpCase)
def Decl.simpCase (decl : Decl .impure) : CompilerM (Decl .impure) := do

View file

@ -116,10 +116,18 @@ partial def Code.toExprM (code : Code pu) : ToExprM Expr := do
let value := mkApp5 (mkConst `sset) (.fvar fvarId) (toExpr i) (toExpr offset) (.fvar y) ty
let body ← withFVar fvarId k.toExprM
return .letE `dummy (mkConst ``Unit) value body true
| .oset fvarId offset y k _ =>
let value := mkApp3 (mkConst `oset) (.fvar fvarId) (toExpr offset) (← y.toExprM)
let body ← withFVar fvarId k.toExprM
return .letE `dummy (mkConst ``Unit) value body true
| .uset fvarId offset y k _ =>
let value := mkApp3 (mkConst `uset) (.fvar fvarId) (toExpr offset) (.fvar y)
let body ← withFVar fvarId k.toExprM
return .letE `dummy (mkConst ``Unit) value body true
| .setTag fvarId cidx k _ =>
let body ← withFVar fvarId k.toExprM
let value := mkApp2 (mkConst `setTag) (.fvar fvarId) (toExpr cidx)
return .letE `dummy (mkConst ``Unit) value body true
| .inc fvarId n check persistent k _ =>
let value := mkApp4 (mkConst `inc) (.fvar fvarId) (toExpr n) (toExpr check) (toExpr persistent)
let body ← withFVar fvarId k.toExprM
@ -128,6 +136,10 @@ partial def Code.toExprM (code : Code pu) : ToExprM Expr := do
let body ← withFVar fvarId k.toExprM
let value := mkApp4 (mkConst `dec) (.fvar fvarId) (toExpr n) (toExpr check) (toExpr persistent)
return .letE `dummy (mkConst ``Unit) value body true
| .del fvarId k _ =>
let body ← withFVar fvarId k.toExprM
let value := mkApp (mkConst `del) (.fvar fvarId)
return .letE `dummy (mkConst ``Unit) value body true
end
public def Code.toExpr (code : Code pu) (xs : Array FVarId := #[]) : Expr :=