From 8d7f0ea2f250e7adaece1636b8ec5272b95d2ccd Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 7 Feb 2022 17:24:32 -0800 Subject: [PATCH] feat: add `removeUnnecessaryCasts` see #988 --- src/Lean/Meta/Tactic/Simp/Main.lean | 25 +++++++++++++++++++++++-- tests/lean/arrayGetU.lean | 17 +++++++++++++++++ tests/lean/arrayGetU.lean.expected.out | 14 ++++++++++++++ 3 files changed, 54 insertions(+), 2 deletions(-) create mode 100644 tests/lean/arrayGetU.lean create mode 100644 tests/lean/arrayGetU.lean.expected.out diff --git a/src/Lean/Meta/Tactic/Simp/Main.lean b/src/Lean/Meta/Tactic/Simp/Main.lean index 4825b57338..26b2342ce0 100644 --- a/src/Lean/Meta/Tactic/Simp/Main.lean +++ b/src/Lean/Meta/Tactic/Simp/Main.lean @@ -142,9 +142,30 @@ def getSimpLetCase (n : Name) (t : Expr) (v : Expr) (b : Expr) : MetaM SimpLetCa else return SimpLetCase.dep +#check Eq.ndrec + /-- Given the application `e`, remove unnecessary casts of the form `Eq.rec a rfl` and `Eq.ndrec a rfl`. -/ -def removeUnnecessaryCasts (e : Expr) : MetaM Expr := - return e -- TODO +partial def removeUnnecessaryCasts (e : Expr) : MetaM Expr := do + let mut args := e.getAppArgs + let mut modified := false + for i in [:args.size] do + let arg := args[i] + if isDummyEqRec arg then + args := args.set! i (elimDummyEqRec arg) + modified := true + if modified then + return mkAppN e.getAppFn args + else + return e +where + isDummyEqRec (e : Expr) : Bool := + (e.isAppOfArity ``Eq.rec 6 || e.isAppOfArity ``Eq.ndrec 6) && e.appArg!.isAppOf ``Eq.refl + + elimDummyEqRec (e : Expr) : Expr := + if isDummyEqRec e then + elimDummyEqRec e.appFn!.appFn!.appArg! + else + e partial def simp (e : Expr) : M Result := withIncRecDepth do checkMaxHeartbeats "simp" diff --git a/tests/lean/arrayGetU.lean b/tests/lean/arrayGetU.lean new file mode 100644 index 0000000000..8b926f3ecc --- /dev/null +++ b/tests/lean/arrayGetU.lean @@ -0,0 +1,17 @@ +def f (a : Array Nat) (i : Nat) (v : Nat) (h : i < a.size) : Array Nat := + a.set ⟨i, h⟩ (a.get ⟨i, h⟩ + v) + +set_option pp.proofs true + +theorem ex1 (h₃ : i = j) : f a i (0 + v) h₁ = f a j v h₂ := by + simp + trace_state + simp [h₃] + +theorem ex2 (h₃ : i = j) : f a (0 + i) (0 + v) h₁ = f a j v h₂ := by + simp + trace_state + simp [h₃] + +theorem ex3 (h₃ : i = j) : f a (0 + i) (0 + v) h₁ = f a j v h₂ := by + simp [h₃] diff --git a/tests/lean/arrayGetU.lean.expected.out b/tests/lean/arrayGetU.lean.expected.out new file mode 100644 index 0000000000..c24c94975e --- /dev/null +++ b/tests/lean/arrayGetU.lean.expected.out @@ -0,0 +1,14 @@ +i j : Nat +a : Array Nat +v : Nat +h₁ : i < Array.size a +h₂ : j < Array.size a +h₃ : i = j +⊢ f a i v h₁ = f a j v h₂ +i j : Nat +a : Array Nat +v : Nat +h₁ : 0 + i < Array.size a +h₂ : j < Array.size a +h₃ : i = j +⊢ f a i v (Nat.zero_add i ▸ h₁) = f a j v h₂