feat: add mutual_induct for (co)inductive predicates in mutual blocks (#9628)

This PR introduces a `mutual_induct` variant of the generated
(co)induction proof principle for mutually defined (co)inductive
predicates. Unlike the standard (co)induction principle (which projects
conclusions separately for each predicate), `mutual_induct` produces a
conjunction of all conclusions.

## Example

Given the following mutual definition:

```lean4
mutual
  def f : Prop := g
  coinductive_fixpoint

  def g : Prop := f
  coinductive_fixpoint
end
```

Standard coinduction principles:
```lean4 
f.coind : ∀ (pred_1 pred_2 : Prop), (pred_1 → pred_2) → (pred_2 → pred_1) → pred_1 → f
g.coind : ∀ (pred_1 pred_2 : Prop), (pred_1 → pred_2) → (pred_2 → pred_1) → pred_2 → g
```

New `mutual_induct`principle:
```lean4
f.mutual_induct: ∀ (pred_1 pred_2 : Prop), (pred_1 → pred_2) → (pred_2 → pred_1) → (pred_1 → f) ∧ (pred_2 → g)
```

---------

Co-authored-by: Joachim Breitner <mail@joachim-breitner.de>
This commit is contained in:
Wojciech Rozowski 2025-07-31 13:39:52 +01:00 committed by GitHub
parent 5f20213876
commit fa449aab14
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 84 additions and 22 deletions

View file

@ -121,17 +121,19 @@ private def numberNames (n : Nat) (base : String) : Array Name :=
.ofFn (n := n) fun ⟨i, _⟩ =>
if n == 1 then .mkSimple base else .mkSimple s!"{base}_{i+1}"
def getInductionPrinciplePostfix (name : Name) : MetaM Name := do
def getInductionPrinciplePostfix (name : Name) (isMutual : Bool) : MetaM Name := do
let some eqnInfo := eqnInfoExt.find? (← getEnv) name | throwError "{name} is not defined by partial_fixpoint, inductive_fixpoint, nor coinductive_fixpoint"
let idx := eqnInfo.declNames.idxOf name
let some res := eqnInfo.fixpointType[idx]? | throwError "Cannot get fixpoint type for {name}"
match res with
| .partialFixpoint => return `fixpoint_induct
| .inductiveFixpoint => return `induct
| .coinductiveFixpoint => return `coinduct
match res, isMutual with
| .partialFixpoint, false => return `fixpoint_induct
| .partialFixpoint, true => throwError "`mutual_induct` is only defined for (co)inductive predicates, not for `partial_fixpoint`"
| .inductiveFixpoint, false => return `induct
| .coinductiveFixpoint, false => return `coinduct
| _, true => return `mutual_induct
def deriveInduction (name : Name) : MetaM Unit := do
let postFix ← getInductionPrinciplePostfix name
def deriveInduction (name : Name) (isMutual : Bool) : MetaM Unit := do
let postFix ← getInductionPrinciplePostfix name isMutual
let inductName := name ++ postFix
realizeConst name inductName do
trace[Elab.definition.partialFixpoint] "Called deriveInduction for {inductName}"
@ -193,8 +195,12 @@ def deriveInduction (name : Name) : MetaM Unit := do
-- We apply all the premises
let packedPremise ← PProdN.mk 0 motiveVars
let e' := mkApp e' packedPremise
-- For each element of the mutual block, we project out the appropriate element
let e' ← PProdN.projM infos.size (eqnInfo.declNames.idxOf name) e'
-- For the `mutual_induct` variant, we are done.
-- Else, project out the appropriate element
let e' ← if isMutual then
pure e'
else
PProdN.projM infos.size (eqnInfo.declNames.idxOf name) e'
-- Finally, we bind all the free variables with lambdas
let e' ← mkLambdaFVars motiveVars e'
let e' ← mkLambdaFVars predVars e'
@ -302,6 +308,10 @@ def isInductName (env : Environment) (name : Name) : Bool := Id.run do
let idx := eqnInfo.declNames.idxOf p
return isInductiveFixpoint eqnInfo.fixpointType[idx]!
return false
| "mutual_induct" =>
if let some eqnInfo := eqnInfoExt.find? env p then
return eqnInfo.fixpointType.all isLatticeTheoretic && eqnInfo.declNames.size > 1
return false
| _ => return false
builtin_initialize
@ -309,8 +319,9 @@ builtin_initialize
registerReservedNameAction fun name => do
if isInductName (← getEnv) name then
let .str p _ := name | return false
MetaM.run' <| deriveInduction p
let .str p s := name | return false
let isMutual := s.endsWith "mutual_induct"
MetaM.run' <| deriveInduction p isMutual
return true
return false
@ -362,7 +373,7 @@ def derivePartialCorrectness (name : Name) : MetaM Unit := do
realizeConst name inductName do
let fixpointInductThm := name ++ `fixpoint_induct
unless (← getEnv).contains fixpointInductThm do
deriveInduction name
deriveInduction name false
prependError m!"Cannot derive partial correctness theorem (please report this issue)" do
let some eqnInfo := eqnInfoExt.find? (← getEnv) name |

View file

@ -16,6 +16,12 @@ info: infseq.coinduct.{u_1} {α : Sort u_1} (R : αα → Prop) (pred : α
-/
#guard_msgs in #check infseq.coinduct
/--
error: Unknown constant `infseq.mutual_induct`
-/
#guard_msgs in
#check infseq.mutual_induct
-- Simple proof by coinduction
theorem cycle_infseq {R : αα → Prop} (x : α) : R x x → infseq R x := by
apply @infseq.coinduct α R (λ m => R m m)

View file

@ -13,12 +13,23 @@ namespace MutualCoinduction
-/
#guard_msgs in
#check MutualCoinduction.f.coinduct
/--
info: MutualCoinduction.f.mutual_induct (pred_1 pred_2 : Prop) (hyp_1 : pred_1 → pred_2) (hyp_2 : pred_2 → pred_1) :
(pred_1 → f) ∧ (pred_2 → g)
-/
#guard_msgs in
#check MutualCoinduction.f.mutual_induct
/--
info: MutualCoinduction.g.coinduct (pred_1 pred_2 : Prop) (hyp_1 : pred_1 → pred_2) (hyp_2 : pred_2 → pred_1) : pred_2 → g
-/
#guard_msgs in
#check MutualCoinduction.g.coinduct
/--
info: MutualCoinduction.g.mutual_induct (pred_1 pred_2 : Prop) (hyp_1 : pred_1 → pred_2) (hyp_2 : pred_2 → pred_1) :
(pred_1 → f) ∧ (pred_2 → g)
-/
#guard_msgs in
#check MutualCoinduction.g.mutual_induct
end MutualCoinduction
namespace MutualInduction
@ -36,12 +47,23 @@ namespace MutualInduction
-/
#guard_msgs in
#check MutualInduction.f.induct
/--
/--
info: MutualInduction.f.mutual_induct (pred_1 pred_2 : Prop) (hyp_1 : pred_2 → pred_1) (hyp_2 : pred_1 → pred_2) :
(f → pred_1) ∧ (g → pred_2)
-/
#guard_msgs in
#check MutualInduction.f.mutual_induct
/--
info: MutualInduction.g.induct (pred_1 pred_2 : Prop) (hyp_1 : pred_2 → pred_1) (hyp_2 : pred_1 → pred_2) : g → pred_2
-/
#guard_msgs in
#check MutualInduction.g.induct
/--
info: MutualInduction.g.mutual_induct (pred_1 pred_2 : Prop) (hyp_1 : pred_2 → pred_1) (hyp_2 : pred_1 → pred_2) :
(f → pred_1) ∧ (g → pred_2)
-/
#guard_msgs in
#check MutualInduction.g.mutual_induct
end MutualInduction
namespace MixedInductionCoinduction
@ -61,14 +83,24 @@ namespace MixedInductionCoinduction
-/
#guard_msgs in
#check f.induct
/--
info: MixedInductionCoinduction.g.coinduct (pred_1 pred_2 : Prop) (hyp_1 : (pred_2 → pred_1) → pred_1)
/--
info: MixedInductionCoinduction.f.mutual_induct (pred_1 pred_2 : Prop) (hyp_1 : (pred_2 → pred_1) → pred_1)
(hyp_2 : pred_2 → pred_1 → pred_2) : (f → pred_1) ∧ (pred_2 → g)
-/
#guard_msgs in
#check f.mutual_induct
/--
info: MixedInductionCoinduction.g.coinduct (pred_1 pred_2 : Prop) (hyp_1 : (pred_2 → pred_1) → pred_1)
(hyp_2 : pred_2 → pred_1 → pred_2) : pred_2 → g
-/
#guard_msgs in
#check g.coinduct
/--
info: MixedInductionCoinduction.g.mutual_induct (pred_1 pred_2 : Prop) (hyp_1 : (pred_2 → pred_1) → pred_1)
(hyp_2 : pred_2 → pred_1 → pred_2) : (f → pred_1) ∧ (pred_2 → g)
-/
#guard_msgs in
#check g.mutual_induct
end MixedInductionCoinduction
namespace DifferentPredicateTypes
@ -89,7 +121,14 @@ namespace DifferentPredicateTypes
-/
#guard_msgs in
#check f.coinduct
/--
info: DifferentPredicateTypes.f.mutual_induct (pred_1 : Nat → Prop) (pred_2 : Nat → Nat → Prop)
(hyp_1 : ∀ (x : Nat), pred_1 x → pred_2 (x + 1) (x + 2))
(hyp_2 : ∀ (x x_1 : Nat), pred_2 x x_1 → pred_1 (x + 2) pred_2 (x_1 + 1) x_1) :
(∀ (x : Nat), pred_1 x → f x) ∧ ∀ (x x_1 : Nat), pred_2 x x_1 → g x x_1
-/
#guard_msgs in
#check f.mutual_induct
/--
info: DifferentPredicateTypes.g.coinduct (pred_1 : Nat → Prop) (pred_2 : Nat → Nat → Prop)
(hyp_1 : ∀ (x : Nat), pred_1 x → pred_2 (x + 1) (x + 2))
@ -98,6 +137,12 @@ namespace DifferentPredicateTypes
-/
#guard_msgs in
#check g.coinduct
/--
info: DifferentPredicateTypes.g.mutual_induct (pred_1 : Nat → Prop) (pred_2 : Nat → Nat → Prop)
(hyp_1 : ∀ (x : Nat), pred_1 x → pred_2 (x + 1) (x + 2))
(hyp_2 : ∀ (x x_1 : Nat), pred_2 x x_1 → pred_1 (x + 2) pred_2 (x_1 + 1) x_1) :
(∀ (x : Nat), pred_1 x → f x) ∧ ∀ (x x_1 : Nat), pred_2 x x_1 → g x x_1
-/
#guard_msgs in
#check g.mutual_induct
end DifferentPredicateTypes