diff --git a/src/Lean/Attributes.lean b/src/Lean/Attributes.lean index 26d36b60df..fb25418d10 100644 --- a/src/Lean/Attributes.lean +++ b/src/Lean/Attributes.lean @@ -47,137 +47,9 @@ builtin_initialize attributeMapRef : IO.Ref (PersistentHashMap Name AttributeImp def registerBuiltinAttribute (attr : AttributeImpl) : IO Unit := do let m ← attributeMapRef.get if m.contains attr.name then throw (IO.userError ("invalid builtin attribute declaration, '" ++ toString attr.name ++ "' has already been used")) - unless (← IO.initializing) do throw (IO.userError "failed to register attribute, attributes can only be registered during initialization") + unless (← IO.initializing) || (← importing) do throw (IO.userError "failed to register attribute, attributes can only be registered during initialization") attributeMapRef.modify fun m => m.insert attr.name attr -abbrev AttributeImplBuilder := List DataValue → Except String AttributeImpl -abbrev AttributeImplBuilderTable := Std.HashMap Name AttributeImplBuilder - -builtin_initialize attributeImplBuilderTableRef : IO.Ref AttributeImplBuilderTable ← IO.mkRef {} - -def registerAttributeImplBuilder (builderId : Name) (builder : AttributeImplBuilder) : IO Unit := do - let table ← attributeImplBuilderTableRef.get - if table.contains builderId then throw (IO.userError ("attribute implementation builder '" ++ toString builderId ++ "' has already been declared")) - attributeImplBuilderTableRef.modify fun table => table.insert builderId builder - -def mkAttributeImplOfBuilder (builderId : Name) (args : List DataValue) : IO AttributeImpl := do - let table ← attributeImplBuilderTableRef.get - match table.find? builderId with - | none => throw (IO.userError ("unknown attribute implementation builder '" ++ toString builderId ++ "'")) - | some builder => IO.ofExcept $ builder args - -inductive AttributeExtensionOLeanEntry where - | decl (declName : Name) -- `declName` has type `AttributeImpl` - | builder (builderId : Name) (args : List DataValue) - -structure AttributeExtensionState where - newEntries : List AttributeExtensionOLeanEntry := [] - map : PersistentHashMap Name AttributeImpl - deriving Inhabited - -abbrev AttributeExtension := PersistentEnvExtension AttributeExtensionOLeanEntry (AttributeExtensionOLeanEntry × AttributeImpl) AttributeExtensionState - -private def AttributeExtension.mkInitial : IO AttributeExtensionState := do - let map ← attributeMapRef.get - pure { map := map } - -unsafe def mkAttributeImplOfConstantUnsafe (env : Environment) (opts : Options) (declName : Name) : Except String AttributeImpl := - match env.find? declName with - | none => throw ("unknow constant '" ++ toString declName ++ "'") - | some info => - match info.type with - | Expr.const `Lean.AttributeImpl _ _ => env.evalConst AttributeImpl opts declName - | _ => throw ("unexpected attribute implementation type at '" ++ toString declName ++ "' (`AttributeImpl` expected") - -@[implementedBy mkAttributeImplOfConstantUnsafe] -constant mkAttributeImplOfConstant (env : Environment) (opts : Options) (declName : Name) : Except String AttributeImpl - -def mkAttributeImplOfEntry (env : Environment) (opts : Options) (e : AttributeExtensionOLeanEntry) : IO AttributeImpl := - match e with - | AttributeExtensionOLeanEntry.decl declName => IO.ofExcept $ mkAttributeImplOfConstant env opts declName - | AttributeExtensionOLeanEntry.builder builderId args => mkAttributeImplOfBuilder builderId args - -private def AttributeExtension.addImported (es : Array (Array AttributeExtensionOLeanEntry)) : ImportM AttributeExtensionState := do - let ctx ← read - let map ← attributeMapRef.get - let map ← es.foldlM - (fun map entries => - entries.foldlM - (fun (map : PersistentHashMap Name AttributeImpl) entry => do - let attrImpl ← liftM $ mkAttributeImplOfEntry ctx.env ctx.opts entry - pure $ map.insert attrImpl.name attrImpl) - map) - map - pure { map := map } - -private def addAttrEntry (s : AttributeExtensionState) (e : AttributeExtensionOLeanEntry × AttributeImpl) : AttributeExtensionState := - { s with map := s.map.insert e.2.name e.2, newEntries := e.1 :: s.newEntries } - -builtin_initialize attributeExtension : AttributeExtension ← - registerPersistentEnvExtension { - name := `attrExt, - mkInitial := AttributeExtension.mkInitial, - addImportedFn := AttributeExtension.addImported, - addEntryFn := addAttrEntry, - exportEntriesFn := fun s => s.newEntries.reverse.toArray, - statsFn := fun s => format "number of local entries: " ++ format s.newEntries.length - } - -/- Return true iff `n` is the name of a registered attribute. -/ -@[export lean_is_attribute] -def isBuiltinAttribute (n : Name) : IO Bool := do - let m ← attributeMapRef.get; pure (m.contains n) - -/- Return the name of all registered attributes. -/ -def getBuiltinAttributeNames : IO (List Name) := do - let m ← attributeMapRef.get; pure $ m.foldl (fun r n _ => n::r) [] - -def getBuiltinAttributeImpl (attrName : Name) : IO AttributeImpl := do - let m ← attributeMapRef.get - match m.find? attrName with - | some attr => pure attr - | none => throw (IO.userError ("unknown attribute '" ++ toString attrName ++ "'")) - -@[export lean_attribute_application_time] -def getBuiltinAttributeApplicationTime (n : Name) : IO AttributeApplicationTime := do - let attr ← getBuiltinAttributeImpl n - pure attr.applicationTime - -def isAttribute (env : Environment) (attrName : Name) : Bool := - (attributeExtension.getState env).map.contains attrName - -def getAttributeNames (env : Environment) : List Name := - let m := (attributeExtension.getState env).map - m.foldl (fun r n _ => n::r) [] - -def getAttributeImpl (env : Environment) (attrName : Name) : Except String AttributeImpl := - let m := (attributeExtension.getState env).map - match m.find? attrName with - | some attr => pure attr - | none => throw ("unknown attribute '" ++ toString attrName ++ "'") - -def registerAttributeOfDecl (env : Environment) (opts : Options) (attrDeclName : Name) : Except String Environment := do - let attrImpl ← mkAttributeImplOfConstant env opts attrDeclName - if isAttribute env attrImpl.name then - throw ("invalid builtin attribute declaration, '" ++ toString attrImpl.name ++ "' has already been used") - else - pure $ attributeExtension.addEntry env (AttributeExtensionOLeanEntry.decl attrDeclName, attrImpl) - -def registerAttributeOfBuilder (env : Environment) (builderId : Name) (args : List DataValue) : IO Environment := do - let attrImpl ← mkAttributeImplOfBuilder builderId args - if isAttribute env attrImpl.name then - throw (IO.userError ("invalid builtin attribute declaration, '" ++ toString attrImpl.name ++ "' has already been used")) - else - pure $ attributeExtension.addEntry env (AttributeExtensionOLeanEntry.builder builderId args, attrImpl) - -def Attribute.add (declName : Name) (attrName : Name) (stx : Syntax) (kind := AttributeKind.global) : AttrM Unit := do - let attr ← ofExcept <| getAttributeImpl (← getEnv) attrName - attr.add declName stx kind - -def Attribute.erase (declName : Name) (attrName : Name) : AttrM Unit := do - let attr ← ofExcept <| getAttributeImpl (← getEnv) attrName - attr.erase declName - /- Helper methods for decoding the parameters of builtin attributes that are defined before `Lean.Parser`. We have the following ones: @@ -406,4 +278,148 @@ def setValue {α : Type} (attrs : EnumAttributes α) (env : Environment) (decl : end EnumAttributes +/- + Attribute extension and builders. We use builders to implement attribute factories for parser categories. +-/ + +abbrev AttributeImplBuilder := List DataValue → Except String AttributeImpl +abbrev AttributeImplBuilderTable := Std.HashMap Name AttributeImplBuilder + +builtin_initialize attributeImplBuilderTableRef : IO.Ref AttributeImplBuilderTable ← IO.mkRef {} + +def registerAttributeImplBuilder (builderId : Name) (builder : AttributeImplBuilder) : IO Unit := do + let table ← attributeImplBuilderTableRef.get + if table.contains builderId then throw (IO.userError ("attribute implementation builder '" ++ toString builderId ++ "' has already been declared")) + attributeImplBuilderTableRef.modify fun table => table.insert builderId builder + +def mkAttributeImplOfBuilder (builderId : Name) (args : List DataValue) : IO AttributeImpl := do + let table ← attributeImplBuilderTableRef.get + match table.find? builderId with + | none => throw (IO.userError ("unknown attribute implementation builder '" ++ toString builderId ++ "'")) + | some builder => IO.ofExcept $ builder args + +inductive AttributeExtensionOLeanEntry where + | decl (declName : Name) -- `declName` has type `AttributeImpl` + | builder (builderId : Name) (args : List DataValue) + +structure AttributeExtensionState where + newEntries : List AttributeExtensionOLeanEntry := [] + map : PersistentHashMap Name AttributeImpl + deriving Inhabited + +abbrev AttributeExtension := PersistentEnvExtension AttributeExtensionOLeanEntry (AttributeExtensionOLeanEntry × AttributeImpl) AttributeExtensionState + +private def AttributeExtension.mkInitial : IO AttributeExtensionState := do + let map ← attributeMapRef.get + pure { map := map } + +unsafe def mkAttributeImplOfConstantUnsafe (env : Environment) (opts : Options) (declName : Name) : Except String AttributeImpl := + match env.find? declName with + | none => throw ("unknow constant '" ++ toString declName ++ "'") + | some info => + match info.type with + | Expr.const `Lean.AttributeImpl _ _ => env.evalConst AttributeImpl opts declName + | _ => throw ("unexpected attribute implementation type at '" ++ toString declName ++ "' (`AttributeImpl` expected") + +@[implementedBy mkAttributeImplOfConstantUnsafe] +constant mkAttributeImplOfConstant (env : Environment) (opts : Options) (declName : Name) : Except String AttributeImpl + +def mkAttributeImplOfEntry (env : Environment) (opts : Options) (e : AttributeExtensionOLeanEntry) : IO AttributeImpl := + match e with + | AttributeExtensionOLeanEntry.decl declName => IO.ofExcept $ mkAttributeImplOfConstant env opts declName + | AttributeExtensionOLeanEntry.builder builderId args => mkAttributeImplOfBuilder builderId args + +private def AttributeExtension.addImported (es : Array (Array AttributeExtensionOLeanEntry)) : ImportM AttributeExtensionState := do + let ctx ← read + let map ← attributeMapRef.get + let map ← es.foldlM + (fun map entries => + entries.foldlM + (fun (map : PersistentHashMap Name AttributeImpl) entry => do + let attrImpl ← liftM $ mkAttributeImplOfEntry ctx.env ctx.opts entry + pure $ map.insert attrImpl.name attrImpl) + map) + map + pure { map := map } + +private def addAttrEntry (s : AttributeExtensionState) (e : AttributeExtensionOLeanEntry × AttributeImpl) : AttributeExtensionState := + { s with map := s.map.insert e.2.name e.2, newEntries := e.1 :: s.newEntries } + +builtin_initialize attributeExtension : AttributeExtension ← + registerPersistentEnvExtension { + name := `attrExt, + mkInitial := AttributeExtension.mkInitial, + addImportedFn := AttributeExtension.addImported, + addEntryFn := addAttrEntry, + exportEntriesFn := fun s => s.newEntries.reverse.toArray, + statsFn := fun s => format "number of local entries: " ++ format s.newEntries.length + } + +/- Return true iff `n` is the name of a registered attribute. -/ +@[export lean_is_attribute] +def isBuiltinAttribute (n : Name) : IO Bool := do + let m ← attributeMapRef.get; pure (m.contains n) + +/- Return the name of all registered attributes. -/ +def getBuiltinAttributeNames : IO (List Name) := do + let m ← attributeMapRef.get; pure $ m.foldl (fun r n _ => n::r) [] + +def getBuiltinAttributeImpl (attrName : Name) : IO AttributeImpl := do + let m ← attributeMapRef.get + match m.find? attrName with + | some attr => pure attr + | none => throw (IO.userError ("unknown attribute '" ++ toString attrName ++ "'")) + +@[export lean_attribute_application_time] +def getBuiltinAttributeApplicationTime (n : Name) : IO AttributeApplicationTime := do + let attr ← getBuiltinAttributeImpl n + pure attr.applicationTime + +def isAttribute (env : Environment) (attrName : Name) : Bool := + (attributeExtension.getState env).map.contains attrName + +def getAttributeNames (env : Environment) : List Name := + let m := (attributeExtension.getState env).map + m.foldl (fun r n _ => n::r) [] + +def getAttributeImpl (env : Environment) (attrName : Name) : Except String AttributeImpl := + let m := (attributeExtension.getState env).map + match m.find? attrName with + | some attr => pure attr + | none => throw ("unknown attribute '" ++ toString attrName ++ "'") + +def registerAttributeOfDecl (env : Environment) (opts : Options) (attrDeclName : Name) : Except String Environment := do + let attrImpl ← mkAttributeImplOfConstant env opts attrDeclName + if isAttribute env attrImpl.name then + throw ("invalid builtin attribute declaration, '" ++ toString attrImpl.name ++ "' has already been used") + else + pure $ attributeExtension.addEntry env (AttributeExtensionOLeanEntry.decl attrDeclName, attrImpl) + +def registerAttributeOfBuilder (env : Environment) (builderId : Name) (args : List DataValue) : IO Environment := do + let attrImpl ← mkAttributeImplOfBuilder builderId args + if isAttribute env attrImpl.name then + throw (IO.userError ("invalid builtin attribute declaration, '" ++ toString attrImpl.name ++ "' has already been used")) + else + pure $ attributeExtension.addEntry env (AttributeExtensionOLeanEntry.builder builderId args, attrImpl) + +def Attribute.add (declName : Name) (attrName : Name) (stx : Syntax) (kind := AttributeKind.global) : AttrM Unit := do + let attr ← ofExcept <| getAttributeImpl (← getEnv) attrName + attr.add declName stx kind + +def Attribute.erase (declName : Name) (attrName : Name) : AttrM Unit := do + let attr ← ofExcept <| getAttributeImpl (← getEnv) attrName + attr.erase declName + +builtin_initialize + -- See comment at `updateEnvAttributesRef` + updateEnvAttributesRef.set fun env => do + let map ← attributeMapRef.get + let s ← attributeExtension.getState env + let s := map.foldl (init := s) fun s attrName attrImpl => + if s.map.contains attrName then + s + else + { s with map := s.map.insert attrName attrImpl } + return attributeExtension.setState env s + end Lean diff --git a/src/Lean/Compiler/InitAttr.lean b/src/Lean/Compiler/InitAttr.lean index b292abb6c8..30b4272f8c 100644 --- a/src/Lean/Compiler/InitAttr.lean +++ b/src/Lean/Compiler/InitAttr.lean @@ -49,8 +49,9 @@ unsafe def registerInitAttrUnsafe (attrName : Name) (runAfterImport : Bool) : IO for modEntries in entries do for (decl, initDecl) in modEntries do if initDecl.isAnonymous then - _ ← IO.ofExcept $ ctx.env.evalConst (IO Unit) ctx.opts decl - else runInit ctx.env ctx.opts decl initDecl + discard <| IO.ofExcept $ ctx.env.evalConst (IO Unit) ctx.opts decl + else + runInit ctx.env ctx.opts decl initDecl } @[implementedBy registerInitAttrUnsafe] diff --git a/src/Lean/Environment.lean b/src/Lean/Environment.lean index dc96222dc5..193ab5d8cd 100644 --- a/src/Lean/Environment.lean +++ b/src/Lean/Environment.lean @@ -570,6 +570,17 @@ private def setImportedEntries (env : Environment) (mods : Array ModuleData) (st env ← extDescr.toEnvExtension.modifyState env fun s => { s with importedEntries := s.importedEntries.push entries } return env +/- + "Forward declaration" needed for updating the attribute table with user-defined attributes. + User-defined attributes are declared using the `initialize` command. The `initialize` command is just syntax sugar for the `init` attribute. + The `init` attribute is initialized after the `attributeExtension` is initialized. We cannot change the order since the `init` attribute is an attribute, + and requires this extension. + The `attributeExtension` initializer uses `attributeMapRef` to initialize the attribute mapping. + When we a new user-defined attribute declaration is imported, `attributeMapRef` is updated. + Later, we set this method with code that adds the user-defined attributes that were imported after we initialized `attributeExtension`. +-/ +builtin_initialize updateEnvAttributesRef : IO.Ref (Environment → IO Environment) ← IO.mkRef (fun env => pure env) + private partial def finalizePersistentExtensions (env : Environment) (mods : Array ModuleData) (opts : Options) : IO Environment := do loop 0 env where @@ -588,6 +599,8 @@ where -- a user-defined persistent extension is imported. -- Thus, we invoke `setImportedEntries` to update the array `importedEntries` with the entries for the new extensions. env ← setImportedEntries env mods prevSize + -- See comment at `updateEnvAttributesRef` + env ← (← updateEnvAttributesRef.get) env loop (i + 1) env else return env diff --git a/src/shell/CMakeLists.txt b/src/shell/CMakeLists.txt index 4949a8729a..b3bdb5718a 100644 --- a/src/shell/CMakeLists.txt +++ b/src/shell/CMakeLists.txt @@ -187,3 +187,11 @@ add_test(NAME leanpkgtest_user_ext export PATH=${LEAN_BIN}:$PATH find . -name '*.olean' -delete leanpkg build | grep 'world, hello, test'") + +add_test(NAME leanpkgtest_user_attr + WORKING_DIRECTORY "${LEAN_SOURCE_DIR}/../tests/leanpkg/user_attr" + COMMAND bash -c " + set -eu + export PATH=${LEAN_BIN}:$PATH + find . -name '*.olean' -delete + leanpkg build") diff --git a/tests/leanpkg/user_attr/.gitignore b/tests/leanpkg/user_attr/.gitignore new file mode 100644 index 0000000000..796b96d1c4 --- /dev/null +++ b/tests/leanpkg/user_attr/.gitignore @@ -0,0 +1 @@ +/build diff --git a/tests/leanpkg/user_attr/UserAttr.lean b/tests/leanpkg/user_attr/UserAttr.lean new file mode 100644 index 0000000000..ec633fea1c --- /dev/null +++ b/tests/leanpkg/user_attr/UserAttr.lean @@ -0,0 +1,12 @@ +import UserAttr.Tst + +open Lean + +def tst : MetaM Unit := do + let env ← getEnv + assert! (blaAttr.hasTag env `f) + assert! (blaAttr.hasTag env `g) + assert! !(blaAttr.hasTag env `id) + pure () + +#eval tst diff --git a/tests/leanpkg/user_attr/UserAttr/BlaAttr.lean b/tests/leanpkg/user_attr/UserAttr/BlaAttr.lean new file mode 100644 index 0000000000..7314fc152b --- /dev/null +++ b/tests/leanpkg/user_attr/UserAttr/BlaAttr.lean @@ -0,0 +1,5 @@ +import Lean + +open Lean + +initialize blaAttr : TagAttribute ← registerTagAttribute `bla "simple user defined attribute" diff --git a/tests/leanpkg/user_attr/UserAttr/Tst.lean b/tests/leanpkg/user_attr/UserAttr/Tst.lean new file mode 100644 index 0000000000..f80db320f8 --- /dev/null +++ b/tests/leanpkg/user_attr/UserAttr/Tst.lean @@ -0,0 +1,4 @@ +import UserAttr.BlaAttr + +@[bla] def f (x : Nat) := x + 2 +@[bla] def g (x : Nat) := x + 1 diff --git a/tests/leanpkg/user_attr/leanpkg.toml b/tests/leanpkg/user_attr/leanpkg.toml new file mode 100644 index 0000000000..155f689776 --- /dev/null +++ b/tests/leanpkg/user_attr/leanpkg.toml @@ -0,0 +1,3 @@ +[package] +name = "UserAttr" +version = "0.1"