From 5d50433e6ac6a69ec422fa342fcf559dd47c4da4 Mon Sep 17 00:00:00 2001 From: Parth Shastri <31370288+cppio@users.noreply.github.com> Date: Fri, 13 Jun 2025 17:51:09 -0400 Subject: [PATCH] fix: allow arbitrary sorts in structural recursion over reflexive inductive types (#7639) This PR changes the generated `below` and `brecOn` implementations for reflexive inductive types to support motives in `Sort u` rather than `Type u`. Closes #7638 --- .../Elab/PreDefinition/Structural/BRecOn.lean | 9 +-- .../PreDefinition/Structural/FindRecArg.lean | 63 +++++++++---------- src/Lean/Meta/Constructions/BRecOn.lean | 37 +++-------- src/library/suffixes.h | 2 - tests/lean/run/7638.lean | 24 +++++++ tests/lean/run/issue4650.lean | 6 +- 6 files changed, 66 insertions(+), 75 deletions(-) create mode 100644 tests/lean/run/7638.lean diff --git a/src/Lean/Elab/PreDefinition/Structural/BRecOn.lean b/src/Lean/Elab/PreDefinition/Structural/BRecOn.lean index f801d4336e..8201aa4b7d 100644 --- a/src/Lean/Elab/PreDefinition/Structural/BRecOn.lean +++ b/src/Lean/Elab/PreDefinition/Structural/BRecOn.lean @@ -240,14 +240,7 @@ def mkBRecOnConst (recArgInfos : Array RecArgInfo) (positions : Positions) let indGroup := recArgInfos[0]!.indGroupInst let motive := motives[0]! let brecOnUniv ← lambdaTelescope motive fun _ type => getLevel type - let indInfo ← getConstInfoInduct indGroup.all[0]! - let useBInductionOn := indInfo.isReflexive && brecOnUniv == levelZero - let brecOnUniv ← - if indInfo.isReflexive && brecOnUniv != levelZero then - decLevel brecOnUniv - else - pure brecOnUniv - let brecOnCons := fun idx => indGroup.brecOn useBInductionOn brecOnUniv idx + let brecOnCons := fun idx => indGroup.brecOn false brecOnUniv idx -- Pick one as a prototype let brecOnAux := brecOnCons 0 -- Infer the type of the packed motive arguments diff --git a/src/Lean/Elab/PreDefinition/Structural/FindRecArg.lean b/src/Lean/Elab/PreDefinition/Structural/FindRecArg.lean index aa3af8fbcf..927d8e077c 100644 --- a/src/Lean/Elab/PreDefinition/Structural/FindRecArg.lean +++ b/src/Lean/Elab/PreDefinition/Structural/FindRecArg.lean @@ -70,41 +70,38 @@ def getRecArgInfo (fnName : Name) (fixedParamPerm : FixedParamPerm) (xs : Array throwError "it is a let-binding" let xType ← whnfD localDecl.type matchConstInduct xType.getAppFn (fun _ => throwError "its type is not an inductive") fun indInfo us => do - if indInfo.isReflexive && !(← hasConst (mkBInductionOnName indInfo.name)) && !(← isInductivePredicate indInfo.name) then - throwError "its type {indInfo.name} is a reflexive inductive, but {mkBInductionOnName indInfo.name} does not exist and it is not an inductive predicate" + let indArgs : Array Expr := xType.getAppArgs + let indParams : Array Expr := indArgs[0:indInfo.numParams] + let indIndices : Array Expr := indArgs[indInfo.numParams:] + if !indIndices.all Expr.isFVar then + throwError "its type {indInfo.name} is an inductive family and indices are not variables{indentExpr xType}" + else if !indIndices.allDiff then + throwError "its type {indInfo.name} is an inductive family and indices are not pairwise distinct{indentExpr xType}" else - let indArgs : Array Expr := xType.getAppArgs - let indParams : Array Expr := indArgs[0:indInfo.numParams] - let indIndices : Array Expr := indArgs[indInfo.numParams:] - if !indIndices.all Expr.isFVar then - throwError "its type {indInfo.name} is an inductive family and indices are not variables{indentExpr xType}" - else if !indIndices.allDiff then - throwError "its type {indInfo.name} is an inductive family and indices are not pairwise distinct{indentExpr xType}" - else - let ys := fixedParamPerm.pickVarying xs - match (← hasBadIndexDep? ys indIndices) with - | some (index, y) => - throwError "its type {indInfo.name} is an inductive family{indentExpr xType}\nand index{indentExpr index}\ndepends on the non index{indentExpr y}" + let ys := fixedParamPerm.pickVarying xs + match (← hasBadIndexDep? ys indIndices) with + | some (index, y) => + throwError "its type {indInfo.name} is an inductive family{indentExpr xType}\nand index{indentExpr index}\ndepends on the non index{indentExpr y}" + | none => + match (← hasBadParamDep? ys indParams) with + | some (indParam, y) => + throwError "its type is an inductive datatype{indentExpr xType}\nand the datatype parameter{indentExpr indParam}\ndepends on the function parameter{indentExpr y}\nwhich is not fixed." | none => - match (← hasBadParamDep? ys indParams) with - | some (indParam, y) => - throwError "its type is an inductive datatype{indentExpr xType}\nand the datatype parameter{indentExpr indParam}\ndepends on the function parameter{indentExpr y}\nwhich is not fixed." - | none => - let indAll := indInfo.all.toArray - let .some indIdx := indAll.idxOf? indInfo.name | panic! "{indInfo.name} not in {indInfo.all}" - let indicesPos := indIndices.map fun index => match xs.idxOf? index with | some i => i | none => unreachable! - let indGroupInst := { - IndGroupInfo.ofInductiveVal indInfo with - levels := us - params := indParams } - return { fnName := fnName - fixedParamPerm := fixedParamPerm - recArgPos := i - indicesPos := indicesPos - indGroupInst := indGroupInst - indIdx := indIdx } - else - throwError "the index #{i+1} exceeds {xs.size}, the number of parameters" + let indAll := indInfo.all.toArray + let .some indIdx := indAll.idxOf? indInfo.name | panic! "{indInfo.name} not in {indInfo.all}" + let indicesPos := indIndices.map fun index => match xs.idxOf? index with | some i => i | none => unreachable! + let indGroupInst := { + IndGroupInfo.ofInductiveVal indInfo with + levels := us + params := indParams } + return { fnName := fnName + fixedParamPerm := fixedParamPerm + recArgPos := i + indicesPos := indicesPos + indGroupInst := indGroupInst + indIdx := indIdx } + else + throwError "the index #{i+1} exceeds {xs.size}, the number of parameters" /-- Collects the `RecArgInfos` for one function, and returns a report for why the others were not diff --git a/src/Lean/Meta/Constructions/BRecOn.lean b/src/Lean/Meta/Constructions/BRecOn.lean index 7f5715994e..1ec88e2fb6 100644 --- a/src/Lean/Meta/Constructions/BRecOn.lean +++ b/src/Lean/Meta/Constructions/BRecOn.lean @@ -13,14 +13,6 @@ import Lean.Meta.PProdN namespace Lean open Meta -/-- Transforms `e : xᵢ → (t₁ ×' t₂)` into `(xᵢ → t₁) ×' (xᵢ → t₂) -/ -private def etaPProd (xs : Array Expr) (e : Expr) : MetaM Expr := do - if xs.isEmpty then return e - let r := mkAppN e xs - let r₁ ← mkLambdaFVars xs (← mkPProdFstM r) - let r₂ ← mkLambdaFVars xs (← mkPProdSndM r) - mkPProdMk r₁ r₂ - /-- If `minorType` is the type of a minor premies of a recursor, such as ``` @@ -40,7 +32,6 @@ of type private def buildBelowMinorPremise (rlvl : Level) (motives : Array Expr) (minorType : Expr) : MetaM Expr := forallTelescope minorType fun minor_args _ => do go #[] minor_args.toList where - ibelow := rlvl matches .zero go (prods : Array Expr) : List Expr → MetaM Expr | [] => PProdN.pack rlvl prods | arg::args => do @@ -50,8 +41,7 @@ where let name ← arg.fvarId!.getUserName let type' ← forallTelescope argType fun args _ => mkForallFVars args (.sort rlvl) withLocalDeclD name type' fun arg' => do - let snd ← mkForallFVars arg_args (mkAppN arg' arg_args) - let e' ← mkPProd argType snd + let e' ← mkForallFVars arg_args <| ← mkPProd arg_type (mkAppN arg' arg_args) mkLambdaFVars #[arg'] (← go (prods.push e') args) else mkLambdaFVars #[arg] (← go prods args) @@ -86,8 +76,6 @@ private def mkBelowFromRec (recName : Name) (ibelow reflexive : Bool) (nParams : let refType := if ibelow then recVal.type.instantiateLevelParams [lvlParam] [0] - else if reflexive then - recVal.type.instantiateLevelParams [lvlParam] [lvl.succ] else recVal.type @@ -116,12 +104,9 @@ private def mkBelowFromRec (recName : Name) (ibelow reflexive : Bool) (nParams : if ibelow then 0 else if reflexive then - if let .max 1 ilvl' := ilvl then - mkLevelMax' (.succ lvl) ilvl' - else - mkLevelMax' (.succ lvl) ilvl + mkLevelMax ilvl lvl else - mkLevelMax' 1 lvl + mkLevelMax 1 lvl let mut val := .const recName (rlvl.succ :: lvls) -- add parameters @@ -168,8 +153,8 @@ private def mkBelowOrIBelow (indName : Name) (ibelow : Bool) : MetaM Unit := do let belowName := belowName.appendIndexAfter (i + 1) mkBelowFromRec recName ibelow indVal.isReflexive indVal.numParams belowName -def mkBelow (declName : Name) : MetaM Unit := mkBelowOrIBelow declName true -def mkIBelow (declName : Name) : MetaM Unit := mkBelowOrIBelow declName false +def mkBelow (declName : Name) : MetaM Unit := mkBelowOrIBelow declName false +def mkIBelow (declName : Name) : MetaM Unit := mkBelowOrIBelow declName true /-- If `minorType` is the type of a minor premies of a recursor, such as @@ -207,8 +192,7 @@ private def buildBRecOnMinorPremise (rlvl : Level) (motives : Array Expr) let type' ← mkForallFVars arg_args (← mkPProd arg_type (mkAppN belows[idx]! arg_type_args) ) withLocalDeclD name type' fun arg' => do - let r ← etaPProd arg_args arg' - mkLambdaFVars #[arg'] (← go (prods.push r) args) + mkLambdaFVars #[arg'] (← go (prods.push arg') args) else mkLambdaFVars #[arg] (← go prods args) go #[] minor_args.toList @@ -251,8 +235,6 @@ private def mkBRecOnFromRec (recName : Name) (ind reflexive : Bool) (nParams : N let refType := if ind then recVal.type.instantiateLevelParams [lvlParam] [0] - else if reflexive then - recVal.type.instantiateLevelParams [lvlParam] [lvl.succ] else recVal.type @@ -279,12 +261,9 @@ private def mkBRecOnFromRec (recName : Name) (ind reflexive : Bool) (nParams : N if ind then 0 else if reflexive then - if let .max 1 ilvl' := ilvl then - mkLevelMax' (.succ lvl) ilvl' - else - mkLevelMax' (.succ lvl) ilvl + mkLevelMax ilvl lvl else - mkLevelMax' 1 lvl + mkLevelMax 1 lvl -- One `below` for each motive, with the same motive parameters let blvls := if ind then lvls else lvl::lvls diff --git a/src/library/suffixes.h b/src/library/suffixes.h index 3d39166092..9fdd6e8f8d 100644 --- a/src/library/suffixes.h +++ b/src/library/suffixes.h @@ -8,8 +8,6 @@ Author: Leonardo de Moura namespace lean { constexpr char const * g_rec = "rec"; -constexpr char const * g_brec_on = "brecOn"; -constexpr char const * g_binduction_on = "binductionOn"; constexpr char const * g_cases_on = "casesOn"; constexpr char const * g_no_confusion = "noConfusion"; constexpr char const * g_no_confusion_type = "noConfusionType"; diff --git a/tests/lean/run/7638.lean b/tests/lean/run/7638.lean new file mode 100644 index 0000000000..293ad19248 --- /dev/null +++ b/tests/lean/run/7638.lean @@ -0,0 +1,24 @@ +inductive Foo : Type + | mk : Foo → Foo + +inductive Bar : Type + | mk : (Unit → Bar) → Bar + +def Foo.elim {α : Sort u} : Foo → α + | ⟨foo⟩ => elim foo + termination_by structural foo => foo + +def Bar.elim {α : Sort u} : Bar → α + | ⟨bar⟩ => elim (bar ()) + termination_by structural bar => bar + +inductive StressTest : Type 5 + | f (x : Type 4 → StressTest) + | g (x : Type 3 → StressTest) + | h (x : Type 4 → StressTest) (y : Type 3 → StressTest) + +def StressTest.elim {α : Sort u} : StressTest → α + | f x => elim (x (Type 3)) + | g x => elim (x (Type 2)) + | h x _y => elim (x (Type 3)) + termination_by structural t => t diff --git a/tests/lean/run/issue4650.lean b/tests/lean/run/issue4650.lean index b6659e361a..afa84e2e3c 100644 --- a/tests/lean/run/issue4650.lean +++ b/tests/lean/run/issue4650.lean @@ -4,7 +4,7 @@ inductive Foo1 : Sort (max 1 u) where | intro: (h : Nat → Foo1) → Foo1 /-- -info: Foo1.below.{u_1, u} {motive : Foo1.{u} → Type u_1} (t : Foo1.{u}) : Sort (max (u_1 + 1) u) +info: Foo1.below.{u_1, u} {motive : Foo1.{u} → Sort u_1} (t : Foo1.{u}) : Sort (max (max 1 u) u_1) -/ #guard_msgs in #check Foo1.below @@ -13,7 +13,7 @@ inductive Foo2 : Sort (max u 1) where | intro: (h : Nat → Foo2) → Foo2 /-- -info: Foo2.below.{u_1, u} {motive : Foo2.{u} → Type u_1} (t : Foo2.{u}) : Sort (max (u_1 + 1) u 1) +info: Foo2.below.{u_1, u} {motive : Foo2.{u} → Sort u_1} (t : Foo2.{u}) : Sort (max (max u 1) u_1) -/ #guard_msgs in #check Foo2.below @@ -21,7 +21,7 @@ info: Foo2.below.{u_1, u} {motive : Foo2.{u} → Type u_1} (t : Foo2.{u}) : Sort inductive Foo3 : Sort (u+1) where | intro: (h : Nat → Foo3) → Foo3 -/-- info: Foo3.below.{u_1, u} {motive : Foo3.{u} → Type u_1} (t : Foo3.{u}) : Type (max u_1 u) -/ +/-- info: Foo3.below.{u_1, u} {motive : Foo3.{u} → Sort u_1} (t : Foo3.{u}) : Sort (max (u + 1) u_1) -/ #guard_msgs in #check Foo3.below