From 3cf7fdcbe02e8065c9c9bef61d7d334d05f134f6 Mon Sep 17 00:00:00 2001 From: Joachim Breitner Date: Fri, 12 Sep 2025 10:00:12 +0200 Subject: [PATCH] feat: per-constructor noConfusion constructions (#10315) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds `T.ctor.noConfusion` declarations, which are specializations of `T.noConfusion` to equalities between `T.ctor`. The point is to avoid reducing the `T.noConfusionType` construction every time we use `injection` or a similar tactic. ```lean Vec.cons.noConfusion.{u_1, u} {α : Type u} (P : Sort u_1) {n : Nat} (x : α) (xs : Vec α n) (x' : α) (xs' : Vec α n) (h : Vec.cons x xs = Vec.cons x' xs') (k : n = n → x = x' → xs ≍ xs' → P) : P ``` The constructions are not as powerful as `T.noConfusion` when the indices of the inductive type are not just constructor parameters (or constructor applications of these parameters), so the full `T.noConfusion` construction is still needed as a fallback. It may seem costly to generate these eagerly, but given that we eagerly generate injectivity theorems already, and we will use them there, it seems reasonable for now. To further reduce the cost, we only generate them for constructors with fields (for others, the `T.noConfusion` theorem doesn't provide any information), and we use `macro_inline` to prevent the compiler from creating code for these, given that the compiler has special support for `T.noConfusion` that we want it to use). An earlier version of this PR also removed trivial equations and un-HEq-ed others, leading to ``` (k : x = x' → xs = xs' → P) ``` in the example above. I backed out of that change, as it makes it harder for tactics like `injectivity` to know how often to `intro`, so better to keep things uniform. --- src/Lean/Meta/Constructions/NoConfusion.lean | 74 ++++++- tests/lean/run/noConfusionCtors.lean | 205 +++++++++++++++++++ 2 files changed, 276 insertions(+), 3 deletions(-) create mode 100644 tests/lean/run/noConfusionCtors.lean diff --git a/src/Lean/Meta/Constructions/NoConfusion.lean b/src/Lean/Meta/Constructions/NoConfusion.lean index c7ecdbba93..167f9067e2 100644 --- a/src/Lean/Meta/Constructions/NoConfusion.lean +++ b/src/Lean/Meta/Constructions/NoConfusion.lean @@ -12,6 +12,7 @@ import Lean.Meta.AppBuilder import Lean.Meta.CompletionName import Lean.Meta.NatTable import Lean.Meta.Constructions.CtorIdx +import Lean.Meta.SameCtorUtils import Lean.Meta.Constructions.CtorIdx import Lean.Meta.Constructions.CtorElim @@ -28,7 +29,7 @@ fun params x1 x2 x3 x1' x2' x3' => (x1 = x1' → x2 = x2' → x3 = x3' → P) where `x1 x2 x3` and `x1' x2' x3'` are the fields of a constructor application of `ctorName`, omitting equalities between propositions and using `HEq` where needed. -/ -public def mkNoConfusionCtorArg (ctorName : Name) (P : Expr) : MetaM Expr := do +def mkNoConfusionCtorArg (ctorName : Name) (P : Expr) : MetaM Expr := do let ctorInfo ← getConstInfoCtor ctorName -- We bring the constructor's parameters into scope abstractly, this way -- we can check if we need to use HEq. (The concrete fields could allow Eq) @@ -68,7 +69,7 @@ def mkIfNatEq (P : Expr) (e1 e2 : Expr) («then» : Expr → MetaM Expr) («else let e := mkApp e (← withLocalDeclD `h (mkNot heq) (fun h => do mkLambdaFVars #[h] (← «else» h))) pure e -public def mkNoConfusionType (indName : Name) : MetaM Unit := do +def mkNoConfusionType (indName : Name) : MetaM Unit := do let declName := mkNoConfusionTypeName indName let ConstantInfo.inductInfo info ← getConstInfo indName | unreachable! let useLinearConstruction := @@ -194,6 +195,72 @@ def mkNoConfusionCoreImp (indName : Name) : MetaM Unit := do modifyEnv fun env => markNoConfusion env declName modifyEnv fun env => addProtected env declName +/-- +Creates per-constructor no-confusion definitions. These specialize the general noConfusion +declaration to equalities between two applications of the same constructor, to effectively cache +the computation of `noConfusionType` for that constructor: + +``` +def L.cons.noConfusion.{u_1, u} : {α : Type u} → (P : Sort u_1) → + (x : α) → (xs : L α) → (x' : α) → (xs' : L α) → + L.cons x xs = L.cons x' xs' → + (x = x' → xs = xs' → P) → + P +``` + +These definitions are less expressive than the general `noConfusion` principle when there are +complicated indices. In particular they assume that all fields of the constructor that appear +in its type are equal already. The `mkNoConfusion` app builder falls back to the general principle +if the per-constructor one does not apply. + +At some point I tried to be clever and remove hypotheses that are trivial (`n = n →`), but that +made it harder for, say, `injection` to know how often to `intro`. So we just keep them. +-/ +def mkNoConfusionCtors (declName : Name) : MetaM Unit := do + -- Do not do anything unless can_elim_to_type. + let .inductInfo indVal ← getConstInfo declName | return + let recInfo ← getConstInfo (mkRecName declName) + unless recInfo.levelParams.length > indVal.levelParams.length do return + if (← isPropFormerType indVal.type) then return + let noConfusionName := Name.mkStr declName "noConfusion" + + -- We take the level names from `.rec`, as that conveniently has an extra level parameter that + -- is distinct from the ones from the inductive + let (v::us) := recInfo.levelParams.map mkLevelParam | throwError "unexpected number of level parameters in {recInfo.name}" + + for ctor in indVal.ctors do + let ctorInfo ← getConstInfoCtor ctor + if ctorInfo.numFields > 0 then + let e ← withLocalDeclD `P (.sort v) fun P => + forallBoundedTelescope ctorInfo.type ctorInfo.numParams fun xs _ => do + let ctorApp := mkAppN (mkConst ctor us) xs + withSharedCtorIndices ctorApp fun ys indices fields1 fields2 => do + let ctor1 := mkAppN ctorApp fields1 + let ctor2 := mkAppN ctorApp fields2 + let heqType ← mkEq ctor1 ctor2 + withLocalDeclD `h heqType fun h => do + -- When the kernel checks this definitios, it will perform the potentially expensive + -- computation that `noConfusionType h` is equal to `$kType → P` + let kType ← mkNoConfusionCtorArg ctor P + let kType := kType.beta (xs ++ fields1 ++ fields2) + withLocalDeclD `k kType fun k => + let e := mkConst noConfusionName (v :: us) + let e := mkAppN e (xs ++ indices ++ #[P, ctor1, ctor2, h, k]) + mkLambdaFVars (xs ++ #[P] ++ ys ++ #[h, k]) e + let name := ctor.str "noConfusion" + addDecl (.defnDecl (← mkDefinitionValInferringUnsafe + (name := name) + (levelParams := recInfo.levelParams) + (type := (← inferType e)) + (value := e) + (hints := ReducibilityHints.abbrev) + )) + setReducibleAttribute name + -- The compiler has special support for `noConfusion`. So lets mark this as + -- macroInline to not generate code for all these extra definitions, and instead + -- let the compiler unfold this to then put the custom code there + setInlineAttribute name (kind := .macroInline) + def mkNoConfusionCore (declName : Name) : MetaM Unit := do -- Do not do anything unless can_elim_to_type. TODO: Extract to util @@ -204,7 +271,7 @@ def mkNoConfusionCore (declName : Name) : MetaM Unit := do mkNoConfusionType declName mkNoConfusionCoreImp declName - + mkNoConfusionCtors declName def mkNoConfusionEnum (enumName : Name) : MetaM Unit := do if (← getEnv).contains ``noConfusionEnum then @@ -278,6 +345,7 @@ public def mkNoConfusion (declName : Name) : MetaM Unit := do else mkNoConfusionCore declName + builtin_initialize registerTraceClass `Meta.mkNoConfusion diff --git a/tests/lean/run/noConfusionCtors.lean b/tests/lean/run/noConfusionCtors.lean new file mode 100644 index 0000000000..dc624b9843 --- /dev/null +++ b/tests/lean/run/noConfusionCtors.lean @@ -0,0 +1,205 @@ +inductive L (α : Type u) : Type u where + | nil : L α + | cons (x : α) (xs : L α) : L α + +/-- +info: @[reducible] def L.cons.noConfusion.{u_1, u} : {α : Type u} → + (P : Sort u_1) → + (x : α) → (xs : L α) → (x' : α) → (xs' : L α) → L.cons x xs = L.cons x' xs' → (x = x' → xs = xs' → P) → P +-/ +#guard_msgs in +#print sig L.cons.noConfusion + +inductive Vec (α : Type u) : Nat → Type u where + | nil : Vec α 0 + | cons : {n : Nat} → (x : α) → (xs : Vec α n) → Vec α (n + 1) + +/-- +info: @[reducible] def Vec.cons.noConfusion.{u_1, u} : {α : Type u} → + (P : Sort u_1) → + {n : Nat} → + (x : α) → + (xs : Vec α n) → + (x' : α) → (xs' : Vec α n) → Vec.cons x xs = Vec.cons x' xs' → (n = n → x = x' → xs ≍ xs' → P) → P +-/ +#guard_msgs in +#print sig Vec.cons.noConfusion + +inductive I : (n : Nat) → Type where + | mk n : (b : Bool) → I (n / 2) + +/-- +info: @[reducible] def I.mk.noConfusion.{u} : (P : Sort u) → + (n : Nat) → (b b' : Bool) → I.mk n b = I.mk n b' → (n = n → b = b' → P) → P +-/ +#guard_msgs in #print sig I.mk.noConfusion + + +inductive WithDep {α : Type u} (β : α → Type v) : Type (max u v) where + | intro (a : α) (b : β a) : WithDep β + +/-- +info: @[reducible] def WithDep.intro.noConfusion.{u_1, u, v} : {α : Type u} → + {β : α → Type v} → + (P : Sort u_1) → + (a : α) → (b : β a) → (a' : α) → (b' : β a') → WithDep.intro a b = WithDep.intro a' b' → (a = a' → b ≍ b' → P) → P +-/ +#guard_msgs in #print sig WithDep.intro.noConfusion + +-- Copy of 3386 +-- This is a tricky case: `Tmₛ {T1 A1} a1 arg1 = Tmₛ {T2 A2} a2 arg2` only type checks if +-- `A1 = A2` and `arg1 = arg1`. The latter requires `T1 = T2`, even though `T` does not seem to +-- appear in the result type of `Tmₐ.app`. + +inductive Tyₛ : Type (u+1) +| SPi : (T : Type u) -> (T -> Tyₛ) -> Tyₛ + +/-- +info: @[reducible] def Tyₛ.SPi.noConfusion.{u_1, u} : (P : Sort u_1) → + (T : Type u) → + (a : T → Tyₛ) → (T' : Type u) → (a' : T' → Tyₛ) → Tyₛ.SPi T a = Tyₛ.SPi T' a' → (T = T' → a ≍ a' → P) → P +-/ +#guard_msgs in #print sig Tyₛ.SPi.noConfusion + +inductive Tmₛ.{u} : Tyₛ.{u} -> Type (u+1) +| app : Tmₛ (.SPi T A) -> (arg : T) -> Tmₛ (A arg) + +set_option pp.explicit true in +/-- +info: constructor Tmₛ.app.{u} : {T : Type u} → {A : T → Tyₛ} → Tmₛ (Tyₛ.SPi T A) → (arg : T) → Tmₛ (A arg) +-/ +#guard_msgs in +#print sig Tmₛ.app + + +/-- +info: @[reducible] def Tmₛ.app.noConfusion.{u_1, u} : (P : Sort u_1) → + {T : Type u} → + {A : T → Tyₛ} → + (a : Tmₛ (Tyₛ.SPi T A)) → + (arg : T) → (a' : Tmₛ (Tyₛ.SPi T A)) → a.app arg = a'.app arg → (T = T → A ≍ A → a ≍ a' → arg ≍ arg → P) → P := +fun P {T} {A} a arg a' h k => Tmₛ.noConfusion h k +-/ +#guard_msgs in #print Tmₛ.app.noConfusion + + +unsafe inductive U : Type | mk : (U → U) → U + +/-- +info: @[reducible] unsafe def U.mk.noConfusion.{u} : (P : Sort u) → (a a' : U → U) → U.mk a = U.mk a' → (a = a' → P) → P +-/ +#guard_msgs in #print sig U.mk.noConfusion + +-- More tests suggested by Claude + +-- Test 2: Indexed family with complex indices +inductive Matrix (α : Type u) : Nat → Nat → Type u where + | empty : Matrix α 0 0 + | row (n m : Nat) (v : Vector α n) (rest : Matrix α m n) : Matrix α (m + 1) n + +/-- +info: @[reducible] def Matrix.row.noConfusion.{u_1, u} : {α : Type u} → + (P : Sort u_1) → + (n m : Nat) → + (v : Vector α n) → + (rest : Matrix α m n) → + (v' : Vector α n) → + (rest' : Matrix α m n) → + Matrix.row n m v rest = Matrix.row n m v' rest' → (n = n → m = m → v ≍ v' → rest ≍ rest' → P) → P +-/ +#guard_msgs in #print sig Matrix.row.noConfusion + +-- Test 3: Mutual inductive types +mutual + inductive Tree (α : Type u) : Type u where + | leaf (val : α) : Tree α + | node (forest : Forest α) : Tree α + + inductive Forest (α : Type u) : Type u where + | empty : Forest α + | cons (tree : Tree α) (rest : Forest α) : Forest α +end + +-- Test 4: Higher-order inductive with function types +inductive HigherOrder (α : Type) : Type 1 where + | base (x : α) : HigherOrder α + | func (f : α → HigherOrder α) : HigherOrder α + +-- Test noConfusion with function arguments +/-- +info: @[reducible] def HigherOrder.base.noConfusion.{u} : {α : Type} → + (P : Sort u) → (x x' : α) → HigherOrder.base x = HigherOrder.base x' → (x = x' → P) → P +-/ +#guard_msgs in #print sig HigherOrder.base.noConfusion +/-- +info: @[reducible] def HigherOrder.func.noConfusion.{u} : {α : Type} → + (P : Sort u) → (f f' : α → HigherOrder α) → HigherOrder.func f = HigherOrder.func f' → (f = f' → P) → P +-/ +#guard_msgs in #print sig HigherOrder.func.noConfusion + +-- Test 5: Nested inductive with complex dependency +inductive Nested : Type 1 where + | simple (n : Nat) : Nested + | complex (inner : List Nested) : Nested + +-- Test recursive nesting in noConfusion +/-- +info: @[reducible] def Nested.simple.noConfusion.{u} : (P : Sort u) → + (n n' : Nat) → Nested.simple n = Nested.simple n' → (n = n' → P) → P +-/ +#guard_msgs in #print sig Nested.simple.noConfusion +/-- +info: @[reducible] def Nested.complex.noConfusion.{u} : (P : Sort u) → + (inner inner' : List Nested) → Nested.complex inner = Nested.complex inner' → (inner = inner' → P) → P +-/ +#guard_msgs in #print sig Nested.complex.noConfusion + +-- Test 6: Inductive with universe polymorphism +inductive UnivPoly.{u, v} (α : Type u) (β : Type v) : Type (max u v) where + | left (a : α) : UnivPoly α β + | right (b : β) : UnivPoly α β + | both (a : α) (b : β) : UnivPoly α β + +-- Test universe-polymorphic noConfusion +/-- +info: @[reducible] def UnivPoly.left.noConfusion.{u_1, u, v} : {α : Type u} → + {β : Type v} → (P : Sort u_1) → (a a' : α) → UnivPoly.left a = UnivPoly.left a' → (a = a' → P) → P +-/ +#guard_msgs in #print sig UnivPoly.left.noConfusion +/-- +info: @[reducible] def UnivPoly.right.noConfusion.{u_1, u, v} : {α : Type u} → + {β : Type v} → (P : Sort u_1) → (b b' : β) → UnivPoly.right b = UnivPoly.right b' → (b = b' → P) → P +-/ +#guard_msgs in #print sig UnivPoly.right.noConfusion +/-- +info: @[reducible] def UnivPoly.both.noConfusion.{u_1, u, v} : {α : Type u} → + {β : Type v} → + (P : Sort u_1) → + (a : α) → (b : β) → (a' : α) → (b' : β) → UnivPoly.both a b = UnivPoly.both a' b' → (a = a' → b = b' → P) → P +-/ +#guard_msgs in #print sig UnivPoly.both.noConfusion + +-- Test 7: Inductive with implicit arguments and type classes +inductive WithTypeClass (α : Type u) [Inhabited α] : Type u where + | default : WithTypeClass α + | custom (val : α) : WithTypeClass α + +-- Test 8: Very complex indexed family with dependent types +inductive ComplexVec (α : Type u) : (n : Nat) → (valid : n > 0) → Type u where + | single (x : α) : ComplexVec α 1 (by simp) + | extend {n : Nat} {h : n > 0} (x : α) (rest : ComplexVec α n h) : + ComplexVec α (n + 1) (by simp) + +/-- +info: @[reducible] def ComplexVec.extend.noConfusion.{u_1, u} : {α : Type u} → + (P : Sort u_1) → + {n : Nat} → + {h : n > 0} → + (x : α) → + (rest : ComplexVec α n h) → + (h' : n > 0) → + (x' : α) → + (rest' : ComplexVec α n h') → + ComplexVec.extend x rest = ComplexVec.extend x' rest' → (n = n → x = x' → rest ≍ rest' → P) → P +-/ +#guard_msgs in #print sig ComplexVec.extend.noConfusion