From b2a8d890c13ebf4c1f69f1ba4162929cfab21120 Mon Sep 17 00:00:00 2001 From: Joachim Breitner Date: Wed, 25 Jun 2025 11:05:03 +0200 Subject: [PATCH] =?UTF-8?q?refactor:=20linearNoConfusionType:=20use=20PULi?= =?UTF-8?q?ft,=20not=20`PUnit=20=E2=86=92`=20(#8973)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR refactors the juggling of universes in the linear `noConfusionType` construction: Instead of using `PUnit.{…} → ` in the to get the branches of `withCtorType` to the same universe level, we use `PULift`. This fixes https://github.com/leanprover/lean4/issues/8962, although probably doesn’t solve all issues of that kind while level equality checking is incomplete. --- src/Lean/Meta/Constructions/NoConfusion.lean | 10 +- .../Meta/Constructions/NoConfusionLinear.lean | 98 +++++++++++++++---- tests/lean/run/ind_whnf.lean | 2 +- tests/lean/run/issue8962.lean | 40 ++++++++ tests/lean/run/linearNoConfusion.lean | 41 +++++--- 5 files changed, 146 insertions(+), 45 deletions(-) create mode 100644 tests/lean/run/issue8962.lean diff --git a/src/Lean/Meta/Constructions/NoConfusion.lean b/src/Lean/Meta/Constructions/NoConfusion.lean index 27e11eedc4..be85c4b506 100644 --- a/src/Lean/Meta/Constructions/NoConfusion.lean +++ b/src/Lean/Meta/Constructions/NoConfusion.lean @@ -10,10 +10,6 @@ import Lean.Meta.CompletionName import Lean.Meta.Constructions.NoConfusionLinear -register_builtin_option backwards.linearNoConfusionType : Bool := { - defValue := true - descr := "use the linear-size construction for the `noConfusionType` declaration of an inductive type. Set to false to use the previous, simpler but quadratic-size construction. " -} namespace Lean @@ -28,11 +24,7 @@ def mkNoConfusionCore (declName : Name) : MetaM Unit := do let recInfo ← getConstInfo (mkRecName declName) unless recInfo.levelParams.length > indVal.levelParams.length do return - let useLinear ← - if backwards.linearNoConfusionType.get (← getOptions) then - NoConfusionLinear.deps.allM (hasConst · (skipRealize := true)) - else - pure false + let useLinear ← NoConfusionLinear.canUse if useLinear then NoConfusionLinear.mkWithCtorType declName diff --git a/src/Lean/Meta/Constructions/NoConfusionLinear.lean b/src/Lean/Meta/Constructions/NoConfusionLinear.lean index 16e47d3115..463fb5da17 100644 --- a/src/Lean/Meta/Constructions/NoConfusionLinear.lean +++ b/src/Lean/Meta/Constructions/NoConfusionLinear.lean @@ -37,14 +37,24 @@ namespace Lean.NoConfusionLinear open Meta + +register_builtin_option backwards.linearNoConfusionType : Bool := { + defValue := true + descr := "use the linear-size construction for the `noConfusionType` declaration of an inductive type. Set to false to use the previous, simpler but quadratic-size construction. " +} + /-- List of constants that the linear `noConfusionType` construction depends on. -/ -def deps : Array Lean.Name := - #[ ``Nat.lt, ``cond, ``Nat, ``PUnit, ``Eq, ``Not, ``dite, ``Nat.decEq, ``Nat.blt ] +private def deps : Array Lean.Name := + #[ ``cond, ``ULift, ``Eq.ndrec, ``Not, ``dite, ``Nat.decEq, ``Nat.blt ] -def mkNatLookupTable (n : Expr) (es : Array Expr) (default : Expr) : MetaM Expr := do - let type ← inferType default +def canUse : MetaM Bool := do + unless backwards.linearNoConfusionType.get (← getOptions) do return false + unless (← NoConfusionLinear.deps.allM (hasConst · (skipRealize := true))) do return false + return true + +def mkNatLookupTable (n : Expr) (type : Expr) (es : Array Expr) (default : Expr) : MetaM Expr := do let u ← getLevel type let rec go (start stop : Nat) (hstart : start < stop := by omega) (hstop : stop ≤ es.size := by omega) : MetaM Expr := do if h : start + 1 = stop then @@ -59,6 +69,55 @@ def mkNatLookupTable (n : Expr) (es : Array Expr) (default : Expr) : MetaM Expr else go 0 es.size +-- Right-associates the top-most `max`s to work around #5695 for prettier code +private def reassocMax (l : Level) : Level := + let lvls := maxArgs l #[] + let last := lvls.back! + lvls.pop.foldr mkLevelMax last +where + maxArgs (l : Level) (lvls : Array Level) : Array Level := + match l with + | .max l1 l2 => maxArgs l2 (maxArgs l1 lvls) + | _ => lvls.push l + +/-- +Takes the max of the levels of the given expressions. +-/ +def maxLevels (es : Array Expr) (default : Expr) : MetaM Level := do + let mut maxLevel ← getLevel default + for e in es do + let l ← getLevel e + maxLevel := mkLevelMax' maxLevel l + return reassocMax maxLevel.normalize + +private def mkPULift (r : Level) (t : Expr) : MetaM Expr := do + let s ← getLevel t + return mkApp (mkConst `PULift [r,s]) t + +private def withMkPULiftUp (t : Expr) (k : Expr → MetaM Expr) : MetaM Expr := do + let t ← whnf t + if t.isAppOfArity `PULift 1 then + let t' := t.appArg! + let e ← k t' + return mkApp2 (mkConst `PULift.up (t.appFn!.constLevels!)) t' e + else + throwError "withMkPULiftUp: expected PULift type, got {t}" + +private def mkPULiftDown (e : Expr) : MetaM Expr := do + let t ← whnf (← inferType e) + if t.isAppOfArity `PULift 1 then + let t' := t.appArg! + return mkApp2 (mkConst `PULift.down t.appFn!.constLevels!) t' e + else + throwError "mkULiftDown: expected ULift type, got {t}" + +def mkNatLookupTableLifting (n : Expr) (es : Array Expr) (default : Expr) : MetaM Expr := do + let u ← maxLevels es default + let default ← mkPULift u default + let u' := reassocMax (mkLevelMax' u 1).normalize + let es ← es.mapM (mkPULift u) + mkNatLookupTable n (.sort u') es default + def mkWithCtorTypeName (indName : Name) : Name := Name.str indName "noConfusionType" |>.str "withCtorType" @@ -75,18 +134,15 @@ def mkWithCtorType (indName : Name) : MetaM Unit := do let v::us := casesOnInfo.levelParams.map mkLevelParam | panic! "unexpected universe levels on `casesOn`" let indTyCon := mkConst indName us let indTyKind ← inferType indTyCon - let indLevel ← getLevel indTyKind - let e ← forallBoundedTelescope indTyKind info.numParams fun xs _ => do + let e ← forallBoundedTelescope indTyKind info.numParams fun xs _ => do withLocalDeclD `P (mkSort v.succ) fun P => do withLocalDeclD `ctorIdx (mkConst ``Nat) fun ctorIdx => do - let default ← mkArrow (mkConst ``PUnit [indLevel]) P let es ← info.ctors.toArray.mapM fun ctorName => do let ctor := mkAppN (mkConst ctorName us) xs let ctorType ← inferType ctor - let argType ← forallTelescope ctorType fun ys _ => + forallTelescope ctorType fun ys _ => mkForallFVars ys P - mkArrow (mkConst ``PUnit [indLevel]) argType - let e ← mkNatLookupTable ctorIdx es default + let e ← mkNatLookupTableLifting ctorIdx es P mkLambdaFVars ((xs.push P).push ctorIdx) e let declName := mkWithCtorTypeName indName @@ -109,7 +165,6 @@ def mkWithCtor (indName : Name) : MetaM Unit := do let v::us := casesOnInfo.levelParams.map mkLevelParam | panic! "unexpected universe levels on `casesOn`" let indTyCon := mkConst indName us let indTyKind ← inferType indTyCon - let indLevel ← getLevel indTyKind let e ← forallBoundedTelescope indTyKind info.numParams fun xs t => do withLocalDeclD `P (mkSort v.succ) fun P => do withLocalDeclD `ctorIdx (mkConst ``Nat) fun ctorIdx => do @@ -134,7 +189,7 @@ def mkWithCtor (indName : Name) : MetaM Unit := do let heq := mkApp3 (mkConst ``Eq [1]) (mkConst ``Nat) ctorIdx (mkRawNatLit i) let «then» ← withLocalDeclD `h heq fun h => do let e ← mkEqNDRec (motive := withCtorTypeNameApp) k h - let e := mkApp e (mkConst ``PUnit.unit [indLevel]) + let e ← mkPULiftDown e let e := mkAppN e zs -- ``Eq.ndrec mkLambdaFVars #[h] e @@ -191,15 +246,16 @@ def mkNoConfusionTypeLinear (indName : Name) : MetaM Unit := do let alt := mkAppN alt xs let alt := mkApp alt PType let alt := mkApp alt (mkRawNatLit i) - let k ← forallTelescopeReducing (← inferType alt).bindingDomain! fun zs2 _ => do - let eqs ← (Array.zip zs1 zs2[1:]).filterMapM fun (z1,z2) => do - if (← isProof z1) then - return none - else - return some (← mkEqHEq z1 z2) - let k ← mkArrowN eqs P - let k ← mkArrow k P - mkLambdaFVars zs2 k + let k ← withMkPULiftUp (← inferType alt).bindingDomain! fun t => + forallTelescopeReducing t fun zs2 _ => do + let eqs ← (Array.zip zs1 zs2).filterMapM fun (z1,z2) => do + if (← isProof z1) then + return none + else + return some (← mkEqHEq z1 z2) + let k ← mkArrowN eqs P + let k ← mkArrow k P + mkLambdaFVars zs2 k let alt := mkApp alt k let alt := mkApp alt P let alt := mkAppN alt ys diff --git a/tests/lean/run/ind_whnf.lean b/tests/lean/run/ind_whnf.lean index dbb3d7e50e..21ba1c5576 100644 --- a/tests/lean/run/ind_whnf.lean +++ b/tests/lean/run/ind_whnf.lean @@ -4,7 +4,7 @@ inductive Expr : id Type partial def Expr.fold (f : Nat → α → α) : Expr → α → α | var n, a => f n a - | app s as, a => as.foldl (init := a) fun a e => fold f e a + | app _ as, a => as.foldl (init := a) fun a e => fold f e a def Expr.isVar : Expr → Bool | var _ => true diff --git a/tests/lean/run/issue8962.lean b/tests/lean/run/issue8962.lean new file mode 100644 index 0000000000..99286ab319 --- /dev/null +++ b/tests/lean/run/issue8962.lean @@ -0,0 +1,40 @@ +-- This triggered a bug in the linear-size `noConfusionType` construction +-- which confused the kernel when producing the `noConfusion` lemma. + +set_option debug.skipKernelTC true +set_option pp.universes true + +-- Works + +inductive S where +| a {α : Sort u} {β : Type v} (f : α → β) +| b + +/-- +info: @[reducible] protected def S.noConfusionType.withCtorType.{u_1, u, v} : Type u_1 → Nat → Type (max u u_1 (v + 1)) := +fun P ctorIdx => + bif Nat.blt ctorIdx 1 then + PULift.{max (u + 1) (u_1 + 1) (v + 2), max (max (u + 1) (u_1 + 1)) (v + 2)} + ({α : Sort u} → {β : Type v} → (α → β) → P) + else PULift.{max (u + 1) (u_1 + 1) (v + 2), u_1 + 1} P +-/ +#guard_msgs in +#print S.noConfusionType.withCtorType + +-- Didn't work + +inductive T where +| a {α : Sort u} {β : Sort v} (f : α → β) +| b + +/-- +info: @[reducible] protected def T.noConfusionType.withCtorType.{u_1, u, v} : Type u_1 → + Nat → Sort (max (u + 1) (u_1 + 1) (v + 1) (imax u v)) := +fun P ctorIdx => + bif Nat.blt ctorIdx 1 then + PULift.{max (u + 1) (u_1 + 1) (v + 1) (imax u v), max (max (max (u + 1) (u_1 + 1)) (v + 1)) (imax u v)} + ({α : Sort u} → {β : Sort v} → (α → β) → P) + else PULift.{max (u + 1) (u_1 + 1) (v + 1) (imax u v), u_1 + 1} P +-/ +#guard_msgs in +#print T.noConfusionType.withCtorType diff --git a/tests/lean/run/linearNoConfusion.lean b/tests/lean/run/linearNoConfusion.lean index 59bb05a909..1de9bf9edb 100644 --- a/tests/lean/run/linearNoConfusion.lean +++ b/tests/lean/run/linearNoConfusion.lean @@ -12,53 +12,60 @@ inductive Vec.{u} (α : Type) : Nat → Type u where | nil : Vec α 0 | cons {n} : α → Vec α n → Vec α (n + 1) + @[reducible] protected def Vec.noConfusionType.withCtorType'.{u_1, u} : - Type → Type u_1 → Nat → Type (max (u + 1) u_1) := fun α P ctorIdx => - bif Nat.blt ctorIdx 1 - then PUnit.{u + 2} → P - else PUnit.{u + 2} → {n : Nat} → α → Vec.{u} α n → P + Type → Type u_1 → Nat → Type (max u u_1) := +fun α P ctorIdx => + bif Nat.blt ctorIdx 1 then PULift.{max (u+1) (u_1+1)} P + else PULift.{max (u+1) (u_1+1)} ({n : Nat} → α → Vec.{u} α n → P) /-- -info: @[reducible] protected def Vec.noConfusionType.withCtorType.{u_1, u} : Type → Type u_1 → Nat → Type (max (u + 1) u_1) := -fun α P ctorIdx => bif ctorIdx.blt 1 then PUnit → P else PUnit → {n : Nat} → α → Vec α n → P +info: @[reducible] protected def Vec.noConfusionType.withCtorType.{u_1, u} : Type → Type u_1 → Nat → Type (max u u_1) := +fun α P ctorIdx => + bif Nat.blt ctorIdx 1 then PULift.{max (u + 1) (u_1 + 1), u_1 + 1} P + else PULift.{max (u + 1) (u_1 + 1), max (u + 1) (u_1 + 1)} ({n : Nat} → α → Vec.{u} α n → P) -/ #guard_msgs in +set_option pp.universes true in #print Vec.noConfusionType.withCtorType example : @Vec.noConfusionType.withCtorType.{u_1,u} = @Vec.noConfusionType.withCtorType'.{u_1,u} := rfl + @[reducible] protected noncomputable def Vec.noConfusionType.withCtor'.{u_1, u} : (α : Type) → (P : Type u_1) → (ctorIdx : Nat) → Vec.noConfusionType.withCtorType' α P ctorIdx → P → (a : Nat) → Vec.{u} α a → P := fun _α _P ctorIdx k k' _a x => Vec.casesOn x - (if h : ctorIdx = 0 then Eq.ndrec k h PUnit.unit else k') - (fun a a_1 => if h : ctorIdx = 1 then Eq.ndrec k h PUnit.unit a a_1 else k') + (if h : ctorIdx = 0 then (Eq.ndrec k h).down else k') + (fun a a_1 => if h : ctorIdx = 1 then (Eq.ndrec k h).down a a_1 else k') /-- info: @[reducible] protected def Vec.noConfusionType.withCtor.{u_1, u} : (α : Type) → (P : Type u_1) → (ctorIdx : Nat) → Vec.noConfusionType.withCtorType α P ctorIdx → P → (a : Nat) → Vec α a → P := fun α P ctorIdx k k' a x => - Vec.casesOn x (if h : ctorIdx = 0 then (h ▸ k) PUnit.unit else k') fun {n} a a_1 => - if h : ctorIdx = 1 then (h ▸ k) PUnit.unit a a_1 else k' + Vec.casesOn x (if h : ctorIdx = 0 then (h ▸ k).down else k') fun {n} a a_1 => + if h : ctorIdx = 1 then (h ▸ k).down a a_1 else k' -/ #guard_msgs in #print Vec.noConfusionType.withCtor example : @Vec.noConfusionType.withCtor.{u_1,u} = @Vec.noConfusionType.withCtor'.{u_1,u} := rfl + @[reducible] protected def Vec.noConfusionType'.{u_1, u} : {α : Type} → {a : Nat} → Sort u_1 → Vec.{u} α a → Vec α a → Sort u_1 := fun {α} {a} P x1 x2 => Vec.casesOn x1 - (Vec.noConfusionType.withCtor' α (Sort u_1) 0 (fun _x => P → P) P a x2) - (fun {n} a_1 a_2 => Vec.noConfusionType.withCtor' α (Sort u_1) 1 (fun _x {n_1} a a_3 => (n = n_1 → a_1 = a → a_2 ≍ a_3 → P) → P) P a x2) + (Vec.noConfusionType.withCtor' α (Sort u_1) 0 ⟨P → P⟩ P a x2) + (fun {n} a_1 a_2 => Vec.noConfusionType.withCtor' α (Sort u_1) 1 ⟨fun {n_1} a a_3 => (n = n_1 → a_1 = a → a_2 ≍ a_3 → P) → P⟩ P a x2) /-- info: @[reducible] protected def Vec.noConfusionType.{u_1, u} : {α : Type} → {a : Nat} → Sort u_1 → Vec α a → Vec α a → Sort u_1 := fun {α} {a} P x1 x2 => - Vec.casesOn x1 (Vec.noConfusionType.withCtor α (Sort u_1) 0 (fun x => P → P) P a x2) fun {n} a_1 a_2 => - Vec.noConfusionType.withCtor α (Sort u_1) 1 (fun x {n_1} a a_3 => (n = n_1 → a_1 = a → a_2 ≍ a_3 → P) → P) P a x2 + Vec.casesOn x1 (Vec.noConfusionType.withCtor α (Sort u_1) 0 { down := P → P } P a x2) fun {n} a_1 a_2 => + Vec.noConfusionType.withCtor α (Sort u_1) 1 { down := fun {n_1} a a_3 => (n = n_1 → a_1 = a → a_2 ≍ a_3 → P) → P } P + a x2 -/ #guard_msgs in #print Vec.noConfusionType @@ -84,3 +91,9 @@ run_meta do -- inductive Enum.{u} : Type u where | a | b -- set_option pp.universes true in -- #print noConfusionTypeEnum + +-- A possibly tricky universes case (resulting universe cannot be decremented) + +inductive UnivTest.{u,v} (α : Sort v): Sort (max u v 1) where + | mk1 : UnivTest α + | mk2 : (x : α) → UnivTest α