fix: handle flattened inheritance in wrapInstance (#13302)

Instead of unconditionally wrapping value of fields that were copied
from flattened parent structures, try finding an existing instance and
projecting it first.
This commit is contained in:
Sebastian Ullrich 2026-04-09 17:53:21 +02:00 committed by GitHub
parent 1d1c5c6e30
commit 031bfa5989
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 112 additions and 36 deletions

View file

@ -9,8 +9,8 @@ prelude
public import Lean.Meta.Closure
public import Lean.Meta.SynthInstance
public import Lean.Meta.CtorRecognizer
public section
public import Lean.Meta.AppBuilder
import Lean.Structure
/-!
# Instance Wrapping
@ -25,22 +25,30 @@ Given an instance `i : I` and expected type `I'` (where `I'` must be mvar-free),
`wrapInstance` constructs a result instance as follows, executing all steps at
`instances` transparency:
1. If `I'` is not a class, return `i` unchanged.
1. If `I'` is not a class application, return `i` unchanged.
2. If `I'` is a proposition, wrap `i` in an auxiliary theorem of type `I'` and return it
(controlled by `backward.inferInstanceAs.wrap.instances`).
3. Reduce `i` to whnf.
4. If `i` is not a constructor application: if the type of `i` is already defeq to `I'`,
4. If `i` is not a constructor application: if `I` is already defeq to `I'`,
return `i`; otherwise wrap it in an auxiliary definition of type `I'` and return it
(controlled by `backward.inferInstanceAs.wrap.instances`).
5. Otherwise, for `i = ctor a₁ ... aₙ` with `ctor : C ?p₁ ... ?pₙ`:
- Unify `C ?p₁ ... ?pₙ` with `I'`.
- Return a new application `ctor a₁' ... aₙ' : I'` where each `aᵢ'` is constructed as:
5. Otherwise, if `i` is an application of `ctor` of class `C`:
- Unify the conclusion of the type of `ctor` with `I'` to obtain adjusted field type `Fᵢ'` for
each field.
- Return a new application `ctor ... : I'` where the fields are adjusted as follows:
- If the field type is a proposition: assign directly if types are defeq, otherwise
wrap in an auxiliary theorem.
- If the field type is a class: first try to reuse an existing synthesized instance
for the target type (controlled by `backward.inferInstanceAs.wrap.reuseSubInstances`);
if that fails, recurse with source instance `aᵢ` and expected type `?pᵢ`.
- Otherwise (data field): assign directly if types are defeq, otherwise wrap in an
- If the field is a parent field (subobject) `p : P`: first try to reuse an existing
instance that can be synthesized for `P` (controlled by
`backward.inferInstanceAs.wrap.reuseSubInstances`) in order to preserve defeqs; if that
fails, recurse.
- If it is a field of a flattened parent class `C'` and
`backward.inferInstanceAs.wrap.reuseSubInstances` is true, try synthesizing an instance of
`C'` for `I'` and if successful, use the corresponding projection of the found instance in
order to preserve defeqs; otherwise, continue.
- Specifically, construct the chain of base projections from `C` to `C'` applied to `_ : I'`
and infer its type to obtain an appropriate application of `C'` for the instance search.
- Otherwise (non-inherited data field): assign directly if types are defeq, otherwise wrap in an
auxiliary definition to fix the type (controlled by `backward.inferInstanceAs.wrap.data`).
## Options
@ -56,34 +64,36 @@ Given an instance `i : I` and expected type `I'` (where `I'` must be mvar-free),
namespace Lean.Meta
register_builtin_option backward.inferInstanceAs.wrap : Bool := {
public register_builtin_option backward.inferInstanceAs.wrap : Bool := {
defValue := true
descr := "wrap instance bodies in `inferInstanceAs` and the default `deriving` handler"
}
register_builtin_option backward.inferInstanceAs.wrap.reuseSubInstances : Bool := {
public register_builtin_option backward.inferInstanceAs.wrap.reuseSubInstances : Bool := {
defValue := true
descr := "when recursing into sub-instances, reuse existing instances for the target type instead of re-wrapping them, which can be important to avoid non-defeq instance diamonds"
}
register_builtin_option backward.inferInstanceAs.wrap.instances : Bool := {
public register_builtin_option backward.inferInstanceAs.wrap.instances : Bool := {
defValue := true
descr := "wrap non-reducible instances in auxiliary definitions to fix their types"
}
register_builtin_option backward.inferInstanceAs.wrap.data : Bool := {
public register_builtin_option backward.inferInstanceAs.wrap.data : Bool := {
defValue := true
descr := "wrap data fields in auxiliary definitions to fix their types"
}
builtin_initialize registerTraceClass `Meta.wrapInstance
open Meta
/--
Rebuild a type application with fresh synthetic metavariables for instance-implicit arguments.
Non-instance-implicit arguments are assigned from the original application's arguments.
If the function is over-applied, extra arguments are preserved.
-/
def abstractInstImplicitArgs (type : Expr) : MetaM Expr := do
public def abstractInstImplicitArgs (type : Expr) : MetaM Expr := do
let fn := type.getAppFn
let args := type.getAppArgs
let (mvars, bis, _) ← forallMetaTelescope (← inferType fn)
@ -93,11 +103,38 @@ def abstractInstImplicitArgs (type : Expr) : MetaM Expr := do
let args := mvars ++ args.drop mvars.size
instantiateMVars (mkAppN fn args)
partial def getFieldOrigin (structName field : Name) : MetaM (Name × StructureFieldInfo) := do
let env ← getEnv
for parent in getStructureParentInfo env structName do
if (findField? env parent.structName field).isSome then
return ← getFieldOrigin parent.structName field
let some fi := getFieldInfo? env structName field
| throwError "no such field {field} in {structName}"
return (structName, fi)
/-- Projects application of a structure type to corresponding application of a parent structure. -/
def getParentStructType? (structName parentStructName : Name) (structType : Expr) : MetaM (Option Expr) := OptionT.run do
let env ← getEnv
let some path := getPathToBaseStructure? env parentStructName structName | failure
withLocalDeclD `self structType fun self => do
let proj ← path.foldlM (init := self) fun e projFn => do
let ty ← whnf (← inferType e)
let .const _ us := ty.getAppFn
| trace[Meta.wrapInstance] "could not reduce type `{ty}`"
failure
let params := ty.getAppArgs
pure <| mkApp (mkAppN (.const projFn us) params) e
let projTy ← whnf <| ← inferType proj
if projTy.containsFVar self.fvarId! then
trace[Meta.wrapInstance] "parent type depends on instance fields{indentExpr projTy}"
failure
return projTy
/--
Wrap an instance value so its type matches the expected type exactly.
See the module docstring for the full algorithm specification.
-/
partial def wrapInstance (inst expectedType : Expr) (compile : Bool := true)
public partial def wrapInstance (inst expectedType : Expr) (compile : Bool := true)
(logCompileErrors : Bool := true) (isMeta : Bool := false) : MetaM Expr := withTransparency .instances do
withTraceNode `Meta.wrapInstance
(fun _ => return m!"type: {expectedType}") do
@ -155,8 +192,10 @@ partial def wrapInstance (inst expectedType : Expr) (compile : Bool := true)
else
trace[Meta.wrapInstance] "proof field {i} does not have expected type {argExpectedType} but {argType}, wrapping in auxiliary theorem: {arg}"
mvarId.assign (← mkAuxTheorem argExpectedType arg (zetaDelta := true))
continue
-- Recurse into instance arguments of the constructor
else if (← isClass? argExpectedType).isSome then
if (← isClass? argExpectedType).isSome then
if backward.inferInstanceAs.wrap.reuseSubInstances.get (← getOptions) then
-- Reuse existing instance for the target type if any. This is especially important when recursing
-- as it guarantees subinstances of overlapping instances are defeq under more than just
@ -170,22 +209,35 @@ partial def wrapInstance (inst expectedType : Expr) (compile : Bool := true)
mvarId.assign (← wrapInstance arg argExpectedType (compile := compile)
(logCompileErrors := logCompileErrors) (isMeta := isMeta))
else
-- For data fields, assign directly or wrap in aux def to fix types.
if backward.inferInstanceAs.wrap.data.get (← getOptions) then
let argType ← inferType arg
if ← isDefEq argExpectedType argType then
mvarId.assign arg
else
let name ← mkAuxDeclName
mvarId.assign (← mkAuxDefinition name argExpectedType arg (compile := false))
setInlineAttribute name
if isMeta then modifyEnv (markMeta · name)
if compile then
compileDecls (logErrors := logCompileErrors) #[name]
enableRealizationsForConst name
else
mvarId.assign arg
return mkAppN f (← mvars.mapM instantiateMVars)
continue
end Lean.Meta
if backward.inferInstanceAs.wrap.reuseSubInstances.get (← getOptions) then
let (baseClassName, fieldInfo) ← getFieldOrigin className mvarDecl.userName
if baseClassName != className then
trace[Meta.wrapInstance] "found inherited field `{mvarDecl.userName}` from parent `{baseClassName}`"
if let some baseClassType ← getParentStructType? className baseClassName expectedType then
try
if let .some existingBaseClassInst ← trySynthInstance baseClassType then
trace[Meta.wrapInstance] "using projection of existing instance `{existingBaseClassInst}`"
mvarId.assign (← mkProjection existingBaseClassInst fieldInfo.fieldName)
continue
trace[Meta.wrapInstance] "did not find existing instance for `{baseClassName}`"
catch e =>
trace[Meta.wrapInstance] "error when attempting to reuse existing instance for `{baseClassName}`: {e.toMessageData}"
-- For data fields, assign directly or wrap in aux def to fix types.
if backward.inferInstanceAs.wrap.data.get (← getOptions) then
let argType ← inferType arg
if ← isDefEq argExpectedType argType then
mvarId.assign arg
else
let name ← mkAuxDeclName
mvarId.assign (← mkAuxDefinition name argExpectedType arg (compile := false))
setInlineAttribute name
if isMeta then modifyEnv (markMeta · name)
if compile then
compileDecls (logErrors := logCompileErrors) #[name]
enableRealizationsForConst name
else
mvarId.assign arg
return mkAppN f (← mvars.mapM instantiateMVars)

View file

@ -52,6 +52,30 @@ instCI2
#guard_msgs in
#print instCD2._aux_1
/-! Flattened inheritance test. -/
class Base (α : Type) where
b : α
class Foo (α : Type) extends Base α where
a : α
class Bar (α : Type) extends Base α where
c : α
class FooBar (α : Type) extends Foo α, Bar α
instance : FooBar Nat where
a := 0
b := 1
c := 2
def MyNat := Nat
deriving Foo, Bar, FooBar
theorem zou : instFooBarMyNat.toBar = instBarMyNat := by
with_reducible_and_instances rfl
/-! Non-constructor instances should be used as is. -/
@[macro_inline, implicit_reducible]