diff --git a/src/Lean/Meta/Constructions/NoConfusion.lean b/src/Lean/Meta/Constructions/NoConfusion.lean index c7ecdbba93..167f9067e2 100644 --- a/src/Lean/Meta/Constructions/NoConfusion.lean +++ b/src/Lean/Meta/Constructions/NoConfusion.lean @@ -12,6 +12,7 @@ import Lean.Meta.AppBuilder import Lean.Meta.CompletionName import Lean.Meta.NatTable import Lean.Meta.Constructions.CtorIdx +import Lean.Meta.SameCtorUtils import Lean.Meta.Constructions.CtorIdx import Lean.Meta.Constructions.CtorElim @@ -28,7 +29,7 @@ fun params x1 x2 x3 x1' x2' x3' => (x1 = x1' → x2 = x2' → x3 = x3' → P) where `x1 x2 x3` and `x1' x2' x3'` are the fields of a constructor application of `ctorName`, omitting equalities between propositions and using `HEq` where needed. -/ -public def mkNoConfusionCtorArg (ctorName : Name) (P : Expr) : MetaM Expr := do +def mkNoConfusionCtorArg (ctorName : Name) (P : Expr) : MetaM Expr := do let ctorInfo ← getConstInfoCtor ctorName -- We bring the constructor's parameters into scope abstractly, this way -- we can check if we need to use HEq. (The concrete fields could allow Eq) @@ -68,7 +69,7 @@ def mkIfNatEq (P : Expr) (e1 e2 : Expr) («then» : Expr → MetaM Expr) («else let e := mkApp e (← withLocalDeclD `h (mkNot heq) (fun h => do mkLambdaFVars #[h] (← «else» h))) pure e -public def mkNoConfusionType (indName : Name) : MetaM Unit := do +def mkNoConfusionType (indName : Name) : MetaM Unit := do let declName := mkNoConfusionTypeName indName let ConstantInfo.inductInfo info ← getConstInfo indName | unreachable! let useLinearConstruction := @@ -194,6 +195,72 @@ def mkNoConfusionCoreImp (indName : Name) : MetaM Unit := do modifyEnv fun env => markNoConfusion env declName modifyEnv fun env => addProtected env declName +/-- +Creates per-constructor no-confusion definitions. These specialize the general noConfusion +declaration to equalities between two applications of the same constructor, to effectively cache +the computation of `noConfusionType` for that constructor: + +``` +def L.cons.noConfusion.{u_1, u} : {α : Type u} → (P : Sort u_1) → + (x : α) → (xs : L α) → (x' : α) → (xs' : L α) → + L.cons x xs = L.cons x' xs' → + (x = x' → xs = xs' → P) → + P +``` + +These definitions are less expressive than the general `noConfusion` principle when there are +complicated indices. In particular they assume that all fields of the constructor that appear +in its type are equal already. The `mkNoConfusion` app builder falls back to the general principle +if the per-constructor one does not apply. + +At some point I tried to be clever and remove hypotheses that are trivial (`n = n →`), but that +made it harder for, say, `injection` to know how often to `intro`. So we just keep them. +-/ +def mkNoConfusionCtors (declName : Name) : MetaM Unit := do + -- Do not do anything unless can_elim_to_type. + let .inductInfo indVal ← getConstInfo declName | return + let recInfo ← getConstInfo (mkRecName declName) + unless recInfo.levelParams.length > indVal.levelParams.length do return + if (← isPropFormerType indVal.type) then return + let noConfusionName := Name.mkStr declName "noConfusion" + + -- We take the level names from `.rec`, as that conveniently has an extra level parameter that + -- is distinct from the ones from the inductive + let (v::us) := recInfo.levelParams.map mkLevelParam | throwError "unexpected number of level parameters in {recInfo.name}" + + for ctor in indVal.ctors do + let ctorInfo ← getConstInfoCtor ctor + if ctorInfo.numFields > 0 then + let e ← withLocalDeclD `P (.sort v) fun P => + forallBoundedTelescope ctorInfo.type ctorInfo.numParams fun xs _ => do + let ctorApp := mkAppN (mkConst ctor us) xs + withSharedCtorIndices ctorApp fun ys indices fields1 fields2 => do + let ctor1 := mkAppN ctorApp fields1 + let ctor2 := mkAppN ctorApp fields2 + let heqType ← mkEq ctor1 ctor2 + withLocalDeclD `h heqType fun h => do + -- When the kernel checks this definitios, it will perform the potentially expensive + -- computation that `noConfusionType h` is equal to `$kType → P` + let kType ← mkNoConfusionCtorArg ctor P + let kType := kType.beta (xs ++ fields1 ++ fields2) + withLocalDeclD `k kType fun k => + let e := mkConst noConfusionName (v :: us) + let e := mkAppN e (xs ++ indices ++ #[P, ctor1, ctor2, h, k]) + mkLambdaFVars (xs ++ #[P] ++ ys ++ #[h, k]) e + let name := ctor.str "noConfusion" + addDecl (.defnDecl (← mkDefinitionValInferringUnsafe + (name := name) + (levelParams := recInfo.levelParams) + (type := (← inferType e)) + (value := e) + (hints := ReducibilityHints.abbrev) + )) + setReducibleAttribute name + -- The compiler has special support for `noConfusion`. So lets mark this as + -- macroInline to not generate code for all these extra definitions, and instead + -- let the compiler unfold this to then put the custom code there + setInlineAttribute name (kind := .macroInline) + def mkNoConfusionCore (declName : Name) : MetaM Unit := do -- Do not do anything unless can_elim_to_type. TODO: Extract to util @@ -204,7 +271,7 @@ def mkNoConfusionCore (declName : Name) : MetaM Unit := do mkNoConfusionType declName mkNoConfusionCoreImp declName - + mkNoConfusionCtors declName def mkNoConfusionEnum (enumName : Name) : MetaM Unit := do if (← getEnv).contains ``noConfusionEnum then @@ -278,6 +345,7 @@ public def mkNoConfusion (declName : Name) : MetaM Unit := do else mkNoConfusionCore declName + builtin_initialize registerTraceClass `Meta.mkNoConfusion diff --git a/tests/lean/run/noConfusionCtors.lean b/tests/lean/run/noConfusionCtors.lean new file mode 100644 index 0000000000..dc624b9843 --- /dev/null +++ b/tests/lean/run/noConfusionCtors.lean @@ -0,0 +1,205 @@ +inductive L (α : Type u) : Type u where + | nil : L α + | cons (x : α) (xs : L α) : L α + +/-- +info: @[reducible] def L.cons.noConfusion.{u_1, u} : {α : Type u} → + (P : Sort u_1) → + (x : α) → (xs : L α) → (x' : α) → (xs' : L α) → L.cons x xs = L.cons x' xs' → (x = x' → xs = xs' → P) → P +-/ +#guard_msgs in +#print sig L.cons.noConfusion + +inductive Vec (α : Type u) : Nat → Type u where + | nil : Vec α 0 + | cons : {n : Nat} → (x : α) → (xs : Vec α n) → Vec α (n + 1) + +/-- +info: @[reducible] def Vec.cons.noConfusion.{u_1, u} : {α : Type u} → + (P : Sort u_1) → + {n : Nat} → + (x : α) → + (xs : Vec α n) → + (x' : α) → (xs' : Vec α n) → Vec.cons x xs = Vec.cons x' xs' → (n = n → x = x' → xs ≍ xs' → P) → P +-/ +#guard_msgs in +#print sig Vec.cons.noConfusion + +inductive I : (n : Nat) → Type where + | mk n : (b : Bool) → I (n / 2) + +/-- +info: @[reducible] def I.mk.noConfusion.{u} : (P : Sort u) → + (n : Nat) → (b b' : Bool) → I.mk n b = I.mk n b' → (n = n → b = b' → P) → P +-/ +#guard_msgs in #print sig I.mk.noConfusion + + +inductive WithDep {α : Type u} (β : α → Type v) : Type (max u v) where + | intro (a : α) (b : β a) : WithDep β + +/-- +info: @[reducible] def WithDep.intro.noConfusion.{u_1, u, v} : {α : Type u} → + {β : α → Type v} → + (P : Sort u_1) → + (a : α) → (b : β a) → (a' : α) → (b' : β a') → WithDep.intro a b = WithDep.intro a' b' → (a = a' → b ≍ b' → P) → P +-/ +#guard_msgs in #print sig WithDep.intro.noConfusion + +-- Copy of 3386 +-- This is a tricky case: `Tmₛ {T1 A1} a1 arg1 = Tmₛ {T2 A2} a2 arg2` only type checks if +-- `A1 = A2` and `arg1 = arg1`. The latter requires `T1 = T2`, even though `T` does not seem to +-- appear in the result type of `Tmₐ.app`. + +inductive Tyₛ : Type (u+1) +| SPi : (T : Type u) -> (T -> Tyₛ) -> Tyₛ + +/-- +info: @[reducible] def Tyₛ.SPi.noConfusion.{u_1, u} : (P : Sort u_1) → + (T : Type u) → + (a : T → Tyₛ) → (T' : Type u) → (a' : T' → Tyₛ) → Tyₛ.SPi T a = Tyₛ.SPi T' a' → (T = T' → a ≍ a' → P) → P +-/ +#guard_msgs in #print sig Tyₛ.SPi.noConfusion + +inductive Tmₛ.{u} : Tyₛ.{u} -> Type (u+1) +| app : Tmₛ (.SPi T A) -> (arg : T) -> Tmₛ (A arg) + +set_option pp.explicit true in +/-- +info: constructor Tmₛ.app.{u} : {T : Type u} → {A : T → Tyₛ} → Tmₛ (Tyₛ.SPi T A) → (arg : T) → Tmₛ (A arg) +-/ +#guard_msgs in +#print sig Tmₛ.app + + +/-- +info: @[reducible] def Tmₛ.app.noConfusion.{u_1, u} : (P : Sort u_1) → + {T : Type u} → + {A : T → Tyₛ} → + (a : Tmₛ (Tyₛ.SPi T A)) → + (arg : T) → (a' : Tmₛ (Tyₛ.SPi T A)) → a.app arg = a'.app arg → (T = T → A ≍ A → a ≍ a' → arg ≍ arg → P) → P := +fun P {T} {A} a arg a' h k => Tmₛ.noConfusion h k +-/ +#guard_msgs in #print Tmₛ.app.noConfusion + + +unsafe inductive U : Type | mk : (U → U) → U + +/-- +info: @[reducible] unsafe def U.mk.noConfusion.{u} : (P : Sort u) → (a a' : U → U) → U.mk a = U.mk a' → (a = a' → P) → P +-/ +#guard_msgs in #print sig U.mk.noConfusion + +-- More tests suggested by Claude + +-- Test 2: Indexed family with complex indices +inductive Matrix (α : Type u) : Nat → Nat → Type u where + | empty : Matrix α 0 0 + | row (n m : Nat) (v : Vector α n) (rest : Matrix α m n) : Matrix α (m + 1) n + +/-- +info: @[reducible] def Matrix.row.noConfusion.{u_1, u} : {α : Type u} → + (P : Sort u_1) → + (n m : Nat) → + (v : Vector α n) → + (rest : Matrix α m n) → + (v' : Vector α n) → + (rest' : Matrix α m n) → + Matrix.row n m v rest = Matrix.row n m v' rest' → (n = n → m = m → v ≍ v' → rest ≍ rest' → P) → P +-/ +#guard_msgs in #print sig Matrix.row.noConfusion + +-- Test 3: Mutual inductive types +mutual + inductive Tree (α : Type u) : Type u where + | leaf (val : α) : Tree α + | node (forest : Forest α) : Tree α + + inductive Forest (α : Type u) : Type u where + | empty : Forest α + | cons (tree : Tree α) (rest : Forest α) : Forest α +end + +-- Test 4: Higher-order inductive with function types +inductive HigherOrder (α : Type) : Type 1 where + | base (x : α) : HigherOrder α + | func (f : α → HigherOrder α) : HigherOrder α + +-- Test noConfusion with function arguments +/-- +info: @[reducible] def HigherOrder.base.noConfusion.{u} : {α : Type} → + (P : Sort u) → (x x' : α) → HigherOrder.base x = HigherOrder.base x' → (x = x' → P) → P +-/ +#guard_msgs in #print sig HigherOrder.base.noConfusion +/-- +info: @[reducible] def HigherOrder.func.noConfusion.{u} : {α : Type} → + (P : Sort u) → (f f' : α → HigherOrder α) → HigherOrder.func f = HigherOrder.func f' → (f = f' → P) → P +-/ +#guard_msgs in #print sig HigherOrder.func.noConfusion + +-- Test 5: Nested inductive with complex dependency +inductive Nested : Type 1 where + | simple (n : Nat) : Nested + | complex (inner : List Nested) : Nested + +-- Test recursive nesting in noConfusion +/-- +info: @[reducible] def Nested.simple.noConfusion.{u} : (P : Sort u) → + (n n' : Nat) → Nested.simple n = Nested.simple n' → (n = n' → P) → P +-/ +#guard_msgs in #print sig Nested.simple.noConfusion +/-- +info: @[reducible] def Nested.complex.noConfusion.{u} : (P : Sort u) → + (inner inner' : List Nested) → Nested.complex inner = Nested.complex inner' → (inner = inner' → P) → P +-/ +#guard_msgs in #print sig Nested.complex.noConfusion + +-- Test 6: Inductive with universe polymorphism +inductive UnivPoly.{u, v} (α : Type u) (β : Type v) : Type (max u v) where + | left (a : α) : UnivPoly α β + | right (b : β) : UnivPoly α β + | both (a : α) (b : β) : UnivPoly α β + +-- Test universe-polymorphic noConfusion +/-- +info: @[reducible] def UnivPoly.left.noConfusion.{u_1, u, v} : {α : Type u} → + {β : Type v} → (P : Sort u_1) → (a a' : α) → UnivPoly.left a = UnivPoly.left a' → (a = a' → P) → P +-/ +#guard_msgs in #print sig UnivPoly.left.noConfusion +/-- +info: @[reducible] def UnivPoly.right.noConfusion.{u_1, u, v} : {α : Type u} → + {β : Type v} → (P : Sort u_1) → (b b' : β) → UnivPoly.right b = UnivPoly.right b' → (b = b' → P) → P +-/ +#guard_msgs in #print sig UnivPoly.right.noConfusion +/-- +info: @[reducible] def UnivPoly.both.noConfusion.{u_1, u, v} : {α : Type u} → + {β : Type v} → + (P : Sort u_1) → + (a : α) → (b : β) → (a' : α) → (b' : β) → UnivPoly.both a b = UnivPoly.both a' b' → (a = a' → b = b' → P) → P +-/ +#guard_msgs in #print sig UnivPoly.both.noConfusion + +-- Test 7: Inductive with implicit arguments and type classes +inductive WithTypeClass (α : Type u) [Inhabited α] : Type u where + | default : WithTypeClass α + | custom (val : α) : WithTypeClass α + +-- Test 8: Very complex indexed family with dependent types +inductive ComplexVec (α : Type u) : (n : Nat) → (valid : n > 0) → Type u where + | single (x : α) : ComplexVec α 1 (by simp) + | extend {n : Nat} {h : n > 0} (x : α) (rest : ComplexVec α n h) : + ComplexVec α (n + 1) (by simp) + +/-- +info: @[reducible] def ComplexVec.extend.noConfusion.{u_1, u} : {α : Type u} → + (P : Sort u_1) → + {n : Nat} → + {h : n > 0} → + (x : α) → + (rest : ComplexVec α n h) → + (h' : n > 0) → + (x' : α) → + (rest' : ComplexVec α n h') → + ComplexVec.extend x rest = ComplexVec.extend x' rest' → (n = n → x = x' → rest ≍ rest' → P) → P +-/ +#guard_msgs in #print sig ComplexVec.extend.noConfusion