refactor: polymorphic applyAttributes

This commit is contained in:
Leonardo de Moura 2020-08-27 10:46:33 -07:00
parent d84078283c
commit bb3c8a2105
8 changed files with 123 additions and 127 deletions

View file

@ -39,5 +39,17 @@ def elabAttrs {m} [Monad m] [MonadEnv m] [MonadError m] (stx : Syntax) : m (Arra
pure $ attrs.push attr)
#[]
def applyAttributesImp (declName : Name) (attrs : Array Attribute) (applicationTime : AttributeApplicationTime) : CoreM Unit :=
attrs.forM $ fun attr => do
env ← getEnv;
match getAttributeImpl env attr.name with
| Except.error errMsg => throwError errMsg
| Except.ok attrImpl =>
when (attrImpl.applicationTime == applicationTime) do
attrImpl.add declName attr.args true
def applyAttributes {m} [MonadLiftT CoreM m] (declName : Name) (attrs : Array Attribute) (applicationTime : AttributeApplicationTime) : m Unit :=
liftM $ applyAttributesImp declName attrs applicationTime
end Elab
end Lean

View file

@ -540,11 +540,6 @@ when succeeded $
@[builtinCommandElab «check_failure»] def elabCheckFailure : CommandElab :=
fun stx => failIfSucceeds $ elabCheck stx
def addInstance (declName : Name) : CommandElabM Unit := do
env ← getEnv;
env ← liftIO $ Meta.addGlobalInstance env declName;
setEnv env
unsafe def elabEvalUnsafe : CommandElab :=
fun stx => withoutModifyingEnv do
let ref := stx;

View file

@ -124,15 +124,6 @@ currNamespace ← getCurrNamespace;
let declName := currNamespace ++ atomicName;
applyVisibility modifiers.visibility declName
def applyAttributes (declName : Name) (attrs : Array Attribute) (applicationTime : AttributeApplicationTime) : CommandElabM Unit :=
attrs.forM $ fun attr => do
env ← getEnv;
match getAttributeImpl env attr.name with
| Except.error errMsg => throwError errMsg
| Except.ok attrImpl =>
when (attrImpl.applicationTime == applicationTime) do
liftCoreM (attrImpl.add declName attr.args true)
end Command
end Elab
end Lean

View file

@ -84,9 +84,9 @@ let (binders, typeStx) := expandDeclSig (stx.getArg 2);
scopeLevelNames ← getLevelNames;
withDeclId declId $ fun name => do
declName ← mkDeclName modifiers name;
applyAttributes declName modifiers.attrs AttributeApplicationTime.beforeElaboration;
allUserLevelNames ← getLevelNames;
decl ← runTermElabM declName $ fun vars => Term.elabBinders binders.getArgs $ fun xs => do {
runTermElabM declName $ fun vars => Term.elabBinders binders.getArgs $ fun xs => do
applyAttributes declName modifiers.attrs AttributeApplicationTime.beforeElaboration;
type ← Term.elabType typeStx;
Term.synthesizeSyntheticMVars false;
type ← instantiateMVars type;
@ -96,17 +96,17 @@ withDeclId declId $ fun name => do
let usedParams := (collectLevelParams {} type).params;
match sortDeclLevelParams scopeLevelNames allUserLevelNames usedParams with
| Except.error msg => throwErrorAt stx msg
| Except.ok levelParams =>
pure $ Declaration.axiomDecl {
| Except.ok levelParams => do
let decl := Declaration.axiomDecl {
name := declName,
lparams := levelParams,
type := type,
isUnsafe := modifiers.isUnsafe
}
};
addDecl decl;
applyAttributes declName modifiers.attrs AttributeApplicationTime.afterTypeChecking;
applyAttributes declName modifiers.attrs AttributeApplicationTime.afterCompilation
};
-- ensureNoUnassignedMVars decl; -- TODO
addDecl decl;
applyAttributes declName modifiers.attrs AttributeApplicationTime.afterTypeChecking;
applyAttributes declName modifiers.attrs AttributeApplicationTime.afterCompilation
/-
parser! "inductive " >> declId >> optDeclSig >> many ctor

View file

@ -61,7 +61,7 @@ private def withUsedWhen' {α} (vars : Array Expr) (xs : Array Expr) (e : Expr)
let dummyExpr := mkSort levelOne;
withUsedWhen vars xs e dummyExpr cond k
def mkDef (view : DefView) (declName : Name) (scopeLevelNames allUserLevelNames : List Name) (vars : Array Expr) (xs : Array Expr) (type : Expr) (val : Expr)
def mkDef? (view : DefView) (declName : Name) (scopeLevelNames allUserLevelNames : List Name) (vars : Array Expr) (xs : Array Expr) (type : Expr) (val : Expr)
: TermElabM (Option Declaration) := do
withRef view.ref do
Term.synthesizeSyntheticMVars;
@ -126,29 +126,30 @@ withRef view.ref do
scopeLevelNames ← getLevelNames;
withDeclId view.declId $ fun name => do
declName ← withRef view.declId $ mkDeclName view.modifiers name;
applyAttributes declName view.modifiers.attrs AttributeApplicationTime.beforeElaboration;
allUserLevelNames ← getLevelNames;
decl? ← runTermElabM declName $ fun vars => Term.elabBinders view.binders.getArgs $ fun xs =>
match view.type? with
| some typeStx => do
type ← Term.elabType typeStx;
Term.synthesizeSyntheticMVars false;
type ← instantiateMVars type;
withUsedWhen' vars xs type view.kind.isTheorem $ fun vars => do
val ← elabDefVal view.val type;
mkDef view declName scopeLevelNames allUserLevelNames vars xs type val
| none => do {
type ← withRef view.binders $ mkFreshTypeMVar;
val ← elabDefVal view.val type;
mkDef view declName scopeLevelNames allUserLevelNames vars xs type val
};
match decl? with
| none => pure ()
| some decl => do
addDecl decl;
applyAttributes declName view.modifiers.attrs AttributeApplicationTime.afterTypeChecking;
compileDecl decl;
applyAttributes declName view.modifiers.attrs AttributeApplicationTime.afterCompilation
runTermElabM declName $ fun vars => Term.elabBinders view.binders.getArgs $ fun xs => do
applyAttributes declName view.modifiers.attrs AttributeApplicationTime.beforeElaboration;
decl? ← match view.type? with
| some typeStx => do
type ← Term.elabType typeStx;
Term.synthesizeSyntheticMVars false;
type ← instantiateMVars type;
withUsedWhen' vars xs type view.kind.isTheorem $ fun vars => do
val ← elabDefVal view.val type;
mkDef? view declName scopeLevelNames allUserLevelNames vars xs type val
| none => do {
type ← withRef view.binders $ mkFreshTypeMVar;
val ← elabDefVal view.val type;
mkDef? view declName scopeLevelNames allUserLevelNames vars xs type val
};
match decl? with
| none => pure ()
| some decl => do
-- ensureNoUnassignedMVars decl; -- TODO
addDecl decl;
applyAttributes declName view.modifiers.attrs AttributeApplicationTime.afterTypeChecking;
compileDecl decl;
applyAttributes declName view.modifiers.attrs AttributeApplicationTime.afterCompilation
@[init] private def regTraceClasses : IO Unit := do
registerTraceClass `Elab.definition;

View file

@ -426,7 +426,29 @@ indTypes.map fun indType =>
{ ctor with type := ctorType };
{ indType with ctors := ctors }
private def mkInductiveDecl (vars : Array Expr) (views : Array InductiveView) : TermElabM Declaration := do
private def mkAuxConstructions (views : Array InductiveView) : TermElabM Unit := do
env ← getEnv;
let hasEq := env.contains `Eq;
let hasHEq := env.contains `HEq;
let hasUnit := env.contains `PUnit;
let hasProd := env.contains `Prod;
views.forM fun view => do {
let n := view.declName;
modifyEnv fun env => mkRecOn env n;
when hasUnit $ modifyEnv fun env => mkCasesOn env n;
when (hasUnit && hasEq && hasHEq) $ modifyEnv fun env => mkNoConfusion env n;
when (hasUnit && hasProd) $ modifyEnv fun env => mkBelow env n;
when (hasUnit && hasProd) $ modifyEnv fun env => mkIBelow env n;
pure ()
};
views.forM fun view => do {
let n := view.declName;
when (hasUnit && hasProd) $ modifyEnv fun env => mkBRecOn env n;
when (hasUnit && hasProd) $ modifyEnv fun env => mkBInductionOn env n;
pure ()
}
private def mkInductiveDecl (vars : Array Expr) (views : Array InductiveView) : TermElabM Unit := do
let view0 := views.get! 0;
scopeLevelNames ← Term.getLevelNames;
checkLevelNames views;
@ -461,39 +483,17 @@ adaptReader (fun (ctx : Term.Context) => { ctx with levelNames := allUserLevelNa
| Except.ok levelParams => do
indTypes ← replaceIndFVarsWithConsts views indFVars levelParams numParams indTypes;
let indTypes := applyInferMod views numParams indTypes;
pure $ Declaration.inductDecl levelParams numParams indTypes isUnsafe
private def mkAuxConstructions (views : Array InductiveView) : CommandElabM Unit := do
env ← getEnv;
let hasEq := env.contains `Eq;
let hasHEq := env.contains `HEq;
let hasUnit := env.contains `PUnit;
let hasProd := env.contains `Prod;
views.forM fun view => do {
let n := view.declName;
modifyEnv fun env => mkRecOn env n;
when hasUnit $ modifyEnv fun env => mkCasesOn env n;
when (hasUnit && hasEq && hasHEq) $ modifyEnv fun env => mkNoConfusion env n;
when (hasUnit && hasProd) $ modifyEnv fun env => mkBelow env n;
when (hasUnit && hasProd) $ modifyEnv fun env => mkIBelow env n;
pure ()
};
views.forM fun view => do {
let n := view.declName;
when (hasUnit && hasProd) $ modifyEnv fun env => mkBRecOn env n;
when (hasUnit && hasProd) $ modifyEnv fun env => mkBInductionOn env n;
pure ()
}
let decl := Declaration.inductDecl levelParams numParams indTypes isUnsafe;
-- ensureNoUnassignedMVars decl -- TODO
addDecl decl;
mkAuxConstructions views;
-- We need to invoke `applyAttributes` because `class` is implemented as an attribute.
views.forM fun view => applyAttributes view.declName view.modifiers.attrs AttributeApplicationTime.afterTypeChecking
def elabInductiveViews (views : Array InductiveView) : CommandElabM Unit := do
let view0 := views.get! 0;
let ref := view0.ref;
decl ← runTermElabM view0.declName fun vars => withRef ref $ mkInductiveDecl vars views;
addDecl decl;
mkAuxConstructions views;
-- We need to invoke `applyAttributes` because `class` is implemented as an attribute.
views.forM fun view => applyAttributes view.declName view.modifiers.attrs AttributeApplicationTime.afterTypeChecking;
pure ()
runTermElabM view0.declName fun vars => withRef ref $ mkInductiveDecl vars views
end Command
end Elab

View file

@ -418,7 +418,36 @@ type ← instantiateMVars type;
let type := type.inferImplicit params.size !view.ctor.inferMod;
pure { name := view.ctor.declName, type := type }
private def elabStructureView (view : StructView) : TermElabM ElabStructResult := do
@[extern "lean_mk_projections"]
private constant mkProjections (env : Environment) (structName : @& Name) (projs : @& List ProjectionInfo) (isClass : Bool) : Except String Environment := arbitrary _
private def addProjections (structName : Name) (projs : List ProjectionInfo) (isClass : Bool) : TermElabM Unit := do
env ← getEnv;
match mkProjections env structName projs isClass with
| Except.ok env => setEnv env
| Except.error msg => throwError msg
private def mkAuxConstructions (declName : Name) : TermElabM Unit := do
env ← getEnv;
let hasUnit := env.contains `PUnit;
let hasEq := env.contains `Eq;
let hasHEq := env.contains `HEq;
modifyEnv fun env => mkRecOn env declName;
when hasUnit $ modifyEnv fun env => mkCasesOn env declName;
when (hasUnit && hasEq && hasHEq) $ modifyEnv fun env => mkNoConfusion env declName
private def addDefaults (lctx : LocalContext) (defaultAuxDecls : Array (Name × Expr × Expr)) : TermElabM Unit := do
localInsts ← getLocalInstances;
withLCtx lctx localInsts do
defaultAuxDecls.forM fun ⟨declName, type, value⟩ => do
/- The identity function is used as "marker". -/
value ← mkId value;
let zeta := true; -- expand `let-declarations`
_ ← mkAuxDefinition declName type value zeta;
modifyEnv fun env => setReducibilityStatus env declName ReducibilityStatus.reducible;
pure ()
private def elabStructureView (view : StructView) : TermElabM Unit := do
let numExplicitParams := view.params.size;
type ← Term.elabType view.type;
unless (validStructType type) $ throwErrorAt view.type "expected Type";
@ -442,22 +471,26 @@ withFields view.fields 0 fieldInfos fun fieldInfos => do
type ← instantiateMVars type;
let indType := { name := view.declName, type := type, ctors := [ctor] : InductiveType };
let decl := Declaration.inductDecl levelParams params.size [indType] view.modifiers.isUnsafe;
-- ensureNoUnassignedMVars decl -- TODO
addDecl decl;
let projInfos := (fieldInfos.filter fun (info : StructFieldInfo) => !info.isFromParent).toList.map fun (info : StructFieldInfo) =>
{ declName := info.declName, inferMod := info.inferMod : ProjectionInfo };
addProjections view.declName projInfos view.isClass;
mkAuxConstructions view.declName;
instParents ← fieldInfos.filterM fun info => do {
decl ← Term.getFVarLocalDecl! info.fvar;
pure (info.isSubobject && decl.binderInfo.isInstImplicit)
};
let projInstances := instParents.toList.map fun info => info.declName;
mctx ← getMCtx;
applyAttributes view.declName view.modifiers.attrs AttributeApplicationTime.afterTypeChecking;
projInstances.forM addGlobalInstance;
lctx ← getLCtx;
localInsts ← getLocalInstances;
let fieldsWithDefault := fieldInfos.filter fun info => info.value?.isSome;
defaultAuxDecls ← fieldsWithDefault.mapM fun info => do {
type ← inferType info.fvar;
pure (info.declName ++ `_default, type, info.value?.get!)
};
/- The `mctx`, `lctx`, `localInsts` and `defaultAuxDecls` are used to create the auxiliary `_default` declarations *after* the structure has been declarated.
/- The `lctx` and `defaultAuxDecls` are used to create the auxiliary `_default` declarations
The parameters `params` for these definitions must be marked as implicit, and all others as explicit. -/
let lctx := params.foldl
(fun (lctx : LocalContext) (p : Expr) =>
@ -468,38 +501,7 @@ withFields view.fields 0 fieldInfos fun fieldInfos => do
if info.isFromParent then lctx -- `fromParent` fields are elaborated as let-decls, and are zeta-expanded when creating `_default`.
else lctx.updateBinderInfo info.fvar.fvarId! BinderInfo.default)
lctx;
pure { decl := decl, projInfos := projInfos, projInstances := projInstances,
mctx := mctx, lctx := lctx, localInsts := localInsts, defaultAuxDecls := defaultAuxDecls }
@[extern "lean_mk_projections"]
private constant mkProjections (env : Environment) (structName : @& Name) (projs : @& List ProjectionInfo) (isClass : Bool) : Except String Environment := arbitrary _
private def addProjections (structName : Name) (projs : List ProjectionInfo) (isClass : Bool) : CommandElabM Unit := do
env ← getEnv;
match mkProjections env structName projs isClass with
| Except.ok env => setEnv env
| Except.error msg => throwError msg
private def mkAuxConstructions (declName : Name) : CommandElabM Unit := do
env ← getEnv;
let hasUnit := env.contains `PUnit;
let hasEq := env.contains `Eq;
let hasHEq := env.contains `HEq;
modifyEnv fun env => mkRecOn env declName;
when hasUnit $ modifyEnv fun env => mkCasesOn env declName;
when (hasUnit && hasEq && hasHEq) $ modifyEnv fun env => mkNoConfusion env declName
private def addDefaults (mctx : MetavarContext) (lctx : LocalContext) (localInsts : LocalInstances)
(defaultAuxDecls : Array (Name × Expr × Expr)) : CommandElabM Unit :=
liftTermElabM none $ withLCtx lctx localInsts do
setMCtx mctx;
defaultAuxDecls.forM fun ⟨declName, type, value⟩ => do
/- The identity function is used as "marker". -/
value ← mkId value;
let zeta := true; -- expand `let-declarations`
_ ← mkAuxDefinition declName type value zeta;
modifyEnv fun env => setReducibilityStatus env declName ReducibilityStatus.reducible;
pure ()
addDefaults lctx defaultAuxDecls
/-
parser! (structureTk <|> classTk) >> declId >> many Term.bracketedBinder >> optional «extends» >> Term.optType >> " := " >> optional structCtor >> structFields
@ -529,7 +531,7 @@ withDeclId declId $ fun name => do
allUserLevelNames ← getLevelNames;
ctor ← expandCtor stx modifiers declName;
fields ← expandFields stx modifiers declName;
r ← runTermElabM declName $ fun scopeVars => Term.elabBinders params $ fun params => elabStructureView {
runTermElabM declName $ fun scopeVars => Term.elabBinders params $ fun params => elabStructureView {
ref := stx,
modifiers := modifiers,
scopeLevelNames := scopeLevelNames,
@ -542,15 +544,7 @@ withDeclId declId $ fun name => do
type := type,
ctor := ctor,
fields := fields
};
let ref := declId;
addDecl r.decl;
addProjections declName r.projInfos isClass;
mkAuxConstructions declName;
applyAttributes declName modifiers.attrs AttributeApplicationTime.afterTypeChecking;
r.projInstances.forM addInstance;
addDefaults r.mctx r.lctx r.localInsts r.defaultAuxDecls;
pure ()
}
end Command
end Elab

View file

@ -34,7 +34,7 @@ withNewMCtxDepth $ do
DiscrTree.mkPath type
@[export lean_add_instance]
def addGlobalInstance (env : Environment) (constName : Name) : IO Environment :=
def addGlobalInstanceImp (env : Environment) (constName : Name) : IO Environment :=
match env.find? constName with
| none => throw $ IO.userError "unknown constant"
| some cinfo => do
@ -42,6 +42,11 @@ match env.find? constName with
(keys, s, _) ← (mkInstanceKey c).toIO {} { env := env } {} {};
pure $ instanceExtension.addEntry s.env { keys := keys, val := c }
def addGlobalInstance {m} [Monad m] [MonadEnv m] [MonadIO m] (declName : Name) : m Unit := do
env ← getEnv;
env ← liftIO $ Meta.addGlobalInstanceImp env declName;
setEnv env
@[init] def registerInstanceAttr : IO Unit :=
registerBuiltinAttribute {
name := `instance,
@ -49,10 +54,8 @@ registerBuiltinAttribute {
add := fun declName args persistent => do
when args.hasArgs $ throwError "invalid attribute 'instance', unexpected argument";
unless persistent $ throwError "invalid attribute 'instance', must be persistent";
env ← getEnv;
env ← ofExcept (addGlobalInstanceOld env declName); -- TODO: delete
env ← liftIO $ addGlobalInstance env declName;
setEnv env
env ← getEnv; env ← ofExcept (addGlobalInstanceOld env declName); setEnv env; -- TODO: delete after we remove old frontend
addGlobalInstance declName
}
end Meta