From 9b788db91f10368fd5d438527bb14dff4e113aab Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 5 Sep 2020 07:34:26 -0700 Subject: [PATCH] feat: compute level parameters for mutually recursive definitions --- src/Lean/Elab/Command.lean | 20 ---------- src/Lean/Elab/DeclUtil.lean | 20 ++++++++++ src/Lean/Elab/MutualDef.lean | 71 ++++++++++++++++++++++++++++++++++++ 3 files changed, 91 insertions(+), 20 deletions(-) diff --git a/src/Lean/Elab/Command.lean b/src/Lean/Elab/Command.lean index 1b3d07df61..4921aada24 100644 --- a/src/Lean/Elab/Command.lean +++ b/src/Lean/Elab/Command.lean @@ -622,26 +622,6 @@ fun stx => do let cmd₂ := stx.getArg 2; `(section $cmd₁:command $cmd₂:command end) -/-- - Sort the given list of `usedParams` using the following order: - - If it is an explicit level `allUserParams`, then use user given order. - - Otherwise, use lexicographical. - - Remark: `scopeParams` are the universe params introduced using the `universe` command. `allUserParams` contains - the universe params introduced using the `universe` command *and* the `.{...}` notation. - - Remark: this function return an exception if there is an `u` not in `usedParams`, that is in `allUserParams` but not in `scopeParams`. - - Remark: `explicitParams` are in reverse declaration order. That is, the head is the last declared parameter. -/ -def sortDeclLevelParams (scopeParams : List Name) (allUserParams : List Name) (usedParams : Array Name) : Except String (List Name) := -match allUserParams.find? $ fun u => !usedParams.contains u && !scopeParams.elem u with -| some u => throw ("unused universe parameter '" ++ toString u ++ "'") -| none => - let result := allUserParams.foldl (fun result levelName => if usedParams.elem levelName then levelName :: result else result) []; - let remaining := usedParams.filter (fun levelParam => !allUserParams.elem levelParam); - let remaining := remaining.qsort Name.lt; - pure $ result ++ remaining.toList - def expandDeclId (declId : Syntax) (modifiers : Modifiers) : CommandElabM ExpandDeclIdResult := do currNamespace ← getCurrNamespace; currLevelNames ← getLevelNames; diff --git a/src/Lean/Elab/DeclUtil.lean b/src/Lean/Elab/DeclUtil.lean index 77a09b9b78..3163070b77 100644 --- a/src/Lean/Elab/DeclUtil.lean +++ b/src/Lean/Elab/DeclUtil.lean @@ -60,5 +60,25 @@ match name with | Name.str _ s _ => "_instance".isPrefixOf s | _ => false +/-- + Sort the given list of `usedParams` using the following order: + - If it is an explicit level `allUserParams`, then use user given order. + - Otherwise, use lexicographical. + + Remark: `scopeParams` are the universe params introduced using the `universe` command. `allUserParams` contains + the universe params introduced using the `universe` command *and* the `.{...}` notation. + + Remark: this function return an exception if there is an `u` not in `usedParams`, that is in `allUserParams` but not in `scopeParams`. + + Remark: `explicitParams` are in reverse declaration order. That is, the head is the last declared parameter. -/ +def sortDeclLevelParams (scopeParams : List Name) (allUserParams : List Name) (usedParams : Array Name) : Except String (List Name) := +match allUserParams.find? $ fun u => !usedParams.contains u && !scopeParams.elem u with +| some u => throw ("unused universe parameter '" ++ toString u ++ "'") +| none => + let result := allUserParams.foldl (fun result levelName => if usedParams.elem levelName then levelName :: result else result) []; + let remaining := usedParams.filter (fun levelParam => !allUserParams.elem levelParam); + let remaining := remaining.qsort Name.lt; + pure $ result ++ remaining.toList + end Elab end Lean diff --git a/src/Lean/Elab/MutualDef.lean b/src/Lean/Elab/MutualDef.lean index 7485f6dfdf..5373fc901a 100644 --- a/src/Lean/Elab/MutualDef.lean +++ b/src/Lean/Elab/MutualDef.lean @@ -206,6 +206,7 @@ letRecsToLift.forM fun toLift => do structure PreDeclaration := (kind : DefKind) +(lparams : List Name) (modifiers : Modifiers) (declName : Name) (type : Expr) @@ -494,6 +495,7 @@ mainHeaders.size.foldM pure $ preDecls.push { kind := header.kind, declName := header.declName, + lparams := [], -- we set it later modifiers := header.modifiers, type := type, val := val @@ -508,6 +510,7 @@ letRecClosures.foldl preDecls.push { kind := kind, declName := c.toLift.declName, + lparams := [], -- we set it later modifiers := { modifiers with attrs := c.toLift.attrs }, type := type, val := val @@ -559,9 +562,74 @@ pushMain (pushLetRecs #[] letRecClosures letRecKind letRecMods) sectionVars main end MutualClosure +private def getAllUserLevelNames (headers : Array DefViewElabHeader) : List Name := +if h : 0 < headers.size then + -- Recall that all top-level functions must have the same levels. See `check` method above + (headers.get ⟨0, h⟩).levelNames +else + [] + +private def instantiateMVarsAtPreDecls (preDecls : Array PreDeclaration) : TermElabM (Array PreDeclaration) := +preDecls.mapM fun preDecl => do + type ← instantiateMVars preDecl.type; + val ← instantiateMVars preDecl.val; + pure { preDecl with type := type, val := val } + +private def levelMVarToParamExpr (e : Expr) : StateRefT Nat TermElabM Expr := do +nextIdx ← get; +(e, nextIdx) ← liftM $ levelMVarToParam e nextIdx; +set nextIdx; +pure e + +private def levelMVarToParamPreDeclsAux (preDecls : Array PreDeclaration) : StateRefT Nat TermElabM (Array PreDeclaration) := +preDecls.mapM fun preDecl => do + type ← levelMVarToParamExpr preDecl.type; + val ← levelMVarToParamExpr preDecl.val; + pure { preDecl with type := type, val := val } + +private def levelMVarToParamPreDecls (preDecls : Array PreDeclaration) : TermElabM (Array PreDeclaration) := +(levelMVarToParamPreDeclsAux preDecls).run' 1 + +private def collectLevelParamsExpr (e : Expr) : StateM CollectLevelParams.State Unit := do +modify fun s => collectLevelParams s e + +private def getLevelParamsPreDecls (preDecls : Array PreDeclaration) (scopeLevelNames allUserLevelNames : List Name) : TermElabM (List Name) := +let (_, s) := StateT.run + (preDecls.forM fun preDecl => do { + collectLevelParamsExpr preDecl.type; + collectLevelParamsExpr preDecl.val }) + {}; +match sortDeclLevelParams scopeLevelNames allUserLevelNames s.params with +| Except.error msg => throwError msg +| Except.ok levelParams => pure levelParams + +private def shareCommon (preDecls : Array PreDeclaration) : Array PreDeclaration := +let result : Std.ShareCommonM (Array PreDeclaration) := + preDecls.mapM fun preDecl => do { + type ← Std.withShareCommon preDecl.type; + val ← Std.withShareCommon preDecl.val; + pure { preDecl with type := type, val := val } + }; +result.run + +private def fixLevelParams (preDecls : Array PreDeclaration) (scopeLevelNames allUserLevelNames : List Name) : TermElabM (Array PreDeclaration) := do +let preDecls := shareCommon preDecls; +lparams ← getLevelParamsPreDecls preDecls scopeLevelNames allUserLevelNames; +let us := lparams.map mkLevelParam; +let fixExpr (e : Expr) : Expr := + e.replace fun c => match c with + | Expr.const declName _ _ => if preDecls.any fun preDecl => preDecl.declName == declName then some $ Lean.mkConst declName us else none + | _ => none; +pure $ preDecls.map fun preDecl => + { preDecl with + type := fixExpr preDecl.type, + val := fixExpr preDecl.val, + lparams := lparams } + def elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM Unit := do scopeLevelNames ← getLevelNames; headers ← elabHeaders views; +let allUserLevelNames := getAllUserLevelNames headers; withFunLocalDecls headers fun funFVars => do values ← elabFunValues headers; Term.synthesizeSyntheticMVarsNoPostponing; @@ -574,6 +642,9 @@ withFunLocalDecls headers fun funFVars => do checkLetRecsToLiftTypes funFVars letRecsToLift; withUsedWhen vars headers values letRecsToLift (not $ isTheorem views) fun vars => do preDecls ← MutualClosure.main vars headers funFVars values letRecsToLift; + preDecls ← levelMVarToParamPreDecls preDecls; + preDecls ← instantiateMVarsAtPreDecls preDecls; + preDecls ← fixLevelParams preDecls scopeLevelNames allUserLevelNames; -- TODO preDecls.forM fun preDecl => IO.println (toString preDecl.declName ++ " : " ++ toString preDecl.type ++ " :=\n" ++ toString preDecl.val ++ "\n"); throwError "WIP mutual def"