feat: per-constructor noConfusion constructions (#10315)

This PR adds `T.ctor.noConfusion` declarations, which are
specializations of `T.noConfusion` to equalities between `T.ctor`. The
point is to avoid reducing the `T.noConfusionType` construction every
time we use `injection` or a similar tactic.

```lean
Vec.cons.noConfusion.{u_1, u} {α : Type u} (P : Sort u_1) {n : Nat}
  (x : α) (xs : Vec α n) (x' : α) (xs' : Vec α n)
  (h : Vec.cons x xs = Vec.cons x' xs')
  (k : n = n → x = x' → xs ≍ xs' → P) : P
```

The constructions are not as powerful as `T.noConfusion` when the
indices of the inductive type are not just constructor parameters (or
constructor applications of these parameters), so the full
`T.noConfusion` construction is still needed as a fallback.

It may seem costly to generate these eagerly, but given that we eagerly
generate injectivity theorems already, and we will use them there, it
seems reasonable for now.

To further reduce the cost, we only generate them for constructors with
fields (for others, the `T.noConfusion` theorem doesn't provide any
information), and we use `macro_inline` to prevent the compiler from
creating code for these, given that the compiler has special support for
`T.noConfusion` that we want it to use).

An earlier version of this PR also removed trivial equations and
un-HEq-ed others, leading to
```
 (k : x = x' → xs = xs' → P) 
```
in the example above. I backed out of that change, as it makes it harder
for tactics like `injectivity` to know how often to `intro`, so better
to keep things uniform.
This commit is contained in:
Joachim Breitner 2025-09-12 10:00:12 +02:00 committed by GitHub
parent caa0eacea8
commit 3cf7fdcbe0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 276 additions and 3 deletions

View file

@ -12,6 +12,7 @@ import Lean.Meta.AppBuilder
import Lean.Meta.CompletionName
import Lean.Meta.NatTable
import Lean.Meta.Constructions.CtorIdx
import Lean.Meta.SameCtorUtils
import Lean.Meta.Constructions.CtorIdx
import Lean.Meta.Constructions.CtorElim
@ -28,7 +29,7 @@ fun params x1 x2 x3 x1' x2' x3' => (x1 = x1' → x2 = x2' → x3 = x3' → P)
where `x1 x2 x3` and `x1' x2' x3'` are the fields of a constructor application of `ctorName`,
omitting equalities between propositions and using `HEq` where needed.
-/
public def mkNoConfusionCtorArg (ctorName : Name) (P : Expr) : MetaM Expr := do
def mkNoConfusionCtorArg (ctorName : Name) (P : Expr) : MetaM Expr := do
let ctorInfo ← getConstInfoCtor ctorName
-- We bring the constructor's parameters into scope abstractly, this way
-- we can check if we need to use HEq. (The concrete fields could allow Eq)
@ -68,7 +69,7 @@ def mkIfNatEq (P : Expr) (e1 e2 : Expr) («then» : Expr → MetaM Expr) («else
let e := mkApp e (← withLocalDeclD `h (mkNot heq) (fun h => do mkLambdaFVars #[h] (← «else» h)))
pure e
public def mkNoConfusionType (indName : Name) : MetaM Unit := do
def mkNoConfusionType (indName : Name) : MetaM Unit := do
let declName := mkNoConfusionTypeName indName
let ConstantInfo.inductInfo info ← getConstInfo indName | unreachable!
let useLinearConstruction :=
@ -194,6 +195,72 @@ def mkNoConfusionCoreImp (indName : Name) : MetaM Unit := do
modifyEnv fun env => markNoConfusion env declName
modifyEnv fun env => addProtected env declName
/--
Creates per-constructor no-confusion definitions. These specialize the general noConfusion
declaration to equalities between two applications of the same constructor, to effectively cache
the computation of `noConfusionType` for that constructor:
```
def L.cons.noConfusion.{u_1, u} : {α : Type u} → (P : Sort u_1) →
(x : α) → (xs : L α) → (x' : α) → (xs' : L α) →
L.cons x xs = L.cons x' xs' →
(x = x' → xs = xs' → P) →
P
```
These definitions are less expressive than the general `noConfusion` principle when there are
complicated indices. In particular they assume that all fields of the constructor that appear
in its type are equal already. The `mkNoConfusion` app builder falls back to the general principle
if the per-constructor one does not apply.
At some point I tried to be clever and remove hypotheses that are trivial (`n = n →`), but that
made it harder for, say, `injection` to know how often to `intro`. So we just keep them.
-/
def mkNoConfusionCtors (declName : Name) : MetaM Unit := do
-- Do not do anything unless can_elim_to_type.
let .inductInfo indVal ← getConstInfo declName | return
let recInfo ← getConstInfo (mkRecName declName)
unless recInfo.levelParams.length > indVal.levelParams.length do return
if (← isPropFormerType indVal.type) then return
let noConfusionName := Name.mkStr declName "noConfusion"
-- We take the level names from `.rec`, as that conveniently has an extra level parameter that
-- is distinct from the ones from the inductive
let (v::us) := recInfo.levelParams.map mkLevelParam | throwError "unexpected number of level parameters in {recInfo.name}"
for ctor in indVal.ctors do
let ctorInfo ← getConstInfoCtor ctor
if ctorInfo.numFields > 0 then
let e ← withLocalDeclD `P (.sort v) fun P =>
forallBoundedTelescope ctorInfo.type ctorInfo.numParams fun xs _ => do
let ctorApp := mkAppN (mkConst ctor us) xs
withSharedCtorIndices ctorApp fun ys indices fields1 fields2 => do
let ctor1 := mkAppN ctorApp fields1
let ctor2 := mkAppN ctorApp fields2
let heqType ← mkEq ctor1 ctor2
withLocalDeclD `h heqType fun h => do
-- When the kernel checks this definitios, it will perform the potentially expensive
-- computation that `noConfusionType h` is equal to `$kType → P`
let kType ← mkNoConfusionCtorArg ctor P
let kType := kType.beta (xs ++ fields1 ++ fields2)
withLocalDeclD `k kType fun k =>
let e := mkConst noConfusionName (v :: us)
let e := mkAppN e (xs ++ indices ++ #[P, ctor1, ctor2, h, k])
mkLambdaFVars (xs ++ #[P] ++ ys ++ #[h, k]) e
let name := ctor.str "noConfusion"
addDecl (.defnDecl (← mkDefinitionValInferringUnsafe
(name := name)
(levelParams := recInfo.levelParams)
(type := (← inferType e))
(value := e)
(hints := ReducibilityHints.abbrev)
))
setReducibleAttribute name
-- The compiler has special support for `noConfusion`. So lets mark this as
-- macroInline to not generate code for all these extra definitions, and instead
-- let the compiler unfold this to then put the custom code there
setInlineAttribute name (kind := .macroInline)
def mkNoConfusionCore (declName : Name) : MetaM Unit := do
-- Do not do anything unless can_elim_to_type. TODO: Extract to util
@ -204,7 +271,7 @@ def mkNoConfusionCore (declName : Name) : MetaM Unit := do
mkNoConfusionType declName
mkNoConfusionCoreImp declName
mkNoConfusionCtors declName
def mkNoConfusionEnum (enumName : Name) : MetaM Unit := do
if (← getEnv).contains ``noConfusionEnum then
@ -278,6 +345,7 @@ public def mkNoConfusion (declName : Name) : MetaM Unit := do
else
mkNoConfusionCore declName
builtin_initialize
registerTraceClass `Meta.mkNoConfusion

View file

@ -0,0 +1,205 @@
inductive L (α : Type u) : Type u where
| nil : L α
| cons (x : α) (xs : L α) : L α
/--
info: @[reducible] def L.cons.noConfusion.{u_1, u} : {α : Type u} →
(P : Sort u_1) →
(x : α) → (xs : L α) → (x' : α) → (xs' : L α) → L.cons x xs = L.cons x' xs' → (x = x' → xs = xs' → P) → P
-/
#guard_msgs in
#print sig L.cons.noConfusion
inductive Vec (α : Type u) : Nat → Type u where
| nil : Vec α 0
| cons : {n : Nat} → (x : α) → (xs : Vec α n) → Vec α (n + 1)
/--
info: @[reducible] def Vec.cons.noConfusion.{u_1, u} : {α : Type u} →
(P : Sort u_1) →
{n : Nat} →
(x : α) →
(xs : Vec α n) →
(x' : α) → (xs' : Vec α n) → Vec.cons x xs = Vec.cons x' xs' → (n = n → x = x' → xs ≍ xs' → P) → P
-/
#guard_msgs in
#print sig Vec.cons.noConfusion
inductive I : (n : Nat) → Type where
| mk n : (b : Bool) → I (n / 2)
/--
info: @[reducible] def I.mk.noConfusion.{u} : (P : Sort u) →
(n : Nat) → (b b' : Bool) → I.mk n b = I.mk n b' → (n = n → b = b' → P) → P
-/
#guard_msgs in #print sig I.mk.noConfusion
inductive WithDep {α : Type u} (β : α → Type v) : Type (max u v) where
| intro (a : α) (b : β a) : WithDep β
/--
info: @[reducible] def WithDep.intro.noConfusion.{u_1, u, v} : {α : Type u} →
{β : α → Type v} →
(P : Sort u_1) →
(a : α) → (b : β a) → (a' : α) → (b' : β a') → WithDep.intro a b = WithDep.intro a' b' → (a = a' → b ≍ b' → P) → P
-/
#guard_msgs in #print sig WithDep.intro.noConfusion
-- Copy of 3386
-- This is a tricky case: `Tmₛ {T1 A1} a1 arg1 = Tmₛ {T2 A2} a2 arg2` only type checks if
-- `A1 = A2` and `arg1 = arg1`. The latter requires `T1 = T2`, even though `T` does not seem to
-- appear in the result type of `Tmₐ.app`.
inductive Tyₛ : Type (u+1)
| SPi : (T : Type u) -> (T -> Tyₛ) -> Tyₛ
/--
info: @[reducible] def Tyₛ.SPi.noConfusion.{u_1, u} : (P : Sort u_1) →
(T : Type u) →
(a : T → Tyₛ) → (T' : Type u) → (a' : T' → Tyₛ) → Tyₛ.SPi T a = Tyₛ.SPi T' a' → (T = T' → a ≍ a' → P) → P
-/
#guard_msgs in #print sig Tyₛ.SPi.noConfusion
inductive Tmₛ.{u} : Tyₛ.{u} -> Type (u+1)
| app : Tmₛ (.SPi T A) -> (arg : T) -> Tmₛ (A arg)
set_option pp.explicit true in
/--
info: constructor Tmₛ.app.{u} : {T : Type u} → {A : T → Tyₛ} → Tmₛ (Tyₛ.SPi T A) → (arg : T) → Tmₛ (A arg)
-/
#guard_msgs in
#print sig Tmₛ.app
/--
info: @[reducible] def Tmₛ.app.noConfusion.{u_1, u} : (P : Sort u_1) →
{T : Type u} →
{A : T → Tyₛ} →
(a : Tmₛ (Tyₛ.SPi T A)) →
(arg : T) → (a' : Tmₛ (Tyₛ.SPi T A)) → a.app arg = a'.app arg → (T = T → A ≍ A → a ≍ a' → arg ≍ arg → P) → P :=
fun P {T} {A} a arg a' h k => Tmₛ.noConfusion h k
-/
#guard_msgs in #print Tmₛ.app.noConfusion
unsafe inductive U : Type | mk : (U → U) → U
/--
info: @[reducible] unsafe def U.mk.noConfusion.{u} : (P : Sort u) → (a a' : U → U) → U.mk a = U.mk a' → (a = a' → P) → P
-/
#guard_msgs in #print sig U.mk.noConfusion
-- More tests suggested by Claude
-- Test 2: Indexed family with complex indices
inductive Matrix (α : Type u) : Nat → Nat → Type u where
| empty : Matrix α 0 0
| row (n m : Nat) (v : Vector α n) (rest : Matrix α m n) : Matrix α (m + 1) n
/--
info: @[reducible] def Matrix.row.noConfusion.{u_1, u} : {α : Type u} →
(P : Sort u_1) →
(n m : Nat) →
(v : Vector α n) →
(rest : Matrix α m n) →
(v' : Vector α n) →
(rest' : Matrix α m n) →
Matrix.row n m v rest = Matrix.row n m v' rest' → (n = n → m = m → v ≍ v' → rest ≍ rest' → P) → P
-/
#guard_msgs in #print sig Matrix.row.noConfusion
-- Test 3: Mutual inductive types
mutual
inductive Tree (α : Type u) : Type u where
| leaf (val : α) : Tree α
| node (forest : Forest α) : Tree α
inductive Forest (α : Type u) : Type u where
| empty : Forest α
| cons (tree : Tree α) (rest : Forest α) : Forest α
end
-- Test 4: Higher-order inductive with function types
inductive HigherOrder (α : Type) : Type 1 where
| base (x : α) : HigherOrder α
| func (f : α → HigherOrder α) : HigherOrder α
-- Test noConfusion with function arguments
/--
info: @[reducible] def HigherOrder.base.noConfusion.{u} : {α : Type} →
(P : Sort u) → (x x' : α) → HigherOrder.base x = HigherOrder.base x' → (x = x' → P) → P
-/
#guard_msgs in #print sig HigherOrder.base.noConfusion
/--
info: @[reducible] def HigherOrder.func.noConfusion.{u} : {α : Type} →
(P : Sort u) → (f f' : α → HigherOrder α) → HigherOrder.func f = HigherOrder.func f' → (f = f' → P) → P
-/
#guard_msgs in #print sig HigherOrder.func.noConfusion
-- Test 5: Nested inductive with complex dependency
inductive Nested : Type 1 where
| simple (n : Nat) : Nested
| complex (inner : List Nested) : Nested
-- Test recursive nesting in noConfusion
/--
info: @[reducible] def Nested.simple.noConfusion.{u} : (P : Sort u) →
(n n' : Nat) → Nested.simple n = Nested.simple n' → (n = n' → P) → P
-/
#guard_msgs in #print sig Nested.simple.noConfusion
/--
info: @[reducible] def Nested.complex.noConfusion.{u} : (P : Sort u) →
(inner inner' : List Nested) → Nested.complex inner = Nested.complex inner' → (inner = inner' → P) → P
-/
#guard_msgs in #print sig Nested.complex.noConfusion
-- Test 6: Inductive with universe polymorphism
inductive UnivPoly.{u, v} (α : Type u) (β : Type v) : Type (max u v) where
| left (a : α) : UnivPoly α β
| right (b : β) : UnivPoly α β
| both (a : α) (b : β) : UnivPoly α β
-- Test universe-polymorphic noConfusion
/--
info: @[reducible] def UnivPoly.left.noConfusion.{u_1, u, v} : {α : Type u} →
{β : Type v} → (P : Sort u_1) → (a a' : α) → UnivPoly.left a = UnivPoly.left a' → (a = a' → P) → P
-/
#guard_msgs in #print sig UnivPoly.left.noConfusion
/--
info: @[reducible] def UnivPoly.right.noConfusion.{u_1, u, v} : {α : Type u} →
{β : Type v} → (P : Sort u_1) → (b b' : β) → UnivPoly.right b = UnivPoly.right b' → (b = b' → P) → P
-/
#guard_msgs in #print sig UnivPoly.right.noConfusion
/--
info: @[reducible] def UnivPoly.both.noConfusion.{u_1, u, v} : {α : Type u} →
{β : Type v} →
(P : Sort u_1) →
(a : α) → (b : β) → (a' : α) → (b' : β) → UnivPoly.both a b = UnivPoly.both a' b' → (a = a' → b = b' → P) → P
-/
#guard_msgs in #print sig UnivPoly.both.noConfusion
-- Test 7: Inductive with implicit arguments and type classes
inductive WithTypeClass (α : Type u) [Inhabited α] : Type u where
| default : WithTypeClass α
| custom (val : α) : WithTypeClass α
-- Test 8: Very complex indexed family with dependent types
inductive ComplexVec (α : Type u) : (n : Nat) → (valid : n > 0) → Type u where
| single (x : α) : ComplexVec α 1 (by simp)
| extend {n : Nat} {h : n > 0} (x : α) (rest : ComplexVec α n h) :
ComplexVec α (n + 1) (by simp)
/--
info: @[reducible] def ComplexVec.extend.noConfusion.{u_1, u} : {α : Type u} →
(P : Sort u_1) →
{n : Nat} →
{h : n > 0} →
(x : α) →
(rest : ComplexVec α n h) →
(h' : n > 0) →
(x' : α) →
(rest' : ComplexVec α n h') →
ComplexVec.extend x rest = ComplexVec.extend x' rest' → (n = n → x = x' → rest ≍ rest' → P) → P
-/
#guard_msgs in #print sig ComplexVec.extend.noConfusion