From ad64f7c1bab7725a4be97bdb90056a52076358e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrik=20B=C3=B6ving?= Date: Wed, 18 Feb 2026 11:55:16 +0100 Subject: [PATCH] feat: LCNF inc/dec instructions (#12550) This PR adds `inc`/`dec` instructions to LCNF. It should be a functional no-op. --- src/Lean/Compiler/IR/ToIR.lean | 6 ++ src/Lean/Compiler/LCNF/AlphaEqv.lean | 12 ++++ src/Lean/Compiler/LCNF/Basic.lean | 67 +++++++++++++++++++--- src/Lean/Compiler/LCNF/Bind.lean | 4 +- src/Lean/Compiler/LCNF/CompilerM.lean | 8 ++- src/Lean/Compiler/LCNF/DeclHash.lean | 4 ++ src/Lean/Compiler/LCNF/DependsOn.lean | 6 +- src/Lean/Compiler/LCNF/ElimDead.lean | 4 ++ src/Lean/Compiler/LCNF/ExplicitBoxing.lean | 3 +- src/Lean/Compiler/LCNF/FVarUtil.lean | 10 ++++ src/Lean/Compiler/LCNF/InferBorrow.lean | 3 + src/Lean/Compiler/LCNF/Internalize.lean | 12 ++++ src/Lean/Compiler/LCNF/LCtx.lean | 5 +- src/Lean/Compiler/LCNF/LiveVars.lean | 2 + src/Lean/Compiler/LCNF/PrettyPrinter.lean | 10 ++++ src/Lean/Compiler/LCNF/Probing.lean | 20 +++---- src/Lean/Compiler/LCNF/PushProj.lean | 4 ++ src/Lean/Compiler/LCNF/Renaming.lean | 6 +- src/Lean/Compiler/LCNF/ResetReuse.lean | 4 ++ src/Lean/Compiler/LCNF/SimpCase.lean | 2 +- src/Lean/Compiler/LCNF/SplitSCC.lean | 3 +- src/Lean/Compiler/LCNF/ToExpr.lean | 8 +++ 22 files changed, 172 insertions(+), 31 deletions(-) diff --git a/src/Lean/Compiler/IR/ToIR.lean b/src/Lean/Compiler/IR/ToIR.lean index 4140624080..552814e4a8 100644 --- a/src/Lean/Compiler/IR/ToIR.lean +++ b/src/Lean/Compiler/IR/ToIR.lean @@ -109,6 +109,12 @@ partial def lowerCode (c : LCNF.Code .impure) : M FnBody := do let .var y ← getFVarValue y | unreachable! let .var var ← getFVarValue var | unreachable! return .uset var i y (← lowerCode k) + | .inc fvarId n check persistent k _ => + let .var var ← getFVarValue fvarId | unreachable! + return .inc var n check persistent (← lowerCode k) + | .dec fvarId n check persistent k _ => + let .var var ← getFVarValue fvarId | unreachable! + return .dec var n check persistent (← lowerCode k) | .fun .. => panic! "all local functions should be λ-lifted" partial def lowerLet (decl : LCNF.LetDecl .impure) (k : LCNF.Code .impure) : M FnBody := do diff --git a/src/Lean/Compiler/LCNF/AlphaEqv.lean b/src/Lean/Compiler/LCNF/AlphaEqv.lean index e1571dbcda..16490b22b6 100644 --- a/src/Lean/Compiler/LCNF/AlphaEqv.lean +++ b/src/Lean/Compiler/LCNF/AlphaEqv.lean @@ -155,6 +155,18 @@ partial def eqv (code₁ code₂ : Code pu) : EqvM Bool := do eqvFVar var₁ var₂ <&&> eqvFVar y₁ y₂ <&&> eqv k₁ k₂ + | .inc fvarId₁ n₁ c₁ p₁ k₁ _, .inc fvarId₂ n₂ c₂ p₂ k₂ _ => + pure (n₁ == n₂) <&&> + pure (c₁ == c₂) <&&> + pure (p₁ == p₂) <&&> + eqvFVar fvarId₁ fvarId₂ <&&> + eqv k₁ k₂ + | .dec fvarId₁ n₁ c₁ p₁ k₁ _, .dec fvarId₂ n₂ c₂ p₂ k₂ _ => + pure (n₁ == n₂) <&&> + pure (c₁ == c₂) <&&> + pure (p₁ == p₂) <&&> + eqvFVar fvarId₁ fvarId₂ <&&> + eqv k₁ k₂ | _, _ => return false end diff --git a/src/Lean/Compiler/LCNF/Basic.lean b/src/Lean/Compiler/LCNF/Basic.lean index 7e86fd69a4..ca11a14a09 100644 --- a/src/Lean/Compiler/LCNF/Basic.lean +++ b/src/Lean/Compiler/LCNF/Basic.lean @@ -363,6 +363,8 @@ inductive Code (pu : Purity) where | unreach (type : Expr) | uset (var : FVarId) (i : Nat) (y : FVarId) (k : Code pu) (h : pu = .impure := by purity_tac) | sset (var : FVarId) (i : Nat) (offset : Nat) (y : FVarId) (ty : Expr) (k : Code pu) (h : pu = .impure := by purity_tac) + | inc (fvarId : FVarId) (n : Nat) (check : Bool) (persistent : Bool) (k : Code pu) (h : pu = .impure := by purity_tac) + | dec (fvarId : FVarId) (n : Nat) (check : Bool) (persistent : Bool) (k : Code pu) (h : pu = .impure := by purity_tac) deriving Inhabited end @@ -440,11 +442,13 @@ inductive CodeDecl (pu : Purity) where | jp (decl : FunDecl pu) | uset (var : FVarId) (i : Nat) (y : FVarId) (h : pu = .impure := by purity_tac) | sset (var : FVarId) (i : Nat) (offset : Nat) (y : FVarId) (ty : Expr) (h : pu = .impure := by purity_tac) + | inc (fvarId : FVarId) (n : Nat) (check : Bool) (persistent : Bool) (h : pu = .impure := by purity_tac) + | dec (fvarId : FVarId) (n : Nat) (check : Bool) (persistent : Bool) (h : pu = .impure := by purity_tac) deriving Inhabited def CodeDecl.fvarId : CodeDecl pu → FVarId | .let decl | .fun decl _ | .jp decl => decl.fvarId - | .uset var .. | .sset var .. => var + | .uset fvarId .. | .sset fvarId .. | .inc fvarId .. | .dec fvarId .. => fvarId def Code.toCodeDecl! : Code pu → CodeDecl pu | .let decl _ => .let decl @@ -452,6 +456,8 @@ def Code.toCodeDecl! : Code pu → CodeDecl pu | .jp decl _ => .jp decl | .uset var i y _ _ => .uset var i y | .sset var i offset ty y _ _ => .sset var i offset ty y +| .inc fvarId n check persistent _ _ => .inc fvarId n check persistent +| .dec fvarId n check persistent _ _ => .dec fvarId n check persistent | _ => unreachable! def attachCodeDecls (decls : Array (CodeDecl pu)) (code : Code pu) : Code pu := @@ -465,6 +471,8 @@ where | .jp decl => go (i-1) (.jp decl code) | .uset var idx y _ => go (i-1) (.uset var idx y code) | .sset var idx offset y ty _ => go (i-1) (.sset var idx offset y ty code) + | .inc fvarId n check persistent _ => go (i-1) (.inc fvarId n check persistent code) + | .dec fvarId n check persistent _ => go (i-1) (.dec fvarId n check persistent code) else code @@ -484,6 +492,10 @@ mutual v₁ == v₂ && i₁ == i₂ && y₁ == y₂ && eqImp k₁ k₂ | .sset v₁ i₁ o₁ y₁ ty₁ k₁ _, .sset v₂ i₂ o₂ y₂ ty₂ k₂ _ => v₁ == v₂ && i₁ == i₂ && o₁ == o₂ && y₁ == y₂ && ty₁ == ty₂ && eqImp k₁ k₂ + | .inc v₁ n₁ c₁ p₁ k₁ _, .inc v₂ n₂ c₂ p₂ k₂ _ => + v₁ == v₂ && n₁ == n₂ && c₁ == c₂ && p₁ == p₂ && eqImp k₁ k₂ + | .dec v₁ n₁ c₁ p₁ k₁ _, .dec v₂ n₂ c₂ p₂ k₂ _ => + v₁ == v₂ && n₁ == n₂ && c₁ == c₂ && p₁ == p₂ && eqImp k₁ k₂ | _, _ => false private unsafe def eqFunDecl (d₁ d₂ : FunDecl pu) : Bool := @@ -578,6 +590,8 @@ private unsafe def updateAltImp (alt : Alt pu) (ps' : Array (Param pu)) (k' : Co | .jp decl k => if ptrEq k k' then c else .jp decl k' | .sset fvarId i offset y ty k _ => if ptrEq k k' then c else .sset fvarId i offset y ty k' | .uset fvarId offset y k _ => if ptrEq k k' then c else .uset fvarId offset y k' + | .inc fvarId n check persistent k _ => if ptrEq k k' then c else .inc fvarId n check persistent k' + | .dec fvarId n check persistent k _ => if ptrEq k k' then c else .dec fvarId n check persistent k' | _ => unreachable! @[implemented_by updateContImp] opaque Code.updateCont! (c : Code pu) (k' : Code pu) : Code pu @@ -637,6 +651,40 @@ private unsafe def updateAltImp (alt : Alt pu) (ps' : Array (Param pu)) (k' : Co @[implemented_by updateUsetImp] opaque Code.updateUset! (c : Code pu) (fvarId' : FVarId) (i' : Nat) (y' : FVarId) (k' : Code pu) : Code pu +@[inline] private unsafe def updateIncImp (c : Code pu) (fvarId' : FVarId) (n' : Nat) + (check' : Bool) (persistent' : Bool) (k' : Code pu) : Code pu := + match c with + | .inc fvarId n check persistent k _ => + if ptrEq fvarId fvarId' + && n == n' + && check == check' + && persistent == persistent' + && ptrEq k k' then + c + else + .inc fvarId' n' check' persistent' k' + | _ => unreachable! + +@[implemented_by updateIncImp] opaque Code.updateInc! (c : Code pu) (fvarId' : FVarId) (n' : Nat) + (check' : Bool) (persistent' : Bool) (k' : Code pu) : Code pu + +@[inline] private unsafe def updateDecImp (c : Code pu) (fvarId' : FVarId) (n' : Nat) + (check' : Bool) (persistent' : Bool) (k' : Code pu) : Code pu := + match c with + | .dec fvarId n check persistent k _ => + if ptrEq fvarId fvarId' + && n == n' + && check == check' + && persistent == persistent' + && ptrEq k k' then + c + else + .dec fvarId' n' check' persistent' k' + | _ => unreachable! + +@[implemented_by updateDecImp] opaque Code.updateDec! (c : Code pu) (fvarId' : FVarId) (n' : Nat) + (check' : Bool) (persistent' : Bool) (k' : Code pu) : Code pu + private unsafe def updateParamCoreImp (p : Param pu) (type : Expr) : Param pu := if ptrEq type p.type then p @@ -705,7 +753,8 @@ partial def Code.size (c : Code pu) : Nat := where go (c : Code pu) (n : Nat) : Nat := match c with - | .let _ k | .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => go k (n + 1) + | .let (k := k) .. | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. + | .dec (k := k) .. => go k (n + 1) | .jp decl k | .fun decl k _ => go k <| go decl.value n | .cases c => c.alts.foldl (init := n+1) fun n alt => go alt.getCode (n+1) | .jmp .. => n+1 @@ -723,7 +772,8 @@ where go (c : Code pu) : EStateM Unit Nat Unit := do match c with - | .let _ k | .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => inc; go k + | .let (k := k) .. | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. + | .dec (k := k) .. => inc; go k | .jp decl k | .fun decl k _ => inc; go decl.value; go k | .cases c => inc; c.alts.forM fun alt => go alt.getCode | .jmp .. => inc @@ -735,7 +785,8 @@ where go (c : Code pu) : m Unit := do f c match c with - | .let _ k | .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => go k + | .let (k := k) .. | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. + | .dec (k := k) .. => go k | .fun decl k _ | .jp decl k => go decl.value; go k | .cases c => c.alts.forM fun alt => go alt.getCode | .unreach .. | .return .. | .jmp .. => return () @@ -1031,6 +1082,8 @@ partial def Code.collectUsed (code : Code pu) (s : FVarIdHashSet := {}) : FVarId let s := s.insert var let s := s.insert y k.collectUsed s + | .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) .. => + k.collectUsed <| s.insert fvarId end @[inline] def collectUsedAtExpr (s : FVarIdHashSet) (e : Expr) : FVarIdHashSet := @@ -1040,8 +1093,8 @@ def CodeDecl.collectUsed (codeDecl : CodeDecl pu) (s : FVarIdHashSet := ∅) : F match codeDecl with | .let decl => collectLetValue decl.value <| collectType decl.type s | .jp decl | .fun decl _ => decl.collectUsed s - | .sset var _ _ y ty _ => s.insert var |>.insert y |> collectType ty - | .uset var _ y _ => s.insert var |>.insert y + | .sset (var := var) (y := y) .. | .uset (var := var) (y := y) .. => s.insert var |>.insert y + | .inc (fvarId := fvarId) .. | .dec (fvarId := fvarId) .. => s.insert fvarId /-- Traverse the given block of potentially mutually recursive functions @@ -1071,7 +1124,7 @@ where modify fun s => s.insert declName | _ => pure () visit k - | .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => visit k + | .uset (k := k) .. | .sset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => visit k go : StateM NameSet Unit := decls.forM (·.value.forCodeM visit) diff --git a/src/Lean/Compiler/LCNF/Bind.lean b/src/Lean/Compiler/LCNF/Bind.lean index e24ea42a2f..6e8a323dbf 100644 --- a/src/Lean/Compiler/LCNF/Bind.lean +++ b/src/Lean/Compiler/LCNF/Bind.lean @@ -68,8 +68,8 @@ where eraseCode k eraseParam auxParam return .unreach typeNew - | .sset fvarId i offset y ty k _ => return .sset fvarId i offset y ty (← go k) - | .uset fvarId offset y k _ => return .uset fvarId offset y (← go k) + | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => + return c.updateCont! (← go k) instance : MonadCodeBind CompilerM where codeBind := CompilerM.codeBind diff --git a/src/Lean/Compiler/LCNF/CompilerM.lean b/src/Lean/Compiler/LCNF/CompilerM.lean index 3615c758c8..f12eb74c64 100644 --- a/src/Lean/Compiler/LCNF/CompilerM.lean +++ b/src/Lean/Compiler/LCNF/CompilerM.lean @@ -149,7 +149,7 @@ def eraseCodeDecl (decl : CodeDecl pu) : CompilerM Unit := do match decl with | .let decl => eraseLetDecl decl | .jp decl | .fun decl _ => eraseFunDecl decl - | .sset .. | .uset .. => return () + | .sset .. | .uset .. | .inc .. | .dec .. => return () /-- Erase all free variables occurring in `decls` from the local context. @@ -501,6 +501,12 @@ mutual withNormFVarResult (← normFVar fvarId) fun fvarId => do withNormFVarResult (← normFVar y) fun y => do return code.updateUset! fvarId offset y (← normCodeImp k) + | .inc fvarId n check persistent k _ => + withNormFVarResult (← normFVar fvarId) fun fvarId => do + return code.updateInc! fvarId n check persistent (← normCodeImp k) + | .dec fvarId n check persistent k _ => + withNormFVarResult (← normFVar fvarId) fun fvarId => do + return code.updateDec! fvarId n check persistent (← normCodeImp k) end @[inline] def normFunDecl [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m pu t] (decl : FunDecl pu) : m (FunDecl pu) := do diff --git a/src/Lean/Compiler/LCNF/DeclHash.lean b/src/Lean/Compiler/LCNF/DeclHash.lean index fc52d89a92..e47e601117 100644 --- a/src/Lean/Compiler/LCNF/DeclHash.lean +++ b/src/Lean/Compiler/LCNF/DeclHash.lean @@ -41,6 +41,10 @@ partial def hashCode (code : Code pu) : UInt64 := mixHash (mixHash (hash fvarId) (hash i)) (mixHash (mixHash (hash offset) (hash y)) (mixHash (hash ty) (hashCode k))) | .uset fvarId offset y k _ => mixHash (mixHash (hash fvarId) (hash offset)) (mixHash (hash y) (hashCode k)) + | .inc fvarId n check persistent k _ => + mixHash (mixHash (hash fvarId) (hash n)) (mixHash (mixHash (hash persistent) (hash check)) (hashCode k)) + | .dec fvarId n check persistent k _ => + mixHash (mixHash (hash fvarId) (hash n)) (mixHash (mixHash (hash persistent) (hash check)) (hashCode k)) end diff --git a/src/Lean/Compiler/LCNF/DependsOn.lean b/src/Lean/Compiler/LCNF/DependsOn.lean index 9875fb3bdf..f4ca834a5d 100644 --- a/src/Lean/Compiler/LCNF/DependsOn.lean +++ b/src/Lean/Compiler/LCNF/DependsOn.lean @@ -47,6 +47,8 @@ private partial def depOn (c : Code pu) : M Bool := | .return fvarId => fvarDepOn fvarId | .unreach _ => return false | .sset fv1 _ _ fv2 _ k _ | .uset fv1 _ fv2 k _ => fvarDepOn fv1 <||> fvarDepOn fv2 <||> depOn k + | .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) .. => + fvarDepOn fvarId <||> depOn k @[inline] def Arg.dependsOn (arg : Arg pu) (s : FVarIdSet) : Bool := argDepOn arg s @@ -64,8 +66,8 @@ def CodeDecl.dependsOn (decl : CodeDecl pu) (s : FVarIdSet) : Bool := match decl with | .let decl => decl.dependsOn s | .jp decl | .fun decl _ => decl.dependsOn s - | .uset var _ y _ => s.contains var || s.contains y - | .sset var _ _ y ty _ => s.contains var || s.contains y || (typeDepOn ty s) + | .uset (var := var) (y := y) .. | .sset (var := var) (y := y) .. => s.contains var || s.contains y + | .inc (fvarId := fvarId) .. | .dec (fvarId := fvarId) .. => s.contains fvarId /-- Return `true` is `c` depends on a free variable in `s`. diff --git a/src/Lean/Compiler/LCNF/ElimDead.lean b/src/Lean/Compiler/LCNF/ElimDead.lean index 76403b2141..863624dde1 100644 --- a/src/Lean/Compiler/LCNF/ElimDead.lean +++ b/src/Lean/Compiler/LCNF/ElimDead.lean @@ -102,6 +102,10 @@ partial def Code.elimDead (code : Code pu) : M (Code pu) := do return code.updateCont! k else return k + | .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) .. => + let k ← k.elimDead + collectFVarM fvarId + return code.updateCont! k end diff --git a/src/Lean/Compiler/LCNF/ExplicitBoxing.lean b/src/Lean/Compiler/LCNF/ExplicitBoxing.lean index d6187dce15..d35a1db04b 100644 --- a/src/Lean/Compiler/LCNF/ExplicitBoxing.lean +++ b/src/Lean/Compiler/LCNF/ExplicitBoxing.lean @@ -292,6 +292,7 @@ partial def Code.explicitBoxing (code : Code .impure) : BoxM (Code .impure) := d let some jpDecl ← findFunDecl? fvarId | unreachable! castArgsIfNeeded args jpDecl.params fun args => return code.updateJmp! fvarId args | .unreach .. => return code.updateUnreach! (← getResultType) + | .inc .. | .dec .. => unreachable! where /-- Up to this point the type system of IR is quite loose so we can for example encounter situations @@ -368,7 +369,7 @@ def run (decls : Array (Decl .impure)) : CompilerM (Array (Decl .impure)) := do public def explicitBoxing : Pass where phase := .impure phaseOut := .impure - name := `boxing + name := `explicitBoxing run := run builtin_initialize diff --git a/src/Lean/Compiler/LCNF/FVarUtil.lean b/src/Lean/Compiler/LCNF/FVarUtil.lean index 7f0c1e67c6..27fe132b03 100644 --- a/src/Lean/Compiler/LCNF/FVarUtil.lean +++ b/src/Lean/Compiler/LCNF/FVarUtil.lean @@ -143,6 +143,10 @@ partial def Code.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m F return Code.updateSset! c (← f fvarId) i offset (← f y) (← Expr.mapFVarM f ty) (← mapFVarM f k) | .uset fvarId offset y k _ => return Code.updateUset! c (← f fvarId) offset (← f y) (← mapFVarM f k) + | .inc fvarId n check persistent k _ => + return Code.updateInc! c (← f fvarId) n check persistent (← mapFVarM f k) + | .dec fvarId n check persistent k _ => + return Code.updateDec! c (← f fvarId) n check persistent (← mapFVarM f k) partial def Code.forFVarM [Monad m] (f : FVarId → m Unit) (c : Code pu) : m Unit := do match c with @@ -178,6 +182,9 @@ partial def Code.forFVarM [Monad m] (f : FVarId → m Unit) (c : Code pu) : m Un f fvarId f y forFVarM f k + | .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) .. => + f fvarId + forFVarM f k instance : TraverseFVar (Code pu) where mapFVarM := Code.mapFVarM @@ -204,6 +211,8 @@ instance : TraverseFVar (CodeDecl pu) where | .let decl => return .let (← mapFVarM f decl) | .uset var i y _ => return .uset (← f var) i (← f y) | .sset var i offset y ty _ => return .sset (← f var) i offset (← f y) (← mapFVarM f ty) + | .inc fvarId n check persistent _ => return .inc (← f fvarId) n check persistent + | .dec fvarId n check persistent _ => return .dec (← f fvarId) n check persistent forFVarM f decl := match decl with | .fun decl _ => forFVarM f decl @@ -211,6 +220,7 @@ instance : TraverseFVar (CodeDecl pu) where | .let decl => forFVarM f decl | .uset var i y _ => do f var; f y | .sset var i offset y ty _ => do f var; f y; forFVarM f ty + | .inc (fvarId := fvarId) .. | .dec (fvarId := fvarId) .. => f fvarId instance : TraverseFVar (Alt pu) where mapFVarM f alt := do diff --git a/src/Lean/Compiler/LCNF/InferBorrow.lean b/src/Lean/Compiler/LCNF/InferBorrow.lean index ec1a25eb7e..9e7ff0b5be 100644 --- a/src/Lean/Compiler/LCNF/InferBorrow.lean +++ b/src/Lean/Compiler/LCNF/InferBorrow.lean @@ -91,6 +91,7 @@ where | .cases cs => cs.alts.forM (·.forCodeM (goCode declName)) | .let _ k | .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => goCode declName k | .return .. | .jmp .. | .unreach .. => return () + | .inc .. | .dec .. => unreachable! /-- Apply the inferred borrow annotations from `map` to a SCC. @@ -120,6 +121,7 @@ where | .cases cs => return code.updateAlts! <| ← cs.alts.mapM (·.mapCodeM (go declName)) | .let _ k | .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => return code.updateCont! (← go declName k) | .return .. | .jmp .. | .unreach .. => return code + | .inc .. | .dec .. => unreachable! structure Ctx where /-- @@ -298,6 +300,7 @@ where | .cases cs => cs.alts.forM (·.forCodeM collectCode) | .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => collectCode k | .return .. | .unreach .. => return () + | .inc .. | .dec .. => unreachable! public def inferBorrow : Pass where diff --git a/src/Lean/Compiler/LCNF/Internalize.lean b/src/Lean/Compiler/LCNF/Internalize.lean index 67224bf04b..07048ffc59 100644 --- a/src/Lean/Compiler/LCNF/Internalize.lean +++ b/src/Lean/Compiler/LCNF/Internalize.lean @@ -166,6 +166,12 @@ partial def internalizeCode (code : Code pu) : InternalizeM pu (Code pu) := do withNormFVarResult (← normFVar fvarId) fun fvarId => do withNormFVarResult (← normFVar y) fun y => do return .uset fvarId offset y (← internalizeCode k) + | .inc fvarId n check persistent k _ => + withNormFVarResult (← normFVar fvarId) fun fvarId => do + return .inc fvarId n check persistent (← internalizeCode k) + | .dec fvarId n check persistent k _ => + withNormFVarResult (← normFVar fvarId) fun fvarId => do + return .dec fvarId n check persistent (← internalizeCode k) end @@ -184,6 +190,12 @@ partial def internalizeCodeDecl (decl : CodeDecl pu) : InternalizeM pu (CodeDecl let .fvar y ← normFVar y | unreachable! let ty ← normExpr ty return .sset var i offset y ty + | .inc fvarId n check offset _ => + let .fvar fvarId ← normFVar fvarId | unreachable! + return .inc fvarId n check offset + | .dec fvarId n check offset _ => + let .fvar fvarId ← normFVar fvarId | unreachable! + return .dec fvarId n check offset end Internalize diff --git a/src/Lean/Compiler/LCNF/LCtx.lean b/src/Lean/Compiler/LCNF/LCtx.lean index 86e6df70c2..0a8bc28f24 100644 --- a/src/Lean/Compiler/LCNF/LCtx.lean +++ b/src/Lean/Compiler/LCNF/LCtx.lean @@ -77,8 +77,9 @@ mutual | .let decl k => eraseCode k <| lctx.eraseLetDecl decl | .jp decl k | .fun decl k _ => eraseCode k <| eraseFunDecl lctx decl | .cases c => eraseAlts c.alts lctx - | .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => eraseCode k lctx - | _ => lctx + | .uset (k := k) .. | .sset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => + eraseCode k lctx + | .return .. | .jmp .. | .unreach .. => lctx end @[inline] diff --git a/src/Lean/Compiler/LCNF/LiveVars.lean b/src/Lean/Compiler/LCNF/LiveVars.lean index 7d9e004d76..db2f95db72 100644 --- a/src/Lean/Compiler/LCNF/LiveVars.lean +++ b/src/Lean/Compiler/LCNF/LiveVars.lean @@ -76,6 +76,8 @@ where go decl.value | .return var => visitVar var | .unreach .. => return false + | .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) .. => + visitVar fvarId <||> go k @[inline] visitVar (x : FVarId) : LiveM Bool := diff --git a/src/Lean/Compiler/LCNF/PrettyPrinter.lean b/src/Lean/Compiler/LCNF/PrettyPrinter.lean index 49fd92bb76..b0786d34a4 100644 --- a/src/Lean/Compiler/LCNF/PrettyPrinter.lean +++ b/src/Lean/Compiler/LCNF/PrettyPrinter.lean @@ -149,6 +149,16 @@ mutual return f!"sset {← ppFVar fvarId} [{i}, {offset}] := {← ppFVar y} " ++ ";" ++ .line ++ (← ppCode k) | .uset fvarId i y k _ => return f!"uset {← ppFVar fvarId} [{i}] := {← ppFVar y} " ++ ";" ++ .line ++ (← ppCode k) + | .inc fvarId n _ _ k _ => + if n != 1 then + return f!"inc[{n}] {← ppFVar fvarId};" ++ .line ++ (← ppCode k) + else + return f!"inc {← ppFVar fvarId};" ++ .line ++ (← ppCode k) + | .dec fvarId n _ _ k _ => + if n != 1 then + return f!"dec[{n}] {← ppFVar fvarId};" ++ .line ++ (← ppCode k) + else + return f!"dec {← ppFVar fvarId};" ++ .line ++ (← ppCode k) partial def ppDeclValue (b : DeclValue pu) : M Format := do diff --git a/src/Lean/Compiler/LCNF/Probing.lean b/src/Lean/Compiler/LCNF/Probing.lean index ba943ffca6..be5bdbf661 100644 --- a/src/Lean/Compiler/LCNF/Probing.lean +++ b/src/Lean/Compiler/LCNF/Probing.lean @@ -58,7 +58,7 @@ where go k | .cases cs => cs.alts.forM (go ·.getCode) | .jmp .. | .return .. | .unreach .. => return () - | .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => go k + | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k start (decls : Array (Decl pu)) : StateRefT (Array (LetValue pu)) CompilerM Unit := decls.forM (·.value.forCodeM go) @@ -73,7 +73,7 @@ where | .jp decl k => modify (·.push decl); go decl.value; go k | .cases cs => cs.alts.forM (go ·.getCode) | .jmp .. | .return .. | .unreach .. => return () - | .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => go k + | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k start (decls : Array (Decl pu)) : StateRefT (Array (FunDecl pu)) CompilerM Unit := decls.forM (·.value.forCodeM go) @@ -86,7 +86,7 @@ where | .fun decl k _ | .jp decl k => go decl.value <||> go k | .cases cs => cs.alts.anyM (go ·.getCode) | .jmp .. | .return .. | .unreach .. => return false - | .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => go k + | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k partial def filterByFun (pu : Purity) (f : FunDecl pu → CompilerM Bool) : Probe (Decl pu) (Decl pu) := filter (·.value.isCodeAndM go) @@ -96,7 +96,7 @@ where | .fun decl k _ => do if (← f decl) then return true else go decl.value <||> go k | .cases cs => cs.alts.anyM (go ·.getCode) | .jmp .. | .return .. | .unreach .. => return false - | .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => go k + | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k partial def filterByJp (pu : Purity) (f : FunDecl pu → CompilerM Bool) : Probe (Decl pu) (Decl pu) := filter (·.value.isCodeAndM go) @@ -107,7 +107,7 @@ where | .jp decl k => do if (← f decl) then return true else go decl.value <||> go k | .cases cs => cs.alts.anyM (go ·.getCode) | .jmp .. | .return .. | .unreach .. => return false - | .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => go k + | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k partial def filterByFunDecl (pu : Purity) (f : FunDecl pu → CompilerM Bool) : Probe (Decl pu) (Decl pu):= @@ -118,7 +118,7 @@ where | .fun decl k _ | .jp decl k => do if (← f decl) then return true else go decl.value <||> go k | .cases cs => cs.alts.anyM (go ·.getCode) | .jmp .. | .return .. | .unreach .. => return false - | .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => go k + | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k partial def filterByCases (pu : Purity) (f : Cases pu → CompilerM Bool) : Probe (Decl pu) (Decl pu) := filter (·.value.isCodeAndM go) @@ -128,7 +128,7 @@ where | .fun decl k _ | .jp decl k => go decl.value <||> go k | .cases cs => do if (← f cs) then return true else cs.alts.anyM (go ·.getCode) | .jmp .. | .return .. | .unreach .. => return false - | .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => go k + | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k partial def filterByJmp (pu : Purity) (f : FVarId → Array (Arg pu) → CompilerM Bool) : Probe (Decl pu) (Decl pu) := @@ -140,7 +140,7 @@ where | .cases cs => cs.alts.anyM (go ·.getCode) | .jmp fn var => f fn var | .return .. | .unreach .. => return false - | .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => go k + | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k partial def filterByReturn (pu : Purity) (f : FVarId → CompilerM Bool) : Probe (Decl pu) (Decl pu) := filter (·.value.isCodeAndM go) @@ -151,7 +151,7 @@ where | .cases cs => cs.alts.anyM (go ·.getCode) | .jmp .. | .unreach .. => return false | .return var => f var - | .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => go k + | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k partial def filterByUnreach (pu : Purity) (f : Expr → CompilerM Bool) : Probe (Decl pu) (Decl pu) := filter (·.value.isCodeAndM go) @@ -162,7 +162,7 @@ where | .cases cs => cs.alts.anyM (go ·.getCode) | .jmp .. | .return .. => return false | .unreach typ => f typ - | .sset _ _ _ _ _ k _ | .uset _ _ _ k _ => go k + | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k @[inline] def declNames (pu : Purity) : Probe (Decl pu) Name := diff --git a/src/Lean/Compiler/LCNF/PushProj.lean b/src/Lean/Compiler/LCNF/PushProj.lean index 4adb6fa90c..4c01cc3a79 100644 --- a/src/Lean/Compiler/LCNF/PushProj.lean +++ b/src/Lean/Compiler/LCNF/PushProj.lean @@ -137,6 +137,10 @@ where go k (decls.push (.uset var i y)) | .sset var i offset y ty k _ => go k (decls.push (.sset var i offset y ty)) + | .inc fvarId n check persistent k _ => + go k (decls.push (.inc fvarId n check persistent)) + | .dec fvarId n check persistent k _ => + go k (decls.push (.dec fvarId n check persistent)) | .cases c => c.pushProjs decls | .jmp .. | .return .. | .unreach .. => return attachCodeDecls decls c diff --git a/src/Lean/Compiler/LCNF/Renaming.lean b/src/Lean/Compiler/LCNF/Renaming.lean index 714d66c48f..91c25df0cb 100644 --- a/src/Lean/Compiler/LCNF/Renaming.lean +++ b/src/Lean/Compiler/LCNF/Renaming.lean @@ -53,10 +53,8 @@ partial def Code.applyRenaming (code : Code pu) (r : Renaming) : CompilerM (Code | .ctorAlt _ k _ => return alt.updateCode (← k.applyRenaming r) return code.updateAlts! alts | .jmp .. | .unreach .. | .return .. => return code - | .sset fvarId i offset y ty k _ => - return code.updateSset! fvarId i offset y ty (← k.applyRenaming r) - | .uset fvarId offset y k _ => - return code.updateUset! fvarId offset y (← k.applyRenaming r) + | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => + return code.updateCont! (← k.applyRenaming r) end def Decl.applyRenaming (decl : Decl pu) (r : Renaming) : CompilerM (Decl pu) := do diff --git a/src/Lean/Compiler/LCNF/ResetReuse.lean b/src/Lean/Compiler/LCNF/ResetReuse.lean index c00613efac..562fe46d96 100644 --- a/src/Lean/Compiler/LCNF/ResetReuse.lean +++ b/src/Lean/Compiler/LCNF/ResetReuse.lean @@ -120,6 +120,7 @@ where | .return .. | .jmp .. | .unreach .. => return (c, false) | .sset _ _ _ _ _ k _ | .uset _ _ _ k _ | .let _ k => goK k + | .inc .. | .dec .. => unreachable! def isCtorUsing (instr : CodeDecl .impure) (x : FVarId) : Bool := match instr with @@ -241,6 +242,7 @@ where return (c.updateCont! k, false) | .return .. | .jmp .. | .unreach .. => return (c, ← c.isFVarLiveIn x) + | .inc .. | .dec .. => unreachable! end @@ -273,6 +275,7 @@ partial def Code.insertResetReuse (c : Code .impure) : ReuseM (Code .impure) := | .let _ k | .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => return c.updateCont! (← k.insertResetReuse) | .return .. | .jmp .. | .unreach .. => return c + | .inc .. | .dec .. => unreachable! partial def Decl.insertResetReuseCore (decl : Decl .impure) : ReuseM (Decl .impure) := do let value ← decl.value.mapCodeM fun code => do @@ -295,6 +298,7 @@ where | .jp decl k => collectResets decl.value; collectResets k | .cases c => c.alts.forM (collectResets ·.getCode) | .unreach .. | .return .. | .jmp .. => return () + | .inc .. | .dec .. => unreachable! def Decl.insertResetReuse (decl : Decl .impure) : CompilerM (Decl .impure) := do diff --git a/src/Lean/Compiler/LCNF/SimpCase.lean b/src/Lean/Compiler/LCNF/SimpCase.lean index 23d7a0b3e6..a213ce3636 100644 --- a/src/Lean/Compiler/LCNF/SimpCase.lean +++ b/src/Lean/Compiler/LCNF/SimpCase.lean @@ -107,7 +107,7 @@ partial def Code.simpCase (code : Code .impure) : CompilerM (Code .impure) := do let decl ← decl.updateValue (← decl.value.simpCase) return code.updateFun! decl (← k.simpCase) | .return .. | .jmp .. | .unreach .. => return code - | .let _ k | .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => + | .let _ k | .uset (k := k) .. | .sset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => return code.updateCont! (← k.simpCase) def Decl.simpCase (decl : Decl .impure) : CompilerM (Decl .impure) := do diff --git a/src/Lean/Compiler/LCNF/SplitSCC.lean b/src/Lean/Compiler/LCNF/SplitSCC.lean index eaad3fa621..e0e4429e2c 100644 --- a/src/Lean/Compiler/LCNF/SplitSCC.lean +++ b/src/Lean/Compiler/LCNF/SplitSCC.lean @@ -34,7 +34,8 @@ where goCode k | .cases cases => cases.alts.forM (·.forCodeM goCode) | .jmp .. | .return .. | .unreach .. => return () - | .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => goCode k + | .uset (k := k) .. | .sset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => + goCode k end SplitScc diff --git a/src/Lean/Compiler/LCNF/ToExpr.lean b/src/Lean/Compiler/LCNF/ToExpr.lean index 4651a41040..4fe111bf8e 100644 --- a/src/Lean/Compiler/LCNF/ToExpr.lean +++ b/src/Lean/Compiler/LCNF/ToExpr.lean @@ -120,6 +120,14 @@ partial def Code.toExprM (code : Code pu) : ToExprM Expr := do let value := mkApp3 (mkConst `uset) (.fvar fvarId) (toExpr offset) (.fvar y) let body ← withFVar fvarId k.toExprM return .letE `dummy (mkConst ``Unit) value body true + | .inc fvarId n check persistent k _ => + let value := mkApp4 (mkConst `inc) (.fvar fvarId) (toExpr n) (toExpr check) (toExpr persistent) + let body ← withFVar fvarId k.toExprM + return .letE `dummy (mkConst ``Unit) value body true + | .dec fvarId n check persistent k _ => + let body ← withFVar fvarId k.toExprM + let value := mkApp4 (mkConst `dec) (.fvar fvarId) (toExpr n) (toExpr check) (toExpr persistent) + return .letE `dummy (mkConst ``Unit) value body true end public def Code.toExpr (code : Code pu) (xs : Array FVarId := #[]) : Expr :=