feat: add attributeExtension

This commit is contained in:
Leonardo de Moura 2020-01-10 19:51:53 -08:00
parent 48600dbbfc
commit 3c8d8c7434
2 changed files with 84 additions and 12 deletions

View file

@ -60,26 +60,94 @@ initializing ← IO.initializing;
unless initializing $ throw (IO.userError ("failed to register attribute, attributes can only be registered during initialization"));
attributeMapRef.modify (fun m => m.insert attr.name attr)
structure AttributeExtensionState :=
(newEntries : List Name := [])
(map : PersistentHashMap Name AttributeImpl)
abbrev AttributeExtension := PersistentEnvExtension Name (Name × AttributeImpl) AttributeExtensionState
instance AttributeExtensionState.inhabited : Inhabited AttributeExtensionState := ⟨{ map := {} }⟩
private def AttributeExtension.mkInitial : IO AttributeExtensionState := do
map ← attributeMapRef.get;
pure { map := map }
unsafe def mkAttributeImplOfConstantUnsafe (env : Environment) (declName : Name) : Except String AttributeImpl :=
match env.find? declName with
| none => throw ("unknow constant '" ++ toString declName ++ "'")
| some info =>
match info.type with
| Expr.const `Lean.AttributeImpl _ _ => env.evalConst AttributeImpl declName
| _ => throw ("unexpected attribute implementation type at '" ++ toString declName ++ "' (`AttributeImpl` expected")
@[implementedBy mkAttributeImplOfConstantUnsafe]
constant mkAttributeImplOfConstant (env : Environment) (declName : Name) : Except String AttributeImpl := arbitrary _
private def AttributeExtension.addImported (env : Environment) (es : Array (Array Name)) : IO AttributeExtensionState := do
map ← attributeMapRef.get;
map ← es.foldlM
(fun map entries =>
entries.foldlM
(fun (map : PersistentHashMap Name AttributeImpl) declName => do
attrImpl ← IO.ofExcept $ mkAttributeImplOfConstant env declName;
pure $ map.insert attrImpl.name attrImpl)
map)
map;
pure { map := map }
private def AttributeExtension.addEntry (s : AttributeExtensionState) (e : Name × AttributeImpl) : AttributeExtensionState :=
{ map := s.map.insert e.2.name e.2, newEntries := e.1 :: s.newEntries, .. s }
def mkAttributeExtension : IO AttributeExtension :=
registerPersistentEnvExtension {
name := `attrExt,
mkInitial := AttributeExtension.mkInitial,
addImportedFn := AttributeExtension.addImported,
addEntryFn := AttributeExtension.addEntry,
exportEntriesFn := fun s => s.newEntries.reverse.toArray,
statsFn := fun s => format "number of local entries: " ++ format s.newEntries.length
}
@[init mkAttributeExtension]
def attributeExtension : AttributeExtension := arbitrary _
/- Return true iff `n` is the name of a registered attribute. -/
@[export lean_is_attribute]
def isAttribute (n : Name) : IO Bool := do
def isBuiltinAttribute (n : Name) : IO Bool := do
m ← attributeMapRef.get; pure (m.contains n)
/- Return the name of all registered attributes. -/
def getAttributeNames : IO (List Name) := do
def getBuiltinAttributeNames : IO (List Name) := do
m ← attributeMapRef.get; pure $ m.foldl (fun r n _ => n::r) []
def getAttributeImpl (attrName : Name) : IO AttributeImpl := do
def getBuiltinAttributeImpl (attrName : Name) : IO AttributeImpl := do
m ← attributeMapRef.get;
match m.find? attrName with
| some attr => pure attr
| none => throw (IO.userError ("unknown attribute '" ++ toString attrName ++ "'"))
@[export lean_attribute_application_time]
def attributeApplicationTime (n : Name) : IO AttributeApplicationTime := do
attr ← getAttributeImpl n;
def getBuiltinAttributeApplicationTime (n : Name) : IO AttributeApplicationTime := do
attr ← getBuiltinAttributeImpl n;
pure attr.applicationTime
def isAttribute (env : Environment) (attrName : Name) : Bool :=
(attributeExtension.getState env).map.contains attrName
def getAttributeNames (env : Environment) : List Name :=
let m := (attributeExtension.getState env).map;
m.foldl (fun r n _ => n::r) []
def getAttributeImpl (env : Environment) (attrName : Name) : Except String AttributeImpl :=
let m := (attributeExtension.getState env).map;
match m.find? attrName with
| some attr => pure attr
| none => throw ("unknown attribute '" ++ toString attrName ++ "'")
def registerAttribute (env : Environment) (attrDeclName : Name) : Except String Environment := do
attrImpl ← mkAttributeImplOfConstant env attrDeclName;
pure $ attributeExtension.addEntry env (attrDeclName, attrImpl)
namespace Environment
/- Add attribute `attr` to declaration `decl` with arguments `args`. If `persistent == true`, then attribute is saved on .olean file.
@ -89,7 +157,7 @@ namespace Environment
- `args` is not valid for `attr`. -/
@[export lean_add_attribute]
def addAttribute (env : Environment) (decl : Name) (attrName : Name) (args : Syntax := Syntax.missing) (persistent := true) : IO Environment := do
attr ← getAttributeImpl attrName;
attr ← IO.ofExcept $ getAttributeImpl env attrName;
attr.add env decl args persistent
/-

View file

@ -60,7 +60,8 @@ let nameStx := stx.getArg 0;
attrName ← match nameStx.isIdOrAtom? with
| none => throwError nameStx "identifier expected"
| some str => pure $ mkNameSimple str;
unlessM (liftIO stx (isAttribute attrName)) $
env ← getEnv;
unless (isAttribute env attrName) $
throwError stx ("unknown attribute [" ++ attrName ++ "]");
let args := stx.getArg 1;
pure { name := attrName, args := args }
@ -117,11 +118,14 @@ match modifiers.visibility with
def applyAttributes (ref : Syntax) (declName : Name) (attrs : Array Attribute) (applicationTime : AttributeApplicationTime) : CommandElabM Unit :=
attrs.forM $ fun attr => do
attrImpl ← liftIO ref $ getAttributeImpl attr.name;
when (attrImpl.applicationTime == applicationTime) $ do
env ← getEnv;
env ← liftIO ref $ attrImpl.add env declName attr.args true;
setEnv env
env ← getEnv;
match getAttributeImpl env attr.name with
| Except.error errMsg => throwError ref errMsg
| Except.ok attrImpl =>
when (attrImpl.applicationTime == applicationTime) $ do
env ← getEnv;
env ← liftIO ref $ attrImpl.add env declName attr.args true;
setEnv env
end Command
end Elab