diff --git a/src/Lean/Elab/PreDefinition/PartialFixpoint/Induction.lean b/src/Lean/Elab/PreDefinition/PartialFixpoint/Induction.lean index f68dc457cb..4cfc80496a 100644 --- a/src/Lean/Elab/PreDefinition/PartialFixpoint/Induction.lean +++ b/src/Lean/Elab/PreDefinition/PartialFixpoint/Induction.lean @@ -121,17 +121,19 @@ private def numberNames (n : Nat) (base : String) : Array Name := .ofFn (n := n) fun ⟨i, _⟩ => if n == 1 then .mkSimple base else .mkSimple s!"{base}_{i+1}" -def getInductionPrinciplePostfix (name : Name) : MetaM Name := do +def getInductionPrinciplePostfix (name : Name) (isMutual : Bool) : MetaM Name := do let some eqnInfo := eqnInfoExt.find? (← getEnv) name | throwError "{name} is not defined by partial_fixpoint, inductive_fixpoint, nor coinductive_fixpoint" let idx := eqnInfo.declNames.idxOf name let some res := eqnInfo.fixpointType[idx]? | throwError "Cannot get fixpoint type for {name}" - match res with - | .partialFixpoint => return `fixpoint_induct - | .inductiveFixpoint => return `induct - | .coinductiveFixpoint => return `coinduct + match res, isMutual with + | .partialFixpoint, false => return `fixpoint_induct + | .partialFixpoint, true => throwError "`mutual_induct` is only defined for (co)inductive predicates, not for `partial_fixpoint`" + | .inductiveFixpoint, false => return `induct + | .coinductiveFixpoint, false => return `coinduct + | _, true => return `mutual_induct -def deriveInduction (name : Name) : MetaM Unit := do - let postFix ← getInductionPrinciplePostfix name +def deriveInduction (name : Name) (isMutual : Bool) : MetaM Unit := do + let postFix ← getInductionPrinciplePostfix name isMutual let inductName := name ++ postFix realizeConst name inductName do trace[Elab.definition.partialFixpoint] "Called deriveInduction for {inductName}" @@ -193,8 +195,12 @@ def deriveInduction (name : Name) : MetaM Unit := do -- We apply all the premises let packedPremise ← PProdN.mk 0 motiveVars let e' := mkApp e' packedPremise - -- For each element of the mutual block, we project out the appropriate element - let e' ← PProdN.projM infos.size (eqnInfo.declNames.idxOf name) e' + -- For the `mutual_induct` variant, we are done. + -- Else, project out the appropriate element + let e' ← if isMutual then + pure e' + else + PProdN.projM infos.size (eqnInfo.declNames.idxOf name) e' -- Finally, we bind all the free variables with lambdas let e' ← mkLambdaFVars motiveVars e' let e' ← mkLambdaFVars predVars e' @@ -302,6 +308,10 @@ def isInductName (env : Environment) (name : Name) : Bool := Id.run do let idx := eqnInfo.declNames.idxOf p return isInductiveFixpoint eqnInfo.fixpointType[idx]! return false + | "mutual_induct" => + if let some eqnInfo := eqnInfoExt.find? env p then + return eqnInfo.fixpointType.all isLatticeTheoretic && eqnInfo.declNames.size > 1 + return false | _ => return false builtin_initialize @@ -309,8 +319,9 @@ builtin_initialize registerReservedNameAction fun name => do if isInductName (← getEnv) name then - let .str p _ := name | return false - MetaM.run' <| deriveInduction p + let .str p s := name | return false + let isMutual := s.endsWith "mutual_induct" + MetaM.run' <| deriveInduction p isMutual return true return false @@ -362,7 +373,7 @@ def derivePartialCorrectness (name : Name) : MetaM Unit := do realizeConst name inductName do let fixpointInductThm := name ++ `fixpoint_induct unless (← getEnv).contains fixpointInductThm do - deriveInduction name + deriveInduction name false prependError m!"Cannot derive partial correctness theorem (please report this issue)" do let some eqnInfo := eqnInfoExt.find? (← getEnv) name | diff --git a/tests/lean/run/coinductive_predicates.lean b/tests/lean/run/coinductive_predicates.lean index 30c4ac2e7b..765c358042 100644 --- a/tests/lean/run/coinductive_predicates.lean +++ b/tests/lean/run/coinductive_predicates.lean @@ -16,6 +16,12 @@ info: infseq.coinduct.{u_1} {α : Sort u_1} (R : α → α → Prop) (pred : α -/ #guard_msgs in #check infseq.coinduct +/-- +error: Unknown constant `infseq.mutual_induct` +-/ +#guard_msgs in +#check infseq.mutual_induct + -- Simple proof by coinduction theorem cycle_infseq {R : α → α → Prop} (x : α) : R x x → infseq R x := by apply @infseq.coinduct α R (λ m => R m m) diff --git a/tests/lean/run/mutual_coinduction.lean b/tests/lean/run/mutual_coinduction.lean index 0e9d9f0658..b074575383 100644 --- a/tests/lean/run/mutual_coinduction.lean +++ b/tests/lean/run/mutual_coinduction.lean @@ -13,12 +13,23 @@ namespace MutualCoinduction -/ #guard_msgs in #check MutualCoinduction.f.coinduct - + /-- + info: MutualCoinduction.f.mutual_induct (pred_1 pred_2 : Prop) (hyp_1 : pred_1 → pred_2) (hyp_2 : pred_2 → pred_1) : + (pred_1 → f) ∧ (pred_2 → g) + -/ + #guard_msgs in + #check MutualCoinduction.f.mutual_induct /-- info: MutualCoinduction.g.coinduct (pred_1 pred_2 : Prop) (hyp_1 : pred_1 → pred_2) (hyp_2 : pred_2 → pred_1) : pred_2 → g -/ #guard_msgs in #check MutualCoinduction.g.coinduct + /-- + info: MutualCoinduction.g.mutual_induct (pred_1 pred_2 : Prop) (hyp_1 : pred_1 → pred_2) (hyp_2 : pred_2 → pred_1) : + (pred_1 → f) ∧ (pred_2 → g) + -/ + #guard_msgs in + #check MutualCoinduction.g.mutual_induct end MutualCoinduction namespace MutualInduction @@ -36,12 +47,23 @@ namespace MutualInduction -/ #guard_msgs in #check MutualInduction.f.induct - - /-- + /-- + info: MutualInduction.f.mutual_induct (pred_1 pred_2 : Prop) (hyp_1 : pred_2 → pred_1) (hyp_2 : pred_1 → pred_2) : + (f → pred_1) ∧ (g → pred_2) + -/ + #guard_msgs in + #check MutualInduction.f.mutual_induct + /-- info: MutualInduction.g.induct (pred_1 pred_2 : Prop) (hyp_1 : pred_2 → pred_1) (hyp_2 : pred_1 → pred_2) : g → pred_2 -/ #guard_msgs in #check MutualInduction.g.induct + /-- + info: MutualInduction.g.mutual_induct (pred_1 pred_2 : Prop) (hyp_1 : pred_2 → pred_1) (hyp_2 : pred_1 → pred_2) : + (f → pred_1) ∧ (g → pred_2) + -/ + #guard_msgs in + #check MutualInduction.g.mutual_induct end MutualInduction namespace MixedInductionCoinduction @@ -61,14 +83,24 @@ namespace MixedInductionCoinduction -/ #guard_msgs in #check f.induct - - /-- - info: MixedInductionCoinduction.g.coinduct (pred_1 pred_2 : Prop) (hyp_1 : (pred_2 → pred_1) → pred_1) + /-- + info: MixedInductionCoinduction.f.mutual_induct (pred_1 pred_2 : Prop) (hyp_1 : (pred_2 → pred_1) → pred_1) + (hyp_2 : pred_2 → pred_1 → pred_2) : (f → pred_1) ∧ (pred_2 → g) + -/ + #guard_msgs in + #check f.mutual_induct + /-- + info: MixedInductionCoinduction.g.coinduct (pred_1 pred_2 : Prop) (hyp_1 : (pred_2 → pred_1) → pred_1) (hyp_2 : pred_2 → pred_1 → pred_2) : pred_2 → g -/ #guard_msgs in #check g.coinduct - + /-- + info: MixedInductionCoinduction.g.mutual_induct (pred_1 pred_2 : Prop) (hyp_1 : (pred_2 → pred_1) → pred_1) + (hyp_2 : pred_2 → pred_1 → pred_2) : (f → pred_1) ∧ (pred_2 → g) + -/ + #guard_msgs in + #check g.mutual_induct end MixedInductionCoinduction namespace DifferentPredicateTypes @@ -89,7 +121,14 @@ namespace DifferentPredicateTypes -/ #guard_msgs in #check f.coinduct - + /-- + info: DifferentPredicateTypes.f.mutual_induct (pred_1 : Nat → Prop) (pred_2 : Nat → Nat → Prop) + (hyp_1 : ∀ (x : Nat), pred_1 x → pred_2 (x + 1) (x + 2)) + (hyp_2 : ∀ (x x_1 : Nat), pred_2 x x_1 → pred_1 (x + 2) ∨ pred_2 (x_1 + 1) x_1) : + (∀ (x : Nat), pred_1 x → f x) ∧ ∀ (x x_1 : Nat), pred_2 x x_1 → g x x_1 + -/ + #guard_msgs in + #check f.mutual_induct /-- info: DifferentPredicateTypes.g.coinduct (pred_1 : Nat → Prop) (pred_2 : Nat → Nat → Prop) (hyp_1 : ∀ (x : Nat), pred_1 x → pred_2 (x + 1) (x + 2)) @@ -98,6 +137,12 @@ namespace DifferentPredicateTypes -/ #guard_msgs in #check g.coinduct - - + /-- + info: DifferentPredicateTypes.g.mutual_induct (pred_1 : Nat → Prop) (pred_2 : Nat → Nat → Prop) + (hyp_1 : ∀ (x : Nat), pred_1 x → pred_2 (x + 1) (x + 2)) + (hyp_2 : ∀ (x x_1 : Nat), pred_2 x x_1 → pred_1 (x + 2) ∨ pred_2 (x_1 + 1) x_1) : + (∀ (x : Nat), pred_1 x → f x) ∧ ∀ (x x_1 : Nat), pred_2 x x_1 → g x x_1 + -/ + #guard_msgs in + #check g.mutual_induct end DifferentPredicateTypes