chore: cleanup
This commit is contained in:
parent
ca9b494e4d
commit
743f6dd3a2
2 changed files with 78 additions and 65 deletions
|
|
@ -145,8 +145,8 @@ private def checkHeader (r : ElabHeaderResult) (numParams : Nat) (firstType? : O
|
|||
match firstType? with
|
||||
| none => pure type
|
||||
| some firstType =>
|
||||
withRef r.view.ref $ checkParamsAndResultType type firstType numParams
|
||||
pure firstType
|
||||
withRef r.view.ref <| checkParamsAndResultType type firstType numParams
|
||||
return firstType
|
||||
|
||||
-- Auxiliary function for checking whether the types in mutually inductive declaration are compatible.
|
||||
private partial def checkHeaders (rs : Array ElabHeaderResult) (numParams : Nat) (i : Nat) (firstType? : Option Expr) : TermElabM Unit := do
|
||||
|
|
@ -173,7 +173,7 @@ private partial def withInductiveLocalDecls {α} (rs : Array ElabHeaderResult) (
|
|||
pure (r.view.shortDeclName, type)
|
||||
let r0 := rs[0]
|
||||
let params := r0.params
|
||||
withLCtx r0.lctx r0.localInsts $ withRef r0.view.ref do
|
||||
withLCtx r0.lctx r0.localInsts <| withRef r0.view.ref do
|
||||
let rec loop (i : Nat) (indFVars : Array Expr) := do
|
||||
if h : i < namesAndTypes.size then
|
||||
let (id, type) := namesAndTypes.get ⟨i, h⟩
|
||||
|
|
@ -308,58 +308,11 @@ def shouldInferResultUniverse (u : Level) : TermElabM Bool := do
|
|||
match u.getLevelOffset with
|
||||
| Level.mvar mvarId _ => do
|
||||
Term.assignLevelMVar mvarId tmpIndParam
|
||||
pure true
|
||||
return true
|
||||
| _ =>
|
||||
throwError "cannot infer resulting universe level of inductive datatype, given level contains metavariables {mkSort u}, provide universe explicitly"
|
||||
else
|
||||
pure false
|
||||
|
||||
/-
|
||||
Auxiliary function for `updateResultingUniverse`
|
||||
`accLevelAtCtor u r rOffset us` add `u` components to `us` if they are not already there and it is different from the resulting universe level `r+rOffset`.
|
||||
If `u` is a `max`, then its components are recursively processed.
|
||||
If `u` is a `succ` and `rOffset > 0`, we process the `u`s child using `rOffset-1`.
|
||||
|
||||
This method is used to infer the resulting universe level of an inductive datatype. -/
|
||||
def accLevelAtCtor : Level → Level → Nat → Array Level → TermElabM (Array Level)
|
||||
| Level.max u v _, r, rOffset, us => do let us ← accLevelAtCtor u r rOffset us; accLevelAtCtor v r rOffset us
|
||||
| Level.imax u v _, r, rOffset, us => do let us ← accLevelAtCtor u r rOffset us; accLevelAtCtor v r rOffset us
|
||||
| Level.zero _, _, _, us => pure us
|
||||
| Level.succ u _, r, rOffset+1, us => accLevelAtCtor u r rOffset us
|
||||
| u, r, rOffset, us =>
|
||||
if rOffset == 0 && u == r then pure us
|
||||
else if r.occurs u then throwError "failed to compute resulting universe level of inductive datatype, provide universe explicitly"
|
||||
else if rOffset > 0 then throwError "failed to compute resulting universe level of inductive datatype, provide universe explicitly"
|
||||
else if us.contains u then pure us
|
||||
else pure (us.push u)
|
||||
|
||||
/- Auxiliary function for `updateResultingUniverse` -/
|
||||
private partial def collectUniversesFromCtorTypeAux (r : Level) (rOffset : Nat) : Nat → Expr → Array Level → TermElabM (Array Level)
|
||||
| 0, Expr.forallE n d b c, us => do
|
||||
let u ← getLevel d
|
||||
let u ← instantiateLevelMVars u
|
||||
let us ← accLevelAtCtor u r rOffset us
|
||||
withLocalDecl n c.binderInfo d fun x =>
|
||||
let e := b.instantiate1 x
|
||||
collectUniversesFromCtorTypeAux r rOffset 0 e us
|
||||
| i+1, Expr.forallE n d b c, us => do
|
||||
withLocalDecl n c.binderInfo d fun x =>
|
||||
let e := b.instantiate1 x
|
||||
collectUniversesFromCtorTypeAux r rOffset i e us
|
||||
| _, _, us => pure us
|
||||
|
||||
/- Auxiliary function for `updateResultingUniverse` -/
|
||||
private partial def collectUniversesFromCtorType
|
||||
(r : Level) (rOffset : Nat) (ctorType : Expr) (numParams : Nat) (us : Array Level) : TermElabM (Array Level) :=
|
||||
collectUniversesFromCtorTypeAux r rOffset numParams ctorType us
|
||||
|
||||
/- Auxiliary function for `updateResultingUniverse` -/
|
||||
private partial def collectUniverses (r : Level) (rOffset : Nat) (numParams : Nat) (indTypes : List InductiveType) : TermElabM (Array Level) := do
|
||||
let mut us := #[]
|
||||
for indType in indTypes do
|
||||
for ctor in indType.ctors do
|
||||
us ← collectUniversesFromCtorType r rOffset ctor.type numParams us
|
||||
return us
|
||||
return false
|
||||
|
||||
def mkResultUniverse (us : Array Level) (rOffset : Nat) : Level :=
|
||||
if us.isEmpty && rOffset == 0 then
|
||||
|
|
@ -367,10 +320,62 @@ def mkResultUniverse (us : Array Level) (rOffset : Nat) : Level :=
|
|||
else
|
||||
let r := Level.mkNaryMax us.toList
|
||||
if rOffset == 0 && !r.isZero && !r.isNeverZero then
|
||||
(mkLevelMax r levelOne).normalize
|
||||
mkLevelMax r levelOne |>.normalize
|
||||
else
|
||||
r.normalize
|
||||
|
||||
/--
|
||||
Auxiliary function for `updateResultingUniverse`
|
||||
`accLevelAtCtor u r rOffset` add `u` to state if it is not already there and
|
||||
it is different from the resulting universe level `r+rOffset`.
|
||||
|
||||
If `u` is a `max`, then its components are recursively processed.
|
||||
If `u` is a `succ` and `rOffset > 0`, we process the `u`s child using `rOffset-1`.
|
||||
|
||||
This method is used to infer the resulting universe level of an inductive datatype.
|
||||
-/
|
||||
def accLevelAtCtor (u : Level) (r : Level) (rOffset : Nat) : StateRefT (Array Level) TermElabM Unit := do
|
||||
match u, rOffset with
|
||||
| Level.max u v _, rOffset => accLevelAtCtor u r rOffset; accLevelAtCtor v r rOffset
|
||||
| Level.imax u v _, rOffset => accLevelAtCtor u r rOffset; accLevelAtCtor v r rOffset
|
||||
| Level.zero _, _ => return ()
|
||||
| Level.succ u _, rOffset+1 => accLevelAtCtor u r rOffset
|
||||
| u, rOffset =>
|
||||
if rOffset == 0 && u == r then
|
||||
return ()
|
||||
else if r.occurs u then
|
||||
throwError "failed to compute resulting universe level of inductive datatype, provide universe explicitly"
|
||||
else if rOffset > 0 then
|
||||
throwError "failed to compute resulting universe level of inductive datatype, provide universe explicitly"
|
||||
else if (← get).contains u then
|
||||
return ()
|
||||
else
|
||||
modify fun us => us.push u
|
||||
|
||||
/-- Auxiliary function for `updateResultingUniverse` -/
|
||||
private partial def collectUniverses (r : Level) (rOffset : Nat) (numParams : Nat) (indTypes : List InductiveType) : TermElabM (Array Level) := do
|
||||
let (_, us) ← go |>.run #[]
|
||||
return us
|
||||
where
|
||||
go : StateRefT (Array Level) TermElabM Unit :=
|
||||
indTypes.forM fun indType => indType.ctors.forM fun ctor =>
|
||||
collectUniversesFromCtorType numParams ctor.type
|
||||
|
||||
collectUniversesFromCtorType (i : Nat) (type : Expr) : StateRefT (Array Level) TermElabM Unit := do
|
||||
match i, type with
|
||||
| 0, Expr.forallE n d b c =>
|
||||
let u ← getLevel d
|
||||
let u ← instantiateLevelMVars u
|
||||
let us ← accLevelAtCtor u r rOffset
|
||||
withLocalDecl n c.binderInfo d fun x =>
|
||||
let e := b.instantiate1 x
|
||||
collectUniversesFromCtorType 0 e
|
||||
| i+1, Expr.forallE n d b c =>
|
||||
withLocalDecl n c.binderInfo d fun x =>
|
||||
let e := b.instantiate1 x
|
||||
collectUniversesFromCtorType i e
|
||||
| _, _ => return ()
|
||||
|
||||
private def updateResultingUniverse (numParams : Nat) (indTypes : List InductiveType) : TermElabM (List InductiveType) := do
|
||||
let r ← getResultingUniverse indTypes
|
||||
let rOffset : Nat := r.getOffset
|
||||
|
|
@ -413,15 +418,15 @@ private def removeUnused (vars : Array Expr) (indTypes : List InductiveType) : T
|
|||
|
||||
private def withUsed {α} (vars : Array Expr) (indTypes : List InductiveType) (k : Array Expr → TermElabM α) : TermElabM α := do
|
||||
let (lctx, localInsts, vars) ← removeUnused vars indTypes
|
||||
withLCtx lctx localInsts $ k vars
|
||||
withLCtx lctx localInsts <| k vars
|
||||
|
||||
private def updateParams (vars : Array Expr) (indTypes : List InductiveType) : TermElabM (List InductiveType) :=
|
||||
indTypes.mapM fun indType => do
|
||||
let type ← mkForallFVars vars indType.type
|
||||
let ctors ← indType.ctors.mapM fun ctor => do
|
||||
let ctorType ← mkForallFVars vars ctor.type
|
||||
pure { ctor with type := ctorType }
|
||||
pure { indType with type := type, ctors := ctors }
|
||||
return { ctor with type := ctorType }
|
||||
return { indType with type := type, ctors := ctors }
|
||||
|
||||
private def collectLevelParamsInInductive (indTypes : List InductiveType) : Array Name := Id.run <| do
|
||||
let mut usedParams : CollectLevelParams.State := {}
|
||||
|
|
@ -455,8 +460,8 @@ private def replaceIndFVarsWithConsts (views : Array InductiveView) (indFVars :
|
|||
| none => none
|
||||
| some c => mkAppN c (params.extract 0 numVars)
|
||||
mkForallFVars params type
|
||||
pure { ctor with type := type }
|
||||
pure { indType with ctors := ctors }
|
||||
return { ctor with type := type }
|
||||
return { indType with ctors := ctors }
|
||||
|
||||
abbrev Ctor2InferMod := Std.HashMap Name Bool
|
||||
|
||||
|
|
@ -589,7 +594,6 @@ private def mkInductiveDecl (vars : Array Expr) (views : Array InductiveView) :
|
|||
indTypesArray := indTypesArray.push { name := r.view.declName, type := type, ctors := ctors : InductiveType }
|
||||
Term.synthesizeSyntheticMVarsNoPostponing
|
||||
let (numExplicitParams, indTypes) ← fixedIndicesToParams params.size indTypesArray indFVars
|
||||
trace[Meta.debug] "numExplicitParams: {numExplicitParams}"
|
||||
let u ← getResultingUniverse indTypes
|
||||
let inferLevel ← shouldInferResultUniverse u
|
||||
withUsed vars indTypes fun vars => do
|
||||
|
|
@ -597,7 +601,12 @@ private def mkInductiveDecl (vars : Array Expr) (views : Array InductiveView) :
|
|||
let numParams := numVars + numExplicitParams
|
||||
let indTypes ← updateParams vars indTypes
|
||||
let indTypes ← levelMVarToParam indTypes
|
||||
let indTypes ← if inferLevel then updateResultingUniverse numParams indTypes else checkResultingUniverses indTypes; pure indTypes
|
||||
let indTypes ←
|
||||
if inferLevel then
|
||||
updateResultingUniverse numParams indTypes
|
||||
else
|
||||
checkResultingUniverses indTypes
|
||||
pure indTypes
|
||||
let usedLevelNames := collectLevelParamsInInductive indTypes
|
||||
match sortDeclLevelParams scopeLevelNames allUserLevelNames usedLevelNames with
|
||||
| Except.error msg => throwError msg
|
||||
|
|
|
|||
|
|
@ -623,11 +623,15 @@ private def levelMVarToParam (scopeVars : Array Expr) (params : Array Expr) (fie
|
|||
(levelMVarToParamAux scopeVars params fieldInfos).run' 1
|
||||
|
||||
private partial def collectUniversesFromFields (r : Level) (rOffset : Nat) (fieldInfos : Array StructFieldInfo) : TermElabM (Array Level) := do
|
||||
fieldInfos.foldlM (init := #[]) fun (us : Array Level) (info : StructFieldInfo) => do
|
||||
let type ← inferType info.fvar
|
||||
let u ← getLevel type
|
||||
let u ← instantiateLevelMVars u
|
||||
accLevelAtCtor u r rOffset us
|
||||
let (_, us) ← go |>.run #[]
|
||||
return us
|
||||
where
|
||||
go : StateRefT (Array Level) TermElabM Unit :=
|
||||
for info in fieldInfos do
|
||||
let type ← inferType info.fvar
|
||||
let u ← getLevel type
|
||||
let u ← instantiateLevelMVars u
|
||||
accLevelAtCtor u r rOffset
|
||||
|
||||
private def updateResultingUniverse (fieldInfos : Array StructFieldInfo) (type : Expr) : TermElabM Expr := do
|
||||
let r ← getResultUniverse type
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue