fix: method spec theorems to be private when appropriate (#10406)

This PR improves upon #10302 to properly make the method spec theorems
private if the implementation function is not exposed.
This commit is contained in:
Joachim Breitner 2025-09-16 11:20:04 +02:00 committed by GitHub
parent a1cd945e82
commit ca10fd7c4f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 111 additions and 39 deletions

View file

@ -32,6 +32,8 @@ structure MethodSpecTheorem where
structure MethodSpecsInfo where
clsName : Name
/-- Whether the specs should be public or private -/
privateSpecs : Bool
/-- Array mapping field names to implementation functions. -/
fieldImpls : Array (Name × Name)
/-- rewrite rules to apply -/
@ -59,6 +61,7 @@ def getMethodSpecsInfo (instName : Name) : MetaM MethodSpecsInfo := do
let mut fieldImpls := #[]
let mut thms := #[]
let mut privateSpecs := false
for field in structInfo.fieldNames, arg in body.getAppArgsN structInfo.fieldNames.size do
if (← isProof arg) then continue
@ -72,6 +75,10 @@ def getMethodSpecsInfo (instName : Name) : MetaM MethodSpecsInfo := do
the instances' universe parameters\n {instInfo.levelParams.map mkLevelParam}"
unless xs == ys do
throwError "function `{f}` does not take its arguments in the same order as the instance"
let implName := f.constName!
let isExposed := !(← getEnv).header.isModule || (((← getEnv).setExporting true).find? implName).elim false (·.hasValue)
unless isExposed do
privateSpecs := true
-- Construct the replacement theorems
let some fieldInfo := getFieldInfo? (← getEnv) clsName field
| throwError "internal error: could not find field {field} in structure {clsName}"
@ -82,22 +89,23 @@ def getMethodSpecsInfo (instName : Name) : MetaM MethodSpecsInfo := do
let thm ← mkForallFVars xs eq
unless (← isDefEq lhs rhs) do
throwError "internal error: equation `{eq}` does not hold definitionally"
fieldImpls := fieldImpls.push (field, f.constName!)
thms := thms.push { name := f.constName!, levelParams := instInfo.levelParams, type := thm }
trace[Meta.MethodSpecs] "MethodSpecs for {instName}:\n{fieldImpls}\nthms: {thms.map (·.type)}"
fieldImpls := fieldImpls.push (field, implName)
thms := thms.push { name := implName, levelParams := instInfo.levelParams, type := thm }
trace[Meta.MethodSpecs] "MethodSpecs for {instName}:\n{fieldImpls}\n\
thms: {thms.map (·.type)}\nprivateSpecs: {privateSpecs}"
return {clsName, fieldImpls, thms}
return {clsName, fieldImpls, thms, privateSpecs}
public structure MethodSpecsAttrData where
clsName : Name
/-- Whether the specs should be public or private -/
privateSpecs : Bool
deriving Inhabited
def getParam (instName : Name) (_stx : Syntax) : AttrM MethodSpecsAttrData := do
-- Preflight check
let specsInfo ← (getMethodSpecsInfo instName).run'
return {
clsName := specsInfo.clsName
}
return { specsInfo with }
/--
Generate method specification theorems for the methods of the given type class instance.
@ -125,6 +133,31 @@ builtin_initialize methodSpecsSimpExtension : SimpExtension ←
registerSimpAttr `method_specs_simp
"simp lemma used to post-process the theorem created by `@[method_specs]`."
def mkSpecTheoremName (env : Environment) (instName : Name) (privateSpecs : Bool) (suffix : String) : Name :=
let thmName := instName.str suffix
if privateSpecs then mkPrivateName env thmName else thmName
def startsWithFollowedByNumber (s p : String) : Bool :=
s.startsWith p && (s.drop p.length).isNat
def isSpecThmLikeSuffix (fieldName : Name) (s : String) : Bool :=
s == s!"{fieldName}_spec" || startsWithFollowedByNumber s s!"{fieldName}_spec_"
/--
The spec theorem theorem for an instance can be private even if the instance itself is not.
So un-private the name here when looking for a declaration, and finally check if it matches.
Cf. `Lean.Meta.declFromEqLikeName`. Maybe worth collecting this logic in a central place.
-/
def isSpecThmNameFor (env : Environment) (name : Name) : Option Name := do
let .str p s := name | none
[p, privateToUserName p].firstM fun p => do
let attrData ← methodSpecsAttr.getParam? env p
for fieldName in getStructureFields env attrData.clsName do
if isSpecThmLikeSuffix fieldName s then
if name == mkSpecTheoremName env p attrData.privateSpecs s then
return p
none
def rewriteThm (ctx : Simp.Context) (simprocs : Simprocs)
(eqThmName destThmName : Name) : MetaM Unit := do
let thmInfo ← getConstVal eqThmName
@ -140,56 +173,52 @@ def rewriteThm (ctx : Simp.Context) (simprocs : Simprocs)
def genSpecs (instName : Name) : MetaM Unit := do
let methodSpecsInfo ← getMethodSpecsInfo instName
let key := instName.str s!"{methodSpecsInfo.fieldImpls[0]!.1}_spec"
let key := mkSpecTheoremName (← getEnv) instName methodSpecsInfo.privateSpecs s!"{methodSpecsInfo.fieldImpls[0]!.1}_spec"
realizeConst instName key doRealize
where
doRealize := do
let methodSpecsInfo ← getMethodSpecsInfo instName
let mut s ← methodSpecsSimpExtension.getTheorems
for thm in methodSpecsInfo.thms do
trace[Meta.MethodSpecs] "adding simp theorem for {thm.name} : {thm.type}"
s := s.addSimpTheorem <| ← mkDSimpTheorem (.other thm.name) thm.levelParams.toArray thm.type
let ctx ← Simp.mkContext
(simpTheorems := #[s])
(congrTheorems := (← getSimpCongrTheorems))
(config := { } ) -- Simp.neutralConfig with dsimp := true, letToHave := false })
let simprocs ← Simp.getSimprocs
withoutExporting (when := methodSpecsInfo.privateSpecs) do
let mut s ← methodSpecsSimpExtension.getTheorems
for thm in methodSpecsInfo.thms do
trace[Meta.MethodSpecs] "adding simp theorem for {thm.name} : {thm.type}"
s := s.addSimpTheorem <| ← mkDSimpTheorem (.other thm.name) thm.levelParams.toArray thm.type
let ctx ← Simp.mkContext
(simpTheorems := #[s])
(congrTheorems := (← getSimpCongrTheorems))
(config := { } ) -- Simp.neutralConfig with dsimp := true, letToHave := false })
let simprocs ← Simp.getSimprocs
for (fieldName, implName) in methodSpecsInfo.fieldImpls do
let some unfoldThm ← getUnfoldEqnFor? implName (nonRec := true)
| throwError "failed to generate unfolding theorem for {.ofConstName implName}"
rewriteThm ctx simprocs unfoldThm (instName.str s!"{fieldName}_spec")
let env ← getEnv
for (fieldName, implName) in methodSpecsInfo.fieldImpls do
let some unfoldThm ← getUnfoldEqnFor? implName (nonRec := true)
| throwError "failed to generate unfolding theorem for {.ofConstName implName}"
let thmName := mkSpecTheoremName env instName methodSpecsInfo.privateSpecs s!"{fieldName}_spec"
rewriteThm ctx simprocs unfoldThm thmName
if let some eqnThms ← getEqnsFor? implName then
for eqnThm in eqnThms, i in [:eqnThms.size] do
rewriteThm ctx simprocs eqnThm (instName.str s!"{fieldName}_spec_{i+1}")
if let some eqnThms ← getEqnsFor? implName then
for eqnThm in eqnThms, i in [:eqnThms.size] do
let thmName := mkSpecTheoremName env instName methodSpecsInfo.privateSpecs s!"{fieldName}_spec_{i+1}"
rewriteThm ctx simprocs eqnThm thmName
def startsWithFollowedByNumber (s p : String) : Bool :=
s.startsWith p && (s.drop p.length).isNat
def isSpecThmNameFor (env : Environment) (name : Name) : Option Name := do
let .str p n := name | none
let attrData ← methodSpecsAttr.getParam? env p
for fieldName in getStructureFields env attrData.clsName do
if n == s!"{fieldName}_spec" || startsWithFollowedByNumber n s!"{fieldName}_spec_" then
return p
none
public partial def getMethodSpecTheorem (instName : Name) (op : String) : MetaM (Option Name) := do
let env ← getEnv
let some _ := methodSpecsAttr.getParam? env instName | return none
realizeGlobalConstNoOverloadCore (instName.str s!"{op}_spec")
let some methodSpecInfos := methodSpecsAttr.getParam? env instName | return none
let thmName := mkSpecTheoremName env instName methodSpecInfos.privateSpecs s!"{op}_spec"
realizeGlobalConstNoOverloadCore thmName
public partial def getMethodSpecTheorems (instName : Name) (op : String) : MetaM (Option (Array Name)) := do
let some _ := methodSpecsAttr.getParam? (← getEnv) instName | return none
let some methodSpecInfos := methodSpecsAttr.getParam? (← getEnv) instName | return none
-- Realize spec theorems
let _ ← realizeGlobalConstNoOverloadCore (instName.str s!"{op}_spec")
let thmName := mkSpecTheoremName (← getEnv) instName methodSpecInfos.privateSpecs s!"{op}_spec"
let _ ← realizeGlobalConstNoOverloadCore thmName
-- Now collect the generated ones
let mut i := 0
let mut thms := #[]
let env ← getEnv
while true do
let thmName := instName.str s!"{op}_spec_{i+1}"
let thmName := mkSpecTheoremName (← getEnv) instName methodSpecInfos.privateSpecs s!"{op}_spec_{i+1}"
if env.containsOnBranch thmName then
thms := thms.push thmName
i := i + 1

View file

@ -1,5 +1,7 @@
#include "util/options.h"
// please update stage0
namespace lean {
options get_default_options() {
options opts;

View file

@ -1,4 +1,6 @@
module
-- set_option trace.Elab.Deriving.lawfulBEq true
-- set_option trace.Meta.MethodSpecs true
inductive L (α : Type u) where
| nil : L α
@ -9,3 +11,42 @@ example {n m : Nat} (h : n = m) :
(L.cons n (L.nil : L Nat) == L.cons m (L.nil : L Nat)) = true := by
simp [reduceBEq]
assumption
-- Module system interactions
namespace A
inductive L where | nil : L | cons : Nat → L → L deriving BEq
-- NB: Instance, op and theorem are private
/-- info: private def A.instBEqL : BEq L -/
#guard_msgs in #print sig instBEqL
/-- info: private def A.instBEqL.beq : L → L → Bool -/
#guard_msgs in #print sig instBEqL.beq
/-- info: private theorem A.instBEqL.beq_spec_1 : (L.nil == L.nil) = true -/
#guard_msgs(pass trace, all) in #print sig instBEqL.beq_spec_1
example : (L.cons n (L.nil : L) == L.cons m (L.nil : L)) ↔ n = m := by simp [reduceBEq]
end A
namespace B
public inductive L where | nil : L | cons : Nat → L → L deriving BEq
-- NB: Instance is public and exposed, op and theorem are private
/-- info: @[expose] def B.instBEqL : BEq L -/
#guard_msgs in #print sig instBEqL
/-- info: def B.instBEqL.beq : L → L → Bool -/
#guard_msgs in #print sig instBEqL.beq
-- NB: Private theorem
/-- info: private theorem B.instBEqL.beq_spec_1 : (L.nil == L.nil) = true -/
#guard_msgs(pass trace, all) in #print sig instBEqL.beq_spec_1
example : (L.cons n (L.nil : L) == L.cons m (L.nil : L)) ↔ n = m := by simp [reduceBEq]
end B
namespace C
public inductive L where | nil : L | cons : Nat → L → L deriving @[expose] BEq
-- NB: Public exposed instances, implementation and public theorem
/-- info: @[expose] def C.instBEqL : BEq L -/
#guard_msgs in #print sig instBEqL
/-- info: @[expose] def C.instBEqL.beq : L → L → Bool -/
#guard_msgs in #print sig instBEqL.beq
/-- info: theorem C.instBEqL.beq_spec_1 : (L.nil == L.nil) = true -/
#guard_msgs(pass trace, all) in #print sig instBEqL.beq_spec_1
example : (L.cons n (L.nil : L) == L.cons m (L.nil : L)) ↔ n = m := by simp [reduceBEq]
end C