fix: allow arbitrary sorts in structural recursion over reflexive inductive types (#7639)

This PR changes the generated `below` and `brecOn` implementations for
reflexive inductive types to support motives in `Sort u` rather than
`Type u`.

Closes #7638
This commit is contained in:
Parth Shastri 2025-06-13 17:51:09 -04:00 committed by GitHub
parent 812bab6910
commit 5d50433e6a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 66 additions and 75 deletions

View file

@ -240,14 +240,7 @@ def mkBRecOnConst (recArgInfos : Array RecArgInfo) (positions : Positions)
let indGroup := recArgInfos[0]!.indGroupInst
let motive := motives[0]!
let brecOnUniv ← lambdaTelescope motive fun _ type => getLevel type
let indInfo ← getConstInfoInduct indGroup.all[0]!
let useBInductionOn := indInfo.isReflexive && brecOnUniv == levelZero
let brecOnUniv ←
if indInfo.isReflexive && brecOnUniv != levelZero then
decLevel brecOnUniv
else
pure brecOnUniv
let brecOnCons := fun idx => indGroup.brecOn useBInductionOn brecOnUniv idx
let brecOnCons := fun idx => indGroup.brecOn false brecOnUniv idx
-- Pick one as a prototype
let brecOnAux := brecOnCons 0
-- Infer the type of the packed motive arguments

View file

@ -70,41 +70,38 @@ def getRecArgInfo (fnName : Name) (fixedParamPerm : FixedParamPerm) (xs : Array
throwError "it is a let-binding"
let xType ← whnfD localDecl.type
matchConstInduct xType.getAppFn (fun _ => throwError "its type is not an inductive") fun indInfo us => do
if indInfo.isReflexive && !(← hasConst (mkBInductionOnName indInfo.name)) && !(← isInductivePredicate indInfo.name) then
throwError "its type {indInfo.name} is a reflexive inductive, but {mkBInductionOnName indInfo.name} does not exist and it is not an inductive predicate"
let indArgs : Array Expr := xType.getAppArgs
let indParams : Array Expr := indArgs[0:indInfo.numParams]
let indIndices : Array Expr := indArgs[indInfo.numParams:]
if !indIndices.all Expr.isFVar then
throwError "its type {indInfo.name} is an inductive family and indices are not variables{indentExpr xType}"
else if !indIndices.allDiff then
throwError "its type {indInfo.name} is an inductive family and indices are not pairwise distinct{indentExpr xType}"
else
let indArgs : Array Expr := xType.getAppArgs
let indParams : Array Expr := indArgs[0:indInfo.numParams]
let indIndices : Array Expr := indArgs[indInfo.numParams:]
if !indIndices.all Expr.isFVar then
throwError "its type {indInfo.name} is an inductive family and indices are not variables{indentExpr xType}"
else if !indIndices.allDiff then
throwError "its type {indInfo.name} is an inductive family and indices are not pairwise distinct{indentExpr xType}"
else
let ys := fixedParamPerm.pickVarying xs
match (← hasBadIndexDep? ys indIndices) with
| some (index, y) =>
throwError "its type {indInfo.name} is an inductive family{indentExpr xType}\nand index{indentExpr index}\ndepends on the non index{indentExpr y}"
let ys := fixedParamPerm.pickVarying xs
match (← hasBadIndexDep? ys indIndices) with
| some (index, y) =>
throwError "its type {indInfo.name} is an inductive family{indentExpr xType}\nand index{indentExpr index}\ndepends on the non index{indentExpr y}"
| none =>
match (← hasBadParamDep? ys indParams) with
| some (indParam, y) =>
throwError "its type is an inductive datatype{indentExpr xType}\nand the datatype parameter{indentExpr indParam}\ndepends on the function parameter{indentExpr y}\nwhich is not fixed."
| none =>
match (← hasBadParamDep? ys indParams) with
| some (indParam, y) =>
throwError "its type is an inductive datatype{indentExpr xType}\nand the datatype parameter{indentExpr indParam}\ndepends on the function parameter{indentExpr y}\nwhich is not fixed."
| none =>
let indAll := indInfo.all.toArray
let .some indIdx := indAll.idxOf? indInfo.name | panic! "{indInfo.name} not in {indInfo.all}"
let indicesPos := indIndices.map fun index => match xs.idxOf? index with | some i => i | none => unreachable!
let indGroupInst := {
IndGroupInfo.ofInductiveVal indInfo with
levels := us
params := indParams }
return { fnName := fnName
fixedParamPerm := fixedParamPerm
recArgPos := i
indicesPos := indicesPos
indGroupInst := indGroupInst
indIdx := indIdx }
else
throwError "the index #{i+1} exceeds {xs.size}, the number of parameters"
let indAll := indInfo.all.toArray
let .some indIdx := indAll.idxOf? indInfo.name | panic! "{indInfo.name} not in {indInfo.all}"
let indicesPos := indIndices.map fun index => match xs.idxOf? index with | some i => i | none => unreachable!
let indGroupInst := {
IndGroupInfo.ofInductiveVal indInfo with
levels := us
params := indParams }
return { fnName := fnName
fixedParamPerm := fixedParamPerm
recArgPos := i
indicesPos := indicesPos
indGroupInst := indGroupInst
indIdx := indIdx }
else
throwError "the index #{i+1} exceeds {xs.size}, the number of parameters"
/--
Collects the `RecArgInfos` for one function, and returns a report for why the others were not

View file

@ -13,14 +13,6 @@ import Lean.Meta.PProdN
namespace Lean
open Meta
/-- Transforms `e : xᵢ → (t₁ ×' t₂)` into `(xᵢ → t₁) ×' (xᵢ → t₂) -/
private def etaPProd (xs : Array Expr) (e : Expr) : MetaM Expr := do
if xs.isEmpty then return e
let r := mkAppN e xs
let r₁ ← mkLambdaFVars xs (← mkPProdFstM r)
let r₂ ← mkLambdaFVars xs (← mkPProdSndM r)
mkPProdMk r₁ r₂
/--
If `minorType` is the type of a minor premies of a recursor, such as
```
@ -40,7 +32,6 @@ of type
private def buildBelowMinorPremise (rlvl : Level) (motives : Array Expr) (minorType : Expr) : MetaM Expr :=
forallTelescope minorType fun minor_args _ => do go #[] minor_args.toList
where
ibelow := rlvl matches .zero
go (prods : Array Expr) : List Expr → MetaM Expr
| [] => PProdN.pack rlvl prods
| arg::args => do
@ -50,8 +41,7 @@ where
let name ← arg.fvarId!.getUserName
let type' ← forallTelescope argType fun args _ => mkForallFVars args (.sort rlvl)
withLocalDeclD name type' fun arg' => do
let snd ← mkForallFVars arg_args (mkAppN arg' arg_args)
let e' ← mkPProd argType snd
let e' ← mkForallFVars arg_args <| ← mkPProd arg_type (mkAppN arg' arg_args)
mkLambdaFVars #[arg'] (← go (prods.push e') args)
else
mkLambdaFVars #[arg] (← go prods args)
@ -86,8 +76,6 @@ private def mkBelowFromRec (recName : Name) (ibelow reflexive : Bool) (nParams :
let refType :=
if ibelow then
recVal.type.instantiateLevelParams [lvlParam] [0]
else if reflexive then
recVal.type.instantiateLevelParams [lvlParam] [lvl.succ]
else
recVal.type
@ -116,12 +104,9 @@ private def mkBelowFromRec (recName : Name) (ibelow reflexive : Bool) (nParams :
if ibelow then
0
else if reflexive then
if let .max 1 ilvl' := ilvl then
mkLevelMax' (.succ lvl) ilvl'
else
mkLevelMax' (.succ lvl) ilvl
mkLevelMax ilvl lvl
else
mkLevelMax' 1 lvl
mkLevelMax 1 lvl
let mut val := .const recName (rlvl.succ :: lvls)
-- add parameters
@ -168,8 +153,8 @@ private def mkBelowOrIBelow (indName : Name) (ibelow : Bool) : MetaM Unit := do
let belowName := belowName.appendIndexAfter (i + 1)
mkBelowFromRec recName ibelow indVal.isReflexive indVal.numParams belowName
def mkBelow (declName : Name) : MetaM Unit := mkBelowOrIBelow declName true
def mkIBelow (declName : Name) : MetaM Unit := mkBelowOrIBelow declName false
def mkBelow (declName : Name) : MetaM Unit := mkBelowOrIBelow declName false
def mkIBelow (declName : Name) : MetaM Unit := mkBelowOrIBelow declName true
/--
If `minorType` is the type of a minor premies of a recursor, such as
@ -207,8 +192,7 @@ private def buildBRecOnMinorPremise (rlvl : Level) (motives : Array Expr)
let type' ← mkForallFVars arg_args
(← mkPProd arg_type (mkAppN belows[idx]! arg_type_args) )
withLocalDeclD name type' fun arg' => do
let r ← etaPProd arg_args arg'
mkLambdaFVars #[arg'] (← go (prods.push r) args)
mkLambdaFVars #[arg'] (← go (prods.push arg') args)
else
mkLambdaFVars #[arg] (← go prods args)
go #[] minor_args.toList
@ -251,8 +235,6 @@ private def mkBRecOnFromRec (recName : Name) (ind reflexive : Bool) (nParams : N
let refType :=
if ind then
recVal.type.instantiateLevelParams [lvlParam] [0]
else if reflexive then
recVal.type.instantiateLevelParams [lvlParam] [lvl.succ]
else
recVal.type
@ -279,12 +261,9 @@ private def mkBRecOnFromRec (recName : Name) (ind reflexive : Bool) (nParams : N
if ind then
0
else if reflexive then
if let .max 1 ilvl' := ilvl then
mkLevelMax' (.succ lvl) ilvl'
else
mkLevelMax' (.succ lvl) ilvl
mkLevelMax ilvl lvl
else
mkLevelMax' 1 lvl
mkLevelMax 1 lvl
-- One `below` for each motive, with the same motive parameters
let blvls := if ind then lvls else lvl::lvls

View file

@ -8,8 +8,6 @@ Author: Leonardo de Moura
namespace lean {
constexpr char const * g_rec = "rec";
constexpr char const * g_brec_on = "brecOn";
constexpr char const * g_binduction_on = "binductionOn";
constexpr char const * g_cases_on = "casesOn";
constexpr char const * g_no_confusion = "noConfusion";
constexpr char const * g_no_confusion_type = "noConfusionType";

24
tests/lean/run/7638.lean Normal file
View file

@ -0,0 +1,24 @@
inductive Foo : Type
| mk : Foo → Foo
inductive Bar : Type
| mk : (Unit → Bar) → Bar
def Foo.elim {α : Sort u} : Foo → α
| ⟨foo⟩ => elim foo
termination_by structural foo => foo
def Bar.elim {α : Sort u} : Bar → α
| ⟨bar⟩ => elim (bar ())
termination_by structural bar => bar
inductive StressTest : Type 5
| f (x : Type 4 → StressTest)
| g (x : Type 3 → StressTest)
| h (x : Type 4 → StressTest) (y : Type 3 → StressTest)
def StressTest.elim {α : Sort u} : StressTest → α
| f x => elim (x (Type 3))
| g x => elim (x (Type 2))
| h x _y => elim (x (Type 3))
termination_by structural t => t

View file

@ -4,7 +4,7 @@ inductive Foo1 : Sort (max 1 u) where
| intro: (h : Nat → Foo1) → Foo1
/--
info: Foo1.below.{u_1, u} {motive : Foo1.{u} → Type u_1} (t : Foo1.{u}) : Sort (max (u_1 + 1) u)
info: Foo1.below.{u_1, u} {motive : Foo1.{u} → Sort u_1} (t : Foo1.{u}) : Sort (max (max 1 u) u_1)
-/
#guard_msgs in
#check Foo1.below
@ -13,7 +13,7 @@ inductive Foo2 : Sort (max u 1) where
| intro: (h : Nat → Foo2) → Foo2
/--
info: Foo2.below.{u_1, u} {motive : Foo2.{u} → Type u_1} (t : Foo2.{u}) : Sort (max (u_1 + 1) u 1)
info: Foo2.below.{u_1, u} {motive : Foo2.{u} → Sort u_1} (t : Foo2.{u}) : Sort (max (max u 1) u_1)
-/
#guard_msgs in
#check Foo2.below
@ -21,7 +21,7 @@ info: Foo2.below.{u_1, u} {motive : Foo2.{u} → Type u_1} (t : Foo2.{u}) : Sort
inductive Foo3 : Sort (u+1) where
| intro: (h : Nat → Foo3) → Foo3
/-- info: Foo3.below.{u_1, u} {motive : Foo3.{u} → Type u_1} (t : Foo3.{u}) : Type (max u_1 u) -/
/-- info: Foo3.below.{u_1, u} {motive : Foo3.{u} → Sort u_1} (t : Foo3.{u}) : Sort (max (u + 1) u_1) -/
#guard_msgs in
#check Foo3.below