refactor: linearNoConfusionType: use PULift, not PUnit → (#8973)

This PR refactors the juggling of universes in the linear
`noConfusionType` construction: Instead of using `PUnit.{…} → ` in the
to get the branches of `withCtorType` to the same universe level, we use
`PULift`.

This fixes https://github.com/leanprover/lean4/issues/8962, although
probably doesn’t solve all issues of that kind while level equality
checking is incomplete.
This commit is contained in:
Joachim Breitner 2025-06-25 11:05:03 +02:00 committed by GitHub
parent 9641a9ac6c
commit b2a8d890c1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 146 additions and 45 deletions

View file

@ -10,10 +10,6 @@ import Lean.Meta.CompletionName
import Lean.Meta.Constructions.NoConfusionLinear
register_builtin_option backwards.linearNoConfusionType : Bool := {
defValue := true
descr := "use the linear-size construction for the `noConfusionType` declaration of an inductive type. Set to false to use the previous, simpler but quadratic-size construction. "
}
namespace Lean
@ -28,11 +24,7 @@ def mkNoConfusionCore (declName : Name) : MetaM Unit := do
let recInfo ← getConstInfo (mkRecName declName)
unless recInfo.levelParams.length > indVal.levelParams.length do return
let useLinear ←
if backwards.linearNoConfusionType.get (← getOptions) then
NoConfusionLinear.deps.allM (hasConst · (skipRealize := true))
else
pure false
let useLinear ← NoConfusionLinear.canUse
if useLinear then
NoConfusionLinear.mkWithCtorType declName

View file

@ -37,14 +37,24 @@ namespace Lean.NoConfusionLinear
open Meta
register_builtin_option backwards.linearNoConfusionType : Bool := {
defValue := true
descr := "use the linear-size construction for the `noConfusionType` declaration of an inductive type. Set to false to use the previous, simpler but quadratic-size construction. "
}
/--
List of constants that the linear `noConfusionType` construction depends on.
-/
def deps : Array Lean.Name :=
#[ ``Nat.lt, ``cond, ``Nat, ``PUnit, ``Eq, ``Not, ``dite, ``Nat.decEq, ``Nat.blt ]
private def deps : Array Lean.Name :=
#[ ``cond, ``ULift, ``Eq.ndrec, ``Not, ``dite, ``Nat.decEq, ``Nat.blt ]
def mkNatLookupTable (n : Expr) (es : Array Expr) (default : Expr) : MetaM Expr := do
let type ← inferType default
def canUse : MetaM Bool := do
unless backwards.linearNoConfusionType.get (← getOptions) do return false
unless (← NoConfusionLinear.deps.allM (hasConst · (skipRealize := true))) do return false
return true
def mkNatLookupTable (n : Expr) (type : Expr) (es : Array Expr) (default : Expr) : MetaM Expr := do
let u ← getLevel type
let rec go (start stop : Nat) (hstart : start < stop := by omega) (hstop : stop ≤ es.size := by omega) : MetaM Expr := do
if h : start + 1 = stop then
@ -59,6 +69,55 @@ def mkNatLookupTable (n : Expr) (es : Array Expr) (default : Expr) : MetaM Expr
else
go 0 es.size
-- Right-associates the top-most `max`s to work around #5695 for prettier code
private def reassocMax (l : Level) : Level :=
let lvls := maxArgs l #[]
let last := lvls.back!
lvls.pop.foldr mkLevelMax last
where
maxArgs (l : Level) (lvls : Array Level) : Array Level :=
match l with
| .max l1 l2 => maxArgs l2 (maxArgs l1 lvls)
| _ => lvls.push l
/--
Takes the max of the levels of the given expressions.
-/
def maxLevels (es : Array Expr) (default : Expr) : MetaM Level := do
let mut maxLevel ← getLevel default
for e in es do
let l ← getLevel e
maxLevel := mkLevelMax' maxLevel l
return reassocMax maxLevel.normalize
private def mkPULift (r : Level) (t : Expr) : MetaM Expr := do
let s ← getLevel t
return mkApp (mkConst `PULift [r,s]) t
private def withMkPULiftUp (t : Expr) (k : Expr → MetaM Expr) : MetaM Expr := do
let t ← whnf t
if t.isAppOfArity `PULift 1 then
let t' := t.appArg!
let e ← k t'
return mkApp2 (mkConst `PULift.up (t.appFn!.constLevels!)) t' e
else
throwError "withMkPULiftUp: expected PULift type, got {t}"
private def mkPULiftDown (e : Expr) : MetaM Expr := do
let t ← whnf (← inferType e)
if t.isAppOfArity `PULift 1 then
let t' := t.appArg!
return mkApp2 (mkConst `PULift.down t.appFn!.constLevels!) t' e
else
throwError "mkULiftDown: expected ULift type, got {t}"
def mkNatLookupTableLifting (n : Expr) (es : Array Expr) (default : Expr) : MetaM Expr := do
let u ← maxLevels es default
let default ← mkPULift u default
let u' := reassocMax (mkLevelMax' u 1).normalize
let es ← es.mapM (mkPULift u)
mkNatLookupTable n (.sort u') es default
def mkWithCtorTypeName (indName : Name) : Name :=
Name.str indName "noConfusionType" |>.str "withCtorType"
@ -75,18 +134,15 @@ def mkWithCtorType (indName : Name) : MetaM Unit := do
let v::us := casesOnInfo.levelParams.map mkLevelParam | panic! "unexpected universe levels on `casesOn`"
let indTyCon := mkConst indName us
let indTyKind ← inferType indTyCon
let indLevel ← getLevel indTyKind
let e ← forallBoundedTelescope indTyKind info.numParams fun xs _ => do
let e ← forallBoundedTelescope indTyKind info.numParams fun xs _ => do
withLocalDeclD `P (mkSort v.succ) fun P => do
withLocalDeclD `ctorIdx (mkConst ``Nat) fun ctorIdx => do
let default ← mkArrow (mkConst ``PUnit [indLevel]) P
let es ← info.ctors.toArray.mapM fun ctorName => do
let ctor := mkAppN (mkConst ctorName us) xs
let ctorType ← inferType ctor
let argType ← forallTelescope ctorType fun ys _ =>
forallTelescope ctorType fun ys _ =>
mkForallFVars ys P
mkArrow (mkConst ``PUnit [indLevel]) argType
let e ← mkNatLookupTable ctorIdx es default
let e ← mkNatLookupTableLifting ctorIdx es P
mkLambdaFVars ((xs.push P).push ctorIdx) e
let declName := mkWithCtorTypeName indName
@ -109,7 +165,6 @@ def mkWithCtor (indName : Name) : MetaM Unit := do
let v::us := casesOnInfo.levelParams.map mkLevelParam | panic! "unexpected universe levels on `casesOn`"
let indTyCon := mkConst indName us
let indTyKind ← inferType indTyCon
let indLevel ← getLevel indTyKind
let e ← forallBoundedTelescope indTyKind info.numParams fun xs t => do
withLocalDeclD `P (mkSort v.succ) fun P => do
withLocalDeclD `ctorIdx (mkConst ``Nat) fun ctorIdx => do
@ -134,7 +189,7 @@ def mkWithCtor (indName : Name) : MetaM Unit := do
let heq := mkApp3 (mkConst ``Eq [1]) (mkConst ``Nat) ctorIdx (mkRawNatLit i)
let «then» ← withLocalDeclD `h heq fun h => do
let e ← mkEqNDRec (motive := withCtorTypeNameApp) k h
let e := mkApp e (mkConst ``PUnit.unit [indLevel])
let e ← mkPULiftDown e
let e := mkAppN e zs
-- ``Eq.ndrec
mkLambdaFVars #[h] e
@ -191,15 +246,16 @@ def mkNoConfusionTypeLinear (indName : Name) : MetaM Unit := do
let alt := mkAppN alt xs
let alt := mkApp alt PType
let alt := mkApp alt (mkRawNatLit i)
let k ← forallTelescopeReducing (← inferType alt).bindingDomain! fun zs2 _ => do
let eqs ← (Array.zip zs1 zs2[1:]).filterMapM fun (z1,z2) => do
if (← isProof z1) then
return none
else
return some (← mkEqHEq z1 z2)
let k ← mkArrowN eqs P
let k ← mkArrow k P
mkLambdaFVars zs2 k
let k ← withMkPULiftUp (← inferType alt).bindingDomain! fun t =>
forallTelescopeReducing t fun zs2 _ => do
let eqs ← (Array.zip zs1 zs2).filterMapM fun (z1,z2) => do
if (← isProof z1) then
return none
else
return some (← mkEqHEq z1 z2)
let k ← mkArrowN eqs P
let k ← mkArrow k P
mkLambdaFVars zs2 k
let alt := mkApp alt k
let alt := mkApp alt P
let alt := mkAppN alt ys

View file

@ -4,7 +4,7 @@ inductive Expr : id Type
partial def Expr.fold (f : Nat → αα) : Expr → αα
| var n, a => f n a
| app s as, a => as.foldl (init := a) fun a e => fold f e a
| app _ as, a => as.foldl (init := a) fun a e => fold f e a
def Expr.isVar : Expr → Bool
| var _ => true

View file

@ -0,0 +1,40 @@
-- This triggered a bug in the linear-size `noConfusionType` construction
-- which confused the kernel when producing the `noConfusion` lemma.
set_option debug.skipKernelTC true
set_option pp.universes true
-- Works
inductive S where
| a {α : Sort u} {β : Type v} (f : α → β)
| b
/--
info: @[reducible] protected def S.noConfusionType.withCtorType.{u_1, u, v} : Type u_1 → Nat → Type (max u u_1 (v + 1)) :=
fun P ctorIdx =>
bif Nat.blt ctorIdx 1 then
PULift.{max (u + 1) (u_1 + 1) (v + 2), max (max (u + 1) (u_1 + 1)) (v + 2)}
({α : Sort u} → {β : Type v} → (α → β) → P)
else PULift.{max (u + 1) (u_1 + 1) (v + 2), u_1 + 1} P
-/
#guard_msgs in
#print S.noConfusionType.withCtorType
-- Didn't work
inductive T where
| a {α : Sort u} {β : Sort v} (f : α → β)
| b
/--
info: @[reducible] protected def T.noConfusionType.withCtorType.{u_1, u, v} : Type u_1 →
Nat → Sort (max (u + 1) (u_1 + 1) (v + 1) (imax u v)) :=
fun P ctorIdx =>
bif Nat.blt ctorIdx 1 then
PULift.{max (u + 1) (u_1 + 1) (v + 1) (imax u v), max (max (max (u + 1) (u_1 + 1)) (v + 1)) (imax u v)}
({α : Sort u} → {β : Sort v} → (α → β) → P)
else PULift.{max (u + 1) (u_1 + 1) (v + 1) (imax u v), u_1 + 1} P
-/
#guard_msgs in
#print T.noConfusionType.withCtorType

View file

@ -12,53 +12,60 @@ inductive Vec.{u} (α : Type) : Nat → Type u where
| nil : Vec α 0
| cons {n} : α → Vec α n → Vec α (n + 1)
@[reducible] protected def Vec.noConfusionType.withCtorType'.{u_1, u} :
Type → Type u_1 → Nat → Type (max (u + 1) u_1) := fun α P ctorIdx =>
bif Nat.blt ctorIdx 1
then PUnit.{u + 2} → P
else PUnit.{u + 2} → {n : Nat} → α → Vec.{u} α n → P
Type → Type u_1 → Nat → Type (max u u_1) :=
fun α P ctorIdx =>
bif Nat.blt ctorIdx 1 then PULift.{max (u+1) (u_1+1)} P
else PULift.{max (u+1) (u_1+1)} ({n : Nat} → α → Vec.{u} α n → P)
/--
info: @[reducible] protected def Vec.noConfusionType.withCtorType.{u_1, u} : Type → Type u_1 → Nat → Type (max (u + 1) u_1) :=
fun α P ctorIdx => bif ctorIdx.blt 1 then PUnit → P else PUnit → {n : Nat} → α → Vec α n → P
info: @[reducible] protected def Vec.noConfusionType.withCtorType.{u_1, u} : Type → Type u_1 → Nat → Type (max u u_1) :=
fun α P ctorIdx =>
bif Nat.blt ctorIdx 1 then PULift.{max (u + 1) (u_1 + 1), u_1 + 1} P
else PULift.{max (u + 1) (u_1 + 1), max (u + 1) (u_1 + 1)} ({n : Nat} → α → Vec.{u} α n → P)
-/
#guard_msgs in
set_option pp.universes true in
#print Vec.noConfusionType.withCtorType
example : @Vec.noConfusionType.withCtorType.{u_1,u} = @Vec.noConfusionType.withCtorType'.{u_1,u} := rfl
@[reducible] protected noncomputable def Vec.noConfusionType.withCtor'.{u_1, u} : (α : Type) →
(P : Type u_1) → (ctorIdx : Nat) → Vec.noConfusionType.withCtorType' α P ctorIdx → P → (a : Nat) → Vec.{u} α a → P :=
fun _α _P ctorIdx k k' _a x =>
Vec.casesOn x
(if h : ctorIdx = 0 then Eq.ndrec k h PUnit.unit else k')
(fun a a_1 => if h : ctorIdx = 1 then Eq.ndrec k h PUnit.unit a a_1 else k')
(if h : ctorIdx = 0 then (Eq.ndrec k h).down else k')
(fun a a_1 => if h : ctorIdx = 1 then (Eq.ndrec k h).down a a_1 else k')
/--
info: @[reducible] protected def Vec.noConfusionType.withCtor.{u_1, u} : (α : Type) →
(P : Type u_1) → (ctorIdx : Nat) → Vec.noConfusionType.withCtorType α P ctorIdx → P → (a : Nat) → Vec α a → P :=
fun α P ctorIdx k k' a x =>
Vec.casesOn x (if h : ctorIdx = 0 then (h ▸ k) PUnit.unit else k') fun {n} a a_1 =>
if h : ctorIdx = 1 then (h ▸ k) PUnit.unit a a_1 else k'
Vec.casesOn x (if h : ctorIdx = 0 then (h ▸ k).down else k') fun {n} a a_1 =>
if h : ctorIdx = 1 then (h ▸ k).down a a_1 else k'
-/
#guard_msgs in
#print Vec.noConfusionType.withCtor
example : @Vec.noConfusionType.withCtor.{u_1,u} = @Vec.noConfusionType.withCtor'.{u_1,u} := rfl
@[reducible] protected def Vec.noConfusionType'.{u_1, u} : {α : Type} →
{a : Nat} → Sort u_1 → Vec.{u} α a → Vec α a → Sort u_1 :=
fun {α} {a} P x1 x2 =>
Vec.casesOn x1
(Vec.noConfusionType.withCtor' α (Sort u_1) 0 (fun _x => P → P) P a x2)
(fun {n} a_1 a_2 => Vec.noConfusionType.withCtor' α (Sort u_1) 1 (fun _x {n_1} a a_3 => (n = n_1 → a_1 = a → a_2 ≍ a_3 → P) → P) P a x2)
(Vec.noConfusionType.withCtor' α (Sort u_1) 0 ⟨P → P⟩ P a x2)
(fun {n} a_1 a_2 => Vec.noConfusionType.withCtor' α (Sort u_1) 1 ⟨fun {n_1} a a_3 => (n = n_1 → a_1 = a → a_2 ≍ a_3 → P) → P⟩ P a x2)
/--
info: @[reducible] protected def Vec.noConfusionType.{u_1, u} : {α : Type} →
{a : Nat} → Sort u_1 → Vec α a → Vec α a → Sort u_1 :=
fun {α} {a} P x1 x2 =>
Vec.casesOn x1 (Vec.noConfusionType.withCtor α (Sort u_1) 0 (fun x => P → P) P a x2) fun {n} a_1 a_2 =>
Vec.noConfusionType.withCtor α (Sort u_1) 1 (fun x {n_1} a a_3 => (n = n_1 → a_1 = a → a_2 ≍ a_3 → P) → P) P a x2
Vec.casesOn x1 (Vec.noConfusionType.withCtor α (Sort u_1) 0 { down := P → P } P a x2) fun {n} a_1 a_2 =>
Vec.noConfusionType.withCtor α (Sort u_1) 1 { down := fun {n_1} a a_3 => (n = n_1 → a_1 = a → a_2 ≍ a_3 → P) → P } P
a x2
-/
#guard_msgs in
#print Vec.noConfusionType
@ -84,3 +91,9 @@ run_meta do
-- inductive Enum.{u} : Type u where | a | b
-- set_option pp.universes true in
-- #print noConfusionTypeEnum
-- A possibly tricky universes case (resulting universe cannot be decremented)
inductive UnivTest.{u,v} (α : Sort v): Sort (max u v 1) where
| mk1 : UnivTest α
| mk2 : (x : α) → UnivTest α