feat: cache output universe parameter positions (#12285)
This PR implements a cache for the positions of class universe level
parameters that only appear in output parameter types.
During type class resolution, the cache key for a query like
`HAppend.{0, 0, ?u} (BitVec 8) (BitVec 8) ?m` should be independent of
the specific metavariable IDs in output parameter positions. To achieve
this, output parameter arguments are erased from the cache key. However,
universe levels that only appear in output parameter types (e.g., `?u`
corresponding to the result type's universe) must also be erased to
avoid cache misses when the same query is issued with different universe
metavariable IDs.
This function identifies which universe level parameter positions are
"output-only" by collecting all level param names that appear in
non-output parameter domains, then returning the positions of any level
params not in that set.
**Remark**: This PR requires a manual update stage0 because it changes
the structure of our .olean files.
This commit is contained in:
parent
3ad3bacd97
commit
3deba604bf
2 changed files with 821 additions and 204 deletions
|
|
@ -4,12 +4,10 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
|||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Lean.Attributes
|
||||
|
||||
import Lean.Util.CollectLevelParams
|
||||
public section
|
||||
|
||||
namespace Lean
|
||||
|
||||
/-- An entry for the persistent environment extension for declared type classes -/
|
||||
|
|
@ -17,14 +15,21 @@ structure ClassEntry where
|
|||
/-- Class name. -/
|
||||
name : Name
|
||||
/--
|
||||
Position of the class `outParams`.
|
||||
For example, for class
|
||||
```
|
||||
class GetElem (cont : Type u) (idx : Type v) (elem : outParam (Type w)) (dom : outParam (cont → idx → Prop)) where
|
||||
```
|
||||
`outParams := #[2, 3]`
|
||||
Position of the class `outParams`.
|
||||
For example, for class
|
||||
```
|
||||
class GetElem (cont : Type u) (idx : Type v) (elem : outParam (Type w)) (dom : outParam (cont → idx → Prop)) where
|
||||
```
|
||||
`outParams := #[2, 3]`
|
||||
-/
|
||||
outParams : Array Nat
|
||||
/--
|
||||
Positions of universe level parameters that only appear in output parameter types.
|
||||
For example, for `HAdd (α : Type u) (β : Type v) (γ : outParam (Type w))`,
|
||||
`outLevelParams := #[2]` since universe `w` only appears in the output parameter `γ`.
|
||||
This is used to normalize TC resolution cache keys.
|
||||
-/
|
||||
outLevelParams : Array Nat
|
||||
|
||||
namespace ClassEntry
|
||||
|
||||
|
|
@ -36,12 +41,15 @@ end ClassEntry
|
|||
/-- State of the type class environment extension. -/
|
||||
structure ClassState where
|
||||
outParamMap : SMap Name (Array Nat) := SMap.empty
|
||||
outLevelParamMap : SMap Name (Array Nat) := SMap.empty
|
||||
deriving Inhabited
|
||||
|
||||
namespace ClassState
|
||||
|
||||
def addEntry (s : ClassState) (entry : ClassEntry) : ClassState :=
|
||||
{ s with outParamMap := s.outParamMap.insert entry.name entry.outParams }
|
||||
{ s with
|
||||
outParamMap := s.outParamMap.insert entry.name entry.outParams
|
||||
outLevelParamMap := s.outLevelParamMap.insert entry.name entry.outLevelParams }
|
||||
|
||||
/--
|
||||
Switch the state into persistent mode. We switch to this mode after
|
||||
|
|
@ -49,7 +57,9 @@ we read all imported .olean files.
|
|||
Recall that we use a `SMap` for implementing the state of the type class environment extension.
|
||||
-/
|
||||
def switch (s : ClassState) : ClassState :=
|
||||
{ s with outParamMap := s.outParamMap.switch }
|
||||
{ s with
|
||||
outParamMap := s.outParamMap.switch
|
||||
outLevelParamMap := s.outLevelParamMap.switch }
|
||||
|
||||
end ClassState
|
||||
|
||||
|
|
@ -79,6 +89,10 @@ def hasOutParams (env : Environment) (declName : Name) : Bool :=
|
|||
| some outParams => !outParams.isEmpty
|
||||
| none => false
|
||||
|
||||
/-- If `declName` is a class, return the positions of universe level parameters that only appear in output parameter types. -/
|
||||
def getOutLevelParamPositions? (env : Environment) (declName : Name) : Option (Array Nat) :=
|
||||
(classExtension.getState env).outLevelParamMap.find? declName
|
||||
|
||||
/--
|
||||
Auxiliary function for collection the position class `outParams`, and
|
||||
checking whether they are being correctly used.
|
||||
|
|
@ -146,6 +160,39 @@ where
|
|||
keepBinderInfo ()
|
||||
| _ => type
|
||||
|
||||
/--
|
||||
Compute positions of universe level parameters that only appear in output parameter types.
|
||||
|
||||
During type class resolution, the cache key for a query like
|
||||
`HAppend.{0, 0, ?u} (BitVec 8) (BitVec 8) ?m` should be independent of the specific
|
||||
metavariable IDs in output parameter positions. To achieve this, output parameter arguments
|
||||
are erased from the cache key. However, universe levels that only appear in output parameter
|
||||
types (e.g., `?u` corresponding to the result type's universe) must also be erased to avoid
|
||||
cache misses when the same query is issued with different universe metavariable IDs.
|
||||
|
||||
This function identifies which universe level parameter positions are "output-only" by
|
||||
collecting all level param names that appear in non-output parameter domains, then returning
|
||||
the positions of any level params not in that set.
|
||||
-/
|
||||
private partial def computeOutLevelParams (type : Expr) (outParams : Array Nat) (levelParams : List Name) : Array Nat := Id.run do
|
||||
let nonOutLevels := go type 0 {} |>.params
|
||||
let mut result := #[]
|
||||
let mut i := 0
|
||||
for name in levelParams do
|
||||
unless nonOutLevels.contains name do
|
||||
result := result.push i
|
||||
i := i + 1
|
||||
result
|
||||
where
|
||||
go (type : Expr) (i : Nat) (s : CollectLevelParams.State) : CollectLevelParams.State :=
|
||||
match type with
|
||||
| .forallE _ d b _ =>
|
||||
if outParams.contains i then
|
||||
go b (i + 1) s
|
||||
else
|
||||
go b (i + 1) (collectLevelParams s d)
|
||||
| _ => s
|
||||
|
||||
/--
|
||||
Add a new type class with the given name to the environment.
|
||||
`declName` must not be the name of an existing type class,
|
||||
|
|
@ -161,7 +208,8 @@ def addClass (env : Environment) (clsName : Name) : Except MessageData Environme
|
|||
unless decl matches .inductInfo .. | .axiomInfo .. do
|
||||
throw m!"invalid 'class', declaration '{.ofConstName clsName}' must be inductive datatype, structure, or constant"
|
||||
let outParams ← checkOutParam 0 #[] #[] decl.type
|
||||
return classExtension.addEntry env { name := clsName, outParams }
|
||||
let outLevelParams := computeOutLevelParams decl.type outParams decl.levelParams
|
||||
return classExtension.addEntry env { name := clsName, outParams, outLevelParams }
|
||||
|
||||
/--
|
||||
Registers an inductive type or structure as a type class. Using `class` or `class inductive` is
|
||||
|
|
|
|||
953
stage0/stdlib/Lean/Class.c
generated
953
stage0/stdlib/Lean/Class.c
generated
File diff suppressed because it is too large
Load diff
Loading…
Add table
Reference in a new issue