feat(library/init/lean/environment): environment extension registration and API

This commit is contained in:
Leonardo de Moura 2019-05-09 16:53:04 -07:00
parent 3b80ec89b7
commit ecb50b553d

View file

@ -29,13 +29,14 @@ instance moduleidHasBeq : HasBeq ModuleId := ⟨ModuleId.beq⟩
/- Environment Extension Data -/
structure EnvExtensionData :=
(importedEntries : Array (Array ExtensionEntry))
(initState : Thunk ExtensionState)
(importedState : Thunk ExtensionState)
(entries : List ExtensionEntry := [])
(state : Option ExtensionState := none)
instance envExtensionDataInh : Inhabited EnvExtensionData :=
⟨{ importedEntries := Array.empty, initState := Thunk.mk (default _) }⟩
⟨{ importedEntries := Array.empty, importedState := Thunk.mk (default _) }⟩
/- TODO: mark opaque. -/
structure Environment :=
(const2ModId : HashMap Name ModuleId)
(constants : SMap Name ConstantInfo Name.quickLt)
@ -46,20 +47,67 @@ structure Environment :=
instance envInh : Inhabited Environment :=
⟨{ const2ModId := {}, constants := {}, extensions := Array.empty }⟩
/- TODO: mark opaque. -/
structure EnvExtension (α : Type) (σ : Type) :=
(name : Name)
(idx : Nat)
(initState : σ)
(addEntry : Bool → Environment → σασ)
(someVal : α)
(addEntry : Bool → σασ)
(toArray : List α → Array α)
instance envExtDefInh : Inhabited (EnvExtension ExtensionEntry ExtensionState) :=
⟨{ name := default _, idx := 0, initState := default _,
someVal := default _, addEntry := λ _ s _, s,
toArray := λ l, l.toArray }⟩
private def mkEnvExtensionsRef : IO (IO.Ref (Array (EnvExtension ExtensionEntry ExtensionState))) :=
IO.mkRef Array.empty
@[init mkEnvExtensionsRef]
private constant envExtensionsRef : IO.Ref (Array (EnvExtension ExtensionEntry ExtensionState)) := default _
/- TODO: replace export/extern trick with (the to be implemented) [implementedBy ...] attribute.
The extport/extern trick allows us to implement an opaque constant using an unsafe definition. -/
@[export leanRegisterEnvExtensionUnsafe]
unsafe def registerEnvExtensionUnsafe {α σ : Type} (name : Name) (initState : σ) (someVal : α) (addEntry : Bool → σασ) (toArray : List α → Array α) : IO (EnvExtension α σ) :=
do
exts ← envExtensionsRef.get,
when (exts.any (λ ext, ext.name == name)) $ throw (IO.userError ("invalid environment extension, '" ++ toString name ++ "' has already been used")),
let idx := exts.size,
let ext : EnvExtension α σ := {
name := name,
idx := idx,
initState := initState,
someVal := someVal,
addEntry := addEntry,
toArray := toArray
},
envExtensionsRef.modify (λ exts, exts.push (unsafeCast ext)),
pure ext
@[extern "leanRegisterEnvExtensionUnsafe"]
constant registerEnvExtension {α σ : Type} (name : Name) (initState : σ) (someVal : α) (addEntry : Bool → σασ) (toArray : List α → Array α) : IO (EnvExtension α σ) := default _
def mkEmptyEnvironment (trustLevel : UInt32 := 0) : IO Environment :=
do exts ← envExtensionsRef.get,
pure { const2ModId := {},
constants := {},
trustLevel := trustLevel,
extensions := exts.map $ λ ext, {
importedEntries := Array.empty,
importedState := Thunk.pure ext.initState
}
}
@[export leanGetModuleEntriesUnsafe]
unsafe def getModuleEntriesUnsafe {α σ : Type} (ext : EnvExtension α σ) (env : Environment) (m : ModuleId) : Array α :=
unsafe def getModuleEntriesUnsafe {α σ : Type} (env : Environment) (ext : EnvExtension α σ) (m : ModuleId) : Array α :=
let entries := (env.extensions.get ext.idx).importedEntries.get m.id in
unsafeCast entries
@[extern "leanGetModuleEntriesUnsafe"]
constant getModuleEntries {α σ : Type} (ext : EnvExtension α σ) (env : Environment) (m : ModuleId) : Array α := default _
constant getModuleEntries {α σ : Type} (env : Environment) (ext : EnvExtension α σ) (m : ModuleId) : Array α := default _
private def releaseExtensionData (env : Environment) (extIdx : Nat) : Environment :=
{ extensions := env.extensions.set extIdx (default _), .. env }
@ -68,7 +116,7 @@ private def setExtensionData (env : Environment) (extIdx : Nat) (d : EnvExtensio
{ extensions := env.extensions.set extIdx d, .. env }
@[export leanAddEntryUnsafe]
unsafe def addEntryUnsafe {α σ : Type} (ext : EnvExtension α σ) (env : Environment) (a : α) : Environment :=
unsafe def addEntryUnsafe {α σ : Type} (env : Environment) (ext : EnvExtension α σ) (a : α) : Environment :=
let extIdx := ext.idx in
let extData := env.extensions.get extIdx in
let env := releaseExtensionData env extIdx in
@ -79,14 +127,44 @@ match extData.state with
setExtensionData env extIdx extData
| some s :=
let extData := { state := none, .. extData } in
let s : σ := @unsafeCast _ _ ⟨ext.initState⟩ s in
let s := ext.addEntry false env s a in
let extData := { state := unsafeCast s, .. extData } in
let s := ext.addEntry false (@unsafeCast _ _ ⟨ext.initState⟩ s) a in
let extData := { state := some (unsafeCast s), .. extData } in
setExtensionData env extIdx extData
@[extern "leanAddEntryUnsafe"]
constant addEntry {α σ : Type} (ext : EnvExtension α σ) (env : Environment) (a : α) : Environment := default _
constant addEntry {α σ : Type} (env : Environment) (ext : EnvExtension α σ) (a : α) : Environment := default _
-- unsafe def getStateUnsafe {α σ : Type} (ext : EnvExtension α σ) (env : Environment)
unsafe def mkExtensionState {α σ : Type} (extData : EnvExtensionData) (ext : EnvExtension α σ) : ExtensionState :=
let importedState := extData.importedState.get in
extData.entries.foldl
(λ s e, unsafeCast (ext.addEntry false (@unsafeCast _ _ ⟨ext.initState⟩ s) (@unsafeCast _ _ ⟨ext.someVal⟩ e)))
importedState
@[export leanInitExtensionStateUnsafe]
unsafe def initExtensionStateUnsafe {α σ : Type} (env : Environment) (ext : EnvExtension α σ) : Environment :=
let extIdx := ext.idx in
let extData := env.extensions.get extIdx in
if extData.state.isSome then env
else
let env := releaseExtensionData env extIdx in
let s := mkExtensionState extData ext in
let extData := { state := some s, .. extData } in
setExtensionData env extIdx extData
@[extern "leanInitExtensionStateUnsafe"]
constant initExtensionState {α σ : Type} (env : Environment) (ext : EnvExtension α σ) : Environment := default _
@[export leanGetExtensionStateUnsafe]
unsafe def getExtensionStateUnsafe {α σ : Type} (env : Environment) (ext : EnvExtension α σ) : σ :=
let extIdx := ext.idx in
let extData := env.extensions.get extIdx in
match extData.state with
| some s := @unsafeCast _ _ ⟨ext.initState⟩ s
| none :=
let s := mkExtensionState extData ext in
@unsafeCast _ _ ⟨ext.initState⟩ s
@[extern "leanGetExtensionStateUnsafe"]
constant getExtensionState {α σ : Type} (env : Environment) (ext : EnvExtension α σ) : σ := ext.initState
end Lean