fix: remove global NatCast (Fin n) instance (#8620)

This PR removes the `NatCast (Fin n)` global instance (both the direct
instance, and the indirect one via `Lean.Grind.Semiring`), as that
instance causes causes `x < n` (for `x : Fin k`, `n : Nat`) to be
elaborated as `x < ↑n` rather than `↑x < n`, which is undesirable. Note
however that in Mathlib this happens anyway!
This commit is contained in:
Kim Morrison 2025-06-04 16:58:39 +10:00 committed by GitHub
parent c12159b519
commit 4500a7f02b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 106 additions and 31 deletions

View file

@ -319,6 +319,7 @@ theorem ofFin_ofNat (n : Nat) :
@[simp] theorem ofFin_neg {x : Fin (2 ^ w)} : ofFin (-x) = -(ofFin x) := by
rfl
open Fin.NatCast in
@[simp, norm_cast] theorem ofFin_natCast (n : Nat) : ofFin (n : Fin (2^w)) = (n : BitVec w) := by
rfl
@ -337,6 +338,7 @@ theorem toFin_zero : toFin (0 : BitVec w) = 0 := rfl
theorem toFin_one : toFin (1 : BitVec w) = 1 := by
rw [toFin_inj]; simp only [ofNat_eq_ofNat, ofFin_ofNat]
open Fin.NatCast in
@[simp, norm_cast] theorem toFin_natCast (n : Nat) : toFin (n : BitVec w) = (n : Fin (2^w)) := by
rfl

View file

@ -102,9 +102,30 @@ theorem dite_val {n : Nat} {c : Prop} [Decidable c] {x y : Fin n} :
(if c then x else y).val = if c then x.val else y.val := by
by_cases c <;> simp [*]
instance (n : Nat) [NeZero n] : NatCast (Fin n) where
namespace NatCast
/--
This is not a global instance, but may be activated locally via `open Fin.NatCast in ...`.
This is not an instance because the `binop%` elaborator assumes that
there are no non-trivial coercion loops,
but this introduces a coercion from `Nat` to `Fin n` and back.
Non-trivial loops lead to undesirable and counterintuitive elaboration behavior.
For example, for `x : Fin k` and `n : Nat`,
it causes `x < n` to be elaborated as `x < ↑n` rather than `↑x < n`,
silently introducing wraparound arithmetic.
Note: as of 2025-06-03, Mathlib has such a coercion for `Fin n` anyway!
-/
@[expose]
def instNatCast (n : Nat) [NeZero n] : NatCast (Fin n) where
natCast a := Fin.ofNat n a
attribute [scoped instance] instNatCast
end NatCast
@[expose]
def intCast [NeZero n] (a : Int) : Fin n :=
if 0 ≤ a then
@ -112,9 +133,22 @@ def intCast [NeZero n] (a : Int) : Fin n :=
else
- Fin.ofNat n a.natAbs
instance (n : Nat) [NeZero n] : IntCast (Fin n) where
namespace IntCast
/--
This is not a global instance, but may be activated locally via `open Fin.IntCast in ...`.
See the doc-string for `Fin.NatCast.instNatCast` for more details.
-/
@[expose]
def instIntCast (n : Nat) [NeZero n] : IntCast (Fin n) where
intCast := Fin.intCast
attribute [scoped instance] instIntCast
end IntCast
open IntCast in
theorem intCast_def {n : Nat} [NeZero n] (x : Int) :
(x : Fin n) = if 0 ≤ x then Fin.ofNat n x.natAbs else -Fin.ofNat n x.natAbs := rfl

View file

@ -71,7 +71,9 @@ class CommRing (α : Type u) extends Ring α, CommSemiring α
attribute [instance 100] Semiring.toAdd Semiring.toMul Semiring.toHPow Ring.toNeg Ring.toSub
-- This is a low-priority instance, to avoid conflicts with existing `OfNat`, `NatCast`, and `IntCast` instances.
attribute [instance 100] Semiring.ofNat Semiring.natCast Ring.intCast
attribute [instance 100] Semiring.ofNat
attribute [local instance] Semiring.natCast Ring.intCast
namespace Semiring

View file

@ -14,22 +14,6 @@ namespace Lean.Grind
namespace Fin
instance (n : Nat) [NeZero n] : NatCast (Fin n) where
natCast a := Fin.ofNat n a
@[expose]
def intCast [NeZero n] (a : Int) : Fin n :=
if 0 ≤ a then
Fin.ofNat n a.natAbs
else
- Fin.ofNat n a.natAbs
instance (n : Nat) [NeZero n] : IntCast (Fin n) where
intCast := Fin.intCast
theorem intCast_def {n : Nat} [NeZero n] (x : Int) :
(x : Fin n) = if 0 ≤ x then Fin.ofNat n x.natAbs else -Fin.ofNat n x.natAbs := rfl
-- TODO: we should replace this at runtime with either repeated squaring,
-- or a GMP accelerated function.
@[expose]
@ -78,18 +62,22 @@ theorem sub_eq_add_neg [NeZero n] (a b : Fin n) : a - b = a + -b := by
cases a; cases b; simp [Fin.neg_def, Fin.sub_def, Fin.add_def, Nat.add_comm]
private theorem neg_neg [NeZero n] (a : Fin n) : - - a = a := by
cases a; simp [Fin.neg_def, Fin.sub_def];
cases a; simp [Fin.neg_def, Fin.sub_def]
next a h => cases a; simp; next a =>
rw [Nat.self_sub_mod n (a+1)]
have : NeZero (n - (a + 1)) := ⟨by omega⟩
rw [Nat.self_sub_mod, Nat.sub_sub_eq_min, Nat.min_eq_right (Nat.le_of_lt h)]
open Fin.NatCast Fin.IntCast in
theorem intCast_neg [NeZero n] (i : Int) : Int.cast (R := Fin n) (-i) = - Int.cast (R := Fin n) i := by
simp [Int.cast, IntCast.intCast, Fin.intCast]; split <;> split <;> try omega
simp [Int.cast, IntCast.intCast, Fin.intCast]
split <;> split <;> try omega
next h₁ h₂ => simp [Int.le_antisymm h₁ h₂, Fin.neg_def]
next => simp [Fin.neg_neg]
instance (n : Nat) [NeZero n] : CommRing (Fin n) where
natCast := Fin.NatCast.instNatCast n
intCast := Fin.IntCast.instIntCast n
add_assoc := Fin.add_assoc
add_comm := Fin.add_comm
add_zero := Fin.add_zero

View file

@ -15,6 +15,9 @@ import Init.Grind.CommRing.Basic
namespace Lean.Grind
namespace CommRing
-- These are no longer global instances, so we need to turn them on here.
attribute [local instance] Semiring.natCast Ring.intCast
abbrev Var := Nat
inductive Expr where

View file

@ -58,21 +58,37 @@ private def getPowFn (type : Expr) (u : Level) (semiringInst : Expr) : GoalM Exp
internalizeFn <| mkApp4 (mkConst ``HPow.hPow [u, 0, u]) type Nat.mkType type inst
private def getIntCastFn (type : Expr) (u : Level) (ringInst : Expr) : GoalM Expr := do
let instType := mkApp (mkConst ``IntCast [u]) type
let .some inst ← trySynthInstance instType |
throwError "failed to find instance for ring intCast{indentExpr instType}"
let inst' := mkApp2 (mkConst ``Grind.Ring.intCast [u]) type ringInst
unless (← withDefault <| isDefEq inst inst') do
throwError "instance for intCast{indentExpr inst}\nis not definitionally equal to the `Grind.Ring` one{indentExpr inst'}"
let instType := mkApp (mkConst ``IntCast [u]) type
-- Note that `Ring.intCast` is not registered as a global instance
-- (to avoid introducing unwanted coercions)
-- so merely having a `Ring α` instance
-- does not guarantee that an `IntCast α` will be available.
-- When both are present we verify that they are defeq,
-- and otherwise fall back to the field of the `Ring α` instance that we already have.
let inst ← match (← trySynthInstance instType).toOption with
| none => pure inst'
| some inst =>
unless (← withDefault <| isDefEq inst inst') do
throwError "instance for intCast{indentExpr inst}\nis not definitionally equal to the `Grind.Ring` one{indentExpr inst'}"
pure inst
internalizeFn <| mkApp2 (mkConst ``IntCast.intCast [u]) type inst
private def getNatCastFn (type : Expr) (u : Level) (semiringInst : Expr) : GoalM Expr := do
let instType := mkApp (mkConst ``NatCast [u]) type
let .some inst ← trySynthInstance instType |
throwError "failed to find instance for ring natCast{indentExpr instType}"
let inst' := mkApp2 (mkConst ``Grind.Semiring.natCast [u]) type semiringInst
unless (← withDefault <| isDefEq inst inst') do
throwError "instance for natCast{indentExpr inst}\nis not definitionally equal to the `Grind.Semiring` one{indentExpr inst'}"
let instType := mkApp (mkConst ``NatCast [u]) type
-- Note that `Semiring.natCast` is not registered as a global instance
-- (to avoid introducing unwanted coercions)
-- so merely having a `Semiring α` instance
-- does not guarantee that an `NatCast α` will be available.
-- When both are present we verify that they are defeq,
-- and otherwise fall back to the field of the `Semiring α` instance that we already have.
let inst ← match (← trySynthInstance instType).toOption with
| none => pure inst'
| some inst =>
unless (← withDefault <| isDefEq inst inst') do
throwError "instance for natCast{indentExpr inst}\nis not definitionally equal to the `Grind.Semiring` one{indentExpr inst'}"
pure inst
internalizeFn <| mkApp2 (mkConst ``NatCast.natCast [u]) type inst
/--

View file

@ -0,0 +1,30 @@
set_option pp.mvars false
-- We first verify that there is no global coercion from `Nat` to `Fin n`.
-- Such a coercion would frequently introduce unexpected modular arithmetic.
/--
error: type mismatch
n
has type
Nat : Type
but is expected to have type
Fin 3 : Type
---
info: fun n => sorry : (n : Nat) → ?_ n
-/
#guard_msgs in #check fun (n : Nat) => (n : Fin 3)
-- This instance is available via `open Fin.NatCast in ...`
section
open Fin.NatCast
variable (m : Nat) (n : Fin 3)
/-- info: n < ↑m : Prop -/
#guard_msgs in #check n < m
end
example (x : Fin (n + 1)) (h : x < n) : Fin (n + 1) := x.succ.castLT (by simp [h])