diff --git a/src/Lean/Compiler/IR/Boxing.lean b/src/Lean/Compiler/IR/Boxing.lean index 7048b2e5a9..c8a1e4bd47 100644 --- a/src/Lean/Compiler/IR/Boxing.lean +++ b/src/Lean/Compiler/IR/Boxing.lean @@ -267,8 +267,29 @@ def visitVDeclExpr (x : VarId) (ty : IRType) (e : Expr) (b : FnBody) : M FnBody | _ => return .vdecl x ty e b +/-- +Up to this point the type system of IR is quite loose so we can for example encounter situations +such as +``` +let y : obj := f x +``` +where `f : obj -> uint8`. It is the job of the boxing pass to enforce a stricter obj/scalar +separation and as such it needs to correct situations like this. +-/ +def tryCorrectVDeclType (ty : IRType) (e : Expr) : M IRType := + match e with + | .fap f _ => do + let decl ← getDecl f + return decl.resultType + | .pap .. => return .object + | .uproj .. => return .usize + | .ctor .. | .reuse .. | .ap .. | .lit .. | .sproj .. | .proj .. | .reset .. => + return ty + | .unbox .. | .box .. | .isShared .. => unreachable! + partial def visitFnBody : FnBody → M FnBody | .vdecl x t v b => do + let t ← tryCorrectVDeclType t v let b ← withVDecl x t v (visitFnBody b) visitVDeclExpr x t v b | .jdecl j xs v b => do diff --git a/tests/lean/run/boxing_bug.lean b/tests/lean/run/boxing_bug.lean new file mode 100644 index 0000000000..6490074f13 --- /dev/null +++ b/tests/lean/run/boxing_bug.lean @@ -0,0 +1,28 @@ +def myCast : NatCast UInt8 where + natCast := UInt8.ofNat + +class Semiring (α : Type u) where + [nsmul : SMul Nat α] + +/-- +trace: [Compiler.IR] [result] + def instSemiringUInt8._lam_0 (x_1 : @& tobj) (x_2 : u8) : u8 := + let x_3 : u8 := UInt8.ofNat x_1; + let x_4 : u8 := UInt8.mul x_3 x_2; + ret x_4 + def instSemiringUInt8 : obj := + let x_1 : obj := pap instSemiringUInt8._lam_0._boxed; + ret x_1 + def instSemiringUInt8._lam_0._boxed (x_1 : tobj) (x_2 : tagged) : tagged := + let x_3 : u8 := unbox x_2; + let x_4 : u8 := instSemiringUInt8._lam_0 x_1 x_3; + dec x_1; + let x_5 : tagged := box x_4; + ret x_5 +-/ +#guard_msgs in +set_option trace.compiler.ir.result true in +attribute [local instance] myCast UInt8.intCast in +instance : Semiring UInt8 where + nsmul := ⟨(· * ·)⟩ +