From 62ded77e810a87fa5de5ca7cf13b1675aaf3c153 Mon Sep 17 00:00:00 2001 From: Cameron Zwarich Date: Tue, 15 Jul 2025 17:42:56 -0700 Subject: [PATCH] chore: add a new tagged IRType for inline tagged scalars (#9394) --- src/Lean/Compiler/IR/Basic.lean | 12 +++++++++-- src/Lean/Compiler/IR/EmitC.lean | 1 + src/Lean/Compiler/IR/EmitLLVM.lean | 1 + src/Lean/Compiler/IR/Format.lean | 1 + src/Lean/Compiler/IR/ToIR.lean | 6 ++++-- src/Lean/Compiler/IR/ToIRType.lean | 6 +++--- src/library/ir_interpreter.cpp | 8 +++++++- src/library/ir_types.h | 3 ++- tests/lean/4089.lean.expected.out | 4 ++-- tests/lean/4240.lean.expected.out | 4 ++-- .../lean/computedFieldsCode.lean.expected.out | 20 +++++++++---------- tests/lean/doubleReset.lean.expected.out | 2 +- tests/lean/listLength.lean.expected.out | 2 +- tests/lean/run/extractClosed.lean | 4 ++-- tests/lean/sint_basic.lean.expected.out | 4 ++-- tests/lean/unboxStruct.lean.expected.out | 4 ++-- 16 files changed, 51 insertions(+), 31 deletions(-) diff --git a/src/Lean/Compiler/IR/Basic.lean b/src/Lean/Compiler/IR/Basic.lean index a8429cdae3..b7480b2b00 100644 --- a/src/Lean/Compiler/IR/Basic.lean +++ b/src/Lean/Compiler/IR/Basic.lean @@ -49,8 +49,9 @@ abbrev MData.empty : MData := {} - `object` a pointer to a value in the heap. - - `tobject` a pointer to a value in the heap or tagged pointer - (i.e., the least significant bit is 1) storing a scalar value. + - `tagged` a tagged pointer (i.e., the least significant bit is 1) storing a scalar value. + + - `tobject` an `object` or a `tagged` pointer - `struct` and `union` are used to return small values (e.g., `Option`, `Prod`, `Except`) on the stack. @@ -77,6 +78,8 @@ inductive IRType where | float32 | struct (leanTypeName : Option Name) (types : Array IRType) : IRType | union (leanTypeName : Name) (types : Array IRType) : IRType + -- TODO: Move this upwards after a stage0 update. + | tagged deriving Inhabited, BEq, Repr namespace IRType @@ -93,6 +96,7 @@ def isScalar : IRType → Bool def isObj : IRType → Bool | object => true + | tagged => true | tobject => true | _ => false @@ -102,6 +106,7 @@ def isErased : IRType → Bool def boxed : IRType → IRType | object | erased | float | float32 => object + | tagged | uint8 | uint16 => tagged | _ => tobject end IRType @@ -150,6 +155,9 @@ def CtorInfo.isRef (info : CtorInfo) : Bool := def CtorInfo.isScalar (info : CtorInfo) : Bool := !info.isRef +def CtorInfo.type (info : CtorInfo) : IRType := + if info.isRef then .object else .tagged + inductive Expr where /-- We use `ctor` mainly for constructing Lean object/tobject values `lean_ctor_object` in the runtime. This instruction is also used to creat `struct` and `union` return values. diff --git a/src/Lean/Compiler/IR/EmitC.lean b/src/Lean/Compiler/IR/EmitC.lean index 54e7f3a55f..052efe8c6c 100644 --- a/src/Lean/Compiler/IR/EmitC.lean +++ b/src/Lean/Compiler/IR/EmitC.lean @@ -62,6 +62,7 @@ def toCType : IRType → String | IRType.uint64 => "uint64_t" | IRType.usize => "size_t" | IRType.object => "lean_object*" + | IRType.tagged => "lean_object*" | IRType.tobject => "lean_object*" | IRType.erased => "lean_object*" | IRType.struct _ _ => panic! "not implemented yet" diff --git a/src/Lean/Compiler/IR/EmitLLVM.lean b/src/Lean/Compiler/IR/EmitLLVM.lean index 63b8f05798..862b1b81f6 100644 --- a/src/Lean/Compiler/IR/EmitLLVM.lean +++ b/src/Lean/Compiler/IR/EmitLLVM.lean @@ -322,6 +322,7 @@ def toLLVMType (t : IRType) : M llvmctx (LLVM.LLVMType llvmctx) := do -- TODO: how to cleanly size_t in LLVM? We can do eg. instantiate the current target and query for size. | IRType.usize => LLVM.size_tType llvmctx | IRType.object => do LLVM.pointerType (← LLVM.i8Type llvmctx) + | IRType.tagged => do LLVM.pointerType (← LLVM.i8Type llvmctx) | IRType.tobject => do LLVM.pointerType (← LLVM.i8Type llvmctx) | IRType.erased => do LLVM.pointerType (← LLVM.i8Type llvmctx) | IRType.struct _ _ => panic! "not implemented yet" diff --git a/src/Lean/Compiler/IR/Format.lean b/src/Lean/Compiler/IR/Format.lean index 197bd546d2..d53e19bc1e 100644 --- a/src/Lean/Compiler/IR/Format.lean +++ b/src/Lean/Compiler/IR/Format.lean @@ -63,6 +63,7 @@ private partial def formatIRType : IRType → Format | IRType.usize => "usize" | IRType.erased => "◾" | IRType.object => "obj" + | IRType.tagged => "tagged" | IRType.tobject => "tobj" | IRType.struct _ tys => let _ : ToFormat IRType := ⟨formatIRType⟩ diff --git a/src/Lean/Compiler/IR/ToIR.lean b/src/Lean/Compiler/IR/ToIR.lean index 54f38e2641..c044b4090c 100644 --- a/src/Lean/Compiler/IR/ToIR.lean +++ b/src/Lean/Compiler/IR/ToIR.lean @@ -65,7 +65,9 @@ def addDecl (d : Decl) : M Unit := def lowerLitValue (v : LCNF.LitValue) : LitVal × IRType := match v with - | .nat n => ⟨.num n, .tobject⟩ + | .nat n => + let type := if n < UInt32.size then .tagged else .tobject + ⟨.num n, type⟩ | .str s => ⟨.str s, .object⟩ | .uint8 v => ⟨.num (UInt8.toNat v), .uint8⟩ | .uint16 v => ⟨.num (UInt16.toNat v), .uint16⟩ @@ -196,7 +198,7 @@ partial def lowerLet (decl : LCNF.LetDecl) (k : LCNF.Code) : M FnBody := do | some .erased => loop (i + 1) | none => lowerCode k loop 0 - return .vdecl objVar type (.ctor ctorInfo objArgs) (← lowerNonObjectFields ()) + return .vdecl objVar ctorInfo.type (.ctor ctorInfo objArgs) (← lowerNonObjectFields ()) | some (.defnInfo ..) | some (.opaqueInfo ..) => mkFap name irArgs | some (.axiomInfo ..) | .some (.quotInfo ..) | .some (.inductInfo ..) | .some (.thmInfo ..) => diff --git a/src/Lean/Compiler/IR/ToIRType.lean b/src/Lean/Compiler/IR/ToIRType.lean index fca105b59f..4c1849e106 100644 --- a/src/Lean/Compiler/IR/ToIRType.lean +++ b/src/Lean/Compiler/IR/ToIRType.lean @@ -17,7 +17,7 @@ open Lean.Compiler (LCNF.CacheExtension LCNF.isTypeFormerType LCNF.toLCNFType LC def irTypeForEnum (numCtors : Nat) : IRType := if numCtors == 1 then - .tobject + .tagged else if numCtors < Nat.pow 2 8 then .uint8 else if numCtors < Nat.pow 2 16 then @@ -25,7 +25,7 @@ def irTypeForEnum (numCtors : Nat) : IRType := else if numCtors < Nat.pow 2 32 then .uint32 else - .tobject + .tagged builtin_initialize irTypeExt : LCNF.CacheExtension Name IRType ← LCNF.CacheExtension.register @@ -135,7 +135,7 @@ where fillCache := do let monoFieldType ← LCNF.toMonoType lcnfFieldType let irFieldType ← toIRType monoFieldType let ctorField ← match irFieldType with - | .object | .tobject => do + | .object | .tagged | .tobject => do let i := nextIdx nextIdx := nextIdx + 1 pure <| .object i irFieldType diff --git a/src/library/ir_interpreter.cpp b/src/library/ir_interpreter.cpp index 8bad97b780..a3d35afad2 100644 --- a/src/library/ir_interpreter.cpp +++ b/src/library/ir_interpreter.cpp @@ -211,7 +211,7 @@ std::string format_fn_body_head(fn_body const & b) { } static bool type_is_scalar(type t) { - return t != type::Object && t != type::TObject && t != type::Irrelevant; + return t != type::Object && t != type::Tagged && t != type::TObject && t != type::Irrelevant; } extern "C" object* lean_get_regular_init_fn_name_for(object* env, object* fn); @@ -257,6 +257,7 @@ object * box_t(value v, type t) { case type::UInt64: return box_uint64(v.m_num); case type::USize: return box_size_t(v.m_num); case type::Object: + case type::Tagged: case type::TObject: case type::Irrelevant: return v.m_obj; @@ -278,6 +279,7 @@ value unbox_t(object * o, type t) { case type::USize: return unbox_size_t(o); case type::Irrelevant: case type::Object: + case type::Tagged: case type::TObject: break; case type::Struct: @@ -507,6 +509,7 @@ private: case type::USize: case type::Irrelevant: case type::Object: + case type::Tagged: case type::TObject: case type::Struct: case type::Union: @@ -572,6 +575,7 @@ private: return lean_uint64_of_nat(n.raw()); // `nat` literal case type::Object: + case type::Tagged: case type::TObject: return n.to_obj_arg(); case type::Irrelevant: @@ -698,6 +702,7 @@ private: case type::USize: case type::Irrelevant: case type::Object: + case type::Tagged: case type::TObject: case type::Struct: case type::Union: @@ -866,6 +871,7 @@ private: case type::UInt64: return *static_cast(e.m_native.m_addr); case type::USize: return *static_cast(e.m_native.m_addr); case type::Object: + case type::Tagged: case type::TObject: case type::Irrelevant: return *static_cast(e.m_native.m_addr); diff --git a/src/library/ir_types.h b/src/library/ir_types.h index 4e08cb61cd..f165882d43 100644 --- a/src/library/ir_types.h +++ b/src/library/ir_types.h @@ -15,10 +15,11 @@ inductive IRType | float32 | struct (leanTypeName : Option Name) (types : Array IRType) : IRType | union (leanTypeName : Name) (types : Array IRType) : IRType +| tagged Remark: we don't create struct/union types from C++. */ -enum class type { Float, UInt8, UInt16, UInt32, UInt64, USize, Irrelevant, Object, TObject, Float32, Struct, Union }; +enum class type { Float, UInt8, UInt16, UInt32, UInt64, USize, Irrelevant, Object, TObject, Float32, Struct, Union, Tagged }; typedef nat var_id; typedef nat jp_id; diff --git a/tests/lean/4089.lean.expected.out b/tests/lean/4089.lean.expected.out index ca9f55e859..3be7e86b28 100644 --- a/tests/lean/4089.lean.expected.out +++ b/tests/lean/4089.lean.expected.out @@ -23,7 +23,7 @@ def foo (x_1 : tobj) : tobj := case x_1 : tobj of List.nil → - let x_2 : tobj := ctor_0[List.nil]; + let x_2 : tagged := ctor_0[List.nil]; ret x_2 List.cons → let x_3 : tobj := proj[0] x_1; @@ -34,5 +34,5 @@ let x_5 : tobj := proj[0] x_3; let x_6 : tobj := proj[1] x_3; let x_7 : tobj := foo x_4; - let x_8 : tobj := reuse x_10 in ctor_1[List.cons] x_5 x_7; + let x_8 : obj := reuse x_10 in ctor_1[List.cons] x_5 x_7; ret x_8 diff --git a/tests/lean/4240.lean.expected.out b/tests/lean/4240.lean.expected.out index 4ce44be94c..14ba062692 100644 --- a/tests/lean/4240.lean.expected.out +++ b/tests/lean/4240.lean.expected.out @@ -13,12 +13,12 @@ let x_4 : u8 := MyOption.isSomeWithInstance._at_.isSomeWithInstanceNat.spec_0 x_3; dec x_3; ret x_4 - def MyOption.isSomeWithInstance._at_.isSomeWithInstanceNat.spec_0._boxed (x_1 : tobj) : tobj := + def MyOption.isSomeWithInstance._at_.isSomeWithInstanceNat.spec_0._boxed (x_1 : tobj) : tagged := let x_2 : u8 := MyOption.isSomeWithInstance._at_.isSomeWithInstanceNat.spec_0 x_1; dec x_1; let x_3 : tobj := box x_2; ret x_3 - def isSomeWithInstanceNat._boxed (x_1 : obj) : tobj := + def isSomeWithInstanceNat._boxed (x_1 : obj) : tagged := let x_2 : u8 := isSomeWithInstanceNat x_1; dec x_1; let x_3 : tobj := box x_2; diff --git a/tests/lean/computedFieldsCode.lean.expected.out b/tests/lean/computedFieldsCode.lean.expected.out index 7418fee433..33f1f1799e 100644 --- a/tests/lean/computedFieldsCode.lean.expected.out +++ b/tests/lean/computedFieldsCode.lean.expected.out @@ -106,7 +106,7 @@ let x_2 : u64 := UInt32.toUInt64 x_1; let x_3 : u64 := 1000; let x_4 : u64 := UInt64.add x_2 x_3; - let x_5 : tobj := ctor_0.0.12[Exp.var._impl]; + let x_5 : obj := ctor_0.0.12[Exp.var._impl]; sset x_5[0, 0] : u64 := x_4; sset x_5[0, 8] : u32 := x_1; ret x_5 @@ -114,23 +114,23 @@ let x_3 : u64 := Exp.hash._override x_1; let x_4 : u64 := Exp.hash._override x_2; let x_5 : u64 := mixHash x_3 x_4; - let x_6 : tobj := ctor_1.0.8[Exp.app._impl] x_1 x_2; + let x_6 : obj := ctor_1.0.8[Exp.app._impl] x_1 x_2; sset x_6[2, 0] : u64 := x_5; ret x_6 def Exp.a1._override : tobj := - let x_1 : tobj := ctor_2[Exp.a1._impl]; + let x_1 : tagged := ctor_2[Exp.a1._impl]; ret x_1 def Exp.a2._override : tobj := - let x_1 : tobj := ctor_3[Exp.a2._impl]; + let x_1 : tagged := ctor_3[Exp.a2._impl]; ret x_1 def Exp.a3._override : tobj := - let x_1 : tobj := ctor_4[Exp.a3._impl]; + let x_1 : tagged := ctor_4[Exp.a3._impl]; ret x_1 def Exp.a4._override : tobj := - let x_1 : tobj := ctor_5[Exp.a4._impl]; + let x_1 : tagged := ctor_5[Exp.a4._impl]; ret x_1 def Exp.a5._override : tobj := - let x_1 : tobj := ctor_6[Exp.a5._impl]; + let x_1 : tagged := ctor_6[Exp.a5._impl]; ret x_1 def Exp.hash._override (x_1 : @& tobj) : u64 := case x_1 : tobj of @@ -159,7 +159,7 @@ let x_2 : tobj := Exp.var._override x_1; ret x_2 def f._closed_1 : tobj := - let x_1 : tobj := ctor_5[Exp.a4._impl]; + let x_1 : tagged := ctor_5[Exp.a4._impl]; let x_2 : tobj := f._closed_0; let x_3 : tobj := Exp.app._override x_2 x_1; ret x_3 @@ -180,7 +180,7 @@ default → let x_3 : u8 := 0; ret x_3 - def g._boxed (x_1 : tobj) : tobj := + def g._boxed (x_1 : tobj) : tagged := let x_2 : u8 := g x_1; dec x_1; let x_3 : tobj := box x_2; @@ -202,7 +202,7 @@ dec x_6; ret x_8 default → - let x_9 : tobj := 42; + let x_9 : tagged := 42; ret x_9 def hash'._boxed (x_1 : tobj) : tobj := let x_2 : tobj := hash' x_1; diff --git a/tests/lean/doubleReset.lean.expected.out b/tests/lean/doubleReset.lean.expected.out index d16cd3e0ad..8332ab81e1 100644 --- a/tests/lean/doubleReset.lean.expected.out +++ b/tests/lean/doubleReset.lean.expected.out @@ -15,7 +15,7 @@ let x_8 : tobj := proj[0] x_7; let x_9 : tobj := proj[1] x_7; let x_18 : tobj := reset[2] x_7; - let x_10 : tobj := 0; + let x_10 : tagged := 0; let x_11 : obj := Array.uset ◾ x_4 x_3 x_10 ◾; let x_12 : obj := reuse x_18 in ctor_0[Prod.mk] x_8 x_9; let x_13 : obj := reuse x_19 in ctor_0[Prod.mk] x_12 x_1; diff --git a/tests/lean/listLength.lean.expected.out b/tests/lean/listLength.lean.expected.out index d3ea17388e..2c1da39c75 100644 --- a/tests/lean/listLength.lean.expected.out +++ b/tests/lean/listLength.lean.expected.out @@ -1,6 +1,6 @@ [Compiler.IR] [init] def f (x_1 : tobj) : tobj := - let x_2 : tobj := 2; + let x_2 : tagged := 2; let x_3 : tobj := List.lengthTR._redArg x_1; let x_4 : tobj := Nat.mul x_2 x_3; ret x_4 diff --git a/tests/lean/run/extractClosed.lean b/tests/lean/run/extractClosed.lean index 15c377ccc0..2a1a10737b 100644 --- a/tests/lean/run/extractClosed.lean +++ b/tests/lean/run/extractClosed.lean @@ -1,7 +1,7 @@ /-- trace: [Compiler.IR] [result] def f._closed_0 : obj := - let x_1 : tobj := 1; + let x_1 : tagged := 1; let x_2 : obj := Array.mkEmpty ◾ x_1; ret x_2 def f (x_1 : tobj) : obj := @@ -16,7 +16,7 @@ def f (a : Nat) : Array Nat := #[a] /-- trace: [Compiler.IR] [result] def g (x_1 : tobj) : obj := - let x_2 : tobj := 1; + let x_2 : tagged := 1; let x_3 : obj := Array.mkEmpty ◾ x_2; let x_4 : obj := Array.push ◾ x_3 x_1; ret x_4 diff --git a/tests/lean/sint_basic.lean.expected.out b/tests/lean/sint_basic.lean.expected.out index 2d5bb36385..070f918b3d 100644 --- a/tests/lean/sint_basic.lean.expected.out +++ b/tests/lean/sint_basic.lean.expected.out @@ -75,7 +75,7 @@ true [Compiler.IR] [result] def myId8 (x_1 : u8) : u8 := ret x_1 - def myId8._boxed (x_1 : tobj) : tobj := + def myId8._boxed (x_1 : tagged) : tagged := let x_2 : u8 := unbox x_1; dec x_1; let x_3 : u8 := myId8 x_2; @@ -158,7 +158,7 @@ true [Compiler.IR] [result] def myId16 (x_1 : u16) : u16 := ret x_1 - def myId16._boxed (x_1 : tobj) : tobj := + def myId16._boxed (x_1 : tagged) : tagged := let x_2 : u16 := unbox x_1; dec x_1; let x_3 : u16 := myId16 x_2; diff --git a/tests/lean/unboxStruct.lean.expected.out b/tests/lean/unboxStruct.lean.expected.out index 8fe65e262e..3185ae2afb 100644 --- a/tests/lean/unboxStruct.lean.expected.out +++ b/tests/lean/unboxStruct.lean.expected.out @@ -1,8 +1,8 @@ [Compiler.IR] [result] - def test2 (x_1 : u32) (x_2 : tobj) : obj := + def test2 (x_1 : u32) (x_2 : tagged) : obj := let x_3 : obj := foo x_1 x_2; ret x_3 - def test2._boxed (x_1 : tobj) (x_2 : tobj) : obj := + def test2._boxed (x_1 : tobj) (x_2 : tagged) : obj := let x_3 : u32 := unbox x_1; dec x_1; let x_4 : obj := test2 x_3 x_2;