diff --git a/src/Lean/Environment.lean b/src/Lean/Environment.lean index 209c4ea06c..dc96222dc5 100644 --- a/src/Lean/Environment.lean +++ b/src/Lean/Environment.lean @@ -12,6 +12,13 @@ import Lean.Util.FindExpr import Lean.Util.Profile namespace Lean +builtin_initialize importingRef : IO.Ref Bool ← IO.mkRef false + +/- True while modules are being imported. We use this flag to test check whether environment extensions are registered only + during initialization (builtin ones), and importing (user defined ones). -/ +def importing : IO Bool := + importingRef.get + /- Opaque environment extension state. -/ constant EnvExtensionStateSpec : PointedType.{0} def EnvExtensionState : Type := EnvExtensionStateSpec.type @@ -145,18 +152,20 @@ structure EnvExtensionInterface where registerExt {σ} (mkInitial : IO σ) : IO (ext σ) setState {σ} (e : ext σ) (env : Environment) : σ → Environment modifyState {σ} (e : ext σ) (env : Environment) : (σ → σ) → Environment - getState {σ} (e : ext σ) (env : Environment) : σ + getState {σ} [Inhabited σ] (e : ext σ) (env : Environment) : σ mkInitialExtStates : IO (Array EnvExtensionState) + ensureExtensionsSize : Environment → IO Environment instance : Inhabited EnvExtensionInterface where default := { - ext := id, - inhabitedExt := id, - registerExt := fun mk => mk, - setState := fun _ env _ => env, - modifyState := fun _ env _ => env, - getState := fun ext _ => ext, - mkInitialExtStates := pure #[] + ext := id + inhabitedExt := id + ensureExtensionsSize := fun env => pure env + registerExt := fun mk => mk + setState := fun _ env _ => env + modifyState := fun _ env _ => env + getState := fun ext _ => ext + mkInitialExtStates := pure #[] } /- Unsafe implementation of `EnvExtensionInterface` -/ @@ -170,23 +179,53 @@ structure Ext (σ : Type) where private def mkEnvExtensionsRef : IO (IO.Ref (Array (Ext EnvExtensionState))) := IO.mkRef #[] @[builtinInit mkEnvExtensionsRef] private constant envExtensionsRef : IO.Ref (Array (Ext EnvExtensionState)) +/-- + User-defined environment extensions are declared using the `initialize` command. + This command is just syntax sugar for the `init` attribute. + When we `import` lean modules, the vector stored at `envExtensionsRef` may increase in size because of + user-defined environment extensions. When this happens, we must adjust the size of the `env.extensions`. + This method is invoked when processing `import`s. +-/ +partial def ensureExtensionsArraySize (env : Environment) : IO Environment := do + loop env.extensions.size env +where + loop (i : Nat) (env : Environment) : IO Environment := do + let envExtensions ← envExtensionsRef.get + if h : i < envExtensions.size then + let s ← envExtensions[i].mkInitial + let env := { env with extensions := env.extensions.push s } + loop (i + 1) env + else + return env + +private def invalidExtMsg := "invalid environment extension has been accessed" + unsafe def setState {σ} (ext : Ext σ) (env : Environment) (s : σ) : Environment := - { env with extensions := env.extensions.set! ext.idx (unsafeCast s) } + if h : ext.idx < env.extensions.size then + { env with extensions := env.extensions.set ⟨ext.idx, h⟩ (unsafeCast s) } + else + panic! invalidExtMsg @[inline] unsafe def modifyState {σ : Type} (ext : Ext σ) (env : Environment) (f : σ → σ) : Environment := - { env with - extensions := env.extensions.modify ext.idx fun s => - let s : σ := unsafeCast s; - let s : σ := f s; - unsafeCast s } + if ext.idx < env.extensions.size then + { env with + extensions := env.extensions.modify ext.idx fun s => + let s : σ := unsafeCast s + let s : σ := f s + unsafeCast s } + else + panic! invalidExtMsg -unsafe def getState {σ} (ext : Ext σ) (env : Environment) : σ := - let s : EnvExtensionState := env.extensions.get! ext.idx - unsafeCast s +unsafe def getState {σ} [Inhabited σ] (ext : Ext σ) (env : Environment) : σ := + if h : ext.idx < env.extensions.size then + let s : EnvExtensionState := env.extensions.get ⟨ext.idx, h⟩ + unsafeCast s + else + panic! invalidExtMsg unsafe def registerExt {σ} (mkInitial : IO σ) : IO (Ext σ) := do - let initializing ← IO.initializing - unless initializing do throw (IO.userError "failed to register environment, extensions can only be registered during initialization") + unless (← IO.initializing) || (← importing) do + throw (IO.userError "failed to register environment, extensions can only be registered during initialization") let exts ← envExtensionsRef.get let idx := exts.size let ext : Ext σ := { @@ -201,13 +240,14 @@ def mkInitialExtStates : IO (Array EnvExtensionState) := do exts.mapM fun ext => ext.mkInitial unsafe def imp : EnvExtensionInterface := { - ext := Ext, - inhabitedExt := fun _ => ⟨arbitrary⟩, - registerExt := registerExt, - setState := setState, - modifyState := modifyState, - getState := getState, - mkInitialExtStates := mkInitialExtStates + ext := Ext + ensureExtensionsSize := ensureExtensionsArraySize + inhabitedExt := fun _ => ⟨arbitrary⟩ + registerExt := registerExt + setState := setState + modifyState := modifyState + getState := getState + mkInitialExtStates := mkInitialExtStates } end EnvExtensionInterfaceUnsafe @@ -217,11 +257,14 @@ constant EnvExtensionInterfaceImp : EnvExtensionInterface def EnvExtension (σ : Type) : Type := EnvExtensionInterfaceImp.ext σ +private def ensureExtensionsArraySize (env : Environment) : IO Environment := + EnvExtensionInterfaceImp.ensureExtensionsSize env + namespace EnvExtension instance {σ} [s : Inhabited σ] : Inhabited (EnvExtension σ) := EnvExtensionInterfaceImp.inhabitedExt s def setState {σ : Type} (ext : EnvExtension σ) (env : Environment) (s : σ) : Environment := EnvExtensionInterfaceImp.setState ext env s def modifyState {σ : Type} (ext : EnvExtension σ) (env : Environment) (f : σ → σ) : Environment := EnvExtensionInterfaceImp.modifyState ext env f -def getState {σ : Type} (ext : EnvExtension σ) (env : Environment) : σ := EnvExtensionInterfaceImp.getState ext env +def getState {σ : Type} [Inhabited σ] (ext : EnvExtension σ) (env : Environment) : σ := EnvExtensionInterfaceImp.getState ext env end EnvExtension /- Environment extensions can only be registered during initialization. @@ -297,7 +340,7 @@ instance {α β σ} [Inhabited σ] : Inhabited (PersistentEnvExtension α β σ) namespace PersistentEnvExtension -def getModuleEntries {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : Environment) (m : ModuleIdx) : Array α := +def getModuleEntries {α β σ : Type} [Inhabited σ] (ext : PersistentEnvExtension α β σ) (env : Environment) (m : ModuleIdx) : Array α := (ext.toEnvExtension.getState env).importedEntries.get! m def addEntry {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : Environment) (b : β) : Environment := @@ -305,7 +348,7 @@ def addEntry {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : En let state := ext.addEntryFn s.state b; { s with state := state } -def getState {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : Environment) : σ := +def getState {α β σ : Type} [Inhabited σ] (ext : PersistentEnvExtension α β σ) (env : Environment) : σ := (ext.toEnvExtension.getState env).state def setState {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : Environment) (s : σ) : Environment := @@ -379,10 +422,10 @@ namespace SimplePersistentEnvExtension instance {α σ : Type} [Inhabited σ] : Inhabited (SimplePersistentEnvExtension α σ) := inferInstanceAs (Inhabited (PersistentEnvExtension α α (List α × σ))) -def getEntries {α σ : Type} (ext : SimplePersistentEnvExtension α σ) (env : Environment) : List α := +def getEntries {α σ : Type} [Inhabited σ] (ext : SimplePersistentEnvExtension α σ) (env : Environment) : List α := (PersistentEnvExtension.getState ext env).1 -def getState {α σ : Type} (ext : SimplePersistentEnvExtension α σ) (env : Environment) : σ := +def getState {α σ : Type} [Inhabited σ] (ext : SimplePersistentEnvExtension α σ) (env : Environment) : σ := (PersistentEnvExtension.getState ext env).2 def setState {α σ : Type} (ext : SimplePersistentEnvExtension α σ) (env : Environment) (s : σ) : Environment := @@ -518,23 +561,36 @@ private partial def getEntriesFor (mod : ModuleData) (extId : Name) (i : Nat) : else #[] -private def setImportedEntries (env : Environment) (mods : Array ModuleData) : IO Environment := do +private def setImportedEntries (env : Environment) (mods : Array ModuleData) (startingAt : Nat := 0) : IO Environment := do let mut env := env let pExtDescrs ← persistentEnvExtensionsRef.get for mod in mods do - for extDescr in pExtDescrs do + for extDescr in pExtDescrs[startingAt:] do let entries := getEntriesFor mod extDescr.name 0 env ← extDescr.toEnvExtension.modifyState env fun s => { s with importedEntries := s.importedEntries.push entries } return env -private def finalizePersistentExtensions (env : Environment) (opts : Options) : IO Environment := do - let mut env := env - let pExtDescrs ← persistentEnvExtensionsRef.get - for extDescr in pExtDescrs do - let s := extDescr.toEnvExtension.getState env - let newState ← extDescr.addImportedFn s.importedEntries { env := env, opts := opts } - env ← extDescr.toEnvExtension.setState env { s with state := newState } - return env +private partial def finalizePersistentExtensions (env : Environment) (mods : Array ModuleData) (opts : Options) : IO Environment := do + loop 0 env +where + loop (i : Nat) (env : Environment) : IO Environment := do + -- Recall that the size of the array stored `persistentEnvExtensionRef` may increase when we import user-defined environment extensions. + let pExtDescrs ← persistentEnvExtensionsRef.get + if h : i < pExtDescrs.size then + let extDescr := pExtDescrs[i] + let s := extDescr.toEnvExtension.getState env + let prevSize := (← persistentEnvExtensionsRef.get).size + let newState ← extDescr.addImportedFn s.importedEntries { env := env, opts := opts } + let mut env ← extDescr.toEnvExtension.setState env { s with state := newState } + env ← ensureExtensionsArraySize env + if (← persistentEnvExtensionsRef.get).size > prevSize then + -- This branch is executed when `pExtDescrs[i]` is the extension associated with the `init` attribute, and + -- a user-defined persistent extension is imported. + -- Thus, we invoke `setImportedEntries` to update the array `importedEntries` with the entries for the new extensions. + env ← setImportedEntries env mods prevSize + loop (i + 1) env + else + return env structure ImportState where moduleNameSet : NameSet := {} @@ -544,34 +600,38 @@ structure ImportState where @[export lean_import_modules] partial def importModules (imports : List Import) (opts : Options) (trustLevel : UInt32 := 0) : IO Environment := profileitIO "import" opts do - let (_, s) ← importMods imports |>.run {} - -- (moduleNames, mods, regions) - let mut modIdx : Nat := 0 - let mut const2ModIdx : HashMap Name ModuleIdx := {} - let mut constants : ConstMap := SMap.empty - for mod in s.moduleData do - for cinfo in mod.constants do - const2ModIdx := const2ModIdx.insert cinfo.name modIdx - if constants.contains cinfo.name then throw (IO.userError s!"import failed, environment already contains '{cinfo.name}'") - constants := constants.insert cinfo.name cinfo - modIdx := modIdx + 1 - constants := constants.switch - let exts ← mkInitialExtensionStates - let env : Environment := { - const2ModIdx := const2ModIdx, - constants := constants, - extensions := exts, - header := { - quotInit := !imports.isEmpty, -- We assume `core.lean` initializes quotient module - trustLevel := trustLevel, - imports := imports.toArray, - regions := s.regions, - moduleNames := s.moduleNames + try + importingRef.set true + let (_, s) ← importMods imports |>.run {} + -- (moduleNames, mods, regions) + let mut modIdx : Nat := 0 + let mut const2ModIdx : HashMap Name ModuleIdx := {} + let mut constants : ConstMap := SMap.empty + for mod in s.moduleData do + for cinfo in mod.constants do + const2ModIdx := const2ModIdx.insert cinfo.name modIdx + if constants.contains cinfo.name then throw (IO.userError s!"import failed, environment already contains '{cinfo.name}'") + constants := constants.insert cinfo.name cinfo + modIdx := modIdx + 1 + constants := constants.switch + let exts ← mkInitialExtensionStates + let env : Environment := { + const2ModIdx := const2ModIdx, + constants := constants, + extensions := exts, + header := { + quotInit := !imports.isEmpty, -- We assume `core.lean` initializes quotient module + trustLevel := trustLevel, + imports := imports.toArray, + regions := s.regions, + moduleNames := s.moduleNames + } } - } - let env ← setImportedEntries env s.moduleData - let env ← finalizePersistentExtensions env opts - pure env + let env ← setImportedEntries env s.moduleData + let env ← finalizePersistentExtensions env s.moduleData opts + pure env + finally + importingRef.set false where importMods : List Import → StateRefT ImportState IO Unit | [] => pure () diff --git a/src/shell/CMakeLists.txt b/src/shell/CMakeLists.txt index d68e1513c3..4949a8729a 100644 --- a/src/shell/CMakeLists.txt +++ b/src/shell/CMakeLists.txt @@ -179,3 +179,11 @@ add_test(NAME leanpkgtest_cyclic set -eu export PATH=${LEAN_BIN}:$PATH leanpkg build 2>&1 | grep 'import cycle'") + +add_test(NAME leanpkgtest_user_ext + WORKING_DIRECTORY "${LEAN_SOURCE_DIR}/../tests/leanpkg/user_ext" + COMMAND bash -c " + set -eu + export PATH=${LEAN_BIN}:$PATH + find . -name '*.olean' -delete + leanpkg build | grep 'world, hello, test'") diff --git a/tests/leanpkg/user_ext/.gitignore b/tests/leanpkg/user_ext/.gitignore new file mode 100644 index 0000000000..796b96d1c4 --- /dev/null +++ b/tests/leanpkg/user_ext/.gitignore @@ -0,0 +1 @@ +/build diff --git a/tests/leanpkg/user_ext/UserExt.lean b/tests/leanpkg/user_ext/UserExt.lean new file mode 100644 index 0000000000..b8531c360c --- /dev/null +++ b/tests/leanpkg/user_ext/UserExt.lean @@ -0,0 +1,5 @@ +import UserExt.Tst1 +import UserExt.Tst2 + +show_foo_set +show_bla_set diff --git a/tests/leanpkg/user_ext/UserExt/BlaExt.lean b/tests/leanpkg/user_ext/UserExt/BlaExt.lean new file mode 100644 index 0000000000..5284117fe9 --- /dev/null +++ b/tests/leanpkg/user_ext/UserExt/BlaExt.lean @@ -0,0 +1,23 @@ +import Lean + +open Lean + +initialize blaExtension : SimplePersistentEnvExtension Name NameSet ← + registerSimplePersistentEnvExtension { + name := `blaExt + addEntryFn := NameSet.insert + addImportedFn := fun es => mkStateFromImportedEntries NameSet.insert {} es + } + +syntax (name := insertBla) "insert_bla " ident : command +syntax (name := showBla) "show_bla_set" : command + +open Lean.Elab +open Lean.Elab.Command + +@[commandElab insertBla] def elabInsertBla : CommandElab := fun stx => do + IO.println s!"inserting {stx[1].getId}" + modifyEnv fun env => blaExtension.addEntry env stx[1].getId + +@[commandElab showBla] def elabShowBla : CommandElab := fun stx => do + IO.println s!"bla set: {blaExtension.getState (← getEnv) |>.toList}" diff --git a/tests/leanpkg/user_ext/UserExt/FooExt.lean b/tests/leanpkg/user_ext/UserExt/FooExt.lean new file mode 100644 index 0000000000..6618dbd4a9 --- /dev/null +++ b/tests/leanpkg/user_ext/UserExt/FooExt.lean @@ -0,0 +1,23 @@ +import Lean + +open Lean + +initialize fooExtension : SimplePersistentEnvExtension Name NameSet ← + registerSimplePersistentEnvExtension { + name := `fooExt + addEntryFn := NameSet.insert + addImportedFn := fun es => mkStateFromImportedEntries NameSet.insert {} es + } + +syntax (name := insertFoo) "insert_foo " ident : command +syntax (name := showFoo) "show_foo_set" : command + +open Lean.Elab +open Lean.Elab.Command + +@[commandElab insertFoo] def elabInsertFoo : CommandElab := fun stx => do + IO.println s!"inserting {stx[1].getId}" + modifyEnv fun env => fooExtension.addEntry env stx[1].getId + +@[commandElab showFoo] def elabShowFoo : CommandElab := fun stx => do + IO.println s!"foo set: {fooExtension.getState (← getEnv) |>.toList}" diff --git a/tests/leanpkg/user_ext/UserExt/Tst1.lean b/tests/leanpkg/user_ext/UserExt/Tst1.lean new file mode 100644 index 0000000000..df3f4ff5ce --- /dev/null +++ b/tests/leanpkg/user_ext/UserExt/Tst1.lean @@ -0,0 +1,9 @@ +import UserExt.FooExt +import UserExt.BlaExt + +insert_foo hello +insert_foo world +show_foo_set + +insert_bla abc +show_bla_set diff --git a/tests/leanpkg/user_ext/UserExt/Tst2.lean b/tests/leanpkg/user_ext/UserExt/Tst2.lean new file mode 100644 index 0000000000..c7743d2044 --- /dev/null +++ b/tests/leanpkg/user_ext/UserExt/Tst2.lean @@ -0,0 +1,6 @@ +import UserExt.BlaExt +import UserExt.FooExt + +insert_foo test +insert_bla defg +insert_bla hij diff --git a/tests/leanpkg/user_ext/leanpkg.toml b/tests/leanpkg/user_ext/leanpkg.toml new file mode 100644 index 0000000000..2606dd26d3 --- /dev/null +++ b/tests/leanpkg/user_ext/leanpkg.toml @@ -0,0 +1,3 @@ +[package] +name = "UserExt" +version = "0.1"