diff --git a/src/Lean/Elab/Structure.lean b/src/Lean/Elab/Structure.lean index 1e04b4c59a..ac4f84c5da 100644 --- a/src/Lean/Elab/Structure.lean +++ b/src/Lean/Elab/Structure.lean @@ -160,6 +160,8 @@ structure StructFieldInfo where declName : Name /-- Binder info to use when making the constructor. Only applies to those fields that will appear in the constructor. -/ binfo : BinderInfo + /-- Overrides for the parameters' binder infos when making the projections. The first component is a ref for the binder. -/ + paramInfoOverrides : ExprMap (Syntax × BinderInfo) := {} /-- Structure names that are responsible for this field being here. - Empty if the field is a `newField`. @@ -184,7 +186,7 @@ structure StructFieldInfo where inheritedDefaults : Array (Name × StructFieldDefault) := #[] /-- The default that will be used for this structure. -/ resolvedDefault? : Option StructFieldDefault := none - deriving Inhabited, Repr + deriving Inhabited /-! ### View construction @@ -922,23 +924,58 @@ private def solveParentMVars (e : Expr) : StructElabM Expr := do discard <| MVarId.checkedAssign mvar parentInfo.fvar return e -private def elabFieldTypeValue (view : StructFieldView) : StructElabM (Option Expr × Option StructFieldDefault) := do +open Parser.Term in +private def typelessBinder? : Syntax → Option ((Array Ident) × BinderInfo) + | `(bracketedBinderF|($ids:ident*)) => some (ids, .default) + | `(bracketedBinderF|{$ids:ident*}) => some (ids, .implicit) + | `(bracketedBinderF|⦃$ids:ident*⦄) => some (ids, .strictImplicit) + | `(bracketedBinderF|[$id:ident]) => some (#[id], .instImplicit) + | _ => none + +/-- +Takes a binder list and interprets the prefix to see if any could be construed to be binder info updates. +Returns the binder list without these updates along with the new binder infos for these parameters. +-/ +private def elabParamInfoUpdates (structParams : Array Expr) (binders : Array Syntax) : StructElabM (Array Syntax × ExprMap (Syntax × BinderInfo)) := do + let mut overrides : ExprMap (Syntax × BinderInfo) := {} + for i in [0:binders.size] do + match typelessBinder? binders[i]! with + | none => return (binders.extract i, overrides) + | some (ids, bi) => + let lctx ← getLCtx + let decls := ids.filterMap fun id => lctx.findFromUserName? id.getId + -- Filter out all fields. We assume the remaining fvars are the possible parameters. + let decls ← decls.filterM fun decl => return (← findFieldInfoByFVarId? decl.fvarId).isNone + if decls.size != ids.size then + -- Then either these are for a new variables or the binder isn't only for parameters + return (binders.extract i, overrides) + for decl in decls, id in ids do + Term.addTermInfo' id decl.toExpr + unless structParams.contains decl.toExpr do + throwErrorAt id m!"only parameters appearing in the declaration header may have their binders kinds be overridden\n\n\ + If this is not intended to be an override, use a binder with a type, for example '(x : _)'." + overrides := overrides.insert decl.toExpr (id, bi) + return (#[], overrides) + +private def elabFieldTypeValue (structParams : Array Expr) (view : StructFieldView) : + StructElabM (Option Expr × ExprMap (Syntax × BinderInfo) × Option StructFieldDefault) := do let state ← get - Term.withAutoBoundImplicit <| Term.withAutoBoundImplicitForbiddenPred (fun n => view.name == n) <| Term.elabBinders view.binders.getArgs fun params => do + let binders := view.binders.getArgs + let (binders, paramInfoOverrides) ← elabParamInfoUpdates structParams binders + Term.withAutoBoundImplicit <| Term.withAutoBoundImplicitForbiddenPred (fun n => view.name == n) <| Term.elabBinders binders fun params => do match view.type? with - | none => + | none => match view.default? with - | none => return (none, none) + | none => return (none, paramInfoOverrides, none) | some (.optParam valStx) => Term.synthesizeSyntheticMVarsNoPostponing - -- TODO: add forbidden predicate using `shortDeclName` from `view` let params ← Term.addAutoBoundImplicits params (view.nameId.getTailPos? (canonicalOnly := true)) let value ← Term.withoutAutoBoundImplicit <| Term.elabTerm valStx none let value ← runStructElabM (init := state) <| solveParentMVars value registerFailedToInferFieldType view.name (← inferType value) view.nameId registerFailedToInferDefaultValue view.name value valStx let value ← mkLambdaFVars params value - return (none, StructFieldDefault.optParam value) + return (none, paramInfoOverrides, StructFieldDefault.optParam value) | some (.autoParam tacticStx) => throwErrorAt tacticStx "invalid field declaration, type must be provided when auto-param tactic is used" | some typeStx => @@ -948,9 +985,9 @@ private def elabFieldTypeValue (view : StructFieldView) : StructElabM (Option Ex Term.synthesizeSyntheticMVarsNoPostponing let params ← Term.addAutoBoundImplicits params (view.nameId.getTailPos? (canonicalOnly := true)) match view.default? with - | none => + | none => let type ← mkForallFVars params type - return (type, none) + return (type, paramInfoOverrides, none) | some (.optParam valStx) => let value ← Term.withoutAutoBoundImplicit <| Term.elabTermEnsuringType valStx type let value ← runStructElabM (init := state) <| solveParentMVars value @@ -958,14 +995,14 @@ private def elabFieldTypeValue (view : StructFieldView) : StructElabM (Option Ex Term.synthesizeSyntheticMVarsNoPostponing let type ← mkForallFVars params type let value ← mkLambdaFVars params value - return (type, StructFieldDefault.optParam value) + return (type, paramInfoOverrides, StructFieldDefault.optParam value) | some (.autoParam tacticStx) => let name := mkAutoParamFnOfProjFn view.declName discard <| Term.declareTacticSyntax tacticStx name let type ← mkForallFVars params type - return (type, StructFieldDefault.autoParam <| .const name []) + return (type, paramInfoOverrides, StructFieldDefault.autoParam <| .const name []) -private partial def withFields (views : Array StructFieldView) (k : StructElabM α) : StructElabM α := do +private partial def withFields (structParams : Array Expr) (views : Array StructFieldView) (k : StructElabM α) : StructElabM α := do go 0 where go (i : Nat) : StructElabM α := do @@ -976,14 +1013,14 @@ where throwError "field '{view.name}' has already been declared as a projection for parent '{.ofConstName parent.structName}'" match ← findFieldInfo? view.name with | none => - let (type?, default?) ← elabFieldTypeValue view + let (type?, paramInfoOverrides, default?) ← elabFieldTypeValue structParams view match type?, default? with | none, none => throwError "invalid field, type expected" | some type, _ => withLocalDecl view.rawName view.binderInfo type fun fieldFVar => do addFieldInfo { ref := view.nameId, sourceStructNames := [], name := view.name, declName := view.declName, fvar := fieldFVar, default? := default?, - binfo := view.binderInfo, + binfo := view.binderInfo, paramInfoOverrides, kind := StructFieldKind.newField } go (i+1) | none, some (.optParam value) => @@ -991,7 +1028,7 @@ where withLocalDecl view.rawName view.binderInfo type fun fieldFVar => do addFieldInfo { ref := view.nameId, sourceStructNames := [], name := view.name, declName := view.declName, fvar := fieldFVar, default? := default?, - binfo := view.binderInfo, + binfo := view.binderInfo, paramInfoOverrides, kind := StructFieldKind.newField } go (i+1) | none, some (.autoParam _) => @@ -1007,8 +1044,12 @@ where if info.default?.isSome then throwError "field '{view.name}' new default value has already been set" let mut valStx := valStx - if view.binders.getArgs.size > 0 then - valStx ← `(fun $(view.binders.getArgs)* => $valStx:term) + let (binders, paramInfoOverrides) ← elabParamInfoUpdates structParams view.binders.getArgs + unless paramInfoOverrides.isEmpty do + let params := MessageData.joinSep (paramInfoOverrides.toList.map (m!"{·.1}")) ", " + throwError "cannot override structure parameter binder kinds when overriding the default value: {params}" + if binders.size > 0 then + valStx ← `(fun $binders* => $valStx:term) let fvarType ← inferType info.fvar let value ← Term.elabTermEnsuringType valStx fvarType registerFailedToInferDefaultValue view.name value valStx @@ -1160,11 +1201,16 @@ private partial def checkResultingUniversesForFields (fieldInfos : Array StructF which is not less than or equal to the structure's resulting universe level{indentD u}" throwErrorAt info.ref msg -private def addProjections (r : ElabHeaderResult) (fieldInfos : Array StructFieldInfo) : TermElabM Unit := do - let projDecls : Array StructProjDecl := +private def addProjections (params : Array Expr) (r : ElabHeaderResult) (fieldInfos : Array StructFieldInfo) : TermElabM Unit := do + let projDecls : Array StructProjDecl ← fieldInfos |>.filter (·.kind.isInCtor) - |>.map (fun info => { ref := info.ref, projName := info.declName }) + |>.mapM (fun info => do + info.paramInfoOverrides.forM fun p (ref, _) => do + unless params.contains p do + throwErrorAt ref "invalid parameter binder update, not a parameter" + let paramInfoOverrides := params |>.map (fun param => info.paramInfoOverrides[param]?.map Prod.snd) |>.toList + return { ref := info.ref, projName := info.declName, paramInfoOverrides }) mkProjections r.view.declName projDecls r.view.isClass for fieldInfo in fieldInfos do if fieldInfo.kind.isSubobject then @@ -1412,7 +1458,7 @@ def elabStructureCommand : InductiveElabDescr where view := view.toInductiveView elabCtors := fun rs r params => runStructElabM do withParents view rs r.indFVar do - withFields view.fields do + withFields params view.fields do withRef view.ref do Term.synthesizeSyntheticMVarsNoPostponing resolveFieldDefaults view.declName @@ -1429,7 +1475,7 @@ def elabStructureCommand : InductiveElabDescr where finalizeTermElab := withLCtx lctx localInsts do checkDefaults fieldInfos prefinalize := fun levelParams params replaceIndFVars => do withLCtx lctx localInsts do - addProjections r fieldInfos + addProjections params r fieldInfos registerStructure view.declName fieldInfos runStructElabM (init := state) do mkFlatCtor levelParams params view.declName replaceIndFVars diff --git a/src/Lean/Expr.lean b/src/Lean/Expr.lean index 2258c3cb1e..301128cafa 100644 --- a/src/Lean/Expr.lean +++ b/src/Lean/Expr.lean @@ -1320,6 +1320,17 @@ def inferImplicit (e : Expr) (numParams : Nat) (considerRange : Bool) : Expr := mkForall n newInfo d b | e, _ => e +/-- +Uses `newBinderInfos` to update the binder infos of the first `numParams` foralls. +-/ +def updateForallBinderInfos (e : Expr) (binderInfos? : List (Option BinderInfo)) : Expr := + match e, binderInfos? with + | Expr.forallE n d b bi, newBi? :: binderInfos? => + let b := updateForallBinderInfos b binderInfos? + let bi := newBi?.getD bi + Expr.forallE n d b bi + | e, _ => e + /-- Instantiates the loose bound variables in `e` using the `subst` array, where a loose `Expr.bvar i` at "binding depth" `d` is instantiated with `subst[i - d]` if `0 <= i - d < subst.size`, diff --git a/src/Lean/Meta/Structure.lean b/src/Lean/Meta/Structure.lean index 348ffb785e..1fb876188f 100644 --- a/src/Lean/Meta/Structure.lean +++ b/src/Lean/Meta/Structure.lean @@ -32,6 +32,8 @@ Structure projection declaration for `mkProjections`. structure StructProjDecl where ref : Syntax projName : Name + /-- Overrides to param binders to apply after param binder info inference. -/ + paramInfoOverrides : List (Option BinderInfo) := [] /-- Adds projection functions to the environment for the one-constructor inductive type named `n`. @@ -73,11 +75,15 @@ def mkProjections (n : Name) (projDecls : Array StructProjDecl) (instImplicit : -- Construct the projection functions: let mut ctorType := ctorType for h : i in [0:projDecls.size] do - let {ref, projName} := projDecls[i] + let {ref, projName, paramInfoOverrides} := projDecls[i] unless ctorType.isForall do throwErrorAt ref "\ failed to generate projection '{projName}' for '{.ofConstName n}', \ not enough constructor fields" + unless paramInfoOverrides.length ≤ params.size do + throwErrorAt ref "\ + failed to generate projection '{projName}' for '{.ofConstName n}', \ + too many structure parameter overrides" let resultType := ctorType.bindingDomain!.consumeTypeAnnotations let isProp ← isProp resultType if isPredicate && !isProp then @@ -87,6 +93,7 @@ def mkProjections (n : Name) (projDecls : Array StructProjDecl) (instImplicit : {indentExpr resultType}" let projType := lctx.mkForall projArgs resultType let projType := projType.inferImplicit indVal.numParams (considerRange := true) + let projType := projType.updateForallBinderInfos paramInfoOverrides let projVal := lctx.mkLambda projArgs <| Expr.proj n i self let cval : ConstantVal := { name := projName, levelParams := indVal.levelParams, type := projType } withRef ref do diff --git a/tests/lean/run/structBinderUpdates.lean b/tests/lean/run/structBinderUpdates.lean new file mode 100644 index 0000000000..088f3e072d --- /dev/null +++ b/tests/lean/run/structBinderUpdates.lean @@ -0,0 +1,136 @@ +/-! +# Tests of structure parameter binder updates +-/ + +/-! +Motivating issue: https://github.com/leanprover/lean4/issues/3574 +Normally one defines a `cast_eq_zero_iff'` field and restates a `cast_eq_zero_iff` version. +-/ +namespace Issue3574 + +class AddMonoidWithOne (R : Type u) extends Add R, Zero R where + natCast : Nat → R + +instance [AddMonoidWithOne R] : Coe Nat R where + coe := AddMonoidWithOne.natCast +attribute [coe] AddMonoidWithOne.natCast + +class CharP [AddMonoidWithOne R] (p : Nat) : Prop where + cast_eq_zero_iff (R) (p) : ∀ x : Nat, (x : R) = 0 ↔ p ∣ x + +-- Both `R` and `p` are explicit now. +/-- +info: Issue3574.CharP.cast_eq_zero_iff.{u_1} (R : Type u_1) {inst✝ : AddMonoidWithOne R} (p : Nat) [self : CharP p] + (x : Nat) : ↑x = 0 ↔ p ∣ x +-/ +#guard_msgs in #check CharP.cast_eq_zero_iff + +-- Multiple parameters can be updated at once. +class CharP' [AddMonoidWithOne R] (p : Nat) : Prop where + cast_eq_zero_iff (R p) : ∀ x : Nat, (x : R) = 0 ↔ p ∣ x + +/-- +info: Issue3574.CharP'.cast_eq_zero_iff.{u_1} (R : Type u_1) {inst✝ : AddMonoidWithOne R} (p : Nat) [self : CharP' p] + (x : Nat) : ↑x = 0 ↔ p ∣ x +-/ +#guard_msgs in #check CharP'.cast_eq_zero_iff + +end Issue3574 + +/-! +Basic test for structures. +-/ +namespace Ex1 + +structure Inhabited (α : Type) where + default : α + +/-- info: Ex1.Inhabited.default {α : Type} (self : Inhabited α) : α -/ +#guard_msgs in #check Inhabited.default + +structure Inhabited' (α : Type) where + default (α) : α + +/-- info: Ex1.Inhabited'.default (α : Type) (self : Inhabited' α) : α -/ +#guard_msgs in #check Inhabited'.default + +end Ex1 + +/-! +Basic test for classes. +-/ +namespace Ex2 + +class Inhabited (α : Type) where + default : α + +/-- info: Ex2.Inhabited.default {α : Type} [self : Inhabited α] : α -/ +#guard_msgs in #check Inhabited.default + +class Inhabited' (α : Type) where + default (α) : α + +/-- info: Ex2.Inhabited'.default (α : Type) [self : Inhabited' α] : α -/ +#guard_msgs in #check Inhabited'.default + +end Ex2 + +/-! +Example with a parameter from a `variable` +-/ +namespace Ex3 + +class Inhabited (α : Type) where + default : α + +/-- info: Ex3.Inhabited.default {α : Type} [self : Inhabited α] : α -/ +#guard_msgs in #check Inhabited.default + +class Inhabited' (α : Type) where + default (α) : α + +/-- info: Ex3.Inhabited'.default (α : Type) [self : Inhabited' α] : α -/ +#guard_msgs in #check Inhabited'.default + +end Ex3 + +/-! +Trying to set a `variable` binder kind; only parameters in the declaration itself can be overridden. +Rationale: we found in mathlib that often users had large binder lists declared at the beginning of files, +and the structure fields accidentally were shadowing them. +-/ +namespace Ex4 + +variable (α : Type) + +/-- +error: only parameters appearing in the declaration header may have their binders kinds be overridden + +If this is not intended to be an override, use a binder with a type, for example '(x : _)'. +-/ +#guard_msgs in +class Inhabited where + default (α) : α + +end Ex4 + +/-! +Here, `(α β)` is not an override since `β` is not an existing parameter, so `α` is treated as a binder. +-/ +namespace Ex5 +/-- error: failed to infer binder type -/ +#guard_msgs in +class C (α : Type) where + f (α β) : β +end Ex5 + +/-! +Here, `(α β)` is not an override since `β` is a field. +-/ +namespace Ex6 +/-- error: failed to infer binder type -/ +#guard_msgs in +class C (α : Type) where + β : Type + f (α β) : β +end Ex6