refactor: linearNoConfusionType: use PULift, not PUnit → (#8973)
This PR refactors the juggling of universes in the linear
`noConfusionType` construction: Instead of using `PUnit.{…} → ` in the
to get the branches of `withCtorType` to the same universe level, we use
`PULift`.
This fixes https://github.com/leanprover/lean4/issues/8962, although
probably doesn’t solve all issues of that kind while level equality
checking is incomplete.
This commit is contained in:
parent
9641a9ac6c
commit
b2a8d890c1
5 changed files with 146 additions and 45 deletions
|
|
@ -10,10 +10,6 @@ import Lean.Meta.CompletionName
|
|||
import Lean.Meta.Constructions.NoConfusionLinear
|
||||
|
||||
|
||||
register_builtin_option backwards.linearNoConfusionType : Bool := {
|
||||
defValue := true
|
||||
descr := "use the linear-size construction for the `noConfusionType` declaration of an inductive type. Set to false to use the previous, simpler but quadratic-size construction. "
|
||||
}
|
||||
|
||||
namespace Lean
|
||||
|
||||
|
|
@ -28,11 +24,7 @@ def mkNoConfusionCore (declName : Name) : MetaM Unit := do
|
|||
let recInfo ← getConstInfo (mkRecName declName)
|
||||
unless recInfo.levelParams.length > indVal.levelParams.length do return
|
||||
|
||||
let useLinear ←
|
||||
if backwards.linearNoConfusionType.get (← getOptions) then
|
||||
NoConfusionLinear.deps.allM (hasConst · (skipRealize := true))
|
||||
else
|
||||
pure false
|
||||
let useLinear ← NoConfusionLinear.canUse
|
||||
|
||||
if useLinear then
|
||||
NoConfusionLinear.mkWithCtorType declName
|
||||
|
|
|
|||
|
|
@ -37,14 +37,24 @@ namespace Lean.NoConfusionLinear
|
|||
|
||||
open Meta
|
||||
|
||||
|
||||
register_builtin_option backwards.linearNoConfusionType : Bool := {
|
||||
defValue := true
|
||||
descr := "use the linear-size construction for the `noConfusionType` declaration of an inductive type. Set to false to use the previous, simpler but quadratic-size construction. "
|
||||
}
|
||||
|
||||
/--
|
||||
List of constants that the linear `noConfusionType` construction depends on.
|
||||
-/
|
||||
def deps : Array Lean.Name :=
|
||||
#[ ``Nat.lt, ``cond, ``Nat, ``PUnit, ``Eq, ``Not, ``dite, ``Nat.decEq, ``Nat.blt ]
|
||||
private def deps : Array Lean.Name :=
|
||||
#[ ``cond, ``ULift, ``Eq.ndrec, ``Not, ``dite, ``Nat.decEq, ``Nat.blt ]
|
||||
|
||||
def mkNatLookupTable (n : Expr) (es : Array Expr) (default : Expr) : MetaM Expr := do
|
||||
let type ← inferType default
|
||||
def canUse : MetaM Bool := do
|
||||
unless backwards.linearNoConfusionType.get (← getOptions) do return false
|
||||
unless (← NoConfusionLinear.deps.allM (hasConst · (skipRealize := true))) do return false
|
||||
return true
|
||||
|
||||
def mkNatLookupTable (n : Expr) (type : Expr) (es : Array Expr) (default : Expr) : MetaM Expr := do
|
||||
let u ← getLevel type
|
||||
let rec go (start stop : Nat) (hstart : start < stop := by omega) (hstop : stop ≤ es.size := by omega) : MetaM Expr := do
|
||||
if h : start + 1 = stop then
|
||||
|
|
@ -59,6 +69,55 @@ def mkNatLookupTable (n : Expr) (es : Array Expr) (default : Expr) : MetaM Expr
|
|||
else
|
||||
go 0 es.size
|
||||
|
||||
-- Right-associates the top-most `max`s to work around #5695 for prettier code
|
||||
private def reassocMax (l : Level) : Level :=
|
||||
let lvls := maxArgs l #[]
|
||||
let last := lvls.back!
|
||||
lvls.pop.foldr mkLevelMax last
|
||||
where
|
||||
maxArgs (l : Level) (lvls : Array Level) : Array Level :=
|
||||
match l with
|
||||
| .max l1 l2 => maxArgs l2 (maxArgs l1 lvls)
|
||||
| _ => lvls.push l
|
||||
|
||||
/--
|
||||
Takes the max of the levels of the given expressions.
|
||||
-/
|
||||
def maxLevels (es : Array Expr) (default : Expr) : MetaM Level := do
|
||||
let mut maxLevel ← getLevel default
|
||||
for e in es do
|
||||
let l ← getLevel e
|
||||
maxLevel := mkLevelMax' maxLevel l
|
||||
return reassocMax maxLevel.normalize
|
||||
|
||||
private def mkPULift (r : Level) (t : Expr) : MetaM Expr := do
|
||||
let s ← getLevel t
|
||||
return mkApp (mkConst `PULift [r,s]) t
|
||||
|
||||
private def withMkPULiftUp (t : Expr) (k : Expr → MetaM Expr) : MetaM Expr := do
|
||||
let t ← whnf t
|
||||
if t.isAppOfArity `PULift 1 then
|
||||
let t' := t.appArg!
|
||||
let e ← k t'
|
||||
return mkApp2 (mkConst `PULift.up (t.appFn!.constLevels!)) t' e
|
||||
else
|
||||
throwError "withMkPULiftUp: expected PULift type, got {t}"
|
||||
|
||||
private def mkPULiftDown (e : Expr) : MetaM Expr := do
|
||||
let t ← whnf (← inferType e)
|
||||
if t.isAppOfArity `PULift 1 then
|
||||
let t' := t.appArg!
|
||||
return mkApp2 (mkConst `PULift.down t.appFn!.constLevels!) t' e
|
||||
else
|
||||
throwError "mkULiftDown: expected ULift type, got {t}"
|
||||
|
||||
def mkNatLookupTableLifting (n : Expr) (es : Array Expr) (default : Expr) : MetaM Expr := do
|
||||
let u ← maxLevels es default
|
||||
let default ← mkPULift u default
|
||||
let u' := reassocMax (mkLevelMax' u 1).normalize
|
||||
let es ← es.mapM (mkPULift u)
|
||||
mkNatLookupTable n (.sort u') es default
|
||||
|
||||
def mkWithCtorTypeName (indName : Name) : Name :=
|
||||
Name.str indName "noConfusionType" |>.str "withCtorType"
|
||||
|
||||
|
|
@ -75,18 +134,15 @@ def mkWithCtorType (indName : Name) : MetaM Unit := do
|
|||
let v::us := casesOnInfo.levelParams.map mkLevelParam | panic! "unexpected universe levels on `casesOn`"
|
||||
let indTyCon := mkConst indName us
|
||||
let indTyKind ← inferType indTyCon
|
||||
let indLevel ← getLevel indTyKind
|
||||
let e ← forallBoundedTelescope indTyKind info.numParams fun xs _ => do
|
||||
let e ← forallBoundedTelescope indTyKind info.numParams fun xs _ => do
|
||||
withLocalDeclD `P (mkSort v.succ) fun P => do
|
||||
withLocalDeclD `ctorIdx (mkConst ``Nat) fun ctorIdx => do
|
||||
let default ← mkArrow (mkConst ``PUnit [indLevel]) P
|
||||
let es ← info.ctors.toArray.mapM fun ctorName => do
|
||||
let ctor := mkAppN (mkConst ctorName us) xs
|
||||
let ctorType ← inferType ctor
|
||||
let argType ← forallTelescope ctorType fun ys _ =>
|
||||
forallTelescope ctorType fun ys _ =>
|
||||
mkForallFVars ys P
|
||||
mkArrow (mkConst ``PUnit [indLevel]) argType
|
||||
let e ← mkNatLookupTable ctorIdx es default
|
||||
let e ← mkNatLookupTableLifting ctorIdx es P
|
||||
mkLambdaFVars ((xs.push P).push ctorIdx) e
|
||||
|
||||
let declName := mkWithCtorTypeName indName
|
||||
|
|
@ -109,7 +165,6 @@ def mkWithCtor (indName : Name) : MetaM Unit := do
|
|||
let v::us := casesOnInfo.levelParams.map mkLevelParam | panic! "unexpected universe levels on `casesOn`"
|
||||
let indTyCon := mkConst indName us
|
||||
let indTyKind ← inferType indTyCon
|
||||
let indLevel ← getLevel indTyKind
|
||||
let e ← forallBoundedTelescope indTyKind info.numParams fun xs t => do
|
||||
withLocalDeclD `P (mkSort v.succ) fun P => do
|
||||
withLocalDeclD `ctorIdx (mkConst ``Nat) fun ctorIdx => do
|
||||
|
|
@ -134,7 +189,7 @@ def mkWithCtor (indName : Name) : MetaM Unit := do
|
|||
let heq := mkApp3 (mkConst ``Eq [1]) (mkConst ``Nat) ctorIdx (mkRawNatLit i)
|
||||
let «then» ← withLocalDeclD `h heq fun h => do
|
||||
let e ← mkEqNDRec (motive := withCtorTypeNameApp) k h
|
||||
let e := mkApp e (mkConst ``PUnit.unit [indLevel])
|
||||
let e ← mkPULiftDown e
|
||||
let e := mkAppN e zs
|
||||
-- ``Eq.ndrec
|
||||
mkLambdaFVars #[h] e
|
||||
|
|
@ -191,15 +246,16 @@ def mkNoConfusionTypeLinear (indName : Name) : MetaM Unit := do
|
|||
let alt := mkAppN alt xs
|
||||
let alt := mkApp alt PType
|
||||
let alt := mkApp alt (mkRawNatLit i)
|
||||
let k ← forallTelescopeReducing (← inferType alt).bindingDomain! fun zs2 _ => do
|
||||
let eqs ← (Array.zip zs1 zs2[1:]).filterMapM fun (z1,z2) => do
|
||||
if (← isProof z1) then
|
||||
return none
|
||||
else
|
||||
return some (← mkEqHEq z1 z2)
|
||||
let k ← mkArrowN eqs P
|
||||
let k ← mkArrow k P
|
||||
mkLambdaFVars zs2 k
|
||||
let k ← withMkPULiftUp (← inferType alt).bindingDomain! fun t =>
|
||||
forallTelescopeReducing t fun zs2 _ => do
|
||||
let eqs ← (Array.zip zs1 zs2).filterMapM fun (z1,z2) => do
|
||||
if (← isProof z1) then
|
||||
return none
|
||||
else
|
||||
return some (← mkEqHEq z1 z2)
|
||||
let k ← mkArrowN eqs P
|
||||
let k ← mkArrow k P
|
||||
mkLambdaFVars zs2 k
|
||||
let alt := mkApp alt k
|
||||
let alt := mkApp alt P
|
||||
let alt := mkAppN alt ys
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ inductive Expr : id Type
|
|||
|
||||
partial def Expr.fold (f : Nat → α → α) : Expr → α → α
|
||||
| var n, a => f n a
|
||||
| app s as, a => as.foldl (init := a) fun a e => fold f e a
|
||||
| app _ as, a => as.foldl (init := a) fun a e => fold f e a
|
||||
|
||||
def Expr.isVar : Expr → Bool
|
||||
| var _ => true
|
||||
|
|
|
|||
40
tests/lean/run/issue8962.lean
Normal file
40
tests/lean/run/issue8962.lean
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
-- This triggered a bug in the linear-size `noConfusionType` construction
|
||||
-- which confused the kernel when producing the `noConfusion` lemma.
|
||||
|
||||
set_option debug.skipKernelTC true
|
||||
set_option pp.universes true
|
||||
|
||||
-- Works
|
||||
|
||||
inductive S where
|
||||
| a {α : Sort u} {β : Type v} (f : α → β)
|
||||
| b
|
||||
|
||||
/--
|
||||
info: @[reducible] protected def S.noConfusionType.withCtorType.{u_1, u, v} : Type u_1 → Nat → Type (max u u_1 (v + 1)) :=
|
||||
fun P ctorIdx =>
|
||||
bif Nat.blt ctorIdx 1 then
|
||||
PULift.{max (u + 1) (u_1 + 1) (v + 2), max (max (u + 1) (u_1 + 1)) (v + 2)}
|
||||
({α : Sort u} → {β : Type v} → (α → β) → P)
|
||||
else PULift.{max (u + 1) (u_1 + 1) (v + 2), u_1 + 1} P
|
||||
-/
|
||||
#guard_msgs in
|
||||
#print S.noConfusionType.withCtorType
|
||||
|
||||
-- Didn't work
|
||||
|
||||
inductive T where
|
||||
| a {α : Sort u} {β : Sort v} (f : α → β)
|
||||
| b
|
||||
|
||||
/--
|
||||
info: @[reducible] protected def T.noConfusionType.withCtorType.{u_1, u, v} : Type u_1 →
|
||||
Nat → Sort (max (u + 1) (u_1 + 1) (v + 1) (imax u v)) :=
|
||||
fun P ctorIdx =>
|
||||
bif Nat.blt ctorIdx 1 then
|
||||
PULift.{max (u + 1) (u_1 + 1) (v + 1) (imax u v), max (max (max (u + 1) (u_1 + 1)) (v + 1)) (imax u v)}
|
||||
({α : Sort u} → {β : Sort v} → (α → β) → P)
|
||||
else PULift.{max (u + 1) (u_1 + 1) (v + 1) (imax u v), u_1 + 1} P
|
||||
-/
|
||||
#guard_msgs in
|
||||
#print T.noConfusionType.withCtorType
|
||||
|
|
@ -12,53 +12,60 @@ inductive Vec.{u} (α : Type) : Nat → Type u where
|
|||
| nil : Vec α 0
|
||||
| cons {n} : α → Vec α n → Vec α (n + 1)
|
||||
|
||||
|
||||
@[reducible] protected def Vec.noConfusionType.withCtorType'.{u_1, u} :
|
||||
Type → Type u_1 → Nat → Type (max (u + 1) u_1) := fun α P ctorIdx =>
|
||||
bif Nat.blt ctorIdx 1
|
||||
then PUnit.{u + 2} → P
|
||||
else PUnit.{u + 2} → {n : Nat} → α → Vec.{u} α n → P
|
||||
Type → Type u_1 → Nat → Type (max u u_1) :=
|
||||
fun α P ctorIdx =>
|
||||
bif Nat.blt ctorIdx 1 then PULift.{max (u+1) (u_1+1)} P
|
||||
else PULift.{max (u+1) (u_1+1)} ({n : Nat} → α → Vec.{u} α n → P)
|
||||
|
||||
/--
|
||||
info: @[reducible] protected def Vec.noConfusionType.withCtorType.{u_1, u} : Type → Type u_1 → Nat → Type (max (u + 1) u_1) :=
|
||||
fun α P ctorIdx => bif ctorIdx.blt 1 then PUnit → P else PUnit → {n : Nat} → α → Vec α n → P
|
||||
info: @[reducible] protected def Vec.noConfusionType.withCtorType.{u_1, u} : Type → Type u_1 → Nat → Type (max u u_1) :=
|
||||
fun α P ctorIdx =>
|
||||
bif Nat.blt ctorIdx 1 then PULift.{max (u + 1) (u_1 + 1), u_1 + 1} P
|
||||
else PULift.{max (u + 1) (u_1 + 1), max (u + 1) (u_1 + 1)} ({n : Nat} → α → Vec.{u} α n → P)
|
||||
-/
|
||||
#guard_msgs in
|
||||
set_option pp.universes true in
|
||||
#print Vec.noConfusionType.withCtorType
|
||||
|
||||
example : @Vec.noConfusionType.withCtorType.{u_1,u} = @Vec.noConfusionType.withCtorType'.{u_1,u} := rfl
|
||||
|
||||
|
||||
@[reducible] protected noncomputable def Vec.noConfusionType.withCtor'.{u_1, u} : (α : Type) →
|
||||
(P : Type u_1) → (ctorIdx : Nat) → Vec.noConfusionType.withCtorType' α P ctorIdx → P → (a : Nat) → Vec.{u} α a → P :=
|
||||
fun _α _P ctorIdx k k' _a x =>
|
||||
Vec.casesOn x
|
||||
(if h : ctorIdx = 0 then Eq.ndrec k h PUnit.unit else k')
|
||||
(fun a a_1 => if h : ctorIdx = 1 then Eq.ndrec k h PUnit.unit a a_1 else k')
|
||||
(if h : ctorIdx = 0 then (Eq.ndrec k h).down else k')
|
||||
(fun a a_1 => if h : ctorIdx = 1 then (Eq.ndrec k h).down a a_1 else k')
|
||||
|
||||
/--
|
||||
info: @[reducible] protected def Vec.noConfusionType.withCtor.{u_1, u} : (α : Type) →
|
||||
(P : Type u_1) → (ctorIdx : Nat) → Vec.noConfusionType.withCtorType α P ctorIdx → P → (a : Nat) → Vec α a → P :=
|
||||
fun α P ctorIdx k k' a x =>
|
||||
Vec.casesOn x (if h : ctorIdx = 0 then (h ▸ k) PUnit.unit else k') fun {n} a a_1 =>
|
||||
if h : ctorIdx = 1 then (h ▸ k) PUnit.unit a a_1 else k'
|
||||
Vec.casesOn x (if h : ctorIdx = 0 then (h ▸ k).down else k') fun {n} a a_1 =>
|
||||
if h : ctorIdx = 1 then (h ▸ k).down a a_1 else k'
|
||||
-/
|
||||
#guard_msgs in
|
||||
#print Vec.noConfusionType.withCtor
|
||||
|
||||
example : @Vec.noConfusionType.withCtor.{u_1,u} = @Vec.noConfusionType.withCtor'.{u_1,u} := rfl
|
||||
|
||||
|
||||
@[reducible] protected def Vec.noConfusionType'.{u_1, u} : {α : Type} →
|
||||
{a : Nat} → Sort u_1 → Vec.{u} α a → Vec α a → Sort u_1 :=
|
||||
fun {α} {a} P x1 x2 =>
|
||||
Vec.casesOn x1
|
||||
(Vec.noConfusionType.withCtor' α (Sort u_1) 0 (fun _x => P → P) P a x2)
|
||||
(fun {n} a_1 a_2 => Vec.noConfusionType.withCtor' α (Sort u_1) 1 (fun _x {n_1} a a_3 => (n = n_1 → a_1 = a → a_2 ≍ a_3 → P) → P) P a x2)
|
||||
(Vec.noConfusionType.withCtor' α (Sort u_1) 0 ⟨P → P⟩ P a x2)
|
||||
(fun {n} a_1 a_2 => Vec.noConfusionType.withCtor' α (Sort u_1) 1 ⟨fun {n_1} a a_3 => (n = n_1 → a_1 = a → a_2 ≍ a_3 → P) → P⟩ P a x2)
|
||||
|
||||
/--
|
||||
info: @[reducible] protected def Vec.noConfusionType.{u_1, u} : {α : Type} →
|
||||
{a : Nat} → Sort u_1 → Vec α a → Vec α a → Sort u_1 :=
|
||||
fun {α} {a} P x1 x2 =>
|
||||
Vec.casesOn x1 (Vec.noConfusionType.withCtor α (Sort u_1) 0 (fun x => P → P) P a x2) fun {n} a_1 a_2 =>
|
||||
Vec.noConfusionType.withCtor α (Sort u_1) 1 (fun x {n_1} a a_3 => (n = n_1 → a_1 = a → a_2 ≍ a_3 → P) → P) P a x2
|
||||
Vec.casesOn x1 (Vec.noConfusionType.withCtor α (Sort u_1) 0 { down := P → P } P a x2) fun {n} a_1 a_2 =>
|
||||
Vec.noConfusionType.withCtor α (Sort u_1) 1 { down := fun {n_1} a a_3 => (n = n_1 → a_1 = a → a_2 ≍ a_3 → P) → P } P
|
||||
a x2
|
||||
-/
|
||||
#guard_msgs in
|
||||
#print Vec.noConfusionType
|
||||
|
|
@ -84,3 +91,9 @@ run_meta do
|
|||
-- inductive Enum.{u} : Type u where | a | b
|
||||
-- set_option pp.universes true in
|
||||
-- #print noConfusionTypeEnum
|
||||
|
||||
-- A possibly tricky universes case (resulting universe cannot be decremented)
|
||||
|
||||
inductive UnivTest.{u,v} (α : Sort v): Sort (max u v 1) where
|
||||
| mk1 : UnivTest α
|
||||
| mk2 : (x : α) → UnivTest α
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue