feat: user-defined environment extensions
New test demonstrates how to use them. The user-defined extensions cannot be used in the same file where they were declared because the `initialize` commands are only executed when we import the modules containing them. TODO: user-defined attributes.
This commit is contained in:
parent
42561bb93f
commit
cdd1dbbb36
9 changed files with 206 additions and 68 deletions
|
|
@ -12,6 +12,13 @@ import Lean.Util.FindExpr
|
|||
import Lean.Util.Profile
|
||||
|
||||
namespace Lean
|
||||
builtin_initialize importingRef : IO.Ref Bool ← IO.mkRef false
|
||||
|
||||
/- True while modules are being imported. We use this flag to test check whether environment extensions are registered only
|
||||
during initialization (builtin ones), and importing (user defined ones). -/
|
||||
def importing : IO Bool :=
|
||||
importingRef.get
|
||||
|
||||
/- Opaque environment extension state. -/
|
||||
constant EnvExtensionStateSpec : PointedType.{0}
|
||||
def EnvExtensionState : Type := EnvExtensionStateSpec.type
|
||||
|
|
@ -145,18 +152,20 @@ structure EnvExtensionInterface where
|
|||
registerExt {σ} (mkInitial : IO σ) : IO (ext σ)
|
||||
setState {σ} (e : ext σ) (env : Environment) : σ → Environment
|
||||
modifyState {σ} (e : ext σ) (env : Environment) : (σ → σ) → Environment
|
||||
getState {σ} (e : ext σ) (env : Environment) : σ
|
||||
getState {σ} [Inhabited σ] (e : ext σ) (env : Environment) : σ
|
||||
mkInitialExtStates : IO (Array EnvExtensionState)
|
||||
ensureExtensionsSize : Environment → IO Environment
|
||||
|
||||
instance : Inhabited EnvExtensionInterface where
|
||||
default := {
|
||||
ext := id,
|
||||
inhabitedExt := id,
|
||||
registerExt := fun mk => mk,
|
||||
setState := fun _ env _ => env,
|
||||
modifyState := fun _ env _ => env,
|
||||
getState := fun ext _ => ext,
|
||||
mkInitialExtStates := pure #[]
|
||||
ext := id
|
||||
inhabitedExt := id
|
||||
ensureExtensionsSize := fun env => pure env
|
||||
registerExt := fun mk => mk
|
||||
setState := fun _ env _ => env
|
||||
modifyState := fun _ env _ => env
|
||||
getState := fun ext _ => ext
|
||||
mkInitialExtStates := pure #[]
|
||||
}
|
||||
|
||||
/- Unsafe implementation of `EnvExtensionInterface` -/
|
||||
|
|
@ -170,23 +179,53 @@ structure Ext (σ : Type) where
|
|||
private def mkEnvExtensionsRef : IO (IO.Ref (Array (Ext EnvExtensionState))) := IO.mkRef #[]
|
||||
@[builtinInit mkEnvExtensionsRef] private constant envExtensionsRef : IO.Ref (Array (Ext EnvExtensionState))
|
||||
|
||||
/--
|
||||
User-defined environment extensions are declared using the `initialize` command.
|
||||
This command is just syntax sugar for the `init` attribute.
|
||||
When we `import` lean modules, the vector stored at `envExtensionsRef` may increase in size because of
|
||||
user-defined environment extensions. When this happens, we must adjust the size of the `env.extensions`.
|
||||
This method is invoked when processing `import`s.
|
||||
-/
|
||||
partial def ensureExtensionsArraySize (env : Environment) : IO Environment := do
|
||||
loop env.extensions.size env
|
||||
where
|
||||
loop (i : Nat) (env : Environment) : IO Environment := do
|
||||
let envExtensions ← envExtensionsRef.get
|
||||
if h : i < envExtensions.size then
|
||||
let s ← envExtensions[i].mkInitial
|
||||
let env := { env with extensions := env.extensions.push s }
|
||||
loop (i + 1) env
|
||||
else
|
||||
return env
|
||||
|
||||
private def invalidExtMsg := "invalid environment extension has been accessed"
|
||||
|
||||
unsafe def setState {σ} (ext : Ext σ) (env : Environment) (s : σ) : Environment :=
|
||||
{ env with extensions := env.extensions.set! ext.idx (unsafeCast s) }
|
||||
if h : ext.idx < env.extensions.size then
|
||||
{ env with extensions := env.extensions.set ⟨ext.idx, h⟩ (unsafeCast s) }
|
||||
else
|
||||
panic! invalidExtMsg
|
||||
|
||||
@[inline] unsafe def modifyState {σ : Type} (ext : Ext σ) (env : Environment) (f : σ → σ) : Environment :=
|
||||
{ env with
|
||||
extensions := env.extensions.modify ext.idx fun s =>
|
||||
let s : σ := unsafeCast s;
|
||||
let s : σ := f s;
|
||||
unsafeCast s }
|
||||
if ext.idx < env.extensions.size then
|
||||
{ env with
|
||||
extensions := env.extensions.modify ext.idx fun s =>
|
||||
let s : σ := unsafeCast s
|
||||
let s : σ := f s
|
||||
unsafeCast s }
|
||||
else
|
||||
panic! invalidExtMsg
|
||||
|
||||
unsafe def getState {σ} (ext : Ext σ) (env : Environment) : σ :=
|
||||
let s : EnvExtensionState := env.extensions.get! ext.idx
|
||||
unsafeCast s
|
||||
unsafe def getState {σ} [Inhabited σ] (ext : Ext σ) (env : Environment) : σ :=
|
||||
if h : ext.idx < env.extensions.size then
|
||||
let s : EnvExtensionState := env.extensions.get ⟨ext.idx, h⟩
|
||||
unsafeCast s
|
||||
else
|
||||
panic! invalidExtMsg
|
||||
|
||||
unsafe def registerExt {σ} (mkInitial : IO σ) : IO (Ext σ) := do
|
||||
let initializing ← IO.initializing
|
||||
unless initializing do throw (IO.userError "failed to register environment, extensions can only be registered during initialization")
|
||||
unless (← IO.initializing) || (← importing) do
|
||||
throw (IO.userError "failed to register environment, extensions can only be registered during initialization")
|
||||
let exts ← envExtensionsRef.get
|
||||
let idx := exts.size
|
||||
let ext : Ext σ := {
|
||||
|
|
@ -201,13 +240,14 @@ def mkInitialExtStates : IO (Array EnvExtensionState) := do
|
|||
exts.mapM fun ext => ext.mkInitial
|
||||
|
||||
unsafe def imp : EnvExtensionInterface := {
|
||||
ext := Ext,
|
||||
inhabitedExt := fun _ => ⟨arbitrary⟩,
|
||||
registerExt := registerExt,
|
||||
setState := setState,
|
||||
modifyState := modifyState,
|
||||
getState := getState,
|
||||
mkInitialExtStates := mkInitialExtStates
|
||||
ext := Ext
|
||||
ensureExtensionsSize := ensureExtensionsArraySize
|
||||
inhabitedExt := fun _ => ⟨arbitrary⟩
|
||||
registerExt := registerExt
|
||||
setState := setState
|
||||
modifyState := modifyState
|
||||
getState := getState
|
||||
mkInitialExtStates := mkInitialExtStates
|
||||
}
|
||||
|
||||
end EnvExtensionInterfaceUnsafe
|
||||
|
|
@ -217,11 +257,14 @@ constant EnvExtensionInterfaceImp : EnvExtensionInterface
|
|||
|
||||
def EnvExtension (σ : Type) : Type := EnvExtensionInterfaceImp.ext σ
|
||||
|
||||
private def ensureExtensionsArraySize (env : Environment) : IO Environment :=
|
||||
EnvExtensionInterfaceImp.ensureExtensionsSize env
|
||||
|
||||
namespace EnvExtension
|
||||
instance {σ} [s : Inhabited σ] : Inhabited (EnvExtension σ) := EnvExtensionInterfaceImp.inhabitedExt s
|
||||
def setState {σ : Type} (ext : EnvExtension σ) (env : Environment) (s : σ) : Environment := EnvExtensionInterfaceImp.setState ext env s
|
||||
def modifyState {σ : Type} (ext : EnvExtension σ) (env : Environment) (f : σ → σ) : Environment := EnvExtensionInterfaceImp.modifyState ext env f
|
||||
def getState {σ : Type} (ext : EnvExtension σ) (env : Environment) : σ := EnvExtensionInterfaceImp.getState ext env
|
||||
def getState {σ : Type} [Inhabited σ] (ext : EnvExtension σ) (env : Environment) : σ := EnvExtensionInterfaceImp.getState ext env
|
||||
end EnvExtension
|
||||
|
||||
/- Environment extensions can only be registered during initialization.
|
||||
|
|
@ -297,7 +340,7 @@ instance {α β σ} [Inhabited σ] : Inhabited (PersistentEnvExtension α β σ)
|
|||
|
||||
namespace PersistentEnvExtension
|
||||
|
||||
def getModuleEntries {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : Environment) (m : ModuleIdx) : Array α :=
|
||||
def getModuleEntries {α β σ : Type} [Inhabited σ] (ext : PersistentEnvExtension α β σ) (env : Environment) (m : ModuleIdx) : Array α :=
|
||||
(ext.toEnvExtension.getState env).importedEntries.get! m
|
||||
|
||||
def addEntry {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : Environment) (b : β) : Environment :=
|
||||
|
|
@ -305,7 +348,7 @@ def addEntry {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : En
|
|||
let state := ext.addEntryFn s.state b;
|
||||
{ s with state := state }
|
||||
|
||||
def getState {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : Environment) : σ :=
|
||||
def getState {α β σ : Type} [Inhabited σ] (ext : PersistentEnvExtension α β σ) (env : Environment) : σ :=
|
||||
(ext.toEnvExtension.getState env).state
|
||||
|
||||
def setState {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : Environment) (s : σ) : Environment :=
|
||||
|
|
@ -379,10 +422,10 @@ namespace SimplePersistentEnvExtension
|
|||
instance {α σ : Type} [Inhabited σ] : Inhabited (SimplePersistentEnvExtension α σ) :=
|
||||
inferInstanceAs (Inhabited (PersistentEnvExtension α α (List α × σ)))
|
||||
|
||||
def getEntries {α σ : Type} (ext : SimplePersistentEnvExtension α σ) (env : Environment) : List α :=
|
||||
def getEntries {α σ : Type} [Inhabited σ] (ext : SimplePersistentEnvExtension α σ) (env : Environment) : List α :=
|
||||
(PersistentEnvExtension.getState ext env).1
|
||||
|
||||
def getState {α σ : Type} (ext : SimplePersistentEnvExtension α σ) (env : Environment) : σ :=
|
||||
def getState {α σ : Type} [Inhabited σ] (ext : SimplePersistentEnvExtension α σ) (env : Environment) : σ :=
|
||||
(PersistentEnvExtension.getState ext env).2
|
||||
|
||||
def setState {α σ : Type} (ext : SimplePersistentEnvExtension α σ) (env : Environment) (s : σ) : Environment :=
|
||||
|
|
@ -518,23 +561,36 @@ private partial def getEntriesFor (mod : ModuleData) (extId : Name) (i : Nat) :
|
|||
else
|
||||
#[]
|
||||
|
||||
private def setImportedEntries (env : Environment) (mods : Array ModuleData) : IO Environment := do
|
||||
private def setImportedEntries (env : Environment) (mods : Array ModuleData) (startingAt : Nat := 0) : IO Environment := do
|
||||
let mut env := env
|
||||
let pExtDescrs ← persistentEnvExtensionsRef.get
|
||||
for mod in mods do
|
||||
for extDescr in pExtDescrs do
|
||||
for extDescr in pExtDescrs[startingAt:] do
|
||||
let entries := getEntriesFor mod extDescr.name 0
|
||||
env ← extDescr.toEnvExtension.modifyState env fun s => { s with importedEntries := s.importedEntries.push entries }
|
||||
return env
|
||||
|
||||
private def finalizePersistentExtensions (env : Environment) (opts : Options) : IO Environment := do
|
||||
let mut env := env
|
||||
let pExtDescrs ← persistentEnvExtensionsRef.get
|
||||
for extDescr in pExtDescrs do
|
||||
let s := extDescr.toEnvExtension.getState env
|
||||
let newState ← extDescr.addImportedFn s.importedEntries { env := env, opts := opts }
|
||||
env ← extDescr.toEnvExtension.setState env { s with state := newState }
|
||||
return env
|
||||
private partial def finalizePersistentExtensions (env : Environment) (mods : Array ModuleData) (opts : Options) : IO Environment := do
|
||||
loop 0 env
|
||||
where
|
||||
loop (i : Nat) (env : Environment) : IO Environment := do
|
||||
-- Recall that the size of the array stored `persistentEnvExtensionRef` may increase when we import user-defined environment extensions.
|
||||
let pExtDescrs ← persistentEnvExtensionsRef.get
|
||||
if h : i < pExtDescrs.size then
|
||||
let extDescr := pExtDescrs[i]
|
||||
let s := extDescr.toEnvExtension.getState env
|
||||
let prevSize := (← persistentEnvExtensionsRef.get).size
|
||||
let newState ← extDescr.addImportedFn s.importedEntries { env := env, opts := opts }
|
||||
let mut env ← extDescr.toEnvExtension.setState env { s with state := newState }
|
||||
env ← ensureExtensionsArraySize env
|
||||
if (← persistentEnvExtensionsRef.get).size > prevSize then
|
||||
-- This branch is executed when `pExtDescrs[i]` is the extension associated with the `init` attribute, and
|
||||
-- a user-defined persistent extension is imported.
|
||||
-- Thus, we invoke `setImportedEntries` to update the array `importedEntries` with the entries for the new extensions.
|
||||
env ← setImportedEntries env mods prevSize
|
||||
loop (i + 1) env
|
||||
else
|
||||
return env
|
||||
|
||||
structure ImportState where
|
||||
moduleNameSet : NameSet := {}
|
||||
|
|
@ -544,34 +600,38 @@ structure ImportState where
|
|||
|
||||
@[export lean_import_modules]
|
||||
partial def importModules (imports : List Import) (opts : Options) (trustLevel : UInt32 := 0) : IO Environment := profileitIO "import" opts do
|
||||
let (_, s) ← importMods imports |>.run {}
|
||||
-- (moduleNames, mods, regions)
|
||||
let mut modIdx : Nat := 0
|
||||
let mut const2ModIdx : HashMap Name ModuleIdx := {}
|
||||
let mut constants : ConstMap := SMap.empty
|
||||
for mod in s.moduleData do
|
||||
for cinfo in mod.constants do
|
||||
const2ModIdx := const2ModIdx.insert cinfo.name modIdx
|
||||
if constants.contains cinfo.name then throw (IO.userError s!"import failed, environment already contains '{cinfo.name}'")
|
||||
constants := constants.insert cinfo.name cinfo
|
||||
modIdx := modIdx + 1
|
||||
constants := constants.switch
|
||||
let exts ← mkInitialExtensionStates
|
||||
let env : Environment := {
|
||||
const2ModIdx := const2ModIdx,
|
||||
constants := constants,
|
||||
extensions := exts,
|
||||
header := {
|
||||
quotInit := !imports.isEmpty, -- We assume `core.lean` initializes quotient module
|
||||
trustLevel := trustLevel,
|
||||
imports := imports.toArray,
|
||||
regions := s.regions,
|
||||
moduleNames := s.moduleNames
|
||||
try
|
||||
importingRef.set true
|
||||
let (_, s) ← importMods imports |>.run {}
|
||||
-- (moduleNames, mods, regions)
|
||||
let mut modIdx : Nat := 0
|
||||
let mut const2ModIdx : HashMap Name ModuleIdx := {}
|
||||
let mut constants : ConstMap := SMap.empty
|
||||
for mod in s.moduleData do
|
||||
for cinfo in mod.constants do
|
||||
const2ModIdx := const2ModIdx.insert cinfo.name modIdx
|
||||
if constants.contains cinfo.name then throw (IO.userError s!"import failed, environment already contains '{cinfo.name}'")
|
||||
constants := constants.insert cinfo.name cinfo
|
||||
modIdx := modIdx + 1
|
||||
constants := constants.switch
|
||||
let exts ← mkInitialExtensionStates
|
||||
let env : Environment := {
|
||||
const2ModIdx := const2ModIdx,
|
||||
constants := constants,
|
||||
extensions := exts,
|
||||
header := {
|
||||
quotInit := !imports.isEmpty, -- We assume `core.lean` initializes quotient module
|
||||
trustLevel := trustLevel,
|
||||
imports := imports.toArray,
|
||||
regions := s.regions,
|
||||
moduleNames := s.moduleNames
|
||||
}
|
||||
}
|
||||
}
|
||||
let env ← setImportedEntries env s.moduleData
|
||||
let env ← finalizePersistentExtensions env opts
|
||||
pure env
|
||||
let env ← setImportedEntries env s.moduleData
|
||||
let env ← finalizePersistentExtensions env s.moduleData opts
|
||||
pure env
|
||||
finally
|
||||
importingRef.set false
|
||||
where
|
||||
importMods : List Import → StateRefT ImportState IO Unit
|
||||
| [] => pure ()
|
||||
|
|
|
|||
|
|
@ -179,3 +179,11 @@ add_test(NAME leanpkgtest_cyclic
|
|||
set -eu
|
||||
export PATH=${LEAN_BIN}:$PATH
|
||||
leanpkg build 2>&1 | grep 'import cycle'")
|
||||
|
||||
add_test(NAME leanpkgtest_user_ext
|
||||
WORKING_DIRECTORY "${LEAN_SOURCE_DIR}/../tests/leanpkg/user_ext"
|
||||
COMMAND bash -c "
|
||||
set -eu
|
||||
export PATH=${LEAN_BIN}:$PATH
|
||||
find . -name '*.olean' -delete
|
||||
leanpkg build | grep 'world, hello, test'")
|
||||
|
|
|
|||
1
tests/leanpkg/user_ext/.gitignore
vendored
Normal file
1
tests/leanpkg/user_ext/.gitignore
vendored
Normal file
|
|
@ -0,0 +1 @@
|
|||
/build
|
||||
5
tests/leanpkg/user_ext/UserExt.lean
Normal file
5
tests/leanpkg/user_ext/UserExt.lean
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
import UserExt.Tst1
|
||||
import UserExt.Tst2
|
||||
|
||||
show_foo_set
|
||||
show_bla_set
|
||||
23
tests/leanpkg/user_ext/UserExt/BlaExt.lean
Normal file
23
tests/leanpkg/user_ext/UserExt/BlaExt.lean
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
import Lean
|
||||
|
||||
open Lean
|
||||
|
||||
initialize blaExtension : SimplePersistentEnvExtension Name NameSet ←
|
||||
registerSimplePersistentEnvExtension {
|
||||
name := `blaExt
|
||||
addEntryFn := NameSet.insert
|
||||
addImportedFn := fun es => mkStateFromImportedEntries NameSet.insert {} es
|
||||
}
|
||||
|
||||
syntax (name := insertBla) "insert_bla " ident : command
|
||||
syntax (name := showBla) "show_bla_set" : command
|
||||
|
||||
open Lean.Elab
|
||||
open Lean.Elab.Command
|
||||
|
||||
@[commandElab insertBla] def elabInsertBla : CommandElab := fun stx => do
|
||||
IO.println s!"inserting {stx[1].getId}"
|
||||
modifyEnv fun env => blaExtension.addEntry env stx[1].getId
|
||||
|
||||
@[commandElab showBla] def elabShowBla : CommandElab := fun stx => do
|
||||
IO.println s!"bla set: {blaExtension.getState (← getEnv) |>.toList}"
|
||||
23
tests/leanpkg/user_ext/UserExt/FooExt.lean
Normal file
23
tests/leanpkg/user_ext/UserExt/FooExt.lean
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
import Lean
|
||||
|
||||
open Lean
|
||||
|
||||
initialize fooExtension : SimplePersistentEnvExtension Name NameSet ←
|
||||
registerSimplePersistentEnvExtension {
|
||||
name := `fooExt
|
||||
addEntryFn := NameSet.insert
|
||||
addImportedFn := fun es => mkStateFromImportedEntries NameSet.insert {} es
|
||||
}
|
||||
|
||||
syntax (name := insertFoo) "insert_foo " ident : command
|
||||
syntax (name := showFoo) "show_foo_set" : command
|
||||
|
||||
open Lean.Elab
|
||||
open Lean.Elab.Command
|
||||
|
||||
@[commandElab insertFoo] def elabInsertFoo : CommandElab := fun stx => do
|
||||
IO.println s!"inserting {stx[1].getId}"
|
||||
modifyEnv fun env => fooExtension.addEntry env stx[1].getId
|
||||
|
||||
@[commandElab showFoo] def elabShowFoo : CommandElab := fun stx => do
|
||||
IO.println s!"foo set: {fooExtension.getState (← getEnv) |>.toList}"
|
||||
9
tests/leanpkg/user_ext/UserExt/Tst1.lean
Normal file
9
tests/leanpkg/user_ext/UserExt/Tst1.lean
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
import UserExt.FooExt
|
||||
import UserExt.BlaExt
|
||||
|
||||
insert_foo hello
|
||||
insert_foo world
|
||||
show_foo_set
|
||||
|
||||
insert_bla abc
|
||||
show_bla_set
|
||||
6
tests/leanpkg/user_ext/UserExt/Tst2.lean
Normal file
6
tests/leanpkg/user_ext/UserExt/Tst2.lean
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
import UserExt.BlaExt
|
||||
import UserExt.FooExt
|
||||
|
||||
insert_foo test
|
||||
insert_bla defg
|
||||
insert_bla hij
|
||||
3
tests/leanpkg/user_ext/leanpkg.toml
Normal file
3
tests/leanpkg/user_ext/leanpkg.toml
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
[package]
|
||||
name = "UserExt"
|
||||
version = "0.1"
|
||||
Loading…
Add table
Reference in a new issue