feat: user-defined attributes

See new test for an example.

closes #513
This commit is contained in:
Leonardo de Moura 2021-07-26 18:20:13 -07:00
parent 0bea52d1b5
commit a77598f7cf
9 changed files with 194 additions and 131 deletions

View file

@ -47,137 +47,9 @@ builtin_initialize attributeMapRef : IO.Ref (PersistentHashMap Name AttributeImp
def registerBuiltinAttribute (attr : AttributeImpl) : IO Unit := do
let m ← attributeMapRef.get
if m.contains attr.name then throw (IO.userError ("invalid builtin attribute declaration, '" ++ toString attr.name ++ "' has already been used"))
unless (← IO.initializing) do throw (IO.userError "failed to register attribute, attributes can only be registered during initialization")
unless (← IO.initializing) || (← importing) do throw (IO.userError "failed to register attribute, attributes can only be registered during initialization")
attributeMapRef.modify fun m => m.insert attr.name attr
abbrev AttributeImplBuilder := List DataValue → Except String AttributeImpl
abbrev AttributeImplBuilderTable := Std.HashMap Name AttributeImplBuilder
builtin_initialize attributeImplBuilderTableRef : IO.Ref AttributeImplBuilderTable ← IO.mkRef {}
def registerAttributeImplBuilder (builderId : Name) (builder : AttributeImplBuilder) : IO Unit := do
let table ← attributeImplBuilderTableRef.get
if table.contains builderId then throw (IO.userError ("attribute implementation builder '" ++ toString builderId ++ "' has already been declared"))
attributeImplBuilderTableRef.modify fun table => table.insert builderId builder
def mkAttributeImplOfBuilder (builderId : Name) (args : List DataValue) : IO AttributeImpl := do
let table ← attributeImplBuilderTableRef.get
match table.find? builderId with
| none => throw (IO.userError ("unknown attribute implementation builder '" ++ toString builderId ++ "'"))
| some builder => IO.ofExcept $ builder args
inductive AttributeExtensionOLeanEntry where
| decl (declName : Name) -- `declName` has type `AttributeImpl`
| builder (builderId : Name) (args : List DataValue)
structure AttributeExtensionState where
newEntries : List AttributeExtensionOLeanEntry := []
map : PersistentHashMap Name AttributeImpl
deriving Inhabited
abbrev AttributeExtension := PersistentEnvExtension AttributeExtensionOLeanEntry (AttributeExtensionOLeanEntry × AttributeImpl) AttributeExtensionState
private def AttributeExtension.mkInitial : IO AttributeExtensionState := do
let map ← attributeMapRef.get
pure { map := map }
unsafe def mkAttributeImplOfConstantUnsafe (env : Environment) (opts : Options) (declName : Name) : Except String AttributeImpl :=
match env.find? declName with
| none => throw ("unknow constant '" ++ toString declName ++ "'")
| some info =>
match info.type with
| Expr.const `Lean.AttributeImpl _ _ => env.evalConst AttributeImpl opts declName
| _ => throw ("unexpected attribute implementation type at '" ++ toString declName ++ "' (`AttributeImpl` expected")
@[implementedBy mkAttributeImplOfConstantUnsafe]
constant mkAttributeImplOfConstant (env : Environment) (opts : Options) (declName : Name) : Except String AttributeImpl
def mkAttributeImplOfEntry (env : Environment) (opts : Options) (e : AttributeExtensionOLeanEntry) : IO AttributeImpl :=
match e with
| AttributeExtensionOLeanEntry.decl declName => IO.ofExcept $ mkAttributeImplOfConstant env opts declName
| AttributeExtensionOLeanEntry.builder builderId args => mkAttributeImplOfBuilder builderId args
private def AttributeExtension.addImported (es : Array (Array AttributeExtensionOLeanEntry)) : ImportM AttributeExtensionState := do
let ctx ← read
let map ← attributeMapRef.get
let map ← es.foldlM
(fun map entries =>
entries.foldlM
(fun (map : PersistentHashMap Name AttributeImpl) entry => do
let attrImpl ← liftM $ mkAttributeImplOfEntry ctx.env ctx.opts entry
pure $ map.insert attrImpl.name attrImpl)
map)
map
pure { map := map }
private def addAttrEntry (s : AttributeExtensionState) (e : AttributeExtensionOLeanEntry × AttributeImpl) : AttributeExtensionState :=
{ s with map := s.map.insert e.2.name e.2, newEntries := e.1 :: s.newEntries }
builtin_initialize attributeExtension : AttributeExtension ←
registerPersistentEnvExtension {
name := `attrExt,
mkInitial := AttributeExtension.mkInitial,
addImportedFn := AttributeExtension.addImported,
addEntryFn := addAttrEntry,
exportEntriesFn := fun s => s.newEntries.reverse.toArray,
statsFn := fun s => format "number of local entries: " ++ format s.newEntries.length
}
/- Return true iff `n` is the name of a registered attribute. -/
@[export lean_is_attribute]
def isBuiltinAttribute (n : Name) : IO Bool := do
let m ← attributeMapRef.get; pure (m.contains n)
/- Return the name of all registered attributes. -/
def getBuiltinAttributeNames : IO (List Name) := do
let m ← attributeMapRef.get; pure $ m.foldl (fun r n _ => n::r) []
def getBuiltinAttributeImpl (attrName : Name) : IO AttributeImpl := do
let m ← attributeMapRef.get
match m.find? attrName with
| some attr => pure attr
| none => throw (IO.userError ("unknown attribute '" ++ toString attrName ++ "'"))
@[export lean_attribute_application_time]
def getBuiltinAttributeApplicationTime (n : Name) : IO AttributeApplicationTime := do
let attr ← getBuiltinAttributeImpl n
pure attr.applicationTime
def isAttribute (env : Environment) (attrName : Name) : Bool :=
(attributeExtension.getState env).map.contains attrName
def getAttributeNames (env : Environment) : List Name :=
let m := (attributeExtension.getState env).map
m.foldl (fun r n _ => n::r) []
def getAttributeImpl (env : Environment) (attrName : Name) : Except String AttributeImpl :=
let m := (attributeExtension.getState env).map
match m.find? attrName with
| some attr => pure attr
| none => throw ("unknown attribute '" ++ toString attrName ++ "'")
def registerAttributeOfDecl (env : Environment) (opts : Options) (attrDeclName : Name) : Except String Environment := do
let attrImpl ← mkAttributeImplOfConstant env opts attrDeclName
if isAttribute env attrImpl.name then
throw ("invalid builtin attribute declaration, '" ++ toString attrImpl.name ++ "' has already been used")
else
pure $ attributeExtension.addEntry env (AttributeExtensionOLeanEntry.decl attrDeclName, attrImpl)
def registerAttributeOfBuilder (env : Environment) (builderId : Name) (args : List DataValue) : IO Environment := do
let attrImpl ← mkAttributeImplOfBuilder builderId args
if isAttribute env attrImpl.name then
throw (IO.userError ("invalid builtin attribute declaration, '" ++ toString attrImpl.name ++ "' has already been used"))
else
pure $ attributeExtension.addEntry env (AttributeExtensionOLeanEntry.builder builderId args, attrImpl)
def Attribute.add (declName : Name) (attrName : Name) (stx : Syntax) (kind := AttributeKind.global) : AttrM Unit := do
let attr ← ofExcept <| getAttributeImpl (← getEnv) attrName
attr.add declName stx kind
def Attribute.erase (declName : Name) (attrName : Name) : AttrM Unit := do
let attr ← ofExcept <| getAttributeImpl (← getEnv) attrName
attr.erase declName
/-
Helper methods for decoding the parameters of builtin attributes that are defined before `Lean.Parser`.
We have the following ones:
@ -406,4 +278,148 @@ def setValue {α : Type} (attrs : EnumAttributes α) (env : Environment) (decl :
end EnumAttributes
/-
Attribute extension and builders. We use builders to implement attribute factories for parser categories.
-/
abbrev AttributeImplBuilder := List DataValue → Except String AttributeImpl
abbrev AttributeImplBuilderTable := Std.HashMap Name AttributeImplBuilder
builtin_initialize attributeImplBuilderTableRef : IO.Ref AttributeImplBuilderTable ← IO.mkRef {}
def registerAttributeImplBuilder (builderId : Name) (builder : AttributeImplBuilder) : IO Unit := do
let table ← attributeImplBuilderTableRef.get
if table.contains builderId then throw (IO.userError ("attribute implementation builder '" ++ toString builderId ++ "' has already been declared"))
attributeImplBuilderTableRef.modify fun table => table.insert builderId builder
def mkAttributeImplOfBuilder (builderId : Name) (args : List DataValue) : IO AttributeImpl := do
let table ← attributeImplBuilderTableRef.get
match table.find? builderId with
| none => throw (IO.userError ("unknown attribute implementation builder '" ++ toString builderId ++ "'"))
| some builder => IO.ofExcept $ builder args
inductive AttributeExtensionOLeanEntry where
| decl (declName : Name) -- `declName` has type `AttributeImpl`
| builder (builderId : Name) (args : List DataValue)
structure AttributeExtensionState where
newEntries : List AttributeExtensionOLeanEntry := []
map : PersistentHashMap Name AttributeImpl
deriving Inhabited
abbrev AttributeExtension := PersistentEnvExtension AttributeExtensionOLeanEntry (AttributeExtensionOLeanEntry × AttributeImpl) AttributeExtensionState
private def AttributeExtension.mkInitial : IO AttributeExtensionState := do
let map ← attributeMapRef.get
pure { map := map }
unsafe def mkAttributeImplOfConstantUnsafe (env : Environment) (opts : Options) (declName : Name) : Except String AttributeImpl :=
match env.find? declName with
| none => throw ("unknow constant '" ++ toString declName ++ "'")
| some info =>
match info.type with
| Expr.const `Lean.AttributeImpl _ _ => env.evalConst AttributeImpl opts declName
| _ => throw ("unexpected attribute implementation type at '" ++ toString declName ++ "' (`AttributeImpl` expected")
@[implementedBy mkAttributeImplOfConstantUnsafe]
constant mkAttributeImplOfConstant (env : Environment) (opts : Options) (declName : Name) : Except String AttributeImpl
def mkAttributeImplOfEntry (env : Environment) (opts : Options) (e : AttributeExtensionOLeanEntry) : IO AttributeImpl :=
match e with
| AttributeExtensionOLeanEntry.decl declName => IO.ofExcept $ mkAttributeImplOfConstant env opts declName
| AttributeExtensionOLeanEntry.builder builderId args => mkAttributeImplOfBuilder builderId args
private def AttributeExtension.addImported (es : Array (Array AttributeExtensionOLeanEntry)) : ImportM AttributeExtensionState := do
let ctx ← read
let map ← attributeMapRef.get
let map ← es.foldlM
(fun map entries =>
entries.foldlM
(fun (map : PersistentHashMap Name AttributeImpl) entry => do
let attrImpl ← liftM $ mkAttributeImplOfEntry ctx.env ctx.opts entry
pure $ map.insert attrImpl.name attrImpl)
map)
map
pure { map := map }
private def addAttrEntry (s : AttributeExtensionState) (e : AttributeExtensionOLeanEntry × AttributeImpl) : AttributeExtensionState :=
{ s with map := s.map.insert e.2.name e.2, newEntries := e.1 :: s.newEntries }
builtin_initialize attributeExtension : AttributeExtension ←
registerPersistentEnvExtension {
name := `attrExt,
mkInitial := AttributeExtension.mkInitial,
addImportedFn := AttributeExtension.addImported,
addEntryFn := addAttrEntry,
exportEntriesFn := fun s => s.newEntries.reverse.toArray,
statsFn := fun s => format "number of local entries: " ++ format s.newEntries.length
}
/- Return true iff `n` is the name of a registered attribute. -/
@[export lean_is_attribute]
def isBuiltinAttribute (n : Name) : IO Bool := do
let m ← attributeMapRef.get; pure (m.contains n)
/- Return the name of all registered attributes. -/
def getBuiltinAttributeNames : IO (List Name) := do
let m ← attributeMapRef.get; pure $ m.foldl (fun r n _ => n::r) []
def getBuiltinAttributeImpl (attrName : Name) : IO AttributeImpl := do
let m ← attributeMapRef.get
match m.find? attrName with
| some attr => pure attr
| none => throw (IO.userError ("unknown attribute '" ++ toString attrName ++ "'"))
@[export lean_attribute_application_time]
def getBuiltinAttributeApplicationTime (n : Name) : IO AttributeApplicationTime := do
let attr ← getBuiltinAttributeImpl n
pure attr.applicationTime
def isAttribute (env : Environment) (attrName : Name) : Bool :=
(attributeExtension.getState env).map.contains attrName
def getAttributeNames (env : Environment) : List Name :=
let m := (attributeExtension.getState env).map
m.foldl (fun r n _ => n::r) []
def getAttributeImpl (env : Environment) (attrName : Name) : Except String AttributeImpl :=
let m := (attributeExtension.getState env).map
match m.find? attrName with
| some attr => pure attr
| none => throw ("unknown attribute '" ++ toString attrName ++ "'")
def registerAttributeOfDecl (env : Environment) (opts : Options) (attrDeclName : Name) : Except String Environment := do
let attrImpl ← mkAttributeImplOfConstant env opts attrDeclName
if isAttribute env attrImpl.name then
throw ("invalid builtin attribute declaration, '" ++ toString attrImpl.name ++ "' has already been used")
else
pure $ attributeExtension.addEntry env (AttributeExtensionOLeanEntry.decl attrDeclName, attrImpl)
def registerAttributeOfBuilder (env : Environment) (builderId : Name) (args : List DataValue) : IO Environment := do
let attrImpl ← mkAttributeImplOfBuilder builderId args
if isAttribute env attrImpl.name then
throw (IO.userError ("invalid builtin attribute declaration, '" ++ toString attrImpl.name ++ "' has already been used"))
else
pure $ attributeExtension.addEntry env (AttributeExtensionOLeanEntry.builder builderId args, attrImpl)
def Attribute.add (declName : Name) (attrName : Name) (stx : Syntax) (kind := AttributeKind.global) : AttrM Unit := do
let attr ← ofExcept <| getAttributeImpl (← getEnv) attrName
attr.add declName stx kind
def Attribute.erase (declName : Name) (attrName : Name) : AttrM Unit := do
let attr ← ofExcept <| getAttributeImpl (← getEnv) attrName
attr.erase declName
builtin_initialize
-- See comment at `updateEnvAttributesRef`
updateEnvAttributesRef.set fun env => do
let map ← attributeMapRef.get
let s ← attributeExtension.getState env
let s := map.foldl (init := s) fun s attrName attrImpl =>
if s.map.contains attrName then
s
else
{ s with map := s.map.insert attrName attrImpl }
return attributeExtension.setState env s
end Lean

View file

@ -49,8 +49,9 @@ unsafe def registerInitAttrUnsafe (attrName : Name) (runAfterImport : Bool) : IO
for modEntries in entries do
for (decl, initDecl) in modEntries do
if initDecl.isAnonymous then
_ ← IO.ofExcept $ ctx.env.evalConst (IO Unit) ctx.opts decl
else runInit ctx.env ctx.opts decl initDecl
discard <| IO.ofExcept $ ctx.env.evalConst (IO Unit) ctx.opts decl
else
runInit ctx.env ctx.opts decl initDecl
}
@[implementedBy registerInitAttrUnsafe]

View file

@ -570,6 +570,17 @@ private def setImportedEntries (env : Environment) (mods : Array ModuleData) (st
env ← extDescr.toEnvExtension.modifyState env fun s => { s with importedEntries := s.importedEntries.push entries }
return env
/-
"Forward declaration" needed for updating the attribute table with user-defined attributes.
User-defined attributes are declared using the `initialize` command. The `initialize` command is just syntax sugar for the `init` attribute.
The `init` attribute is initialized after the `attributeExtension` is initialized. We cannot change the order since the `init` attribute is an attribute,
and requires this extension.
The `attributeExtension` initializer uses `attributeMapRef` to initialize the attribute mapping.
When we a new user-defined attribute declaration is imported, `attributeMapRef` is updated.
Later, we set this method with code that adds the user-defined attributes that were imported after we initialized `attributeExtension`.
-/
builtin_initialize updateEnvAttributesRef : IO.Ref (Environment → IO Environment) ← IO.mkRef (fun env => pure env)
private partial def finalizePersistentExtensions (env : Environment) (mods : Array ModuleData) (opts : Options) : IO Environment := do
loop 0 env
where
@ -588,6 +599,8 @@ where
-- a user-defined persistent extension is imported.
-- Thus, we invoke `setImportedEntries` to update the array `importedEntries` with the entries for the new extensions.
env ← setImportedEntries env mods prevSize
-- See comment at `updateEnvAttributesRef`
env ← (← updateEnvAttributesRef.get) env
loop (i + 1) env
else
return env

View file

@ -187,3 +187,11 @@ add_test(NAME leanpkgtest_user_ext
export PATH=${LEAN_BIN}:$PATH
find . -name '*.olean' -delete
leanpkg build | grep 'world, hello, test'")
add_test(NAME leanpkgtest_user_attr
WORKING_DIRECTORY "${LEAN_SOURCE_DIR}/../tests/leanpkg/user_attr"
COMMAND bash -c "
set -eu
export PATH=${LEAN_BIN}:$PATH
find . -name '*.olean' -delete
leanpkg build")

1
tests/leanpkg/user_attr/.gitignore vendored Normal file
View file

@ -0,0 +1 @@
/build

View file

@ -0,0 +1,12 @@
import UserAttr.Tst
open Lean
def tst : MetaM Unit := do
let env ← getEnv
assert! (blaAttr.hasTag env `f)
assert! (blaAttr.hasTag env `g)
assert! !(blaAttr.hasTag env `id)
pure ()
#eval tst

View file

@ -0,0 +1,5 @@
import Lean
open Lean
initialize blaAttr : TagAttribute ← registerTagAttribute `bla "simple user defined attribute"

View file

@ -0,0 +1,4 @@
import UserAttr.BlaAttr
@[bla] def f (x : Nat) := x + 2
@[bla] def g (x : Nat) := x + 1

View file

@ -0,0 +1,3 @@
[package]
name = "UserAttr"
version = "0.1"