chore: further shake improvements (#10947)

This commit is contained in:
Sebastian Ullrich 2025-10-26 12:27:19 +01:00 committed by GitHub
parent 4887eeb77c
commit 77ddfd49e6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 89 additions and 92 deletions

View file

@ -131,10 +131,11 @@ structure State where
`transDeps[i]` is the (non-reflexive) transitive closure of `mods[i].imports`. More specifically,
* `j ∈ transDeps[i].pub` if `i -(public import)->+ j`
* `j ∈ transDeps[i].priv` if `i -(import ...)-> _ -(public import)->* j`
* `j ∈ transDeps[i].priv` if `i -(import all)->+ -(public import ...)-> _ -(public import)->* j`
* `j ∈ transDeps[i].metaPub` if `i -(public (meta)? import)->* _ -(public meta import)-> _ -(public (meta)? import ...)->* j`
* `j ∈ transDeps[i].metaPriv` if `i -(meta import ...)-> _ -(public (meta)? import ...)->* j`
* `j ∈ transDeps[i].metaPriv` if `i -(import all)->+ -(public meta import ...)-> _ -(public (meta)? import ...)->* j`
* `j ∈ transDeps[i].priv` if `i -(import all)->+ i'` and `j ∈ transDeps[i'].pub/priv`
* `j ∈ transDeps[i].metaPub` if `i -(public (meta)? import)->* _ -(public meta import)-> _ -(public (meta)? import)->* j`
* `j ∈ transDeps[i].metaPriv` if `i -(meta import ...)-> _ -(public (meta)? import)->* j`
* `j ∈ transDeps[i].metaPriv` if `i -(import ...)-> i'` and `j ∈ transDeps[i'].metaPub`
* `j ∈ transDeps[i].metaPriv` if `i -(import all)->+ i'` and `j ∈ transDeps[i'].metaPub/metaPriv`
-/
transDeps : Array Needs := #[]
/--
@ -162,10 +163,10 @@ def addTransitiveImps (transImps : Needs) (imp : Import) (j : Nat) (impTransImps
-- `j ∈ transDeps[i].priv` if `i -(import ...)-> _ -(public import)->* j`
transImps := transImps.union .priv {j} |>.union .priv (impTransImps.get .pub)
if imp.importAll then
-- `j ∈ transDeps[i].priv` if `i -(import all)->+ -(public import ...)-> _ -(public import)->* j`
transImps := transImps.union .priv (impTransImps.get .pub)
-- `j ∈ transDeps[i].priv` if `i -(import all)->+ i'` and `j ∈ transDeps[i'].pub/priv`
transImps := transImps.union .priv (impTransImps.get .pub impTransImps.get .priv)
-- `j ∈ transDeps[i].metaPub` if `i -(public (meta)? import)->* _ -(public meta import)-> _ -(public (meta)? import ...)->* j`
-- `j ∈ transDeps[i].metaPub` if `i -(public (meta)? import)->* _ -(public meta import)-> _ -(public (meta)? import)->* j`
if imp.isExported then
transImps := transImps.union .metaPub (impTransImps.get .metaPub)
if imp.isMeta then
@ -173,10 +174,13 @@ def addTransitiveImps (transImps : Needs) (imp : Import) (j : Nat) (impTransImps
if !imp.isExported then
if imp.isMeta then
-- `j ∈ transDeps[i].metaPriv` if `i -(meta import ...)-> _ -(public (meta)? import ...)->* j`
-- `j ∈ transDeps[i].metaPriv` if `i -(meta import ...)-> _ -(public (meta)? import)->* j`
transImps := transImps.union .metaPriv {j} |>.union .metaPriv (impTransImps.get .pub impTransImps.get .metaPub)
if imp.importAll then
-- `j ∈ transDeps[i].metaPriv` if `i -(import all)->+ -(public meta import ...)-> _ -(public (meta)? import ...)->* j`
-- `j ∈ transDeps[i].metaPriv` if `i -(import all)->+ i'` and `j ∈ transDeps[i'].metaPub/metaPriv`
transImps := transImps.union .metaPriv (impTransImps.get .metaPub impTransImps.get .metaPriv)
else
-- `j ∈ transDeps[i].metaPriv` if `i -(import ...)-> i'` and `j ∈ transDeps[i'].metaPub`
transImps := transImps.union .metaPriv (impTransImps.get .metaPub)
transImps
@ -185,7 +189,8 @@ def addTransitiveImps (transImps : Needs) (imp : Import) (j : Nat) (impTransImps
def calcNeeds (env : Environment) (i : ModuleIdx) : Needs := Id.run do
let mut needs := default
for ci in env.header.moduleData[i]!.constants do
let pubCI? := env.setExporting true |>.find? ci.name
-- Added guard for cases like `structure` that are still exported even if private
let pubCI? := guard (!isPrivateName ci.name) *> (env.setExporting true).find? ci.name
let k := { isExported := pubCI?.isSome, isMeta := isMeta env ci.name }
needs := visitExpr k ci.type needs
if let some e := ci.value? (allowOpaque := true) then
@ -216,7 +221,8 @@ def getExplanations (env : Environment) (i : ModuleIdx) :
Std.HashMap (ModuleIdx × NeedsKind) (Option (Name × Name)) := Id.run do
let mut deps := default
for ci in env.header.moduleData[i]!.constants do
let pubCI? := env.setExporting true |>.find? ci.name
-- Added guard for cases like `structure` that are still exported even if private
let pubCI? := guard (!isPrivateName ci.name) *> (env.setExporting true).find? ci.name
let k := { isExported := pubCI?.isSome, isMeta := isMeta env ci.name }
deps := visitExpr k ci.name ci.type deps
if let some e := ci.value? (allowOpaque := true) then
@ -286,7 +292,7 @@ and `endPos` is the position of the end of the header.
-/
def parseHeaderFromString (text path : String) :
IO (System.FilePath × Parser.InputContext ×
TSyntaxArray ``Parser.Module.import × String.Pos) := do
TSyntax ``Parser.Module.header × String.Pos.Raw) := do
let inputCtx := Parser.mkInputContext text path
let (header, parserState, msgs) ← Parser.parseHeader inputCtx
if !msgs.toList.isEmpty then -- skip this file if there are parse errors
@ -294,8 +300,8 @@ def parseHeaderFromString (text path : String) :
throw <| .userError "parse errors in file"
-- the insertion point for `add` is the first newline after the imports
let insertion := header.raw.getTailPos?.getD parserState.pos
let insertion := text.findAux (· == '\n') text.endPos insertion + ⟨1⟩
pure (path, inputCtx, .mk header.raw[2].getArgs, insertion)
let insertion := text.findAux (· == '\n') text.endPos insertion + '\n'
pure (path, inputCtx, header, insertion)
/-- Parse a source file to extract the location of the import lines, for edits and error messages.
@ -304,13 +310,18 @@ and `endPos` is the position of the end of the header.
-/
def parseHeader (srcSearchPath : SearchPath) (mod : Name) :
IO (System.FilePath × Parser.InputContext ×
TSyntaxArray ``Parser.Module.import × String.Pos) := do
TSyntax ``Parser.Module.header × String.Pos.Raw) := do
-- Parse the input file
let some path ← srcSearchPath.findModuleWithExt "lean" mod
| throw <| .userError s!"error: failed to find source file for {mod}"
let text ← IO.FS.readFile path
parseHeaderFromString text path.toString
def decodeHeader : TSyntax ``Parser.Module.header → Option (TSyntax `module) × Option (TSyntax `prelude) × TSyntaxArray ``Parser.Module.import
| `(Parser.Module.header| $[module%$moduleTk?]? $[prelude%$preludeTk?]? $imports*) =>
(moduleTk?.map .mk, preludeTk?.map .mk, imports)
| _ => unreachable!
def decodeImport : TSyntax ``Parser.Module.import → Import
| `(Parser.Module.import| $[public%$pubTk?]? $[meta%$metaTk?]? import $[all%$allTk?]? $id) =>
{ module := id.getId, isExported := pubTk?.isSome, isMeta := metaTk?.isSome, importAll := allTk?.isSome }
@ -326,11 +337,20 @@ def decodeImport : TSyntax ``Parser.Module.import → Import
* `addOnly`: if true, only add missing imports, do not remove unused ones
-/
def visitModule (srcSearchPath : SearchPath)
(i : Nat) (needs : Needs) (preserve : Needs) (edits : Edits)
(i : Nat) (needs : Needs) (preserve : Needs) (edits : Edits) (headerStx : TSyntax ``Parser.Module.header)
(addOnly := false) (githubStyle := false) (explain := false) : StateT State IO Edits := do
let s ← get
-- Do transitive reduction of `needs` in `deps`.
let mut deps := needs
let (_, prelude?, imports) := decodeHeader headerStx
if prelude?.isNone then
deps := deps.union .pub {s.env.getModuleIdx? `Init |>.get!}
for imp in imports do
if addOnly || imp.raw.getTrailing?.any (·.toString.toSlice.contains "shake: keep") then
let imp := decodeImport imp
let j := s.env.getModuleIdx? imp.module |>.get!
let k := NeedsKind.ofImport imp
deps := deps.union k {j}
for j in [0:s.mods.size] do
let transDeps := s.transDeps[j]!
for k in NeedsKind.all do
@ -354,7 +374,8 @@ def visitModule (srcSearchPath : SearchPath)
newDeps := addTransitiveImps newDeps imp j s.transDeps[j]!
else
let k := NeedsKind.ofImport imp
if !addOnly && !deps.has k j && !deps.has { k with isExported := false } j then
-- A private import should also be removed if the public version is needed
if !deps.has k j || !k.isExported && deps.has { k with isExported := true } j then
toRemove := toRemove.push imp
else
newDeps := addTransitiveImps newDeps imp j s.transDeps[j]!
@ -385,7 +406,8 @@ def visitModule (srcSearchPath : SearchPath)
if githubStyle then
try
let (path, inputCtx, imports, endHeader) ← parseHeader srcSearchPath s.modNames[i]!
let (path, inputCtx, stx, endHeader) ← parseHeader srcSearchPath s.modNames[i]!
let (_, _, imports) := decodeHeader stx
for stx in imports do
if toRemove.any fun imp => imp == decodeImport stx then
let pos := inputCtx.fileMap.toPosition stx.raw.getPos?.get!
@ -529,33 +551,43 @@ def main (args : List String) : IO UInt32 := do
let needs := s.mods.mapIdx fun i _ =>
Task.spawn fun _ => calcNeeds s.env i
-- Parse headers in parallel
let headers ← s.mods.mapIdxM fun i _ =>
BaseIO.asTask (parseHeader srcSearchPath s.modNames[i]! |>.toBaseIO)
if args.fix then
println! "The following changes will be made automatically:"
-- Check all selected modules
let mut edits : Edits := ∅
let mut revNeeds : Needs := default
for i in [0:s.mods.size], t in needs do
edits ← visitModule (addOnly := !pkg.isPrefixOf s.modNames[i]!) srcSearchPath i t.get revNeeds edits args.githubStyle args.explain
if isExtraRevModUse s.env i then
revNeeds := revNeeds.union .priv {i}
for i in [0:s.mods.size], t in needs, header in headers do
match header.get with
| .ok (_, _, stx, _) =>
edits ← visitModule (addOnly := !pkg.isPrefixOf s.modNames[i]!)
srcSearchPath i t.get revNeeds edits stx args.githubStyle args.explain
if isExtraRevModUse s.env i then
revNeeds := revNeeds.union .priv {i}
| .error e =>
println! e.toString
if !args.fix then
-- return error if any issues were found
return if edits.isEmpty then 0 else 1
-- Apply the edits to existing files
let count ← edits.foldM (init := 0) fun count mod (remove, add) => do
let mut count := 0
for mod in s.modNames, header? in headers do
let some (remove, add) := edits[mod]? | continue
let add : Array Import := add.qsortOrd
-- Parse the input file
let (path, inputCtx, imports, insertion) ←
try parseHeader srcSearchPath mod
catch e => println! e.toString; return count
let .ok (path, inputCtx, stx, insertion) := header?.get | continue
let (_, _, imports) := decodeHeader stx
let text := inputCtx.fileMap.source
-- Calculate the edit result
let mut pos : String.Pos := 0
let mut pos : String.Pos.Raw := 0
let mut out : String := ""
let mut seen : Std.HashSet Import := {}
for stx in imports do
@ -563,7 +595,7 @@ def main (args : List String) : IO UInt32 := do
if remove.contains mod || seen.contains mod then
out := out ++ text.extract pos stx.raw.getPos?.get!
-- We use the end position of the syntax, but include whitespace up to the first newline
pos := text.findAux (· == '\n') text.rawEndPos stx.raw.getTailPos?.get! + ⟨1⟩
pos := text.findAux (· == '\n') text.rawEndPos stx.raw.getTailPos?.get! + '\n'
seen := seen.insert mod
out := out ++ text.extract pos insertion
for mod in add do
@ -573,7 +605,7 @@ def main (args : List String) : IO UInt32 := do
out := out ++ text.extract insertion text.rawEndPos
IO.FS.writeFile path out
return count + 1
count := count + 1
-- Since we throw an error upon encountering issues, we can be sure that everything worked
-- if we reach this point of the script.

View file

@ -48,5 +48,26 @@ partial def inferMeta (decls : Array Decl) : CompilerM Unit := do
modifyEnv (setDeclMeta · decl.name)
setClosureMeta decl
/--
Checks meta availability just before `evalConst`. This is a "last line of defense" as accesses
should have been checked at declaration time in case of attributes. We do not solely want to rely on
errors from the interpreter itself as those depend on whether we are running in the server.
-/
@[export lean_eval_check_meta]
private partial def evalCheckMeta (env : Environment) (declName : Name) : Except String Unit := do
if !env.header.isModule then
return
go declName |>.run' {}
where go (ref : Name) : StateT NameSet (Except String) Unit := do
if (← get).contains ref then
return
modify (·.insert ref)
if let some localDecl := declMapExt.getState env |>.find? ref then
for ref in collectUsedFDecls localDecl do
go ref
else
if getIRPhases env ref == .runtime then
throw s!"Cannot evaluate constant `{declName}` as it uses `{ref}` which is neither marked nor imported as `meta`"
builtin_initialize
registerTraceClass `compiler.ir.inferMeta

View file

@ -113,30 +113,6 @@ where go (isMeta isPublic : Bool) (decl : Decl) : StateT NameSet CompilerM Unit
if let some refDecl ← getLocalDecl? ref then
go isMeta isPublic refDecl
/--
Checks meta availability just before `evalConst`. This is a "last line of defense" as accesses
should have been checked at declaration time in case of attributes. We do not solely want to rely on
errors from the interpreter itself as those depend on whether we are running in the server.
-/
@[export lean_eval_check_meta]
private partial def evalCheckMeta (env : Environment) (declName : Name) : Except String Unit := do
if !env.header.isModule then
return
let some decl := getDeclCore? env baseExt declName
| return -- We might not have the LCNF available, in which case there's nothing we can do
go decl |>.run' {}
where go (decl : Decl) : StateT NameSet (Except String) Unit :=
decl.value.forCodeM fun code =>
for ref in collectUsedDecls code do
if (← get).contains ref then
continue
modify (·.insert ref)
if let some localDecl := baseExt.getState env |>.find? ref then
go localDecl
else
if getIRPhases env ref == .runtime then
throw s!"Cannot evaluate constant `{declName}` as it uses `{ref}` which is neither marked nor imported as `meta`"
/--
Checks that imports necessary for inlining/specialization are public as otherwise we may run into
unknown declarations at the point of inlining/specializing. This is a limitation that we want to

View file

@ -240,9 +240,10 @@ def elabScientificLit : TermElab := fun stx expectedType? => do
@[builtin_term_elab doubleQuotedName] def elabDoubleQuotedName : TermElab := fun stx _ => do
-- Always allow quoting private names.
let n ← withoutExporting <| realizeGlobalConstNoOverloadWithInfo stx[2]
recordExtraModUseFromDecl (isMeta := false) n
return toExpr n
withoutExporting do
let n ← realizeGlobalConstNoOverloadWithInfo stx[2]
recordExtraModUseFromDecl (isMeta := false) n
return toExpr n
@[builtin_term_elab declName] def elabDeclName : TermElab := adaptExpander fun _ => do
let some declName ← getDeclName?

View file

@ -330,11 +330,6 @@ def elabMutual : CommandElab := fun stx => do
Term.applyAttributes declName attrs
for attrName in toErase do
Attribute.erase declName attrName
if (← getEnv).isImportedConst declName && attrs.any (·.kind == .global) then
-- If an imported declaration is marked with a global attribute, there is no good way to track
-- its use generally and so Shake should conservatively preserve imports of the current
-- module.
recordExtraRevUseOfCurrentModule
@[builtin_command_elab Lean.Parser.Command.«initialize»] def elabInitialize : CommandElab
| stx@`($declModifiers:declModifiers $kw:initializeKeyword $[$id? : $type? ←]? $doSeq) => do

View file

@ -130,8 +130,6 @@ private partial def quoteSyntax : Syntax → TermElabM Term
-- Add global scopes at compilation time (now), add macro scope at runtime (in the quotation).
-- See the paper for details.
let consts ← resolveGlobalName val
-- Record all constants to make sure they can still be resolved after shaking imports
consts.forM fun (n, _) => recordExtraModUseFromDecl (isMeta := false) n
-- extension of the paper algorithm: also store unique section variable names as top-level scopes
-- so they can be captured and used inside the section, but not outside
let sectionVars := resolveSectionVariable (← read).sectionVars val

View file

@ -357,6 +357,8 @@ private def elabGrindConfig' (config : TSyntax ``Lean.Parser.Tactic.optConfig) (
elabGrindConfig config
@[builtin_tactic Lean.Parser.Tactic.grind] def evalGrind : Tactic := fun stx => do
-- Preserve this import in core; all others import `Init` anyway
recordExtraModUse (isMeta := false) `Init.Grind.Tactics
match stx with
| `(tactic| grind $config:optConfig $[only%$only]? $[ [$params:grindParam,*] ]? $[=> $seq:grindSeq]?) =>
let interactive := seq.isSome

View file

@ -60,20 +60,6 @@ Is rev mod use: false
-/
#guard_msgs in #eval showExtraModUses
/-!
Running `attribute` with declarations from an imported module causes a rev use.
-/
#eval resetExtraModUses
attribute [builtin_doc] Int.natCast_add
/--
info: Entries: []
Is rev mod use: true
-/
#guard_msgs in #eval showExtraModUses
/-!
`recommended_spelling` records a dependency.
-/
@ -188,7 +174,7 @@ attribute [grind =] List.append
/--
info: Entries: [import Init.Grind.Attr, public import Init.Prelude]
Is rev mod use: true
Is rev mod use: false
-/
#guard_msgs in #eval showExtraModUses
@ -266,20 +252,6 @@ Is rev mod use: false
-/
#guard_msgs in #eval showExtraModUses
/-!
Resolved constants from syntax quotations get added (here `List.sum` from Init.Data.List.Basic).
-/
#eval resetExtraModUses
def test9 : Lean.MacroM Lean.Syntax := `(List.sum)
/--
info: Entries: [import Init.Notation, import Init.Coe, import Init.Data.List.Basic]
Is rev mod use: false
-/
#guard_msgs in #eval showExtraModUses
/-!
Elaboration attributes add dependency on the syntax node kind
(here `Lean.Parser.Tactic.done` from Init.Tactics).