feat: use @[method_specs] when deriving BEq and Ord (#10346)
This PR lets `deriving BEq` and `deriving Ord` use `@[method_specs]` from #10302 when applicable (i.e. when not using `partial`).
This commit is contained in:
parent
3bea7e209e
commit
9aa6448fa9
4 changed files with 79 additions and 18 deletions
|
|
@ -6,13 +6,13 @@ Authors: Leonardo de Moura
|
|||
module
|
||||
|
||||
prelude
|
||||
public import Lean.Meta.Transform
|
||||
public import Lean.Elab.Deriving.Basic
|
||||
public import Lean.Elab.Deriving.Util
|
||||
public import Lean.Data.Options
|
||||
import Lean.Meta.Transform
|
||||
import Lean.Elab.Deriving.Basic
|
||||
import Lean.Elab.Deriving.Util
|
||||
import Lean.Meta.Eqns
|
||||
import Lean.Meta.SameCtorUtils
|
||||
|
||||
public section
|
||||
|
||||
namespace Lean.Elab.Deriving.BEq
|
||||
open Lean.Parser.Term
|
||||
open Meta
|
||||
|
|
@ -122,18 +122,16 @@ def mkMutualBlock (ctx : Context) : TermElabM Syntax := do
|
|||
$auxDefs:command*
|
||||
end)
|
||||
|
||||
private def mkBEqInstanceCmds (declName : Name) : TermElabM (Array Syntax) := do
|
||||
let ctx ← mkContext ``BEq "beq" declName
|
||||
def mkBEqInstanceCmds (ctx : Context) (declName : Name) : TermElabM (Array Syntax) := do
|
||||
let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `BEq #[declName])
|
||||
trace[Elab.Deriving.beq] "\n{cmds}"
|
||||
return cmds
|
||||
|
||||
private def mkBEqEnumFun (ctx : Context) (name : Name) : TermElabM Syntax := do
|
||||
def mkBEqEnumFun (ctx : Context) (name : Name) : TermElabM Syntax := do
|
||||
let auxFunName := ctx.auxFunNames[0]!
|
||||
`(def $(mkIdent auxFunName):ident (x y : $(mkCIdent name)) : Bool := x.ctorIdx == y.ctorIdx)
|
||||
|
||||
private def mkBEqEnumCmd (name : Name): TermElabM (Array Syntax) := do
|
||||
let ctx ← mkContext ``BEq "beq" name
|
||||
def mkBEqEnumCmd (ctx : Context) (name : Name): TermElabM (Array Syntax) := do
|
||||
let cmds := #[← mkBEqEnumFun ctx name] ++ (← mkInstanceCmds ctx `BEq #[name])
|
||||
trace[Elab.Deriving.beq] "\n{cmds}"
|
||||
return cmds
|
||||
|
|
@ -142,12 +140,15 @@ open Command
|
|||
|
||||
def mkBEqInstance (declName : Name) : CommandElabM Unit := do
|
||||
withoutExposeFromCtors declName do
|
||||
let ctx ← liftTermElabM <| mkContext ``BEq "beq" declName
|
||||
let cmds ← liftTermElabM <|
|
||||
if (← isEnumType declName) then
|
||||
mkBEqEnumCmd declName
|
||||
mkBEqEnumCmd ctx declName
|
||||
else
|
||||
mkBEqInstanceCmds declName
|
||||
mkBEqInstanceCmds ctx declName
|
||||
cmds.forM elabCommand
|
||||
unless ctx.usePartial do
|
||||
elabCommand (← `(attribute [method_specs] $(mkIdent ctx.instName):ident))
|
||||
|
||||
def mkBEqInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
|
||||
if (← declNames.allM isInductive) then
|
||||
|
|
|
|||
|
|
@ -79,11 +79,11 @@ def mkAuxFunction (ctx : Context) (i : Nat) : TermElabM Command := do
|
|||
let indVal := ctx.typeInfos[i]!
|
||||
let header ← mkOrdHeader indVal
|
||||
let mut body ← mkMatch header indVal
|
||||
if ctx.usePartial || indVal.isRec then
|
||||
if ctx.usePartial then
|
||||
let letDecls ← mkLocalInstanceLetDecls ctx `Ord header.argNames
|
||||
body ← mkLet letDecls body
|
||||
let binders := header.binders
|
||||
if ctx.usePartial || indVal.isRec then
|
||||
if ctx.usePartial then
|
||||
`(partial def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Ordering := $body:term)
|
||||
else
|
||||
`(def $(mkIdent auxFunName):ident $binders:bracketedBinder* : Ordering := $body:term)
|
||||
|
|
@ -98,8 +98,10 @@ def mkMutualBlock (ctx : Context) : TermElabM Syntax := do
|
|||
end)
|
||||
|
||||
private def mkOrdInstanceCmds (declName : Name) : TermElabM (Array Syntax) := do
|
||||
let ctx ← mkContext ``Ord "ord" declName
|
||||
let cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `Ord #[declName])
|
||||
let ctx ← mkContext ``Ord "ord" declName (supportsRec := false)
|
||||
let mut cmds := #[← mkMutualBlock ctx] ++ (← mkInstanceCmds ctx `Ord #[declName])
|
||||
unless ctx.usePartial do
|
||||
cmds := cmds.push (← `(command| attribute [method_specs] $(mkIdent ctx.instName):ident))
|
||||
trace[Elab.Deriving.ord] "\n{cmds}"
|
||||
return cmds
|
||||
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ structure Context where
|
|||
usePartial : Bool
|
||||
|
||||
|
||||
def mkContext (className : Name) (fnPrefix : String) (typeName : Name) : TermElabM Context := do
|
||||
def mkContext (className : Name) (fnPrefix : String) (typeName : Name) (supportsRec := true ): TermElabM Context := do
|
||||
let indVal ← getConstInfoInduct typeName
|
||||
let mut typeInfos := #[]
|
||||
for typeName in indVal.all do
|
||||
|
|
@ -109,7 +109,7 @@ def mkContext (className : Name) (fnPrefix : String) (typeName : Name) : TermEla
|
|||
for i in [:indVal.all.length] do
|
||||
auxFunNames := auxFunNames.push (instName ++ .mkSimple s!"{fnPrefix}_{i+1}")
|
||||
trace[Elab.Deriving] "instName: {instName} auxFunNames: {auxFunNames}"
|
||||
let usePartial := indVal.isNested || typeInfos.size > 1
|
||||
let usePartial := indVal.isNested || typeInfos.size > 1 || (indVal.isRec && !supportsRec)
|
||||
return {
|
||||
instName := instName
|
||||
typeInfos := typeInfos
|
||||
|
|
|
|||
58
tests/lean/run/methodSpecsDeriving.lean
Normal file
58
tests/lean/run/methodSpecsDeriving.lean
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
inductive L (α : Type) where
|
||||
| nil : L α
|
||||
| cons : α → L α → L α
|
||||
deriving BEq, Ord
|
||||
|
||||
/--
|
||||
info: theorem instBEqL.beq_spec_2 : ∀ {α : Type} [inst : BEq α] (a : α) (a_1 : L α) (b : α) (b_1 : L α),
|
||||
(L.cons a a_1 == L.cons b b_1) = (a == b && a_1 == b_1)
|
||||
-/
|
||||
#guard_msgs in
|
||||
#print sig instBEqL.beq_spec_2
|
||||
|
||||
-- Ord does not support recursive types yet:
|
||||
|
||||
/-- error: Unknown constant `instOrdL.compare_spec_2` -/
|
||||
#guard_msgs in
|
||||
#print sig instOrdL.compare_spec_2
|
||||
|
||||
inductive O (α : Type u) where
|
||||
| none
|
||||
| some : α → O α
|
||||
deriving BEq, Ord
|
||||
|
||||
/--
|
||||
info: theorem instBEqO.beq_spec_2.{u_1} : ∀ {α : Type u_1} [inst : BEq α] (a b : α), (O.some a == O.some b) = (a == b)
|
||||
-/
|
||||
#guard_msgs in #print sig instBEqO.beq_spec_2
|
||||
/--
|
||||
info: theorem instOrdO.compare_spec_1.{u_1} : ∀ {α : Type u_1} [inst : Ord α], compare O.none O.none = Ordering.eq
|
||||
-/
|
||||
#guard_msgs in #print sig instOrdO.compare_spec_1
|
||||
/--
|
||||
info: theorem instOrdO.compare_spec_2.{u_1} : ∀ {α : Type u_1} [inst : Ord α] (x : O α),
|
||||
(x = O.none → False) → compare O.none x = Ordering.lt
|
||||
-/
|
||||
#guard_msgs in #print sig instOrdO.compare_spec_2
|
||||
/--
|
||||
info: theorem instOrdO.compare_spec_4.{u_1} : ∀ {α : Type u_1} [inst : Ord α] (a b : α),
|
||||
compare (O.some a) (O.some b) = (compare a b).then Ordering.eq
|
||||
-/
|
||||
#guard_msgs in #print sig instOrdO.compare_spec_4
|
||||
|
||||
-- Mutual inductive (not supported yet, but should be)
|
||||
|
||||
mutual
|
||||
inductive Tree (α : Type) where
|
||||
| node : TreeList α → Tree α
|
||||
deriving BEq
|
||||
inductive TreeList (α : Type) where
|
||||
| nil : TreeList α
|
||||
| cons : Tree α → TreeList α → TreeList α
|
||||
deriving BEq
|
||||
end
|
||||
|
||||
|
||||
/-- error: Unknown constant `instBEqTree.beq_spec_1` -/
|
||||
#guard_msgs in
|
||||
#print sig instBEqTree.beq_spec_1
|
||||
Loading…
Add table
Reference in a new issue