diff --git a/src/Lean/Compiler/LCNF/SpecInfo.lean b/src/Lean/Compiler/LCNF/SpecInfo.lean index 9399a45aaf..d0553efe66 100644 --- a/src/Lean/Compiler/LCNF/SpecInfo.lean +++ b/src/Lean/Compiler/LCNF/SpecInfo.lean @@ -50,7 +50,7 @@ instance : ToMessageData SpecParamInfo where | .other => "O" structure SpecState where - specInfo : SMap Name (Array SpecParamInfo) := {} + specInfo : PHashMap Name (Array SpecParamInfo) := {} deriving Inhabited structure SpecEntry where @@ -64,11 +64,17 @@ def addEntry (s : SpecState) (e : SpecEntry) : SpecState := match s with | { specInfo } => { specInfo := specInfo.insert e.declName e.paramsInfo } -def switch : SpecState → SpecState - | { specInfo } => { specInfo := specInfo.switch } - end SpecState +private abbrev declLt (a b : SpecEntry) := + Name.quickLt a.declName b.declName + +private abbrev sortEntries (entries : Array SpecEntry) : Array SpecEntry := + entries.qsort declLt + +private abbrev findAtSorted? (entries : Array SpecEntry) (declName : Name) : Option SpecEntry := + entries.binSearch { declName, paramsInfo := #[] } declLt + /-- Extension for storing `SpecParamInfo` for declarations being compiled. Remark: we only store information for declarations that will be specialized. @@ -76,7 +82,8 @@ Remark: we only store information for declarations that will be specialized. builtin_initialize specExtension : SimplePersistentEnvExtension SpecEntry SpecState ← registerSimplePersistentEnvExtension { addEntryFn := SpecState.addEntry - addImportedFn := fun es => mkStateFromImportedEntries SpecState.addEntry {} es |>.switch + addImportedFn := fun _ => {} + toArrayFn := fun s => sortEntries s.toArray } /-- @@ -188,13 +195,19 @@ def saveSpecParamInfo (decls : Array Decl) : CompilerM Unit := do modifyEnv fun env => specExtension.addEntry env { declName := decl.name, paramsInfo } def getSpecParamInfoCore? (env : Environment) (declName : Name) : Option (Array SpecParamInfo) := - (specExtension.getState env).specInfo.find? declName + match env.getModuleIdxFor? declName with + | some modIdx => + if let some entry := findAtSorted? (specExtension.getModuleEntries env modIdx) declName then + some entry.paramsInfo + else + none + | none => (specExtension.getState env).specInfo.find? declName def getSpecParamInfo? [Monad m] [MonadEnv m] (declName : Name) : m (Option (Array SpecParamInfo)) := - return (specExtension.getState (← getEnv)).specInfo.find? declName + return getSpecParamInfoCore? (← getEnv) declName -def isSpecCandidate [Monad m] [MonadEnv m] (declName : Name) : m Bool := - return (specExtension.getState (← getEnv)).specInfo.contains declName +def isSpecCandidate [Monad m] [MonadEnv m] (declName : Name) : m Bool := do + return getSpecParamInfoCore? (← getEnv) declName |>.isSome builtin_initialize registerTraceClass `Compiler.specialize.info