From e96d969d5935c1b026c0e80d60626c8da88d0922 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrik=20B=C3=B6ving?= Date: Wed, 25 Feb 2026 11:43:15 +0100 Subject: [PATCH] feat: support for del, isShared, oset and setTag (#12687) This PR implements the LCNF instructions required for the expand reset reuse pass. --- src/Lean/Compiler/IR/ToIR.lean | 13 +++ src/Lean/Compiler/LCNF/AlphaEqv.lean | 13 +++ src/Lean/Compiler/LCNF/Basic.lean | 119 +++++++++++++++++---- src/Lean/Compiler/LCNF/Bind.lean | 3 +- src/Lean/Compiler/LCNF/CompilerM.lean | 16 ++- src/Lean/Compiler/LCNF/DeclHash.lean | 6 ++ src/Lean/Compiler/LCNF/DependsOn.lean | 19 +++- src/Lean/Compiler/LCNF/ElimDead.lean | 15 ++- src/Lean/Compiler/LCNF/ExplicitBoxing.lean | 6 +- src/Lean/Compiler/LCNF/ExplicitRC.lean | 6 +- src/Lean/Compiler/LCNF/FVarUtil.lean | 23 +++- src/Lean/Compiler/LCNF/InferBorrow.lean | 6 +- src/Lean/Compiler/LCNF/Internalize.lean | 26 ++++- src/Lean/Compiler/LCNF/LCtx.lean | 3 +- src/Lean/Compiler/LCNF/LiveVars.lean | 5 +- src/Lean/Compiler/LCNF/PrettyPrinter.lean | 7 ++ src/Lean/Compiler/LCNF/Probing.lean | 30 ++++-- src/Lean/Compiler/LCNF/PushProj.lean | 6 ++ src/Lean/Compiler/LCNF/Renaming.lean | 3 +- src/Lean/Compiler/LCNF/ResetReuse.lean | 8 +- src/Lean/Compiler/LCNF/SimpCase.lean | 3 +- src/Lean/Compiler/LCNF/ToExpr.lean | 12 +++ 22 files changed, 288 insertions(+), 60 deletions(-) diff --git a/src/Lean/Compiler/IR/ToIR.lean b/src/Lean/Compiler/IR/ToIR.lean index 7fdab2bff3..5a24dc6a8c 100644 --- a/src/Lean/Compiler/IR/ToIR.lean +++ b/src/Lean/Compiler/IR/ToIR.lean @@ -101,6 +101,10 @@ partial def lowerCode (c : LCNF.Code .impure) : M FnBody := do let ret ← getFVarValue fvarId return .ret ret | .unreach .. => return .unreachable + | .oset fvarId i y k _ => + let y ← lowerArg y + let .var fvarId ← getFVarValue fvarId | unreachable! + return .set fvarId i y (← lowerCode k) | .sset fvarId i offset y type k _ => let .var y ← getFVarValue y | unreachable! let .var fvarId ← getFVarValue fvarId | unreachable! @@ -109,12 +113,18 @@ partial def lowerCode (c : LCNF.Code .impure) : M FnBody := do let .var y ← getFVarValue y | unreachable! let .var fvarId ← getFVarValue fvarId | unreachable! return .uset fvarId i y (← lowerCode k) + | .setTag fvarId cidx k _ => + let .var var ← getFVarValue fvarId | unreachable! + return .setTag var cidx (← lowerCode k) | .inc fvarId n check persistent k _ => let .var var ← getFVarValue fvarId | unreachable! return .inc var n check persistent (← lowerCode k) | .dec fvarId n check persistent k _ => let .var var ← getFVarValue fvarId | unreachable! return .dec var n check persistent (← lowerCode k) + | .del fvarId k _ => + let .var var ← getFVarValue fvarId | unreachable! + return .del var (← lowerCode k) | .fun .. => panic! "all local functions should be λ-lifted" partial def lowerLet (decl : LCNF.LetDecl .impure) (k : LCNF.Code .impure) : M FnBody := do @@ -155,6 +165,9 @@ partial def lowerLet (decl : LCNF.LetDecl .impure) (k : LCNF.Code .impure) : M F | .unbox var => withGetFVarValue var fun var => do continueLet (.unbox var) + | .isShared var => + withGetFVarValue var fun var => do + continueLet (.isShared var) | .erased => mkErased () where mkErased (_ : Unit) : M FnBody := do diff --git a/src/Lean/Compiler/LCNF/AlphaEqv.lean b/src/Lean/Compiler/LCNF/AlphaEqv.lean index ddf835ad84..03c8798a5e 100644 --- a/src/Lean/Compiler/LCNF/AlphaEqv.lean +++ b/src/Lean/Compiler/LCNF/AlphaEqv.lean @@ -75,6 +75,7 @@ def eqvLetValue (e₁ e₂ : LetValue pu) : EqvM Bool := do pure (i₁ == i₂ && u₁ == u₂) <&&> eqvFVar v₁ v₂ <&&> eqvArgs as₁ as₂ | .box ty₁ v₁ _, .box ty₂ v₂ _ => eqvType ty₁ ty₂ <&&> eqvFVar v₁ v₂ | .unbox v₁ _, .unbox v₂ _ => eqvFVar v₁ v₂ + | .isShared v₁ _, .isShared v₂ _ => eqvFVar v₁ v₂ | _, _ => return false @[inline] def withFVar (fvarId₁ fvarId₂ : FVarId) (x : EqvM α) : EqvM α := @@ -143,6 +144,11 @@ partial def eqv (code₁ code₂ : Code pu) : EqvM Bool := do eqvFVar c₁.discr c₂.discr <&&> eqvType c₁.resultType c₂.resultType <&&> eqvAlts c₁.alts c₂.alts + | .oset fvarId₁ i₁ y₁ k₁ _, .oset fvarId₂ i₂ y₂ k₂ _ => + pure (i₁ == i₂) <&&> + eqvFVar fvarId₁ fvarId₂ <&&> + eqvArg y₁ y₂ <&&> + eqv k₁ k₂ | .sset fvarId₁ i₁ offset₁ y₁ ty₁ k₁ _, .sset fvarId₂ i₂ offset₂ y₂ ty₂ k₂ _ => pure (i₁ == i₂) <&&> pure (offset₁ == offset₂) <&&> @@ -155,6 +161,10 @@ partial def eqv (code₁ code₂ : Code pu) : EqvM Bool := do eqvFVar fvarId₁ fvarId₂ <&&> eqvFVar y₁ y₂ <&&> eqv k₁ k₂ + | .setTag fvarId₁ c₁ k₁ _, .setTag fvarId₂ c₂ k₂ _ => + pure (c₁ == c₂) <&&> + eqvFVar fvarId₁ fvarId₂ <&&> + eqv k₁ k₂ | .inc fvarId₁ n₁ c₁ p₁ k₁ _, .inc fvarId₂ n₂ c₂ p₂ k₂ _ => pure (n₁ == n₂) <&&> pure (c₁ == c₂) <&&> @@ -167,6 +177,9 @@ partial def eqv (code₁ code₂ : Code pu) : EqvM Bool := do pure (p₁ == p₂) <&&> eqvFVar fvarId₁ fvarId₂ <&&> eqv k₁ k₂ + | .del fvarId₁ k₁ _, .del fvarId₂ k₂ _ => + eqvFVar fvarId₁ fvarId₂ <&&> + eqv k₁ k₂ | _, _ => return false end diff --git a/src/Lean/Compiler/LCNF/Basic.lean b/src/Lean/Compiler/LCNF/Basic.lean index 77c085edf2..6cff07bec6 100644 --- a/src/Lean/Compiler/LCNF/Basic.lean +++ b/src/Lean/Compiler/LCNF/Basic.lean @@ -219,6 +219,10 @@ inductive LetValue (pu : Purity) where | box (ty : Expr) (fvarId : FVarId) (h : pu = .impure := by purity_tac) /-- Given `fvarId : [t]object`, obtain the underlying scalar value. -/ | unbox (fvarId : FVarId) (h : pu = .impure := by purity_tac) + /-- + Return whether the object stored behind `fvarId` is shared or not. The return type is a `UInt8`. + -/ + | isShared (fvarId : FVarId) (h : pu = .impure := by purity_tac) deriving Inhabited, BEq, Hashable def Arg.toLetValue (arg : Arg pu) : LetValue pu := @@ -298,7 +302,12 @@ private unsafe def LetValue.updateUnboxImp (e : LetValue pu) (fvarId' : FVarId) @[implemented_by LetValue.updateUnboxImp] opaque LetValue.updateUnbox! (e : LetValue pu) (fvarId' : FVarId) : LetValue pu +private unsafe def LetValue.updateIsSharedImp (e : LetValue pu) (fvarId' : FVarId) : LetValue pu := + match e with + | .isShared fvarId _ => if fvarId == fvarId' then e else .isShared fvarId' + | _ => unreachable! +@[implemented_by LetValue.updateIsSharedImp] opaque LetValue.updateIsShared! (e : LetValue pu) (fvarId' : FVarId) : LetValue pu private unsafe def LetValue.updateArgsImp (e : LetValue pu) (args' : Array (Arg pu)) : LetValue pu := match e with @@ -331,6 +340,7 @@ def LetValue.toExpr (e : LetValue pu) : Expr := #[.fvar var, .const i.name [], ToExpr.toExpr updateHeader] ++ (args.map Arg.toExpr) | .box ty var _ => mkApp2 (.const `box []) ty (.fvar var) | .unbox var _ => mkApp (.const `unbox []) (.fvar var) + | .isShared fvarId _ => mkApp (.const `isShared []) (.fvar fvarId) structure LetDecl (pu : Purity) where fvarId : FVarId @@ -361,10 +371,13 @@ inductive Code (pu : Purity) where | cases (cases : Cases pu) | return (fvarId : FVarId) | unreach (type : Expr) + | oset (fvarId : FVarId) (i : Nat) (y : Arg pu) (k : Code pu) (h : pu = .impure := by purity_tac) | uset (fvarId : FVarId) (i : Nat) (y : FVarId) (k : Code pu) (h : pu = .impure := by purity_tac) | sset (fvarId : FVarId) (i : Nat) (offset : Nat) (y : FVarId) (ty : Expr) (k : Code pu) (h : pu = .impure := by purity_tac) + | setTag (fvarId : FVarId) (cidx : Nat) (k : Code pu) (h : pu = .impure := by purity_tac) | inc (fvarId : FVarId) (n : Nat) (check : Bool) (persistent : Bool) (k : Code pu) (h : pu = .impure := by purity_tac) | dec (fvarId : FVarId) (n : Nat) (check : Bool) (persistent : Bool) (k : Code pu) (h : pu = .impure := by purity_tac) + | del (fvarId : FVarId) (k : Code pu) (h : pu = .impure := by purity_tac) deriving Inhabited end @@ -440,25 +453,32 @@ inductive CodeDecl (pu : Purity) where | let (decl : LetDecl pu) | fun (decl : FunDecl pu) (h : pu = .pure := by purity_tac) | jp (decl : FunDecl pu) + | oset (fvarId : FVarId) (i : Nat) (y : Arg pu) (h : pu = .impure := by purity_tac) | uset (fvarId : FVarId) (i : Nat) (y : FVarId) (h : pu = .impure := by purity_tac) | sset (fvarId : FVarId) (i : Nat) (offset : Nat) (y : FVarId) (ty : Expr) (h : pu = .impure := by purity_tac) + | setTag (fvarId : FVarId) (cidx : Nat) (h : pu = .impure := by purity_tac) | inc (fvarId : FVarId) (n : Nat) (check : Bool) (persistent : Bool) (h : pu = .impure := by purity_tac) | dec (fvarId : FVarId) (n : Nat) (check : Bool) (persistent : Bool) (h : pu = .impure := by purity_tac) + | del (fvarId : FVarId) (h : pu = .impure := by purity_tac) deriving Inhabited def CodeDecl.fvarId : CodeDecl pu → FVarId | .let decl | .fun decl _ | .jp decl => decl.fvarId - | .uset fvarId .. | .sset fvarId .. | .inc fvarId .. | .dec fvarId .. => fvarId + | .uset fvarId .. | .sset fvarId .. | .inc fvarId .. | .dec fvarId .. | .del fvarId .. + | .oset fvarId .. | .setTag fvarId .. => fvarId def Code.toCodeDecl! : Code pu → CodeDecl pu -| .let decl _ => .let decl -| .fun decl _ _ => .fun decl -| .jp decl _ => .jp decl -| .uset fvarId i y _ _ => .uset fvarId i y -| .sset fvarId i offset ty y _ _ => .sset fvarId i offset ty y -| .inc fvarId n check persistent _ _ => .inc fvarId n check persistent -| .dec fvarId n check persistent _ _ => .dec fvarId n check persistent -| _ => unreachable! + | .let decl _ => .let decl + | .fun decl _ _ => .fun decl + | .jp decl _ => .jp decl + | .oset fvarId i y _ _ => .oset fvarId i y + | .uset fvarId i y _ _ => .uset fvarId i y + | .sset fvarId i offset ty y _ _ => .sset fvarId i offset ty y + | .setTag fvarId cidx _ _ => .setTag fvarId cidx + | .inc fvarId n check persistent _ _ => .inc fvarId n check persistent + | .dec fvarId n check persistent _ _ => .dec fvarId n check persistent + | .del fvarId _ _ => .del fvarId + | _ => unreachable! def attachCodeDecls (decls : Array (CodeDecl pu)) (code : Code pu) : Code pu := go decls.size code @@ -469,10 +489,13 @@ where | .let decl => go (i-1) (.let decl code) | .fun decl _ => go (i-1) (.fun decl code) | .jp decl => go (i-1) (.jp decl code) + | .oset fvarId idx y _ => go (i-1) (.oset fvarId idx y code) | .uset fvarId idx y _ => go (i-1) (.uset fvarId idx y code) | .sset fvarId idx offset y ty _ => go (i-1) (.sset fvarId idx offset y ty code) + | .setTag fvarId cidx _ => go (i-1) (.setTag fvarId cidx code) | .inc fvarId n check persistent _ => go (i-1) (.inc fvarId n check persistent code) | .dec fvarId n check persistent _ => go (i-1) (.dec fvarId n check persistent code) + | .del fvarId _ => go (i-1) (.del fvarId code) else code @@ -488,14 +511,20 @@ mutual | .jmp j₁ as₁, .jmp j₂ as₂ => j₁ == j₂ && as₁ == as₂ | .return r₁, .return r₂ => r₁ == r₂ | .unreach t₁, .unreach t₂ => t₁ == t₂ + | .oset v₁ i₁ y₁ k₁ _, .oset v₂ i₂ y₂ k₂ _ => + v₁ == v₂ && i₁ == i₂ && y₁ == y₂ && eqImp k₁ k₂ | .uset v₁ i₁ y₁ k₁ _, .uset v₂ i₂ y₂ k₂ _ => v₁ == v₂ && i₁ == i₂ && y₁ == y₂ && eqImp k₁ k₂ | .sset v₁ i₁ o₁ y₁ ty₁ k₁ _, .sset v₂ i₂ o₂ y₂ ty₂ k₂ _ => v₁ == v₂ && i₁ == i₂ && o₁ == o₂ && y₁ == y₂ && ty₁ == ty₂ && eqImp k₁ k₂ + | .setTag v₁ c₁ k₁ _, .setTag v₂ c₂ k₂ _ => + v₁ == v₂ && c₁ == c₂ && eqImp k₁ k₂ | .inc v₁ n₁ c₁ p₁ k₁ _, .inc v₂ n₂ c₂ p₂ k₂ _ => v₁ == v₂ && n₁ == n₂ && c₁ == c₂ && p₁ == p₂ && eqImp k₁ k₂ | .dec v₁ n₁ c₁ p₁ k₁ _, .dec v₂ n₂ c₂ p₂ k₂ _ => v₁ == v₂ && n₁ == n₂ && c₁ == c₂ && p₁ == p₂ && eqImp k₁ k₂ + | .del v₁ k₁ _, .del v₂ k₂ _ => + v₁ == v₂ && eqImp k₁ k₂ | _, _ => false private unsafe def eqFunDecl (d₁ d₂ : FunDecl pu) : Bool := @@ -588,10 +617,13 @@ private unsafe def updateAltImp (alt : Alt pu) (ps' : Array (Param pu)) (k' : Co | .let decl k => if ptrEq k k' then c else .let decl k' | .fun decl k _ => if ptrEq k k' then c else .fun decl k' | .jp decl k => if ptrEq k k' then c else .jp decl k' + | .oset fvarId offset y k _ => if ptrEq k k' then c else .oset fvarId offset y k' | .sset fvarId i offset y ty k _ => if ptrEq k k' then c else .sset fvarId i offset y ty k' | .uset fvarId offset y k _ => if ptrEq k k' then c else .uset fvarId offset y k' + | .setTag fvarId cidx k _ => if ptrEq k k' then c else .setTag fvarId cidx k' | .inc fvarId n check persistent k _ => if ptrEq k k' then c else .inc fvarId n check persistent k' | .dec fvarId n check persistent k _ => if ptrEq k k' then c else .dec fvarId n check persistent k' + | .del fvarId k _ => if ptrEq k k' then c else .del fvarId k' | _ => unreachable! @[implemented_by updateContImp] opaque Code.updateCont! (c : Code pu) (k' : Code pu) : Code pu @@ -635,6 +667,19 @@ private unsafe def updateAltImp (alt : Alt pu) (ps' : Array (Param pu)) (k' : Co .sset fvarId' i' offset' y' ty' k' | _ => unreachable! +@[inline] private unsafe def updateOsetImp (c : Code pu) (fvarId' : FVarId) + (i' : Nat) (y' : Arg pu) (k' : Code pu) : Code pu := + match c with + | .oset fvarId i y k _ => + if ptrEq fvarId fvarId' && i == i' && ptrEq y y' && ptrEq k k' then + c + else + .oset fvarId' i' y' k' + | _ => unreachable! + +@[implemented_by updateOsetImp] opaque Code.updateOset! (c : Code pu) (fvarId' : FVarId) + (i' : Nat) (y' : Arg pu) (k' : Code pu) : Code pu + @[implemented_by updateSsetImp] opaque Code.updateSset! (c : Code pu) (fvarId' : FVarId) (i' : Nat) (offset' : Nat) (y' : FVarId) (ty' : Expr) (k' : Code pu) : Code pu @@ -651,6 +696,19 @@ private unsafe def updateAltImp (alt : Alt pu) (ps' : Array (Param pu)) (k' : Co @[implemented_by updateUsetImp] opaque Code.updateUset! (c : Code pu) (fvarId' : FVarId) (i' : Nat) (y' : FVarId) (k' : Code pu) : Code pu +@[inline] private unsafe def updateSetTagImp (c : Code pu) (fvarId' : FVarId) (cidx' : Nat) + (k' : Code pu) : Code pu := + match c with + | .setTag fvarId cidx k _ => + if ptrEq fvarId fvarId' && cidx == cidx' && ptrEq k k' then + c + else + .setTag fvarId' cidx' k' + | _ => unreachable! + +@[implemented_by updateSetTagImp] opaque Code.updateSetTag! (c : Code pu) (fvarId' : FVarId) + (cidx' : Nat) (k' : Code pu) : Code pu + @[inline] private unsafe def updateIncImp (c : Code pu) (fvarId' : FVarId) (n' : Nat) (check' : Bool) (persistent' : Bool) (k' : Code pu) : Code pu := match c with @@ -685,6 +743,19 @@ private unsafe def updateAltImp (alt : Alt pu) (ps' : Array (Param pu)) (k' : Co @[implemented_by updateDecImp] opaque Code.updateDec! (c : Code pu) (fvarId' : FVarId) (n' : Nat) (check' : Bool) (persistent' : Bool) (k' : Code pu) : Code pu +@[inline] private unsafe def updateDelImp (c : Code pu) (fvarId' : FVarId) (k' : Code pu) : + Code pu := + match c with + | .del fvarId k _ => + if ptrEq fvarId fvarId' && ptrEq k k' then + c + else + .del fvarId' k' + | _ => unreachable! + +@[implemented_by updateDelImp] opaque Code.updateDel! (c : Code pu) (fvarId' : FVarId) + (k' : Code pu) : Code pu + private unsafe def updateParamCoreImp (p : Param pu) (type : Expr) : Param pu := if ptrEq type p.type then p @@ -753,8 +824,8 @@ partial def Code.size (c : Code pu) : Nat := where go (c : Code pu) (n : Nat) : Nat := match c with - | .let (k := k) .. | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. - | .dec (k := k) .. => go k (n + 1) + | .let (k := k) .. | .oset (k := k) .. | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. + | .dec (k := k) .. | .setTag (k := k) .. | .del (k := k) .. => go k (n + 1) | .jp decl k | .fun decl k _ => go k <| go decl.value n | .cases c => c.alts.foldl (init := n+1) fun n alt => go alt.getCode (n+1) | .jmp .. => n+1 @@ -772,8 +843,8 @@ where go (c : Code pu) : EStateM Unit Nat Unit := do match c with - | .let (k := k) .. | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. - | .dec (k := k) .. => inc; go k + | .let (k := k) .. | .oset (k := k) .. | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. + | .dec (k := k) .. | .setTag (k := k) .. | .del (k := k) .. => inc; go k | .jp decl k | .fun decl k _ => inc; go decl.value; go k | .cases c => inc; c.alts.forM fun alt => go alt.getCode | .jmp .. => inc @@ -785,8 +856,8 @@ where go (c : Code pu) : m Unit := do f c match c with - | .let (k := k) .. | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. - | .dec (k := k) .. => go k + | .let (k := k) .. | .oset (k := k) .. | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. + | .dec (k := k) .. | .setTag (k := k) .. | .del (k := k) .. => go k | .fun decl k _ | .jp decl k => go decl.value; go k | .cases c => c.alts.forM fun alt => go alt.getCode | .unreach .. | .return .. | .jmp .. => return () @@ -1053,7 +1124,7 @@ private def collectLetValue (e : LetValue pu) (s : FVarIdHashSet) : FVarIdHashSe | .fvar fvarId args => collectArgs args <| s.insert fvarId | .const _ _ args _ | .pap _ args _ | .fap _ args _ | .ctor _ args _ => collectArgs args s | .proj _ _ fvarId _ | .sproj _ _ fvarId _ | .uproj _ fvarId _ | .oproj _ fvarId _ - | .reset _ fvarId _ | .box _ fvarId _ | .unbox fvarId _ => s.insert fvarId + | .reset _ fvarId _ | .box _ fvarId _ | .unbox fvarId _ | .isShared fvarId _ => s.insert fvarId | .lit .. | .erased => s | .reuse fvarId _ _ args _ => collectArgs args <| s.insert fvarId @@ -1082,7 +1153,12 @@ partial def Code.collectUsed (code : Code pu) (s : FVarIdHashSet := {}) : FVarId let s := s.insert fvarId let s := s.insert y k.collectUsed s - | .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) .. => + | .oset fvarId _ y k _ => + let s := s.insert fvarId + let s := if let .fvar y := y then s.insert y else s + k.collectUsed s + | .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) .. + | .del (fvarId := fvarId) (k := k) .. | .setTag (fvarId := fvarId) (k := k) .. => k.collectUsed <| s.insert fvarId end @@ -1095,7 +1171,11 @@ def CodeDecl.collectUsed (codeDecl : CodeDecl pu) (s : FVarIdHashSet := ∅) : F | .jp decl | .fun decl _ => decl.collectUsed s | .sset (fvarId := fvarId) (y := y) .. | .uset (fvarId := fvarId) (y := y) .. => s.insert fvarId |>.insert y - | .inc (fvarId := fvarId) .. | .dec (fvarId := fvarId) .. => s.insert fvarId + | .oset (fvarId := fvarId) (y := y) .. => + let s := s.insert fvarId + if let .fvar y := y then s.insert y else s + | .inc (fvarId := fvarId) .. | .dec (fvarId := fvarId) .. | .setTag (fvarId := fvarId) .. + | .del (fvarId := fvarId) .. => s.insert fvarId /-- Traverse the given block of potentially mutually recursive functions @@ -1125,7 +1205,8 @@ where modify fun s => s.insert declName | _ => pure () visit k - | .uset (k := k) .. | .sset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => visit k + | .oset (k := k) .. | .uset (k := k) .. | .sset (k := k) .. | .inc (k := k) .. + | .dec (k := k) .. | .del (k := k) .. | .setTag (k := k) .. => visit k go : StateM NameSet Unit := decls.forM (·.value.forCodeM visit) diff --git a/src/Lean/Compiler/LCNF/Bind.lean b/src/Lean/Compiler/LCNF/Bind.lean index 6e8a323dbf..874c8de9c3 100644 --- a/src/Lean/Compiler/LCNF/Bind.lean +++ b/src/Lean/Compiler/LCNF/Bind.lean @@ -68,7 +68,8 @@ where eraseCode k eraseParam auxParam return .unreach typeNew - | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => + | .oset (k := k) ..| .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. + | .del (k := k) .. | .setTag (k := k) .. => return c.updateCont! (← go k) instance : MonadCodeBind CompilerM where diff --git a/src/Lean/Compiler/LCNF/CompilerM.lean b/src/Lean/Compiler/LCNF/CompilerM.lean index f12eb74c64..3cb60f944a 100644 --- a/src/Lean/Compiler/LCNF/CompilerM.lean +++ b/src/Lean/Compiler/LCNF/CompilerM.lean @@ -149,7 +149,7 @@ def eraseCodeDecl (decl : CodeDecl pu) : CompilerM Unit := do match decl with | .let decl => eraseLetDecl decl | .jp decl | .fun decl _ => eraseFunDecl decl - | .sset .. | .uset .. | .inc .. | .dec .. => return () + | .sset .. | .uset .. | .inc .. | .dec .. | .setTag .. | .oset .. | .del .. => return () /-- Erase all free variables occurring in `decls` from the local context. @@ -300,6 +300,10 @@ private partial def normLetValueImp (s : FVarSubst pu) (e : LetValue pu) (transl match normFVarImp s fvarId translator with | .fvar fvarId' => e.updateUnbox! fvarId' | .erased => .erased + | .isShared fvarId _ => + match normFVarImp s fvarId translator with + | .fvar fvarId' => e.updateIsShared! fvarId' + | .erased => .erased /-- Interface for monads that have a free substitutions. @@ -497,16 +501,26 @@ mutual withNormFVarResult (← normFVar fvarId) fun fvarId => do withNormFVarResult (← normFVar y) fun y => do return code.updateSset! fvarId i offset y (← normExpr ty) (← normCodeImp k) + | .oset fvarId offset y k _ => + withNormFVarResult (← normFVar fvarId) fun fvarId => do + let y ← normArg y + return code.updateOset! fvarId offset y (← normCodeImp k) | .uset fvarId offset y k _ => withNormFVarResult (← normFVar fvarId) fun fvarId => do withNormFVarResult (← normFVar y) fun y => do return code.updateUset! fvarId offset y (← normCodeImp k) + | .setTag fvarId cidx k _ => + withNormFVarResult (← normFVar fvarId) fun fvarId => do + return code.updateSetTag! fvarId cidx (← normCodeImp k) | .inc fvarId n check persistent k _ => withNormFVarResult (← normFVar fvarId) fun fvarId => do return code.updateInc! fvarId n check persistent (← normCodeImp k) | .dec fvarId n check persistent k _ => withNormFVarResult (← normFVar fvarId) fun fvarId => do return code.updateDec! fvarId n check persistent (← normCodeImp k) + | .del fvarId k _ => + withNormFVarResult (← normFVar fvarId) fun fvarId => do + return code.updateDel! fvarId (← normCodeImp k) end @[inline] def normFunDecl [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m pu t] (decl : FunDecl pu) : m (FunDecl pu) := do diff --git a/src/Lean/Compiler/LCNF/DeclHash.lean b/src/Lean/Compiler/LCNF/DeclHash.lean index e47e601117..dc5e9dc33b 100644 --- a/src/Lean/Compiler/LCNF/DeclHash.lean +++ b/src/Lean/Compiler/LCNF/DeclHash.lean @@ -39,12 +39,18 @@ partial def hashCode (code : Code pu) : UInt64 := | .cases c => mixHash (mixHash (hash c.discr) (hash c.resultType)) (hashAlts c.alts) | .sset fvarId i offset y ty k _ => mixHash (mixHash (hash fvarId) (hash i)) (mixHash (mixHash (hash offset) (hash y)) (mixHash (hash ty) (hashCode k))) + | .oset fvarId offset y k _ => + mixHash (mixHash (hash fvarId) (hash offset)) (mixHash (hash y) (hashCode k)) | .uset fvarId offset y k _ => mixHash (mixHash (hash fvarId) (hash offset)) (mixHash (hash y) (hashCode k)) + | .setTag fvarId cidx k _ => + mixHash (hash fvarId) (mixHash (hash cidx) (hashCode k)) | .inc fvarId n check persistent k _ => mixHash (mixHash (hash fvarId) (hash n)) (mixHash (mixHash (hash persistent) (hash check)) (hashCode k)) | .dec fvarId n check persistent k _ => mixHash (mixHash (hash fvarId) (hash n)) (mixHash (mixHash (hash persistent) (hash check)) (hashCode k)) + | .del fvarId k _ => + mixHash (hash fvarId) (hashCode k) end diff --git a/src/Lean/Compiler/LCNF/DependsOn.lean b/src/Lean/Compiler/LCNF/DependsOn.lean index 2f2fc4a022..cc711514d8 100644 --- a/src/Lean/Compiler/LCNF/DependsOn.lean +++ b/src/Lean/Compiler/LCNF/DependsOn.lean @@ -31,7 +31,7 @@ private def letValueDepOn (e : LetValue pu) : M Bool := match e with | .erased | .lit .. => return false | .proj _ _ fvarId _ | .oproj _ fvarId _ | .uproj _ fvarId _ | .sproj _ _ fvarId _ - | .reset _ fvarId _ | .box _ fvarId _ | .unbox fvarId _ => fvarDepOn fvarId + | .reset _ fvarId _ | .box _ fvarId _ | .unbox fvarId _ | .isShared fvarId _ => fvarDepOn fvarId | .fvar fvarId args | .reuse fvarId _ _ args _ => fvarDepOn fvarId <||> args.anyM argDepOn | .const _ _ args _ | .ctor _ args _ | .fap _ args _ | .pap _ args _ => args.anyM argDepOn @@ -46,8 +46,12 @@ private partial def depOn (c : Code pu) : M Bool := | .jmp fvarId args => fvarDepOn fvarId <||> args.anyM argDepOn | .return fvarId => fvarDepOn fvarId | .unreach _ => return false - | .sset fv1 _ _ fv2 _ k _ | .uset fv1 _ fv2 k _ => fvarDepOn fv1 <||> fvarDepOn fv2 <||> depOn k - | .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) .. => + | .oset fv1 _ fv2 k _ => + fvarDepOn fv1 <||> argDepOn fv2 <||> depOn k + | .sset fv1 _ _ fv2 _ k _ | .uset fv1 _ fv2 k _ => + fvarDepOn fv1 <||> fvarDepOn fv2 <||> depOn k + | .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) .. + | .del (fvarId := fvarId) (k := k) .. | .setTag (fvarId := fvarId) (k := k) .. => fvarDepOn fvarId <||> depOn k @[inline] def Arg.dependsOn (arg : Arg pu) (s : FVarIdSet) : Bool := @@ -66,9 +70,14 @@ def CodeDecl.dependsOn (decl : CodeDecl pu) (s : FVarIdSet) : Bool := match decl with | .let decl => decl.dependsOn s | .jp decl | .fun decl _ => decl.dependsOn s - | .uset (fvarId := fvarId) (y := y) .. | .sset (fvarId := fvarId) (y := y) .. => + | .oset (fvarId := fvarId) (y := y) .. => + s.contains fvarId || y.dependsOn s + | .uset (fvarId := fvarId) (y := y) .. + | .sset (fvarId := fvarId) (y := y) .. => s.contains fvarId || s.contains y - | .inc (fvarId := fvarId) .. | .dec (fvarId := fvarId) .. => s.contains fvarId + | .inc (fvarId := fvarId) .. | .dec (fvarId := fvarId) .. | .del (fvarId := fvarId) .. + | .setTag (fvarId := fvarId) .. => + s.contains fvarId /-- Return `true` is `c` depends on a free variable in `s`. diff --git a/src/Lean/Compiler/LCNF/ElimDead.lean b/src/Lean/Compiler/LCNF/ElimDead.lean index 43343f3089..5baa270297 100644 --- a/src/Lean/Compiler/LCNF/ElimDead.lean +++ b/src/Lean/Compiler/LCNF/ElimDead.lean @@ -35,7 +35,7 @@ def collectLocalDeclsLetValue (s : UsedLocalDecls) (e : LetValue pu) : UsedLocal match e with | .erased | .lit .. => s | .proj _ _ fvarId _ | .reset _ fvarId _ | .sproj _ _ fvarId _ | .uproj _ fvarId _ - | .oproj _ fvarId _ | .box _ fvarId _ | .unbox fvarId _ => s.insert fvarId + | .oproj _ fvarId _ | .box _ fvarId _ | .unbox fvarId _ | .isShared fvarId _ => s.insert fvarId | .const _ _ args _ => collectLocalDeclsArgs s args | .fvar fvarId args | .reuse fvarId _ _ args _ => collectLocalDeclsArgs (s.insert fvarId) args | .fap _ args _ | .pap _ args _ | .ctor _ args _ => collectLocalDeclsArgs s args @@ -56,9 +56,8 @@ def LetValue.safeToElim (val : LetValue pu) : Bool := | .pure => true | .impure => match val with - -- TODO | .isShared .. | .ctor .. | .reset .. | .reuse .. | .oproj .. | .uproj .. | .sproj .. | .lit .. | .pap .. - | .box .. | .unbox .. | .erased .. => true + | .box .. | .unbox .. | .erased .. | .isShared .. => true -- 0-ary full applications are considered constants | .fap _ args => args.isEmpty | .fvar .. => false @@ -95,6 +94,13 @@ partial def Code.elimDead (code : Code pu) : M (Code pu) := do | .return fvarId => collectFVarM fvarId; return code | .jmp fvarId args => collectFVarM fvarId; args.forM collectArgM; return code | .unreach .. => return code + | .oset fvarId _ y k _ => + let k ← k.elimDead + if (← get).contains fvarId then + collectArgM y + return code.updateCont! k + else + return k | .uset fvarId _ y k _ | .sset fvarId _ _ y _ k _ => let k ← k.elimDead if (← get).contains fvarId then @@ -102,7 +108,8 @@ partial def Code.elimDead (code : Code pu) : M (Code pu) := do return code.updateCont! k else return k - | .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) .. => + | .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) .. + | .setTag (fvarId := fvarId) (k := k) .. | .del (fvarId := fvarId) (k := k) .. => let k ← k.elimDead collectFVarM fvarId return code.updateCont! k diff --git a/src/Lean/Compiler/LCNF/ExplicitBoxing.lean b/src/Lean/Compiler/LCNF/ExplicitBoxing.lean index c2aaf40aa8..7d36b3223b 100644 --- a/src/Lean/Compiler/LCNF/ExplicitBoxing.lean +++ b/src/Lean/Compiler/LCNF/ExplicitBoxing.lean @@ -284,7 +284,7 @@ partial def Code.explicitBoxing (code : Code .impure) : BoxM (Code .impure) := d let some jpDecl ← findFunDecl? fvarId | unreachable! castArgsIfNeeded args jpDecl.params fun args => return code.updateJmp! fvarId args | .unreach .. => return code.updateUnreach! (← getResultType) - | .inc .. | .dec .. => unreachable! + | .inc .. | .dec .. | .oset .. | .setTag .. | .del .. => unreachable! where /-- Up to this point the type system of IR is quite loose so we can for example encounter situations @@ -313,7 +313,7 @@ where | .ctor i _ => return i.type | .fvar .. | .lit .. | .sproj .. | .oproj .. | .reset .. | .reuse .. => return currentType - | .box .. | .unbox .. => unreachable! + | .box .. | .unbox .. | .isShared .. => unreachable! visitLet (code : Code .impure) (decl : LetDecl .impure) (k : Code .impure) : BoxM (Code .impure) := do let type ← tryCorrectLetDeclType decl.type decl.value @@ -350,7 +350,7 @@ where | .erased | .reset .. | .sproj .. | .uproj .. | .oproj .. | .lit .. => let decl ← decl.update type decl.value return code.updateLet! decl k - | .box .. | .unbox .. => unreachable! + | .box .. | .unbox .. | .isShared .. => unreachable! def run (decls : Array (Decl .impure)) : CompilerM (Array (Decl .impure)) := do let decls ← decls.foldlM (init := #[]) fun newDecls decl => do diff --git a/src/Lean/Compiler/LCNF/ExplicitRC.lean b/src/Lean/Compiler/LCNF/ExplicitRC.lean index d2881062d6..ff01955d58 100644 --- a/src/Lean/Compiler/LCNF/ExplicitRC.lean +++ b/src/Lean/Compiler/LCNF/ExplicitRC.lean @@ -117,7 +117,7 @@ partial def collectCode (code : Code .impure) : M Unit := do | .cases cases => cases.alts.forM (·.forCodeM collectCode) | .sset (k := k) .. | .uset (k := k) .. => collectCode k | .return .. | .jmp .. | .unreach .. => return () - | .inc .. | .dec .. => unreachable! + | .inc .. | .dec .. | .setTag .. | .oset .. | .del .. => unreachable! /-- Collect the derived value tree as well as the set of parameters that take objects and are borrowed. @@ -334,6 +334,7 @@ def useLetValue (value : LetValue .impure) : RcM Unit := do useVar fvarId useArgs args | .lit .. | .erased => return () + | .isShared .. => unreachable! @[inline] def bindVar (fvarId : FVarId) : RcM Unit := @@ -547,6 +548,7 @@ def LetDecl.explicitRc (code : Code .impure) (decl : LetDecl .impure) (k : Code addIncBeforeConsumeAll allArgs (code.updateLet! decl k) | .lit .. | .box .. | .reset .. | .erased .. => pure <| code.updateLet! decl k + | .isShared .. => unreachable! useLetValue decl.value bindVar decl.fvarId return k @@ -622,7 +624,7 @@ partial def Code.explicitRc (code : Code .impure) : RcM (Code .impure) := do | .unreach .. => setRetLiveVars return code - | .inc .. | .dec .. => unreachable! + | .inc .. | .dec .. | .setTag .. | .del .. | .oset .. => unreachable! def Decl.explicitRc (decl : Decl .impure) : CompilerM (Decl .impure) := do diff --git a/src/Lean/Compiler/LCNF/FVarUtil.lean b/src/Lean/Compiler/LCNF/FVarUtil.lean index dab681e5e1..7dd9dfb5e3 100644 --- a/src/Lean/Compiler/LCNF/FVarUtil.lean +++ b/src/Lean/Compiler/LCNF/FVarUtil.lean @@ -83,12 +83,13 @@ def LetValue.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m FVarI return e.updateReuse! (← f fvarId) i updateHeader (← args.mapM (TraverseFVar.mapFVarM f)) | .box ty fvarId _ => return e.updateBox! ty (← f fvarId) | .unbox fvarId _ => return e.updateUnbox! (← f fvarId) + | .isShared fvarId _ => return e.updateIsShared! (← f fvarId) def LetValue.forFVarM [Monad m] (f : FVarId → m Unit) (e : LetValue pu) : m Unit := do match e with | .lit .. | .erased => return () | .proj _ _ fvarId _ | .oproj _ fvarId _ | .sproj _ _ fvarId _ | .uproj _ fvarId _ - | .reset _ fvarId _ | .box _ fvarId _ | .unbox fvarId _ => f fvarId + | .reset _ fvarId _ | .box _ fvarId _ | .unbox fvarId _ | .isShared fvarId _ => f fvarId | .const _ _ args _ | .pap _ args _ | .fap _ args _ | .ctor _ args _ => args.forM (TraverseFVar.forFVarM f) | .fvar fvarId args | .reuse fvarId _ _ args _ => f fvarId; args.forM (TraverseFVar.forFVarM f) @@ -139,14 +140,20 @@ partial def Code.mapFVarM [MonadLiftT CompilerM m] [Monad m] (f : FVarId → m F return Code.updateReturn! c (← f var) | .unreach typ => return Code.updateUnreach! c (← Expr.mapFVarM f typ) + | .oset fvarId offset y k _ => + return Code.updateOset! c (← f fvarId) offset (← y.mapFVarM f) (← mapFVarM f k) | .sset fvarId i offset y ty k _ => return Code.updateSset! c (← f fvarId) i offset (← f y) (← Expr.mapFVarM f ty) (← mapFVarM f k) | .uset fvarId offset y k _ => return Code.updateUset! c (← f fvarId) offset (← f y) (← mapFVarM f k) + | .setTag fvarId cidx k _ => + return Code.updateSetTag! c (← f fvarId) cidx (← mapFVarM f k) | .inc fvarId n check persistent k _ => return Code.updateInc! c (← f fvarId) n check persistent (← mapFVarM f k) | .dec fvarId n check persistent k _ => return Code.updateDec! c (← f fvarId) n check persistent (← mapFVarM f k) + | .del fvarId k _ => + return Code.updateDel! c (← f fvarId) (← mapFVarM f k) partial def Code.forFVarM [Monad m] (f : FVarId → m Unit) (c : Code pu) : m Unit := do match c with @@ -182,7 +189,12 @@ partial def Code.forFVarM [Monad m] (f : FVarId → m Unit) (c : Code pu) : m Un f fvarId f y forFVarM f k - | .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) .. => + | .oset fvarId _ y k _ => + f fvarId + y.forFVarM f + forFVarM f k + | .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) .. + | .del (fvarId := fvarId) (k := k) .. | .setTag (fvarId := fvarId) (k := k) .. => f fvarId forFVarM f k @@ -210,17 +222,22 @@ instance : TraverseFVar (CodeDecl pu) where | .jp decl => return .jp (← mapFVarM f decl) | .let decl => return .let (← mapFVarM f decl) | .uset fvarId i y _ => return .uset (← f fvarId) i (← f y) + | .oset fvarId i y _ => return .oset (← f fvarId) i (← y.mapFVarM f) | .sset fvarId i offset y ty _ => return .sset (← f fvarId) i offset (← f y) (← mapFVarM f ty) + | .setTag fvarId cidx _ => return .setTag (← f fvarId) cidx | .inc fvarId n check persistent _ => return .inc (← f fvarId) n check persistent | .dec fvarId n check persistent _ => return .dec (← f fvarId) n check persistent + | .del fvarId _ => return .del (← f fvarId) forFVarM f decl := match decl with | .fun decl _ => forFVarM f decl | .jp decl => forFVarM f decl | .let decl => forFVarM f decl | .uset fvarId i y _ => do f fvarId; f y + | .oset fvarId i y _ => do f fvarId; y.forFVarM f | .sset fvarId i offset y ty _ => do f fvarId; f y; forFVarM f ty - | .inc (fvarId := fvarId) .. | .dec (fvarId := fvarId) .. => f fvarId + | .inc (fvarId := fvarId) .. | .dec (fvarId := fvarId) .. | .del (fvarId := fvarId) .. + | .setTag (fvarId := fvarId) .. => f fvarId instance : TraverseFVar (Alt pu) where mapFVarM f alt := do diff --git a/src/Lean/Compiler/LCNF/InferBorrow.lean b/src/Lean/Compiler/LCNF/InferBorrow.lean index fffe524124..062815f265 100644 --- a/src/Lean/Compiler/LCNF/InferBorrow.lean +++ b/src/Lean/Compiler/LCNF/InferBorrow.lean @@ -91,7 +91,7 @@ where | .cases cs => cs.alts.forM (·.forCodeM (goCode declName)) | .let _ k | .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => goCode declName k | .return .. | .jmp .. | .unreach .. => return () - | .inc .. | .dec .. => unreachable! + | .inc .. | .dec .. | .setTag .. | .del .. | .oset .. => unreachable! /-- Apply the inferred borrow annotations from `map` to a SCC. @@ -121,7 +121,7 @@ where | .cases cs => return code.updateAlts! <| ← cs.alts.mapMonoM (·.mapCodeM (go declName)) | .let _ k | .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => return code.updateCont! (← go declName k) | .return .. | .jmp .. | .unreach .. => return code - | .inc .. | .dec .. => unreachable! + | .inc .. | .dec .. | .setTag .. | .oset .. | .del .. => unreachable! structure Ctx where /-- @@ -300,7 +300,7 @@ where | .cases cs => cs.alts.forM (·.forCodeM collectCode) | .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => collectCode k | .return .. | .unreach .. => return () - | .inc .. | .dec .. => unreachable! + | .inc .. | .dec .. | .setTag .. | .oset .. | .del .. => unreachable! public def inferBorrow : Pass where diff --git a/src/Lean/Compiler/LCNF/Internalize.lean b/src/Lean/Compiler/LCNF/Internalize.lean index 6c6d9766d4..92b03bcaf9 100644 --- a/src/Lean/Compiler/LCNF/Internalize.lean +++ b/src/Lean/Compiler/LCNF/Internalize.lean @@ -120,6 +120,10 @@ private partial def internalizeLetValue (e : LetValue pu) : InternalizeM pu (Let match (← normFVar fvarId) with | .fvar fvarId' => return e.updateBox! ty fvarId' | .erased => return .erased + | .isShared fvarId _ => + match (← normFVar fvarId) with + | .fvar fvarId' => return e.updateIsShared! fvarId' + | .erased => return .erased def internalizeLetDecl (decl : LetDecl pu) : InternalizeM pu (LetDecl pu) := do let binderName ← refreshBinderName decl.binderName @@ -166,12 +170,22 @@ partial def internalizeCode (code : Code pu) : InternalizeM pu (Code pu) := do withNormFVarResult (← normFVar fvarId) fun fvarId => do withNormFVarResult (← normFVar y) fun y => do return .uset fvarId offset y (← internalizeCode k) + | .oset fvarId offset y k _ => + withNormFVarResult (← normFVar fvarId) fun fvarId => do + let y ← normArg y + return .oset fvarId offset y (← internalizeCode k) + | .setTag fvarId cidx k _ => + withNormFVarResult (← normFVar fvarId) fun fvarId => do + return .setTag fvarId cidx (← internalizeCode k) | .inc fvarId n check persistent k _ => withNormFVarResult (← normFVar fvarId) fun fvarId => do return .inc fvarId n check persistent (← internalizeCode k) | .dec fvarId n check persistent k _ => withNormFVarResult (← normFVar fvarId) fun fvarId => do return .dec fvarId n check persistent (← internalizeCode k) + | .del fvarId k _ => + withNormFVarResult (← normFVar fvarId) fun fvarId => do + return .del fvarId (← internalizeCode k) end @@ -180,8 +194,12 @@ partial def internalizeCodeDecl (decl : CodeDecl pu) : InternalizeM pu (CodeDecl | .let decl => return .let (← internalizeLetDecl decl) | .fun decl _ => return .fun (← internalizeFunDecl decl) | .jp decl => return .jp (← internalizeFunDecl decl) - | .uset fvarId i y _ => + | .oset fvarId i y _ => -- Something weird should be happening if these become erased... + let .fvar fvarId ← normFVar fvarId | unreachable! + let y ← normArg y + return .oset fvarId i y + | .uset fvarId i y _ => let .fvar fvarId ← normFVar fvarId | unreachable! let .fvar y ← normFVar y | unreachable! return .uset fvarId i y @@ -190,12 +208,18 @@ partial def internalizeCodeDecl (decl : CodeDecl pu) : InternalizeM pu (CodeDecl let .fvar y ← normFVar y | unreachable! let ty ← normExpr ty return .sset fvarId i offset y ty + | .setTag fvarId cidx _ => + let .fvar fvarId ← normFVar fvarId | unreachable! + return .setTag fvarId cidx | .inc fvarId n check offset _ => let .fvar fvarId ← normFVar fvarId | unreachable! return .inc fvarId n check offset | .dec fvarId n check offset _ => let .fvar fvarId ← normFVar fvarId | unreachable! return .dec fvarId n check offset + | .del fvarId _ => + let .fvar fvarId ← normFVar fvarId | unreachable! + return .del fvarId end Internalize diff --git a/src/Lean/Compiler/LCNF/LCtx.lean b/src/Lean/Compiler/LCNF/LCtx.lean index 0a8bc28f24..91952cdb92 100644 --- a/src/Lean/Compiler/LCNF/LCtx.lean +++ b/src/Lean/Compiler/LCNF/LCtx.lean @@ -77,7 +77,8 @@ mutual | .let decl k => eraseCode k <| lctx.eraseLetDecl decl | .jp decl k | .fun decl k _ => eraseCode k <| eraseFunDecl lctx decl | .cases c => eraseAlts c.alts lctx - | .uset (k := k) .. | .sset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => + | .oset (k := k) .. | .uset (k := k) .. | .sset (k := k) .. | .inc (k := k) .. + | .dec (k := k) .. | .del (k := k) .. | .setTag (k := k) .. => eraseCode k lctx | .return .. | .jmp .. | .unreach .. => lctx end diff --git a/src/Lean/Compiler/LCNF/LiveVars.lean b/src/Lean/Compiler/LCNF/LiveVars.lean index 0db5e6158e..050fdca292 100644 --- a/src/Lean/Compiler/LCNF/LiveVars.lean +++ b/src/Lean/Compiler/LCNF/LiveVars.lean @@ -65,6 +65,8 @@ where | .jp decl k => go decl.value <||> (do markJpVisited decl.fvarId; go k) | .uset fvarId _ y k _ | .sset fvarId _ _ y _ k _ => visitVar fvarId <||> visitVar y <||> go k + | .oset fvarId _ y k _ => + visitVar fvarId <||> pure (y.dependsOn (← read).targetSet) <||> go k | .cases c => visitVar c.discr <||> c.alts.anyM (go ·.getCode) | .jmp fvarId args => (pure <| args.any (·.dependsOn (← read).targetSet)) <||> do @@ -76,7 +78,8 @@ where go decl.value | .return var => visitVar var | .unreach .. => return false - | .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) .. => + | .inc (fvarId := fvarId) (k := k) .. | .dec (fvarId := fvarId) (k := k) .. + | .setTag (fvarId := fvarId) (k := k) .. | .del (fvarId := fvarId) (k := k) => visitVar fvarId <||> go k @[inline] diff --git a/src/Lean/Compiler/LCNF/PrettyPrinter.lean b/src/Lean/Compiler/LCNF/PrettyPrinter.lean index f924e7ed97..27ba8f3f12 100644 --- a/src/Lean/Compiler/LCNF/PrettyPrinter.lean +++ b/src/Lean/Compiler/LCNF/PrettyPrinter.lean @@ -94,6 +94,7 @@ def ppLetValue (e : LetValue pu) : M Format := do return f!"reuse" ++ (if updateHeader then f!"!" else f!"") ++ f!" {← ppFVar fvarId} in {info}{← ppArgs args}" | .box _ fvarId _ => return f!"box {← ppFVar fvarId}" | .unbox fvarId _ => return f!"unbox {← ppFVar fvarId}" + | .isShared fvarId _ => return f!"isShared {← ppFVar fvarId}" def ppParam (param : Param pu) : M Format := do let borrow := if param.borrow then "@&" else "" @@ -149,6 +150,10 @@ mutual return f!"sset {← ppFVar fvarId}[{i}, {offset}] := {← ppFVar y};" ++ .line ++ (← ppCode k) | .uset fvarId i y k _ => return f!"uset {← ppFVar fvarId}[{i}] := {← ppFVar y};" ++ .line ++ (← ppCode k) + | .oset fvarId i y k _ => + return f!"oset {← ppFVar fvarId} [{i}] := {← ppArg y};" ++ .line ++ (← ppCode k) + | .setTag fvarId cidx k _ => + return f!"setTag {← ppFVar fvarId} := {cidx};" ++ .line ++ (← ppCode k) | .inc fvarId n _ _ k _ => if n != 1 then return f!"inc[{n}] {← ppFVar fvarId};" ++ .line ++ (← ppCode k) @@ -159,6 +164,8 @@ mutual return f!"dec[{n}] {← ppFVar fvarId};" ++ .line ++ (← ppCode k) else return f!"dec {← ppFVar fvarId};" ++ .line ++ (← ppCode k) + | .del fvarId k _ => + return f!"del {← ppFVar fvarId};" ++ .line ++ (← ppCode k) partial def ppDeclValue (b : DeclValue pu) : M Format := do diff --git a/src/Lean/Compiler/LCNF/Probing.lean b/src/Lean/Compiler/LCNF/Probing.lean index be5bdbf661..17c2bdb008 100644 --- a/src/Lean/Compiler/LCNF/Probing.lean +++ b/src/Lean/Compiler/LCNF/Probing.lean @@ -58,7 +58,8 @@ where go k | .cases cs => cs.alts.forM (go ·.getCode) | .jmp .. | .return .. | .unreach .. => return () - | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k + | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. | .del (k := k) .. + | .setTag (k := k) .. | .oset (k := k) .. => go k start (decls : Array (Decl pu)) : StateRefT (Array (LetValue pu)) CompilerM Unit := decls.forM (·.value.forCodeM go) @@ -73,7 +74,8 @@ where | .jp decl k => modify (·.push decl); go decl.value; go k | .cases cs => cs.alts.forM (go ·.getCode) | .jmp .. | .return .. | .unreach .. => return () - | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k + | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. | .del (k := k) .. + | .setTag (k := k) .. | .oset (k := k) .. => go k start (decls : Array (Decl pu)) : StateRefT (Array (FunDecl pu)) CompilerM Unit := decls.forM (·.value.forCodeM go) @@ -86,7 +88,8 @@ where | .fun decl k _ | .jp decl k => go decl.value <||> go k | .cases cs => cs.alts.anyM (go ·.getCode) | .jmp .. | .return .. | .unreach .. => return false - | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k + | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. | .del (k := k) .. + | .setTag (k := k) .. | .oset (k := k) .. => go k partial def filterByFun (pu : Purity) (f : FunDecl pu → CompilerM Bool) : Probe (Decl pu) (Decl pu) := filter (·.value.isCodeAndM go) @@ -96,7 +99,8 @@ where | .fun decl k _ => do if (← f decl) then return true else go decl.value <||> go k | .cases cs => cs.alts.anyM (go ·.getCode) | .jmp .. | .return .. | .unreach .. => return false - | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k + | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. | .del (k := k) .. + | .setTag (k := k) .. | .oset (k := k) .. => go k partial def filterByJp (pu : Purity) (f : FunDecl pu → CompilerM Bool) : Probe (Decl pu) (Decl pu) := filter (·.value.isCodeAndM go) @@ -107,7 +111,8 @@ where | .jp decl k => do if (← f decl) then return true else go decl.value <||> go k | .cases cs => cs.alts.anyM (go ·.getCode) | .jmp .. | .return .. | .unreach .. => return false - | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k + | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. | .del (k := k) .. + | .setTag (k := k) .. | .oset (k := k) .. => go k partial def filterByFunDecl (pu : Purity) (f : FunDecl pu → CompilerM Bool) : Probe (Decl pu) (Decl pu):= @@ -118,7 +123,8 @@ where | .fun decl k _ | .jp decl k => do if (← f decl) then return true else go decl.value <||> go k | .cases cs => cs.alts.anyM (go ·.getCode) | .jmp .. | .return .. | .unreach .. => return false - | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k + | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. | .del (k := k) .. + | .setTag (k := k) .. | .oset (k := k) .. => go k partial def filterByCases (pu : Purity) (f : Cases pu → CompilerM Bool) : Probe (Decl pu) (Decl pu) := filter (·.value.isCodeAndM go) @@ -128,7 +134,8 @@ where | .fun decl k _ | .jp decl k => go decl.value <||> go k | .cases cs => do if (← f cs) then return true else cs.alts.anyM (go ·.getCode) | .jmp .. | .return .. | .unreach .. => return false - | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k + | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. | .del (k := k) .. + | .setTag (k := k) .. | .oset (k := k) .. => go k partial def filterByJmp (pu : Purity) (f : FVarId → Array (Arg pu) → CompilerM Bool) : Probe (Decl pu) (Decl pu) := @@ -140,7 +147,8 @@ where | .cases cs => cs.alts.anyM (go ·.getCode) | .jmp fn var => f fn var | .return .. | .unreach .. => return false - | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k + | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. | .del (k := k) .. + | .setTag (k := k) .. | .oset (k := k) .. => go k partial def filterByReturn (pu : Purity) (f : FVarId → CompilerM Bool) : Probe (Decl pu) (Decl pu) := filter (·.value.isCodeAndM go) @@ -151,7 +159,8 @@ where | .cases cs => cs.alts.anyM (go ·.getCode) | .jmp .. | .unreach .. => return false | .return var => f var - | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k + | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. | .del (k := k) .. + | .setTag (k := k) .. | .oset (k := k) .. => go k partial def filterByUnreach (pu : Purity) (f : Expr → CompilerM Bool) : Probe (Decl pu) (Decl pu) := filter (·.value.isCodeAndM go) @@ -162,7 +171,8 @@ where | .cases cs => cs.alts.anyM (go ·.getCode) | .jmp .. | .return .. => return false | .unreach typ => f typ - | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => go k + | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. | .del (k := k) .. + | .setTag (k := k) .. | .oset (k := k) .. => go k @[inline] def declNames (pu : Purity) : Probe (Decl pu) Name := diff --git a/src/Lean/Compiler/LCNF/PushProj.lean b/src/Lean/Compiler/LCNF/PushProj.lean index e131220296..d964232bec 100644 --- a/src/Lean/Compiler/LCNF/PushProj.lean +++ b/src/Lean/Compiler/LCNF/PushProj.lean @@ -133,6 +133,8 @@ where | .jp decl k => let decl ← decl.updateValue (← decl.value.pushProj) go k (decls.push (.jp decl)) + | .oset fvarId i y k _ => + go k (decls.push (.oset fvarId i y)) | .uset fvarId i y k _ => go k (decls.push (.uset fvarId i y)) | .sset fvarId i offset y ty k _ => @@ -141,6 +143,10 @@ where go k (decls.push (.inc fvarId n check persistent)) | .dec fvarId n check persistent k _ => go k (decls.push (.dec fvarId n check persistent)) + | .del fvarId k _ => + go k (decls.push (.del fvarId)) + | .setTag fvarId cidx k _ => + go k (decls.push (.setTag fvarId cidx)) | .cases c => c.pushProjs decls | .jmp .. | .return .. | .unreach .. => return attachCodeDecls decls c diff --git a/src/Lean/Compiler/LCNF/Renaming.lean b/src/Lean/Compiler/LCNF/Renaming.lean index 91c25df0cb..c2782f9433 100644 --- a/src/Lean/Compiler/LCNF/Renaming.lean +++ b/src/Lean/Compiler/LCNF/Renaming.lean @@ -53,7 +53,8 @@ partial def Code.applyRenaming (code : Code pu) (r : Renaming) : CompilerM (Code | .ctorAlt _ k _ => return alt.updateCode (← k.applyRenaming r) return code.updateAlts! alts | .jmp .. | .unreach .. | .return .. => return code - | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => + | .oset (k := k) .. | .sset (k := k) .. | .uset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. + | .del (k := k) .. | .setTag (k := k) .. => return code.updateCont! (← k.applyRenaming r) end diff --git a/src/Lean/Compiler/LCNF/ResetReuse.lean b/src/Lean/Compiler/LCNF/ResetReuse.lean index 1f870928c9..becf97a30a 100644 --- a/src/Lean/Compiler/LCNF/ResetReuse.lean +++ b/src/Lean/Compiler/LCNF/ResetReuse.lean @@ -120,7 +120,7 @@ where | .return .. | .jmp .. | .unreach .. => return (c, false) | .sset _ _ _ _ _ k _ | .uset _ _ _ k _ | .let _ k => goK k - | .inc .. | .dec .. => unreachable! + | .inc .. | .dec .. | .setTag .. | .oset .. | .del .. => unreachable! def isCtorUsing (instr : CodeDecl .impure) (x : FVarId) : Bool := match instr with @@ -242,7 +242,7 @@ where return (c.updateCont! k, false) | .return .. | .jmp .. | .unreach .. => return (c, ← c.isFVarLiveIn x) - | .inc .. | .dec .. => unreachable! + | .inc .. | .dec .. | .setTag .. | .oset .. | .del .. => unreachable! end @@ -275,7 +275,7 @@ partial def Code.insertResetReuse (c : Code .impure) : ReuseM (Code .impure) := | .let _ k | .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => return c.updateCont! (← k.insertResetReuse) | .return .. | .jmp .. | .unreach .. => return c - | .inc .. | .dec .. => unreachable! + | .inc .. | .dec .. | .setTag .. | .oset .. | .del .. => unreachable! partial def Decl.insertResetReuseCore (decl : Decl .impure) : ReuseM (Decl .impure) := do let value ← decl.value.mapCodeM fun code => do @@ -298,7 +298,7 @@ where | .jp decl k => collectResets decl.value; collectResets k | .cases c => c.alts.forM (collectResets ·.getCode) | .unreach .. | .return .. | .jmp .. => return () - | .inc .. | .dec .. => unreachable! + | .inc .. | .dec .. | .setTag .. | .oset .. | .del .. => unreachable! def Decl.insertResetReuse (decl : Decl .impure) : CompilerM (Decl .impure) := do diff --git a/src/Lean/Compiler/LCNF/SimpCase.lean b/src/Lean/Compiler/LCNF/SimpCase.lean index a213ce3636..abe521bcbd 100644 --- a/src/Lean/Compiler/LCNF/SimpCase.lean +++ b/src/Lean/Compiler/LCNF/SimpCase.lean @@ -107,7 +107,8 @@ partial def Code.simpCase (code : Code .impure) : CompilerM (Code .impure) := do let decl ← decl.updateValue (← decl.value.simpCase) return code.updateFun! decl (← k.simpCase) | .return .. | .jmp .. | .unreach .. => return code - | .let _ k | .uset (k := k) .. | .sset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. => + | .let _ k | .uset (k := k) .. | .sset (k := k) .. | .inc (k := k) .. | .dec (k := k) .. + | .setTag (k := k) .. | .del (k := k) .. | .oset (k := k) .. => return code.updateCont! (← k.simpCase) def Decl.simpCase (decl : Decl .impure) : CompilerM (Decl .impure) := do diff --git a/src/Lean/Compiler/LCNF/ToExpr.lean b/src/Lean/Compiler/LCNF/ToExpr.lean index 4fe111bf8e..ee66a49990 100644 --- a/src/Lean/Compiler/LCNF/ToExpr.lean +++ b/src/Lean/Compiler/LCNF/ToExpr.lean @@ -116,10 +116,18 @@ partial def Code.toExprM (code : Code pu) : ToExprM Expr := do let value := mkApp5 (mkConst `sset) (.fvar fvarId) (toExpr i) (toExpr offset) (.fvar y) ty let body ← withFVar fvarId k.toExprM return .letE `dummy (mkConst ``Unit) value body true + | .oset fvarId offset y k _ => + let value := mkApp3 (mkConst `oset) (.fvar fvarId) (toExpr offset) (← y.toExprM) + let body ← withFVar fvarId k.toExprM + return .letE `dummy (mkConst ``Unit) value body true | .uset fvarId offset y k _ => let value := mkApp3 (mkConst `uset) (.fvar fvarId) (toExpr offset) (.fvar y) let body ← withFVar fvarId k.toExprM return .letE `dummy (mkConst ``Unit) value body true + | .setTag fvarId cidx k _ => + let body ← withFVar fvarId k.toExprM + let value := mkApp2 (mkConst `setTag) (.fvar fvarId) (toExpr cidx) + return .letE `dummy (mkConst ``Unit) value body true | .inc fvarId n check persistent k _ => let value := mkApp4 (mkConst `inc) (.fvar fvarId) (toExpr n) (toExpr check) (toExpr persistent) let body ← withFVar fvarId k.toExprM @@ -128,6 +136,10 @@ partial def Code.toExprM (code : Code pu) : ToExprM Expr := do let body ← withFVar fvarId k.toExprM let value := mkApp4 (mkConst `dec) (.fvar fvarId) (toExpr n) (toExpr check) (toExpr persistent) return .letE `dummy (mkConst ``Unit) value body true + | .del fvarId k _ => + let body ← withFVar fvarId k.toExprM + let value := mkApp (mkConst `del) (.fvar fvarId) + return .letE `dummy (mkConst ``Unit) value body true end public def Code.toExpr (code : Code pu) (xs : Array FVarId := #[]) : Expr :=