From bb3c8a2105987a5ef691e3d12bbcd793bbe042fa Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 27 Aug 2020 10:46:33 -0700 Subject: [PATCH] refactor: polymorphic `applyAttributes` --- src/Lean/Elab/Attributes.lean | 12 +++++ src/Lean/Elab/Command.lean | 5 -- src/Lean/Elab/DeclModifiers.lean | 9 ---- src/Lean/Elab/Declaration.lean | 18 +++---- src/Lean/Elab/Definition.lean | 47 ++++++++--------- src/Lean/Elab/Inductive.lean | 60 +++++++++++----------- src/Lean/Elab/Structure.lean | 86 +++++++++++++++----------------- src/Lean/Meta/Instances.lean | 13 +++-- 8 files changed, 123 insertions(+), 127 deletions(-) diff --git a/src/Lean/Elab/Attributes.lean b/src/Lean/Elab/Attributes.lean index 36d1bd894b..aaf8b61cb3 100644 --- a/src/Lean/Elab/Attributes.lean +++ b/src/Lean/Elab/Attributes.lean @@ -39,5 +39,17 @@ def elabAttrs {m} [Monad m] [MonadEnv m] [MonadError m] (stx : Syntax) : m (Arra pure $ attrs.push attr) #[] +def applyAttributesImp (declName : Name) (attrs : Array Attribute) (applicationTime : AttributeApplicationTime) : CoreM Unit := +attrs.forM $ fun attr => do + env ← getEnv; + match getAttributeImpl env attr.name with + | Except.error errMsg => throwError errMsg + | Except.ok attrImpl => + when (attrImpl.applicationTime == applicationTime) do + attrImpl.add declName attr.args true + +def applyAttributes {m} [MonadLiftT CoreM m] (declName : Name) (attrs : Array Attribute) (applicationTime : AttributeApplicationTime) : m Unit := +liftM $ applyAttributesImp declName attrs applicationTime + end Elab end Lean diff --git a/src/Lean/Elab/Command.lean b/src/Lean/Elab/Command.lean index d672de8162..e0def8d1ec 100644 --- a/src/Lean/Elab/Command.lean +++ b/src/Lean/Elab/Command.lean @@ -540,11 +540,6 @@ when succeeded $ @[builtinCommandElab «check_failure»] def elabCheckFailure : CommandElab := fun stx => failIfSucceeds $ elabCheck stx -def addInstance (declName : Name) : CommandElabM Unit := do -env ← getEnv; -env ← liftIO $ Meta.addGlobalInstance env declName; -setEnv env - unsafe def elabEvalUnsafe : CommandElab := fun stx => withoutModifyingEnv do let ref := stx; diff --git a/src/Lean/Elab/DeclModifiers.lean b/src/Lean/Elab/DeclModifiers.lean index 537a225a66..722e329841 100644 --- a/src/Lean/Elab/DeclModifiers.lean +++ b/src/Lean/Elab/DeclModifiers.lean @@ -124,15 +124,6 @@ currNamespace ← getCurrNamespace; let declName := currNamespace ++ atomicName; applyVisibility modifiers.visibility declName -def applyAttributes (declName : Name) (attrs : Array Attribute) (applicationTime : AttributeApplicationTime) : CommandElabM Unit := -attrs.forM $ fun attr => do - env ← getEnv; - match getAttributeImpl env attr.name with - | Except.error errMsg => throwError errMsg - | Except.ok attrImpl => - when (attrImpl.applicationTime == applicationTime) do - liftCoreM (attrImpl.add declName attr.args true) - end Command end Elab end Lean diff --git a/src/Lean/Elab/Declaration.lean b/src/Lean/Elab/Declaration.lean index 7a98a4f513..21b5bfde61 100644 --- a/src/Lean/Elab/Declaration.lean +++ b/src/Lean/Elab/Declaration.lean @@ -84,9 +84,9 @@ let (binders, typeStx) := expandDeclSig (stx.getArg 2); scopeLevelNames ← getLevelNames; withDeclId declId $ fun name => do declName ← mkDeclName modifiers name; - applyAttributes declName modifiers.attrs AttributeApplicationTime.beforeElaboration; allUserLevelNames ← getLevelNames; - decl ← runTermElabM declName $ fun vars => Term.elabBinders binders.getArgs $ fun xs => do { + runTermElabM declName $ fun vars => Term.elabBinders binders.getArgs $ fun xs => do + applyAttributes declName modifiers.attrs AttributeApplicationTime.beforeElaboration; type ← Term.elabType typeStx; Term.synthesizeSyntheticMVars false; type ← instantiateMVars type; @@ -96,17 +96,17 @@ withDeclId declId $ fun name => do let usedParams := (collectLevelParams {} type).params; match sortDeclLevelParams scopeLevelNames allUserLevelNames usedParams with | Except.error msg => throwErrorAt stx msg - | Except.ok levelParams => - pure $ Declaration.axiomDecl { + | Except.ok levelParams => do + let decl := Declaration.axiomDecl { name := declName, lparams := levelParams, type := type, isUnsafe := modifiers.isUnsafe - } - }; - addDecl decl; - applyAttributes declName modifiers.attrs AttributeApplicationTime.afterTypeChecking; - applyAttributes declName modifiers.attrs AttributeApplicationTime.afterCompilation + }; + -- ensureNoUnassignedMVars decl; -- TODO + addDecl decl; + applyAttributes declName modifiers.attrs AttributeApplicationTime.afterTypeChecking; + applyAttributes declName modifiers.attrs AttributeApplicationTime.afterCompilation /- parser! "inductive " >> declId >> optDeclSig >> many ctor diff --git a/src/Lean/Elab/Definition.lean b/src/Lean/Elab/Definition.lean index 41d2983c67..060a0410d6 100644 --- a/src/Lean/Elab/Definition.lean +++ b/src/Lean/Elab/Definition.lean @@ -61,7 +61,7 @@ private def withUsedWhen' {α} (vars : Array Expr) (xs : Array Expr) (e : Expr) let dummyExpr := mkSort levelOne; withUsedWhen vars xs e dummyExpr cond k -def mkDef (view : DefView) (declName : Name) (scopeLevelNames allUserLevelNames : List Name) (vars : Array Expr) (xs : Array Expr) (type : Expr) (val : Expr) +def mkDef? (view : DefView) (declName : Name) (scopeLevelNames allUserLevelNames : List Name) (vars : Array Expr) (xs : Array Expr) (type : Expr) (val : Expr) : TermElabM (Option Declaration) := do withRef view.ref do Term.synthesizeSyntheticMVars; @@ -126,29 +126,30 @@ withRef view.ref do scopeLevelNames ← getLevelNames; withDeclId view.declId $ fun name => do declName ← withRef view.declId $ mkDeclName view.modifiers name; - applyAttributes declName view.modifiers.attrs AttributeApplicationTime.beforeElaboration; allUserLevelNames ← getLevelNames; - decl? ← runTermElabM declName $ fun vars => Term.elabBinders view.binders.getArgs $ fun xs => - match view.type? with - | some typeStx => do - type ← Term.elabType typeStx; - Term.synthesizeSyntheticMVars false; - type ← instantiateMVars type; - withUsedWhen' vars xs type view.kind.isTheorem $ fun vars => do - val ← elabDefVal view.val type; - mkDef view declName scopeLevelNames allUserLevelNames vars xs type val - | none => do { - type ← withRef view.binders $ mkFreshTypeMVar; - val ← elabDefVal view.val type; - mkDef view declName scopeLevelNames allUserLevelNames vars xs type val - }; - match decl? with - | none => pure () - | some decl => do - addDecl decl; - applyAttributes declName view.modifiers.attrs AttributeApplicationTime.afterTypeChecking; - compileDecl decl; - applyAttributes declName view.modifiers.attrs AttributeApplicationTime.afterCompilation + runTermElabM declName $ fun vars => Term.elabBinders view.binders.getArgs $ fun xs => do + applyAttributes declName view.modifiers.attrs AttributeApplicationTime.beforeElaboration; + decl? ← match view.type? with + | some typeStx => do + type ← Term.elabType typeStx; + Term.synthesizeSyntheticMVars false; + type ← instantiateMVars type; + withUsedWhen' vars xs type view.kind.isTheorem $ fun vars => do + val ← elabDefVal view.val type; + mkDef? view declName scopeLevelNames allUserLevelNames vars xs type val + | none => do { + type ← withRef view.binders $ mkFreshTypeMVar; + val ← elabDefVal view.val type; + mkDef? view declName scopeLevelNames allUserLevelNames vars xs type val + }; + match decl? with + | none => pure () + | some decl => do + -- ensureNoUnassignedMVars decl; -- TODO + addDecl decl; + applyAttributes declName view.modifiers.attrs AttributeApplicationTime.afterTypeChecking; + compileDecl decl; + applyAttributes declName view.modifiers.attrs AttributeApplicationTime.afterCompilation @[init] private def regTraceClasses : IO Unit := do registerTraceClass `Elab.definition; diff --git a/src/Lean/Elab/Inductive.lean b/src/Lean/Elab/Inductive.lean index 925e0dea88..2c798636c1 100644 --- a/src/Lean/Elab/Inductive.lean +++ b/src/Lean/Elab/Inductive.lean @@ -426,7 +426,29 @@ indTypes.map fun indType => { ctor with type := ctorType }; { indType with ctors := ctors } -private def mkInductiveDecl (vars : Array Expr) (views : Array InductiveView) : TermElabM Declaration := do +private def mkAuxConstructions (views : Array InductiveView) : TermElabM Unit := do +env ← getEnv; +let hasEq := env.contains `Eq; +let hasHEq := env.contains `HEq; +let hasUnit := env.contains `PUnit; +let hasProd := env.contains `Prod; +views.forM fun view => do { + let n := view.declName; + modifyEnv fun env => mkRecOn env n; + when hasUnit $ modifyEnv fun env => mkCasesOn env n; + when (hasUnit && hasEq && hasHEq) $ modifyEnv fun env => mkNoConfusion env n; + when (hasUnit && hasProd) $ modifyEnv fun env => mkBelow env n; + when (hasUnit && hasProd) $ modifyEnv fun env => mkIBelow env n; + pure () +}; +views.forM fun view => do { + let n := view.declName; + when (hasUnit && hasProd) $ modifyEnv fun env => mkBRecOn env n; + when (hasUnit && hasProd) $ modifyEnv fun env => mkBInductionOn env n; + pure () +} + +private def mkInductiveDecl (vars : Array Expr) (views : Array InductiveView) : TermElabM Unit := do let view0 := views.get! 0; scopeLevelNames ← Term.getLevelNames; checkLevelNames views; @@ -461,39 +483,17 @@ adaptReader (fun (ctx : Term.Context) => { ctx with levelNames := allUserLevelNa | Except.ok levelParams => do indTypes ← replaceIndFVarsWithConsts views indFVars levelParams numParams indTypes; let indTypes := applyInferMod views numParams indTypes; - pure $ Declaration.inductDecl levelParams numParams indTypes isUnsafe - -private def mkAuxConstructions (views : Array InductiveView) : CommandElabM Unit := do -env ← getEnv; -let hasEq := env.contains `Eq; -let hasHEq := env.contains `HEq; -let hasUnit := env.contains `PUnit; -let hasProd := env.contains `Prod; -views.forM fun view => do { - let n := view.declName; - modifyEnv fun env => mkRecOn env n; - when hasUnit $ modifyEnv fun env => mkCasesOn env n; - when (hasUnit && hasEq && hasHEq) $ modifyEnv fun env => mkNoConfusion env n; - when (hasUnit && hasProd) $ modifyEnv fun env => mkBelow env n; - when (hasUnit && hasProd) $ modifyEnv fun env => mkIBelow env n; - pure () -}; -views.forM fun view => do { - let n := view.declName; - when (hasUnit && hasProd) $ modifyEnv fun env => mkBRecOn env n; - when (hasUnit && hasProd) $ modifyEnv fun env => mkBInductionOn env n; - pure () -} + let decl := Declaration.inductDecl levelParams numParams indTypes isUnsafe; + -- ensureNoUnassignedMVars decl -- TODO + addDecl decl; + mkAuxConstructions views; + -- We need to invoke `applyAttributes` because `class` is implemented as an attribute. + views.forM fun view => applyAttributes view.declName view.modifiers.attrs AttributeApplicationTime.afterTypeChecking def elabInductiveViews (views : Array InductiveView) : CommandElabM Unit := do let view0 := views.get! 0; let ref := view0.ref; -decl ← runTermElabM view0.declName fun vars => withRef ref $ mkInductiveDecl vars views; -addDecl decl; -mkAuxConstructions views; --- We need to invoke `applyAttributes` because `class` is implemented as an attribute. -views.forM fun view => applyAttributes view.declName view.modifiers.attrs AttributeApplicationTime.afterTypeChecking; -pure () +runTermElabM view0.declName fun vars => withRef ref $ mkInductiveDecl vars views end Command end Elab diff --git a/src/Lean/Elab/Structure.lean b/src/Lean/Elab/Structure.lean index c51356d183..8241550565 100644 --- a/src/Lean/Elab/Structure.lean +++ b/src/Lean/Elab/Structure.lean @@ -418,7 +418,36 @@ type ← instantiateMVars type; let type := type.inferImplicit params.size !view.ctor.inferMod; pure { name := view.ctor.declName, type := type } -private def elabStructureView (view : StructView) : TermElabM ElabStructResult := do +@[extern "lean_mk_projections"] +private constant mkProjections (env : Environment) (structName : @& Name) (projs : @& List ProjectionInfo) (isClass : Bool) : Except String Environment := arbitrary _ + +private def addProjections (structName : Name) (projs : List ProjectionInfo) (isClass : Bool) : TermElabM Unit := do +env ← getEnv; +match mkProjections env structName projs isClass with +| Except.ok env => setEnv env +| Except.error msg => throwError msg + +private def mkAuxConstructions (declName : Name) : TermElabM Unit := do +env ← getEnv; +let hasUnit := env.contains `PUnit; +let hasEq := env.contains `Eq; +let hasHEq := env.contains `HEq; +modifyEnv fun env => mkRecOn env declName; +when hasUnit $ modifyEnv fun env => mkCasesOn env declName; +when (hasUnit && hasEq && hasHEq) $ modifyEnv fun env => mkNoConfusion env declName + +private def addDefaults (lctx : LocalContext) (defaultAuxDecls : Array (Name × Expr × Expr)) : TermElabM Unit := do +localInsts ← getLocalInstances; +withLCtx lctx localInsts do + defaultAuxDecls.forM fun ⟨declName, type, value⟩ => do + /- The identity function is used as "marker". -/ + value ← mkId value; + let zeta := true; -- expand `let-declarations` + _ ← mkAuxDefinition declName type value zeta; + modifyEnv fun env => setReducibilityStatus env declName ReducibilityStatus.reducible; + pure () + +private def elabStructureView (view : StructView) : TermElabM Unit := do let numExplicitParams := view.params.size; type ← Term.elabType view.type; unless (validStructType type) $ throwErrorAt view.type "expected Type"; @@ -442,22 +471,26 @@ withFields view.fields 0 fieldInfos fun fieldInfos => do type ← instantiateMVars type; let indType := { name := view.declName, type := type, ctors := [ctor] : InductiveType }; let decl := Declaration.inductDecl levelParams params.size [indType] view.modifiers.isUnsafe; + -- ensureNoUnassignedMVars decl -- TODO + addDecl decl; let projInfos := (fieldInfos.filter fun (info : StructFieldInfo) => !info.isFromParent).toList.map fun (info : StructFieldInfo) => { declName := info.declName, inferMod := info.inferMod : ProjectionInfo }; + addProjections view.declName projInfos view.isClass; + mkAuxConstructions view.declName; instParents ← fieldInfos.filterM fun info => do { decl ← Term.getFVarLocalDecl! info.fvar; pure (info.isSubobject && decl.binderInfo.isInstImplicit) }; let projInstances := instParents.toList.map fun info => info.declName; - mctx ← getMCtx; + applyAttributes view.declName view.modifiers.attrs AttributeApplicationTime.afterTypeChecking; + projInstances.forM addGlobalInstance; lctx ← getLCtx; - localInsts ← getLocalInstances; let fieldsWithDefault := fieldInfos.filter fun info => info.value?.isSome; defaultAuxDecls ← fieldsWithDefault.mapM fun info => do { type ← inferType info.fvar; pure (info.declName ++ `_default, type, info.value?.get!) }; - /- The `mctx`, `lctx`, `localInsts` and `defaultAuxDecls` are used to create the auxiliary `_default` declarations *after* the structure has been declarated. + /- The `lctx` and `defaultAuxDecls` are used to create the auxiliary `_default` declarations The parameters `params` for these definitions must be marked as implicit, and all others as explicit. -/ let lctx := params.foldl (fun (lctx : LocalContext) (p : Expr) => @@ -468,38 +501,7 @@ withFields view.fields 0 fieldInfos fun fieldInfos => do if info.isFromParent then lctx -- `fromParent` fields are elaborated as let-decls, and are zeta-expanded when creating `_default`. else lctx.updateBinderInfo info.fvar.fvarId! BinderInfo.default) lctx; - pure { decl := decl, projInfos := projInfos, projInstances := projInstances, - mctx := mctx, lctx := lctx, localInsts := localInsts, defaultAuxDecls := defaultAuxDecls } - -@[extern "lean_mk_projections"] -private constant mkProjections (env : Environment) (structName : @& Name) (projs : @& List ProjectionInfo) (isClass : Bool) : Except String Environment := arbitrary _ - -private def addProjections (structName : Name) (projs : List ProjectionInfo) (isClass : Bool) : CommandElabM Unit := do -env ← getEnv; -match mkProjections env structName projs isClass with -| Except.ok env => setEnv env -| Except.error msg => throwError msg - -private def mkAuxConstructions (declName : Name) : CommandElabM Unit := do -env ← getEnv; -let hasUnit := env.contains `PUnit; -let hasEq := env.contains `Eq; -let hasHEq := env.contains `HEq; -modifyEnv fun env => mkRecOn env declName; -when hasUnit $ modifyEnv fun env => mkCasesOn env declName; -when (hasUnit && hasEq && hasHEq) $ modifyEnv fun env => mkNoConfusion env declName - -private def addDefaults (mctx : MetavarContext) (lctx : LocalContext) (localInsts : LocalInstances) - (defaultAuxDecls : Array (Name × Expr × Expr)) : CommandElabM Unit := -liftTermElabM none $ withLCtx lctx localInsts do - setMCtx mctx; - defaultAuxDecls.forM fun ⟨declName, type, value⟩ => do - /- The identity function is used as "marker". -/ - value ← mkId value; - let zeta := true; -- expand `let-declarations` - _ ← mkAuxDefinition declName type value zeta; - modifyEnv fun env => setReducibilityStatus env declName ReducibilityStatus.reducible; - pure () + addDefaults lctx defaultAuxDecls /- parser! (structureTk <|> classTk) >> declId >> many Term.bracketedBinder >> optional «extends» >> Term.optType >> " := " >> optional structCtor >> structFields @@ -529,7 +531,7 @@ withDeclId declId $ fun name => do allUserLevelNames ← getLevelNames; ctor ← expandCtor stx modifiers declName; fields ← expandFields stx modifiers declName; - r ← runTermElabM declName $ fun scopeVars => Term.elabBinders params $ fun params => elabStructureView { + runTermElabM declName $ fun scopeVars => Term.elabBinders params $ fun params => elabStructureView { ref := stx, modifiers := modifiers, scopeLevelNames := scopeLevelNames, @@ -542,15 +544,7 @@ withDeclId declId $ fun name => do type := type, ctor := ctor, fields := fields - }; - let ref := declId; - addDecl r.decl; - addProjections declName r.projInfos isClass; - mkAuxConstructions declName; - applyAttributes declName modifiers.attrs AttributeApplicationTime.afterTypeChecking; - r.projInstances.forM addInstance; - addDefaults r.mctx r.lctx r.localInsts r.defaultAuxDecls; - pure () + } end Command end Elab diff --git a/src/Lean/Meta/Instances.lean b/src/Lean/Meta/Instances.lean index a3111545fb..03c2ec6cf7 100644 --- a/src/Lean/Meta/Instances.lean +++ b/src/Lean/Meta/Instances.lean @@ -34,7 +34,7 @@ withNewMCtxDepth $ do DiscrTree.mkPath type @[export lean_add_instance] -def addGlobalInstance (env : Environment) (constName : Name) : IO Environment := +def addGlobalInstanceImp (env : Environment) (constName : Name) : IO Environment := match env.find? constName with | none => throw $ IO.userError "unknown constant" | some cinfo => do @@ -42,6 +42,11 @@ match env.find? constName with (keys, s, _) ← (mkInstanceKey c).toIO {} { env := env } {} {}; pure $ instanceExtension.addEntry s.env { keys := keys, val := c } +def addGlobalInstance {m} [Monad m] [MonadEnv m] [MonadIO m] (declName : Name) : m Unit := do +env ← getEnv; +env ← liftIO $ Meta.addGlobalInstanceImp env declName; +setEnv env + @[init] def registerInstanceAttr : IO Unit := registerBuiltinAttribute { name := `instance, @@ -49,10 +54,8 @@ registerBuiltinAttribute { add := fun declName args persistent => do when args.hasArgs $ throwError "invalid attribute 'instance', unexpected argument"; unless persistent $ throwError "invalid attribute 'instance', must be persistent"; - env ← getEnv; - env ← ofExcept (addGlobalInstanceOld env declName); -- TODO: delete - env ← liftIO $ addGlobalInstance env declName; - setEnv env + env ← getEnv; env ← ofExcept (addGlobalInstanceOld env declName); setEnv env; -- TODO: delete after we remove old frontend + addGlobalInstance declName } end Meta