From 75570f327fd481cc3ce60a9f833e7c7b0abd8f71 Mon Sep 17 00:00:00 2001 From: Joachim Breitner Date: Wed, 19 Nov 2025 10:53:09 +0100 Subject: [PATCH] refactor: thunk field-less alternatives of casesOnSameCtor (#11254) This RP adds a `Unit` argument to `casesOnSameCtor` to make it behave moere similar to a matcher. Follow up in spirit to #11239. --- src/Lean/Elab/Deriving/BEq.lean | 3 +++ src/Lean/Elab/Deriving/DecEq.lean | 3 +++ src/Lean/Elab/Deriving/Ord.lean | 3 +++ src/Lean/Meta/Constructions/CasesOnSameCtor.lean | 7 +++++-- tests/lean/decEqMutualInductives.lean.expected.out | 4 ++-- tests/lean/run/casesOnSameCtor.lean | 10 +++++----- 6 files changed, 21 insertions(+), 9 deletions(-) diff --git a/src/Lean/Elab/Deriving/BEq.lean b/src/Lean/Elab/Deriving/BEq.lean index 5351178d38..cefb56773b 100644 --- a/src/Lean/Elab/Deriving/BEq.lean +++ b/src/Lean/Elab/Deriving/BEq.lean @@ -164,6 +164,9 @@ def mkMatchNew (header : Header) (indVal : InductiveVal) (auxFunName : Name) : T rhs_empty := false else rhs ← `($a:ident == $b:ident && $rhs) + if ctorArgs1.isEmpty then + -- Unit thunking argument + ctorArgs1 := ctorArgs1.push (← `(())) `(@fun $ctorArgs1.reverse:term* $ctorArgs2.reverse:term* =>$rhs:term) if indVal.numCtors == 1 then `( $(mkCIdent casesOnSameCtorName) $x1:term $x2:term rfl $alts:term* ) diff --git a/src/Lean/Elab/Deriving/DecEq.lean b/src/Lean/Elab/Deriving/DecEq.lean index 5b78b047c3..0a042bd076 100644 --- a/src/Lean/Elab/Deriving/DecEq.lean +++ b/src/Lean/Elab/Deriving/DecEq.lean @@ -140,6 +140,9 @@ def mkMatchNew (ctx : Context) (header : Header) (indVal : InductiveVal) : TermE let recField := indValNum.map (ctx.auxFunNames[·]!) let isProof ← isProp xType todo := todo.push (a, b, recField, isProof) + if ctorArgs1.isEmpty then + -- Unit thunking argument + ctorArgs1 := ctorArgs1.push (← `(())) let rhs ← mkSameCtorRhs todo.toList `(@fun $ctorArgs1:term* $ctorArgs2:term* =>$rhs:term) if indVal.numCtors == 1 then diff --git a/src/Lean/Elab/Deriving/Ord.lean b/src/Lean/Elab/Deriving/Ord.lean index 28a45ac123..1998f12a1e 100644 --- a/src/Lean/Elab/Deriving/Ord.lean +++ b/src/Lean/Elab/Deriving/Ord.lean @@ -118,6 +118,9 @@ def mkMatchNew (header : Header) (indVal : InductiveVal) : TermElabM Term := do else rhsCont := fun rhs => `(Ordering.then (compare $a $b) $rhs) >>= rhsCont let rhs ← rhsCont (← `(Ordering.eq)) + if ctorArgs1.isEmpty then + -- Unit thunking argument + ctorArgs1 := ctorArgs1.push (← `(())) `(@fun $ctorArgs1:term* $ctorArgs2:term* =>$rhs:term) if indVal.numCtors == 1 then `( $(mkCIdent casesOnSameCtorName) $x1:term $x2:term rfl $alts:term* ) diff --git a/src/Lean/Meta/Constructions/CasesOnSameCtor.lean b/src/Lean/Meta/Constructions/CasesOnSameCtor.lean index a9fc675994..92f0d0d615 100644 --- a/src/Lean/Meta/Constructions/CasesOnSameCtor.lean +++ b/src/Lean/Meta/Constructions/CasesOnSameCtor.lean @@ -160,6 +160,7 @@ public def mkCasesOnSameCtor (declName : Name) (indName : Name) : MetaM Unit := let ctorApp2 := mkAppN ctor fields2 let e := mkAppN motive (is ++ #[ctorApp1, ctorApp2, (← mkEqRefl (mkNatLit i))]) let e ← mkForallFVars zs12 e + let e ← if zs12.isEmpty then mkArrow (mkConst ``Unit) e else pure e let name := match ctorName with | Name.str _ s => Name.mkSimple s | _ => Name.mkSimple s!"alt{i+1}" @@ -190,8 +191,10 @@ public def mkCasesOnSameCtor (declName : Name) (indName : Name) : MetaM Unit := let goal := alt.mvarId! let some (goal, _) ← Cases.unifyEqs? newRefls.size goal {} | throwError "unifyEqns? unexpectedly closed goal" - let [] ← goal.apply alts[i]! - | throwError "could not apply {alts[i]!} to close\n{goal}" + let hyp := alts[i]! + let hyp := if zs1.isEmpty && zs2.isEmpty then mkApp hyp (mkConst ``Unit.unit) else hyp + let [] ← goal.apply hyp + | throwError "could not apply {hyp} to close\n{goal}" mkLambdaFVars (zs1 ++ zs2) (← instantiateMVars alt) let casesOn2 := mkAppN casesOn2 alts' let casesOn2 := mkAppN casesOn2 newRefls diff --git a/tests/lean/decEqMutualInductives.lean.expected.out b/tests/lean/decEqMutualInductives.lean.expected.out index 42c9afc14d..e9c5504d2e 100644 --- a/tests/lean/decEqMutualInductives.lean.expected.out +++ b/tests/lean/decEqMutualInductives.lean.expected.out @@ -65,7 +65,7 @@ def instDecidableEqListTree.decEq_2 (x✝² : @B.ListTree✝) (x✝³ : @B.ListTree✝) : Decidable✝ (x✝² = x✝³) := match decEq✝ (B.ListTree.ctorIdx✝ x✝²) (B.ListTree.ctorIdx✝ x✝³) with | .isTrue✝¹ h✝¹ => - B.ListTree.match_on_same_ctor✝ x✝² x✝³ h✝¹ (@fun => isTrue✝¹ rfl✝) + B.ListTree.match_on_same_ctor✝ x✝² x✝³ h✝¹ (@fun () => isTrue✝¹ rfl✝) @fun a✝¹ a✝² b✝¹ b✝² => let inst✝¹ := instDecidableEqListTree.decEq_1 @a✝¹ @b✝¹; if h✝² : @a✝¹ = @b✝¹ then by subst h✝²; @@ -84,7 +84,7 @@ def instDecidableEqFoo₁.decEq_1 (x✝ : @B.Foo₁✝) (x✝¹ : @B.Foo₁✝) : Decidable✝ (x✝ = x✝¹) := match decEq✝ (B.Foo₁.ctorIdx✝ x✝) (B.Foo₁.ctorIdx✝ x✝¹) with | .isTrue✝ h✝ => - B.Foo₁.match_on_same_ctor✝ x✝ x✝¹ h✝ (@fun => isTrue✝ rfl✝) + B.Foo₁.match_on_same_ctor✝ x✝ x✝¹ h✝ (@fun () => isTrue✝ rfl✝) @fun a✝ b✝ => let inst✝ := instDecidableEqFoo₁.decEq_2 @a✝ @b✝; if h✝¹ : @a✝ = @b✝ then by subst h✝¹; exact isTrue✝¹ rfl✝¹ diff --git a/tests/lean/run/casesOnSameCtor.lean b/tests/lean/run/casesOnSameCtor.lean index 78e89b20fd..2be962b6ba 100644 --- a/tests/lean/run/casesOnSameCtor.lean +++ b/tests/lean/run/casesOnSameCtor.lean @@ -27,7 +27,7 @@ info: Vec.match_on_same_ctor.het.{u_1, u} {α : Type u} {motive : {a : Nat} → /-- info: Vec.match_on_same_ctor.{u_1, u} {α : Type u} {motive : {a : Nat} → (t t_1 : Vec α a) → t.ctorIdx = t_1.ctorIdx → Sort u_1} {a✝ : Nat} (t t✝ : Vec α a✝) - (h : t.ctorIdx = t✝.ctorIdx) (nil : motive nil nil ⋯) + (h : t.ctorIdx = t✝.ctorIdx) (nil : Unit → motive nil nil ⋯) (cons : (a : α) → {n : Nat} → (a_1 : Vec α n) → (a' : α) → (a'_1 : Vec α n) → motive (cons a a_1) (cons a' a'_1) ⋯) : motive t t✝ h -/ @@ -54,10 +54,10 @@ info: Vec.match_on_same_ctor.splitter.{u_1, u} {α : Type u} /-- info: Vec.match_on_same_ctor.eq_2.{u_1, u} {α : Type u} {motive : {a : Nat} → (t t_1 : Vec α a) → t.ctorIdx = t_1.ctorIdx → Sort u_1} (a✝ : α) (n : Nat) (a✝¹ : Vec α n) - (a'✝ : α) (a'✝¹ : Vec α n) (nil : motive nil nil ⋯) + (a'✝ : α) (a'✝¹ : Vec α n) (nil : Unit → motive nil nil ⋯) (cons : (a : α) → {n : Nat} → (a_1 : Vec α n) → (a' : α) → (a'_1 : Vec α n) → motive (cons a a_1) (cons a' a'_1) ⋯) : (match n + 1, Vec.cons a✝ a✝¹, Vec.cons a'✝ a'✝¹ with - | 0, Vec.nil, Vec.nil, ⋯ => nil + | 0, Vec.nil, Vec.nil, ⋯ => nil () | n + 1, Vec.cons a a_1, Vec.cons a' a'_1, ⋯ => cons a a_1 a' a'_1) = cons a✝ a✝¹ a'✝ a'✝¹ -/ @@ -72,7 +72,7 @@ info: Vec.match_on_same_ctor.eq_2.{u_1, u} {α : Type u} def decEqVec {α} {a} [DecidableEq α] (x : @Vec α a) (x_1 : @Vec α a) : Decidable (x = x_1) := if h : Vec.ctorIdx x = Vec.ctorIdx x_1 then - Vec.match_on_same_ctor x x_1 h (isTrue rfl) + Vec.match_on_same_ctor x x_1 h (fun _ => isTrue rfl) @fun a_1 _ a_2 b b_1 => if h_1 : @a_1 = @b then by subst h_1 @@ -137,7 +137,7 @@ run_meta mkCasesOnSameCtor `List.match_on_same_ctor ``List /-- info: List.match_on_same_ctor.{u_1, u} {α : Type u} {motive : (t t_1 : List α) → t.ctorIdx = t_1.ctorIdx → Sort u_1} - (t t✝ : List α) (h : t.ctorIdx = t✝.ctorIdx) (nil : motive [] [] ⋯) + (t t✝ : List α) (h : t.ctorIdx = t✝.ctorIdx) (nil : Unit → motive [] [] ⋯) (cons : (head : α) → (tail : List α) → (head' : α) → (tail' : List α) → motive (head :: tail) (head' :: tail') ⋯) : motive t t✝ h -/