feat: deriving ReflBEq and LawfulBEq (#10351)

This PR adds the ability to do `deriving ReflBEq, LawfulBEq`. Both
classes have to listed in the `deriving` clause. For `ReflBEq`, a simple
`simp`-based proof is used. For `LawfulBEq`, a dedicated,
syntax-directed tactic is used that should work for derived `BEq`
instances. This is meant to work with `deriving BEq` (but you can try to
use it on hand-rolled `@[methods_specs] instance : BEq…` instances).
Does not support mutual or nested inductives.
This commit is contained in:
Joachim Breitner 2025-09-16 14:58:01 +02:00 committed by GitHub
parent 917715c862
commit 186f5a6960
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 336 additions and 11 deletions

View file

@ -45,3 +45,4 @@ public import Init.Try
public import Init.BinderNameHint
public import Init.Task
public import Init.MethodSpecsSimp
public import Init.LawfulBEqTactics

View file

@ -0,0 +1,103 @@
/-
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Joachim Breitner
-/
module
prelude
public import Init.Prelude
public import Init.Notation
public import Init.Tactics
public import Init.Core
import Init.Data.Bool
import Init.ByCases
public section
namespace DerivingHelpers
macro "deriving_ReflEq_tactic" : tactic => `(tactic|(
intro x
induction x
all_goals
simp only [ BEq.refl, ↓reduceDIte, Bool.and_true, *, reduceBEq ]
))
theorem and_true_curry {a b : Bool} {P : Prop}
(h : a → b → P) : (a && b) → P := by
rw [Bool.and_eq_true_iff]
intro h'
apply h h'.1 h'.2
theorem deriving_lawful_beq_helper_dep {x y : α} [BEq α] [ReflBEq α]
{t : (x == y) = true → Bool} {P : Prop}
(inst : (x == y) = true → x = y)
(k : (h : x = y) → t (h ▸ ReflBEq.rfl) = true → P) :
(if h : (x == y) then t h else false) = true → P := by
intro h
by_cases hxy : x = y
· subst hxy
apply k rfl
rw [dif_pos (BEq.refl x)] at h
exact h
· by_cases hxy' : x == y
· exact False.elim <| hxy (inst hxy')
· rw [dif_neg hxy'] at h
contradiction
theorem deriving_lawful_beq_helper_nd {x y : α} [BEq α] [ReflBEq α]
{P : Prop}
(inst : (x == y) = true → x = y)
(k : x = y → P) :
(x == y) = true → P := by
intro h
by_cases hxy : x = y
· subst hxy
apply k rfl
· exact False.elim <| hxy (inst h)
end DerivingHelpers
syntax "deriving_LawfulEq_tactic_step" : tactic
macro_rules
| `(tactic| deriving_LawfulEq_tactic_step) =>
`(tactic| fail "deriving_LawfulEq_tactic_step failed")
macro_rules
| `(tactic| deriving_LawfulEq_tactic_step) =>
`(tactic| ( change dite (_ == _) _ _ = true → _
refine DerivingHelpers.deriving_lawful_beq_helper_dep ?_ ?_
· solve | apply_assumption | simp | fail "could not discharge eq_of_beq assumption"
intro h
cases h
dsimp only
))
macro_rules
| `(tactic| deriving_LawfulEq_tactic_step) =>
`(tactic| ( change (_ == _) = true → _
refine DerivingHelpers.deriving_lawful_beq_helper_nd ?_ ?_
· solve | apply_assumption | simp | fail "could not discharge eq_of_beq assumption"
intro h
subst h
))
macro_rules
| `(tactic| deriving_LawfulEq_tactic_step) =>
`(tactic| refine DerivingHelpers.and_true_curry ?_)
macro_rules
| `(tactic| deriving_LawfulEq_tactic_step) =>
`(tactic| rfl)
macro_rules
| `(tactic| deriving_LawfulEq_tactic_step) =>
`(tactic| intro _; trivial)
macro "deriving_LawfulEq_tactic" : tactic => `(tactic|(
intro x
induction x
all_goals
intro y
cases y
all_goals
simp only [reduceBEq]
repeat deriving_LawfulEq_tactic_step
))

View file

@ -19,5 +19,5 @@ public import Lean.Elab.Deriving.SizeOf
public import Lean.Elab.Deriving.Hashable
public import Lean.Elab.Deriving.Ord
public import Lean.Elab.Deriving.ToExpr
public section
public import Lean.Elab.Deriving.ReflBEq
public import Lean.Elab.Deriving.LawfulBEq

View file

@ -0,0 +1,58 @@
/-
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Joachim Breitner
-/
module
prelude
import Lean.Elab.Deriving.Basic
import Lean.Elab.Deriving.Util
import Init.LawfulBEqTactics
namespace Lean.Elab.Deriving.LawfulBEq
open Lean.Parser.Term
open Meta
open TSyntax.Compat in
open Parser.Tactic in
def mkLawfulBEqInstanceCmds (declName : Name) : TermElabM (Array Syntax) := do
let indVal ← getConstInfoInduct declName
if indVal.all.length > 1 then
throwError "Deriving `LawfulBEq` for mutual inductives is not supported"
if indVal.isNested then
throwError "Deriving `LawfulBEq` for nested inductives is not supported"
let argNames ← mkInductArgNames indVal
let binders ← mkImplicitBinders argNames
let binders := binders ++ (← mkInstImplicitBinders ``BEq indVal argNames)
let binders := binders ++ (← mkInstImplicitBinders ``LawfulBEq indVal argNames)
let indType ← mkInductiveApp indVal argNames
let type ← `($(mkCIdent ``LawfulBEq) $indType)
let instCmd ← `(
instance $binders:implicitBinder* : $type := LawfulBEq.mk (by deriving_LawfulEq_tactic)
)
let cmds := #[instCmd]
trace[Elab.Deriving.lawfulBEq] "\n{cmds}"
return cmds
open Command
def mkLawfulBEqInstance (declName : Name) : CommandElabM Unit := do
withoutExposeFromCtors declName do
let cmds ← liftTermElabM <| mkLawfulBEqInstanceCmds declName
cmds.forM elabCommand
def mkLawfulBEqInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
if (← declNames.allM isInductive) then
for declName in declNames do
mkLawfulBEqInstance declName
return true
else
return false
builtin_initialize
registerDerivingHandler ``LawfulBEq mkLawfulBEqInstanceHandler
registerTraceClass `Elab.Deriving.lawfulBEq
end Lean.Elab.Deriving.LawfulBEq

View file

@ -0,0 +1,57 @@
/-
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Joachim Breitner
-/
module
prelude
import Lean.Elab.Deriving.Basic
import Lean.Elab.Deriving.Util
import Init.LawfulBEqTactics
namespace Lean.Elab.Deriving.ReflBEq
open Lean.Parser.Term
open Meta
open TSyntax.Compat in
open Parser.Tactic in
def mkReflBEqInstanceCmds (declName : Name) : TermElabM (Array Syntax) := do
let indVal ← getConstInfoInduct declName
if indVal.all.length > 1 then
throwError "Deriving `ReflBEq` for mutual inductives is not supported"
if indVal.isNested then
throwError "Deriving `ReflBEq` for nested inductives is not supported"
let argNames ← mkInductArgNames indVal
let binders ← mkImplicitBinders argNames
let binders := binders ++ (← mkInstImplicitBinders ``BEq indVal argNames)
let binders := binders ++ (← mkInstImplicitBinders ``ReflBEq indVal argNames)
let indType ← mkInductiveApp indVal argNames
let type ← `($(mkCIdent ``ReflBEq) $indType)
let instCmd ← `( instance $binders:implicitBinder* : $type where
rfl := by deriving_ReflEq_tactic)
let cmds := #[instCmd]
trace[Elab.Deriving.reflBEq] "\n{cmds}"
return cmds
open Command
def mkReflBEqInstance (declName : Name) : CommandElabM Unit := do
withoutExposeFromCtors declName do
let cmds ← liftTermElabM <| mkReflBEqInstanceCmds declName
cmds.forM elabCommand
def mkReflBEqInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
if (← declNames.allM isInductive) then
for declName in declNames do
mkReflBEqInstance declName
return true
else
return false
builtin_initialize
registerDerivingHandler ``ReflBEq mkReflBEqInstanceHandler
registerTraceClass `Elab.Deriving.reflBEq
end Lean.Elab.Deriving.ReflBEq

View file

@ -90,18 +90,23 @@ structure Context where
auxFunNames : Array Name
usePartial : Bool
/--
Anticipates the default instance name for a derived instance.
-/
def mkInstName (className indName : Name) : TermElabM Name := do
let indVal ← getConstInfoInduct indName
let argNames ← mkInductArgNames indVal
let binders ← mkImplicitBinders argNames
let indType ← mkInductiveApp indVal argNames
let type ← `($(mkCIdent className) $indType)
NameGen.mkBaseNameWithSuffix' "inst" (binders.map (·.raw)) type
def mkContext (className : Name) (fnPrefix : String) (typeName : Name) (supportsRec := true ): TermElabM Context := do
let indVal ← getConstInfoInduct typeName
let mut typeInfos := #[]
for typeName in indVal.all do
typeInfos := typeInfos.push (← getConstInfoInduct typeName)
let instName ← do -- anticipate the instance name
let argNames ← mkInductArgNames indVal
let binders ← mkImplicitBinders argNames
let indType ← mkInductiveApp indVal argNames
let type ← `($(mkCIdent className) $indType)
NameGen.mkBaseNameWithSuffix' "inst" (binders.map (·.raw)) type
let instName ← mkInstName className typeName
let mut auxFunNames := #[]
if indVal.all.length = 1 then
auxFunNames := auxFunNames.push (instName ++ .mkSimple fnPrefix)

View file

@ -1,5 +1,7 @@
#include "util/options.h"
// please update stage0
namespace lean {
options get_default_options() {
options opts;

View file

@ -49,6 +49,11 @@ def exclusions : Std.HashMap Lean.Name (Std.HashSet ExclusionKind) := .ofList [
(``SizeOf, { .singleton, .struct, .sum })
]
def dependencies : Std.HashMap Lean.Name (Array Lean.Name) := .ofList [
(``ReflBEq, #[``BEq]),
(``LawfulBEq, #[``BEq, ``ReflBEq])
]
open Lean Meta Elab Command in
set_option hygiene false in
#eval show CommandElabM Unit from do
@ -59,12 +64,13 @@ set_option hygiene false in
withoutModifyingEnv do
let hasExcl (kind : ExclusionKind) := (·.contains kind) <$> exclusions[cls]? |>.getD false
let s ← getThe Command.State
let classes := ((dependencies[cls]? |>.getD #[]).push cls).map mkIdent
unless hasExcl .singleton do
Command.elabCommand (← `(structure B where deriving $(mkIdent cls):ident))
Command.elabCommand (← `(structure B where deriving $[$classes:ident],*))
unless hasExcl .struct do
Command.elabCommand (← `(structure C where x : Nat deriving $(mkIdent cls):ident))
Command.elabCommand (← `(structure C where x : Nat deriving $[$classes:ident],*))
unless hasExcl .sum do
Command.elabCommand (← `(inductive D where | mk₁ : Bool → D | mk₂ : Bool → D deriving $(mkIdent cls):ident))
Command.elabCommand (← `(inductive D where | mk₁ : Bool → D | mk₂ : Bool → D deriving $[$classes:ident],*))
let msgs := (← getThe Command.State).messages.unreported
set s
if msgs.any (·.severity == .error) then

View file

@ -0,0 +1,93 @@
-- set_option trace.Elab.Deriving.lawfulBEq true
inductive L (α : Type u) where
| nil : L α
| cons : α → L α → L α
deriving BEq, ReflBEq, LawfulBEq
/-- info: theorem instReflBEqL.{u_1} : ∀ {α : Type u_1} [inst : BEq α] [ReflBEq α], ReflBEq (L α) -/
#guard_msgs in
#print sig instReflBEqL
inductive Vec (α : Type u) : Nat → Type u where
| nil : Vec α 0
| cons : ∀ {n}, α → Vec α n → Vec α (n+1)
deriving BEq, ReflBEq, LawfulBEq
/--
info: theorem instReflBEqVec.{u_1} : ∀ {α : Type u_1} {a : Nat} [inst : BEq α] [ReflBEq α], ReflBEq (Vec α a)
-/
#guard_msgs in
#print sig instReflBEqVec
inductive Enum
| mk1 | mk2 | mk3
deriving BEq, ReflBEq, LawfulBEq
/-- info: theorem instReflBEqEnum : ReflBEq Enum -/
#guard_msgs in
#print sig instReflBEqEnum
-- The following type has `Eq.rec`s in its `BEq` implementation,
-- but `simp` seems to handle that just fine
inductive WithHEq (α : Type u) : Nat → Type u where
| nil : WithHEq α 0
| cons : ∀ {n m} , α → WithHEq α n → WithHEq α m → WithHEq α (n+1)
deriving BEq, ReflBEq, LawfulBEq
/--
info: instReflBEqWithHEq.{u_1} {α✝ : Type u_1} {a✝ : Nat} [BEq α✝] [ReflBEq α✝] : ReflBEq (WithHEq α✝ a✝)
-/
#guard_msgs in
#check instReflBEqWithHEq
/--
info: instLawfulBEqWithHEq.{u_1} {α✝ : Type u_1} {a✝ : Nat} [BEq α✝] [LawfulBEq α✝] : LawfulBEq (WithHEq α✝ a✝)
-/
#guard_msgs in
#check instLawfulBEqWithHEq
-- No `BEq` derived? Not a great error message yet, but the error location helps, so good enough.
/--
error: failed to synthesize
BEq Foo
Hint: Additional diagnostic information may be available using the `set_option diagnostics true` command.
-/
#guard_msgs in
structure Foo where
deriving ReflBEq
-- No `ReflBEq` but `LawfulBEq`? ot a great error message yet.
/--
@ +2:16...25
error: failed to synthesize
ReflBEq Bar
Hint: Additional diagnostic information may be available using the `set_option diagnostics true` command.
-/
#guard_msgs (positions := true) in
structure Bar where
deriving BEq, LawfulBEq
/--
@ +5:16...23
error: Deriving `ReflBEq` for mutual inductives is not supported
-/
#guard_msgs (positions := true) in
mutual
inductive Tree (α : Type u) where
| node : TreeList α → Tree α
| leaf : α → Tree α
deriving BEq, ReflBEq, LawfulBEq
inductive TreeList (α : Type u) where
| nil : TreeList α
| cons : Tree α → TreeList α → TreeList α
deriving BEq
end