feat: user-defined environment extensions

New test demonstrates how to use them.
The user-defined extensions cannot be used in the same file where they
were declared because the `initialize` commands are only executed when
we import the modules containing them.

TODO: user-defined attributes.
This commit is contained in:
Leonardo de Moura 2021-07-26 16:15:37 -07:00
parent 42561bb93f
commit cdd1dbbb36
9 changed files with 206 additions and 68 deletions

View file

@ -12,6 +12,13 @@ import Lean.Util.FindExpr
import Lean.Util.Profile
namespace Lean
builtin_initialize importingRef : IO.Ref Bool ← IO.mkRef false
/- True while modules are being imported. We use this flag to test check whether environment extensions are registered only
during initialization (builtin ones), and importing (user defined ones). -/
def importing : IO Bool :=
importingRef.get
/- Opaque environment extension state. -/
constant EnvExtensionStateSpec : PointedType.{0}
def EnvExtensionState : Type := EnvExtensionStateSpec.type
@ -145,18 +152,20 @@ structure EnvExtensionInterface where
registerExt {σ} (mkInitial : IO σ) : IO (ext σ)
setState {σ} (e : ext σ) (env : Environment) : σ → Environment
modifyState {σ} (e : ext σ) (env : Environment) : (σσ) → Environment
getState {σ} (e : ext σ) (env : Environment) : σ
getState {σ} [Inhabited σ] (e : ext σ) (env : Environment) : σ
mkInitialExtStates : IO (Array EnvExtensionState)
ensureExtensionsSize : Environment → IO Environment
instance : Inhabited EnvExtensionInterface where
default := {
ext := id,
inhabitedExt := id,
registerExt := fun mk => mk,
setState := fun _ env _ => env,
modifyState := fun _ env _ => env,
getState := fun ext _ => ext,
mkInitialExtStates := pure #[]
ext := id
inhabitedExt := id
ensureExtensionsSize := fun env => pure env
registerExt := fun mk => mk
setState := fun _ env _ => env
modifyState := fun _ env _ => env
getState := fun ext _ => ext
mkInitialExtStates := pure #[]
}
/- Unsafe implementation of `EnvExtensionInterface` -/
@ -170,23 +179,53 @@ structure Ext (σ : Type) where
private def mkEnvExtensionsRef : IO (IO.Ref (Array (Ext EnvExtensionState))) := IO.mkRef #[]
@[builtinInit mkEnvExtensionsRef] private constant envExtensionsRef : IO.Ref (Array (Ext EnvExtensionState))
/--
User-defined environment extensions are declared using the `initialize` command.
This command is just syntax sugar for the `init` attribute.
When we `import` lean modules, the vector stored at `envExtensionsRef` may increase in size because of
user-defined environment extensions. When this happens, we must adjust the size of the `env.extensions`.
This method is invoked when processing `import`s.
-/
partial def ensureExtensionsArraySize (env : Environment) : IO Environment := do
loop env.extensions.size env
where
loop (i : Nat) (env : Environment) : IO Environment := do
let envExtensions ← envExtensionsRef.get
if h : i < envExtensions.size then
let s ← envExtensions[i].mkInitial
let env := { env with extensions := env.extensions.push s }
loop (i + 1) env
else
return env
private def invalidExtMsg := "invalid environment extension has been accessed"
unsafe def setState {σ} (ext : Ext σ) (env : Environment) (s : σ) : Environment :=
{ env with extensions := env.extensions.set! ext.idx (unsafeCast s) }
if h : ext.idx < env.extensions.size then
{ env with extensions := env.extensions.set ⟨ext.idx, h⟩ (unsafeCast s) }
else
panic! invalidExtMsg
@[inline] unsafe def modifyState {σ : Type} (ext : Ext σ) (env : Environment) (f : σσ) : Environment :=
{ env with
extensions := env.extensions.modify ext.idx fun s =>
let s : σ := unsafeCast s;
let s : σ := f s;
unsafeCast s }
if ext.idx < env.extensions.size then
{ env with
extensions := env.extensions.modify ext.idx fun s =>
let s : σ := unsafeCast s
let s : σ := f s
unsafeCast s }
else
panic! invalidExtMsg
unsafe def getState {σ} (ext : Ext σ) (env : Environment) : σ :=
let s : EnvExtensionState := env.extensions.get! ext.idx
unsafeCast s
unsafe def getState {σ} [Inhabited σ] (ext : Ext σ) (env : Environment) : σ :=
if h : ext.idx < env.extensions.size then
let s : EnvExtensionState := env.extensions.get ⟨ext.idx, h⟩
unsafeCast s
else
panic! invalidExtMsg
unsafe def registerExt {σ} (mkInitial : IO σ) : IO (Ext σ) := do
let initializing ← IO.initializing
unless initializing do throw (IO.userError "failed to register environment, extensions can only be registered during initialization")
unless (← IO.initializing) || (← importing) do
throw (IO.userError "failed to register environment, extensions can only be registered during initialization")
let exts ← envExtensionsRef.get
let idx := exts.size
let ext : Ext σ := {
@ -201,13 +240,14 @@ def mkInitialExtStates : IO (Array EnvExtensionState) := do
exts.mapM fun ext => ext.mkInitial
unsafe def imp : EnvExtensionInterface := {
ext := Ext,
inhabitedExt := fun _ => ⟨arbitrary⟩,
registerExt := registerExt,
setState := setState,
modifyState := modifyState,
getState := getState,
mkInitialExtStates := mkInitialExtStates
ext := Ext
ensureExtensionsSize := ensureExtensionsArraySize
inhabitedExt := fun _ => ⟨arbitrary⟩
registerExt := registerExt
setState := setState
modifyState := modifyState
getState := getState
mkInitialExtStates := mkInitialExtStates
}
end EnvExtensionInterfaceUnsafe
@ -217,11 +257,14 @@ constant EnvExtensionInterfaceImp : EnvExtensionInterface
def EnvExtension (σ : Type) : Type := EnvExtensionInterfaceImp.ext σ
private def ensureExtensionsArraySize (env : Environment) : IO Environment :=
EnvExtensionInterfaceImp.ensureExtensionsSize env
namespace EnvExtension
instance {σ} [s : Inhabited σ] : Inhabited (EnvExtension σ) := EnvExtensionInterfaceImp.inhabitedExt s
def setState {σ : Type} (ext : EnvExtension σ) (env : Environment) (s : σ) : Environment := EnvExtensionInterfaceImp.setState ext env s
def modifyState {σ : Type} (ext : EnvExtension σ) (env : Environment) (f : σσ) : Environment := EnvExtensionInterfaceImp.modifyState ext env f
def getState {σ : Type} (ext : EnvExtension σ) (env : Environment) : σ := EnvExtensionInterfaceImp.getState ext env
def getState {σ : Type} [Inhabited σ] (ext : EnvExtension σ) (env : Environment) : σ := EnvExtensionInterfaceImp.getState ext env
end EnvExtension
/- Environment extensions can only be registered during initialization.
@ -297,7 +340,7 @@ instance {α β σ} [Inhabited σ] : Inhabited (PersistentEnvExtension α β σ)
namespace PersistentEnvExtension
def getModuleEntries {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : Environment) (m : ModuleIdx) : Array α :=
def getModuleEntries {α β σ : Type} [Inhabited σ] (ext : PersistentEnvExtension α β σ) (env : Environment) (m : ModuleIdx) : Array α :=
(ext.toEnvExtension.getState env).importedEntries.get! m
def addEntry {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : Environment) (b : β) : Environment :=
@ -305,7 +348,7 @@ def addEntry {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : En
let state := ext.addEntryFn s.state b;
{ s with state := state }
def getState {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : Environment) : σ :=
def getState {α β σ : Type} [Inhabited σ] (ext : PersistentEnvExtension α β σ) (env : Environment) : σ :=
(ext.toEnvExtension.getState env).state
def setState {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : Environment) (s : σ) : Environment :=
@ -379,10 +422,10 @@ namespace SimplePersistentEnvExtension
instance {α σ : Type} [Inhabited σ] : Inhabited (SimplePersistentEnvExtension α σ) :=
inferInstanceAs (Inhabited (PersistentEnvExtension α α (List α × σ)))
def getEntries {α σ : Type} (ext : SimplePersistentEnvExtension α σ) (env : Environment) : List α :=
def getEntries {α σ : Type} [Inhabited σ] (ext : SimplePersistentEnvExtension α σ) (env : Environment) : List α :=
(PersistentEnvExtension.getState ext env).1
def getState {α σ : Type} (ext : SimplePersistentEnvExtension α σ) (env : Environment) : σ :=
def getState {α σ : Type} [Inhabited σ] (ext : SimplePersistentEnvExtension α σ) (env : Environment) : σ :=
(PersistentEnvExtension.getState ext env).2
def setState {α σ : Type} (ext : SimplePersistentEnvExtension α σ) (env : Environment) (s : σ) : Environment :=
@ -518,23 +561,36 @@ private partial def getEntriesFor (mod : ModuleData) (extId : Name) (i : Nat) :
else
#[]
private def setImportedEntries (env : Environment) (mods : Array ModuleData) : IO Environment := do
private def setImportedEntries (env : Environment) (mods : Array ModuleData) (startingAt : Nat := 0) : IO Environment := do
let mut env := env
let pExtDescrs ← persistentEnvExtensionsRef.get
for mod in mods do
for extDescr in pExtDescrs do
for extDescr in pExtDescrs[startingAt:] do
let entries := getEntriesFor mod extDescr.name 0
env ← extDescr.toEnvExtension.modifyState env fun s => { s with importedEntries := s.importedEntries.push entries }
return env
private def finalizePersistentExtensions (env : Environment) (opts : Options) : IO Environment := do
let mut env := env
let pExtDescrs ← persistentEnvExtensionsRef.get
for extDescr in pExtDescrs do
let s := extDescr.toEnvExtension.getState env
let newState ← extDescr.addImportedFn s.importedEntries { env := env, opts := opts }
env ← extDescr.toEnvExtension.setState env { s with state := newState }
return env
private partial def finalizePersistentExtensions (env : Environment) (mods : Array ModuleData) (opts : Options) : IO Environment := do
loop 0 env
where
loop (i : Nat) (env : Environment) : IO Environment := do
-- Recall that the size of the array stored `persistentEnvExtensionRef` may increase when we import user-defined environment extensions.
let pExtDescrs ← persistentEnvExtensionsRef.get
if h : i < pExtDescrs.size then
let extDescr := pExtDescrs[i]
let s := extDescr.toEnvExtension.getState env
let prevSize := (← persistentEnvExtensionsRef.get).size
let newState ← extDescr.addImportedFn s.importedEntries { env := env, opts := opts }
let mut env ← extDescr.toEnvExtension.setState env { s with state := newState }
env ← ensureExtensionsArraySize env
if (← persistentEnvExtensionsRef.get).size > prevSize then
-- This branch is executed when `pExtDescrs[i]` is the extension associated with the `init` attribute, and
-- a user-defined persistent extension is imported.
-- Thus, we invoke `setImportedEntries` to update the array `importedEntries` with the entries for the new extensions.
env ← setImportedEntries env mods prevSize
loop (i + 1) env
else
return env
structure ImportState where
moduleNameSet : NameSet := {}
@ -544,34 +600,38 @@ structure ImportState where
@[export lean_import_modules]
partial def importModules (imports : List Import) (opts : Options) (trustLevel : UInt32 := 0) : IO Environment := profileitIO "import" opts do
let (_, s) ← importMods imports |>.run {}
-- (moduleNames, mods, regions)
let mut modIdx : Nat := 0
let mut const2ModIdx : HashMap Name ModuleIdx := {}
let mut constants : ConstMap := SMap.empty
for mod in s.moduleData do
for cinfo in mod.constants do
const2ModIdx := const2ModIdx.insert cinfo.name modIdx
if constants.contains cinfo.name then throw (IO.userError s!"import failed, environment already contains '{cinfo.name}'")
constants := constants.insert cinfo.name cinfo
modIdx := modIdx + 1
constants := constants.switch
let exts ← mkInitialExtensionStates
let env : Environment := {
const2ModIdx := const2ModIdx,
constants := constants,
extensions := exts,
header := {
quotInit := !imports.isEmpty, -- We assume `core.lean` initializes quotient module
trustLevel := trustLevel,
imports := imports.toArray,
regions := s.regions,
moduleNames := s.moduleNames
try
importingRef.set true
let (_, s) ← importMods imports |>.run {}
-- (moduleNames, mods, regions)
let mut modIdx : Nat := 0
let mut const2ModIdx : HashMap Name ModuleIdx := {}
let mut constants : ConstMap := SMap.empty
for mod in s.moduleData do
for cinfo in mod.constants do
const2ModIdx := const2ModIdx.insert cinfo.name modIdx
if constants.contains cinfo.name then throw (IO.userError s!"import failed, environment already contains '{cinfo.name}'")
constants := constants.insert cinfo.name cinfo
modIdx := modIdx + 1
constants := constants.switch
let exts ← mkInitialExtensionStates
let env : Environment := {
const2ModIdx := const2ModIdx,
constants := constants,
extensions := exts,
header := {
quotInit := !imports.isEmpty, -- We assume `core.lean` initializes quotient module
trustLevel := trustLevel,
imports := imports.toArray,
regions := s.regions,
moduleNames := s.moduleNames
}
}
}
let env ← setImportedEntries env s.moduleData
let env ← finalizePersistentExtensions env opts
pure env
let env ← setImportedEntries env s.moduleData
let env ← finalizePersistentExtensions env s.moduleData opts
pure env
finally
importingRef.set false
where
importMods : List Import → StateRefT ImportState IO Unit
| [] => pure ()

View file

@ -179,3 +179,11 @@ add_test(NAME leanpkgtest_cyclic
set -eu
export PATH=${LEAN_BIN}:$PATH
leanpkg build 2>&1 | grep 'import cycle'")
add_test(NAME leanpkgtest_user_ext
WORKING_DIRECTORY "${LEAN_SOURCE_DIR}/../tests/leanpkg/user_ext"
COMMAND bash -c "
set -eu
export PATH=${LEAN_BIN}:$PATH
find . -name '*.olean' -delete
leanpkg build | grep 'world, hello, test'")

1
tests/leanpkg/user_ext/.gitignore vendored Normal file
View file

@ -0,0 +1 @@
/build

View file

@ -0,0 +1,5 @@
import UserExt.Tst1
import UserExt.Tst2
show_foo_set
show_bla_set

View file

@ -0,0 +1,23 @@
import Lean
open Lean
initialize blaExtension : SimplePersistentEnvExtension Name NameSet ←
registerSimplePersistentEnvExtension {
name := `blaExt
addEntryFn := NameSet.insert
addImportedFn := fun es => mkStateFromImportedEntries NameSet.insert {} es
}
syntax (name := insertBla) "insert_bla " ident : command
syntax (name := showBla) "show_bla_set" : command
open Lean.Elab
open Lean.Elab.Command
@[commandElab insertBla] def elabInsertBla : CommandElab := fun stx => do
IO.println s!"inserting {stx[1].getId}"
modifyEnv fun env => blaExtension.addEntry env stx[1].getId
@[commandElab showBla] def elabShowBla : CommandElab := fun stx => do
IO.println s!"bla set: {blaExtension.getState (← getEnv) |>.toList}"

View file

@ -0,0 +1,23 @@
import Lean
open Lean
initialize fooExtension : SimplePersistentEnvExtension Name NameSet ←
registerSimplePersistentEnvExtension {
name := `fooExt
addEntryFn := NameSet.insert
addImportedFn := fun es => mkStateFromImportedEntries NameSet.insert {} es
}
syntax (name := insertFoo) "insert_foo " ident : command
syntax (name := showFoo) "show_foo_set" : command
open Lean.Elab
open Lean.Elab.Command
@[commandElab insertFoo] def elabInsertFoo : CommandElab := fun stx => do
IO.println s!"inserting {stx[1].getId}"
modifyEnv fun env => fooExtension.addEntry env stx[1].getId
@[commandElab showFoo] def elabShowFoo : CommandElab := fun stx => do
IO.println s!"foo set: {fooExtension.getState (← getEnv) |>.toList}"

View file

@ -0,0 +1,9 @@
import UserExt.FooExt
import UserExt.BlaExt
insert_foo hello
insert_foo world
show_foo_set
insert_bla abc
show_bla_set

View file

@ -0,0 +1,6 @@
import UserExt.BlaExt
import UserExt.FooExt
insert_foo test
insert_bla defg
insert_bla hij

View file

@ -0,0 +1,3 @@
[package]
name = "UserExt"
version = "0.1"