feat: use @[method_specs] when deriving BEq and Ord (#10346)

This PR lets `deriving BEq` and `deriving Ord` use `@[method_specs]`
from #10302 when applicable (i.e. when not using `partial`).
This commit is contained in:
Joachim Breitner 2025-09-15 16:58:00 +02:00 committed by GitHub
parent 3bea7e209e
commit 9aa6448fa9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 79 additions and 18 deletions

View file

@ -6,13 +6,13 @@ Authors: Leonardo de Moura
module
prelude
public import Lean.Meta.Transform
public import Lean.Elab.Deriving.Basic
public import Lean.Elab.Deriving.Util
public import Lean.Data.Options
import Lean.Meta.Transform
import Lean.Elab.Deriving.Basic
import Lean.Elab.Deriving.Util
import Lean.Meta.Eqns
import Lean.Meta.SameCtorUtils
public section
namespace Lean.Elab.Deriving.BEq
open Lean.Parser.Term
open Meta
@ -122,18 +122,16 @@ def mkMutualBlock (ctx : Context) : TermElabM Syntax := do
$auxDefs:command*
end)
private def mkBEqInstanceCmds (declName : Name) : TermElabM (Array Syntax) := do
let ctx ← mkContext ``BEq "beq" declName
def mkBEqInstanceCmds (ctx : Context) (declName : Name) : TermElabM (Array Syntax) := do
let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `BEq #[declName])
trace[Elab.Deriving.beq] "\n{cmds}"
return cmds
private def mkBEqEnumFun (ctx : Context) (name : Name) : TermElabM Syntax := do
def mkBEqEnumFun (ctx : Context) (name : Name) : TermElabM Syntax := do
let auxFunName := ctx.auxFunNames[0]!
`(def $(mkIdent auxFunName):ident (x y : $(mkCIdent name)) : Bool := x.ctorIdx == y.ctorIdx)
private def mkBEqEnumCmd (name : Name): TermElabM (Array Syntax) := do
let ctx ← mkContext ``BEq "beq" name
def mkBEqEnumCmd (ctx : Context) (name : Name): TermElabM (Array Syntax) := do
let cmds := #[← mkBEqEnumFun ctx name] ++ (← mkInstanceCmds ctx `BEq #[name])
trace[Elab.Deriving.beq] "\n{cmds}"
return cmds
@ -142,12 +140,15 @@ open Command
def mkBEqInstance (declName : Name) : CommandElabM Unit := do
withoutExposeFromCtors declName do
let ctx ← liftTermElabM <| mkContext ``BEq "beq" declName
let cmds ← liftTermElabM <|
if (← isEnumType declName) then
mkBEqEnumCmd declName
mkBEqEnumCmd ctx declName
else
mkBEqInstanceCmds declName
mkBEqInstanceCmds ctx declName
cmds.forM elabCommand
unless ctx.usePartial do
elabCommand (← `(attribute [method_specs] $(mkIdent ctx.instName):ident))
def mkBEqInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
if (← declNames.allM isInductive) then

View file

@ -79,11 +79,11 @@ def mkAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do
let indVal := ctx.typeInfos[i]!
let header ← mkOrdHeader indVal
let mut body ← mkMatch header indVal
if ctx.usePartial || indVal.isRec then
if ctx.usePartial then
let letDecls ← mkLocalInstanceLetDecls ctx `Ord header.argNames
body ← mkLet letDecls body
let binders := header.binders
if ctx.usePartial || indVal.isRec then
if ctx.usePartial then
`(partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Ordering := $body:term)
else
`(def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Ordering := $body:term)
@ -98,8 +98,10 @@ def mkMutualBlock (ctx : Context) : TermElabM Syntax := do
end)
private def mkOrdInstanceCmds (declName : Name) : TermElabM (Array Syntax) := do
let ctx ← mkContext ``Ord "ord" declName
let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `Ord #[declName])
let ctx ← mkContext ``Ord "ord" declName (supportsRec := false)
let mut cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `Ord #[declName])
unless ctx.usePartial do
cmds := cmds.push (← `(command| attribute [method_specs] $(mkIdent ctx.instName):ident))
trace[Elab.Deriving.ord] "\n{cmds}"
return cmds

View file

@ -91,7 +91,7 @@ structure Context where
usePartial : Bool
def mkContext (className : Name) (fnPrefix : String) (typeName : Name) : TermElabM Context := do
def mkContext (className : Name) (fnPrefix : String) (typeName : Name) (supportsRec := true ): TermElabM Context := do
let indVal ← getConstInfoInduct typeName
let mut typeInfos := #[]
for typeName in indVal.all do
@ -109,7 +109,7 @@ def mkContext (className : Name) (fnPrefix : String) (typeName : Name) : TermEla
for i in [:indVal.all.length] do
auxFunNames := auxFunNames.push (instName ++ .mkSimple s!"{fnPrefix}_{i+1}")
trace[Elab.Deriving] "instName: {instName} auxFunNames: {auxFunNames}"
let usePartial := indVal.isNested || typeInfos.size > 1
let usePartial := indVal.isNested || typeInfos.size > 1 || (indVal.isRec && !supportsRec)
return {
instName := instName
typeInfos := typeInfos

View file

@ -0,0 +1,58 @@
inductive L (α : Type) where
| nil : L α
| cons : α → L α → L α
deriving BEq, Ord
/--
info: theorem instBEqL.beq_spec_2 : ∀ {α : Type} [inst : BEq α] (a : α) (a_1 : L α) (b : α) (b_1 : L α),
(L.cons a a_1 == L.cons b b_1) = (a == b && a_1 == b_1)
-/
#guard_msgs in
#print sig instBEqL.beq_spec_2
-- Ord does not support recursive types yet:
/-- error: Unknown constant `instOrdL.compare_spec_2` -/
#guard_msgs in
#print sig instOrdL.compare_spec_2
inductive O (α : Type u) where
| none
| some : α → O α
deriving BEq, Ord
/--
info: theorem instBEqO.beq_spec_2.{u_1} : ∀ {α : Type u_1} [inst : BEq α] (a b : α), (O.some a == O.some b) = (a == b)
-/
#guard_msgs in #print sig instBEqO.beq_spec_2
/--
info: theorem instOrdO.compare_spec_1.{u_1} : ∀ {α : Type u_1} [inst : Ord α], compare O.none O.none = Ordering.eq
-/
#guard_msgs in #print sig instOrdO.compare_spec_1
/--
info: theorem instOrdO.compare_spec_2.{u_1} : ∀ {α : Type u_1} [inst : Ord α] (x : O α),
(x = O.none → False) → compare O.none x = Ordering.lt
-/
#guard_msgs in #print sig instOrdO.compare_spec_2
/--
info: theorem instOrdO.compare_spec_4.{u_1} : ∀ {α : Type u_1} [inst : Ord α] (a b : α),
compare (O.some a) (O.some b) = (compare a b).then Ordering.eq
-/
#guard_msgs in #print sig instOrdO.compare_spec_4
-- Mutual inductive (not supported yet, but should be)
mutual
inductive Tree (α : Type) where
| node : TreeList α → Tree α
deriving BEq
inductive TreeList (α : Type) where
| nil : TreeList α
| cons : Tree α → TreeList α → TreeList α
deriving BEq
end
/-- error: Unknown constant `instBEqTree.beq_spec_1` -/
#guard_msgs in
#print sig instBEqTree.beq_spec_1