lean4-htt/src/Lean/Class.lean
Leonardo de Moura 3deba604bf
feat: cache output universe parameter positions (#12285)
This PR implements a cache for the positions of class universe level
parameters that only appear in output parameter types.

During type class resolution, the cache key for a query like
`HAppend.{0, 0, ?u} (BitVec 8) (BitVec 8) ?m` should be independent of
the specific metavariable IDs in output parameter positions. To achieve
this, output parameter arguments are erased from the cache key. However,
universe levels that only appear in output parameter types (e.g., `?u`
corresponding to the result type's universe) must also be erased to
avoid cache misses when the same query is issued with different universe
metavariable IDs.

This function identifies which universe level parameter positions are
"output-only" by collecting all level param names that appear in
non-output parameter domains, then returning the positions of any level
params not in that set.

**Remark**: This PR requires a manual update stage0 because it changes
the structure of our .olean files.
2026-02-02 19:56:33 -08:00

231 lines
9 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Attributes
import Lean.Util.CollectLevelParams
public section
namespace Lean
/-- An entry for the persistent environment extension for declared type classes -/
structure ClassEntry where
/-- Class name. -/
name : Name
/--
Position of the class `outParams`.
For example, for class
```
class GetElem (cont : Type u) (idx : Type v) (elem : outParam (Type w)) (dom : outParam (cont → idx → Prop)) where
```
`outParams := #[2, 3]`
-/
outParams : Array Nat
/--
Positions of universe level parameters that only appear in output parameter types.
For example, for `HAdd (α : Type u) (β : Type v) (γ : outParam (Type w))`,
`outLevelParams := #[2]` since universe `w` only appears in the output parameter `γ`.
This is used to normalize TC resolution cache keys.
-/
outLevelParams : Array Nat
namespace ClassEntry
def lt (a b : ClassEntry) : Bool :=
Name.quickLt a.name b.name
end ClassEntry
/-- State of the type class environment extension. -/
structure ClassState where
outParamMap : SMap Name (Array Nat) := SMap.empty
outLevelParamMap : SMap Name (Array Nat) := SMap.empty
deriving Inhabited
namespace ClassState
def addEntry (s : ClassState) (entry : ClassEntry) : ClassState :=
{ s with
outParamMap := s.outParamMap.insert entry.name entry.outParams
outLevelParamMap := s.outLevelParamMap.insert entry.name entry.outLevelParams }
/--
Switch the state into persistent mode. We switch to this mode after
we read all imported .olean files.
Recall that we use a `SMap` for implementing the state of the type class environment extension.
-/
def switch (s : ClassState) : ClassState :=
{ s with
outParamMap := s.outParamMap.switch
outLevelParamMap := s.outLevelParamMap.switch }
end ClassState
/--
Type class environment extension
-/
-- TODO: add support for scoped instances
builtin_initialize classExtension : SimplePersistentEnvExtension ClassEntry ClassState ←
registerSimplePersistentEnvExtension {
addEntryFn := ClassState.addEntry
addImportedFn := fun es => (mkStateFromImportedEntries ClassState.addEntry {} es).switch
}
/-- Return `true` if `n` is the name of type class in the given environment. -/
@[export lean_is_class]
def isClass (env : Environment) (n : Name) : Bool :=
(classExtension.getState env).outParamMap.contains n
/-- If `declName` is a class, return the position of its `outParams`. -/
def getOutParamPositions? (env : Environment) (declName : Name) : Option (Array Nat) :=
(classExtension.getState env).outParamMap.find? declName
/-- Return `true` if the given `declName` is a type class with output parameters. -/
@[export lean_has_out_params]
def hasOutParams (env : Environment) (declName : Name) : Bool :=
match getOutParamPositions? env declName with
| some outParams => !outParams.isEmpty
| none => false
/-- If `declName` is a class, return the positions of universe level parameters that only appear in output parameter types. -/
def getOutLevelParamPositions? (env : Environment) (declName : Name) : Option (Array Nat) :=
(classExtension.getState env).outLevelParamMap.find? declName
/--
Auxiliary function for collection the position class `outParams`, and
checking whether they are being correctly used.
A regular (i.e., non `outParam`) must not depend on an `outParam`.
Reason for this restriction:
When performing type class resolution, we replace arguments that
are `outParam`s with fresh metavariables. If regular parameters could
depend on `outParam`s, then we would also have to replace them with
fresh metavariables. Otherwise, the resulting expression could be type
incorrect. This transformation would be counterintuitive to users since
we would implicitly treat these regular parameters as `outParam`s.
-/
private partial def checkOutParam (i : Nat) (outParamFVarIds : Array FVarId) (outParams : Array Nat) (type : Expr) : Except MessageData (Array Nat) :=
match type with
| .forallE _ d b bi =>
let addOutParam (_ : Unit) :=
let fvarId := { name := Name.mkNum `_fvar outParamFVarIds.size }
let fvar := mkFVar fvarId
let b := b.instantiate1 fvar
checkOutParam (i+1) (outParamFVarIds.push fvarId) (outParams.push i) b
if d.isOutParam then
addOutParam ()
else if d.hasAnyFVar fun fvarId => outParamFVarIds.contains fvarId then
if bi.isInstImplicit then
/- See issue #1852 for a motivation for `bi.isInstImplicit` -/
addOutParam ()
else
Except.error m!"invalid class, parameter #{i+1} depends on `outParam`, but it is not an `outParam`"
else
checkOutParam (i+1) outParamFVarIds outParams b
| _ => return outParams
/--
Mark `outParam`s in `type` as implicit. Note that it also marks instance implicit arguments that depend on `outParam`s as implicit.
Remark: this function consumes the `outParam` annotations.
This function uses the same logic used as `checkOutParam`.
See issue #1901
-/
@[export lean_mk_outparam_args_implicit]
partial def mkOutParamArgsImplicit (type : Expr) : Expr :=
go type type #[]
where
go (type : Expr) (typeAux : Expr) (outParamFVarIds : Array FVarId) : Expr :=
match typeAux with
| .forallE _ d b bi =>
let mkOutParamImplicit (dNew : Expr) :=
let fvarId := { name := Name.mkNum `_fvar outParamFVarIds.size }
let fvar := mkFVar fvarId
let b := b.instantiate1 fvar
let bNew := go type.bindingBody! b (outParamFVarIds.push fvarId)
type.updateForall! .implicit dNew bNew
let keepBinderInfo (_ : Unit) :=
let bNew := go type.bindingBody! b outParamFVarIds
type.updateForallE! type.bindingDomain! bNew
if d.isOutParam then
mkOutParamImplicit type.bindingDomain!.appArg! -- consume `outParam` annotation
else if d.hasAnyFVar fun fvarId => outParamFVarIds.contains fvarId then
if bi.isInstImplicit then
mkOutParamImplicit type.bindingDomain!
else
keepBinderInfo ()
else
keepBinderInfo ()
| _ => type
/--
Compute positions of universe level parameters that only appear in output parameter types.
During type class resolution, the cache key for a query like
`HAppend.{0, 0, ?u} (BitVec 8) (BitVec 8) ?m` should be independent of the specific
metavariable IDs in output parameter positions. To achieve this, output parameter arguments
are erased from the cache key. However, universe levels that only appear in output parameter
types (e.g., `?u` corresponding to the result type's universe) must also be erased to avoid
cache misses when the same query is issued with different universe metavariable IDs.
This function identifies which universe level parameter positions are "output-only" by
collecting all level param names that appear in non-output parameter domains, then returning
the positions of any level params not in that set.
-/
private partial def computeOutLevelParams (type : Expr) (outParams : Array Nat) (levelParams : List Name) : Array Nat := Id.run do
let nonOutLevels := go type 0 {} |>.params
let mut result := #[]
let mut i := 0
for name in levelParams do
unless nonOutLevels.contains name do
result := result.push i
i := i + 1
result
where
go (type : Expr) (i : Nat) (s : CollectLevelParams.State) : CollectLevelParams.State :=
match type with
| .forallE _ d b _ =>
if outParams.contains i then
go b (i + 1) s
else
go b (i + 1) (collectLevelParams s d)
| _ => s
/--
Add a new type class with the given name to the environment.
`declName` must not be the name of an existing type class,
and it must be the name of constant in `env`.
`declName` must be a inductive datatype or axiom.
Recall that all structures are inductive datatypes.
-/
def addClass (env : Environment) (clsName : Name) : Except MessageData Environment := do
if isClass env clsName then
throw m!"class has already been declared '{.ofConstName clsName true}'"
let some decl := env.find? clsName
| throw m!"unknown declaration '{clsName}'"
unless decl matches .inductInfo .. | .axiomInfo .. do
throw m!"invalid 'class', declaration '{.ofConstName clsName}' must be inductive datatype, structure, or constant"
let outParams ← checkOutParam 0 #[] #[] decl.type
let outLevelParams := computeOutLevelParams decl.type outParams decl.levelParams
return classExtension.addEntry env { name := clsName, outParams, outLevelParams }
/--
Registers an inductive type or structure as a type class. Using `class` or `class inductive` is
generally preferred over using `@[class] structure` or `@[class] inductive` directly.
-/
@[builtin_init, builtin_doc]
private def init :=
registerBuiltinAttribute {
name := `class
descr := "type class"
add := fun decl stx kind => do
let env ← getEnv
Attribute.Builtin.ensureNoArgs stx
unless kind == AttributeKind.global do throwAttrMustBeGlobal `class kind
let env ← ofExcept (addClass env decl)
setEnv env
}
end Lean