feat: cache output universe parameter positions (#12285)

This PR implements a cache for the positions of class universe level
parameters that only appear in output parameter types.

During type class resolution, the cache key for a query like
`HAppend.{0, 0, ?u} (BitVec 8) (BitVec 8) ?m` should be independent of
the specific metavariable IDs in output parameter positions. To achieve
this, output parameter arguments are erased from the cache key. However,
universe levels that only appear in output parameter types (e.g., `?u`
corresponding to the result type's universe) must also be erased to
avoid cache misses when the same query is issued with different universe
metavariable IDs.

This function identifies which universe level parameter positions are
"output-only" by collecting all level param names that appear in
non-output parameter domains, then returning the positions of any level
params not in that set.

**Remark**: This PR requires a manual update stage0 because it changes
the structure of our .olean files.
This commit is contained in:
Leonardo de Moura 2026-02-02 19:56:33 -08:00 committed by GitHub
parent 3ad3bacd97
commit 3deba604bf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 821 additions and 204 deletions

View file

@ -4,12 +4,10 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Attributes
import Lean.Util.CollectLevelParams
public section
namespace Lean
/-- An entry for the persistent environment extension for declared type classes -/
@ -17,14 +15,21 @@ structure ClassEntry where
/-- Class name. -/
name : Name
/--
Position of the class `outParams`.
For example, for class
```
class GetElem (cont : Type u) (idx : Type v) (elem : outParam (Type w)) (dom : outParam (cont → idx → Prop)) where
```
`outParams := #[2, 3]`
Position of the class `outParams`.
For example, for class
```
class GetElem (cont : Type u) (idx : Type v) (elem : outParam (Type w)) (dom : outParam (cont → idx → Prop)) where
```
`outParams := #[2, 3]`
-/
outParams : Array Nat
/--
Positions of universe level parameters that only appear in output parameter types.
For example, for `HAdd (α : Type u) (β : Type v) (γ : outParam (Type w))`,
`outLevelParams := #[2]` since universe `w` only appears in the output parameter `γ`.
This is used to normalize TC resolution cache keys.
-/
outLevelParams : Array Nat
namespace ClassEntry
@ -36,12 +41,15 @@ end ClassEntry
/-- State of the type class environment extension. -/
structure ClassState where
outParamMap : SMap Name (Array Nat) := SMap.empty
outLevelParamMap : SMap Name (Array Nat) := SMap.empty
deriving Inhabited
namespace ClassState
def addEntry (s : ClassState) (entry : ClassEntry) : ClassState :=
{ s with outParamMap := s.outParamMap.insert entry.name entry.outParams }
{ s with
outParamMap := s.outParamMap.insert entry.name entry.outParams
outLevelParamMap := s.outLevelParamMap.insert entry.name entry.outLevelParams }
/--
Switch the state into persistent mode. We switch to this mode after
@ -49,7 +57,9 @@ we read all imported .olean files.
Recall that we use a `SMap` for implementing the state of the type class environment extension.
-/
def switch (s : ClassState) : ClassState :=
{ s with outParamMap := s.outParamMap.switch }
{ s with
outParamMap := s.outParamMap.switch
outLevelParamMap := s.outLevelParamMap.switch }
end ClassState
@ -79,6 +89,10 @@ def hasOutParams (env : Environment) (declName : Name) : Bool :=
| some outParams => !outParams.isEmpty
| none => false
/-- If `declName` is a class, return the positions of universe level parameters that only appear in output parameter types. -/
def getOutLevelParamPositions? (env : Environment) (declName : Name) : Option (Array Nat) :=
(classExtension.getState env).outLevelParamMap.find? declName
/--
Auxiliary function for collection the position class `outParams`, and
checking whether they are being correctly used.
@ -146,6 +160,39 @@ where
keepBinderInfo ()
| _ => type
/--
Compute positions of universe level parameters that only appear in output parameter types.
During type class resolution, the cache key for a query like
`HAppend.{0, 0, ?u} (BitVec 8) (BitVec 8) ?m` should be independent of the specific
metavariable IDs in output parameter positions. To achieve this, output parameter arguments
are erased from the cache key. However, universe levels that only appear in output parameter
types (e.g., `?u` corresponding to the result type's universe) must also be erased to avoid
cache misses when the same query is issued with different universe metavariable IDs.
This function identifies which universe level parameter positions are "output-only" by
collecting all level param names that appear in non-output parameter domains, then returning
the positions of any level params not in that set.
-/
private partial def computeOutLevelParams (type : Expr) (outParams : Array Nat) (levelParams : List Name) : Array Nat := Id.run do
let nonOutLevels := go type 0 {} |>.params
let mut result := #[]
let mut i := 0
for name in levelParams do
unless nonOutLevels.contains name do
result := result.push i
i := i + 1
result
where
go (type : Expr) (i : Nat) (s : CollectLevelParams.State) : CollectLevelParams.State :=
match type with
| .forallE _ d b _ =>
if outParams.contains i then
go b (i + 1) s
else
go b (i + 1) (collectLevelParams s d)
| _ => s
/--
Add a new type class with the given name to the environment.
`declName` must not be the name of an existing type class,
@ -161,7 +208,8 @@ def addClass (env : Environment) (clsName : Name) : Except MessageData Environme
unless decl matches .inductInfo .. | .axiomInfo .. do
throw m!"invalid 'class', declaration '{.ofConstName clsName}' must be inductive datatype, structure, or constant"
let outParams ← checkOutParam 0 #[] #[] decl.type
return classExtension.addEntry env { name := clsName, outParams }
let outLevelParams := computeOutLevelParams decl.type outParams decl.levelParams
return classExtension.addEntry env { name := clsName, outParams, outLevelParams }
/--
Registers an inductive type or structure as a type class. Using `class` or `class inductive` is

File diff suppressed because it is too large Load diff