feat: LCNF inc/dec instructions (#12550)

This PR adds `inc`/`dec` instructions to LCNF. It should be a functional
no-op.
This commit is contained in:
Henrik Böving 2026-02-18 11:55:16 +01:00 committed by GitHub
parent 6c671ffe6f
commit ad64f7c1ba
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 172 additions and 31 deletions

View file

@ -109,6 +109,12 @@ partial def lowerCode (c : LCNF.Code .impure) : M FnBody := do
let .var y ← getFVarValue y | unreachable!
let .var var ← getFVarValue var | unreachable!
return .uset var i y (← lowerCode k)
| .inc fvarId n check persistent k _ =>
let .var var ← getFVarValue fvarId | unreachable!
return .inc var n check persistent (← lowerCode k)
| .dec fvarId n check persistent k _ =>
let .var var ← getFVarValue fvarId | unreachable!
return .dec var n check persistent (← lowerCode k)
| .fun .. => panic! "all local functions should be λ-lifted"
partial def lowerLet (decl : LCNF.LetDecl .impure) (k : LCNF.Code .impure) : M FnBody := do

View file

@ -155,6 +155,18 @@ partial def eqv (code₁ code₂ : Code pu) : EqvM Bool := do
eqvFVar var₁ var₂ <&&>
eqvFVar y₁ y₂ <&&>
eqv k₁ k₂
| .inc fvarId₁ n₁ c₁ p₁ k₁ _, .inc fvarId₂ n₂ c₂ p₂ k₂ _ =>
pure (n₁ == n₂) <&&>
pure (c₁ == c₂) <&&>
pure (p₁ == p₂) <&&>
eqvFVar fvarId₁ fvarId₂ <&&>
eqv k₁ k₂
| .dec fvarId₁ n₁ c₁ p₁ k₁ _, .dec fvarId₂ n₂ c₂ p₂ k₂ _ =>
pure (n₁ == n₂) <&&>
pure (c₁ == c₂) <&&>
pure (p₁ == p₂) <&&>
eqvFVar fvarId₁ fvarId₂ <&&>
eqv k₁ k₂
| _, _ => return false
end

View file

@ -363,6 +363,8 @@ inductive Code (pu : Purity) where
| unreach (type : Expr)
| uset (var : FVarId) (i : Nat) (y : FVarId) (k : Code pu) (h : pu = .impure := by purity_tac)
| sset (var : FVarId) (i : Nat) (offset : Nat) (y : FVarId) (ty : Expr) (k : Code pu) (h : pu = .impure := by purity_tac)
| inc (fvarId : FVarId) (n : Nat) (check : Bool) (persistent : Bool) (k : Code pu) (h : pu = .impure := by purity_tac)
| dec (fvarId : FVarId) (n : Nat) (check : Bool) (persistent : Bool) (k : Code pu) (h : pu = .impure := by purity_tac)
deriving Inhabited
end
@ -440,11 +442,13 @@ inductive CodeDecl (pu : Purity) where
| jp (decl : FunDecl pu)
| uset (var : FVarId) (i : Nat) (y : FVarId) (h : pu = .impure := by purity_tac)
| sset (var : FVarId) (i : Nat) (offset : Nat) (y : FVarId) (ty : Expr) (h : pu = .impure := by purity_tac)
| inc (fvarId : FVarId) (n : Nat) (check : Bool) (persistent : Bool) (h : pu = .impure := by purity_tac)
| dec (fvarId : FVarId) (n : Nat) (check : Bool) (persistent : Bool) (h : pu = .impure := by purity_tac)
deriving Inhabited
def CodeDecl.fvarId : CodeDecl pu → FVarId
| .let decl | .fun decl _ | .jp decl => decl.fvarId
| .uset var .. | .sset var .. => var
| .uset fvarId .. | .sset fvarId .. | .inc fvarId .. | .dec fvarId .. => fvarId
def Code.toCodeDecl! : Code pu → CodeDecl pu
| .let decl _ => .let decl
@ -452,6 +456,8 @@ def Code.toCodeDecl! : Code pu → CodeDecl pu
| .jp decl _ => .jp decl
| .uset var i y _ _ => .uset var i y
| .sset var i offset ty y _ _ => .sset var i offset ty y
| .inc fvarId n check persistent _ _ => .inc fvarId n check persistent
| .dec fvarId n check persistent _ _ => .dec fvarId n check persistent
| _ => unreachable!
def attachCodeDecls (decls : Array (CodeDecl pu)) (code : Code pu) : Code pu :=
@ -465,6 +471,8 @@ where
| .jp decl => go (i-1) (.jp decl code)
| .uset var idx y _ => go (i-1) (.uset var idx y code)
| .sset var idx offset y ty _ => go (i-1) (.sset var idx offset y ty code)
| .inc fvarId n check persistent _ => go (i-1) (.inc fvarId n check persistent code)
| .dec fvarId n check persistent _ => go (i-1) (.dec fvarId n check persistent code)
else
code
@ -484,6 +492,10 @@ mutual
v₁ == v₂ && i₁ == i₂ && y₁ == y₂ && eqImp k₁ k₂
| .sset v₁ i₁ o₁ y₁ ty₁ k₁ _, .sset v₂ i₂ o₂ y₂ ty₂ k₂ _ =>
v₁ == v₂ && i₁ == i₂ && o₁ == o₂ && y₁ == y₂ && ty₁ == ty₂ && eqImp k₁ k₂
| .inc v₁ n₁ c₁ p₁ k₁ _, .inc v₂ n₂ c₂ p₂ k₂ _ =>
v₁ == v₂ && n₁ == n₂ && c₁ == c₂ && p₁ == p₂ && eqImp k₁ k₂
| .dec v₁ n₁ c₁ p₁ k₁ _, .dec v₂ n₂ c₂ p₂ k₂ _ =>
v₁ == v₂ && n₁ == n₂ && c₁ == c₂ && p₁ == p₂ && eqImp k₁ k₂
| _, _ => false
private unsafe def eqFunDecl (d₁ d₂ : FunDecl pu) : Bool :=
@ -578,6 +590,8 @@ private unsafe def updateAltImp (alt : Alt pu) (ps' : Array (Param pu)) (k' : Co
| .jp decl k => if ptrEq k k' then c else .jp decl k'
| .sset fvarId i offset y ty k _ => if ptrEq k k' then c else .sset fvarId i offset y ty k'
| .uset fvarId offset y k _ => if ptrEq k k' then c else .uset fvarId offset y k'
| .inc fvarId n check persistent k _ => if ptrEq k k' then c else .inc fvarId n check persistent k'
| .dec fvarId n check persistent k _ => if ptrEq k k' then c else .dec fvarId n check persistent k'
| _ => unreachable!
@[implemented_by updateContImp] opaque Code.updateCont! (c : Code pu) (k' : Code pu) : Code pu
@ -637,6 +651,40 @@ private unsafe def updateAltImp (alt : Alt pu) (ps' : Array (Param pu)) (k' : Co
@[implemented_by updateUsetImp] opaque Code.updateUset! (c : Code pu) (fvarId' : FVarId)
(i' : Nat) (y' : FVarId) (k' : Code pu) : Code pu
@[inline] private unsafe def updateIncImp (c : Code pu) (fvarId' : FVarId) (n' : Nat)
(check' : Bool) (persistent' : Bool) (k' : Code pu) : Code pu :=
match c with
| .inc fvarId n check persistent k _ =>
if ptrEq fvarId fvarId'
&& n == n'
&& check == check'
&& persistent == persistent'
&& ptrEq k k' then
c
else
.inc fvarId' n' check' persistent' k'
| _ => unreachable!
@[implemented_by updateIncImp] opaque Code.updateInc! (c : Code pu) (fvarId' : FVarId) (n' : Nat)
(check' : Bool) (persistent' : Bool) (k' : Code pu) : Code pu
@[inline] private unsafe def updateDecImp (c : Code pu) (fvarId' : FVarId) (n' : Nat)
(check' : Bool) (persistent' : Bool) (k' : Code pu) : Code pu :=
match c with
| .dec fvarId n check persistent k _ =>
if ptrEq fvarId fvarId'
&& n == n'
&& check == check'
&& persistent == persistent'
&& ptrEq k k' then
c
else
.dec fvarId' n' check' persistent' k'
| _ => unreachable!
@[implemented_by updateDecImp] opaque Code.updateDec! (c : Code pu) (fvarId' : FVarId) (n' : Nat)
(check' : Bool) (persistent' : Bool) (k' : Code pu) : Code pu
private unsafe def updateParamCoreImp (p : Param pu) (type : Expr) : Param pu :=
if ptrEq type p.type then
p
@ -705,7 +753,8 @@ partial def Code.size (c : Code pu) : Nat :=
where
go (c : Code pu) (n : Nat) : Nat :=
match c with
| .let _ k | .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => go k (n + 1)
| .let (k := k) .. | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) ..
| .dec (k := k) .. => go k (n + 1)
| .jp decl k | .fun decl k _ => go k <| go decl.value n
| .cases c => c.alts.foldl (init := n+1) fun n alt => go alt.getCode (n+1)
| .jmp .. => n+1
@ -723,7 +772,8 @@ where
go (c : Code pu) : EStateM Unit Nat Unit := do
match c with
| .let _ k | .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => inc; go k
| .let (k := k) .. | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) ..
| .dec (k := k) .. => inc; go k
| .jp decl k | .fun decl k _ => inc; go decl.value; go k
| .cases c => inc; c.alts.forM fun alt => go alt.getCode
| .jmp .. => inc
@ -735,7 +785,8 @@ where
go (c : Code pu) : m Unit := do
f c
match c with
| .let _ k | .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => go k
| .let (k := k) .. | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) ..
| .dec (k := k) .. => go k
| .fun decl k _ | .jp decl k => go decl.value; go k
| .cases c => c.alts.forM fun alt => go alt.getCode
| .unreach .. | .return .. | .jmp .. => return ()
@ -1031,6 +1082,8 @@ partial def Code.collectUsed (code : Code pu) (s : FVarIdHashSet := {}) : FVarId
let s := s.insert var
let s := s.insert y
k.collectUsed s
| .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) .. =>
k.collectUsed <| s.insert fvarId
end
@[inline] def collectUsedAtExpr (s : FVarIdHashSet) (e : Expr) : FVarIdHashSet :=
@ -1040,8 +1093,8 @@ def CodeDecl.collectUsed (codeDecl : CodeDecl pu) (s : FVarIdHashSet := ∅) : F
match codeDecl with
| .let decl => collectLetValue decl.value <| collectType decl.type s
| .jp decl | .fun decl _ => decl.collectUsed s
| .sset var _ _ y ty _ => s.insert var |>.insert y |> collectType ty
| .uset var _ y _ => s.insert var |>.insert y
| .sset (var := var) (y := y) .. | .uset (var := var) (y := y) .. => s.insert var |>.insert y
| .inc (fvarId := fvarId) .. | .dec (fvarId := fvarId) .. => s.insert fvarId
/--
Traverse the given block of potentially mutually recursive functions
@ -1071,7 +1124,7 @@ where
modify fun s => s.insert declName
| _ => pure ()
visit k
| .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => visit k
| .uset (k := k) .. | .sset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => visit k
go : StateM NameSet Unit :=
decls.forM (·.value.forCodeM visit)

View file

@ -68,8 +68,8 @@ where
eraseCode k
eraseParam auxParam
return .unreach typeNew
| .sset fvarId i offset y ty k _ => return .sset fvarId i offset y ty (← go k)
| .uset fvarId offset y k _ => return .uset fvarId offset y (← go k)
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. =>
return c.updateCont! (← go k)
instance : MonadCodeBind CompilerM where
codeBind := CompilerM.codeBind

View file

@ -149,7 +149,7 @@ def eraseCodeDecl (decl : CodeDecl pu) : CompilerM Unit := do
match decl with
| .let decl => eraseLetDecl decl
| .jp decl | .fun decl _ => eraseFunDecl decl
| .sset .. | .uset .. => return ()
| .sset .. | .uset .. | .inc .. | .dec .. => return ()
/--
Erase all free variables occurring in `decls` from the local context.
@ -501,6 +501,12 @@ mutual
withNormFVarResult (← normFVar fvarId) fun fvarId => do
withNormFVarResult (← normFVar y) fun y => do
return code.updateUset! fvarId offset y (← normCodeImp k)
| .inc fvarId n check persistent k _ =>
withNormFVarResult (← normFVar fvarId) fun fvarId => do
return code.updateInc! fvarId n check persistent (← normCodeImp k)
| .dec fvarId n check persistent k _ =>
withNormFVarResult (← normFVar fvarId) fun fvarId => do
return code.updateDec! fvarId n check persistent (← normCodeImp k)
end
@[inline] def normFunDecl [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m pu t] (decl : FunDecl pu) : m (FunDecl pu) := do

View file

@ -41,6 +41,10 @@ partial def hashCode (code : Code pu) : UInt64 :=
mixHash (mixHash (hash fvarId) (hash i)) (mixHash (mixHash (hash offset) (hash y)) (mixHash (hash ty) (hashCode k)))
| .uset fvarId offset y k _ =>
mixHash (mixHash (hash fvarId) (hash offset)) (mixHash (hash y) (hashCode k))
| .inc fvarId n check persistent k _ =>
mixHash (mixHash (hash fvarId) (hash n)) (mixHash (mixHash (hash persistent) (hash check)) (hashCode k))
| .dec fvarId n check persistent k _ =>
mixHash (mixHash (hash fvarId) (hash n)) (mixHash (mixHash (hash persistent) (hash check)) (hashCode k))
end

View file

@ -47,6 +47,8 @@ private partial def depOn (c : Code pu) : M Bool :=
| .return fvarId => fvarDepOn fvarId
| .unreach _ => return false
| .sset fv1 _ _ fv2 _ k _ | .uset fv1 _ fv2 k _ => fvarDepOn fv1 <||> fvarDepOn fv2 <||> depOn k
| .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) .. =>
fvarDepOn fvarId <||> depOn k
@[inline] def Arg.dependsOn (arg : Arg pu) (s : FVarIdSet) : Bool :=
argDepOn arg s
@ -64,8 +66,8 @@ def CodeDecl.dependsOn (decl : CodeDecl pu) (s : FVarIdSet) : Bool :=
match decl with
| .let decl => decl.dependsOn s
| .jp decl | .fun decl _ => decl.dependsOn s
| .uset var _ y _ => s.contains var || s.contains y
| .sset var _ _ y ty _ => s.contains var || s.contains y || (typeDepOn ty s)
| .uset (var := var) (y := y) .. | .sset (var := var) (y := y) .. => s.contains var || s.contains y
| .inc (fvarId := fvarId) .. | .dec (fvarId := fvarId) .. => s.contains fvarId
/--
Return `true` is `c` depends on a free variable in `s`.

View file

@ -102,6 +102,10 @@ partial def Code.elimDead (code : Code pu) : M (Code pu) := do
return code.updateCont! k
else
return k
| .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) .. =>
let k ← k.elimDead
collectFVarM fvarId
return code.updateCont! k
end

View file

@ -292,6 +292,7 @@ partial def Code.explicitBoxing (code : Code .impure) : BoxM (Code .impure) := d
let some jpDecl ← findFunDecl? fvarId | unreachable!
castArgsIfNeeded args jpDecl.params fun args => return code.updateJmp! fvarId args
| .unreach .. => return code.updateUnreach! (← getResultType)
| .inc .. | .dec .. => unreachable!
where
/--
Up to this point the type system of IR is quite loose so we can for example encounter situations
@ -368,7 +369,7 @@ def run (decls : Array (Decl .impure)) : CompilerM (Array (Decl .impure)) := do
public def explicitBoxing : Pass where
phase := .impure
phaseOut := .impure
name := `boxing
name := `explicitBoxing
run := run
builtin_initialize

View file

@ -143,6 +143,10 @@ partial def Code.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m F
return Code.updateSset! c (← f fvarId) i offset (← f y) (← Expr.mapFVarM f ty) (← mapFVarM f k)
| .uset fvarId offset y k _ =>
return Code.updateUset! c (← f fvarId) offset (← f y) (← mapFVarM f k)
| .inc fvarId n check persistent k _ =>
return Code.updateInc! c (← f fvarId) n check persistent (← mapFVarM f k)
| .dec fvarId n check persistent k _ =>
return Code.updateDec! c (← f fvarId) n check persistent (← mapFVarM f k)
partial def Code.forFVarM [Monad m] (f : FVarId → m Unit) (c : Code pu) : m Unit := do
match c with
@ -178,6 +182,9 @@ partial def Code.forFVarM [Monad m] (f : FVarId → m Unit) (c : Code pu) : m Un
f fvarId
f y
forFVarM f k
| .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) .. =>
f fvarId
forFVarM f k
instance : TraverseFVar (Code pu) where
mapFVarM := Code.mapFVarM
@ -204,6 +211,8 @@ instance : TraverseFVar (CodeDecl pu) where
| .let decl => return .let (← mapFVarM f decl)
| .uset var i y _ => return .uset (← f var) i (← f y)
| .sset var i offset y ty _ => return .sset (← f var) i offset (← f y) (← mapFVarM f ty)
| .inc fvarId n check persistent _ => return .inc (← f fvarId) n check persistent
| .dec fvarId n check persistent _ => return .dec (← f fvarId) n check persistent
forFVarM f decl :=
match decl with
| .fun decl _ => forFVarM f decl
@ -211,6 +220,7 @@ instance : TraverseFVar (CodeDecl pu) where
| .let decl => forFVarM f decl
| .uset var i y _ => do f var; f y
| .sset var i offset y ty _ => do f var; f y; forFVarM f ty
| .inc (fvarId := fvarId) .. | .dec (fvarId := fvarId) .. => f fvarId
instance : TraverseFVar (Alt pu) where
mapFVarM f alt := do

View file

@ -91,6 +91,7 @@ where
| .cases cs => cs.alts.forM (·.forCodeM (goCode declName))
| .let _ k | .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => goCode declName k
| .return .. | .jmp .. | .unreach .. => return ()
| .inc .. | .dec .. => unreachable!
/--
Apply the inferred borrow annotations from `map` to a SCC.
@ -120,6 +121,7 @@ where
| .cases cs => return code.updateAlts! <| ← cs.alts.mapM (·.mapCodeM (go declName))
| .let _ k | .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => return code.updateCont! (← go declName k)
| .return .. | .jmp .. | .unreach .. => return code
| .inc .. | .dec .. => unreachable!
structure Ctx where
/--
@ -298,6 +300,7 @@ where
| .cases cs => cs.alts.forM (·.forCodeM collectCode)
| .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => collectCode k
| .return .. | .unreach .. => return ()
| .inc .. | .dec .. => unreachable!
public def inferBorrow : Pass where

View file

@ -166,6 +166,12 @@ partial def internalizeCode (code : Code pu) : InternalizeM pu (Code pu) := do
withNormFVarResult (← normFVar fvarId) fun fvarId => do
withNormFVarResult (← normFVar y) fun y => do
return .uset fvarId offset y (← internalizeCode k)
| .inc fvarId n check persistent k _ =>
withNormFVarResult (← normFVar fvarId) fun fvarId => do
return .inc fvarId n check persistent (← internalizeCode k)
| .dec fvarId n check persistent k _ =>
withNormFVarResult (← normFVar fvarId) fun fvarId => do
return .dec fvarId n check persistent (← internalizeCode k)
end
@ -184,6 +190,12 @@ partial def internalizeCodeDecl (decl : CodeDecl pu) : InternalizeM pu (CodeDecl
let .fvar y ← normFVar y | unreachable!
let ty ← normExpr ty
return .sset var i offset y ty
| .inc fvarId n check offset _ =>
let .fvar fvarId ← normFVar fvarId | unreachable!
return .inc fvarId n check offset
| .dec fvarId n check offset _ =>
let .fvar fvarId ← normFVar fvarId | unreachable!
return .dec fvarId n check offset
end Internalize

View file

@ -77,8 +77,9 @@ mutual
| .let decl k => eraseCode k <| lctx.eraseLetDecl decl
| .jp decl k | .fun decl k _ => eraseCode k <| eraseFunDecl lctx decl
| .cases c => eraseAlts c.alts lctx
| .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => eraseCode k lctx
| _ => lctx
| .uset (k := k) .. | .sset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. =>
eraseCode k lctx
| .return .. | .jmp .. | .unreach .. => lctx
end
@[inline]

View file

@ -76,6 +76,8 @@ where
go decl.value
| .return var => visitVar var
| .unreach .. => return false
| .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) .. =>
visitVar fvarId <||> go k
@[inline]
visitVar (x : FVarId) : LiveM Bool :=

View file

@ -149,6 +149,16 @@ mutual
return f!"sset {← ppFVar fvarId} [{i}, {offset}] := {← ppFVar y} " ++ ";" ++ .line ++ (← ppCode k)
| .uset fvarId i y k _ =>
return f!"uset {← ppFVar fvarId} [{i}] := {← ppFVar y} " ++ ";" ++ .line ++ (← ppCode k)
| .inc fvarId n _ _ k _ =>
if n != 1 then
return f!"inc[{n}] {← ppFVar fvarId};" ++ .line ++ (← ppCode k)
else
return f!"inc {← ppFVar fvarId};" ++ .line ++ (← ppCode k)
| .dec fvarId n _ _ k _ =>
if n != 1 then
return f!"dec[{n}] {← ppFVar fvarId};" ++ .line ++ (← ppCode k)
else
return f!"dec {← ppFVar fvarId};" ++ .line ++ (← ppCode k)
partial def ppDeclValue (b : DeclValue pu) : M Format := do

View file

@ -58,7 +58,7 @@ where
go k
| .cases cs => cs.alts.forM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return ()
| .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => go k
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
start (decls : Array (Decl pu)) : StateRefT (Array (LetValue pu)) CompilerM Unit :=
decls.forM (·.value.forCodeM go)
@ -73,7 +73,7 @@ where
| .jp decl k => modify (·.push decl); go decl.value; go k
| .cases cs => cs.alts.forM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return ()
| .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => go k
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
start (decls : Array (Decl pu)) : StateRefT (Array (FunDecl pu)) CompilerM Unit :=
decls.forM (·.value.forCodeM go)
@ -86,7 +86,7 @@ where
| .fun decl k _ | .jp decl k => go decl.value <||> go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return false
| .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => go k
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
partial def filterByFun (pu : Purity) (f : FunDecl pu → CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
filter (·.value.isCodeAndM go)
@ -96,7 +96,7 @@ where
| .fun decl k _ => do if (← f decl) then return true else go decl.value <||> go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return false
| .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => go k
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
partial def filterByJp (pu : Purity) (f : FunDecl pu → CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
filter (·.value.isCodeAndM go)
@ -107,7 +107,7 @@ where
| .jp decl k => do if (← f decl) then return true else go decl.value <||> go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return false
| .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => go k
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
partial def filterByFunDecl (pu : Purity) (f : FunDecl pu → CompilerM Bool) :
Probe (Decl pu) (Decl pu):=
@ -118,7 +118,7 @@ where
| .fun decl k _ | .jp decl k => do if (← f decl) then return true else go decl.value <||> go k
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return false
| .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => go k
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
partial def filterByCases (pu : Purity) (f : Cases pu → CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
filter (·.value.isCodeAndM go)
@ -128,7 +128,7 @@ where
| .fun decl k _ | .jp decl k => go decl.value <||> go k
| .cases cs => do if (← f cs) then return true else cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. | .unreach .. => return false
| .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => go k
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
partial def filterByJmp (pu : Purity) (f : FVarId → Array (Arg pu) → CompilerM Bool) :
Probe (Decl pu) (Decl pu) :=
@ -140,7 +140,7 @@ where
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp fn var => f fn var
| .return .. | .unreach .. => return false
| .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => go k
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
partial def filterByReturn (pu : Purity) (f : FVarId → CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
filter (·.value.isCodeAndM go)
@ -151,7 +151,7 @@ where
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .unreach .. => return false
| .return var => f var
| .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => go k
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
partial def filterByUnreach (pu : Purity) (f : Expr → CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
filter (·.value.isCodeAndM go)
@ -162,7 +162,7 @@ where
| .cases cs => cs.alts.anyM (go ·.getCode)
| .jmp .. | .return .. => return false
| .unreach typ => f typ
| .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => go k
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
@[inline]
def declNames (pu : Purity) : Probe (Decl pu) Name :=

View file

@ -137,6 +137,10 @@ where
go k (decls.push (.uset var i y))
| .sset var i offset y ty k _ =>
go k (decls.push (.sset var i offset y ty))
| .inc fvarId n check persistent k _ =>
go k (decls.push (.inc fvarId n check persistent))
| .dec fvarId n check persistent k _ =>
go k (decls.push (.dec fvarId n check persistent))
| .cases c => c.pushProjs decls
| .jmp .. | .return .. | .unreach .. =>
return attachCodeDecls decls c

View file

@ -53,10 +53,8 @@ partial def Code.applyRenaming (code : Code pu) (r : Renaming) : CompilerM (Code
| .ctorAlt _ k _ => return alt.updateCode (← k.applyRenaming r)
return code.updateAlts! alts
| .jmp .. | .unreach .. | .return .. => return code
| .sset fvarId i offset y ty k _ =>
return code.updateSset! fvarId i offset y ty (← k.applyRenaming r)
| .uset fvarId offset y k _ =>
return code.updateUset! fvarId offset y (← k.applyRenaming r)
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. =>
return code.updateCont! (← k.applyRenaming r)
end
def Decl.applyRenaming (decl : Decl pu) (r : Renaming) : CompilerM (Decl pu) := do

View file

@ -120,6 +120,7 @@ where
| .return .. | .jmp .. | .unreach .. => return (c, false)
| .sset _ _ _ _ _ k _ | .uset _ _ _ k _ | .let _ k =>
goK k
| .inc .. | .dec .. => unreachable!
def isCtorUsing (instr : CodeDecl .impure) (x : FVarId) : Bool :=
match instr with
@ -241,6 +242,7 @@ where
return (c.updateCont! k, false)
| .return .. | .jmp .. | .unreach .. =>
return (c, ← c.isFVarLiveIn x)
| .inc .. | .dec .. => unreachable!
end
@ -273,6 +275,7 @@ partial def Code.insertResetReuse (c : Code .impure) : ReuseM (Code .impure) :=
| .let _ k | .uset _ _ _ k _ | .sset _ _ _ _ _ k _ =>
return c.updateCont! (← k.insertResetReuse)
| .return .. | .jmp .. | .unreach .. => return c
| .inc .. | .dec .. => unreachable!
partial def Decl.insertResetReuseCore (decl : Decl .impure) : ReuseM (Decl .impure) := do
let value ← decl.value.mapCodeM fun code => do
@ -295,6 +298,7 @@ where
| .jp decl k => collectResets decl.value; collectResets k
| .cases c => c.alts.forM (collectResets ·.getCode)
| .unreach .. | .return .. | .jmp .. => return ()
| .inc .. | .dec .. => unreachable!
def Decl.insertResetReuse (decl : Decl .impure) : CompilerM (Decl .impure) := do

View file

@ -107,7 +107,7 @@ partial def Code.simpCase (code : Code .impure) : CompilerM (Code .impure) := do
let decl ← decl.updateValue (← decl.value.simpCase)
return code.updateFun! decl (← k.simpCase)
| .return .. | .jmp .. | .unreach .. => return code
| .let _ k | .uset _ _ _ k _ | .sset _ _ _ _ _ k _ =>
| .let _ k | .uset (k := k) .. | .sset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. =>
return code.updateCont! (← k.simpCase)
def Decl.simpCase (decl : Decl .impure) : CompilerM (Decl .impure) := do

View file

@ -34,7 +34,8 @@ where
goCode k
| .cases cases => cases.alts.forM (·.forCodeM goCode)
| .jmp .. | .return .. | .unreach .. => return ()
| .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => goCode k
| .uset (k := k) .. | .sset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. =>
goCode k
end SplitScc

View file

@ -120,6 +120,14 @@ partial def Code.toExprM (code : Code pu) : ToExprM Expr := do
let value := mkApp3 (mkConst `uset) (.fvar fvarId) (toExpr offset) (.fvar y)
let body ← withFVar fvarId k.toExprM
return .letE `dummy (mkConst ``Unit) value body true
| .inc fvarId n check persistent k _ =>
let value := mkApp4 (mkConst `inc) (.fvar fvarId) (toExpr n) (toExpr check) (toExpr persistent)
let body ← withFVar fvarId k.toExprM
return .letE `dummy (mkConst ``Unit) value body true
| .dec fvarId n check persistent k _ =>
let body ← withFVar fvarId k.toExprM
let value := mkApp4 (mkConst `dec) (.fvar fvarId) (toExpr n) (toExpr check) (toExpr persistent)
return .letE `dummy (mkConst ``Unit) value body true
end
public def Code.toExpr (code : Code pu) (xs : Array FVarId := #[]) : Expr :=