diff --git a/src/Init/Util.lean b/src/Init/Util.lean index f00c12ff03..5c0d06fbd4 100644 --- a/src/Init/Util.lean +++ b/src/Init/Util.lean @@ -51,6 +51,13 @@ if ptrAddrUnsafe a == ptrAddrUnsafe b then true else k () def withPtrEq {α : Type u} (a b : α) (k : Unit → Bool) (h : a = b → k () = true) : Bool := k () +-- `withPtrEq` for `DecidableEq` +@[inline] def withPtrEqDecEq {α : Type u} (a b : α) (k : Unit → Decidable (a = b)) : Decidable (a = b) := +let aux := withPtrEq a b (fun _ => @decide (a = b) (k ())) (fun h => @decideEqTrue _ (k ()) h); +match aux, rfl : forall x (h : _ = x), _ with +| true, h => isTrue (@ofDecideEqTrue _ (k ()) h) +| false, h => isFalse (@ofDecideEqFalse _ (k ()) h) + @[implementedBy withPtrAddrUnsafe] def withPtrAddr {α : Type u} {β : Type v} (a : α) (k : USize → β) (h : ∀ u₁ u₂, k u₁ = k u₂) : β := k 0 diff --git a/tests/lean/run/listDecEq.lean b/tests/lean/run/listDecEq.lean new file mode 100644 index 0000000000..f55fdb3280 --- /dev/null +++ b/tests/lean/run/listDecEq.lean @@ -0,0 +1,16 @@ +-- List decidable equality using `withPtrEqDecEq` + +def listDecEqAux {α} [s : DecidableEq α] : ∀ (as bs : List α), Decidable (as = bs) +| [], [] => isTrue rfl +| [], b::bs => isFalse $ fun h => List.noConfusion h +| a::as, [] => isFalse $ fun h => List.noConfusion h +| a::as, b::bs => + match s a b with + | isTrue h₁ => + match withPtrEqDecEq as bs (fun _ => listDecEqAux as bs) with + | isTrue h₂ => isTrue $ h₁ ▸ h₂ ▸ rfl + | isFalse h₂ => isFalse $ fun h => List.noConfusion h $ fun _ h₃ => absurd h₃ h₂ + | isFalse h₁ => isFalse $ fun h => List.noConfusion h $ fun h₂ _ => absurd h₂ h₁ + +instance List.optimizedDecEq {α} [DecidableEq α] : DecidableEq (List α) := +fun a b => withPtrEqDecEq a b (fun _ => listDecEqAux a b)