feat: sine qua non premise selection
This commit is contained in:
parent
33e92677ba
commit
7a8c2daf96
10 changed files with 412 additions and 105 deletions
|
|
@ -32,7 +32,9 @@ namespace Lean
|
|||
-/
|
||||
structure SMap (α : Type u) (β : Type v) [BEq α] [Hashable α] where
|
||||
stage₁ : Bool := true
|
||||
/-- Imported constants. -/
|
||||
map₁ : Std.HashMap α β := {}
|
||||
/-- Local constants defined in the current module. -/
|
||||
map₂ : PHashMap α β := {}
|
||||
|
||||
namespace SMap
|
||||
|
|
|
|||
|
|
@ -9,3 +9,4 @@ prelude
|
|||
import Lean.PremiseSelection.Basic
|
||||
import Lean.PremiseSelection.SymbolFrequency
|
||||
import Lean.PremiseSelection.MePo
|
||||
import Lean.PremiseSelection.SineQuaNon
|
||||
|
|
|
|||
|
|
@ -27,6 +27,90 @@ Lean does not provide a default premise selector, so this module is intended to
|
|||
with a downstream package which registers a premise selector.
|
||||
-/
|
||||
|
||||
namespace Lean.Expr.FoldRelevantConstantsImpl
|
||||
|
||||
open Lean Meta
|
||||
|
||||
unsafe structure State where
|
||||
visited : PtrSet Expr := mkPtrSet
|
||||
visitedConsts : NameHashSet := {}
|
||||
|
||||
unsafe abbrev FoldM := StateT State MetaM
|
||||
|
||||
unsafe def fold {α : Type} (f : Name → α → MetaM α) (e : Expr) (acc : α) : FoldM α :=
|
||||
let rec visit (e : Expr) (acc : α) : FoldM α := do
|
||||
if (← get).visited.contains e then
|
||||
return acc
|
||||
modify fun s => { s with visited := s.visited.insert e }
|
||||
if ← isProof e then
|
||||
-- Don't visit proofs.
|
||||
return acc
|
||||
match e with
|
||||
| .forallE n d b bi =>
|
||||
let r ← visit d acc
|
||||
withLocalDecl n bi d fun x =>
|
||||
visit (b.instantiate1 x) r
|
||||
| .lam n d b bi =>
|
||||
let r ← visit d acc
|
||||
withLocalDecl n bi d fun x =>
|
||||
visit (b.instantiate1 x) r
|
||||
| .mdata _ b => visit b acc
|
||||
| .letE n t v b nondep =>
|
||||
let r₁ ← visit t acc
|
||||
let r₂ ← visit v r₁
|
||||
withLetDecl n t v (nondep := nondep) fun x =>
|
||||
visit (b.instantiate1 x) r₂
|
||||
| .app f a =>
|
||||
let fi ← getFunInfo f (some 1)
|
||||
if fi.paramInfo[0]!.isInstImplicit then
|
||||
-- Don't visit implicit arguments.
|
||||
visit f acc
|
||||
else
|
||||
visit a (← visit f acc)
|
||||
| .proj _ _ b => visit b acc
|
||||
| .const c _ =>
|
||||
if (← get).visitedConsts.contains c then
|
||||
return acc
|
||||
else
|
||||
modify fun s => { s with visitedConsts := s.visitedConsts.insert c }
|
||||
if ← isInstance c then
|
||||
return acc
|
||||
else
|
||||
f c acc
|
||||
| _ => return acc
|
||||
visit e acc
|
||||
|
||||
@[inline] unsafe def foldUnsafe {α : Type} (e : Expr) (init : α) (f : Name → α → MetaM α) : MetaM α :=
|
||||
(fold f e init).run' {}
|
||||
|
||||
end FoldRelevantConstantsImpl
|
||||
|
||||
/-- Apply `f` to every constant occurring in `e` once, skipping instance arguments and proofs. -/
|
||||
@[implemented_by FoldRelevantConstantsImpl.foldUnsafe]
|
||||
public opaque foldRelevantConstants {α : Type} (e : Expr) (init : α) (f : Name → α → MetaM α) : MetaM α := pure init
|
||||
|
||||
/-- Collect the constants occuring in `e` (once each), skipping instance arguments and proofs. -/
|
||||
public def relevantConstants (e : Expr) : MetaM (Array Name) := foldRelevantConstants e #[] (fun n ns => return ns.push n)
|
||||
|
||||
/-- Collect the constants occuring in `e` (once each), skipping instance arguments and proofs. -/
|
||||
public def relevantConstantsAsSet (e : Expr) : MetaM NameSet := foldRelevantConstants e ∅ (fun n ns => return ns.insert n)
|
||||
|
||||
end Lean.Expr
|
||||
|
||||
open Lean Meta MVarId in
|
||||
public def Lean.MVarId.getConstants (g : MVarId) : MetaM NameSet := withContext g do
|
||||
let mut c := (← g.getType).getUsedConstantsAsSet
|
||||
for t in (← getLocalHyps) do
|
||||
c := c ∪ (← inferType t).getUsedConstantsAsSet
|
||||
return c
|
||||
|
||||
open Lean Meta MVarId in
|
||||
public def Lean.MVarId.getRelevantConstants (g : MVarId) : MetaM NameSet := withContext g do
|
||||
let mut c ← (← g.getType).relevantConstantsAsSet
|
||||
for t in (← getLocalHyps) do
|
||||
c := c ∪ (← (← inferType t).relevantConstantsAsSet)
|
||||
return c
|
||||
|
||||
@[expose] public section
|
||||
|
||||
namespace Lean.PremiseSelection
|
||||
|
|
@ -130,25 +214,37 @@ end Selector
|
|||
|
||||
section DenyList
|
||||
|
||||
/-- Premises from a module whose name has one of the following components are not retrieved. -/
|
||||
/--
|
||||
Premises from a module whose name has one of the following components are not retrieved.
|
||||
|
||||
Use `run_cmd modifyEnv fun env => moduleDenyListExt.addEntry env module` to add a module to the deny list.
|
||||
-/
|
||||
builtin_initialize moduleDenyListExt : SimplePersistentEnvExtension String (List String) ←
|
||||
registerSimplePersistentEnvExtension {
|
||||
addEntryFn := (·.cons)
|
||||
addImportedFn := mkStateFromImportedEntries (·.cons) []
|
||||
addImportedFn := mkStateFromImportedEntries (·.cons) ["Lake", "Lean", "Internal", "Tactic"]
|
||||
}
|
||||
|
||||
/-- A premise whose name has one of the following components is not retrieved. -/
|
||||
/--
|
||||
A premise whose name has one of the following components is not retrieved.
|
||||
|
||||
Use `run_cmd modifyEnv fun env => nameDenyListExt.addEntry env name` to add a name to the deny list.
|
||||
-/
|
||||
builtin_initialize nameDenyListExt : SimplePersistentEnvExtension String (List String) ←
|
||||
registerSimplePersistentEnvExtension {
|
||||
addEntryFn := (·.cons)
|
||||
addImportedFn := mkStateFromImportedEntries (·.cons) []
|
||||
addImportedFn := mkStateFromImportedEntries (·.cons) ["Lake", "Lean", "Internal", "Tactic"]
|
||||
}
|
||||
|
||||
/-- A premise whose `type.getForallBody.getAppFn` is a constant that has one of these prefixes is not retrieved. -/
|
||||
/--
|
||||
A premise whose `type.getForallBody.getAppFn` is a constant that has one of these prefixes is not retrieved.
|
||||
|
||||
Use `run_cmd modifyEnv fun env => typePrefixDenyListExt.addEntry env typePrefix` to add a type prefix to the deny list.
|
||||
-/
|
||||
builtin_initialize typePrefixDenyListExt : SimplePersistentEnvExtension Name (List Name) ←
|
||||
registerSimplePersistentEnvExtension {
|
||||
addEntryFn := (·.cons)
|
||||
addImportedFn := mkStateFromImportedEntries (·.cons) []
|
||||
addImportedFn := mkStateFromImportedEntries (·.cons) [`Lake, `Lean]
|
||||
}
|
||||
|
||||
def isDeniedModule (env : Environment) (moduleName : Name) : Bool :=
|
||||
|
|
@ -157,6 +253,7 @@ def isDeniedModule (env : Environment) (moduleName : Name) : Bool :=
|
|||
def isDeniedPremise (env : Environment) (name : Name) : Bool := Id.run do
|
||||
if name == ``sorryAx then return true
|
||||
if name.isInternalDetail then return true
|
||||
if Lean.Meta.isInstanceCore env name then return true
|
||||
if (nameDenyListExt.getState env).any (fun p => name.anyS (· == p)) then return true
|
||||
if let some moduleIdx := env.getModuleIdxFor? name then
|
||||
let moduleName := env.header.moduleNames[moduleIdx.toNat]!
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ module
|
|||
|
||||
prelude
|
||||
public import Lean.PremiseSelection.Basic
|
||||
import Lean.PremiseSelection.SymbolFrequency
|
||||
import Lean.Meta.Basic
|
||||
|
||||
/-!
|
||||
|
|
@ -24,14 +25,6 @@ namespace Lean.PremiseSelection.MePo
|
|||
|
||||
builtin_initialize registerTraceClass `mepo
|
||||
|
||||
def symbolFrequency (env : Environment) : NameMap Nat := Id.run do
|
||||
-- TODO: ideally this would use a precomputed frequency map, as this is too slow.
|
||||
let mut map := {}
|
||||
for (_, ci) in env.constants do
|
||||
for n' in ci.type.getUsedConstantsAsSet do
|
||||
map := map.alter n' fun i? => some (i?.getD 0 + 1)
|
||||
return map
|
||||
|
||||
def weightedScore (weight : Name → Float) (relevant candidate : NameSet) : Float :=
|
||||
let S := candidate
|
||||
let R := relevant ∩ S
|
||||
|
|
@ -71,26 +64,19 @@ def mepo (initialRelevant : NameSet) (score : NameSet → NameSet → Float) (ac
|
|||
p := p + (1 - p) / c
|
||||
return accepted.qsort (fun a b => a.score > b.score)
|
||||
|
||||
open Lean Meta MVarId in
|
||||
def _root_.Lean.MVarId.getConstants (g : MVarId) : MetaM NameSet := withContext g do
|
||||
let mut c := (← g.getType).getUsedConstantsAsSet
|
||||
for t in (← getLocalHyps) do
|
||||
c := c ∪ (← inferType t).getUsedConstantsAsSet
|
||||
return c
|
||||
|
||||
end MePo
|
||||
|
||||
open MePo
|
||||
|
||||
-- The values of p := 0.6 and c := 2.4 are taken from the MePo paper, and need to be tuned.
|
||||
public def mepoSelector (useRarity : Bool) (p : Float := 0.6) (c : Float := 2.4) : Selector := fun g config => do
|
||||
let constants ← g.getConstants
|
||||
let constants ← g.getRelevantConstants
|
||||
let env ← getEnv
|
||||
let score := if useRarity then
|
||||
let frequency := symbolFrequency env
|
||||
frequencyScore (frequency.getD · 0)
|
||||
let score ← if useRarity then do
|
||||
let frequency ← symbolFrequencyMap
|
||||
pure <| frequencyScore (fun n => frequency.getD n 0)
|
||||
else
|
||||
unweightedScore
|
||||
pure <| unweightedScore
|
||||
let accept := fun ci => return !isDeniedPremise env ci.name
|
||||
let suggestions ← mepo constants score accept config.maxSuggestions p c
|
||||
let suggestions := suggestions
|
||||
|
|
|
|||
196
src/Lean/PremiseSelection/SineQuaNon.lean
Normal file
196
src/Lean/PremiseSelection/SineQuaNon.lean
Normal file
|
|
@ -0,0 +1,196 @@
|
|||
/-
|
||||
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Kim Morrison
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Lean.CoreM
|
||||
public import Lean.Meta.Basic
|
||||
import Lean.Meta.Instances
|
||||
import Lean.PremiseSelection.SymbolFrequency
|
||||
public import Lean.PremiseSelection.Basic
|
||||
|
||||
/-!
|
||||
# Sine Qua Non premise selection
|
||||
|
||||
This is an implementation of the "Sine Qua Non" premise selection algorithm, from
|
||||
"Sine Qua Non for Large Theory Reasoning" by Hodor and Voronkov.
|
||||
|
||||
It needs to be tuned and evaluated for Lean.
|
||||
-/
|
||||
|
||||
namespace Lean.PremiseSelection.SineQuaNon
|
||||
|
||||
builtin_initialize registerTraceClass `sineQuaNon
|
||||
|
||||
/--
|
||||
Constants which should not be used as triggers.
|
||||
|
||||
Use `run_cmd modifyEnv fun env => triggerDenyListExt.addEntry env trigger` to add a trigger to the deny list.
|
||||
-/
|
||||
builtin_initialize triggerDenyListExt : SimplePersistentEnvExtension Name NameSet ←
|
||||
registerSimplePersistentEnvExtension {
|
||||
addEntryFn := (·.insert)
|
||||
addImportedFn := mkStateFromImportedEntries (·.insert)
|
||||
(NameSet.ofList [`Eq, `BEq, `BEq.beq, `LE.le, `LT.lt, `GE.ge, `GT.gt,
|
||||
`Bool.not, `Bool.and, `Bool.or, `Bool.xor, `Bool.true, `Bool.false,
|
||||
`Not, `And, `Or, `Xor,
|
||||
`ite, `dite, `Exists, `OfNat, `OfNat.ofNat, `SizeOf, `SizeOf.sizeOf])
|
||||
}
|
||||
|
||||
/--
|
||||
Return the relevant constants (i.e. ignoring instances and proofs)
|
||||
which appear in the type of `ci` and which are approximately least frequent in the library
|
||||
(relative to other constants appearing in the type of `ci`).
|
||||
-/
|
||||
def triggerSymbols (ci : ConstantInfo) (maxTolerance : Float := 3.0) : MetaM (Array (Name × Float)) := do
|
||||
let denyList := triggerDenyListExt.getState (← getEnv)
|
||||
let consts ← ci.type.relevantConstants
|
||||
let frequencies ← consts.filterMapM fun n => do
|
||||
if denyList.contains n then
|
||||
return none
|
||||
let f := (← symbolFrequency n) + (← localSymbolFrequency n)
|
||||
return if f = 0 then
|
||||
none
|
||||
else
|
||||
some (n, f.toFloat)
|
||||
if frequencies.isEmpty then
|
||||
return #[]
|
||||
let minFrequency := frequencies.foldl (fun acc (_, f) => min acc f) (frequencies[0]!.2)
|
||||
return frequencies.filterMap
|
||||
(fun (n, f) => if f ≤ minFrequency * maxTolerance then some (n, f / minFrequency) else none)
|
||||
|
||||
def _root_.List.orderedInsert (r : α → α → Bool := by exact (· ≤ ·)) (a : α) : List α → List α
|
||||
| [] => [a]
|
||||
| b :: l => if r a b then a :: b :: l else b :: orderedInsert r a l
|
||||
|
||||
def insertTrigger (map : NameMap (List (Name × Float))) (trigger decl : Name) (tolerance : Float) :
|
||||
NameMap (List (Name × Float)) :=
|
||||
map.insert trigger (map.getD trigger [] |>.orderedInsert (fun x y => x.2 ≤ y.2) (decl, tolerance))
|
||||
|
||||
def prepareTriggers (names : Array Name) (maxTolerance : Float := 3.0) : MetaM (NameMap (List (Name × Float))) := do
|
||||
let mut map := {}
|
||||
let env ← getEnv
|
||||
let names := names.filter fun n =>
|
||||
!isDeniedPremise env n && Lean.wasOriginallyTheorem env n
|
||||
for name in names do
|
||||
let triggers ← triggerSymbols (← getConstInfo name) maxTolerance
|
||||
for (trigger, tolerance) in triggers do
|
||||
map := insertTrigger map trigger name tolerance
|
||||
return map
|
||||
|
||||
/--
|
||||
Combine two trigger maps, taking the sorted union of the triggered theorems for each symbol.
|
||||
If one map is much larger than the other, it should be the first argument.
|
||||
-/
|
||||
def combineTriggers (map₁ map₂ : NameMap (List (Name × Float))) : NameMap (List (Name × Float)) := Id.run do
|
||||
let mut map := map₁
|
||||
for (trigger, decls₂) in map₂ do
|
||||
map := match map₁.find? trigger with
|
||||
| none => map.insert trigger decls₂
|
||||
| some decls₁ => map.insert trigger (decls₂.foldl (init := decls₁) (fun acc (decl, tolerance) => acc.orderedInsert (fun x y => x.2 ≤ y.2) (decl, tolerance)))
|
||||
return map
|
||||
|
||||
/--
|
||||
The state is just an array of array of maps.
|
||||
We don't assemble these on import for efficiency reasons: most modules will not query this extension.
|
||||
|
||||
Instead, we use an `IO.Ref` below so that within each module we can assemble the global `NameMap (List (Name × Float))` once.
|
||||
|
||||
Since we never modify the extension state except on export, the `IO.Ref` does not need updating after first access.
|
||||
-/
|
||||
builtin_initialize sineQuaNonExt : PersistentEnvExtension (NameMap (List (Name × Float))) Empty (Array (Array (NameMap (List (Name × Float))))) ←
|
||||
registerPersistentEnvExtension {
|
||||
name := `sineQueNon
|
||||
mkInitial := pure ∅
|
||||
addImportedFn := fun mapss _ => pure mapss
|
||||
addEntryFn := nofun
|
||||
-- TODO: it would be nice to avoid the `toArray` here, e.g. via iterators.
|
||||
exportEntriesFnEx := fun env _ _ => env.unsafeRunMetaM do return #[← prepareTriggers (env.constants.map₂.toArray.map (·.1))]
|
||||
statsFn := fun _ => "sine qua non premise selection extension"
|
||||
}
|
||||
|
||||
/-- A global `IO.Ref` containing the "sine qua non" triggers. This is initialized on first use. -/
|
||||
builtin_initialize sineQuaNonTriggersRef : IO.Ref (Option (NameMap (List (Name × Float)))) ← IO.mkRef none
|
||||
|
||||
/-- The "sine qua non" triggers for imported constants. This is initialized on first use. -/
|
||||
def sineQuaNonTriggerMap : CoreM (NameMap (List (Name × Float))) := do
|
||||
match ← sineQuaNonTriggersRef.get with
|
||||
| some map => return map
|
||||
| none =>
|
||||
let mapss := sineQuaNonExt.getState (← getEnv)
|
||||
let map := mapss.foldl (init := {}) fun acc maps => maps.foldl (init := acc) fun acc map => combineTriggers acc map
|
||||
sineQuaNonTriggersRef.set (some map)
|
||||
return map
|
||||
|
||||
public def sineQuaNonTheorems (trigger : Name) : CoreM (List (Name × Float)) := do
|
||||
let map ← sineQuaNonTriggerMap
|
||||
return map.getD trigger []
|
||||
|
||||
def sineQuaNonTriggersFor (decl : Name) : CoreM (List (Name × Float)) := do
|
||||
let r ← sineQuaNonTriggerMap
|
||||
return r.toList.filterMap fun (t, v) =>
|
||||
(v.find? fun (n, _) => n == decl) |>.map fun (_, f) => (t, f)
|
||||
|
||||
local instance : Ord (Float × Name) where
|
||||
compare x y := if x.1 < y.1 then .lt else if x.1 > y.1 then .gt else Name.cmp x.2 y.2
|
||||
|
||||
def frequencyScore (n : Name) (frequencyWeight : Float := 0.01) : MetaM Float := do
|
||||
let f ← symbolFrequency n
|
||||
return 1.0 + frequencyWeight * (f + 1).toFloat.log2
|
||||
|
||||
/--
|
||||
This isn't exactly what's described in the paper.
|
||||
|
||||
We select theorems in a priority order, where the priority is `1.5 ^ (trigger depth) * Π (tolerances)`.
|
||||
|
||||
The `1.5` factor could be tuned.
|
||||
-/
|
||||
public partial def sineQuaNon (names : NameSet) (maxSuggestions : Nat) (depthFactor := 1.5) (frequencyWeight : Float := 0.01) :
|
||||
MetaM (Array Suggestion) := do
|
||||
let denyList := triggerDenyListExt.getState (← getEnv)
|
||||
let targets := names \ denyList
|
||||
let r ← go denyList targets
|
||||
(Std.TreeSet.ofList (← targets.toList.mapM (fun n => return (← frequencyScore n, n)))) #[] {}
|
||||
return r.map (fun (n, f) => { name := n, score := 1 / f })
|
||||
where go (denyList : NameSet)(pastTriggers : NameSet) (triggerQueue : Std.TreeSet (Float × Name) compare)
|
||||
(acceptedTheorems : Array (Name × Float)) (queuedTheorems : Std.TreeSet (Float × Name) compare) : MetaM (Array (Name × Float)) := do
|
||||
if acceptedTheorems.size ≥ maxSuggestions then return acceptedTheorems else
|
||||
-- Is there a companion to `min?` that gives the minimum element along with the rest of the set?
|
||||
match triggerQueue.min? with
|
||||
| some (tf, t) => do
|
||||
let qf? := queuedTheorems.min?.map (·.1)
|
||||
if match qf? with | none => true | some qf => tf < qf then
|
||||
trace[sineQuaNon] m!"\
|
||||
acceptedTheorems: {acceptedTheorems}\n\
|
||||
pastTriggers: {pastTriggers.toList}\n\
|
||||
triggerQueue: {triggerQueue.toList}\n\
|
||||
queuedTheorems: {queuedTheorems.toList}"
|
||||
let theorems ← sineQuaNonTheorems t
|
||||
return ← go denyList pastTriggers (triggerQueue.erase (tf, t)) acceptedTheorems
|
||||
(theorems.foldl (init := queuedTheorems) fun acc (p, pf) => acc.insert (pf * tf, p))
|
||||
| none => pure ()
|
||||
match queuedTheorems.min? with
|
||||
| none => return acceptedTheorems
|
||||
| some (qf, q) =>
|
||||
let ci ← getConstInfo q
|
||||
let (pastTriggers', triggersQueue') ← (← ci.type.relevantConstants).foldlM (init := (pastTriggers, triggerQueue))
|
||||
fun ⟨pastTriggers', triggersQueue'⟩ n => do
|
||||
if pastTriggers'.contains n || denyList.contains n then
|
||||
pure ⟨pastTriggers', triggersQueue'⟩
|
||||
else
|
||||
pure <| ⟨pastTriggers'.insert n, triggersQueue'.insert (qf * depthFactor * (← frequencyScore n frequencyWeight), n)⟩
|
||||
go denyList pastTriggers' triggersQueue' (acceptedTheorems.push (q, qf)) (queuedTheorems.erase (qf, q))
|
||||
|
||||
end SineQuaNon
|
||||
|
||||
open SineQuaNon
|
||||
|
||||
public def sineQuaNonSelector (depthFactor : Float := 1.5) : Selector := fun g config => do
|
||||
let constants ← g.getRelevantConstants
|
||||
let suggestions ← sineQuaNon constants config.maxSuggestions depthFactor
|
||||
return suggestions.take config.maxSuggestions
|
||||
|
||||
end Lean.PremiseSelection
|
||||
|
|
@ -7,9 +7,11 @@ module
|
|||
|
||||
prelude
|
||||
public import Lean.CoreM
|
||||
public import Lean.Meta.Basic
|
||||
import Lean.Meta.InferType
|
||||
import Lean.Meta.FunInfo
|
||||
import Lean.AddDecl
|
||||
import Lean.PremiseSelection.Basic
|
||||
|
||||
/-!
|
||||
# Symbol frequency
|
||||
|
|
@ -19,67 +21,55 @@ This module provides a persistent environment extension for computing the freque
|
|||
|
||||
namespace Lean.PremiseSelection
|
||||
|
||||
namespace FoldRelevantConstsImpl
|
||||
/--
|
||||
Collect the frequencies for constants occurring in declarations defined in the current module,
|
||||
skipping instance arguments and proofs.
|
||||
-/
|
||||
public def localSymbolFrequencyMap : MetaM (NameMap Nat) := do
|
||||
let env := (← getEnv)
|
||||
env.constants.map₂.foldlM (init := ∅) (fun acc m ci => do
|
||||
if isDeniedPremise env m || !Lean.wasOriginallyTheorem env m then
|
||||
pure acc
|
||||
else
|
||||
ci.type.foldRelevantConstants (init := acc) fun n' acc => return acc.alter n' fun i? => some (i?.getD 0 + 1))
|
||||
|
||||
open Lean Meta
|
||||
/--
|
||||
A global `IO.Ref` containing the local symbol frequency map. This is initialized on first use.
|
||||
-/
|
||||
builtin_initialize localSymbolFrequencyMapRef : IO.Ref (Option (NameMap Nat)) ← IO.mkRef none
|
||||
|
||||
unsafe structure State where
|
||||
visited : PtrSet Expr := mkPtrSet
|
||||
visitedConsts : NameHashSet := {}
|
||||
/--
|
||||
A cached version of the local symbol frequency map.
|
||||
|
||||
unsafe abbrev FoldM := StateT State MetaM
|
||||
Note that the local symbol frequency map changes during elaboration of a file,
|
||||
so if this is called at different times it may give the wrong result.
|
||||
The intended use case is that it is only called by environment extension export functions,
|
||||
i.e. after all declarations have been elaborated.
|
||||
-/
|
||||
def cachedLocalSymbolFrequencyMap : MetaM (NameMap Nat) := do
|
||||
match ← localSymbolFrequencyMapRef.get with
|
||||
| some map => return map
|
||||
| none =>
|
||||
let map ← localSymbolFrequencyMap
|
||||
localSymbolFrequencyMapRef.set (some map)
|
||||
return map
|
||||
|
||||
unsafe def fold {α : Type} (f : Name → α → MetaM α) (e : Expr) (acc : α) : FoldM α :=
|
||||
let rec visit (e : Expr) (acc : α) : FoldM α := do
|
||||
if (← get).visited.contains e then
|
||||
return acc
|
||||
modify fun s => { s with visited := s.visited.insert e }
|
||||
if ← isProof e then
|
||||
-- Don't visit proofs.
|
||||
return acc
|
||||
match e with
|
||||
| .forallE n d b bi =>
|
||||
let r ← visit d acc
|
||||
withLocalDecl n bi d fun x =>
|
||||
visit (b.instantiate1 x) r
|
||||
| .lam n d b bi =>
|
||||
let r ← visit d acc
|
||||
withLocalDecl n bi d fun x =>
|
||||
visit (b.instantiate1 x) r
|
||||
| .mdata _ b => visit b acc
|
||||
| .letE n t v b nondep =>
|
||||
let r₁ ← visit t acc
|
||||
let r₂ ← visit v r₁
|
||||
withLetDecl n t v (nondep := nondep) fun x =>
|
||||
visit (b.instantiate1 x) r₂
|
||||
| .app f a =>
|
||||
let fi ← getFunInfo f (some 1)
|
||||
if fi.paramInfo[0]!.isInstImplicit then
|
||||
-- Don't visit implicit arguments.
|
||||
visit f acc
|
||||
else
|
||||
visit a (← visit f acc)
|
||||
| .proj _ _ b => visit b acc
|
||||
| .const c _ =>
|
||||
if (← get).visitedConsts.contains c then
|
||||
return acc
|
||||
else
|
||||
modify fun s => { s with visitedConsts := s.visitedConsts.insert c };
|
||||
f c acc
|
||||
| _ => return acc
|
||||
visit e acc
|
||||
/--
|
||||
Return the number of times a `Name` appears
|
||||
in the signatures of (non-internal) theorems in the current module,
|
||||
skipping instance arguments and proofs.
|
||||
|
||||
@[inline] unsafe def foldUnsafe {α : Type} (e : Expr) (init : α) (f : Name → α → MetaM α) : MetaM α :=
|
||||
(fold f e init).run' {}
|
||||
Note that this is cached, and so returns the frequency within theorems that had been elaborated
|
||||
when the function is first called (with any argument).
|
||||
-/
|
||||
public def localSymbolFrequency (n : Name) : MetaM Nat := do
|
||||
return (← cachedLocalSymbolFrequencyMap) |>.getD n 0
|
||||
|
||||
end FoldRelevantConstsImpl
|
||||
|
||||
/-- Apply `f` to every constant occurring in `e` once, skipping instance arguments and proofs. -/
|
||||
@[implemented_by FoldRelevantConstsImpl.foldUnsafe]
|
||||
opaque foldRelevantConsts {α : Type} (e : Expr) (init : α) (f : Name → α → MetaM α) : MetaM α := pure init
|
||||
|
||||
/-- Helper function for running `MetaM` code during module export. We have nothing but an `Environment` available. -/
|
||||
private def runMetaM [Inhabited α] (env : Environment) (x : MetaM α) : α :=
|
||||
/--
|
||||
Helper function for running `MetaM` code during module export, when there is nothing but an `Environment` available.
|
||||
Panics on errors.
|
||||
-/
|
||||
public def _root_.Lean.Environment.unsafeRunMetaM [Inhabited α] (env : Environment) (x : MetaM α) : α :=
|
||||
match unsafe unsafeEIO ((((withoutExporting x).run' {} {}).run' { fileName := "symbolFrequency", fileMap := default } { env })) with
|
||||
| Except.ok a => a
|
||||
| Except.error ex => panic! match unsafe unsafeIO ex.toMessageData.toString with
|
||||
|
|
@ -100,13 +90,7 @@ builtin_initialize symbolFrequencyExt : PersistentEnvExtension (NameMap Nat) Emp
|
|||
mkInitial := pure ∅
|
||||
addImportedFn := fun mapss _ => pure mapss
|
||||
addEntryFn := nofun
|
||||
exportEntriesFnEx := fun env _ _ => runMetaM env do
|
||||
let r ← env.constants.map₂.foldlM (init := (∅ : NameMap Nat)) (fun acc n ci => do
|
||||
if n.isInternalDetail || !Lean.wasOriginallyTheorem env n then
|
||||
pure acc
|
||||
else
|
||||
foldRelevantConsts ci.type (init := acc) fun n' acc => pure (acc.alter n' fun i? => some (i?.getD 0 + 1)))
|
||||
return #[r]
|
||||
exportEntriesFnEx := fun env _ _ => env.unsafeRunMetaM do return #[← cachedLocalSymbolFrequencyMap]
|
||||
statsFn := fun _ => "symbol frequency extension"
|
||||
}
|
||||
|
||||
|
|
@ -118,7 +102,7 @@ private local instance : Add (NameMap Nat) where
|
|||
add x y := y.foldl (init := x) fun x' n c => x'.insert n (x'.getD n 0 + c)
|
||||
|
||||
/-- The symbol frequency map for imported constants. This is initialized on first use. -/
|
||||
def symbolFrequencyMap : CoreM (NameMap Nat) := do
|
||||
public def symbolFrequencyMap : CoreM (NameMap Nat) := do
|
||||
match ← symbolFrequencyMapRef.get with
|
||||
| some map => return map
|
||||
| none =>
|
||||
|
|
|
|||
|
|
@ -6,21 +6,16 @@ example (a b : Int) : a + b = b + a := by
|
|||
suggest_premises
|
||||
sorry
|
||||
|
||||
-- #time
|
||||
example (x y z : List Int) : x ++ y ++ z = x ++ (y ++ z) := by
|
||||
suggest_premises
|
||||
sorry
|
||||
|
||||
-- `useRarity` is too slow in practice: it requires analyzing all the types in the environment.
|
||||
-- It would need to be cached.
|
||||
set_premise_selector Lean.PremiseSelection.mepoSelector (useRarity := true)
|
||||
|
||||
-- set_premise_selector Lean.PremiseSelection.mepoSelector (useRarity := true)
|
||||
example (a b : Int) : a + b = b + a := by
|
||||
suggest_premises
|
||||
sorry
|
||||
|
||||
-- example (a b : Int) : a + b = b + a := by
|
||||
-- suggest_premises
|
||||
-- sorry
|
||||
|
||||
-- #time
|
||||
-- example (x y z : List Int) : x ++ y ++ z = x ++ (y ++ z) := by
|
||||
-- suggest_premises
|
||||
-- sorry
|
||||
example (x y z : List Int) : x ++ y ++ z = x ++ (y ++ z) := by
|
||||
suggest_premises
|
||||
sorry
|
||||
|
|
|
|||
46
tests/lean/run/premise_selection_sine_qua_non.lean
Normal file
46
tests/lean/run/premise_selection_sine_qua_non.lean
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
module
|
||||
import all Lean.PremiseSelection.SineQuaNon
|
||||
import Lean.Meta.Basic
|
||||
import Std.Data.ExtHashMap
|
||||
|
||||
open Lean PremiseSelection SineQuaNon
|
||||
|
||||
set_premise_selector Lean.PremiseSelection.sineQuaNonSelector
|
||||
|
||||
example {x : Dyadic} {prec : Int} : x.roundDown prec ≤ x := by
|
||||
fail_if_success grind
|
||||
grind +premises
|
||||
|
||||
example {x : Dyadic} {prec : Int} : (x.roundUp prec).precision ≤ some prec := by
|
||||
fail_if_success grind
|
||||
grind +premises
|
||||
|
||||
/-- info: [(HAppend.hAppend, 1.000000)] -/
|
||||
#guard_msgs in
|
||||
run_meta do
|
||||
let r ← triggerSymbols (← getConstInfo `List.append_assoc)
|
||||
logInfo m!"{r}"
|
||||
|
||||
/-- info: [(HAppend.hAppend, 1.000000)] -/
|
||||
#guard_msgs in
|
||||
run_meta do
|
||||
let r ← sineQuaNonTriggersFor `List.append_assoc
|
||||
logInfo m!"{r}"
|
||||
|
||||
/-- info: true -/
|
||||
#guard_msgs in
|
||||
run_meta do
|
||||
let r ← sineQuaNonTheorems `Std.ExtHashMap.erase
|
||||
logInfo m!"{r.contains (`Std.ExtHashMap.getElem_erase, 1.00)}"
|
||||
|
||||
/-- info: [Std.ExtHashMap.contains, Std.ExtHashMap.erase] -/
|
||||
#guard_msgs in
|
||||
run_meta do
|
||||
let r ← triggerSymbols (← getConstInfo `Std.ExtHashMap.contains_erase)
|
||||
logInfo m!"{r.map (·.1)}"
|
||||
|
||||
/-- info: [Std.ExtHashMap.contains, Std.ExtHashMap.erase] -/
|
||||
#guard_msgs in
|
||||
run_meta do
|
||||
let r ← sineQuaNonTriggersFor `Std.ExtHashMap.contains_erase
|
||||
logInfo m!"{r.map (·.1)}"
|
||||
|
|
@ -7,4 +7,4 @@ open Lean PremiseSelection
|
|||
#guard_msgs in
|
||||
run_meta do
|
||||
let f ← symbolFrequency `Nat
|
||||
logInfo m!"{decide (10000 < f)}"
|
||||
logInfo m!"{decide (5000 < f)}"
|
||||
|
|
|
|||
|
|
@ -9,19 +9,19 @@ open Lean PremiseSelection
|
|||
#guard_msgs in
|
||||
run_meta do
|
||||
let ci ← getConstInfo `List.append_assoc
|
||||
let consts ← foldRelevantConsts ci.type (init := #[]) (fun n ns => return ns.push n)
|
||||
let consts ← ci.type.foldRelevantConstants (init := #[]) (fun n ns => return ns.push n)
|
||||
logInfo m!"{consts}"
|
||||
|
||||
/-- info: [List, Ne, HAppend.hAppend, List.nil, Eq, List.head] -/
|
||||
#guard_msgs in
|
||||
run_meta do
|
||||
let ci ← getConstInfo `List.head_append_right
|
||||
let consts ← foldRelevantConsts ci.type (init := #[]) (fun n ns => return ns.push n)
|
||||
let consts ← ci.type.foldRelevantConstants (init := #[]) (fun n ns => return ns.push n)
|
||||
logInfo m!"{consts}"
|
||||
|
||||
/-- info: [Array, Nat, LT.lt, Array.size, HAdd.hAdd, OfNat.ofNat, Array.swap, Not] -/
|
||||
#guard_msgs in
|
||||
run_meta do
|
||||
let ci ← getConstInfo `Array.eraseIdx.induct
|
||||
let consts ← foldRelevantConsts ci.type (init := #[]) (fun n ns => return ns.push n)
|
||||
let consts ← ci.type.foldRelevantConstants (init := #[]) (fun n ns => return ns.push n)
|
||||
logInfo m!"{consts}"
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue