feat: compute level parameters for mutually recursive definitions
This commit is contained in:
parent
0d39c00782
commit
9b788db91f
3 changed files with 91 additions and 20 deletions
|
|
@ -622,26 +622,6 @@ fun stx => do
|
|||
let cmd₂ := stx.getArg 2;
|
||||
`(section $cmd₁:command $cmd₂:command end)
|
||||
|
||||
/--
|
||||
Sort the given list of `usedParams` using the following order:
|
||||
- If it is an explicit level `allUserParams`, then use user given order.
|
||||
- Otherwise, use lexicographical.
|
||||
|
||||
Remark: `scopeParams` are the universe params introduced using the `universe` command. `allUserParams` contains
|
||||
the universe params introduced using the `universe` command *and* the `.{...}` notation.
|
||||
|
||||
Remark: this function return an exception if there is an `u` not in `usedParams`, that is in `allUserParams` but not in `scopeParams`.
|
||||
|
||||
Remark: `explicitParams` are in reverse declaration order. That is, the head is the last declared parameter. -/
|
||||
def sortDeclLevelParams (scopeParams : List Name) (allUserParams : List Name) (usedParams : Array Name) : Except String (List Name) :=
|
||||
match allUserParams.find? $ fun u => !usedParams.contains u && !scopeParams.elem u with
|
||||
| some u => throw ("unused universe parameter '" ++ toString u ++ "'")
|
||||
| none =>
|
||||
let result := allUserParams.foldl (fun result levelName => if usedParams.elem levelName then levelName :: result else result) [];
|
||||
let remaining := usedParams.filter (fun levelParam => !allUserParams.elem levelParam);
|
||||
let remaining := remaining.qsort Name.lt;
|
||||
pure $ result ++ remaining.toList
|
||||
|
||||
def expandDeclId (declId : Syntax) (modifiers : Modifiers) : CommandElabM ExpandDeclIdResult := do
|
||||
currNamespace ← getCurrNamespace;
|
||||
currLevelNames ← getLevelNames;
|
||||
|
|
|
|||
|
|
@ -60,5 +60,25 @@ match name with
|
|||
| Name.str _ s _ => "_instance".isPrefixOf s
|
||||
| _ => false
|
||||
|
||||
/--
|
||||
Sort the given list of `usedParams` using the following order:
|
||||
- If it is an explicit level `allUserParams`, then use user given order.
|
||||
- Otherwise, use lexicographical.
|
||||
|
||||
Remark: `scopeParams` are the universe params introduced using the `universe` command. `allUserParams` contains
|
||||
the universe params introduced using the `universe` command *and* the `.{...}` notation.
|
||||
|
||||
Remark: this function return an exception if there is an `u` not in `usedParams`, that is in `allUserParams` but not in `scopeParams`.
|
||||
|
||||
Remark: `explicitParams` are in reverse declaration order. That is, the head is the last declared parameter. -/
|
||||
def sortDeclLevelParams (scopeParams : List Name) (allUserParams : List Name) (usedParams : Array Name) : Except String (List Name) :=
|
||||
match allUserParams.find? $ fun u => !usedParams.contains u && !scopeParams.elem u with
|
||||
| some u => throw ("unused universe parameter '" ++ toString u ++ "'")
|
||||
| none =>
|
||||
let result := allUserParams.foldl (fun result levelName => if usedParams.elem levelName then levelName :: result else result) [];
|
||||
let remaining := usedParams.filter (fun levelParam => !allUserParams.elem levelParam);
|
||||
let remaining := remaining.qsort Name.lt;
|
||||
pure $ result ++ remaining.toList
|
||||
|
||||
end Elab
|
||||
end Lean
|
||||
|
|
|
|||
|
|
@ -206,6 +206,7 @@ letRecsToLift.forM fun toLift => do
|
|||
|
||||
structure PreDeclaration :=
|
||||
(kind : DefKind)
|
||||
(lparams : List Name)
|
||||
(modifiers : Modifiers)
|
||||
(declName : Name)
|
||||
(type : Expr)
|
||||
|
|
@ -494,6 +495,7 @@ mainHeaders.size.foldM
|
|||
pure $ preDecls.push {
|
||||
kind := header.kind,
|
||||
declName := header.declName,
|
||||
lparams := [], -- we set it later
|
||||
modifiers := header.modifiers,
|
||||
type := type,
|
||||
val := val
|
||||
|
|
@ -508,6 +510,7 @@ letRecClosures.foldl
|
|||
preDecls.push {
|
||||
kind := kind,
|
||||
declName := c.toLift.declName,
|
||||
lparams := [], -- we set it later
|
||||
modifiers := { modifiers with attrs := c.toLift.attrs },
|
||||
type := type,
|
||||
val := val
|
||||
|
|
@ -559,9 +562,74 @@ pushMain (pushLetRecs #[] letRecClosures letRecKind letRecMods) sectionVars main
|
|||
|
||||
end MutualClosure
|
||||
|
||||
private def getAllUserLevelNames (headers : Array DefViewElabHeader) : List Name :=
|
||||
if h : 0 < headers.size then
|
||||
-- Recall that all top-level functions must have the same levels. See `check` method above
|
||||
(headers.get ⟨0, h⟩).levelNames
|
||||
else
|
||||
[]
|
||||
|
||||
private def instantiateMVarsAtPreDecls (preDecls : Array PreDeclaration) : TermElabM (Array PreDeclaration) :=
|
||||
preDecls.mapM fun preDecl => do
|
||||
type ← instantiateMVars preDecl.type;
|
||||
val ← instantiateMVars preDecl.val;
|
||||
pure { preDecl with type := type, val := val }
|
||||
|
||||
private def levelMVarToParamExpr (e : Expr) : StateRefT Nat TermElabM Expr := do
|
||||
nextIdx ← get;
|
||||
(e, nextIdx) ← liftM $ levelMVarToParam e nextIdx;
|
||||
set nextIdx;
|
||||
pure e
|
||||
|
||||
private def levelMVarToParamPreDeclsAux (preDecls : Array PreDeclaration) : StateRefT Nat TermElabM (Array PreDeclaration) :=
|
||||
preDecls.mapM fun preDecl => do
|
||||
type ← levelMVarToParamExpr preDecl.type;
|
||||
val ← levelMVarToParamExpr preDecl.val;
|
||||
pure { preDecl with type := type, val := val }
|
||||
|
||||
private def levelMVarToParamPreDecls (preDecls : Array PreDeclaration) : TermElabM (Array PreDeclaration) :=
|
||||
(levelMVarToParamPreDeclsAux preDecls).run' 1
|
||||
|
||||
private def collectLevelParamsExpr (e : Expr) : StateM CollectLevelParams.State Unit := do
|
||||
modify fun s => collectLevelParams s e
|
||||
|
||||
private def getLevelParamsPreDecls (preDecls : Array PreDeclaration) (scopeLevelNames allUserLevelNames : List Name) : TermElabM (List Name) :=
|
||||
let (_, s) := StateT.run
|
||||
(preDecls.forM fun preDecl => do {
|
||||
collectLevelParamsExpr preDecl.type;
|
||||
collectLevelParamsExpr preDecl.val })
|
||||
{};
|
||||
match sortDeclLevelParams scopeLevelNames allUserLevelNames s.params with
|
||||
| Except.error msg => throwError msg
|
||||
| Except.ok levelParams => pure levelParams
|
||||
|
||||
private def shareCommon (preDecls : Array PreDeclaration) : Array PreDeclaration :=
|
||||
let result : Std.ShareCommonM (Array PreDeclaration) :=
|
||||
preDecls.mapM fun preDecl => do {
|
||||
type ← Std.withShareCommon preDecl.type;
|
||||
val ← Std.withShareCommon preDecl.val;
|
||||
pure { preDecl with type := type, val := val }
|
||||
};
|
||||
result.run
|
||||
|
||||
private def fixLevelParams (preDecls : Array PreDeclaration) (scopeLevelNames allUserLevelNames : List Name) : TermElabM (Array PreDeclaration) := do
|
||||
let preDecls := shareCommon preDecls;
|
||||
lparams ← getLevelParamsPreDecls preDecls scopeLevelNames allUserLevelNames;
|
||||
let us := lparams.map mkLevelParam;
|
||||
let fixExpr (e : Expr) : Expr :=
|
||||
e.replace fun c => match c with
|
||||
| Expr.const declName _ _ => if preDecls.any fun preDecl => preDecl.declName == declName then some $ Lean.mkConst declName us else none
|
||||
| _ => none;
|
||||
pure $ preDecls.map fun preDecl =>
|
||||
{ preDecl with
|
||||
type := fixExpr preDecl.type,
|
||||
val := fixExpr preDecl.val,
|
||||
lparams := lparams }
|
||||
|
||||
def elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM Unit := do
|
||||
scopeLevelNames ← getLevelNames;
|
||||
headers ← elabHeaders views;
|
||||
let allUserLevelNames := getAllUserLevelNames headers;
|
||||
withFunLocalDecls headers fun funFVars => do
|
||||
values ← elabFunValues headers;
|
||||
Term.synthesizeSyntheticMVarsNoPostponing;
|
||||
|
|
@ -574,6 +642,9 @@ withFunLocalDecls headers fun funFVars => do
|
|||
checkLetRecsToLiftTypes funFVars letRecsToLift;
|
||||
withUsedWhen vars headers values letRecsToLift (not $ isTheorem views) fun vars => do
|
||||
preDecls ← MutualClosure.main vars headers funFVars values letRecsToLift;
|
||||
preDecls ← levelMVarToParamPreDecls preDecls;
|
||||
preDecls ← instantiateMVarsAtPreDecls preDecls;
|
||||
preDecls ← fixLevelParams preDecls scopeLevelNames allUserLevelNames;
|
||||
-- TODO
|
||||
preDecls.forM fun preDecl => IO.println (toString preDecl.declName ++ " : " ++ toString preDecl.type ++ " :=\n" ++ toString preDecl.val ++ "\n");
|
||||
throwError "WIP mutual def"
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue