lean4-htt/src/Lean/LibrarySuggestions/Basic.lean
Kim Morrison 833aaa823e
chore: tactics using library suggestions set the caller field (#11171)
This PR ensures that tactics using library suggestions set the caller
field, so the premise selection engine has access to this. We'll later
use this to filter out some modules for grind, which we know have
already been fully annotated.

Co-authored-by: Claude <noreply@anthropic.com>
2025-11-14 04:50:55 +00:00

434 lines
17 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Kim Morrison
-/
module
prelude
public import Lean.Elab.Command
public import Lean.Meta.Eval
public import Lean.Meta.CompletionName
public import Lean.Linter.Deprecated
public import Init.Data.Random
/-!
# An API for library suggestion algorithms.
This module provides a basic API for library suggestion algorithms,
which are used to suggest relevant theorems from the library for the current goal.
In the literature this is usually known as "premise selection",
but we mostly avoid that term as most of our users will not be familiar with the term.
The core interface is the `Selector` type, which is a function from a metavariable
and a configuration to a list of suggestions.
The `Selector` is registered as an environment extension,
and the trivial (no suggestions) implementation is `Lean.LibrarySuggestions.empty`.
Lean does not provide a default library suggestion engine, so this module is intended to be used in conjunction
with a downstream package which registers a library suggestion engine.
-/
namespace Lean.Expr.FoldRelevantConstantsImpl
open Lean Meta
unsafe structure State where
visited : PtrSet Expr := mkPtrSet
visitedConsts : NameHashSet := {}
unsafe abbrev FoldM := StateT State MetaM
unsafe def fold {α : Type} (f : Name → α → MetaM α) (e : Expr) (acc : α) : FoldM α :=
let rec visit (e : Expr) (acc : α) : FoldM α := do
if (← get).visited.contains e then
return acc
modify fun s => { s with visited := s.visited.insert e }
if ← isProof e then
-- Don't visit proofs.
return acc
match e with
| .forallE n d b bi =>
let r ← visit d acc
withLocalDecl n bi d fun x =>
visit (b.instantiate1 x) r
| .lam n d b bi =>
let r ← visit d acc
withLocalDecl n bi d fun x =>
visit (b.instantiate1 x) r
| .mdata _ b => visit b acc
| .letE n t v b nondep =>
let r₁ ← visit t acc
let r₂ ← visit v r₁
withLetDecl n t v (nondep := nondep) fun x =>
visit (b.instantiate1 x) r₂
| .app f a =>
let fi ← getFunInfo f (some 1)
if fi.paramInfo[0]!.isInstImplicit then
-- Don't visit implicit arguments.
visit f acc
else
visit a (← visit f acc)
| .proj _ _ b => visit b acc
| .const c _ =>
if (← get).visitedConsts.contains c then
return acc
else
modify fun s => { s with visitedConsts := s.visitedConsts.insert c }
if ← isInstance c then
return acc
else
f c acc
| _ => return acc
visit e acc
@[inline] unsafe def foldUnsafe {α : Type} (e : Expr) (init : α) (f : Name → α → MetaM α) : MetaM α :=
(fold f e init).run' {}
end FoldRelevantConstantsImpl
/-- Apply `f` to every constant occurring in `e` once, skipping instance arguments and proofs. -/
@[implemented_by FoldRelevantConstantsImpl.foldUnsafe]
public opaque foldRelevantConstants {α : Type} (e : Expr) (init : α) (f : Name → α → MetaM α) : MetaM α := pure init
/-- Collect the constants occuring in `e` (once each), skipping instance arguments and proofs. -/
public def relevantConstants (e : Expr) : MetaM (Array Name) := foldRelevantConstants e #[] (fun n ns => return ns.push n)
/-- Collect the constants occuring in `e` (once each), skipping instance arguments and proofs. -/
public def relevantConstantsAsSet (e : Expr) : MetaM NameSet := foldRelevantConstants e ∅ (fun n ns => return ns.insert n)
end Lean.Expr
open Lean Meta MVarId in
public def Lean.MVarId.getConstants (g : MVarId) : MetaM NameSet := withContext g do
let mut c := (← g.getType).getUsedConstantsAsSet
for t in (← getLocalHyps) do
c := c (← inferType t).getUsedConstantsAsSet
return c
open Lean Meta MVarId in
public def Lean.MVarId.getRelevantConstants (g : MVarId) : MetaM NameSet := withContext g do
let mut c ← (← g.getType).relevantConstantsAsSet
for t in (← getLocalHyps) do
c := c (← (← inferType t).relevantConstantsAsSet)
return c
@[expose] public section
namespace Lean.LibrarySuggestions
/--
A `Suggestion` is essentially just an identifier and a confidence score that the identifier is relevant.
If the premise selection request included information about the intended use (e.g. in the simplifier, in `grind`, etc.)
the score may be adjusted for that application.
-/
structure Suggestion where
name : Name
/--
The score of the suggestion, as a probability that this suggestion should be used.
-/
score : Float
/--
Optional flag associated with the suggestion, e.g. "←" or "=",
if the premise selection algorithm is aware of the tactic consuming the results,
and wants to suggest modifiers for this suggestion.
E.g. this supports calling `simp` in the reverse direction,
or telling `grind` or `aesop` to use the theorem in a particular way.
-/
flag : Option String := none
structure Config where
/--
The maximum number of suggestions to return.
-/
maxSuggestions : Nat := 100
/--
The tactic that is calling the premise selection, e.g. `simp`, `grind`, or `aesop`.
This may be used to adjust the score of the suggestions
-/
caller : Option String := none
/--
A filter on suggestions; only suggestions returning `true` should be returned.
(It can be better to filter on the premise selection side, to ensure that enough suggestions are returned.)
-/
filter : Name → MetaM Bool := fun _ => pure true
/--
An optional arbitrary "hint" to the premise selection algorithm.
There is no guarantee that the algorithm will make any use of the hint.
Potential use cases include a natural language comment provided by the user
(e.g. allowing use of the premise selector as a search engine)
or including context from the current proof and/or file.
We may later split these use cases into separate fields if necessary.
-/
hint : Option String := none
abbrev Selector : Type := MVarId → Config → MetaM (Array Suggestion)
/--
Construct a `Selector` (which acts on an `MVarId`)
from a function which takes the pretty printed goal.
-/
def ppSelector (selector : String → Config → MetaM (Array Suggestion)) (g : MVarId) (c : Config) :
MetaM (Array Suggestion) := do
selector (toString (← Meta.ppGoal g)) c
namespace Selector
/--
Respect the `Config.filter` option by implementing it as a post-filter.
If a premise selection implementation does not natively handle the filter,
it should be wrapped with this function.
-/
def postFilter (selector : Selector) : Selector := fun g c => do
let suggestions ← selector g { c with filter := fun _ => pure true }
suggestions.filterM (fun s => c.filter s.name)
/--
Wrapper for `Selector` that ensures
the `maxSuggestions` field in `Config` is respected, post-hoc.
-/
def maxSuggestions (selector : Selector) : Selector := fun g c => do
let suggestions ← selector g c
return suggestions.take c.maxSuggestions
/-- Combine two premise selectors, returning the best suggestions. -/
def combine (selector₁ selector₂ : Selector) : Selector := fun g c => do
let suggestions₁ ← selector₁ g c
let suggestions₂ ← selector₂ g c
let mut dedupMap : Std.HashMap (Name × Option String) Suggestion := {}
for s in suggestions₁ ++ suggestions₂ do
let key := (s.name, s.flag)
dedupMap := dedupMap.alter key fun
| none => some s
| some s' => if s.score > s'.score then some s else some s'
let deduped := dedupMap.valuesArray
let sorted := deduped.qsort (fun s₁ s₂ => s₁.score > s₂.score)
return sorted.take c.maxSuggestions
/--
Combine two premise selectors by interspersing their results (ignoring scores).
The parameter `ratio` (defaulting to 0.5) controls the ratio of suggestions from each selector
while results are available from both.
-/
def intersperse (selector₁ selector₂ : Selector) (ratio : Float := 0.5) : Selector := fun g c => do
-- Calculate how many suggestions to request from each selector based on the ratio
let max₁ := (c.maxSuggestions.toFloat * ratio).toUInt32.toNat
let max₂ := (c.maxSuggestions.toFloat * (1 - ratio)).toUInt32.toNat
let suggestions₁ ← selector₁ g { c with maxSuggestions := max₁ }
let suggestions₂ ← selector₂ g { c with maxSuggestions := max₂ }
let mut result := #[]
let mut i₁ := 0
let mut i₂ := 0
let mut count₁ := 0.0
let mut count₂ := 0.0
-- Intersperse while both arrays have elements
while h : i₁ < suggestions₁.size ∧ i₂ < suggestions₂.size ∧ result.size < c.maxSuggestions do
-- Decide whether to take from selector₁ or selector₂ based on the ratio
let currentRatio := if count₁ + count₂ <= 0.0 then 0.0 else count₁ / (count₁ + count₂)
if currentRatio < ratio then
result := result.push suggestions₁[i₁]
i₁ := i₁ + 1
count₁ := count₁ + 1
else
result := result.push suggestions₂[i₂]
i₂ := i₂ + 1
count₂ := count₂ + 1
-- Append remaining elements from selector₁
while h : i₁ < suggestions₁.size ∧ result.size < c.maxSuggestions do
result := result.push suggestions₁[i₁]
i₁ := i₁ + 1
-- Append remaining elements from selector₂
while h : i₂ < suggestions₂.size ∧ result.size < c.maxSuggestions do
result := result.push suggestions₂[i₂]
i₂ := i₂ + 1
return result
end Selector
section DenyList
/--
Premises from a module whose name has one of the following components are not retrieved.
Use `run_cmd modifyEnv fun env => moduleDenyListExt.addEntry env module` to add a module to the deny list.
-/
builtin_initialize moduleDenyListExt : SimplePersistentEnvExtension String (List String) ←
registerSimplePersistentEnvExtension {
addEntryFn := (·.cons)
addImportedFn := mkStateFromImportedEntries (·.cons) ["Lake", "Lean", "Internal", "Tactic"]
}
/--
A premise whose name has one of the following components is not retrieved.
Use `run_cmd modifyEnv fun env => nameDenyListExt.addEntry env name` to add a name to the deny list.
-/
builtin_initialize nameDenyListExt : SimplePersistentEnvExtension String (List String) ←
registerSimplePersistentEnvExtension {
addEntryFn := (·.cons)
addImportedFn := mkStateFromImportedEntries (·.cons) ["Lake", "Lean", "Internal", "Tactic"]
}
/--
A premise whose `type.getForallBody.getAppFn` is a constant that has one of these prefixes is not retrieved.
Use `run_cmd modifyEnv fun env => typePrefixDenyListExt.addEntry env typePrefix` to add a type prefix to the deny list.
-/
builtin_initialize typePrefixDenyListExt : SimplePersistentEnvExtension Name (List Name) ←
registerSimplePersistentEnvExtension {
addEntryFn := (·.cons)
addImportedFn := mkStateFromImportedEntries (·.cons) [`Lake, `Lean]
}
def isDeniedModule (env : Environment) (moduleName : Name) : Bool :=
(moduleDenyListExt.getState env).any fun p => moduleName.anyS (· == p)
def isDeniedPremise (env : Environment) (name : Name) : Bool := Id.run do
if name == ``sorryAx then return true
if name.isInternalDetail then return true
if Lean.Meta.isInstanceCore env name then return true
if Lean.Linter.isDeprecated env name then return true
if (nameDenyListExt.getState env).any (fun p => name.anyS (· == p)) then return true
if let some moduleIdx := env.getModuleIdxFor? name then
let moduleName := env.header.moduleNames[moduleIdx.toNat]!
if isDeniedModule env moduleName then
return true
let some ci := env.find? name | return true
if let .const fnName _ := ci.type.getForallBody.getAppFn then
if (typePrefixDenyListExt.getState env).any (fun p => p.isPrefixOf fnName) then return true
return false
end DenyList
/--
The trivial premise selector, which returns no suggestions.
-/
def empty : Selector := fun _ _ => pure #[]
/-- A random premise selection algorithm, provided solely for testing purposes. -/
def random (gen : StdGen := ⟨37, 59⟩) : Selector := fun _ cfg => do
IO.stdGenRef.set gen
let env ← getEnv
let max := cfg.maxSuggestions
let consts := env.const2ModIdx.keysArray
let mut suggestions := #[]
while suggestions.size < max do
let i ← IO.rand 0 consts.size
let name := consts[i]!
unless isDeniedPremise env name do
suggestions := suggestions.push { name := name, score := 1.0 / consts.size.toFloat }
return suggestions
/-- A library suggestion engine that returns locally defined theorems (those in the current file). -/
def currentFile : Selector := fun _ cfg => do
let env ← getEnv
let max := cfg.maxSuggestions
-- Use map₂ from the staged map, which contains locally defined constants
let mut suggestions := #[]
for (name, ci) in env.constants.map₂.toList do
if suggestions.size >= max then
break
if isDeniedPremise env name then
continue
match ci with
| .thmInfo _ => suggestions := suggestions.push { name := name, score := 1.0 }
| _ => continue
return suggestions
builtin_initialize librarySuggestionsExt : SimplePersistentEnvExtension Syntax (Option Syntax) ←
registerSimplePersistentEnvExtension {
addEntryFn := fun _ stx => some stx -- Last entry wins
addImportedFn := fun entries =>
-- Take the last selector syntax from all imported modules
entries.foldl (init := none) fun acc moduleEntries =>
moduleEntries.foldl (init := acc) fun _ stx => some stx
}
/--
Helper function to elaborate and evaluate selector syntax.
This is shared by both validation (`elabSetLibrarySuggestions`) and retrieval (`getSelector`).
-/
def elabAndEvalSelector (stx : Syntax) : MetaM Selector :=
Elab.Term.TermElabM.run' do
let selectorTerm ← Elab.Term.elabTermEnsuringType stx (some (Expr.const ``Selector []))
unsafe Meta.evalExpr Selector (Expr.const ``Selector []) selectorTerm
/--
Get the currently registered library suggestions selector by evaluating the stored syntax.
Returns `none` if no selector is registered or if evaluation fails.
Uses `Term.elabTermEnsuringType` to elaborate arbitrary syntax (not just identifiers).
-/
def getSelector : MetaM (Option Selector) := do
let some stx := librarySuggestionsExt.getState (← getEnv) | return none
try
let selector ← elabAndEvalSelector stx
return some selector
catch _ =>
return none
/-- Generate library suggestions for the given metavariable, using the currently registered library suggestions engine. -/
def select (m : MVarId) (c : Config := {}) : MetaM (Array Suggestion) := do
let some selector ← getSelector |
throwError "No library suggestions engine registered. \
(Note that Lean does not provide a default library suggestions engine, \
these must be provided by a downstream library, \
and configured using `set_library_suggestions`.)"
selector m c
/-!
Currently the registration mechanism is just global state.
This means that if multiple modules register library suggestions engines,
the behaviour will be dependent on the order of loading modules.
We should replace this with a mechanism so that
library suggestions engines are configured via options in the `lakefile`, and
commands are only used to override in a single declaration or file.
-/
open Lean Elab Command in
@[builtin_command_elab setLibrarySuggestionsCmd, inherit_doc setLibrarySuggestionsCmd]
def elabSetLibrarySuggestions : CommandElab
| `(command| set_library_suggestions $selector) => do
if `Lean.LibrarySuggestions.Basic ∉ (← getEnv).header.moduleNames then
logWarning "Add `import Lean.LibrarySuggestions.Basic` before using the `set_library_suggestions` command."
-- Validate that the syntax can be elaborated (to catch errors early)
liftTermElabM do
try
discard <| elabAndEvalSelector selector
catch _ =>
throwError "Failed to elaborate {selector} as a `MVarId → Config → MetaM (Array Suggestion)`."
-- Store the syntax (not the evaluated Selector) for persistence
modifyEnv fun env => librarySuggestionsExt.addEntry env selector
| _ => throwUnsupportedSyntax
open Lean.Elab.Tactic in
@[builtin_tactic Lean.Parser.Tactic.suggestions] def evalSuggestions : Tactic := fun _ =>
liftMetaTactic1 fun mvarId => do
let suggestions ← select mvarId
let mut msg : MessageData := "Library suggestions:"
-- Check if all scores are 1.0
let allScoresOne := suggestions.all (·.score == 1.0)
for s in suggestions do
msg := msg ++ Format.line ++ " " ++ MessageData.ofConstName s.name
if !allScoresOne then
msg := msg ++ m!" (score: {s.score})"
if let some flag := s.flag then
msg := msg ++ m!" [{flag}]"
logInfo msg
return mvarId
end Lean.LibrarySuggestions