From 7a8c2daf96ecdf6b41950bea40f5f1b4e31ecb73 Mon Sep 17 00:00:00 2001 From: Kim Morrison Date: Sun, 26 Oct 2025 15:57:26 +1100 Subject: [PATCH] feat: sine qua non premise selection --- src/Lean/Data/SMap.lean | 2 + src/Lean/PremiseSelection.lean | 1 + src/Lean/PremiseSelection/Basic.lean | 109 +++++++++- src/Lean/PremiseSelection/MePo.lean | 26 +-- src/Lean/PremiseSelection/SineQuaNon.lean | 196 ++++++++++++++++++ .../PremiseSelection/SymbolFrequency.lean | 110 +++++----- tests/lean/run/premise_selection_mepo.lean | 19 +- .../run/premise_selection_sine_qua_non.lean | 46 ++++ tests/lean/run/symbolFrequency.lean | 2 +- .../symbolFrequency_foldRelevantConsts.lean | 6 +- 10 files changed, 412 insertions(+), 105 deletions(-) create mode 100644 src/Lean/PremiseSelection/SineQuaNon.lean create mode 100644 tests/lean/run/premise_selection_sine_qua_non.lean diff --git a/src/Lean/Data/SMap.lean b/src/Lean/Data/SMap.lean index 41c4d79333..fdc76ae5c8 100644 --- a/src/Lean/Data/SMap.lean +++ b/src/Lean/Data/SMap.lean @@ -32,7 +32,9 @@ namespace Lean -/ structure SMap (α : Type u) (β : Type v) [BEq α] [Hashable α] where stage₁ : Bool := true + /-- Imported constants. -/ map₁ : Std.HashMap α β := {} + /-- Local constants defined in the current module. -/ map₂ : PHashMap α β := {} namespace SMap diff --git a/src/Lean/PremiseSelection.lean b/src/Lean/PremiseSelection.lean index 4c3804ba8e..bc6af08a22 100644 --- a/src/Lean/PremiseSelection.lean +++ b/src/Lean/PremiseSelection.lean @@ -9,3 +9,4 @@ prelude import Lean.PremiseSelection.Basic import Lean.PremiseSelection.SymbolFrequency import Lean.PremiseSelection.MePo +import Lean.PremiseSelection.SineQuaNon diff --git a/src/Lean/PremiseSelection/Basic.lean b/src/Lean/PremiseSelection/Basic.lean index 1ec615f646..5efeeef304 100644 --- a/src/Lean/PremiseSelection/Basic.lean +++ b/src/Lean/PremiseSelection/Basic.lean @@ -27,6 +27,90 @@ Lean does not provide a default premise selector, so this module is intended to with a downstream package which registers a premise selector. -/ +namespace Lean.Expr.FoldRelevantConstantsImpl + +open Lean Meta + +unsafe structure State where + visited : PtrSet Expr := mkPtrSet + visitedConsts : NameHashSet := {} + +unsafe abbrev FoldM := StateT State MetaM + +unsafe def fold {α : Type} (f : Name → α → MetaM α) (e : Expr) (acc : α) : FoldM α := + let rec visit (e : Expr) (acc : α) : FoldM α := do + if (← get).visited.contains e then + return acc + modify fun s => { s with visited := s.visited.insert e } + if ← isProof e then + -- Don't visit proofs. + return acc + match e with + | .forallE n d b bi => + let r ← visit d acc + withLocalDecl n bi d fun x => + visit (b.instantiate1 x) r + | .lam n d b bi => + let r ← visit d acc + withLocalDecl n bi d fun x => + visit (b.instantiate1 x) r + | .mdata _ b => visit b acc + | .letE n t v b nondep => + let r₁ ← visit t acc + let r₂ ← visit v r₁ + withLetDecl n t v (nondep := nondep) fun x => + visit (b.instantiate1 x) r₂ + | .app f a => + let fi ← getFunInfo f (some 1) + if fi.paramInfo[0]!.isInstImplicit then + -- Don't visit implicit arguments. + visit f acc + else + visit a (← visit f acc) + | .proj _ _ b => visit b acc + | .const c _ => + if (← get).visitedConsts.contains c then + return acc + else + modify fun s => { s with visitedConsts := s.visitedConsts.insert c } + if ← isInstance c then + return acc + else + f c acc + | _ => return acc + visit e acc + +@[inline] unsafe def foldUnsafe {α : Type} (e : Expr) (init : α) (f : Name → α → MetaM α) : MetaM α := + (fold f e init).run' {} + +end FoldRelevantConstantsImpl + +/-- Apply `f` to every constant occurring in `e` once, skipping instance arguments and proofs. -/ +@[implemented_by FoldRelevantConstantsImpl.foldUnsafe] +public opaque foldRelevantConstants {α : Type} (e : Expr) (init : α) (f : Name → α → MetaM α) : MetaM α := pure init + +/-- Collect the constants occuring in `e` (once each), skipping instance arguments and proofs. -/ +public def relevantConstants (e : Expr) : MetaM (Array Name) := foldRelevantConstants e #[] (fun n ns => return ns.push n) + +/-- Collect the constants occuring in `e` (once each), skipping instance arguments and proofs. -/ +public def relevantConstantsAsSet (e : Expr) : MetaM NameSet := foldRelevantConstants e ∅ (fun n ns => return ns.insert n) + +end Lean.Expr + +open Lean Meta MVarId in +public def Lean.MVarId.getConstants (g : MVarId) : MetaM NameSet := withContext g do + let mut c := (← g.getType).getUsedConstantsAsSet + for t in (← getLocalHyps) do + c := c ∪ (← inferType t).getUsedConstantsAsSet + return c + +open Lean Meta MVarId in +public def Lean.MVarId.getRelevantConstants (g : MVarId) : MetaM NameSet := withContext g do + let mut c ← (← g.getType).relevantConstantsAsSet + for t in (← getLocalHyps) do + c := c ∪ (← (← inferType t).relevantConstantsAsSet) + return c + @[expose] public section namespace Lean.PremiseSelection @@ -130,25 +214,37 @@ end Selector section DenyList -/-- Premises from a module whose name has one of the following components are not retrieved. -/ +/-- +Premises from a module whose name has one of the following components are not retrieved. + +Use `run_cmd modifyEnv fun env => moduleDenyListExt.addEntry env module` to add a module to the deny list. +-/ builtin_initialize moduleDenyListExt : SimplePersistentEnvExtension String (List String) ← registerSimplePersistentEnvExtension { addEntryFn := (·.cons) - addImportedFn := mkStateFromImportedEntries (·.cons) [] + addImportedFn := mkStateFromImportedEntries (·.cons) ["Lake", "Lean", "Internal", "Tactic"] } -/-- A premise whose name has one of the following components is not retrieved. -/ +/-- +A premise whose name has one of the following components is not retrieved. + +Use `run_cmd modifyEnv fun env => nameDenyListExt.addEntry env name` to add a name to the deny list. +-/ builtin_initialize nameDenyListExt : SimplePersistentEnvExtension String (List String) ← registerSimplePersistentEnvExtension { addEntryFn := (·.cons) - addImportedFn := mkStateFromImportedEntries (·.cons) [] + addImportedFn := mkStateFromImportedEntries (·.cons) ["Lake", "Lean", "Internal", "Tactic"] } -/-- A premise whose `type.getForallBody.getAppFn` is a constant that has one of these prefixes is not retrieved. -/ +/-- +A premise whose `type.getForallBody.getAppFn` is a constant that has one of these prefixes is not retrieved. + +Use `run_cmd modifyEnv fun env => typePrefixDenyListExt.addEntry env typePrefix` to add a type prefix to the deny list. +-/ builtin_initialize typePrefixDenyListExt : SimplePersistentEnvExtension Name (List Name) ← registerSimplePersistentEnvExtension { addEntryFn := (·.cons) - addImportedFn := mkStateFromImportedEntries (·.cons) [] + addImportedFn := mkStateFromImportedEntries (·.cons) [`Lake, `Lean] } def isDeniedModule (env : Environment) (moduleName : Name) : Bool := @@ -157,6 +253,7 @@ def isDeniedModule (env : Environment) (moduleName : Name) : Bool := def isDeniedPremise (env : Environment) (name : Name) : Bool := Id.run do if name == ``sorryAx then return true if name.isInternalDetail then return true + if Lean.Meta.isInstanceCore env name then return true if (nameDenyListExt.getState env).any (fun p => name.anyS (· == p)) then return true if let some moduleIdx := env.getModuleIdxFor? name then let moduleName := env.header.moduleNames[moduleIdx.toNat]! diff --git a/src/Lean/PremiseSelection/MePo.lean b/src/Lean/PremiseSelection/MePo.lean index 108090789a..531afca6ab 100644 --- a/src/Lean/PremiseSelection/MePo.lean +++ b/src/Lean/PremiseSelection/MePo.lean @@ -7,6 +7,7 @@ module prelude public import Lean.PremiseSelection.Basic +import Lean.PremiseSelection.SymbolFrequency import Lean.Meta.Basic /-! @@ -24,14 +25,6 @@ namespace Lean.PremiseSelection.MePo builtin_initialize registerTraceClass `mepo -def symbolFrequency (env : Environment) : NameMap Nat := Id.run do - -- TODO: ideally this would use a precomputed frequency map, as this is too slow. - let mut map := {} - for (_, ci) in env.constants do - for n' in ci.type.getUsedConstantsAsSet do - map := map.alter n' fun i? => some (i?.getD 0 + 1) - return map - def weightedScore (weight : Name → Float) (relevant candidate : NameSet) : Float := let S := candidate let R := relevant ∩ S @@ -71,26 +64,19 @@ def mepo (initialRelevant : NameSet) (score : NameSet → NameSet → Float) (ac p := p + (1 - p) / c return accepted.qsort (fun a b => a.score > b.score) -open Lean Meta MVarId in -def _root_.Lean.MVarId.getConstants (g : MVarId) : MetaM NameSet := withContext g do - let mut c := (← g.getType).getUsedConstantsAsSet - for t in (← getLocalHyps) do - c := c ∪ (← inferType t).getUsedConstantsAsSet - return c - end MePo open MePo -- The values of p := 0.6 and c := 2.4 are taken from the MePo paper, and need to be tuned. public def mepoSelector (useRarity : Bool) (p : Float := 0.6) (c : Float := 2.4) : Selector := fun g config => do - let constants ← g.getConstants + let constants ← g.getRelevantConstants let env ← getEnv - let score := if useRarity then - let frequency := symbolFrequency env - frequencyScore (frequency.getD · 0) + let score ← if useRarity then do + let frequency ← symbolFrequencyMap + pure <| frequencyScore (fun n => frequency.getD n 0) else - unweightedScore + pure <| unweightedScore let accept := fun ci => return !isDeniedPremise env ci.name let suggestions ← mepo constants score accept config.maxSuggestions p c let suggestions := suggestions diff --git a/src/Lean/PremiseSelection/SineQuaNon.lean b/src/Lean/PremiseSelection/SineQuaNon.lean new file mode 100644 index 0000000000..a8ca7521da --- /dev/null +++ b/src/Lean/PremiseSelection/SineQuaNon.lean @@ -0,0 +1,196 @@ +/- +Copyright (c) 2025 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Kim Morrison +-/ +module + +prelude +public import Lean.CoreM +public import Lean.Meta.Basic +import Lean.Meta.Instances +import Lean.PremiseSelection.SymbolFrequency +public import Lean.PremiseSelection.Basic + +/-! +# Sine Qua Non premise selection + +This is an implementation of the "Sine Qua Non" premise selection algorithm, from +"Sine Qua Non for Large Theory Reasoning" by Hodor and Voronkov. + +It needs to be tuned and evaluated for Lean. +-/ + +namespace Lean.PremiseSelection.SineQuaNon + +builtin_initialize registerTraceClass `sineQuaNon + +/-- +Constants which should not be used as triggers. + +Use `run_cmd modifyEnv fun env => triggerDenyListExt.addEntry env trigger` to add a trigger to the deny list. +-/ +builtin_initialize triggerDenyListExt : SimplePersistentEnvExtension Name NameSet ← + registerSimplePersistentEnvExtension { + addEntryFn := (·.insert) + addImportedFn := mkStateFromImportedEntries (·.insert) + (NameSet.ofList [`Eq, `BEq, `BEq.beq, `LE.le, `LT.lt, `GE.ge, `GT.gt, + `Bool.not, `Bool.and, `Bool.or, `Bool.xor, `Bool.true, `Bool.false, + `Not, `And, `Or, `Xor, + `ite, `dite, `Exists, `OfNat, `OfNat.ofNat, `SizeOf, `SizeOf.sizeOf]) + } + +/-- +Return the relevant constants (i.e. ignoring instances and proofs) +which appear in the type of `ci` and which are approximately least frequent in the library +(relative to other constants appearing in the type of `ci`). +-/ +def triggerSymbols (ci : ConstantInfo) (maxTolerance : Float := 3.0) : MetaM (Array (Name × Float)) := do + let denyList := triggerDenyListExt.getState (← getEnv) + let consts ← ci.type.relevantConstants + let frequencies ← consts.filterMapM fun n => do + if denyList.contains n then + return none + let f := (← symbolFrequency n) + (← localSymbolFrequency n) + return if f = 0 then + none + else + some (n, f.toFloat) + if frequencies.isEmpty then + return #[] + let minFrequency := frequencies.foldl (fun acc (_, f) => min acc f) (frequencies[0]!.2) + return frequencies.filterMap + (fun (n, f) => if f ≤ minFrequency * maxTolerance then some (n, f / minFrequency) else none) + +def _root_.List.orderedInsert (r : α → α → Bool := by exact (· ≤ ·)) (a : α) : List α → List α + | [] => [a] + | b :: l => if r a b then a :: b :: l else b :: orderedInsert r a l + +def insertTrigger (map : NameMap (List (Name × Float))) (trigger decl : Name) (tolerance : Float) : + NameMap (List (Name × Float)) := + map.insert trigger (map.getD trigger [] |>.orderedInsert (fun x y => x.2 ≤ y.2) (decl, tolerance)) + +def prepareTriggers (names : Array Name) (maxTolerance : Float := 3.0) : MetaM (NameMap (List (Name × Float))) := do + let mut map := {} + let env ← getEnv + let names := names.filter fun n => + !isDeniedPremise env n && Lean.wasOriginallyTheorem env n + for name in names do + let triggers ← triggerSymbols (← getConstInfo name) maxTolerance + for (trigger, tolerance) in triggers do + map := insertTrigger map trigger name tolerance + return map + +/-- +Combine two trigger maps, taking the sorted union of the triggered theorems for each symbol. +If one map is much larger than the other, it should be the first argument. +-/ +def combineTriggers (map₁ map₂ : NameMap (List (Name × Float))) : NameMap (List (Name × Float)) := Id.run do + let mut map := map₁ + for (trigger, decls₂) in map₂ do + map := match map₁.find? trigger with + | none => map.insert trigger decls₂ + | some decls₁ => map.insert trigger (decls₂.foldl (init := decls₁) (fun acc (decl, tolerance) => acc.orderedInsert (fun x y => x.2 ≤ y.2) (decl, tolerance))) + return map + +/-- +The state is just an array of array of maps. +We don't assemble these on import for efficiency reasons: most modules will not query this extension. + +Instead, we use an `IO.Ref` below so that within each module we can assemble the global `NameMap (List (Name × Float))` once. + +Since we never modify the extension state except on export, the `IO.Ref` does not need updating after first access. +-/ +builtin_initialize sineQuaNonExt : PersistentEnvExtension (NameMap (List (Name × Float))) Empty (Array (Array (NameMap (List (Name × Float))))) ← + registerPersistentEnvExtension { + name := `sineQueNon + mkInitial := pure ∅ + addImportedFn := fun mapss _ => pure mapss + addEntryFn := nofun + -- TODO: it would be nice to avoid the `toArray` here, e.g. via iterators. + exportEntriesFnEx := fun env _ _ => env.unsafeRunMetaM do return #[← prepareTriggers (env.constants.map₂.toArray.map (·.1))] + statsFn := fun _ => "sine qua non premise selection extension" + } + +/-- A global `IO.Ref` containing the "sine qua non" triggers. This is initialized on first use. -/ +builtin_initialize sineQuaNonTriggersRef : IO.Ref (Option (NameMap (List (Name × Float)))) ← IO.mkRef none + +/-- The "sine qua non" triggers for imported constants. This is initialized on first use. -/ +def sineQuaNonTriggerMap : CoreM (NameMap (List (Name × Float))) := do + match ← sineQuaNonTriggersRef.get with + | some map => return map + | none => + let mapss := sineQuaNonExt.getState (← getEnv) + let map := mapss.foldl (init := {}) fun acc maps => maps.foldl (init := acc) fun acc map => combineTriggers acc map + sineQuaNonTriggersRef.set (some map) + return map + +public def sineQuaNonTheorems (trigger : Name) : CoreM (List (Name × Float)) := do + let map ← sineQuaNonTriggerMap + return map.getD trigger [] + +def sineQuaNonTriggersFor (decl : Name) : CoreM (List (Name × Float)) := do + let r ← sineQuaNonTriggerMap + return r.toList.filterMap fun (t, v) => + (v.find? fun (n, _) => n == decl) |>.map fun (_, f) => (t, f) + +local instance : Ord (Float × Name) where + compare x y := if x.1 < y.1 then .lt else if x.1 > y.1 then .gt else Name.cmp x.2 y.2 + +def frequencyScore (n : Name) (frequencyWeight : Float := 0.01) : MetaM Float := do + let f ← symbolFrequency n + return 1.0 + frequencyWeight * (f + 1).toFloat.log2 + +/-- +This isn't exactly what's described in the paper. + +We select theorems in a priority order, where the priority is `1.5 ^ (trigger depth) * Π (tolerances)`. + +The `1.5` factor could be tuned. +-/ +public partial def sineQuaNon (names : NameSet) (maxSuggestions : Nat) (depthFactor := 1.5) (frequencyWeight : Float := 0.01) : + MetaM (Array Suggestion) := do + let denyList := triggerDenyListExt.getState (← getEnv) + let targets := names \ denyList + let r ← go denyList targets + (Std.TreeSet.ofList (← targets.toList.mapM (fun n => return (← frequencyScore n, n)))) #[] {} + return r.map (fun (n, f) => { name := n, score := 1 / f }) +where go (denyList : NameSet)(pastTriggers : NameSet) (triggerQueue : Std.TreeSet (Float × Name) compare) + (acceptedTheorems : Array (Name × Float)) (queuedTheorems : Std.TreeSet (Float × Name) compare) : MetaM (Array (Name × Float)) := do + if acceptedTheorems.size ≥ maxSuggestions then return acceptedTheorems else + -- Is there a companion to `min?` that gives the minimum element along with the rest of the set? + match triggerQueue.min? with + | some (tf, t) => do + let qf? := queuedTheorems.min?.map (·.1) + if match qf? with | none => true | some qf => tf < qf then + trace[sineQuaNon] m!"\ + acceptedTheorems: {acceptedTheorems}\n\ + pastTriggers: {pastTriggers.toList}\n\ + triggerQueue: {triggerQueue.toList}\n\ + queuedTheorems: {queuedTheorems.toList}" + let theorems ← sineQuaNonTheorems t + return ← go denyList pastTriggers (triggerQueue.erase (tf, t)) acceptedTheorems + (theorems.foldl (init := queuedTheorems) fun acc (p, pf) => acc.insert (pf * tf, p)) + | none => pure () + match queuedTheorems.min? with + | none => return acceptedTheorems + | some (qf, q) => + let ci ← getConstInfo q + let (pastTriggers', triggersQueue') ← (← ci.type.relevantConstants).foldlM (init := (pastTriggers, triggerQueue)) + fun ⟨pastTriggers', triggersQueue'⟩ n => do + if pastTriggers'.contains n || denyList.contains n then + pure ⟨pastTriggers', triggersQueue'⟩ + else + pure <| ⟨pastTriggers'.insert n, triggersQueue'.insert (qf * depthFactor * (← frequencyScore n frequencyWeight), n)⟩ + go denyList pastTriggers' triggersQueue' (acceptedTheorems.push (q, qf)) (queuedTheorems.erase (qf, q)) + +end SineQuaNon + +open SineQuaNon + +public def sineQuaNonSelector (depthFactor : Float := 1.5) : Selector := fun g config => do + let constants ← g.getRelevantConstants + let suggestions ← sineQuaNon constants config.maxSuggestions depthFactor + return suggestions.take config.maxSuggestions + +end Lean.PremiseSelection diff --git a/src/Lean/PremiseSelection/SymbolFrequency.lean b/src/Lean/PremiseSelection/SymbolFrequency.lean index 072779589e..3290c4e359 100644 --- a/src/Lean/PremiseSelection/SymbolFrequency.lean +++ b/src/Lean/PremiseSelection/SymbolFrequency.lean @@ -7,9 +7,11 @@ module prelude public import Lean.CoreM +public import Lean.Meta.Basic import Lean.Meta.InferType import Lean.Meta.FunInfo import Lean.AddDecl +import Lean.PremiseSelection.Basic /-! # Symbol frequency @@ -19,67 +21,55 @@ This module provides a persistent environment extension for computing the freque namespace Lean.PremiseSelection -namespace FoldRelevantConstsImpl +/-- +Collect the frequencies for constants occurring in declarations defined in the current module, +skipping instance arguments and proofs. +-/ +public def localSymbolFrequencyMap : MetaM (NameMap Nat) := do + let env := (← getEnv) + env.constants.map₂.foldlM (init := ∅) (fun acc m ci => do + if isDeniedPremise env m || !Lean.wasOriginallyTheorem env m then + pure acc + else + ci.type.foldRelevantConstants (init := acc) fun n' acc => return acc.alter n' fun i? => some (i?.getD 0 + 1)) -open Lean Meta +/-- +A global `IO.Ref` containing the local symbol frequency map. This is initialized on first use. +-/ +builtin_initialize localSymbolFrequencyMapRef : IO.Ref (Option (NameMap Nat)) ← IO.mkRef none -unsafe structure State where - visited : PtrSet Expr := mkPtrSet - visitedConsts : NameHashSet := {} +/-- +A cached version of the local symbol frequency map. -unsafe abbrev FoldM := StateT State MetaM +Note that the local symbol frequency map changes during elaboration of a file, +so if this is called at different times it may give the wrong result. +The intended use case is that it is only called by environment extension export functions, +i.e. after all declarations have been elaborated. +-/ +def cachedLocalSymbolFrequencyMap : MetaM (NameMap Nat) := do + match ← localSymbolFrequencyMapRef.get with + | some map => return map + | none => + let map ← localSymbolFrequencyMap + localSymbolFrequencyMapRef.set (some map) + return map -unsafe def fold {α : Type} (f : Name → α → MetaM α) (e : Expr) (acc : α) : FoldM α := - let rec visit (e : Expr) (acc : α) : FoldM α := do - if (← get).visited.contains e then - return acc - modify fun s => { s with visited := s.visited.insert e } - if ← isProof e then - -- Don't visit proofs. - return acc - match e with - | .forallE n d b bi => - let r ← visit d acc - withLocalDecl n bi d fun x => - visit (b.instantiate1 x) r - | .lam n d b bi => - let r ← visit d acc - withLocalDecl n bi d fun x => - visit (b.instantiate1 x) r - | .mdata _ b => visit b acc - | .letE n t v b nondep => - let r₁ ← visit t acc - let r₂ ← visit v r₁ - withLetDecl n t v (nondep := nondep) fun x => - visit (b.instantiate1 x) r₂ - | .app f a => - let fi ← getFunInfo f (some 1) - if fi.paramInfo[0]!.isInstImplicit then - -- Don't visit implicit arguments. - visit f acc - else - visit a (← visit f acc) - | .proj _ _ b => visit b acc - | .const c _ => - if (← get).visitedConsts.contains c then - return acc - else - modify fun s => { s with visitedConsts := s.visitedConsts.insert c }; - f c acc - | _ => return acc - visit e acc +/-- +Return the number of times a `Name` appears +in the signatures of (non-internal) theorems in the current module, +skipping instance arguments and proofs. -@[inline] unsafe def foldUnsafe {α : Type} (e : Expr) (init : α) (f : Name → α → MetaM α) : MetaM α := - (fold f e init).run' {} +Note that this is cached, and so returns the frequency within theorems that had been elaborated +when the function is first called (with any argument). +-/ +public def localSymbolFrequency (n : Name) : MetaM Nat := do + return (← cachedLocalSymbolFrequencyMap) |>.getD n 0 -end FoldRelevantConstsImpl - -/-- Apply `f` to every constant occurring in `e` once, skipping instance arguments and proofs. -/ -@[implemented_by FoldRelevantConstsImpl.foldUnsafe] -opaque foldRelevantConsts {α : Type} (e : Expr) (init : α) (f : Name → α → MetaM α) : MetaM α := pure init - -/-- Helper function for running `MetaM` code during module export. We have nothing but an `Environment` available. -/ -private def runMetaM [Inhabited α] (env : Environment) (x : MetaM α) : α := +/-- +Helper function for running `MetaM` code during module export, when there is nothing but an `Environment` available. +Panics on errors. +-/ +public def _root_.Lean.Environment.unsafeRunMetaM [Inhabited α] (env : Environment) (x : MetaM α) : α := match unsafe unsafeEIO ((((withoutExporting x).run' {} {}).run' { fileName := "symbolFrequency", fileMap := default } { env })) with | Except.ok a => a | Except.error ex => panic! match unsafe unsafeIO ex.toMessageData.toString with @@ -100,13 +90,7 @@ builtin_initialize symbolFrequencyExt : PersistentEnvExtension (NameMap Nat) Emp mkInitial := pure ∅ addImportedFn := fun mapss _ => pure mapss addEntryFn := nofun - exportEntriesFnEx := fun env _ _ => runMetaM env do - let r ← env.constants.map₂.foldlM (init := (∅ : NameMap Nat)) (fun acc n ci => do - if n.isInternalDetail || !Lean.wasOriginallyTheorem env n then - pure acc - else - foldRelevantConsts ci.type (init := acc) fun n' acc => pure (acc.alter n' fun i? => some (i?.getD 0 + 1))) - return #[r] + exportEntriesFnEx := fun env _ _ => env.unsafeRunMetaM do return #[← cachedLocalSymbolFrequencyMap] statsFn := fun _ => "symbol frequency extension" } @@ -118,7 +102,7 @@ private local instance : Add (NameMap Nat) where add x y := y.foldl (init := x) fun x' n c => x'.insert n (x'.getD n 0 + c) /-- The symbol frequency map for imported constants. This is initialized on first use. -/ -def symbolFrequencyMap : CoreM (NameMap Nat) := do +public def symbolFrequencyMap : CoreM (NameMap Nat) := do match ← symbolFrequencyMapRef.get with | some map => return map | none => diff --git a/tests/lean/run/premise_selection_mepo.lean b/tests/lean/run/premise_selection_mepo.lean index 3d6abf3002..4eb8ed0cf7 100644 --- a/tests/lean/run/premise_selection_mepo.lean +++ b/tests/lean/run/premise_selection_mepo.lean @@ -6,21 +6,16 @@ example (a b : Int) : a + b = b + a := by suggest_premises sorry --- #time example (x y z : List Int) : x ++ y ++ z = x ++ (y ++ z) := by suggest_premises sorry --- `useRarity` is too slow in practice: it requires analyzing all the types in the environment. --- It would need to be cached. +set_premise_selector Lean.PremiseSelection.mepoSelector (useRarity := true) --- set_premise_selector Lean.PremiseSelection.mepoSelector (useRarity := true) +example (a b : Int) : a + b = b + a := by + suggest_premises + sorry --- example (a b : Int) : a + b = b + a := by --- suggest_premises --- sorry - --- #time --- example (x y z : List Int) : x ++ y ++ z = x ++ (y ++ z) := by --- suggest_premises --- sorry +example (x y z : List Int) : x ++ y ++ z = x ++ (y ++ z) := by + suggest_premises + sorry diff --git a/tests/lean/run/premise_selection_sine_qua_non.lean b/tests/lean/run/premise_selection_sine_qua_non.lean new file mode 100644 index 0000000000..6c45d7f271 --- /dev/null +++ b/tests/lean/run/premise_selection_sine_qua_non.lean @@ -0,0 +1,46 @@ +module +import all Lean.PremiseSelection.SineQuaNon +import Lean.Meta.Basic +import Std.Data.ExtHashMap + +open Lean PremiseSelection SineQuaNon + +set_premise_selector Lean.PremiseSelection.sineQuaNonSelector + +example {x : Dyadic} {prec : Int} : x.roundDown prec ≤ x := by + fail_if_success grind + grind +premises + +example {x : Dyadic} {prec : Int} : (x.roundUp prec).precision ≤ some prec := by + fail_if_success grind + grind +premises + +/-- info: [(HAppend.hAppend, 1.000000)] -/ +#guard_msgs in +run_meta do + let r ← triggerSymbols (← getConstInfo `List.append_assoc) + logInfo m!"{r}" + +/-- info: [(HAppend.hAppend, 1.000000)] -/ +#guard_msgs in +run_meta do + let r ← sineQuaNonTriggersFor `List.append_assoc + logInfo m!"{r}" + +/-- info: true -/ +#guard_msgs in +run_meta do + let r ← sineQuaNonTheorems `Std.ExtHashMap.erase + logInfo m!"{r.contains (`Std.ExtHashMap.getElem_erase, 1.00)}" + +/-- info: [Std.ExtHashMap.contains, Std.ExtHashMap.erase] -/ +#guard_msgs in +run_meta do + let r ← triggerSymbols (← getConstInfo `Std.ExtHashMap.contains_erase) + logInfo m!"{r.map (·.1)}" + +/-- info: [Std.ExtHashMap.contains, Std.ExtHashMap.erase] -/ +#guard_msgs in +run_meta do + let r ← sineQuaNonTriggersFor `Std.ExtHashMap.contains_erase + logInfo m!"{r.map (·.1)}" diff --git a/tests/lean/run/symbolFrequency.lean b/tests/lean/run/symbolFrequency.lean index a1eb800364..cf24567383 100644 --- a/tests/lean/run/symbolFrequency.lean +++ b/tests/lean/run/symbolFrequency.lean @@ -7,4 +7,4 @@ open Lean PremiseSelection #guard_msgs in run_meta do let f ← symbolFrequency `Nat - logInfo m!"{decide (10000 < f)}" + logInfo m!"{decide (5000 < f)}" diff --git a/tests/lean/run/symbolFrequency_foldRelevantConsts.lean b/tests/lean/run/symbolFrequency_foldRelevantConsts.lean index ce425c0aa9..c724b32285 100644 --- a/tests/lean/run/symbolFrequency_foldRelevantConsts.lean +++ b/tests/lean/run/symbolFrequency_foldRelevantConsts.lean @@ -9,19 +9,19 @@ open Lean PremiseSelection #guard_msgs in run_meta do let ci ← getConstInfo `List.append_assoc - let consts ← foldRelevantConsts ci.type (init := #[]) (fun n ns => return ns.push n) + let consts ← ci.type.foldRelevantConstants (init := #[]) (fun n ns => return ns.push n) logInfo m!"{consts}" /-- info: [List, Ne, HAppend.hAppend, List.nil, Eq, List.head] -/ #guard_msgs in run_meta do let ci ← getConstInfo `List.head_append_right - let consts ← foldRelevantConsts ci.type (init := #[]) (fun n ns => return ns.push n) + let consts ← ci.type.foldRelevantConstants (init := #[]) (fun n ns => return ns.push n) logInfo m!"{consts}" /-- info: [Array, Nat, LT.lt, Array.size, HAdd.hAdd, OfNat.ofNat, Array.swap, Not] -/ #guard_msgs in run_meta do let ci ← getConstInfo `Array.eraseIdx.induct - let consts ← foldRelevantConsts ci.type (init := #[]) (fun n ns => return ns.push n) + let consts ← ci.type.foldRelevantConstants (init := #[]) (fun n ns => return ns.push n) logInfo m!"{consts}"