feat: compute level parameters for mutually recursive definitions

This commit is contained in:
Leonardo de Moura 2020-09-05 07:34:26 -07:00
parent 0d39c00782
commit 9b788db91f
3 changed files with 91 additions and 20 deletions

View file

@ -622,26 +622,6 @@ fun stx => do
let cmd₂ := stx.getArg 2;
`(section $cmd₁:command $cmd₂:command end)
/--
Sort the given list of `usedParams` using the following order:
- If it is an explicit level `allUserParams`, then use user given order.
- Otherwise, use lexicographical.
Remark: `scopeParams` are the universe params introduced using the `universe` command. `allUserParams` contains
the universe params introduced using the `universe` command *and* the `.{...}` notation.
Remark: this function return an exception if there is an `u` not in `usedParams`, that is in `allUserParams` but not in `scopeParams`.
Remark: `explicitParams` are in reverse declaration order. That is, the head is the last declared parameter. -/
def sortDeclLevelParams (scopeParams : List Name) (allUserParams : List Name) (usedParams : Array Name) : Except String (List Name) :=
match allUserParams.find? $ fun u => !usedParams.contains u && !scopeParams.elem u with
| some u => throw ("unused universe parameter '" ++ toString u ++ "'")
| none =>
let result := allUserParams.foldl (fun result levelName => if usedParams.elem levelName then levelName :: result else result) [];
let remaining := usedParams.filter (fun levelParam => !allUserParams.elem levelParam);
let remaining := remaining.qsort Name.lt;
pure $ result ++ remaining.toList
def expandDeclId (declId : Syntax) (modifiers : Modifiers) : CommandElabM ExpandDeclIdResult := do
currNamespace ← getCurrNamespace;
currLevelNames ← getLevelNames;

View file

@ -60,5 +60,25 @@ match name with
| Name.str _ s _ => "_instance".isPrefixOf s
| _ => false
/--
Sort the given list of `usedParams` using the following order:
- If it is an explicit level `allUserParams`, then use user given order.
- Otherwise, use lexicographical.
Remark: `scopeParams` are the universe params introduced using the `universe` command. `allUserParams` contains
the universe params introduced using the `universe` command *and* the `.{...}` notation.
Remark: this function return an exception if there is an `u` not in `usedParams`, that is in `allUserParams` but not in `scopeParams`.
Remark: `explicitParams` are in reverse declaration order. That is, the head is the last declared parameter. -/
def sortDeclLevelParams (scopeParams : List Name) (allUserParams : List Name) (usedParams : Array Name) : Except String (List Name) :=
match allUserParams.find? $ fun u => !usedParams.contains u && !scopeParams.elem u with
| some u => throw ("unused universe parameter '" ++ toString u ++ "'")
| none =>
let result := allUserParams.foldl (fun result levelName => if usedParams.elem levelName then levelName :: result else result) [];
let remaining := usedParams.filter (fun levelParam => !allUserParams.elem levelParam);
let remaining := remaining.qsort Name.lt;
pure $ result ++ remaining.toList
end Elab
end Lean

View file

@ -206,6 +206,7 @@ letRecsToLift.forM fun toLift => do
structure PreDeclaration :=
(kind : DefKind)
(lparams : List Name)
(modifiers : Modifiers)
(declName : Name)
(type : Expr)
@ -494,6 +495,7 @@ mainHeaders.size.foldM
pure $ preDecls.push {
kind := header.kind,
declName := header.declName,
lparams := [], -- we set it later
modifiers := header.modifiers,
type := type,
val := val
@ -508,6 +510,7 @@ letRecClosures.foldl
preDecls.push {
kind := kind,
declName := c.toLift.declName,
lparams := [], -- we set it later
modifiers := { modifiers with attrs := c.toLift.attrs },
type := type,
val := val
@ -559,9 +562,74 @@ pushMain (pushLetRecs #[] letRecClosures letRecKind letRecMods) sectionVars main
end MutualClosure
private def getAllUserLevelNames (headers : Array DefViewElabHeader) : List Name :=
if h : 0 < headers.size then
-- Recall that all top-level functions must have the same levels. See `check` method above
(headers.get ⟨0, h⟩).levelNames
else
[]
private def instantiateMVarsAtPreDecls (preDecls : Array PreDeclaration) : TermElabM (Array PreDeclaration) :=
preDecls.mapM fun preDecl => do
type ← instantiateMVars preDecl.type;
val ← instantiateMVars preDecl.val;
pure { preDecl with type := type, val := val }
private def levelMVarToParamExpr (e : Expr) : StateRefT Nat TermElabM Expr := do
nextIdx ← get;
(e, nextIdx) ← liftM $ levelMVarToParam e nextIdx;
set nextIdx;
pure e
private def levelMVarToParamPreDeclsAux (preDecls : Array PreDeclaration) : StateRefT Nat TermElabM (Array PreDeclaration) :=
preDecls.mapM fun preDecl => do
type ← levelMVarToParamExpr preDecl.type;
val ← levelMVarToParamExpr preDecl.val;
pure { preDecl with type := type, val := val }
private def levelMVarToParamPreDecls (preDecls : Array PreDeclaration) : TermElabM (Array PreDeclaration) :=
(levelMVarToParamPreDeclsAux preDecls).run' 1
private def collectLevelParamsExpr (e : Expr) : StateM CollectLevelParams.State Unit := do
modify fun s => collectLevelParams s e
private def getLevelParamsPreDecls (preDecls : Array PreDeclaration) (scopeLevelNames allUserLevelNames : List Name) : TermElabM (List Name) :=
let (_, s) := StateT.run
(preDecls.forM fun preDecl => do {
collectLevelParamsExpr preDecl.type;
collectLevelParamsExpr preDecl.val })
{};
match sortDeclLevelParams scopeLevelNames allUserLevelNames s.params with
| Except.error msg => throwError msg
| Except.ok levelParams => pure levelParams
private def shareCommon (preDecls : Array PreDeclaration) : Array PreDeclaration :=
let result : Std.ShareCommonM (Array PreDeclaration) :=
preDecls.mapM fun preDecl => do {
type ← Std.withShareCommon preDecl.type;
val ← Std.withShareCommon preDecl.val;
pure { preDecl with type := type, val := val }
};
result.run
private def fixLevelParams (preDecls : Array PreDeclaration) (scopeLevelNames allUserLevelNames : List Name) : TermElabM (Array PreDeclaration) := do
let preDecls := shareCommon preDecls;
lparams ← getLevelParamsPreDecls preDecls scopeLevelNames allUserLevelNames;
let us := lparams.map mkLevelParam;
let fixExpr (e : Expr) : Expr :=
e.replace fun c => match c with
| Expr.const declName _ _ => if preDecls.any fun preDecl => preDecl.declName == declName then some $ Lean.mkConst declName us else none
| _ => none;
pure $ preDecls.map fun preDecl =>
{ preDecl with
type := fixExpr preDecl.type,
val := fixExpr preDecl.val,
lparams := lparams }
def elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM Unit := do
scopeLevelNames ← getLevelNames;
headers ← elabHeaders views;
let allUserLevelNames := getAllUserLevelNames headers;
withFunLocalDecls headers fun funFVars => do
values ← elabFunValues headers;
Term.synthesizeSyntheticMVarsNoPostponing;
@ -574,6 +642,9 @@ withFunLocalDecls headers fun funFVars => do
checkLetRecsToLiftTypes funFVars letRecsToLift;
withUsedWhen vars headers values letRecsToLift (not $ isTheorem views) fun vars => do
preDecls ← MutualClosure.main vars headers funFVars values letRecsToLift;
preDecls ← levelMVarToParamPreDecls preDecls;
preDecls ← instantiateMVarsAtPreDecls preDecls;
preDecls ← fixLevelParams preDecls scopeLevelNames allUserLevelNames;
-- TODO
preDecls.forM fun preDecl => IO.println (toString preDecl.declName ++ " : " ++ toString preDecl.type ++ " :=\n" ++ toString preDecl.val ++ "\n");
throwError "WIP mutual def"