From 222abdd43d48d397aa27bfb31b801b808f3fa2d1 Mon Sep 17 00:00:00 2001 From: Kim Morrison Date: Tue, 3 Dec 2024 15:42:17 +1100 Subject: [PATCH] feat: simprocs for other Fin operations (#6295) This PR sets up simprocs for all the remaining operations defined in `Init.Data.Fin.Basic` --- src/Lean/Meta/LitValues.lean | 2 +- .../Meta/Tactic/Simp/BuiltinSimprocs/Fin.lean | 89 +++++++++++++++++++ tests/lean/run/simprocFin.lean | 75 ++++++++++++++++ 3 files changed, 165 insertions(+), 1 deletion(-) create mode 100644 tests/lean/run/simprocFin.lean diff --git a/src/Lean/Meta/LitValues.lean b/src/Lean/Meta/LitValues.lean index fad889122d..3f28b22949 100644 --- a/src/Lean/Meta/LitValues.lean +++ b/src/Lean/Meta/LitValues.lean @@ -62,7 +62,7 @@ def getStringValue? (e : Expr) : (Option String) := | .lit (.strVal s) => some s | _ => none -/-- Return `some ⟨n, v⟩` if `e` is af `OfNat.ofNat` application encoding a `Fin n` with value `v` -/ +/-- Return `some ⟨n, v⟩` if `e` is an `OfNat.ofNat` application encoding a `Fin n` with value `v` -/ def getFinValue? (e : Expr) : MetaM (Option ((n : Nat) × Fin n)) := OptionT.run do let (v, type) ← getOfNatValue? e ``Fin let n ← getNatValue? (← whnfD type.appArg!) diff --git a/src/Lean/Meta/Tactic/Simp/BuiltinSimprocs/Fin.lean b/src/Lean/Meta/Tactic/Simp/BuiltinSimprocs/Fin.lean index 350f3cb175..39a4183d04 100644 --- a/src/Lean/Meta/Tactic/Simp/BuiltinSimprocs/Fin.lean +++ b/src/Lean/Meta/Tactic/Simp/BuiltinSimprocs/Fin.lean @@ -20,6 +20,18 @@ def fromExpr? (e : Expr) : SimpM (Option Value) := do let some ⟨n, value⟩ ← getFinValue? e | return none return some { n, value } +@[inline] def reduceOp (declName : Name) (arity : Nat) (f : Nat → Nat) (op : {n : Nat} → Fin n → Fin (f n)) (e : Expr) : SimpM DStep := do + unless e.isAppOfArity declName arity do return .continue + let some v ← fromExpr? e.appArg! | return .continue + let v' := op v.value + return .done <| toExpr v' + +@[inline] def reduceNatOp (declName : Name) (arity : Nat) (f : Nat → Nat) (op : (n : Nat) → Fin (f n)) (e : Expr) : SimpM DStep := do + unless e.isAppOfArity declName arity do return .continue + let some v ← getNatValue? e.appArg! | return .continue + let v' := op v + return .done <| toExpr v' + @[inline] def reduceBin (declName : Name) (arity : Nat) (op : {n : Nat} → Fin n → Fin n → Fin n) (e : Expr) : SimpM DStep := do unless e.isAppOfArity declName arity do return .continue let some v₁ ← fromExpr? e.appFn!.appArg! | return .continue @@ -47,12 +59,23 @@ The following code assumes users did not override the `Fin n` instances for the If they do, they must disable the following `simprocs`. -/ +builtin_dsimproc [simp, seval] reduceSucc (Fin.succ _) := reduceOp ``Fin.succ 2 (· + 1) Fin.succ +builtin_dsimproc [simp, seval] reduceRev (Fin.rev _) := reduceOp ``Fin.rev 2 (·) Fin.rev +builtin_dsimproc [simp, seval] reduceLast (Fin.last _) := reduceNatOp ``Fin.last 1 (· + 1) Fin.last + builtin_dsimproc [simp, seval] reduceAdd ((_ + _ : Fin _)) := reduceBin ``HAdd.hAdd 6 (· + ·) builtin_dsimproc [simp, seval] reduceMul ((_ * _ : Fin _)) := reduceBin ``HMul.hMul 6 (· * ·) builtin_dsimproc [simp, seval] reduceSub ((_ - _ : Fin _)) := reduceBin ``HSub.hSub 6 (· - ·) builtin_dsimproc [simp, seval] reduceDiv ((_ / _ : Fin _)) := reduceBin ``HDiv.hDiv 6 (· / ·) builtin_dsimproc [simp, seval] reduceMod ((_ % _ : Fin _)) := reduceBin ``HMod.hMod 6 (· % ·) +builtin_dsimproc [simp, seval] reduceAnd ((_ &&& _ : Fin _)) := reduceBin ``HAnd.hAnd 6 (· &&& ·) +builtin_dsimproc [simp, seval] reduceOr ((_ ||| _ : Fin _)) := reduceBin ``HOr.hOr 6 (· ||| ·) +builtin_dsimproc [simp, seval] reduceXor ((_ ^^^ _ : Fin _)) := reduceBin ``HXor.hXor 6 (· ^^^ ·) + +builtin_dsimproc [simp, seval] reduceShiftLeft ((_ <<< _ : Fin _)) := reduceBin ``HShiftLeft.hShiftLeft 6 (· <<< ·) +builtin_dsimproc [simp, seval] reduceShiftRight ((_ >>> _ : Fin _)) := reduceBin ``HShiftRight.hShiftRight 6 (· >>> ·) + builtin_simproc [simp, seval] reduceLT (( _ : Fin _) < _) := reduceBinPred ``LT.lt 4 (. < .) builtin_simproc [simp, seval] reduceLE (( _ : Fin _) ≤ _) := reduceBinPred ``LE.le 4 (. ≤ .) builtin_simproc [simp, seval] reduceGT (( _ : Fin _) > _) := reduceBinPred ``GT.gt 4 (. > .) @@ -83,4 +106,70 @@ builtin_dsimproc [simp, seval] reduceFinMk (Fin.mk _ _) := fun e => do else return .continue +builtin_dsimproc [simp, seval] reduceOfNat' (Fin.ofNat' _ _) := fun e => do + unless e.isAppOfArity ``Fin.ofNat' 3 do return .continue + let some (n + 1) ← getNatValue? e.appFn!.appFn!.appArg! | return .continue + let some k ← getNatValue? e.appArg! | return .continue + return .done <| toExpr (Fin.ofNat' (n + 1) k) + +builtin_dsimproc [simp, seval] reduceCastSucc (Fin.castSucc _) := fun e => do + unless e.isAppOfArity ``Fin.castSucc 2 do return .continue + let some k ← fromExpr? e.appArg! | return .continue + return .done <| toExpr (castSucc k.value) + +builtin_dsimproc [simp, seval] reduceCastAdd (Fin.castAdd _ _) := fun e => do + unless e.isAppOfArity ``Fin.castAdd 3 do return .continue + let some m ← getNatValue? e.appFn!.appArg! | return .continue + let some k ← fromExpr? e.appArg! | return .continue + return .done <| toExpr (castAdd m k.value) + +builtin_dsimproc [simp, seval] reduceAddNat (Fin.addNat _ _) := fun e => do + unless e.isAppOfArity ``Fin.addNat 3 do return .continue + let some k ← fromExpr? e.appFn!.appArg! | return .continue + let some m ← getNatValue? e.appArg! | return .continue + return .done <| toExpr (addNat k.value m) + +builtin_dsimproc [simp, seval] reduceNatAdd (Fin.natAdd _ _) := fun e => do + unless e.isAppOfArity ``Fin.natAdd 3 do return .continue + let some m ← getNatValue? e.appFn!.appArg! | return .continue + let some k ← fromExpr? e.appArg! | return .continue + return .done <| toExpr (natAdd m k.value) + +builtin_dsimproc [simp, seval] reduceCastLT (Fin.castLT _ _) := fun e => do + unless e.isAppOfArity ``Fin.castLT 4 do return .continue + let some n ← getNatValue? e.appFn!.appFn!.appFn!.appArg! | return .continue + let some i ← fromExpr? e.appFn!.appArg! | return .continue + if h : i.value < n then + return .done <| toExpr (castLT i.value h) + else + return .continue + +builtin_dsimproc [simp, seval] reduceCastLE (Fin.castLE _ _) := fun e => do + unless e.isAppOfArity ``Fin.castLE 4 do return .continue + let some m ← getNatValue? e.appFn!.appFn!.appArg! | return .continue + let some i ← fromExpr? e.appArg! | return .continue + if h : i.n ≤ m then + return .done <| toExpr (castLE h i.value) + else + return .continue + +-- No simproc is needed for `Fin.cast`, as for explicit numbers `Fin.cast_refl` will apply. + +builtin_dsimproc [simp, seval] reduceSubNat (Fin.subNat _ _ _) := fun e => do + unless e.isAppOfArity ``Fin.subNat 4 do return .continue + let some m ← getNatValue? e.appFn!.appFn!.appArg! | return .continue + let some i ← fromExpr? e.appFn!.appArg! | return .continue + if h : m ≤ i.value then + return .done <| toExpr (subNat m (i.value.cast (by omega : i.n = (i.n - m) + m)) h) + else + return .continue + +builtin_dsimproc [simp, seval] reducePred (Fin.pred _ _) := fun e => do + unless e.isAppOfArity ``Fin.pred 3 do return .continue + let some ⟨(_ + 1), i⟩ ← fromExpr? e.appFn!.appArg! | return .continue + if h : i ≠ 0 then + return .done <| toExpr (pred i h) + else + return .continue + end Fin diff --git a/tests/lean/run/simprocFin.lean b/tests/lean/run/simprocFin.lean new file mode 100644 index 0000000000..875287b40e --- /dev/null +++ b/tests/lean/run/simprocFin.lean @@ -0,0 +1,75 @@ +variable (n : Nat) [NeZero n] + +/- basic operations -/ + +#check_simp (3 : Fin 7).succ ~> (4 : Fin 8) +#check_simp (6 : Fin 7).succ ~> (7 : Fin 8) +#check_simp Fin.last 0 ~> (0 : Fin 1) +#check_simp Fin.last 6 ~> (6 : Fin 7) +#check_simp Fin.ofNat' 6 3 ~> (3 : Fin 6) +#check_simp Fin.ofNat' 6 37 ~> (1 : Fin 6) +#check_simp Fin.rev (0 : Fin 7) ~> (6 : Fin 7) +#check_simp Fin.rev (3 : Fin 7) ~> (3 : Fin 7) +#check_simp Fin.castSucc (0 : Fin 7) ~> (0 : Fin 8) +#check_simp Fin.castSucc (3 : Fin 7) ~> (3 : Fin 8) +#check_simp Fin.castAdd 3 (0 : Fin 7) ~> (0 : Fin 10) +#check_simp Fin.castAdd 3 (3 : Fin 7) ~> (3 : Fin 10) +#check_simp Fin.castLT (3 : Fin 10) (by decide : 3 < 5) ~> (3 : Fin 5) +#check_simp Fin.castLE (by decide : 5 ≤ 37) (3 : Fin 5) ~> (3 : Fin 37) +#check_simp Fin.pred (3 : Fin 7) (by decide) ~> (2 : Fin 6) + +/- arithmetic operation tests -/ + +#check_simp (3 : Fin 7) + (1 : Fin 7) ~> 4 +#check_simp (3 : Fin 7) + (5 : Fin 7) ~> 1 +#check_simp (3 : Fin 7) * (1 : Fin 7) ~> 3 +#check_simp (3 : Fin 7) * (3 : Fin 7) ~> 2 +#check_simp (3 : Fin 7) - (1 : Fin 7) ~> 2 +#check_simp (3 : Fin 7) - (5 : Fin 7) ~> 5 +#check_simp (3 : Fin 7) / (1 : Fin 7) ~> 3 +#check_simp (3 : Fin 7) / (5 : Fin 7) ~> 0 +#check_simp (3 : Fin 7) % (0 : Fin 7) ~> 3 +#check_simp (3 : Fin 7) % (1 : Fin 7) ~> 0 +#check_simp (3 : Fin 7) % (5 : Fin 7) ~> 3 + +#check_simp (3 : Fin n) + (5 : Fin n) !~> +#check_simp (3 : Fin n) * (5 : Fin n) !~> +#check_simp (3 : Fin n) - (5 : Fin n) !~> +#check_simp (3 : Fin n) / (5 : Fin n) !~> +#check_simp (3 : Fin n) % (5 : Fin n) !~> + +#check_simp Fin.addNat (3 : Fin 7) 3 ~> (6 : Fin 10) +#check_simp Fin.natAdd 3 (3 : Fin 7) ~> (6 : Fin 10) +#check_simp Fin.subNat 2 (3 : Fin 7) (by decide) ~> (1 : Fin 5) + +/- bitwise operations -/ + +#check_simp (3 : Fin 7) &&& (1 : Fin 7) ~> 1 +#check_simp (3 : Fin 7) ||| (1 : Fin 7) ~> 3 +#check_simp (3 : Fin 7) ^^^ (1 : Fin 7) ~> 2 +#check_simp (3 : Fin 7) <<< (1 : Fin 7) ~> 6 +#check_simp (3 : Fin 7) >>> (1 : Fin 7) ~> 1 + +/- predicate tests -/ + +#check_simp (3 : Fin 7) < (1 : Fin 7) ~> False +#check_simp (3 : Fin 7) < (5 : Fin 7) ~> True +#check_simp (3 : Fin 7) ≤ (1 : Fin 7) ~> False +#check_simp (3 : Fin 7) ≤ (5 : Fin 7) ~> True +#check_simp (3 : Fin 7) > (1 : Fin 7) ~> True +#check_simp (3 : Fin 7) > (5 : Fin 7) ~> False +#check_simp (3 : Fin 7) ≥ (1 : Fin 7) ~> True +#check_simp (3 : Fin 7) ≥ (5 : Fin 7) ~> False +#check_simp (3 : Fin 7) = (1 : Fin 7) ~> False +#check_simp (3 : Fin 7) = (5 : Fin 7) ~> False +#check_simp (3 : Fin 7) = (3 : Fin 7) ~> True +#check_simp (3 : Fin 7) ≠ (1 : Fin 7) ~> True +#check_simp (3 : Fin 7) ≠ (3 : Fin 7) ~> False +#check_simp (3 : Fin 7) ≠ (5 : Fin 7) ~> True + +#check_simp (3 : Fin 7) == (1 : Fin 7) ~> false +#check_simp (3 : Fin 7) == (3 : Fin 7) ~> true +#check_simp (3 : Fin 7) == (5 : Fin 7) ~> false +#check_simp (3 : Fin 7) != (1 : Fin 7) ~> true +#check_simp (3 : Fin 7) != (3 : Fin 7) ~> false +#check_simp (3 : Fin 7) != (5 : Fin 7) ~> true