feat: LCNF inc/dec instructions (#12550)
This PR adds `inc`/`dec` instructions to LCNF. It should be a functional no-op.
This commit is contained in:
parent
6c671ffe6f
commit
ad64f7c1ba
22 changed files with 172 additions and 31 deletions
|
|
@ -109,6 +109,12 @@ partial def lowerCode (c : LCNF.Code .impure) : M FnBody := do
|
|||
let .var y ← getFVarValue y | unreachable!
|
||||
let .var var ← getFVarValue var | unreachable!
|
||||
return .uset var i y (← lowerCode k)
|
||||
| .inc fvarId n check persistent k _ =>
|
||||
let .var var ← getFVarValue fvarId | unreachable!
|
||||
return .inc var n check persistent (← lowerCode k)
|
||||
| .dec fvarId n check persistent k _ =>
|
||||
let .var var ← getFVarValue fvarId | unreachable!
|
||||
return .dec var n check persistent (← lowerCode k)
|
||||
| .fun .. => panic! "all local functions should be λ-lifted"
|
||||
|
||||
partial def lowerLet (decl : LCNF.LetDecl .impure) (k : LCNF.Code .impure) : M FnBody := do
|
||||
|
|
|
|||
|
|
@ -155,6 +155,18 @@ partial def eqv (code₁ code₂ : Code pu) : EqvM Bool := do
|
|||
eqvFVar var₁ var₂ <&&>
|
||||
eqvFVar y₁ y₂ <&&>
|
||||
eqv k₁ k₂
|
||||
| .inc fvarId₁ n₁ c₁ p₁ k₁ _, .inc fvarId₂ n₂ c₂ p₂ k₂ _ =>
|
||||
pure (n₁ == n₂) <&&>
|
||||
pure (c₁ == c₂) <&&>
|
||||
pure (p₁ == p₂) <&&>
|
||||
eqvFVar fvarId₁ fvarId₂ <&&>
|
||||
eqv k₁ k₂
|
||||
| .dec fvarId₁ n₁ c₁ p₁ k₁ _, .dec fvarId₂ n₂ c₂ p₂ k₂ _ =>
|
||||
pure (n₁ == n₂) <&&>
|
||||
pure (c₁ == c₂) <&&>
|
||||
pure (p₁ == p₂) <&&>
|
||||
eqvFVar fvarId₁ fvarId₂ <&&>
|
||||
eqv k₁ k₂
|
||||
| _, _ => return false
|
||||
|
||||
end
|
||||
|
|
|
|||
|
|
@ -363,6 +363,8 @@ inductive Code (pu : Purity) where
|
|||
| unreach (type : Expr)
|
||||
| uset (var : FVarId) (i : Nat) (y : FVarId) (k : Code pu) (h : pu = .impure := by purity_tac)
|
||||
| sset (var : FVarId) (i : Nat) (offset : Nat) (y : FVarId) (ty : Expr) (k : Code pu) (h : pu = .impure := by purity_tac)
|
||||
| inc (fvarId : FVarId) (n : Nat) (check : Bool) (persistent : Bool) (k : Code pu) (h : pu = .impure := by purity_tac)
|
||||
| dec (fvarId : FVarId) (n : Nat) (check : Bool) (persistent : Bool) (k : Code pu) (h : pu = .impure := by purity_tac)
|
||||
deriving Inhabited
|
||||
|
||||
end
|
||||
|
|
@ -440,11 +442,13 @@ inductive CodeDecl (pu : Purity) where
|
|||
| jp (decl : FunDecl pu)
|
||||
| uset (var : FVarId) (i : Nat) (y : FVarId) (h : pu = .impure := by purity_tac)
|
||||
| sset (var : FVarId) (i : Nat) (offset : Nat) (y : FVarId) (ty : Expr) (h : pu = .impure := by purity_tac)
|
||||
| inc (fvarId : FVarId) (n : Nat) (check : Bool) (persistent : Bool) (h : pu = .impure := by purity_tac)
|
||||
| dec (fvarId : FVarId) (n : Nat) (check : Bool) (persistent : Bool) (h : pu = .impure := by purity_tac)
|
||||
deriving Inhabited
|
||||
|
||||
def CodeDecl.fvarId : CodeDecl pu → FVarId
|
||||
| .let decl | .fun decl _ | .jp decl => decl.fvarId
|
||||
| .uset var .. | .sset var .. => var
|
||||
| .uset fvarId .. | .sset fvarId .. | .inc fvarId .. | .dec fvarId .. => fvarId
|
||||
|
||||
def Code.toCodeDecl! : Code pu → CodeDecl pu
|
||||
| .let decl _ => .let decl
|
||||
|
|
@ -452,6 +456,8 @@ def Code.toCodeDecl! : Code pu → CodeDecl pu
|
|||
| .jp decl _ => .jp decl
|
||||
| .uset var i y _ _ => .uset var i y
|
||||
| .sset var i offset ty y _ _ => .sset var i offset ty y
|
||||
| .inc fvarId n check persistent _ _ => .inc fvarId n check persistent
|
||||
| .dec fvarId n check persistent _ _ => .dec fvarId n check persistent
|
||||
| _ => unreachable!
|
||||
|
||||
def attachCodeDecls (decls : Array (CodeDecl pu)) (code : Code pu) : Code pu :=
|
||||
|
|
@ -465,6 +471,8 @@ where
|
|||
| .jp decl => go (i-1) (.jp decl code)
|
||||
| .uset var idx y _ => go (i-1) (.uset var idx y code)
|
||||
| .sset var idx offset y ty _ => go (i-1) (.sset var idx offset y ty code)
|
||||
| .inc fvarId n check persistent _ => go (i-1) (.inc fvarId n check persistent code)
|
||||
| .dec fvarId n check persistent _ => go (i-1) (.dec fvarId n check persistent code)
|
||||
else
|
||||
code
|
||||
|
||||
|
|
@ -484,6 +492,10 @@ mutual
|
|||
v₁ == v₂ && i₁ == i₂ && y₁ == y₂ && eqImp k₁ k₂
|
||||
| .sset v₁ i₁ o₁ y₁ ty₁ k₁ _, .sset v₂ i₂ o₂ y₂ ty₂ k₂ _ =>
|
||||
v₁ == v₂ && i₁ == i₂ && o₁ == o₂ && y₁ == y₂ && ty₁ == ty₂ && eqImp k₁ k₂
|
||||
| .inc v₁ n₁ c₁ p₁ k₁ _, .inc v₂ n₂ c₂ p₂ k₂ _ =>
|
||||
v₁ == v₂ && n₁ == n₂ && c₁ == c₂ && p₁ == p₂ && eqImp k₁ k₂
|
||||
| .dec v₁ n₁ c₁ p₁ k₁ _, .dec v₂ n₂ c₂ p₂ k₂ _ =>
|
||||
v₁ == v₂ && n₁ == n₂ && c₁ == c₂ && p₁ == p₂ && eqImp k₁ k₂
|
||||
| _, _ => false
|
||||
|
||||
private unsafe def eqFunDecl (d₁ d₂ : FunDecl pu) : Bool :=
|
||||
|
|
@ -578,6 +590,8 @@ private unsafe def updateAltImp (alt : Alt pu) (ps' : Array (Param pu)) (k' : Co
|
|||
| .jp decl k => if ptrEq k k' then c else .jp decl k'
|
||||
| .sset fvarId i offset y ty k _ => if ptrEq k k' then c else .sset fvarId i offset y ty k'
|
||||
| .uset fvarId offset y k _ => if ptrEq k k' then c else .uset fvarId offset y k'
|
||||
| .inc fvarId n check persistent k _ => if ptrEq k k' then c else .inc fvarId n check persistent k'
|
||||
| .dec fvarId n check persistent k _ => if ptrEq k k' then c else .dec fvarId n check persistent k'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by updateContImp] opaque Code.updateCont! (c : Code pu) (k' : Code pu) : Code pu
|
||||
|
|
@ -637,6 +651,40 @@ private unsafe def updateAltImp (alt : Alt pu) (ps' : Array (Param pu)) (k' : Co
|
|||
@[implemented_by updateUsetImp] opaque Code.updateUset! (c : Code pu) (fvarId' : FVarId)
|
||||
(i' : Nat) (y' : FVarId) (k' : Code pu) : Code pu
|
||||
|
||||
@[inline] private unsafe def updateIncImp (c : Code pu) (fvarId' : FVarId) (n' : Nat)
|
||||
(check' : Bool) (persistent' : Bool) (k' : Code pu) : Code pu :=
|
||||
match c with
|
||||
| .inc fvarId n check persistent k _ =>
|
||||
if ptrEq fvarId fvarId'
|
||||
&& n == n'
|
||||
&& check == check'
|
||||
&& persistent == persistent'
|
||||
&& ptrEq k k' then
|
||||
c
|
||||
else
|
||||
.inc fvarId' n' check' persistent' k'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by updateIncImp] opaque Code.updateInc! (c : Code pu) (fvarId' : FVarId) (n' : Nat)
|
||||
(check' : Bool) (persistent' : Bool) (k' : Code pu) : Code pu
|
||||
|
||||
@[inline] private unsafe def updateDecImp (c : Code pu) (fvarId' : FVarId) (n' : Nat)
|
||||
(check' : Bool) (persistent' : Bool) (k' : Code pu) : Code pu :=
|
||||
match c with
|
||||
| .dec fvarId n check persistent k _ =>
|
||||
if ptrEq fvarId fvarId'
|
||||
&& n == n'
|
||||
&& check == check'
|
||||
&& persistent == persistent'
|
||||
&& ptrEq k k' then
|
||||
c
|
||||
else
|
||||
.dec fvarId' n' check' persistent' k'
|
||||
| _ => unreachable!
|
||||
|
||||
@[implemented_by updateDecImp] opaque Code.updateDec! (c : Code pu) (fvarId' : FVarId) (n' : Nat)
|
||||
(check' : Bool) (persistent' : Bool) (k' : Code pu) : Code pu
|
||||
|
||||
private unsafe def updateParamCoreImp (p : Param pu) (type : Expr) : Param pu :=
|
||||
if ptrEq type p.type then
|
||||
p
|
||||
|
|
@ -705,7 +753,8 @@ partial def Code.size (c : Code pu) : Nat :=
|
|||
where
|
||||
go (c : Code pu) (n : Nat) : Nat :=
|
||||
match c with
|
||||
| .let _ k | .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => go k (n + 1)
|
||||
| .let (k := k) .. | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) ..
|
||||
| .dec (k := k) .. => go k (n + 1)
|
||||
| .jp decl k | .fun decl k _ => go k <| go decl.value n
|
||||
| .cases c => c.alts.foldl (init := n+1) fun n alt => go alt.getCode (n+1)
|
||||
| .jmp .. => n+1
|
||||
|
|
@ -723,7 +772,8 @@ where
|
|||
|
||||
go (c : Code pu) : EStateM Unit Nat Unit := do
|
||||
match c with
|
||||
| .let _ k | .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => inc; go k
|
||||
| .let (k := k) .. | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) ..
|
||||
| .dec (k := k) .. => inc; go k
|
||||
| .jp decl k | .fun decl k _ => inc; go decl.value; go k
|
||||
| .cases c => inc; c.alts.forM fun alt => go alt.getCode
|
||||
| .jmp .. => inc
|
||||
|
|
@ -735,7 +785,8 @@ where
|
|||
go (c : Code pu) : m Unit := do
|
||||
f c
|
||||
match c with
|
||||
| .let _ k | .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => go k
|
||||
| .let (k := k) .. | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) ..
|
||||
| .dec (k := k) .. => go k
|
||||
| .fun decl k _ | .jp decl k => go decl.value; go k
|
||||
| .cases c => c.alts.forM fun alt => go alt.getCode
|
||||
| .unreach .. | .return .. | .jmp .. => return ()
|
||||
|
|
@ -1031,6 +1082,8 @@ partial def Code.collectUsed (code : Code pu) (s : FVarIdHashSet := {}) : FVarId
|
|||
let s := s.insert var
|
||||
let s := s.insert y
|
||||
k.collectUsed s
|
||||
| .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) .. =>
|
||||
k.collectUsed <| s.insert fvarId
|
||||
end
|
||||
|
||||
@[inline] def collectUsedAtExpr (s : FVarIdHashSet) (e : Expr) : FVarIdHashSet :=
|
||||
|
|
@ -1040,8 +1093,8 @@ def CodeDecl.collectUsed (codeDecl : CodeDecl pu) (s : FVarIdHashSet := ∅) : F
|
|||
match codeDecl with
|
||||
| .let decl => collectLetValue decl.value <| collectType decl.type s
|
||||
| .jp decl | .fun decl _ => decl.collectUsed s
|
||||
| .sset var _ _ y ty _ => s.insert var |>.insert y |> collectType ty
|
||||
| .uset var _ y _ => s.insert var |>.insert y
|
||||
| .sset (var := var) (y := y) .. | .uset (var := var) (y := y) .. => s.insert var |>.insert y
|
||||
| .inc (fvarId := fvarId) .. | .dec (fvarId := fvarId) .. => s.insert fvarId
|
||||
|
||||
/--
|
||||
Traverse the given block of potentially mutually recursive functions
|
||||
|
|
@ -1071,7 +1124,7 @@ where
|
|||
modify fun s => s.insert declName
|
||||
| _ => pure ()
|
||||
visit k
|
||||
| .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => visit k
|
||||
| .uset (k := k) .. | .sset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => visit k
|
||||
|
||||
go : StateM NameSet Unit :=
|
||||
decls.forM (·.value.forCodeM visit)
|
||||
|
|
|
|||
|
|
@ -68,8 +68,8 @@ where
|
|||
eraseCode k
|
||||
eraseParam auxParam
|
||||
return .unreach typeNew
|
||||
| .sset fvarId i offset y ty k _ => return .sset fvarId i offset y ty (← go k)
|
||||
| .uset fvarId offset y k _ => return .uset fvarId offset y (← go k)
|
||||
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. =>
|
||||
return c.updateCont! (← go k)
|
||||
|
||||
instance : MonadCodeBind CompilerM where
|
||||
codeBind := CompilerM.codeBind
|
||||
|
|
|
|||
|
|
@ -149,7 +149,7 @@ def eraseCodeDecl (decl : CodeDecl pu) : CompilerM Unit := do
|
|||
match decl with
|
||||
| .let decl => eraseLetDecl decl
|
||||
| .jp decl | .fun decl _ => eraseFunDecl decl
|
||||
| .sset .. | .uset .. => return ()
|
||||
| .sset .. | .uset .. | .inc .. | .dec .. => return ()
|
||||
|
||||
/--
|
||||
Erase all free variables occurring in `decls` from the local context.
|
||||
|
|
@ -501,6 +501,12 @@ mutual
|
|||
withNormFVarResult (← normFVar fvarId) fun fvarId => do
|
||||
withNormFVarResult (← normFVar y) fun y => do
|
||||
return code.updateUset! fvarId offset y (← normCodeImp k)
|
||||
| .inc fvarId n check persistent k _ =>
|
||||
withNormFVarResult (← normFVar fvarId) fun fvarId => do
|
||||
return code.updateInc! fvarId n check persistent (← normCodeImp k)
|
||||
| .dec fvarId n check persistent k _ =>
|
||||
withNormFVarResult (← normFVar fvarId) fun fvarId => do
|
||||
return code.updateDec! fvarId n check persistent (← normCodeImp k)
|
||||
end
|
||||
|
||||
@[inline] def normFunDecl [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m pu t] (decl : FunDecl pu) : m (FunDecl pu) := do
|
||||
|
|
|
|||
|
|
@ -41,6 +41,10 @@ partial def hashCode (code : Code pu) : UInt64 :=
|
|||
mixHash (mixHash (hash fvarId) (hash i)) (mixHash (mixHash (hash offset) (hash y)) (mixHash (hash ty) (hashCode k)))
|
||||
| .uset fvarId offset y k _ =>
|
||||
mixHash (mixHash (hash fvarId) (hash offset)) (mixHash (hash y) (hashCode k))
|
||||
| .inc fvarId n check persistent k _ =>
|
||||
mixHash (mixHash (hash fvarId) (hash n)) (mixHash (mixHash (hash persistent) (hash check)) (hashCode k))
|
||||
| .dec fvarId n check persistent k _ =>
|
||||
mixHash (mixHash (hash fvarId) (hash n)) (mixHash (mixHash (hash persistent) (hash check)) (hashCode k))
|
||||
|
||||
end
|
||||
|
||||
|
|
|
|||
|
|
@ -47,6 +47,8 @@ private partial def depOn (c : Code pu) : M Bool :=
|
|||
| .return fvarId => fvarDepOn fvarId
|
||||
| .unreach _ => return false
|
||||
| .sset fv1 _ _ fv2 _ k _ | .uset fv1 _ fv2 k _ => fvarDepOn fv1 <||> fvarDepOn fv2 <||> depOn k
|
||||
| .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) .. =>
|
||||
fvarDepOn fvarId <||> depOn k
|
||||
|
||||
@[inline] def Arg.dependsOn (arg : Arg pu) (s : FVarIdSet) : Bool :=
|
||||
argDepOn arg s
|
||||
|
|
@ -64,8 +66,8 @@ def CodeDecl.dependsOn (decl : CodeDecl pu) (s : FVarIdSet) : Bool :=
|
|||
match decl with
|
||||
| .let decl => decl.dependsOn s
|
||||
| .jp decl | .fun decl _ => decl.dependsOn s
|
||||
| .uset var _ y _ => s.contains var || s.contains y
|
||||
| .sset var _ _ y ty _ => s.contains var || s.contains y || (typeDepOn ty s)
|
||||
| .uset (var := var) (y := y) .. | .sset (var := var) (y := y) .. => s.contains var || s.contains y
|
||||
| .inc (fvarId := fvarId) .. | .dec (fvarId := fvarId) .. => s.contains fvarId
|
||||
|
||||
/--
|
||||
Return `true` is `c` depends on a free variable in `s`.
|
||||
|
|
|
|||
|
|
@ -102,6 +102,10 @@ partial def Code.elimDead (code : Code pu) : M (Code pu) := do
|
|||
return code.updateCont! k
|
||||
else
|
||||
return k
|
||||
| .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) .. =>
|
||||
let k ← k.elimDead
|
||||
collectFVarM fvarId
|
||||
return code.updateCont! k
|
||||
|
||||
end
|
||||
|
||||
|
|
|
|||
|
|
@ -292,6 +292,7 @@ partial def Code.explicitBoxing (code : Code .impure) : BoxM (Code .impure) := d
|
|||
let some jpDecl ← findFunDecl? fvarId | unreachable!
|
||||
castArgsIfNeeded args jpDecl.params fun args => return code.updateJmp! fvarId args
|
||||
| .unreach .. => return code.updateUnreach! (← getResultType)
|
||||
| .inc .. | .dec .. => unreachable!
|
||||
where
|
||||
/--
|
||||
Up to this point the type system of IR is quite loose so we can for example encounter situations
|
||||
|
|
@ -368,7 +369,7 @@ def run (decls : Array (Decl .impure)) : CompilerM (Array (Decl .impure)) := do
|
|||
public def explicitBoxing : Pass where
|
||||
phase := .impure
|
||||
phaseOut := .impure
|
||||
name := `boxing
|
||||
name := `explicitBoxing
|
||||
run := run
|
||||
|
||||
builtin_initialize
|
||||
|
|
|
|||
|
|
@ -143,6 +143,10 @@ partial def Code.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m F
|
|||
return Code.updateSset! c (← f fvarId) i offset (← f y) (← Expr.mapFVarM f ty) (← mapFVarM f k)
|
||||
| .uset fvarId offset y k _ =>
|
||||
return Code.updateUset! c (← f fvarId) offset (← f y) (← mapFVarM f k)
|
||||
| .inc fvarId n check persistent k _ =>
|
||||
return Code.updateInc! c (← f fvarId) n check persistent (← mapFVarM f k)
|
||||
| .dec fvarId n check persistent k _ =>
|
||||
return Code.updateDec! c (← f fvarId) n check persistent (← mapFVarM f k)
|
||||
|
||||
partial def Code.forFVarM [Monad m] (f : FVarId → m Unit) (c : Code pu) : m Unit := do
|
||||
match c with
|
||||
|
|
@ -178,6 +182,9 @@ partial def Code.forFVarM [Monad m] (f : FVarId → m Unit) (c : Code pu) : m Un
|
|||
f fvarId
|
||||
f y
|
||||
forFVarM f k
|
||||
| .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) .. =>
|
||||
f fvarId
|
||||
forFVarM f k
|
||||
|
||||
instance : TraverseFVar (Code pu) where
|
||||
mapFVarM := Code.mapFVarM
|
||||
|
|
@ -204,6 +211,8 @@ instance : TraverseFVar (CodeDecl pu) where
|
|||
| .let decl => return .let (← mapFVarM f decl)
|
||||
| .uset var i y _ => return .uset (← f var) i (← f y)
|
||||
| .sset var i offset y ty _ => return .sset (← f var) i offset (← f y) (← mapFVarM f ty)
|
||||
| .inc fvarId n check persistent _ => return .inc (← f fvarId) n check persistent
|
||||
| .dec fvarId n check persistent _ => return .dec (← f fvarId) n check persistent
|
||||
forFVarM f decl :=
|
||||
match decl with
|
||||
| .fun decl _ => forFVarM f decl
|
||||
|
|
@ -211,6 +220,7 @@ instance : TraverseFVar (CodeDecl pu) where
|
|||
| .let decl => forFVarM f decl
|
||||
| .uset var i y _ => do f var; f y
|
||||
| .sset var i offset y ty _ => do f var; f y; forFVarM f ty
|
||||
| .inc (fvarId := fvarId) .. | .dec (fvarId := fvarId) .. => f fvarId
|
||||
|
||||
instance : TraverseFVar (Alt pu) where
|
||||
mapFVarM f alt := do
|
||||
|
|
|
|||
|
|
@ -91,6 +91,7 @@ where
|
|||
| .cases cs => cs.alts.forM (·.forCodeM (goCode declName))
|
||||
| .let _ k | .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => goCode declName k
|
||||
| .return .. | .jmp .. | .unreach .. => return ()
|
||||
| .inc .. | .dec .. => unreachable!
|
||||
|
||||
/--
|
||||
Apply the inferred borrow annotations from `map` to a SCC.
|
||||
|
|
@ -120,6 +121,7 @@ where
|
|||
| .cases cs => return code.updateAlts! <| ← cs.alts.mapM (·.mapCodeM (go declName))
|
||||
| .let _ k | .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => return code.updateCont! (← go declName k)
|
||||
| .return .. | .jmp .. | .unreach .. => return code
|
||||
| .inc .. | .dec .. => unreachable!
|
||||
|
||||
structure Ctx where
|
||||
/--
|
||||
|
|
@ -298,6 +300,7 @@ where
|
|||
| .cases cs => cs.alts.forM (·.forCodeM collectCode)
|
||||
| .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => collectCode k
|
||||
| .return .. | .unreach .. => return ()
|
||||
| .inc .. | .dec .. => unreachable!
|
||||
|
||||
|
||||
public def inferBorrow : Pass where
|
||||
|
|
|
|||
|
|
@ -166,6 +166,12 @@ partial def internalizeCode (code : Code pu) : InternalizeM pu (Code pu) := do
|
|||
withNormFVarResult (← normFVar fvarId) fun fvarId => do
|
||||
withNormFVarResult (← normFVar y) fun y => do
|
||||
return .uset fvarId offset y (← internalizeCode k)
|
||||
| .inc fvarId n check persistent k _ =>
|
||||
withNormFVarResult (← normFVar fvarId) fun fvarId => do
|
||||
return .inc fvarId n check persistent (← internalizeCode k)
|
||||
| .dec fvarId n check persistent k _ =>
|
||||
withNormFVarResult (← normFVar fvarId) fun fvarId => do
|
||||
return .dec fvarId n check persistent (← internalizeCode k)
|
||||
|
||||
end
|
||||
|
||||
|
|
@ -184,6 +190,12 @@ partial def internalizeCodeDecl (decl : CodeDecl pu) : InternalizeM pu (CodeDecl
|
|||
let .fvar y ← normFVar y | unreachable!
|
||||
let ty ← normExpr ty
|
||||
return .sset var i offset y ty
|
||||
| .inc fvarId n check offset _ =>
|
||||
let .fvar fvarId ← normFVar fvarId | unreachable!
|
||||
return .inc fvarId n check offset
|
||||
| .dec fvarId n check offset _ =>
|
||||
let .fvar fvarId ← normFVar fvarId | unreachable!
|
||||
return .dec fvarId n check offset
|
||||
|
||||
|
||||
end Internalize
|
||||
|
|
|
|||
|
|
@ -77,8 +77,9 @@ mutual
|
|||
| .let decl k => eraseCode k <| lctx.eraseLetDecl decl
|
||||
| .jp decl k | .fun decl k _ => eraseCode k <| eraseFunDecl lctx decl
|
||||
| .cases c => eraseAlts c.alts lctx
|
||||
| .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => eraseCode k lctx
|
||||
| _ => lctx
|
||||
| .uset (k := k) .. | .sset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. =>
|
||||
eraseCode k lctx
|
||||
| .return .. | .jmp .. | .unreach .. => lctx
|
||||
end
|
||||
|
||||
@[inline]
|
||||
|
|
|
|||
|
|
@ -76,6 +76,8 @@ where
|
|||
go decl.value
|
||||
| .return var => visitVar var
|
||||
| .unreach .. => return false
|
||||
| .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) .. =>
|
||||
visitVar fvarId <||> go k
|
||||
|
||||
@[inline]
|
||||
visitVar (x : FVarId) : LiveM Bool :=
|
||||
|
|
|
|||
|
|
@ -149,6 +149,16 @@ mutual
|
|||
return f!"sset {← ppFVar fvarId} [{i}, {offset}] := {← ppFVar y} " ++ ";" ++ .line ++ (← ppCode k)
|
||||
| .uset fvarId i y k _ =>
|
||||
return f!"uset {← ppFVar fvarId} [{i}] := {← ppFVar y} " ++ ";" ++ .line ++ (← ppCode k)
|
||||
| .inc fvarId n _ _ k _ =>
|
||||
if n != 1 then
|
||||
return f!"inc[{n}] {← ppFVar fvarId};" ++ .line ++ (← ppCode k)
|
||||
else
|
||||
return f!"inc {← ppFVar fvarId};" ++ .line ++ (← ppCode k)
|
||||
| .dec fvarId n _ _ k _ =>
|
||||
if n != 1 then
|
||||
return f!"dec[{n}] {← ppFVar fvarId};" ++ .line ++ (← ppCode k)
|
||||
else
|
||||
return f!"dec {← ppFVar fvarId};" ++ .line ++ (← ppCode k)
|
||||
|
||||
|
||||
partial def ppDeclValue (b : DeclValue pu) : M Format := do
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ where
|
|||
go k
|
||||
| .cases cs => cs.alts.forM (go ·.getCode)
|
||||
| .jmp .. | .return .. | .unreach .. => return ()
|
||||
| .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => go k
|
||||
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
|
||||
start (decls : Array (Decl pu)) : StateRefT (Array (LetValue pu)) CompilerM Unit :=
|
||||
decls.forM (·.value.forCodeM go)
|
||||
|
||||
|
|
@ -73,7 +73,7 @@ where
|
|||
| .jp decl k => modify (·.push decl); go decl.value; go k
|
||||
| .cases cs => cs.alts.forM (go ·.getCode)
|
||||
| .jmp .. | .return .. | .unreach .. => return ()
|
||||
| .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => go k
|
||||
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
|
||||
|
||||
start (decls : Array (Decl pu)) : StateRefT (Array (FunDecl pu)) CompilerM Unit :=
|
||||
decls.forM (·.value.forCodeM go)
|
||||
|
|
@ -86,7 +86,7 @@ where
|
|||
| .fun decl k _ | .jp decl k => go decl.value <||> go k
|
||||
| .cases cs => cs.alts.anyM (go ·.getCode)
|
||||
| .jmp .. | .return .. | .unreach .. => return false
|
||||
| .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => go k
|
||||
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
|
||||
|
||||
partial def filterByFun (pu : Purity) (f : FunDecl pu → CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
|
||||
filter (·.value.isCodeAndM go)
|
||||
|
|
@ -96,7 +96,7 @@ where
|
|||
| .fun decl k _ => do if (← f decl) then return true else go decl.value <||> go k
|
||||
| .cases cs => cs.alts.anyM (go ·.getCode)
|
||||
| .jmp .. | .return .. | .unreach .. => return false
|
||||
| .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => go k
|
||||
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
|
||||
|
||||
partial def filterByJp (pu : Purity) (f : FunDecl pu → CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
|
||||
filter (·.value.isCodeAndM go)
|
||||
|
|
@ -107,7 +107,7 @@ where
|
|||
| .jp decl k => do if (← f decl) then return true else go decl.value <||> go k
|
||||
| .cases cs => cs.alts.anyM (go ·.getCode)
|
||||
| .jmp .. | .return .. | .unreach .. => return false
|
||||
| .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => go k
|
||||
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
|
||||
|
||||
partial def filterByFunDecl (pu : Purity) (f : FunDecl pu → CompilerM Bool) :
|
||||
Probe (Decl pu) (Decl pu):=
|
||||
|
|
@ -118,7 +118,7 @@ where
|
|||
| .fun decl k _ | .jp decl k => do if (← f decl) then return true else go decl.value <||> go k
|
||||
| .cases cs => cs.alts.anyM (go ·.getCode)
|
||||
| .jmp .. | .return .. | .unreach .. => return false
|
||||
| .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => go k
|
||||
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
|
||||
|
||||
partial def filterByCases (pu : Purity) (f : Cases pu → CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
|
||||
filter (·.value.isCodeAndM go)
|
||||
|
|
@ -128,7 +128,7 @@ where
|
|||
| .fun decl k _ | .jp decl k => go decl.value <||> go k
|
||||
| .cases cs => do if (← f cs) then return true else cs.alts.anyM (go ·.getCode)
|
||||
| .jmp .. | .return .. | .unreach .. => return false
|
||||
| .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => go k
|
||||
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
|
||||
|
||||
partial def filterByJmp (pu : Purity) (f : FVarId → Array (Arg pu) → CompilerM Bool) :
|
||||
Probe (Decl pu) (Decl pu) :=
|
||||
|
|
@ -140,7 +140,7 @@ where
|
|||
| .cases cs => cs.alts.anyM (go ·.getCode)
|
||||
| .jmp fn var => f fn var
|
||||
| .return .. | .unreach .. => return false
|
||||
| .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => go k
|
||||
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
|
||||
|
||||
partial def filterByReturn (pu : Purity) (f : FVarId → CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
|
||||
filter (·.value.isCodeAndM go)
|
||||
|
|
@ -151,7 +151,7 @@ where
|
|||
| .cases cs => cs.alts.anyM (go ·.getCode)
|
||||
| .jmp .. | .unreach .. => return false
|
||||
| .return var => f var
|
||||
| .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => go k
|
||||
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
|
||||
|
||||
partial def filterByUnreach (pu : Purity) (f : Expr → CompilerM Bool) : Probe (Decl pu) (Decl pu) :=
|
||||
filter (·.value.isCodeAndM go)
|
||||
|
|
@ -162,7 +162,7 @@ where
|
|||
| .cases cs => cs.alts.anyM (go ·.getCode)
|
||||
| .jmp .. | .return .. => return false
|
||||
| .unreach typ => f typ
|
||||
| .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => go k
|
||||
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k
|
||||
|
||||
@[inline]
|
||||
def declNames (pu : Purity) : Probe (Decl pu) Name :=
|
||||
|
|
|
|||
|
|
@ -137,6 +137,10 @@ where
|
|||
go k (decls.push (.uset var i y))
|
||||
| .sset var i offset y ty k _ =>
|
||||
go k (decls.push (.sset var i offset y ty))
|
||||
| .inc fvarId n check persistent k _ =>
|
||||
go k (decls.push (.inc fvarId n check persistent))
|
||||
| .dec fvarId n check persistent k _ =>
|
||||
go k (decls.push (.dec fvarId n check persistent))
|
||||
| .cases c => c.pushProjs decls
|
||||
| .jmp .. | .return .. | .unreach .. =>
|
||||
return attachCodeDecls decls c
|
||||
|
|
|
|||
|
|
@ -53,10 +53,8 @@ partial def Code.applyRenaming (code : Code pu) (r : Renaming) : CompilerM (Code
|
|||
| .ctorAlt _ k _ => return alt.updateCode (← k.applyRenaming r)
|
||||
return code.updateAlts! alts
|
||||
| .jmp .. | .unreach .. | .return .. => return code
|
||||
| .sset fvarId i offset y ty k _ =>
|
||||
return code.updateSset! fvarId i offset y ty (← k.applyRenaming r)
|
||||
| .uset fvarId offset y k _ =>
|
||||
return code.updateUset! fvarId offset y (← k.applyRenaming r)
|
||||
| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. =>
|
||||
return code.updateCont! (← k.applyRenaming r)
|
||||
end
|
||||
|
||||
def Decl.applyRenaming (decl : Decl pu) (r : Renaming) : CompilerM (Decl pu) := do
|
||||
|
|
|
|||
|
|
@ -120,6 +120,7 @@ where
|
|||
| .return .. | .jmp .. | .unreach .. => return (c, false)
|
||||
| .sset _ _ _ _ _ k _ | .uset _ _ _ k _ | .let _ k =>
|
||||
goK k
|
||||
| .inc .. | .dec .. => unreachable!
|
||||
|
||||
def isCtorUsing (instr : CodeDecl .impure) (x : FVarId) : Bool :=
|
||||
match instr with
|
||||
|
|
@ -241,6 +242,7 @@ where
|
|||
return (c.updateCont! k, false)
|
||||
| .return .. | .jmp .. | .unreach .. =>
|
||||
return (c, ← c.isFVarLiveIn x)
|
||||
| .inc .. | .dec .. => unreachable!
|
||||
|
||||
end
|
||||
|
||||
|
|
@ -273,6 +275,7 @@ partial def Code.insertResetReuse (c : Code .impure) : ReuseM (Code .impure) :=
|
|||
| .let _ k | .uset _ _ _ k _ | .sset _ _ _ _ _ k _ =>
|
||||
return c.updateCont! (← k.insertResetReuse)
|
||||
| .return .. | .jmp .. | .unreach .. => return c
|
||||
| .inc .. | .dec .. => unreachable!
|
||||
|
||||
partial def Decl.insertResetReuseCore (decl : Decl .impure) : ReuseM (Decl .impure) := do
|
||||
let value ← decl.value.mapCodeM fun code => do
|
||||
|
|
@ -295,6 +298,7 @@ where
|
|||
| .jp decl k => collectResets decl.value; collectResets k
|
||||
| .cases c => c.alts.forM (collectResets ·.getCode)
|
||||
| .unreach .. | .return .. | .jmp .. => return ()
|
||||
| .inc .. | .dec .. => unreachable!
|
||||
|
||||
|
||||
def Decl.insertResetReuse (decl : Decl .impure) : CompilerM (Decl .impure) := do
|
||||
|
|
|
|||
|
|
@ -107,7 +107,7 @@ partial def Code.simpCase (code : Code .impure) : CompilerM (Code .impure) := do
|
|||
let decl ← decl.updateValue (← decl.value.simpCase)
|
||||
return code.updateFun! decl (← k.simpCase)
|
||||
| .return .. | .jmp .. | .unreach .. => return code
|
||||
| .let _ k | .uset _ _ _ k _ | .sset _ _ _ _ _ k _ =>
|
||||
| .let _ k | .uset (k := k) .. | .sset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. =>
|
||||
return code.updateCont! (← k.simpCase)
|
||||
|
||||
def Decl.simpCase (decl : Decl .impure) : CompilerM (Decl .impure) := do
|
||||
|
|
|
|||
|
|
@ -34,7 +34,8 @@ where
|
|||
goCode k
|
||||
| .cases cases => cases.alts.forM (·.forCodeM goCode)
|
||||
| .jmp .. | .return .. | .unreach .. => return ()
|
||||
| .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => goCode k
|
||||
| .uset (k := k) .. | .sset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. =>
|
||||
goCode k
|
||||
|
||||
end SplitScc
|
||||
|
||||
|
|
|
|||
|
|
@ -120,6 +120,14 @@ partial def Code.toExprM (code : Code pu) : ToExprM Expr := do
|
|||
let value := mkApp3 (mkConst `uset) (.fvar fvarId) (toExpr offset) (.fvar y)
|
||||
let body ← withFVar fvarId k.toExprM
|
||||
return .letE `dummy (mkConst ``Unit) value body true
|
||||
| .inc fvarId n check persistent k _ =>
|
||||
let value := mkApp4 (mkConst `inc) (.fvar fvarId) (toExpr n) (toExpr check) (toExpr persistent)
|
||||
let body ← withFVar fvarId k.toExprM
|
||||
return .letE `dummy (mkConst ``Unit) value body true
|
||||
| .dec fvarId n check persistent k _ =>
|
||||
let body ← withFVar fvarId k.toExprM
|
||||
let value := mkApp4 (mkConst `dec) (.fvar fvarId) (toExpr n) (toExpr check) (toExpr persistent)
|
||||
return .letE `dummy (mkConst ``Unit) value body true
|
||||
end
|
||||
|
||||
public def Code.toExpr (code : Code pu) (xs : Array FVarId := #[]) : Expr :=
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue