From ca10fd7c4fb8ce39dad2d3bad542408f17fb5d50 Mon Sep 17 00:00:00 2001 From: Joachim Breitner Date: Tue, 16 Sep 2025 11:20:04 +0200 Subject: [PATCH] fix: method spec theorems to be private when appropriate (#10406) This PR improves upon #10302 to properly make the method spec theorems private if the implementation function is not exposed. --- src/Lean/Meta/MethodSpecs.lean | 107 +++++++++++++++++---------- src/stdlib_flags.h | 2 + tests/lean/run/reduceBEqSimproc.lean | 41 ++++++++++ 3 files changed, 111 insertions(+), 39 deletions(-) diff --git a/src/Lean/Meta/MethodSpecs.lean b/src/Lean/Meta/MethodSpecs.lean index dd822215c0..0b8a2865c9 100644 --- a/src/Lean/Meta/MethodSpecs.lean +++ b/src/Lean/Meta/MethodSpecs.lean @@ -32,6 +32,8 @@ structure MethodSpecTheorem where structure MethodSpecsInfo where clsName : Name + /-- Whether the specs should be public or private -/ + privateSpecs : Bool /-- Array mapping field names to implementation functions. -/ fieldImpls : Array (Name × Name) /-- rewrite rules to apply -/ @@ -59,6 +61,7 @@ def getMethodSpecsInfo (instName : Name) : MetaM MethodSpecsInfo := do let mut fieldImpls := #[] let mut thms := #[] + let mut privateSpecs := false for field in structInfo.fieldNames, arg in body.getAppArgsN structInfo.fieldNames.size do if (← isProof arg) then continue @@ -72,6 +75,10 @@ def getMethodSpecsInfo (instName : Name) : MetaM MethodSpecsInfo := do the instances' universe parameters\n {instInfo.levelParams.map mkLevelParam}" unless xs == ys do throwError "function `{f}` does not take its arguments in the same order as the instance" + let implName := f.constName! + let isExposed := !(← getEnv).header.isModule || (((← getEnv).setExporting true).find? implName).elim false (·.hasValue) + unless isExposed do + privateSpecs := true -- Construct the replacement theorems let some fieldInfo := getFieldInfo? (← getEnv) clsName field | throwError "internal error: could not find field {field} in structure {clsName}" @@ -82,22 +89,23 @@ def getMethodSpecsInfo (instName : Name) : MetaM MethodSpecsInfo := do let thm ← mkForallFVars xs eq unless (← isDefEq lhs rhs) do throwError "internal error: equation `{eq}` does not hold definitionally" - fieldImpls := fieldImpls.push (field, f.constName!) - thms := thms.push { name := f.constName!, levelParams := instInfo.levelParams, type := thm } - trace[Meta.MethodSpecs] "MethodSpecs for {instName}:\n{fieldImpls}\nthms: {thms.map (·.type)}" + fieldImpls := fieldImpls.push (field, implName) + thms := thms.push { name := implName, levelParams := instInfo.levelParams, type := thm } + trace[Meta.MethodSpecs] "MethodSpecs for {instName}:\n{fieldImpls}\n\ + thms: {thms.map (·.type)}\nprivateSpecs: {privateSpecs}" - return {clsName, fieldImpls, thms} + return {clsName, fieldImpls, thms, privateSpecs} public structure MethodSpecsAttrData where clsName : Name + /-- Whether the specs should be public or private -/ + privateSpecs : Bool deriving Inhabited def getParam (instName : Name) (_stx : Syntax) : AttrM MethodSpecsAttrData := do -- Preflight check let specsInfo ← (getMethodSpecsInfo instName).run' - return { - clsName := specsInfo.clsName - } + return { specsInfo with } /-- Generate method specification theorems for the methods of the given type class instance. @@ -125,6 +133,31 @@ builtin_initialize methodSpecsSimpExtension : SimpExtension ← registerSimpAttr `method_specs_simp "simp lemma used to post-process the theorem created by `@[method_specs]`." +def mkSpecTheoremName (env : Environment) (instName : Name) (privateSpecs : Bool) (suffix : String) : Name := + let thmName := instName.str suffix + if privateSpecs then mkPrivateName env thmName else thmName + +def startsWithFollowedByNumber (s p : String) : Bool := + s.startsWith p && (s.drop p.length).isNat + +def isSpecThmLikeSuffix (fieldName : Name) (s : String) : Bool := + s == s!"{fieldName}_spec" || startsWithFollowedByNumber s s!"{fieldName}_spec_" + +/-- +The spec theorem theorem for an instance can be private even if the instance itself is not. +So un-private the name here when looking for a declaration, and finally check if it matches. +Cf. `Lean.Meta.declFromEqLikeName`. Maybe worth collecting this logic in a central place. +-/ +def isSpecThmNameFor (env : Environment) (name : Name) : Option Name := do + let .str p s := name | none + [p, privateToUserName p].firstM fun p => do + let attrData ← methodSpecsAttr.getParam? env p + for fieldName in getStructureFields env attrData.clsName do + if isSpecThmLikeSuffix fieldName s then + if name == mkSpecTheoremName env p attrData.privateSpecs s then + return p + none + def rewriteThm (ctx : Simp.Context) (simprocs : Simprocs) (eqThmName destThmName : Name) : MetaM Unit := do let thmInfo ← getConstVal eqThmName @@ -140,56 +173,52 @@ def rewriteThm (ctx : Simp.Context) (simprocs : Simprocs) def genSpecs (instName : Name) : MetaM Unit := do let methodSpecsInfo ← getMethodSpecsInfo instName - let key := instName.str s!"{methodSpecsInfo.fieldImpls[0]!.1}_spec" + let key := mkSpecTheoremName (← getEnv) instName methodSpecsInfo.privateSpecs s!"{methodSpecsInfo.fieldImpls[0]!.1}_spec" realizeConst instName key doRealize where doRealize := do let methodSpecsInfo ← getMethodSpecsInfo instName - let mut s ← methodSpecsSimpExtension.getTheorems - for thm in methodSpecsInfo.thms do - trace[Meta.MethodSpecs] "adding simp theorem for {thm.name} : {thm.type}" - s := s.addSimpTheorem <| ← mkDSimpTheorem (.other thm.name) thm.levelParams.toArray thm.type - let ctx ← Simp.mkContext - (simpTheorems := #[s]) - (congrTheorems := (← getSimpCongrTheorems)) - (config := { } ) -- Simp.neutralConfig with dsimp := true, letToHave := false }) - let simprocs ← Simp.getSimprocs + withoutExporting (when := methodSpecsInfo.privateSpecs) do + let mut s ← methodSpecsSimpExtension.getTheorems + for thm in methodSpecsInfo.thms do + trace[Meta.MethodSpecs] "adding simp theorem for {thm.name} : {thm.type}" + s := s.addSimpTheorem <| ← mkDSimpTheorem (.other thm.name) thm.levelParams.toArray thm.type + let ctx ← Simp.mkContext + (simpTheorems := #[s]) + (congrTheorems := (← getSimpCongrTheorems)) + (config := { } ) -- Simp.neutralConfig with dsimp := true, letToHave := false }) + let simprocs ← Simp.getSimprocs - for (fieldName, implName) in methodSpecsInfo.fieldImpls do - let some unfoldThm ← getUnfoldEqnFor? implName (nonRec := true) - | throwError "failed to generate unfolding theorem for {.ofConstName implName}" - rewriteThm ctx simprocs unfoldThm (instName.str s!"{fieldName}_spec") + let env ← getEnv + for (fieldName, implName) in methodSpecsInfo.fieldImpls do + let some unfoldThm ← getUnfoldEqnFor? implName (nonRec := true) + | throwError "failed to generate unfolding theorem for {.ofConstName implName}" + let thmName := mkSpecTheoremName env instName methodSpecsInfo.privateSpecs s!"{fieldName}_spec" + rewriteThm ctx simprocs unfoldThm thmName - if let some eqnThms ← getEqnsFor? implName then - for eqnThm in eqnThms, i in [:eqnThms.size] do - rewriteThm ctx simprocs eqnThm (instName.str s!"{fieldName}_spec_{i+1}") + if let some eqnThms ← getEqnsFor? implName then + for eqnThm in eqnThms, i in [:eqnThms.size] do + let thmName := mkSpecTheoremName env instName methodSpecsInfo.privateSpecs s!"{fieldName}_spec_{i+1}" + rewriteThm ctx simprocs eqnThm thmName -def startsWithFollowedByNumber (s p : String) : Bool := - s.startsWith p && (s.drop p.length).isNat - -def isSpecThmNameFor (env : Environment) (name : Name) : Option Name := do - let .str p n := name | none - let attrData ← methodSpecsAttr.getParam? env p - for fieldName in getStructureFields env attrData.clsName do - if n == s!"{fieldName}_spec" || startsWithFollowedByNumber n s!"{fieldName}_spec_" then - return p - none public partial def getMethodSpecTheorem (instName : Name) (op : String) : MetaM (Option Name) := do let env ← getEnv - let some _ := methodSpecsAttr.getParam? env instName | return none - realizeGlobalConstNoOverloadCore (instName.str s!"{op}_spec") + let some methodSpecInfos := methodSpecsAttr.getParam? env instName | return none + let thmName := mkSpecTheoremName env instName methodSpecInfos.privateSpecs s!"{op}_spec" + realizeGlobalConstNoOverloadCore thmName public partial def getMethodSpecTheorems (instName : Name) (op : String) : MetaM (Option (Array Name)) := do - let some _ := methodSpecsAttr.getParam? (← getEnv) instName | return none + let some methodSpecInfos := methodSpecsAttr.getParam? (← getEnv) instName | return none -- Realize spec theorems - let _ ← realizeGlobalConstNoOverloadCore (instName.str s!"{op}_spec") + let thmName := mkSpecTheoremName (← getEnv) instName methodSpecInfos.privateSpecs s!"{op}_spec" + let _ ← realizeGlobalConstNoOverloadCore thmName -- Now collect the generated ones let mut i := 0 let mut thms := #[] let env ← getEnv while true do - let thmName := instName.str s!"{op}_spec_{i+1}" + let thmName := mkSpecTheoremName (← getEnv) instName methodSpecInfos.privateSpecs s!"{op}_spec_{i+1}" if env.containsOnBranch thmName then thms := thms.push thmName i := i + 1 diff --git a/src/stdlib_flags.h b/src/stdlib_flags.h index 79a0e58edd..f4d5f7ddae 100644 --- a/src/stdlib_flags.h +++ b/src/stdlib_flags.h @@ -1,5 +1,7 @@ #include "util/options.h" +// please update stage0 + namespace lean { options get_default_options() { options opts; diff --git a/tests/lean/run/reduceBEqSimproc.lean b/tests/lean/run/reduceBEqSimproc.lean index de8b19d312..b536387d43 100644 --- a/tests/lean/run/reduceBEqSimproc.lean +++ b/tests/lean/run/reduceBEqSimproc.lean @@ -1,4 +1,6 @@ +module -- set_option trace.Elab.Deriving.lawfulBEq true +-- set_option trace.Meta.MethodSpecs true inductive L (α : Type u) where | nil : L α @@ -9,3 +11,42 @@ example {n m : Nat} (h : n = m) : (L.cons n (L.nil : L Nat) == L.cons m (L.nil : L Nat)) = true := by simp [reduceBEq] assumption + +-- Module system interactions + +namespace A +inductive L where | nil : L | cons : Nat → L → L deriving BEq +-- NB: Instance, op and theorem are private +/-- info: private def A.instBEqL : BEq L -/ +#guard_msgs in #print sig instBEqL +/-- info: private def A.instBEqL.beq : L → L → Bool -/ +#guard_msgs in #print sig instBEqL.beq +/-- info: private theorem A.instBEqL.beq_spec_1 : (L.nil == L.nil) = true -/ +#guard_msgs(pass trace, all) in #print sig instBEqL.beq_spec_1 +example : (L.cons n (L.nil : L) == L.cons m (L.nil : L)) ↔ n = m := by simp [reduceBEq] +end A + +namespace B +public inductive L where | nil : L | cons : Nat → L → L deriving BEq +-- NB: Instance is public and exposed, op and theorem are private +/-- info: @[expose] def B.instBEqL : BEq L -/ +#guard_msgs in #print sig instBEqL +/-- info: def B.instBEqL.beq : L → L → Bool -/ +#guard_msgs in #print sig instBEqL.beq +-- NB: Private theorem +/-- info: private theorem B.instBEqL.beq_spec_1 : (L.nil == L.nil) = true -/ +#guard_msgs(pass trace, all) in #print sig instBEqL.beq_spec_1 +example : (L.cons n (L.nil : L) == L.cons m (L.nil : L)) ↔ n = m := by simp [reduceBEq] +end B + +namespace C +public inductive L where | nil : L | cons : Nat → L → L deriving @[expose] BEq +-- NB: Public exposed instances, implementation and public theorem +/-- info: @[expose] def C.instBEqL : BEq L -/ +#guard_msgs in #print sig instBEqL +/-- info: @[expose] def C.instBEqL.beq : L → L → Bool -/ +#guard_msgs in #print sig instBEqL.beq +/-- info: theorem C.instBEqL.beq_spec_1 : (L.nil == L.nil) = true -/ +#guard_msgs(pass trace, all) in #print sig instBEqL.beq_spec_1 +example : (L.cons n (L.nil : L) == L.cons m (L.nil : L)) ↔ n = m := by simp [reduceBEq] +end C