From 2864efb2227dba1cb949c36c18dc4bb11b204ab4 Mon Sep 17 00:00:00 2001 From: Cameron Zwarich Date: Tue, 1 Jul 2025 15:35:50 -0700 Subject: [PATCH] feat: support enums modulo irrelevance (#9144) This PR adds support for representing more inductive as enums, summarized up as extending support to those that fail to be enums because of parameters or irrelevant fields. While this is nice to have, it is actually motivated by correctness of a future desired optimization. The existing type representation is unsound if we implement `object`/`tobject` distinction between values guaranteed to be an object pointer and those that may also be a tagged scalar. In particular, types like the ones added in this PR's tests would have all of their constructors encoded via tagged values, but under the natural extension of the existing rules of type representation they would be considered `object` rather than `tobject`. --- src/Lean/Compiler/IR/ToIR.lean | 1 - src/Lean/Compiler/IR/ToIRType.lean | 20 ++++++++-- tests/lean/run/enumsModuloIrrelevance.lean | 44 ++++++++++++++++++++++ 3 files changed, 60 insertions(+), 5 deletions(-) create mode 100644 tests/lean/run/enumsModuloIrrelevance.lean diff --git a/src/Lean/Compiler/IR/ToIR.lean b/src/Lean/Compiler/IR/ToIR.lean index 9a7e640966..13a049be88 100644 --- a/src/Lean/Compiler/IR/ToIR.lean +++ b/src/Lean/Compiler/IR/ToIR.lean @@ -211,7 +211,6 @@ partial def lowerLet (decl : LCNF.LetDecl) (k : LCNF.Code) : M FnBody := do else let type ← nameToIRType ctorVal.induct if type.isScalar then - assert! args.isEmpty let var ← bindVar decl.fvarId return .vdecl var type (.lit (.num ctorVal.cidx)) (← lowerCode k) else diff --git a/src/Lean/Compiler/IR/ToIRType.lean b/src/Lean/Compiler/IR/ToIRType.lean index acfb6b859a..ea82bd150c 100644 --- a/src/Lean/Compiler/IR/ToIRType.lean +++ b/src/Lean/Compiler/IR/ToIRType.lean @@ -41,8 +41,16 @@ where fillCache : CoreM IRType := do let ctorNames := inductiveVal.ctors let numCtors := ctorNames.length for ctorName in ctorNames do - let some (.ctorInfo ctorVal) := env.find? ctorName | unreachable! - if ctorVal.type.isForall then return .object + let some (.ctorInfo ctorInfo) := env.find? ctorName | unreachable! + let isRelevant ← Meta.MetaM.run' <| + Meta.forallTelescopeReducing ctorInfo.type fun params _ => do + for field in params[ctorInfo.numParams...*] do + let fieldType ← field.fvarId!.getType + let lcnfFieldType ← LCNF.toLCNFType fieldType + let monoFieldType ← LCNF.toMonoType lcnfFieldType + if !monoFieldType.isErased then return true + return false + if isRelevant then return .object return if numCtors == 1 then .object else if numCtors < Nat.pow 2 8 then @@ -56,8 +64,12 @@ where fillCache : CoreM IRType := do def toIRType (type : Lean.Expr) : CoreM IRType := do match type with - | .const name _ | .app (.const name _) _ => nameToIRType name - | .app .. | .forallE .. => return .object + | .const name _ => nameToIRType name + | .app .. => + -- All mono types are in headBeta form. + let .const name _ := type.getAppFn | unreachable! + nameToIRType name + | .forallE .. => return .object | _ => unreachable! inductive CtorFieldInfo where diff --git a/tests/lean/run/enumsModuloIrrelevance.lean b/tests/lean/run/enumsModuloIrrelevance.lean new file mode 100644 index 0000000000..4e8011d7c2 --- /dev/null +++ b/tests/lean/run/enumsModuloIrrelevance.lean @@ -0,0 +1,44 @@ +inductive E1 (n : Nat) where + | a + | b + | c + +/-- +trace: [Compiler.IR] [result] + def e1 : u8 := + let x_1 : u8 := 2; + ret x_1 +-/ +#guard_msgs in +set_option trace.compiler.ir.result true in +def e1 : E1 7 := .c + +inductive E2 where + | a (p : 0 = 0) + | b (p : 1 = 1) + | c (p : 0 = 1) + +/-- +trace: [Compiler.IR] [result] + def e2 : u8 := + let x_1 : u8 := 1; + ret x_1 +-/ +#guard_msgs in +set_option trace.compiler.ir.result true in +def e2 : E2 := .b rfl + +inductive E3 (m n : Nat) where + | a (p : 0 = 0) + | b (p : 1 = 1) + | c (p : 0 = 0) (q : 1 = 1) + +/-- +trace: [Compiler.IR] [result] + def e3 : u8 := + let x_1 : u8 := 2; + ret x_1 +-/ +#guard_msgs in +set_option trace.compiler.ir.result true in +def e3 : E3 7 11 := .c rfl rfl