feat: simprocs for other Fin operations (#6295)
This PR sets up simprocs for all the remaining operations defined in `Init.Data.Fin.Basic`
This commit is contained in:
parent
490be9282e
commit
222abdd43d
3 changed files with 165 additions and 1 deletions
|
|
@ -62,7 +62,7 @@ def getStringValue? (e : Expr) : (Option String) :=
|
|||
| .lit (.strVal s) => some s
|
||||
| _ => none
|
||||
|
||||
/-- Return `some ⟨n, v⟩` if `e` is af `OfNat.ofNat` application encoding a `Fin n` with value `v` -/
|
||||
/-- Return `some ⟨n, v⟩` if `e` is an `OfNat.ofNat` application encoding a `Fin n` with value `v` -/
|
||||
def getFinValue? (e : Expr) : MetaM (Option ((n : Nat) × Fin n)) := OptionT.run do
|
||||
let (v, type) ← getOfNatValue? e ``Fin
|
||||
let n ← getNatValue? (← whnfD type.appArg!)
|
||||
|
|
|
|||
|
|
@ -20,6 +20,18 @@ def fromExpr? (e : Expr) : SimpM (Option Value) := do
|
|||
let some ⟨n, value⟩ ← getFinValue? e | return none
|
||||
return some { n, value }
|
||||
|
||||
@[inline] def reduceOp (declName : Name) (arity : Nat) (f : Nat → Nat) (op : {n : Nat} → Fin n → Fin (f n)) (e : Expr) : SimpM DStep := do
|
||||
unless e.isAppOfArity declName arity do return .continue
|
||||
let some v ← fromExpr? e.appArg! | return .continue
|
||||
let v' := op v.value
|
||||
return .done <| toExpr v'
|
||||
|
||||
@[inline] def reduceNatOp (declName : Name) (arity : Nat) (f : Nat → Nat) (op : (n : Nat) → Fin (f n)) (e : Expr) : SimpM DStep := do
|
||||
unless e.isAppOfArity declName arity do return .continue
|
||||
let some v ← getNatValue? e.appArg! | return .continue
|
||||
let v' := op v
|
||||
return .done <| toExpr v'
|
||||
|
||||
@[inline] def reduceBin (declName : Name) (arity : Nat) (op : {n : Nat} → Fin n → Fin n → Fin n) (e : Expr) : SimpM DStep := do
|
||||
unless e.isAppOfArity declName arity do return .continue
|
||||
let some v₁ ← fromExpr? e.appFn!.appArg! | return .continue
|
||||
|
|
@ -47,12 +59,23 @@ The following code assumes users did not override the `Fin n` instances for the
|
|||
If they do, they must disable the following `simprocs`.
|
||||
-/
|
||||
|
||||
builtin_dsimproc [simp, seval] reduceSucc (Fin.succ _) := reduceOp ``Fin.succ 2 (· + 1) Fin.succ
|
||||
builtin_dsimproc [simp, seval] reduceRev (Fin.rev _) := reduceOp ``Fin.rev 2 (·) Fin.rev
|
||||
builtin_dsimproc [simp, seval] reduceLast (Fin.last _) := reduceNatOp ``Fin.last 1 (· + 1) Fin.last
|
||||
|
||||
builtin_dsimproc [simp, seval] reduceAdd ((_ + _ : Fin _)) := reduceBin ``HAdd.hAdd 6 (· + ·)
|
||||
builtin_dsimproc [simp, seval] reduceMul ((_ * _ : Fin _)) := reduceBin ``HMul.hMul 6 (· * ·)
|
||||
builtin_dsimproc [simp, seval] reduceSub ((_ - _ : Fin _)) := reduceBin ``HSub.hSub 6 (· - ·)
|
||||
builtin_dsimproc [simp, seval] reduceDiv ((_ / _ : Fin _)) := reduceBin ``HDiv.hDiv 6 (· / ·)
|
||||
builtin_dsimproc [simp, seval] reduceMod ((_ % _ : Fin _)) := reduceBin ``HMod.hMod 6 (· % ·)
|
||||
|
||||
builtin_dsimproc [simp, seval] reduceAnd ((_ &&& _ : Fin _)) := reduceBin ``HAnd.hAnd 6 (· &&& ·)
|
||||
builtin_dsimproc [simp, seval] reduceOr ((_ ||| _ : Fin _)) := reduceBin ``HOr.hOr 6 (· ||| ·)
|
||||
builtin_dsimproc [simp, seval] reduceXor ((_ ^^^ _ : Fin _)) := reduceBin ``HXor.hXor 6 (· ^^^ ·)
|
||||
|
||||
builtin_dsimproc [simp, seval] reduceShiftLeft ((_ <<< _ : Fin _)) := reduceBin ``HShiftLeft.hShiftLeft 6 (· <<< ·)
|
||||
builtin_dsimproc [simp, seval] reduceShiftRight ((_ >>> _ : Fin _)) := reduceBin ``HShiftRight.hShiftRight 6 (· >>> ·)
|
||||
|
||||
builtin_simproc [simp, seval] reduceLT (( _ : Fin _) < _) := reduceBinPred ``LT.lt 4 (. < .)
|
||||
builtin_simproc [simp, seval] reduceLE (( _ : Fin _) ≤ _) := reduceBinPred ``LE.le 4 (. ≤ .)
|
||||
builtin_simproc [simp, seval] reduceGT (( _ : Fin _) > _) := reduceBinPred ``GT.gt 4 (. > .)
|
||||
|
|
@ -83,4 +106,70 @@ builtin_dsimproc [simp, seval] reduceFinMk (Fin.mk _ _) := fun e => do
|
|||
else
|
||||
return .continue
|
||||
|
||||
builtin_dsimproc [simp, seval] reduceOfNat' (Fin.ofNat' _ _) := fun e => do
|
||||
unless e.isAppOfArity ``Fin.ofNat' 3 do return .continue
|
||||
let some (n + 1) ← getNatValue? e.appFn!.appFn!.appArg! | return .continue
|
||||
let some k ← getNatValue? e.appArg! | return .continue
|
||||
return .done <| toExpr (Fin.ofNat' (n + 1) k)
|
||||
|
||||
builtin_dsimproc [simp, seval] reduceCastSucc (Fin.castSucc _) := fun e => do
|
||||
unless e.isAppOfArity ``Fin.castSucc 2 do return .continue
|
||||
let some k ← fromExpr? e.appArg! | return .continue
|
||||
return .done <| toExpr (castSucc k.value)
|
||||
|
||||
builtin_dsimproc [simp, seval] reduceCastAdd (Fin.castAdd _ _) := fun e => do
|
||||
unless e.isAppOfArity ``Fin.castAdd 3 do return .continue
|
||||
let some m ← getNatValue? e.appFn!.appArg! | return .continue
|
||||
let some k ← fromExpr? e.appArg! | return .continue
|
||||
return .done <| toExpr (castAdd m k.value)
|
||||
|
||||
builtin_dsimproc [simp, seval] reduceAddNat (Fin.addNat _ _) := fun e => do
|
||||
unless e.isAppOfArity ``Fin.addNat 3 do return .continue
|
||||
let some k ← fromExpr? e.appFn!.appArg! | return .continue
|
||||
let some m ← getNatValue? e.appArg! | return .continue
|
||||
return .done <| toExpr (addNat k.value m)
|
||||
|
||||
builtin_dsimproc [simp, seval] reduceNatAdd (Fin.natAdd _ _) := fun e => do
|
||||
unless e.isAppOfArity ``Fin.natAdd 3 do return .continue
|
||||
let some m ← getNatValue? e.appFn!.appArg! | return .continue
|
||||
let some k ← fromExpr? e.appArg! | return .continue
|
||||
return .done <| toExpr (natAdd m k.value)
|
||||
|
||||
builtin_dsimproc [simp, seval] reduceCastLT (Fin.castLT _ _) := fun e => do
|
||||
unless e.isAppOfArity ``Fin.castLT 4 do return .continue
|
||||
let some n ← getNatValue? e.appFn!.appFn!.appFn!.appArg! | return .continue
|
||||
let some i ← fromExpr? e.appFn!.appArg! | return .continue
|
||||
if h : i.value < n then
|
||||
return .done <| toExpr (castLT i.value h)
|
||||
else
|
||||
return .continue
|
||||
|
||||
builtin_dsimproc [simp, seval] reduceCastLE (Fin.castLE _ _) := fun e => do
|
||||
unless e.isAppOfArity ``Fin.castLE 4 do return .continue
|
||||
let some m ← getNatValue? e.appFn!.appFn!.appArg! | return .continue
|
||||
let some i ← fromExpr? e.appArg! | return .continue
|
||||
if h : i.n ≤ m then
|
||||
return .done <| toExpr (castLE h i.value)
|
||||
else
|
||||
return .continue
|
||||
|
||||
-- No simproc is needed for `Fin.cast`, as for explicit numbers `Fin.cast_refl` will apply.
|
||||
|
||||
builtin_dsimproc [simp, seval] reduceSubNat (Fin.subNat _ _ _) := fun e => do
|
||||
unless e.isAppOfArity ``Fin.subNat 4 do return .continue
|
||||
let some m ← getNatValue? e.appFn!.appFn!.appArg! | return .continue
|
||||
let some i ← fromExpr? e.appFn!.appArg! | return .continue
|
||||
if h : m ≤ i.value then
|
||||
return .done <| toExpr (subNat m (i.value.cast (by omega : i.n = (i.n - m) + m)) h)
|
||||
else
|
||||
return .continue
|
||||
|
||||
builtin_dsimproc [simp, seval] reducePred (Fin.pred _ _) := fun e => do
|
||||
unless e.isAppOfArity ``Fin.pred 3 do return .continue
|
||||
let some ⟨(_ + 1), i⟩ ← fromExpr? e.appFn!.appArg! | return .continue
|
||||
if h : i ≠ 0 then
|
||||
return .done <| toExpr (pred i h)
|
||||
else
|
||||
return .continue
|
||||
|
||||
end Fin
|
||||
|
|
|
|||
75
tests/lean/run/simprocFin.lean
Normal file
75
tests/lean/run/simprocFin.lean
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
variable (n : Nat) [NeZero n]
|
||||
|
||||
/- basic operations -/
|
||||
|
||||
#check_simp (3 : Fin 7).succ ~> (4 : Fin 8)
|
||||
#check_simp (6 : Fin 7).succ ~> (7 : Fin 8)
|
||||
#check_simp Fin.last 0 ~> (0 : Fin 1)
|
||||
#check_simp Fin.last 6 ~> (6 : Fin 7)
|
||||
#check_simp Fin.ofNat' 6 3 ~> (3 : Fin 6)
|
||||
#check_simp Fin.ofNat' 6 37 ~> (1 : Fin 6)
|
||||
#check_simp Fin.rev (0 : Fin 7) ~> (6 : Fin 7)
|
||||
#check_simp Fin.rev (3 : Fin 7) ~> (3 : Fin 7)
|
||||
#check_simp Fin.castSucc (0 : Fin 7) ~> (0 : Fin 8)
|
||||
#check_simp Fin.castSucc (3 : Fin 7) ~> (3 : Fin 8)
|
||||
#check_simp Fin.castAdd 3 (0 : Fin 7) ~> (0 : Fin 10)
|
||||
#check_simp Fin.castAdd 3 (3 : Fin 7) ~> (3 : Fin 10)
|
||||
#check_simp Fin.castLT (3 : Fin 10) (by decide : 3 < 5) ~> (3 : Fin 5)
|
||||
#check_simp Fin.castLE (by decide : 5 ≤ 37) (3 : Fin 5) ~> (3 : Fin 37)
|
||||
#check_simp Fin.pred (3 : Fin 7) (by decide) ~> (2 : Fin 6)
|
||||
|
||||
/- arithmetic operation tests -/
|
||||
|
||||
#check_simp (3 : Fin 7) + (1 : Fin 7) ~> 4
|
||||
#check_simp (3 : Fin 7) + (5 : Fin 7) ~> 1
|
||||
#check_simp (3 : Fin 7) * (1 : Fin 7) ~> 3
|
||||
#check_simp (3 : Fin 7) * (3 : Fin 7) ~> 2
|
||||
#check_simp (3 : Fin 7) - (1 : Fin 7) ~> 2
|
||||
#check_simp (3 : Fin 7) - (5 : Fin 7) ~> 5
|
||||
#check_simp (3 : Fin 7) / (1 : Fin 7) ~> 3
|
||||
#check_simp (3 : Fin 7) / (5 : Fin 7) ~> 0
|
||||
#check_simp (3 : Fin 7) % (0 : Fin 7) ~> 3
|
||||
#check_simp (3 : Fin 7) % (1 : Fin 7) ~> 0
|
||||
#check_simp (3 : Fin 7) % (5 : Fin 7) ~> 3
|
||||
|
||||
#check_simp (3 : Fin n) + (5 : Fin n) !~>
|
||||
#check_simp (3 : Fin n) * (5 : Fin n) !~>
|
||||
#check_simp (3 : Fin n) - (5 : Fin n) !~>
|
||||
#check_simp (3 : Fin n) / (5 : Fin n) !~>
|
||||
#check_simp (3 : Fin n) % (5 : Fin n) !~>
|
||||
|
||||
#check_simp Fin.addNat (3 : Fin 7) 3 ~> (6 : Fin 10)
|
||||
#check_simp Fin.natAdd 3 (3 : Fin 7) ~> (6 : Fin 10)
|
||||
#check_simp Fin.subNat 2 (3 : Fin 7) (by decide) ~> (1 : Fin 5)
|
||||
|
||||
/- bitwise operations -/
|
||||
|
||||
#check_simp (3 : Fin 7) &&& (1 : Fin 7) ~> 1
|
||||
#check_simp (3 : Fin 7) ||| (1 : Fin 7) ~> 3
|
||||
#check_simp (3 : Fin 7) ^^^ (1 : Fin 7) ~> 2
|
||||
#check_simp (3 : Fin 7) <<< (1 : Fin 7) ~> 6
|
||||
#check_simp (3 : Fin 7) >>> (1 : Fin 7) ~> 1
|
||||
|
||||
/- predicate tests -/
|
||||
|
||||
#check_simp (3 : Fin 7) < (1 : Fin 7) ~> False
|
||||
#check_simp (3 : Fin 7) < (5 : Fin 7) ~> True
|
||||
#check_simp (3 : Fin 7) ≤ (1 : Fin 7) ~> False
|
||||
#check_simp (3 : Fin 7) ≤ (5 : Fin 7) ~> True
|
||||
#check_simp (3 : Fin 7) > (1 : Fin 7) ~> True
|
||||
#check_simp (3 : Fin 7) > (5 : Fin 7) ~> False
|
||||
#check_simp (3 : Fin 7) ≥ (1 : Fin 7) ~> True
|
||||
#check_simp (3 : Fin 7) ≥ (5 : Fin 7) ~> False
|
||||
#check_simp (3 : Fin 7) = (1 : Fin 7) ~> False
|
||||
#check_simp (3 : Fin 7) = (5 : Fin 7) ~> False
|
||||
#check_simp (3 : Fin 7) = (3 : Fin 7) ~> True
|
||||
#check_simp (3 : Fin 7) ≠ (1 : Fin 7) ~> True
|
||||
#check_simp (3 : Fin 7) ≠ (3 : Fin 7) ~> False
|
||||
#check_simp (3 : Fin 7) ≠ (5 : Fin 7) ~> True
|
||||
|
||||
#check_simp (3 : Fin 7) == (1 : Fin 7) ~> false
|
||||
#check_simp (3 : Fin 7) == (3 : Fin 7) ~> true
|
||||
#check_simp (3 : Fin 7) == (5 : Fin 7) ~> false
|
||||
#check_simp (3 : Fin 7) != (1 : Fin 7) ~> true
|
||||
#check_simp (3 : Fin 7) != (3 : Fin 7) ~> false
|
||||
#check_simp (3 : Fin 7) != (5 : Fin 7) ~> true
|
||||
Loading…
Add table
Reference in a new issue