feat: structure parameter binder kind overrides (#7742)

This PR adds a feature to `structure`/`class` where binders without
types on a field definition are interpreted as overriding the type's
parameters binder kinds in that field's projection function. The rules
are (1) only a prefix of the binders are interpreted this way, (2)
multi-identifier binders are allowed but they must all be for
parameters, (3) only parameters that appear in the declaration itself
(not from `variables`) can be overridden and (4) the updates will be
applied after parameter binder kind inference is done. Binder updates
are not allowed in default value redefinitions. Example application: In
the following, `(R p)` causes the `R` and `p` parameters to be explicit,
where normally they would be implicit.
```
class CharP (R : Type u) [AddMonoidWithOne R] (p : Nat) : Prop where
  cast_eq_zero_iff (R p) : ∀ x : Nat, (x : R) = 0 ↔ p ∣ x


#guard_msgs in #check CharP.cast_eq_zero_iff
/-
info: CharP.cast_eq_zero_iff.{u} (R : Type u) {inst✝ : AddMonoidWithOne R} (p : Nat) [self : CharP R p] (x : Nat) :
  ↑x = 0 ↔ p ∣ x
-/
```
The rationale for (3) is that there are cases where a module starts with
a large `variables` list and a field only incidentally uses the binder.
Without the restriction, the field ends up depending on that variable,
counterintuitively causing it to be introduced as an additional
parameter for the type. Instead, there is an explicit error. The easy
fix is to add `: _`, which is the bare minimum to make the binder have a
type.

We should consider warning when binders shadow parameters.

Closes #3574

[Zulip
discussion](https://leanprover.zulipchat.com/#narrow/channel/270676-lean4/topic/RFC.3A.20adjust.20argument.20explicitness.20on.20typeclass.20projections/near/508584627)

Mathlib fixes:
https://github.com/leanprover-community/mathlib4/pull/23469
This commit is contained in:
Kyle Miller 2025-03-30 20:54:03 -07:00 committed by GitHub
parent e00dd3b25a
commit 5a50a8d278
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 223 additions and 23 deletions

View file

@ -160,6 +160,8 @@ structure StructFieldInfo where
declName : Name
/-- Binder info to use when making the constructor. Only applies to those fields that will appear in the constructor. -/
binfo : BinderInfo
/-- Overrides for the parameters' binder infos when making the projections. The first component is a ref for the binder. -/
paramInfoOverrides : ExprMap (Syntax × BinderInfo) := {}
/--
Structure names that are responsible for this field being here.
- Empty if the field is a `newField`.
@ -184,7 +186,7 @@ structure StructFieldInfo where
inheritedDefaults : Array (Name × StructFieldDefault) := #[]
/-- The default that will be used for this structure. -/
resolvedDefault? : Option StructFieldDefault := none
deriving Inhabited, Repr
deriving Inhabited
/-!
### View construction
@ -922,23 +924,58 @@ private def solveParentMVars (e : Expr) : StructElabM Expr := do
discard <| MVarId.checkedAssign mvar parentInfo.fvar
return e
private def elabFieldTypeValue (view : StructFieldView) : StructElabM (Option Expr × Option StructFieldDefault) := do
open Parser.Term in
private def typelessBinder? : Syntax → Option ((Array Ident) × BinderInfo)
| `(bracketedBinderF|($ids:ident*)) => some (ids, .default)
| `(bracketedBinderF|{$ids:ident*}) => some (ids, .implicit)
| `(bracketedBinderF|⦃$ids:ident*⦄) => some (ids, .strictImplicit)
| `(bracketedBinderF|[$id:ident]) => some (#[id], .instImplicit)
| _ => none
/--
Takes a binder list and interprets the prefix to see if any could be construed to be binder info updates.
Returns the binder list without these updates along with the new binder infos for these parameters.
-/
private def elabParamInfoUpdates (structParams : Array Expr) (binders : Array Syntax) : StructElabM (Array Syntax × ExprMap (Syntax × BinderInfo)) := do
let mut overrides : ExprMap (Syntax × BinderInfo) := {}
for i in [0:binders.size] do
match typelessBinder? binders[i]! with
| none => return (binders.extract i, overrides)
| some (ids, bi) =>
let lctx ← getLCtx
let decls := ids.filterMap fun id => lctx.findFromUserName? id.getId
-- Filter out all fields. We assume the remaining fvars are the possible parameters.
let decls ← decls.filterM fun decl => return (← findFieldInfoByFVarId? decl.fvarId).isNone
if decls.size != ids.size then
-- Then either these are for a new variables or the binder isn't only for parameters
return (binders.extract i, overrides)
for decl in decls, id in ids do
Term.addTermInfo' id decl.toExpr
unless structParams.contains decl.toExpr do
throwErrorAt id m!"only parameters appearing in the declaration header may have their binders kinds be overridden\n\n\
If this is not intended to be an override, use a binder with a type, for example '(x : _)'."
overrides := overrides.insert decl.toExpr (id, bi)
return (#[], overrides)
private def elabFieldTypeValue (structParams : Array Expr) (view : StructFieldView) :
StructElabM (Option Expr × ExprMap (Syntax × BinderInfo) × Option StructFieldDefault) := do
let state ← get
Term.withAutoBoundImplicit <| Term.withAutoBoundImplicitForbiddenPred (fun n => view.name == n) <| Term.elabBinders view.binders.getArgs fun params => do
let binders := view.binders.getArgs
let (binders, paramInfoOverrides) ← elabParamInfoUpdates structParams binders
Term.withAutoBoundImplicit <| Term.withAutoBoundImplicitForbiddenPred (fun n => view.name == n) <| Term.elabBinders binders fun params => do
match view.type? with
| none =>
| none =>
match view.default? with
| none => return (none, none)
| none => return (none, paramInfoOverrides, none)
| some (.optParam valStx) =>
Term.synthesizeSyntheticMVarsNoPostponing
-- TODO: add forbidden predicate using `shortDeclName` from `view`
let params ← Term.addAutoBoundImplicits params (view.nameId.getTailPos? (canonicalOnly := true))
let value ← Term.withoutAutoBoundImplicit <| Term.elabTerm valStx none
let value ← runStructElabM (init := state) <| solveParentMVars value
registerFailedToInferFieldType view.name (← inferType value) view.nameId
registerFailedToInferDefaultValue view.name value valStx
let value ← mkLambdaFVars params value
return (none, StructFieldDefault.optParam value)
return (none, paramInfoOverrides, StructFieldDefault.optParam value)
| some (.autoParam tacticStx) =>
throwErrorAt tacticStx "invalid field declaration, type must be provided when auto-param tactic is used"
| some typeStx =>
@ -948,9 +985,9 @@ private def elabFieldTypeValue (view : StructFieldView) : StructElabM (Option Ex
Term.synthesizeSyntheticMVarsNoPostponing
let params ← Term.addAutoBoundImplicits params (view.nameId.getTailPos? (canonicalOnly := true))
match view.default? with
| none =>
| none =>
let type ← mkForallFVars params type
return (type, none)
return (type, paramInfoOverrides, none)
| some (.optParam valStx) =>
let value ← Term.withoutAutoBoundImplicit <| Term.elabTermEnsuringType valStx type
let value ← runStructElabM (init := state) <| solveParentMVars value
@ -958,14 +995,14 @@ private def elabFieldTypeValue (view : StructFieldView) : StructElabM (Option Ex
Term.synthesizeSyntheticMVarsNoPostponing
let type ← mkForallFVars params type
let value ← mkLambdaFVars params value
return (type, StructFieldDefault.optParam value)
return (type, paramInfoOverrides, StructFieldDefault.optParam value)
| some (.autoParam tacticStx) =>
let name := mkAutoParamFnOfProjFn view.declName
discard <| Term.declareTacticSyntax tacticStx name
let type ← mkForallFVars params type
return (type, StructFieldDefault.autoParam <| .const name [])
return (type, paramInfoOverrides, StructFieldDefault.autoParam <| .const name [])
private partial def withFields (views : Array StructFieldView) (k : StructElabM α) : StructElabM α := do
private partial def withFields (structParams : Array Expr) (views : Array StructFieldView) (k : StructElabM α) : StructElabM α := do
go 0
where
go (i : Nat) : StructElabM α := do
@ -976,14 +1013,14 @@ where
throwError "field '{view.name}' has already been declared as a projection for parent '{.ofConstName parent.structName}'"
match ← findFieldInfo? view.name with
| none =>
let (type?, default?) ← elabFieldTypeValue view
let (type?, paramInfoOverrides, default?) ← elabFieldTypeValue structParams view
match type?, default? with
| none, none => throwError "invalid field, type expected"
| some type, _ =>
withLocalDecl view.rawName view.binderInfo type fun fieldFVar => do
addFieldInfo { ref := view.nameId, sourceStructNames := [],
name := view.name, declName := view.declName, fvar := fieldFVar, default? := default?,
binfo := view.binderInfo,
binfo := view.binderInfo, paramInfoOverrides,
kind := StructFieldKind.newField }
go (i+1)
| none, some (.optParam value) =>
@ -991,7 +1028,7 @@ where
withLocalDecl view.rawName view.binderInfo type fun fieldFVar => do
addFieldInfo { ref := view.nameId, sourceStructNames := [],
name := view.name, declName := view.declName, fvar := fieldFVar, default? := default?,
binfo := view.binderInfo,
binfo := view.binderInfo, paramInfoOverrides,
kind := StructFieldKind.newField }
go (i+1)
| none, some (.autoParam _) =>
@ -1007,8 +1044,12 @@ where
if info.default?.isSome then
throwError "field '{view.name}' new default value has already been set"
let mut valStx := valStx
if view.binders.getArgs.size > 0 then
valStx ← `(fun $(view.binders.getArgs)* => $valStx:term)
let (binders, paramInfoOverrides) ← elabParamInfoUpdates structParams view.binders.getArgs
unless paramInfoOverrides.isEmpty do
let params := MessageData.joinSep (paramInfoOverrides.toList.map (m!"{·.1}")) ", "
throwError "cannot override structure parameter binder kinds when overriding the default value: {params}"
if binders.size > 0 then
valStx ← `(fun $binders* => $valStx:term)
let fvarType ← inferType info.fvar
let value ← Term.elabTermEnsuringType valStx fvarType
registerFailedToInferDefaultValue view.name value valStx
@ -1160,11 +1201,16 @@ private partial def checkResultingUniversesForFields (fieldInfos : Array StructF
which is not less than or equal to the structure's resulting universe level{indentD u}"
throwErrorAt info.ref msg
private def addProjections (r : ElabHeaderResult) (fieldInfos : Array StructFieldInfo) : TermElabM Unit := do
let projDecls : Array StructProjDecl :=
private def addProjections (params : Array Expr) (r : ElabHeaderResult) (fieldInfos : Array StructFieldInfo) : TermElabM Unit := do
let projDecls : Array StructProjDecl
fieldInfos
|>.filter (·.kind.isInCtor)
|>.map (fun info => { ref := info.ref, projName := info.declName })
|>.mapM (fun info => do
info.paramInfoOverrides.forM fun p (ref, _) => do
unless params.contains p do
throwErrorAt ref "invalid parameter binder update, not a parameter"
let paramInfoOverrides := params |>.map (fun param => info.paramInfoOverrides[param]?.map Prod.snd) |>.toList
return { ref := info.ref, projName := info.declName, paramInfoOverrides })
mkProjections r.view.declName projDecls r.view.isClass
for fieldInfo in fieldInfos do
if fieldInfo.kind.isSubobject then
@ -1412,7 +1458,7 @@ def elabStructureCommand : InductiveElabDescr where
view := view.toInductiveView
elabCtors := fun rs r params => runStructElabM do
withParents view rs r.indFVar do
withFields view.fields do
withFields params view.fields do
withRef view.ref do
Term.synthesizeSyntheticMVarsNoPostponing
resolveFieldDefaults view.declName
@ -1429,7 +1475,7 @@ def elabStructureCommand : InductiveElabDescr where
finalizeTermElab := withLCtx lctx localInsts do checkDefaults fieldInfos
prefinalize := fun levelParams params replaceIndFVars => do
withLCtx lctx localInsts do
addProjections r fieldInfos
addProjections params r fieldInfos
registerStructure view.declName fieldInfos
runStructElabM (init := state) do
mkFlatCtor levelParams params view.declName replaceIndFVars

View file

@ -1320,6 +1320,17 @@ def inferImplicit (e : Expr) (numParams : Nat) (considerRange : Bool) : Expr :=
mkForall n newInfo d b
| e, _ => e
/--
Uses `newBinderInfos` to update the binder infos of the first `numParams` foralls.
-/
def updateForallBinderInfos (e : Expr) (binderInfos? : List (Option BinderInfo)) : Expr :=
match e, binderInfos? with
| Expr.forallE n d b bi, newBi? :: binderInfos? =>
let b := updateForallBinderInfos b binderInfos?
let bi := newBi?.getD bi
Expr.forallE n d b bi
| e, _ => e
/--
Instantiates the loose bound variables in `e` using the `subst` array,
where a loose `Expr.bvar i` at "binding depth" `d` is instantiated with `subst[i - d]` if `0 <= i - d < subst.size`,

View file

@ -32,6 +32,8 @@ Structure projection declaration for `mkProjections`.
structure StructProjDecl where
ref : Syntax
projName : Name
/-- Overrides to param binders to apply after param binder info inference. -/
paramInfoOverrides : List (Option BinderInfo) := []
/--
Adds projection functions to the environment for the one-constructor inductive type named `n`.
@ -73,11 +75,15 @@ def mkProjections (n : Name) (projDecls : Array StructProjDecl) (instImplicit :
-- Construct the projection functions:
let mut ctorType := ctorType
for h : i in [0:projDecls.size] do
let {ref, projName} := projDecls[i]
let {ref, projName, paramInfoOverrides} := projDecls[i]
unless ctorType.isForall do
throwErrorAt ref "\
failed to generate projection '{projName}' for '{.ofConstName n}', \
not enough constructor fields"
unless paramInfoOverrides.length ≤ params.size do
throwErrorAt ref "\
failed to generate projection '{projName}' for '{.ofConstName n}', \
too many structure parameter overrides"
let resultType := ctorType.bindingDomain!.consumeTypeAnnotations
let isProp ← isProp resultType
if isPredicate && !isProp then
@ -87,6 +93,7 @@ def mkProjections (n : Name) (projDecls : Array StructProjDecl) (instImplicit :
{indentExpr resultType}"
let projType := lctx.mkForall projArgs resultType
let projType := projType.inferImplicit indVal.numParams (considerRange := true)
let projType := projType.updateForallBinderInfos paramInfoOverrides
let projVal := lctx.mkLambda projArgs <| Expr.proj n i self
let cval : ConstantVal := { name := projName, levelParams := indVal.levelParams, type := projType }
withRef ref do

View file

@ -0,0 +1,136 @@
/-!
# Tests of structure parameter binder updates
-/
/-!
Motivating issue: https://github.com/leanprover/lean4/issues/3574
Normally one defines a `cast_eq_zero_iff'` field and restates a `cast_eq_zero_iff` version.
-/
namespace Issue3574
class AddMonoidWithOne (R : Type u) extends Add R, Zero R where
natCast : Nat → R
instance [AddMonoidWithOne R] : Coe Nat R where
coe := AddMonoidWithOne.natCast
attribute [coe] AddMonoidWithOne.natCast
class CharP [AddMonoidWithOne R] (p : Nat) : Prop where
cast_eq_zero_iff (R) (p) : ∀ x : Nat, (x : R) = 0 ↔ p x
-- Both `R` and `p` are explicit now.
/--
info: Issue3574.CharP.cast_eq_zero_iff.{u_1} (R : Type u_1) {inst✝ : AddMonoidWithOne R} (p : Nat) [self : CharP p]
(x : Nat) : ↑x = 0 ↔ p x
-/
#guard_msgs in #check CharP.cast_eq_zero_iff
-- Multiple parameters can be updated at once.
class CharP' [AddMonoidWithOne R] (p : Nat) : Prop where
cast_eq_zero_iff (R p) : ∀ x : Nat, (x : R) = 0 ↔ p x
/--
info: Issue3574.CharP'.cast_eq_zero_iff.{u_1} (R : Type u_1) {inst✝ : AddMonoidWithOne R} (p : Nat) [self : CharP' p]
(x : Nat) : ↑x = 0 ↔ p x
-/
#guard_msgs in #check CharP'.cast_eq_zero_iff
end Issue3574
/-!
Basic test for structures.
-/
namespace Ex1
structure Inhabited (α : Type) where
default : α
/-- info: Ex1.Inhabited.default {α : Type} (self : Inhabited α) : α -/
#guard_msgs in #check Inhabited.default
structure Inhabited' (α : Type) where
default (α) : α
/-- info: Ex1.Inhabited'.default (α : Type) (self : Inhabited' α) : α -/
#guard_msgs in #check Inhabited'.default
end Ex1
/-!
Basic test for classes.
-/
namespace Ex2
class Inhabited (α : Type) where
default : α
/-- info: Ex2.Inhabited.default {α : Type} [self : Inhabited α] : α -/
#guard_msgs in #check Inhabited.default
class Inhabited' (α : Type) where
default (α) : α
/-- info: Ex2.Inhabited'.default (α : Type) [self : Inhabited' α] : α -/
#guard_msgs in #check Inhabited'.default
end Ex2
/-!
Example with a parameter from a `variable`
-/
namespace Ex3
class Inhabited (α : Type) where
default : α
/-- info: Ex3.Inhabited.default {α : Type} [self : Inhabited α] : α -/
#guard_msgs in #check Inhabited.default
class Inhabited' (α : Type) where
default (α) : α
/-- info: Ex3.Inhabited'.default (α : Type) [self : Inhabited' α] : α -/
#guard_msgs in #check Inhabited'.default
end Ex3
/-!
Trying to set a `variable` binder kind; only parameters in the declaration itself can be overridden.
Rationale: we found in mathlib that often users had large binder lists declared at the beginning of files,
and the structure fields accidentally were shadowing them.
-/
namespace Ex4
variable (α : Type)
/--
error: only parameters appearing in the declaration header may have their binders kinds be overridden
If this is not intended to be an override, use a binder with a type, for example '(x : _)'.
-/
#guard_msgs in
class Inhabited where
default (α) : α
end Ex4
/-!
Here, `(α β)` is not an override since `β` is not an existing parameter, so `α` is treated as a binder.
-/
namespace Ex5
/-- error: failed to infer binder type -/
#guard_msgs in
class C (α : Type) where
f (α β) : β
end Ex5
/-!
Here, `(α β)` is not an override since `β` is a field.
-/
namespace Ex6
/-- error: failed to infer binder type -/
#guard_msgs in
class C (α : Type) where
β : Type
f (α β) : β
end Ex6