diff --git a/src/Lean/Elab/Inductive.lean b/src/Lean/Elab/Inductive.lean index 80de954edc..bd18f7fc27 100644 --- a/src/Lean/Elab/Inductive.lean +++ b/src/Lean/Elab/Inductive.lean @@ -145,8 +145,8 @@ private def checkHeader (r : ElabHeaderResult) (numParams : Nat) (firstType? : O match firstType? with | none => pure type | some firstType => - withRef r.view.ref $ checkParamsAndResultType type firstType numParams - pure firstType + withRef r.view.ref <| checkParamsAndResultType type firstType numParams + return firstType -- Auxiliary function for checking whether the types in mutually inductive declaration are compatible. private partial def checkHeaders (rs : Array ElabHeaderResult) (numParams : Nat) (i : Nat) (firstType? : Option Expr) : TermElabM Unit := do @@ -173,7 +173,7 @@ private partial def withInductiveLocalDecls {α} (rs : Array ElabHeaderResult) ( pure (r.view.shortDeclName, type) let r0 := rs[0] let params := r0.params - withLCtx r0.lctx r0.localInsts $ withRef r0.view.ref do + withLCtx r0.lctx r0.localInsts <| withRef r0.view.ref do let rec loop (i : Nat) (indFVars : Array Expr) := do if h : i < namesAndTypes.size then let (id, type) := namesAndTypes.get ⟨i, h⟩ @@ -308,58 +308,11 @@ def shouldInferResultUniverse (u : Level) : TermElabM Bool := do match u.getLevelOffset with | Level.mvar mvarId _ => do Term.assignLevelMVar mvarId tmpIndParam - pure true + return true | _ => throwError "cannot infer resulting universe level of inductive datatype, given level contains metavariables {mkSort u}, provide universe explicitly" else - pure false - -/- - Auxiliary function for `updateResultingUniverse` - `accLevelAtCtor u r rOffset us` add `u` components to `us` if they are not already there and it is different from the resulting universe level `r+rOffset`. - If `u` is a `max`, then its components are recursively processed. - If `u` is a `succ` and `rOffset > 0`, we process the `u`s child using `rOffset-1`. - - This method is used to infer the resulting universe level of an inductive datatype. -/ -def accLevelAtCtor : Level → Level → Nat → Array Level → TermElabM (Array Level) - | Level.max u v _, r, rOffset, us => do let us ← accLevelAtCtor u r rOffset us; accLevelAtCtor v r rOffset us - | Level.imax u v _, r, rOffset, us => do let us ← accLevelAtCtor u r rOffset us; accLevelAtCtor v r rOffset us - | Level.zero _, _, _, us => pure us - | Level.succ u _, r, rOffset+1, us => accLevelAtCtor u r rOffset us - | u, r, rOffset, us => - if rOffset == 0 && u == r then pure us - else if r.occurs u then throwError "failed to compute resulting universe level of inductive datatype, provide universe explicitly" - else if rOffset > 0 then throwError "failed to compute resulting universe level of inductive datatype, provide universe explicitly" - else if us.contains u then pure us - else pure (us.push u) - -/- Auxiliary function for `updateResultingUniverse` -/ -private partial def collectUniversesFromCtorTypeAux (r : Level) (rOffset : Nat) : Nat → Expr → Array Level → TermElabM (Array Level) - | 0, Expr.forallE n d b c, us => do - let u ← getLevel d - let u ← instantiateLevelMVars u - let us ← accLevelAtCtor u r rOffset us - withLocalDecl n c.binderInfo d fun x => - let e := b.instantiate1 x - collectUniversesFromCtorTypeAux r rOffset 0 e us - | i+1, Expr.forallE n d b c, us => do - withLocalDecl n c.binderInfo d fun x => - let e := b.instantiate1 x - collectUniversesFromCtorTypeAux r rOffset i e us - | _, _, us => pure us - -/- Auxiliary function for `updateResultingUniverse` -/ -private partial def collectUniversesFromCtorType - (r : Level) (rOffset : Nat) (ctorType : Expr) (numParams : Nat) (us : Array Level) : TermElabM (Array Level) := - collectUniversesFromCtorTypeAux r rOffset numParams ctorType us - -/- Auxiliary function for `updateResultingUniverse` -/ -private partial def collectUniverses (r : Level) (rOffset : Nat) (numParams : Nat) (indTypes : List InductiveType) : TermElabM (Array Level) := do - let mut us := #[] - for indType in indTypes do - for ctor in indType.ctors do - us ← collectUniversesFromCtorType r rOffset ctor.type numParams us - return us + return false def mkResultUniverse (us : Array Level) (rOffset : Nat) : Level := if us.isEmpty && rOffset == 0 then @@ -367,10 +320,62 @@ def mkResultUniverse (us : Array Level) (rOffset : Nat) : Level := else let r := Level.mkNaryMax us.toList if rOffset == 0 && !r.isZero && !r.isNeverZero then - (mkLevelMax r levelOne).normalize + mkLevelMax r levelOne |>.normalize else r.normalize +/-- + Auxiliary function for `updateResultingUniverse` + `accLevelAtCtor u r rOffset` add `u` to state if it is not already there and + it is different from the resulting universe level `r+rOffset`. + + If `u` is a `max`, then its components are recursively processed. + If `u` is a `succ` and `rOffset > 0`, we process the `u`s child using `rOffset-1`. + + This method is used to infer the resulting universe level of an inductive datatype. +-/ +def accLevelAtCtor (u : Level) (r : Level) (rOffset : Nat) : StateRefT (Array Level) TermElabM Unit := do + match u, rOffset with + | Level.max u v _, rOffset => accLevelAtCtor u r rOffset; accLevelAtCtor v r rOffset + | Level.imax u v _, rOffset => accLevelAtCtor u r rOffset; accLevelAtCtor v r rOffset + | Level.zero _, _ => return () + | Level.succ u _, rOffset+1 => accLevelAtCtor u r rOffset + | u, rOffset => + if rOffset == 0 && u == r then + return () + else if r.occurs u then + throwError "failed to compute resulting universe level of inductive datatype, provide universe explicitly" + else if rOffset > 0 then + throwError "failed to compute resulting universe level of inductive datatype, provide universe explicitly" + else if (← get).contains u then + return () + else + modify fun us => us.push u + +/-- Auxiliary function for `updateResultingUniverse` -/ +private partial def collectUniverses (r : Level) (rOffset : Nat) (numParams : Nat) (indTypes : List InductiveType) : TermElabM (Array Level) := do + let (_, us) ← go |>.run #[] + return us +where + go : StateRefT (Array Level) TermElabM Unit := + indTypes.forM fun indType => indType.ctors.forM fun ctor => + collectUniversesFromCtorType numParams ctor.type + + collectUniversesFromCtorType (i : Nat) (type : Expr) : StateRefT (Array Level) TermElabM Unit := do + match i, type with + | 0, Expr.forallE n d b c => + let u ← getLevel d + let u ← instantiateLevelMVars u + let us ← accLevelAtCtor u r rOffset + withLocalDecl n c.binderInfo d fun x => + let e := b.instantiate1 x + collectUniversesFromCtorType 0 e + | i+1, Expr.forallE n d b c => + withLocalDecl n c.binderInfo d fun x => + let e := b.instantiate1 x + collectUniversesFromCtorType i e + | _, _ => return () + private def updateResultingUniverse (numParams : Nat) (indTypes : List InductiveType) : TermElabM (List InductiveType) := do let r ← getResultingUniverse indTypes let rOffset : Nat := r.getOffset @@ -413,15 +418,15 @@ private def removeUnused (vars : Array Expr) (indTypes : List InductiveType) : T private def withUsed {α} (vars : Array Expr) (indTypes : List InductiveType) (k : Array Expr → TermElabM α) : TermElabM α := do let (lctx, localInsts, vars) ← removeUnused vars indTypes - withLCtx lctx localInsts $ k vars + withLCtx lctx localInsts <| k vars private def updateParams (vars : Array Expr) (indTypes : List InductiveType) : TermElabM (List InductiveType) := indTypes.mapM fun indType => do let type ← mkForallFVars vars indType.type let ctors ← indType.ctors.mapM fun ctor => do let ctorType ← mkForallFVars vars ctor.type - pure { ctor with type := ctorType } - pure { indType with type := type, ctors := ctors } + return { ctor with type := ctorType } + return { indType with type := type, ctors := ctors } private def collectLevelParamsInInductive (indTypes : List InductiveType) : Array Name := Id.run <| do let mut usedParams : CollectLevelParams.State := {} @@ -455,8 +460,8 @@ private def replaceIndFVarsWithConsts (views : Array InductiveView) (indFVars : | none => none | some c => mkAppN c (params.extract 0 numVars) mkForallFVars params type - pure { ctor with type := type } - pure { indType with ctors := ctors } + return { ctor with type := type } + return { indType with ctors := ctors } abbrev Ctor2InferMod := Std.HashMap Name Bool @@ -589,7 +594,6 @@ private def mkInductiveDecl (vars : Array Expr) (views : Array InductiveView) : indTypesArray := indTypesArray.push { name := r.view.declName, type := type, ctors := ctors : InductiveType } Term.synthesizeSyntheticMVarsNoPostponing let (numExplicitParams, indTypes) ← fixedIndicesToParams params.size indTypesArray indFVars - trace[Meta.debug] "numExplicitParams: {numExplicitParams}" let u ← getResultingUniverse indTypes let inferLevel ← shouldInferResultUniverse u withUsed vars indTypes fun vars => do @@ -597,7 +601,12 @@ private def mkInductiveDecl (vars : Array Expr) (views : Array InductiveView) : let numParams := numVars + numExplicitParams let indTypes ← updateParams vars indTypes let indTypes ← levelMVarToParam indTypes - let indTypes ← if inferLevel then updateResultingUniverse numParams indTypes else checkResultingUniverses indTypes; pure indTypes + let indTypes ← + if inferLevel then + updateResultingUniverse numParams indTypes + else + checkResultingUniverses indTypes + pure indTypes let usedLevelNames := collectLevelParamsInInductive indTypes match sortDeclLevelParams scopeLevelNames allUserLevelNames usedLevelNames with | Except.error msg => throwError msg diff --git a/src/Lean/Elab/Structure.lean b/src/Lean/Elab/Structure.lean index 0ef256961e..cec52a5efe 100644 --- a/src/Lean/Elab/Structure.lean +++ b/src/Lean/Elab/Structure.lean @@ -623,11 +623,15 @@ private def levelMVarToParam (scopeVars : Array Expr) (params : Array Expr) (fie (levelMVarToParamAux scopeVars params fieldInfos).run' 1 private partial def collectUniversesFromFields (r : Level) (rOffset : Nat) (fieldInfos : Array StructFieldInfo) : TermElabM (Array Level) := do - fieldInfos.foldlM (init := #[]) fun (us : Array Level) (info : StructFieldInfo) => do - let type ← inferType info.fvar - let u ← getLevel type - let u ← instantiateLevelMVars u - accLevelAtCtor u r rOffset us + let (_, us) ← go |>.run #[] + return us +where + go : StateRefT (Array Level) TermElabM Unit := + for info in fieldInfos do + let type ← inferType info.fvar + let u ← getLevel type + let u ← instantiateLevelMVars u + accLevelAtCtor u r rOffset private def updateResultingUniverse (fieldInfos : Array StructFieldInfo) (type : Expr) : TermElabM Expr := do let r ← getResultUniverse type