feat: @[method_specs] to generate specification theorems from class instances (#10302)

This PR introduces the `@[specs]` attribute. It can be applied to
(certain) type class instances and define “specification theorems” for
the class’ operations, by taking the equational theorems of the
implementation function mentioned in the type class instance and
rephrasing them in terms of the overloaded operations. Fixes #5295.

Example:

```
inductive L α where
  | nil  : L α
  | cons : α → L α → L α

def L.beqImpl [BEq α] : L α → L α → Bool
  | nil, nil           => true
  | cons x xs, cons y ys => x == y && L.beqImpl xs ys
  | _, _               => false

@[method_specs] instance [BEq α] : BEq (L α) := ⟨L.beqImpl⟩

/--
info: theorem instBEqL.beq_spec_2.{u_1} : ∀ {α : Type u_1} [inst : BEq α] (x_2 : α) (xs : L α) (y : α) (ys : L α),
  (L.cons x_2 xs == L.cons y ys) = (x_2 == y && xs == ys)
-/
#guard_msgs(pass trace, all) in
#print sig instBEqL.beq_spec_2
```

It also introduces the `method_specs_norm` simpset to allow registering
further normalization of the theorems. The intended use of this is to
rewrite, say, `Append.append` to the `HAppend.hAppend` (i.e. `++`) that
the user wants to see. Library annotations to follow in a separate PR.
This commit is contained in:
Joachim Breitner 2025-09-15 13:17:06 +02:00 committed by GitHub
parent 97464c9d7f
commit 88fa4212d7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 395 additions and 10 deletions

View file

@ -2262,6 +2262,18 @@ such as replacing `if c then _ else _` with `if h : c then _ else _` or `xs.map`
-/
syntax (name := wf_preprocess) "wf_preprocess" (Tactic.simpPre <|> Tactic.simpPost)? patternIgnore("← " <|> "<- ")? (ppSpace prio)? : attr
/--
Theorems tagged with the `method_specs_simp` attribute are used by `@[method_specs]` to further
rewrite the theorem statement. This is primarily used to rewrite type class methods further to
the desired user-visible form, e.g. from `Append.append` to `HAppend.hAppend`, which has the familiar
notation associated.
The `method_specs` theorems are created on demand (using the realizable constant feature). Thus,
this simp set should behave the same in all modules. Do not add theorems to it except in the module
defining the thing you are rewriting.
-/
syntax (name := method_specs_simp) "method_specs_simp" (Tactic.simpPre <|> Tactic.simpPost)? patternIgnore("← " <|> "<- ")? (ppSpace prio)? : attr
/-- The possible `norm_cast` kinds: `elim`, `move`, or `squash`. -/
syntax normCastLabel := &"elim" <|> &"move" <|> &"squash"

View file

@ -57,5 +57,4 @@ public import Lean.Meta.Diagnostics
public import Lean.Meta.BinderNameHint
public import Lean.Meta.TryThis
public import Lean.Meta.Hint
public section
public import Lean.Meta.MethodSpecs

View file

@ -0,0 +1,211 @@
/-
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Joachim Breitner
-/
module
prelude
public import Init.System.IO
public import Lean.Attributes
public import Lean.Meta.Tactic.Simp.SimpTheorems
import Lean.Meta.Basic
import Lean.Structure
import Lean.Meta.CtorRecognizer
import Lean.Meta.InferType
import Lean.Meta.AppBuilder
import Lean.ReservedNameAction
import Lean.Meta.Tactic.Simp.SimpTheorems
import Lean.Meta.Tactic.Simp.Types
import Lean.Meta.Tactic.Simp.Main
namespace Lean
open Meta
structure MethodSpecTheorem where
/-- Name of the implementation function -/
name : Name
levelParams : List Name
/-- `opImpl = Cls.op instClsT` -/
type : Expr
structure MethodSpecsInfo where
clsName : Name
/-- Array mapping field names to implementation functions. -/
fieldImpls : Array (Name × Name)
/-- rewrite rules to apply -/
thms : Array MethodSpecTheorem
/--
This function checks the `instName` for eligibility and collects the information to rewrite.
It is run twice: when setting the `@[specs]` attribute as a preflight check, and when actually realizing
the constants.
-/
def getMethodSpecsInfo (instName : Name) : MetaM MethodSpecsInfo := do
let instInfo ← getConstInfoDefn instName
let some clsName ← isClass? instInfo.type
| throwError "expected `{.ofConstName instName}` to be a type class instance, but its \
type{inlineExpr instInfo.type}does not look like a class."
let instArity ← forallTelescopeReducing instInfo.type fun xs _ => pure xs.size
let some structInfo := getStructureInfo? (← getEnv) clsName
| throwError "`{.ofConstName clsName}` is not a structure"
lambdaTelescope instInfo.value fun xs body => do
let inst := mkAppN (mkConst instInfo.name (instInfo.levelParams.map mkLevelParam)) xs
let clsApp ← instantiateForall instInfo.type xs
unless xs.size == instArity && (← isConstructorApp body) do
throwError "the definition of `{.ofConstName instName}` does not have the expected shape"
let mut fieldImpls := #[]
let mut thms := #[]
for field in structInfo.fieldNames, arg in body.getAppArgsN structInfo.fieldNames.size do
if (← isProof arg) then continue
let arg := arg.eta
let f := arg.getAppFn
let ys := arg.getAppArgs
unless f.isConst do
throwError "field `{field}` of the instance is not an application of a constant"
unless f.constLevels! == instInfo.levelParams.map mkLevelParam do
throwError "function `{f}` is called with universe parameters\n {f.constLevels!}\nwhich differs from \
the instances' universe parameters\n {instInfo.levelParams.map mkLevelParam}"
unless xs == ys do
throwError "function `{f}` does not take its arguments in the same order as the instance"
-- Construct the replacement theorems
let some fieldInfo := getFieldInfo? (← getEnv) clsName field
| throwError "internal error: could not find field {field} in structure {clsName}"
let lhs := arg
let projFn := mkConst fieldInfo.projFn clsApp.getAppFn.constLevels!
let rhs := mkAppN projFn (clsApp.getAppArgs ++ #[inst])
let eq ← mkEq lhs rhs
let thm ← mkForallFVars xs eq
unless (← isDefEq lhs rhs) do
throwError "internal error: equation `{eq}` does not hold definitionally"
fieldImpls := fieldImpls.push (field, f.constName!)
thms := thms.push { name := f.constName!, levelParams := instInfo.levelParams, type := thm }
trace[Meta.MethodSpecs] "MethodSpecs for {instName}:\n{fieldImpls}\nthms: {thms.map (·.type)}"
return {clsName, fieldImpls, thms}
public structure MethodSpecsAttrData where
clsName : Name
deriving Inhabited
def getParam (instName : Name) (_stx : Syntax) : AttrM MethodSpecsAttrData := do
-- Preflight check
let specsInfo ← (getMethodSpecsInfo instName).run'
return {
clsName := specsInfo.clsName
}
/--
Generate method specification theorems for the methods of the given type class instance.
This expects all (non-proof) methods of the instance to be defined via separate helper functions,
which must take the same arguments as the instance itself, in the same order.
If it is applied to an instance
```
instance instClsT : Cls T where op := opImpl
```
it produces a theorem `instClsT.op_spec` based on `opImpl.eq_def`, but phrased in terms of the
overloaded `Cls.op` operation, and similarly `instClsT.op_spec_<n>` based on the equational theorems
`opImpl.eq_<n>`.
-/
@[builtin_doc]
builtin_initialize methodSpecsAttr : ParametricAttribute MethodSpecsAttrData ←
registerParametricAttribute {
name := `method_specs
descr := "generate method specification theorems"
getParam
}
builtin_initialize methodSpecsSimpExtension : SimpExtension ←
registerSimpAttr `method_specs_simp
"simp lemma used to post-process the theorem created by `@[method_specs]`."
def rewriteThm (ctx : Simp.Context) (simprocs : Simprocs)
(eqThmName destThmName : Name) : MetaM Unit := do
let thmInfo ← getConstVal eqThmName
let (result, _) ← simp thmInfo.type ctx (simprocs := #[simprocs])
trace[Meta.MethodSpecs] "type for {destThmName}:{indentExpr result.expr}"
let value ← result.mkEqMPR <| mkConst eqThmName (thmInfo.levelParams.map mkLevelParam)
addDecl <| Declaration.thmDecl {
name := destThmName
levelParams := thmInfo.levelParams
type := result.expr
value := value
}
def genSpecs (instName : Name) : MetaM Unit := do
let methodSpecsInfo ← getMethodSpecsInfo instName
let key := instName.str s!"{methodSpecsInfo.fieldImpls[0]!.1}_spec"
realizeConst instName key doRealize
where
doRealize := do
let methodSpecsInfo ← getMethodSpecsInfo instName
let mut s ← methodSpecsSimpExtension.getTheorems
for thm in methodSpecsInfo.thms do
trace[Meta.MethodSpecs] "adding simp theorem for {thm.name} : {thm.type}"
s := s.addSimpTheorem <| ← mkDSimpTheorem (.other thm.name) thm.levelParams.toArray thm.type
let ctx ← Simp.mkContext
(simpTheorems := #[s])
(congrTheorems := (← getSimpCongrTheorems))
(config := { } ) -- Simp.neutralConfig with dsimp := true, letToHave := false })
let simprocs ← Simp.getSimprocs
for (fieldName, implName) in methodSpecsInfo.fieldImpls do
let some unfoldThm ← getUnfoldEqnFor? implName (nonRec := true)
| throwError "failed to generate unfolding theorem for {.ofConstName implName}"
rewriteThm ctx simprocs unfoldThm (instName.str s!"{fieldName}_spec")
if let some eqnThms ← getEqnsFor? implName then
for eqnThm in eqnThms, i in [:eqnThms.size] do
rewriteThm ctx simprocs eqnThm (instName.str s!"{fieldName}_spec_{i+1}")
def startsWithFollowedByNumber (s p : String) : Bool :=
s.startsWith p && (s.drop p.length).isNat
def isSpecThmNameFor (env : Environment) (name : Name) : Option Name := do
let .str p n := name | none
let attrData ← methodSpecsAttr.getParam? env p
for fieldName in getStructureFields env attrData.clsName do
if n == s!"{fieldName}_spec" || startsWithFollowedByNumber n s!"{fieldName}_spec_" then
return p
none
public partial def getMethodSpecTheorem (instName : Name) (op : String) : MetaM (Option Name) := do
let env ← getEnv
let some _ := methodSpecsAttr.getParam? env instName | return none
realizeGlobalConstNoOverloadCore (instName.str s!"{op}_spec")
public partial def getMethodSpecTheorems (instName : Name) (op : String) : MetaM (Option (Array Name)) := do
let some _ := methodSpecsAttr.getParam? (← getEnv) instName | return none
-- Realize spec theorems
let _ ← realizeGlobalConstNoOverloadCore (instName.str s!"{op}_spec")
-- Now collect the generated ones
let mut i := 0
let mut thms := #[]
let env ← getEnv
while true do
let thmName := instName.str s!"{op}_spec_{i+1}"
if env.containsOnBranch thmName then
thms := thms.push thmName
i := i + 1
else
break
return some thms
builtin_initialize
registerReservedNamePredicate fun env name => isSpecThmNameFor env name |>.isSome
registerReservedNameAction fun name => do
if let some instName := isSpecThmNameFor (← getEnv) name then
(genSpecs instName).run'
return true
return false
builtin_initialize
Lean.registerTraceClass `Meta.MethodSpecs

View file

@ -122,6 +122,7 @@ private def useImplicitDefEqProof (thm : SimpTheorem) : SimpM Bool := do
private def tryTheoremCore (lhs : Expr) (xs : Array Expr) (bis : Array BinderInfo) (val : Expr) (type : Expr) (e : Expr) (thm : SimpTheorem) (numExtraArgs : Nat) : SimpM (Option Result) := do
recordTriedSimpTheorem thm.origin
let rec go (e : Expr) : SimpM (Option Result) := do
trace[Debug.Meta.Tactic.simp] "trying {← ppSimpTheorem thm} to rewrite{indentExpr e}"
if (← withSimpMetaConfig <| isDefEq lhs e) then
unless (← synthesizeArgs thm.origin bis xs) do
return none
@ -142,6 +143,9 @@ private def tryTheoremCore (lhs : Expr) (xs : Array Expr) (bis : Array BinderInf
we seldom have assigned metavariables in goals.
-/
if (← instantiateMVars e) == rhs then
trace[Debug.Meta.Tactic.simp] "Not applying {← ppSimpTheorem thm} with type\
{indentExpr type}\nto{indentExpr e}\nas the result is structurally equal \
to the original expression"
return none
if thm.perm then
/-

View file

@ -364,17 +364,19 @@ private def checkTypeIsProp (type : Expr) : MetaM Unit :=
unless (← isProp type) do
throwError "Invalid simp theorem: Expected a proposition, but found{indentExpr type}"
private def mkSimpTheoremCore (origin : Origin) (e : Expr) (levelParams : Array Name) (proof : Expr) (post : Bool) (prio : Nat) (noIndexAtArgs : Bool) : MetaM SimpTheorem := do
assert! origin != .fvar ⟨.anonymous⟩
let type ← instantiateMVars (← inferType e)
private def mkSimpTheoremKeys (type : Expr) (noIndexAtArgs : Bool) : MetaM (Array SimpTheoremKey × Bool) := do
withNewMCtxDepth do
let (_, _, type) ← forallMetaTelescopeReducing type
let type ← whnfR type
let (keys, perm) ←
match type.eq? with
| some (_, lhs, rhs) => pure (← DiscrTree.mkPath lhs noIndexAtArgs, ← isPerm lhs rhs)
| none => throwError "Unexpected kind of simp theorem{indentExpr type}"
return { origin, keys, perm, post, levelParams, proof, priority := prio, rfl := (← isRflProof proof) }
match type.eq? with
| some (_, lhs, rhs) => pure (← DiscrTree.mkPath lhs noIndexAtArgs, ← isPerm lhs rhs)
| none => throwError "Unexpected kind of simp theorem{indentExpr type}"
private def mkSimpTheoremCore (origin : Origin) (e : Expr) (levelParams : Array Name) (proof : Expr) (post : Bool) (prio : Nat) (noIndexAtArgs : Bool) : MetaM SimpTheorem := do
assert! origin != .fvar ⟨.anonymous⟩
let type ← instantiateMVars (← inferType e)
let (keys, perm) ← mkSimpTheoremKeys type noIndexAtArgs
return { origin, keys, perm, post, levelParams, proof, priority := prio, rfl := (← isRflProof proof) }
/--
Creates a `SimpTheorem` from a global theorem.
@ -427,6 +429,19 @@ def mkSimpTheoremFromExpr (id : Origin) (levelParams : Array Name) (proof : Expr
(← preprocessProof proof inv).mapM fun val =>
mkSimpTheoremCore id val levelParams val post prio (noIndexAtArgs := true)
/-- Creates a `SimpTheorem` from a definitional equality. -/
def mkDSimpTheorem (id : Origin) (levelParams : Array Name) (type : Expr)
(post := true) (prio : Nat := eval_prio default) (config : ConfigWithKey := simpGlobalConfig) :
MetaM SimpTheorem := do
withConfigWithKey config do
let (keys, perm) ← mkSimpTheoremKeys type (noIndexAtArgs := true)
let proof ← forallTelescopeReducing type fun xs r => do
let some (_, lhs, _rhs) := r.eq?
| throwError "Unexpected kind of dsimp theorem{indentExpr type}"
-- We need to wrap the proof in a type hint, else the type is lost
mkExpectedTypeHint (← mkLambdaFVars xs (← mkEqRefl lhs)) type
return { origin := id, keys, perm, post, levelParams, proof, priority := prio, rfl := true }
/--
A simp theorem or information about a declaration to unfold by simp.
This is stored in the oleans to implement the `simp` attribute and user-defined simp sets.

View file

@ -0,0 +1,144 @@
inductive L α where
| nil : L α
| cons : α → L α → L α
def L.beqImpl [BEq α] : L α → L α → Bool
| nil, nil => true
| cons x xs, cons y ys => x == y && L.beqImpl xs ys
| _, _ => false
@[method_specs] instance [BEq α] : BEq (L α) := ⟨L.beqImpl⟩
/--
info: theorem instBEqL.beq_spec.{u_1} : ∀ {α : Type u_1} [inst : BEq α] (x x_1 : L α),
(x == x_1) =
match x, x_1 with
| L.nil, L.nil => true
| L.cons x xs, L.cons y ys => x == y && xs == ys
| x, x_2 => false
-/
#guard_msgs(pass trace, all) in
#print sig instBEqL.beq_spec
/--
info: theorem instBEqL.beq_spec_1.{u_1} : ∀ {α : Type u_1} [inst : BEq α], (L.nil == L.nil) = true
-/
#guard_msgs(pass trace, all) in
#print sig instBEqL.beq_spec_1
/--
info: theorem instBEqL.beq_spec_2.{u_1} : ∀ {α : Type u_1} [inst : BEq α] (x_2 : α) (xs : L α) (y : α) (ys : L α),
(L.cons x_2 xs == L.cons y ys) = (x_2 == y && xs == ys)
-/
#guard_msgs(pass trace, all) in
#print sig instBEqL.beq_spec_2
/--
info: theorem instBEqL.beq_spec_3.{u_1} : ∀ {α : Type u_1} [inst : BEq α] (x x_1 : L α),
(x = L.nil → x_1 = L.nil → False) →
(∀ (x_2 : α) (xs : L α) (y : α) (ys : L α), x = L.cons x_2 xs → x_1 = L.cons y ys → False) → (x == x_1) = false
-/
#guard_msgs(pass trace, all) in
#print sig instBEqL.beq_spec_3
/-- error: Unknown constant `instBEqL.beq_spec_4` -/
#guard_msgs(pass trace, all) in
#print sig instBEqL.beq_spec_4
-- Other names are not reserved
/-- error: Unknown constant `instBEqL.eq_spec` -/
#guard_msgs in #print sig instBEqL.eq_spec
/-- error: Unknown constant `instBEqL.beq_spec_` -/
#guard_msgs in #print sig instBEqL.beq_spec_
-- Test rewriting all the way to HAppend
@[method_specs_simp] theorem Append.append_eq_happend :
@Append.append α inst = @HAppend.hAppend α α α (@instHAppendOfAppend α inst) := rfl
def L.append {α : Type u} : L α → L α → L α
| nil, ys => ys
| cons x xs, ys => cons x (L.append xs ys)
@[method_specs] instance (α : Type u) : Append (L α) where
append := L.append
/--
info: theorem instAppendL.append_spec_2.{u} : ∀ {α : Type u} (x : L α) (x_2 : α) (xs : L α),
L.cons x_2 xs ++ x = L.cons x_2 (xs ++ x)
-/
#guard_msgs in #print sig instAppendL.append_spec_2
-- Test that rewriting works with non-rfl theorem too
class Cls α where op : αα
class HCls α where hOp : αα
instance instHClsOfCls [Cls α] : HCls α where hOp := Cls.op
-- NB: Not a rfl theorem
@[method_specs_simp] theorem Cls.op_eq_hOp : @Cls.op α inst = @HCls.hOp α (@instHClsOfCls α inst) := (rfl)
def L.op {α : Type u} : L α → L α
| nil => nil
| cons x xs => cons x (L.op xs)
@[method_specs] instance : Cls (L α) where op := L.op
/--
info: theorem instClsL.op_spec_2.{u} : ∀ {α : Type u} (x_1 : α) (xs : L α),
HCls.hOp (L.cons x_1 xs) = L.cons x_1 (HCls.hOp xs)
-/
#guard_msgs in
#print sig instClsL.op_spec_2
/-!
Now some error conditions
-/
/-- error: `Foo` is not a definition -/
#guard_msgs in @[method_specs] inductive Foo
/--
error: expected `foo` to be a type class instance, but its type `Nat` does not look like a class.
-/
#guard_msgs in @[method_specs] def foo := 1
structure S where field : Nat
/--
error: expected `aS` to be a type class instance, but its type `S` does not look like a class.
-/
#guard_msgs in @[method_specs] def aS : S := ⟨1⟩
@[class] inductive indClass where | mk
/-- error: `indClass` is not a structure -/
#guard_msgs in @[method_specs] def instIndClass : indClass := .mk
-- This used to fail until we eta-reduced the field values
@[method_specs] instance anotherInstBEqL [BEq α] : BEq (L α) := ⟨fun x y => L.beqImpl x y⟩
def L.badBeqImpl {α : Type u} : L α → L α → Bool
| nil, nil => true
| cons _ xs, cons _ ys => L.badBeqImpl xs ys
| _, _ => false
/-- error: function `@L.badBeqImpl` does not take its arguments in the same order as the instance -/
#guard_msgs in
@[method_specs] instance badInstBEqL [BEq α] : BEq (L α) := ⟨L.badBeqImpl⟩
-- L.append has a more general type (in terms of universes)
-- than the instance below.
-- This should be caught and warned about.
def L.badAppend : L α → L α → L α
| nil, ys => ys
| cons x xs, ys => cons x (L.badAppend xs ys)
/--
error: function `@L.badAppend` is called with universe parameters
[u+1]
which differs from the instances' universe parameters
[u]
-/
#guard_msgs in
@[method_specs] instance badInstAppendL (α : Type u) : Append (L α) where
append := L.badAppend