chore: cleanup

This commit is contained in:
Leonardo de Moura 2022-04-03 06:56:27 -07:00
parent ca9b494e4d
commit 743f6dd3a2
2 changed files with 78 additions and 65 deletions

View file

@ -145,8 +145,8 @@ private def checkHeader (r : ElabHeaderResult) (numParams : Nat) (firstType? : O
match firstType? with
| none => pure type
| some firstType =>
withRef r.view.ref $ checkParamsAndResultType type firstType numParams
pure firstType
withRef r.view.ref <| checkParamsAndResultType type firstType numParams
return firstType
-- Auxiliary function for checking whether the types in mutually inductive declaration are compatible.
private partial def checkHeaders (rs : Array ElabHeaderResult) (numParams : Nat) (i : Nat) (firstType? : Option Expr) : TermElabM Unit := do
@ -173,7 +173,7 @@ private partial def withInductiveLocalDecls {α} (rs : Array ElabHeaderResult) (
pure (r.view.shortDeclName, type)
let r0 := rs[0]
let params := r0.params
withLCtx r0.lctx r0.localInsts $ withRef r0.view.ref do
withLCtx r0.lctx r0.localInsts <| withRef r0.view.ref do
let rec loop (i : Nat) (indFVars : Array Expr) := do
if h : i < namesAndTypes.size then
let (id, type) := namesAndTypes.get ⟨i, h⟩
@ -308,58 +308,11 @@ def shouldInferResultUniverse (u : Level) : TermElabM Bool := do
match u.getLevelOffset with
| Level.mvar mvarId _ => do
Term.assignLevelMVar mvarId tmpIndParam
pure true
return true
| _ =>
throwError "cannot infer resulting universe level of inductive datatype, given level contains metavariables {mkSort u}, provide universe explicitly"
else
pure false
/-
Auxiliary function for `updateResultingUniverse`
`accLevelAtCtor u r rOffset us` add `u` components to `us` if they are not already there and it is different from the resulting universe level `r+rOffset`.
If `u` is a `max`, then its components are recursively processed.
If `u` is a `succ` and `rOffset > 0`, we process the `u`s child using `rOffset-1`.
This method is used to infer the resulting universe level of an inductive datatype. -/
def accLevelAtCtor : Level → Level → Nat → Array Level → TermElabM (Array Level)
| Level.max u v _, r, rOffset, us => do let us ← accLevelAtCtor u r rOffset us; accLevelAtCtor v r rOffset us
| Level.imax u v _, r, rOffset, us => do let us ← accLevelAtCtor u r rOffset us; accLevelAtCtor v r rOffset us
| Level.zero _, _, _, us => pure us
| Level.succ u _, r, rOffset+1, us => accLevelAtCtor u r rOffset us
| u, r, rOffset, us =>
if rOffset == 0 && u == r then pure us
else if r.occurs u then throwError "failed to compute resulting universe level of inductive datatype, provide universe explicitly"
else if rOffset > 0 then throwError "failed to compute resulting universe level of inductive datatype, provide universe explicitly"
else if us.contains u then pure us
else pure (us.push u)
/- Auxiliary function for `updateResultingUniverse` -/
private partial def collectUniversesFromCtorTypeAux (r : Level) (rOffset : Nat) : Nat → Expr → Array Level → TermElabM (Array Level)
| 0, Expr.forallE n d b c, us => do
let u ← getLevel d
let u ← instantiateLevelMVars u
let us ← accLevelAtCtor u r rOffset us
withLocalDecl n c.binderInfo d fun x =>
let e := b.instantiate1 x
collectUniversesFromCtorTypeAux r rOffset 0 e us
| i+1, Expr.forallE n d b c, us => do
withLocalDecl n c.binderInfo d fun x =>
let e := b.instantiate1 x
collectUniversesFromCtorTypeAux r rOffset i e us
| _, _, us => pure us
/- Auxiliary function for `updateResultingUniverse` -/
private partial def collectUniversesFromCtorType
(r : Level) (rOffset : Nat) (ctorType : Expr) (numParams : Nat) (us : Array Level) : TermElabM (Array Level) :=
collectUniversesFromCtorTypeAux r rOffset numParams ctorType us
/- Auxiliary function for `updateResultingUniverse` -/
private partial def collectUniverses (r : Level) (rOffset : Nat) (numParams : Nat) (indTypes : List InductiveType) : TermElabM (Array Level) := do
let mut us := #[]
for indType in indTypes do
for ctor in indType.ctors do
us ← collectUniversesFromCtorType r rOffset ctor.type numParams us
return us
return false
def mkResultUniverse (us : Array Level) (rOffset : Nat) : Level :=
if us.isEmpty && rOffset == 0 then
@ -367,10 +320,62 @@ def mkResultUniverse (us : Array Level) (rOffset : Nat) : Level :=
else
let r := Level.mkNaryMax us.toList
if rOffset == 0 && !r.isZero && !r.isNeverZero then
(mkLevelMax r levelOne).normalize
mkLevelMax r levelOne |>.normalize
else
r.normalize
/--
Auxiliary function for `updateResultingUniverse`
`accLevelAtCtor u r rOffset` add `u` to state if it is not already there and
it is different from the resulting universe level `r+rOffset`.
If `u` is a `max`, then its components are recursively processed.
If `u` is a `succ` and `rOffset > 0`, we process the `u`s child using `rOffset-1`.
This method is used to infer the resulting universe level of an inductive datatype.
-/
def accLevelAtCtor (u : Level) (r : Level) (rOffset : Nat) : StateRefT (Array Level) TermElabM Unit := do
match u, rOffset with
| Level.max u v _, rOffset => accLevelAtCtor u r rOffset; accLevelAtCtor v r rOffset
| Level.imax u v _, rOffset => accLevelAtCtor u r rOffset; accLevelAtCtor v r rOffset
| Level.zero _, _ => return ()
| Level.succ u _, rOffset+1 => accLevelAtCtor u r rOffset
| u, rOffset =>
if rOffset == 0 && u == r then
return ()
else if r.occurs u then
throwError "failed to compute resulting universe level of inductive datatype, provide universe explicitly"
else if rOffset > 0 then
throwError "failed to compute resulting universe level of inductive datatype, provide universe explicitly"
else if (← get).contains u then
return ()
else
modify fun us => us.push u
/-- Auxiliary function for `updateResultingUniverse` -/
private partial def collectUniverses (r : Level) (rOffset : Nat) (numParams : Nat) (indTypes : List InductiveType) : TermElabM (Array Level) := do
let (_, us) ← go |>.run #[]
return us
where
go : StateRefT (Array Level) TermElabM Unit :=
indTypes.forM fun indType => indType.ctors.forM fun ctor =>
collectUniversesFromCtorType numParams ctor.type
collectUniversesFromCtorType (i : Nat) (type : Expr) : StateRefT (Array Level) TermElabM Unit := do
match i, type with
| 0, Expr.forallE n d b c =>
let u ← getLevel d
let u ← instantiateLevelMVars u
let us ← accLevelAtCtor u r rOffset
withLocalDecl n c.binderInfo d fun x =>
let e := b.instantiate1 x
collectUniversesFromCtorType 0 e
| i+1, Expr.forallE n d b c =>
withLocalDecl n c.binderInfo d fun x =>
let e := b.instantiate1 x
collectUniversesFromCtorType i e
| _, _ => return ()
private def updateResultingUniverse (numParams : Nat) (indTypes : List InductiveType) : TermElabM (List InductiveType) := do
let r ← getResultingUniverse indTypes
let rOffset : Nat := r.getOffset
@ -413,15 +418,15 @@ private def removeUnused (vars : Array Expr) (indTypes : List InductiveType) : T
private def withUsed {α} (vars : Array Expr) (indTypes : List InductiveType) (k : Array Expr → TermElabM α) : TermElabM α := do
let (lctx, localInsts, vars) ← removeUnused vars indTypes
withLCtx lctx localInsts $ k vars
withLCtx lctx localInsts <| k vars
private def updateParams (vars : Array Expr) (indTypes : List InductiveType) : TermElabM (List InductiveType) :=
indTypes.mapM fun indType => do
let type ← mkForallFVars vars indType.type
let ctors ← indType.ctors.mapM fun ctor => do
let ctorType ← mkForallFVars vars ctor.type
pure { ctor with type := ctorType }
pure { indType with type := type, ctors := ctors }
return { ctor with type := ctorType }
return { indType with type := type, ctors := ctors }
private def collectLevelParamsInInductive (indTypes : List InductiveType) : Array Name := Id.run <| do
let mut usedParams : CollectLevelParams.State := {}
@ -455,8 +460,8 @@ private def replaceIndFVarsWithConsts (views : Array InductiveView) (indFVars :
| none => none
| some c => mkAppN c (params.extract 0 numVars)
mkForallFVars params type
pure { ctor with type := type }
pure { indType with ctors := ctors }
return { ctor with type := type }
return { indType with ctors := ctors }
abbrev Ctor2InferMod := Std.HashMap Name Bool
@ -589,7 +594,6 @@ private def mkInductiveDecl (vars : Array Expr) (views : Array InductiveView) :
indTypesArray := indTypesArray.push { name := r.view.declName, type := type, ctors := ctors : InductiveType }
Term.synthesizeSyntheticMVarsNoPostponing
let (numExplicitParams, indTypes) ← fixedIndicesToParams params.size indTypesArray indFVars
trace[Meta.debug] "numExplicitParams: {numExplicitParams}"
let u ← getResultingUniverse indTypes
let inferLevel ← shouldInferResultUniverse u
withUsed vars indTypes fun vars => do
@ -597,7 +601,12 @@ private def mkInductiveDecl (vars : Array Expr) (views : Array InductiveView) :
let numParams := numVars + numExplicitParams
let indTypes ← updateParams vars indTypes
let indTypes ← levelMVarToParam indTypes
let indTypes ← if inferLevel then updateResultingUniverse numParams indTypes else checkResultingUniverses indTypes; pure indTypes
let indTypes ←
if inferLevel then
updateResultingUniverse numParams indTypes
else
checkResultingUniverses indTypes
pure indTypes
let usedLevelNames := collectLevelParamsInInductive indTypes
match sortDeclLevelParams scopeLevelNames allUserLevelNames usedLevelNames with
| Except.error msg => throwError msg

View file

@ -623,11 +623,15 @@ private def levelMVarToParam (scopeVars : Array Expr) (params : Array Expr) (fie
(levelMVarToParamAux scopeVars params fieldInfos).run' 1
private partial def collectUniversesFromFields (r : Level) (rOffset : Nat) (fieldInfos : Array StructFieldInfo) : TermElabM (Array Level) := do
fieldInfos.foldlM (init := #[]) fun (us : Array Level) (info : StructFieldInfo) => do
let type ← inferType info.fvar
let u ← getLevel type
let u ← instantiateLevelMVars u
accLevelAtCtor u r rOffset us
let (_, us) ← go |>.run #[]
return us
where
go : StateRefT (Array Level) TermElabM Unit :=
for info in fieldInfos do
let type ← inferType info.fvar
let u ← getLevel type
let u ← instantiateLevelMVars u
accLevelAtCtor u r rOffset
private def updateResultingUniverse (fieldInfos : Array StructFieldInfo) (type : Expr) : TermElabM Expr := do
let r ← getResultUniverse type