From 9aa6448fa96d20462758152e69b7b7ebdbb3a6e6 Mon Sep 17 00:00:00 2001 From: Joachim Breitner Date: Mon, 15 Sep 2025 16:58:00 +0200 Subject: [PATCH] feat: use @[method_specs] when deriving BEq and Ord (#10346) This PR lets `deriving BEq` and `deriving Ord` use `@[method_specs]` from #10302 when applicable (i.e. when not using `partial`). --- src/Lean/Elab/Deriving/BEq.lean | 25 ++++++----- src/Lean/Elab/Deriving/Ord.lean | 10 +++-- src/Lean/Elab/Deriving/Util.lean | 4 +- tests/lean/run/methodSpecsDeriving.lean | 58 +++++++++++++++++++++++++ 4 files changed, 79 insertions(+), 18 deletions(-) create mode 100644 tests/lean/run/methodSpecsDeriving.lean diff --git a/src/Lean/Elab/Deriving/BEq.lean b/src/Lean/Elab/Deriving/BEq.lean index 1da6841336..1d4e647d33 100644 --- a/src/Lean/Elab/Deriving/BEq.lean +++ b/src/Lean/Elab/Deriving/BEq.lean @@ -6,13 +6,13 @@ Authors: Leonardo de Moura module prelude -public import Lean.Meta.Transform -public import Lean.Elab.Deriving.Basic -public import Lean.Elab.Deriving.Util +public import Lean.Data.Options +import Lean.Meta.Transform +import Lean.Elab.Deriving.Basic +import Lean.Elab.Deriving.Util +import Lean.Meta.Eqns import Lean.Meta.SameCtorUtils -public section - namespace Lean.Elab.Deriving.BEq open Lean.Parser.Term open Meta @@ -122,18 +122,16 @@ def mkMutualBlock (ctx : Context) : TermElabM Syntax := do $auxDefs:command* end) -private def mkBEqInstanceCmds (declName : Name) : TermElabM (Array Syntax) := do - let ctx ← mkContext ``BEq "beq" declName +def mkBEqInstanceCmds (ctx : Context) (declName : Name) : TermElabM (Array Syntax) := do let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `BEq #[declName]) trace[Elab.Deriving.beq] "\n{cmds}" return cmds -private def mkBEqEnumFun (ctx : Context) (name : Name) : TermElabM Syntax := do +def mkBEqEnumFun (ctx : Context) (name : Name) : TermElabM Syntax := do let auxFunName := ctx.auxFunNames[0]! `(def $(mkIdent auxFunName):ident (x y : $(mkCIdent name)) : Bool := x.ctorIdx == y.ctorIdx) -private def mkBEqEnumCmd (name : Name): TermElabM (Array Syntax) := do - let ctx ← mkContext ``BEq "beq" name +def mkBEqEnumCmd (ctx : Context) (name : Name): TermElabM (Array Syntax) := do let cmds := #[← mkBEqEnumFun ctx name] ++ (← mkInstanceCmds ctx `BEq #[name]) trace[Elab.Deriving.beq] "\n{cmds}" return cmds @@ -142,12 +140,15 @@ open Command def mkBEqInstance (declName : Name) : CommandElabM Unit := do withoutExposeFromCtors declName do + let ctx ← liftTermElabM <| mkContext ``BEq "beq" declName let cmds ← liftTermElabM <| if (← isEnumType declName) then - mkBEqEnumCmd declName + mkBEqEnumCmd ctx declName else - mkBEqInstanceCmds declName + mkBEqInstanceCmds ctx declName cmds.forM elabCommand + unless ctx.usePartial do + elabCommand (← `(attribute [method_specs] $(mkIdent ctx.instName):ident)) def mkBEqInstanceHandler (declNames : Array Name) : CommandElabM Bool := do if (← declNames.allM isInductive) then diff --git a/src/Lean/Elab/Deriving/Ord.lean b/src/Lean/Elab/Deriving/Ord.lean index 1151e9db09..e2234ee6ba 100644 --- a/src/Lean/Elab/Deriving/Ord.lean +++ b/src/Lean/Elab/Deriving/Ord.lean @@ -79,11 +79,11 @@ def mkAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do let indVal := ctx.typeInfos[i]! let header ← mkOrdHeader indVal let mut body ← mkMatch header indVal - if ctx.usePartial || indVal.isRec then + if ctx.usePartial then let letDecls ← mkLocalInstanceLetDecls ctx `Ord header.argNames body ← mkLet letDecls body let binders := header.binders - if ctx.usePartial || indVal.isRec then + if ctx.usePartial then `(partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Ordering := $body:term) else `(def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Ordering := $body:term) @@ -98,8 +98,10 @@ def mkMutualBlock (ctx : Context) : TermElabM Syntax := do end) private def mkOrdInstanceCmds (declName : Name) : TermElabM (Array Syntax) := do - let ctx ← mkContext ``Ord "ord" declName - let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `Ord #[declName]) + let ctx ← mkContext ``Ord "ord" declName (supportsRec := false) + let mut cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `Ord #[declName]) + unless ctx.usePartial do + cmds := cmds.push (← `(command| attribute [method_specs] $(mkIdent ctx.instName):ident)) trace[Elab.Deriving.ord] "\n{cmds}" return cmds diff --git a/src/Lean/Elab/Deriving/Util.lean b/src/Lean/Elab/Deriving/Util.lean index 55832e2593..8582b13cb6 100644 --- a/src/Lean/Elab/Deriving/Util.lean +++ b/src/Lean/Elab/Deriving/Util.lean @@ -91,7 +91,7 @@ structure Context where usePartial : Bool -def mkContext (className : Name) (fnPrefix : String) (typeName : Name) : TermElabM Context := do +def mkContext (className : Name) (fnPrefix : String) (typeName : Name) (supportsRec := true ): TermElabM Context := do let indVal ← getConstInfoInduct typeName let mut typeInfos := #[] for typeName in indVal.all do @@ -109,7 +109,7 @@ def mkContext (className : Name) (fnPrefix : String) (typeName : Name) : TermEla for i in [:indVal.all.length] do auxFunNames := auxFunNames.push (instName ++ .mkSimple s!"{fnPrefix}_{i+1}") trace[Elab.Deriving] "instName: {instName} auxFunNames: {auxFunNames}" - let usePartial := indVal.isNested || typeInfos.size > 1 + let usePartial := indVal.isNested || typeInfos.size > 1 || (indVal.isRec && !supportsRec) return { instName := instName typeInfos := typeInfos diff --git a/tests/lean/run/methodSpecsDeriving.lean b/tests/lean/run/methodSpecsDeriving.lean new file mode 100644 index 0000000000..40c186aa23 --- /dev/null +++ b/tests/lean/run/methodSpecsDeriving.lean @@ -0,0 +1,58 @@ +inductive L (α : Type) where + | nil : L α + | cons : α → L α → L α +deriving BEq, Ord + +/-- +info: theorem instBEqL.beq_spec_2 : ∀ {α : Type} [inst : BEq α] (a : α) (a_1 : L α) (b : α) (b_1 : L α), + (L.cons a a_1 == L.cons b b_1) = (a == b && a_1 == b_1) +-/ +#guard_msgs in +#print sig instBEqL.beq_spec_2 + +-- Ord does not support recursive types yet: + +/-- error: Unknown constant `instOrdL.compare_spec_2` -/ +#guard_msgs in +#print sig instOrdL.compare_spec_2 + +inductive O (α : Type u) where + | none + | some : α → O α +deriving BEq, Ord + +/-- +info: theorem instBEqO.beq_spec_2.{u_1} : ∀ {α : Type u_1} [inst : BEq α] (a b : α), (O.some a == O.some b) = (a == b) +-/ +#guard_msgs in #print sig instBEqO.beq_spec_2 +/-- +info: theorem instOrdO.compare_spec_1.{u_1} : ∀ {α : Type u_1} [inst : Ord α], compare O.none O.none = Ordering.eq +-/ +#guard_msgs in #print sig instOrdO.compare_spec_1 +/-- +info: theorem instOrdO.compare_spec_2.{u_1} : ∀ {α : Type u_1} [inst : Ord α] (x : O α), + (x = O.none → False) → compare O.none x = Ordering.lt +-/ +#guard_msgs in #print sig instOrdO.compare_spec_2 +/-- +info: theorem instOrdO.compare_spec_4.{u_1} : ∀ {α : Type u_1} [inst : Ord α] (a b : α), + compare (O.some a) (O.some b) = (compare a b).then Ordering.eq +-/ +#guard_msgs in #print sig instOrdO.compare_spec_4 + +-- Mutual inductive (not supported yet, but should be) + +mutual +inductive Tree (α : Type) where + | node : TreeList α → Tree α + deriving BEq +inductive TreeList (α : Type) where + | nil : TreeList α + | cons : Tree α → TreeList α → TreeList α + deriving BEq +end + + +/-- error: Unknown constant `instBEqTree.beq_spec_1` -/ +#guard_msgs in +#print sig instBEqTree.beq_spec_1