diff --git a/library/init/lean/environment.lean b/library/init/lean/environment.lean index f04e2824b0..2fcec98fcf 100644 --- a/library/init/lean/environment.lean +++ b/library/init/lean/environment.lean @@ -29,13 +29,14 @@ instance moduleidHasBeq : HasBeq ModuleId := ⟨ModuleId.beq⟩ /- Environment Extension Data -/ structure EnvExtensionData := (importedEntries : Array (Array ExtensionEntry)) -(initState : Thunk ExtensionState) +(importedState : Thunk ExtensionState) (entries : List ExtensionEntry := []) (state : Option ExtensionState := none) instance envExtensionDataInh : Inhabited EnvExtensionData := -⟨{ importedEntries := Array.empty, initState := Thunk.mk (default _) }⟩ +⟨{ importedEntries := Array.empty, importedState := Thunk.mk (default _) }⟩ +/- TODO: mark opaque. -/ structure Environment := (const2ModId : HashMap Name ModuleId) (constants : SMap Name ConstantInfo Name.quickLt) @@ -46,20 +47,67 @@ structure Environment := instance envInh : Inhabited Environment := ⟨{ const2ModId := {}, constants := {}, extensions := Array.empty }⟩ +/- TODO: mark opaque. -/ structure EnvExtension (α : Type) (σ : Type) := (name : Name) (idx : Nat) (initState : σ) -(addEntry : Bool → Environment → σ → α → σ) +(someVal : α) +(addEntry : Bool → σ → α → σ) (toArray : List α → Array α) +instance envExtDefInh : Inhabited (EnvExtension ExtensionEntry ExtensionState) := +⟨{ name := default _, idx := 0, initState := default _, + someVal := default _, addEntry := λ _ s _, s, + toArray := λ l, l.toArray }⟩ + +private def mkEnvExtensionsRef : IO (IO.Ref (Array (EnvExtension ExtensionEntry ExtensionState))) := +IO.mkRef Array.empty + +@[init mkEnvExtensionsRef] +private constant envExtensionsRef : IO.Ref (Array (EnvExtension ExtensionEntry ExtensionState)) := default _ + +/- TODO: replace export/extern trick with (the to be implemented) [implementedBy ...] attribute. + The extport/extern trick allows us to implement an opaque constant using an unsafe definition. -/ + +@[export leanRegisterEnvExtensionUnsafe] +unsafe def registerEnvExtensionUnsafe {α σ : Type} (name : Name) (initState : σ) (someVal : α) (addEntry : Bool → σ → α → σ) (toArray : List α → Array α) : IO (EnvExtension α σ) := +do +exts ← envExtensionsRef.get, +when (exts.any (λ ext, ext.name == name)) $ throw (IO.userError ("invalid environment extension, '" ++ toString name ++ "' has already been used")), +let idx := exts.size, +let ext : EnvExtension α σ := { + name := name, + idx := idx, + initState := initState, + someVal := someVal, + addEntry := addEntry, + toArray := toArray +}, +envExtensionsRef.modify (λ exts, exts.push (unsafeCast ext)), +pure ext + +@[extern "leanRegisterEnvExtensionUnsafe"] +constant registerEnvExtension {α σ : Type} (name : Name) (initState : σ) (someVal : α) (addEntry : Bool → σ → α → σ) (toArray : List α → Array α) : IO (EnvExtension α σ) := default _ + +def mkEmptyEnvironment (trustLevel : UInt32 := 0) : IO Environment := +do exts ← envExtensionsRef.get, +pure { const2ModId := {}, + constants := {}, + trustLevel := trustLevel, + extensions := exts.map $ λ ext, { + importedEntries := Array.empty, + importedState := Thunk.pure ext.initState + } +} + @[export leanGetModuleEntriesUnsafe] -unsafe def getModuleEntriesUnsafe {α σ : Type} (ext : EnvExtension α σ) (env : Environment) (m : ModuleId) : Array α := +unsafe def getModuleEntriesUnsafe {α σ : Type} (env : Environment) (ext : EnvExtension α σ) (m : ModuleId) : Array α := let entries := (env.extensions.get ext.idx).importedEntries.get m.id in unsafeCast entries @[extern "leanGetModuleEntriesUnsafe"] -constant getModuleEntries {α σ : Type} (ext : EnvExtension α σ) (env : Environment) (m : ModuleId) : Array α := default _ +constant getModuleEntries {α σ : Type} (env : Environment) (ext : EnvExtension α σ) (m : ModuleId) : Array α := default _ private def releaseExtensionData (env : Environment) (extIdx : Nat) : Environment := { extensions := env.extensions.set extIdx (default _), .. env } @@ -68,7 +116,7 @@ private def setExtensionData (env : Environment) (extIdx : Nat) (d : EnvExtensio { extensions := env.extensions.set extIdx d, .. env } @[export leanAddEntryUnsafe] -unsafe def addEntryUnsafe {α σ : Type} (ext : EnvExtension α σ) (env : Environment) (a : α) : Environment := +unsafe def addEntryUnsafe {α σ : Type} (env : Environment) (ext : EnvExtension α σ) (a : α) : Environment := let extIdx := ext.idx in let extData := env.extensions.get extIdx in let env := releaseExtensionData env extIdx in @@ -79,14 +127,44 @@ match extData.state with setExtensionData env extIdx extData | some s := let extData := { state := none, .. extData } in - let s : σ := @unsafeCast _ _ ⟨ext.initState⟩ s in - let s := ext.addEntry false env s a in - let extData := { state := unsafeCast s, .. extData } in + let s := ext.addEntry false (@unsafeCast _ _ ⟨ext.initState⟩ s) a in + let extData := { state := some (unsafeCast s), .. extData } in setExtensionData env extIdx extData @[extern "leanAddEntryUnsafe"] -constant addEntry {α σ : Type} (ext : EnvExtension α σ) (env : Environment) (a : α) : Environment := default _ +constant addEntry {α σ : Type} (env : Environment) (ext : EnvExtension α σ) (a : α) : Environment := default _ --- unsafe def getStateUnsafe {α σ : Type} (ext : EnvExtension α σ) (env : Environment) +unsafe def mkExtensionState {α σ : Type} (extData : EnvExtensionData) (ext : EnvExtension α σ) : ExtensionState := +let importedState := extData.importedState.get in +extData.entries.foldl + (λ s e, unsafeCast (ext.addEntry false (@unsafeCast _ _ ⟨ext.initState⟩ s) (@unsafeCast _ _ ⟨ext.someVal⟩ e))) + importedState + +@[export leanInitExtensionStateUnsafe] +unsafe def initExtensionStateUnsafe {α σ : Type} (env : Environment) (ext : EnvExtension α σ) : Environment := +let extIdx := ext.idx in +let extData := env.extensions.get extIdx in +if extData.state.isSome then env +else + let env := releaseExtensionData env extIdx in + let s := mkExtensionState extData ext in + let extData := { state := some s, .. extData } in + setExtensionData env extIdx extData + +@[extern "leanInitExtensionStateUnsafe"] +constant initExtensionState {α σ : Type} (env : Environment) (ext : EnvExtension α σ) : Environment := default _ + +@[export leanGetExtensionStateUnsafe] +unsafe def getExtensionStateUnsafe {α σ : Type} (env : Environment) (ext : EnvExtension α σ) : σ := +let extIdx := ext.idx in +let extData := env.extensions.get extIdx in +match extData.state with +| some s := @unsafeCast _ _ ⟨ext.initState⟩ s +| none := + let s := mkExtensionState extData ext in + @unsafeCast _ _ ⟨ext.initState⟩ s + +@[extern "leanGetExtensionStateUnsafe"] +constant getExtensionState {α σ : Type} (env : Environment) (ext : EnvExtension α σ) : σ := ext.initState end Lean