fix: ensure simp and dsimp do not unfold too much (#6397)
This PR ensures that `simp` and `dsimp` do not unfold definitions that are not intended to be unfolded by the user. See issue #5755 for an example affected by this issue. Closes #5755 --------- Co-authored-by: Kim Morrison <kim@tqft.net>
This commit is contained in:
parent
9e30ac3265
commit
16bc6ebcb6
10 changed files with 129 additions and 34 deletions
|
|
@ -954,13 +954,18 @@ theorem size_eq_of_beq [BEq α] {xs ys : Array α} (h : xs == ys) : xs.size = ys
|
|||
rw [Bool.eq_iff_iff]
|
||||
simp +contextual
|
||||
|
||||
private theorem beq_of_beq_singleton [BEq α] {a b : α} : #[a] == #[b] → a == b := by
|
||||
intro h
|
||||
have : isEqv #[a] #[b] BEq.beq = true := h
|
||||
simp [isEqv, isEqvAux] at this
|
||||
assumption
|
||||
|
||||
@[simp] theorem reflBEq_iff [BEq α] : ReflBEq (Array α) ↔ ReflBEq α := by
|
||||
constructor
|
||||
· intro h
|
||||
constructor
|
||||
intro a
|
||||
suffices (#[a] == #[a]) = true by
|
||||
simpa only [instBEq, isEqv, isEqvAux, Bool.and_true]
|
||||
apply beq_of_beq_singleton
|
||||
simp
|
||||
· intro h
|
||||
constructor
|
||||
|
|
@ -973,11 +978,9 @@ theorem size_eq_of_beq [BEq α] {xs ys : Array α} (h : xs == ys) : xs.size = ys
|
|||
· intro a b h
|
||||
apply singleton_inj.1
|
||||
apply eq_of_beq
|
||||
simp only [instBEq, isEqv, isEqvAux]
|
||||
simpa
|
||||
simpa [instBEq, isEqv, isEqvAux]
|
||||
· intro a
|
||||
suffices (#[a] == #[a]) = true by
|
||||
simpa only [instBEq, isEqv, isEqvAux, Bool.and_true]
|
||||
apply beq_of_beq_singleton
|
||||
simp
|
||||
· intro h
|
||||
constructor
|
||||
|
|
|
|||
|
|
@ -333,7 +333,7 @@ theorem lex_eq_true_iff_exists [BEq α] (lt : α → α → Bool) :
|
|||
cases l₂ with
|
||||
| nil => simp [lex]
|
||||
| cons b l₂ =>
|
||||
simp only [lex_cons_cons, Bool.or_eq_true, Bool.and_eq_true, ih, isEqv, length_cons]
|
||||
simp [lex_cons_cons, Bool.or_eq_true, Bool.and_eq_true, ih, isEqv, length_cons]
|
||||
constructor
|
||||
· rintro (hab | ⟨hab, ⟨h₁, h₂⟩ | ⟨i, h₁, h₂, w₁, w₂⟩⟩)
|
||||
· exact .inr ⟨0, by simp [hab]⟩
|
||||
|
|
@ -397,7 +397,7 @@ theorem lex_eq_false_iff_exists [BEq α] [PartialEquivBEq α] (lt : α → α
|
|||
cases l₂ with
|
||||
| nil => simp [lex]
|
||||
| cons b l₂ =>
|
||||
simp only [lex_cons_cons, Bool.or_eq_false_iff, Bool.and_eq_false_imp, ih, isEqv,
|
||||
simp [lex_cons_cons, Bool.or_eq_false_iff, Bool.and_eq_false_imp, ih, isEqv,
|
||||
Bool.and_eq_true, length_cons]
|
||||
constructor
|
||||
· rintro ⟨hab, h⟩
|
||||
|
|
|
|||
|
|
@ -259,7 +259,7 @@ theorem zip_map (f : α → γ) (g : β → δ) :
|
|||
| [], _ => rfl
|
||||
| _, [] => by simp only [map, zip_nil_right]
|
||||
| _ :: _, _ :: _ => by
|
||||
simp only [map, zip_cons_cons, zip_map, Prod.map]; constructor
|
||||
simp only [map, zip_cons_cons, zip_map, Prod.map]; try constructor -- TODO: remove try constructor after update stage0
|
||||
|
||||
theorem zip_map_left (f : α → γ) (l₁ : List α) (l₂ : List β) :
|
||||
zip (l₁.map f) l₂ = (zip l₁ l₂).map (Prod.map f id) := by rw [← zip_map, map_id]
|
||||
|
|
|
|||
|
|
@ -47,6 +47,17 @@ def isOfScientificLit (e : Expr) : Bool :=
|
|||
def isCharLit (e : Expr) : Bool :=
|
||||
e.isAppOfArity ``Char.ofNat 1 && e.appArg!.isRawNatLit
|
||||
|
||||
/--
|
||||
Unfold definition even if it is not marked as `@[reducible]`.
|
||||
Remark: We never unfold irreducible definitions. Mathlib relies on that in the implementation of the
|
||||
command `irreducible_def`.
|
||||
-/
|
||||
private def unfoldDefinitionAny? (e : Expr) : MetaM (Option Expr) := do
|
||||
if let .const declName _ := e.getAppFn then
|
||||
if (← isIrreducible declName) then
|
||||
return none
|
||||
unfoldDefinition? e (ignoreTransparency := true)
|
||||
|
||||
private def reduceProjFn? (e : Expr) : SimpM (Option Expr) := do
|
||||
matchConst e.getAppFn (fun _ => pure none) fun cinfo _ => do
|
||||
match (← getProjectionFnInfo? cinfo.name) with
|
||||
|
|
@ -83,7 +94,7 @@ private def reduceProjFn? (e : Expr) : SimpM (Option Expr) := do
|
|||
let major := e.getArg! projInfo.numParams
|
||||
unless (← isConstructorApp major) do
|
||||
return none
|
||||
reduceProjCont? (← withDefault <| unfoldDefinition? e)
|
||||
reduceProjCont? (← unfoldDefinitionAny? e)
|
||||
else
|
||||
-- `structure` projections
|
||||
reduceProjCont? (← unfoldDefinition? e)
|
||||
|
|
@ -133,7 +144,7 @@ private def unfold? (e : Expr) : SimpM (Option Expr) := do
|
|||
if cfg.unfoldPartialApp -- If we are unfolding partial applications, ignore issue #2042
|
||||
-- When smart unfolding is enabled, and `f` supports it, we don't need to worry about issue #2042
|
||||
|| (smartUnfolding.get options && (← getEnv).contains (mkSmartUnfoldingNameFor fName)) then
|
||||
withDefault <| unfoldDefinition? e
|
||||
unfoldDefinitionAny? e
|
||||
else
|
||||
-- `We are not unfolding partial applications, and `fName` does not have smart unfolding support.
|
||||
-- Thus, we must check whether the arity of the function >= number of arguments.
|
||||
|
|
@ -142,16 +153,16 @@ private def unfold? (e : Expr) : SimpM (Option Expr) := do
|
|||
let arity := value.getNumHeadLambdas
|
||||
-- Partially applied function, return `none`. See issue #2042
|
||||
if arity > e.getAppNumArgs then return none
|
||||
withDefault <| unfoldDefinition? e
|
||||
unfoldDefinitionAny? e
|
||||
if (← isProjectionFn fName) then
|
||||
return none -- should be reduced by `reduceProjFn?`
|
||||
else if ctx.config.autoUnfold then
|
||||
if ctx.simpTheorems.isErased (.decl fName) then
|
||||
return none
|
||||
else if hasSmartUnfoldingDecl (← getEnv) fName then
|
||||
withDefault <| unfoldDefinition? e
|
||||
unfoldDefinitionAny? e
|
||||
else if (← isMatchDef fName) then
|
||||
let some value ← withDefault <| unfoldDefinition? e | return none
|
||||
let some value ← unfoldDefinitionAny? e | return none
|
||||
let .reduced value ← withSimpMetaConfig <| reduceMatcher? value | return none
|
||||
return some value
|
||||
else
|
||||
|
|
|
|||
|
|
@ -64,10 +64,38 @@ def isAuxDef (constName : Name) : MetaM Bool := do
|
|||
let env ← getEnv
|
||||
return isAuxRecursor env constName || isNoConfusion env constName
|
||||
|
||||
@[inline] private def matchConstAux {α} (e : Expr) (failK : Unit → MetaM α) (k : ConstantInfo → List Level → MetaM α) : MetaM α := do
|
||||
let .const name lvls := e
|
||||
/--
|
||||
Retrieves `ConstInfo` for `declName`.
|
||||
Remark: if `ignoreTransparency = false`, then `getUnfoldableConst?` is used.
|
||||
For example, if `ignoreTransparency = false` and `transparencyMode = .reducible` and `declName` is not reducible,
|
||||
then the result is `none`.
|
||||
-/
|
||||
private def getConstInfo? (declName : Name) (ignoreTransparency : Bool) : MetaM (Option ConstantInfo) := do
|
||||
if ignoreTransparency then
|
||||
return (← getEnv).find? declName
|
||||
else
|
||||
getUnfoldableConst? declName
|
||||
|
||||
/--
|
||||
Similar to `getConstInfo?` but using `getUnfoldableConstNoEx?`.
|
||||
-/
|
||||
private def getConstInfoNoEx? (declName : Name) (ignoreTransparency : Bool) : MetaM (Option ConstantInfo) := do
|
||||
if ignoreTransparency then
|
||||
return (← getEnv).find? declName
|
||||
else
|
||||
getUnfoldableConstNoEx? declName
|
||||
|
||||
/--
|
||||
If `e` is of the form `Expr.const declName us`, executes `k info us` if
|
||||
- `declName` is in the `Environment` and (is unfoldable or `ignoreTransparency = true`)
|
||||
- `info` is the `ConstantInfo` associated with `declName`.
|
||||
|
||||
Otherwise executes `failK`.
|
||||
-/
|
||||
@[inline] private def matchConstAux {α} (e : Expr) (failK : Unit → MetaM α) (k : ConstantInfo → List Level → MetaM α) (ignoreTransparency := false) : MetaM α := do
|
||||
let .const declName lvls := e
|
||||
| failK ()
|
||||
let (some cinfo) ← getUnfoldableConst? name
|
||||
let some cinfo ← getConstInfo? declName ignoreTransparency
|
||||
| failK ()
|
||||
k cinfo lvls
|
||||
|
||||
|
|
@ -713,11 +741,14 @@ mutual
|
|||
else
|
||||
unfoldProjInst? e
|
||||
|
||||
/-- Unfold definition using "smart unfolding" if possible. -/
|
||||
partial def unfoldDefinition? (e : Expr) : MetaM (Option Expr) :=
|
||||
/--
|
||||
Unfold definition using "smart unfolding" if possible.
|
||||
If `ignoreTransparency = true`, then the definition is unfolded even if the transparency setting does not allow it.
|
||||
-/
|
||||
partial def unfoldDefinition? (e : Expr) (ignoreTransparency := false) : MetaM (Option Expr) :=
|
||||
match e with
|
||||
| .app f _ =>
|
||||
matchConstAux f.getAppFn (fun _ => unfoldProjInstWhenInstances? e) fun fInfo fLvls => do
|
||||
matchConstAux (ignoreTransparency := ignoreTransparency) f.getAppFn (fun _ => unfoldProjInstWhenInstances? e) fun fInfo fLvls => do
|
||||
if fInfo.levelParams.length != fLvls.length then
|
||||
return none
|
||||
else
|
||||
|
|
@ -756,7 +787,8 @@ mutual
|
|||
|
||||
Remark 2: the match expression reduces reduces to `cons a xs` when the discriminants are `⟨0, h⟩` and `xs`.
|
||||
|
||||
Remark 3: this check is unnecessary in most cases, but we don't need dependent elimination to trigger the issue fixed by this extra check. Here is another example that triggers the issue fixed by this check.
|
||||
Remark 3: this check is unnecessary in most cases, but we don't need dependent elimination to trigger the issue
|
||||
fixed by this extra check. Here is another example that triggers the issue fixed by this check.
|
||||
```
|
||||
def f : Nat → Nat → Nat
|
||||
| 0, y => y
|
||||
|
|
@ -788,7 +820,7 @@ mutual
|
|||
else
|
||||
unfoldDefault ()
|
||||
| .const declName lvls => do
|
||||
let some cinfo ← getUnfoldableConstNoEx? declName | pure none
|
||||
let some cinfo ← getConstInfoNoEx? declName ignoreTransparency | pure none
|
||||
-- check smart unfolding only after `getUnfoldableConstNoEx?` because smart unfoldings have a
|
||||
-- significant chance of not existing and `Environment.contains` misses are more costly
|
||||
if smartUnfolding.get (← getOptions) && (← getEnv).contains (mkSmartUnfoldingNameFor declName) then
|
||||
|
|
|
|||
|
|
@ -16,14 +16,14 @@ namespace Vector'
|
|||
(v.snoc x).nth k = x := by
|
||||
cases k; rename_i k hk
|
||||
induction v generalizing k <;> subst h
|
||||
· simp only [nth]
|
||||
· simp! [*]
|
||||
· simp only [nth, snoc]
|
||||
· simp! [*, nth]
|
||||
|
||||
theorem nth_snoc_eq_works (k: Fin (n+1))(v : Vector' α n)
|
||||
(h: k.val = n):
|
||||
(v.snoc x).nth k = x := by
|
||||
cases k; rename_i k hk
|
||||
induction v generalizing k <;> subst h
|
||||
· simp only [nth]
|
||||
· simp[*,nth]
|
||||
· simp only [nth, snoc]
|
||||
· simp [*, nth, snoc]
|
||||
end Vector'
|
||||
|
|
|
|||
49
tests/lean/run/5755.lean
Normal file
49
tests/lean/run/5755.lean
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
inductive C : Type where
|
||||
| C1 (b : Bool) : C
|
||||
| C2 (c1 c2 : C) : C
|
||||
deriving Inhabited
|
||||
|
||||
open C
|
||||
|
||||
def id1 (b : Bool) : C := C1 b
|
||||
|
||||
def id2 (c : C) : C :=
|
||||
match c with
|
||||
| C1 b => id1 b
|
||||
| C2 c1 c2 => C2 (id2 c1) (id2 c2)
|
||||
|
||||
theorem id2_is_idempotent : id2 (id2 c) ≠ id2 c :=
|
||||
match c with
|
||||
| C1 b => by
|
||||
guard_target =ₛ id2 (id2 (C1 b)) ≠ id2 (C1 b)
|
||||
dsimp only [id2]
|
||||
guard_target =ₛ id2 (id1 b) ≠ id1 b
|
||||
sorry
|
||||
| C2 _ _ => by
|
||||
sorry
|
||||
|
||||
example : id2 (id1 b) ≠ a := by
|
||||
fail_if_success dsimp only [id2]
|
||||
dsimp only [id2, id1]
|
||||
guard_target =ₛ C1 b ≠ a
|
||||
sorry
|
||||
|
||||
|
||||
/-
|
||||
Here is another problematic example that has been fixed.
|
||||
-/
|
||||
|
||||
|
||||
def f : Nat → Nat
|
||||
| 0 => 1
|
||||
| x+1 => 2 * f x
|
||||
|
||||
def fib : Nat → Nat
|
||||
| 0 => 1
|
||||
| 1 => 1
|
||||
| x+2 => fib (x+1) + fib x
|
||||
|
||||
example : 0 + f (fib 10000) = a := by
|
||||
simp [f] -- should not trigger max rec depth
|
||||
guard_target =ₛ f (fib 10000) = a
|
||||
sorry
|
||||
|
|
@ -45,10 +45,10 @@ end DataEntry
|
|||
|
||||
abbrev Header := List (DataType × String)
|
||||
|
||||
def Header.colTypes (h : Header) : List DataType :=
|
||||
@[simp] def Header.colTypes (h : Header) : List DataType :=
|
||||
h.map fun x => x.1
|
||||
|
||||
def Header.colNames (h : Header) : List String :=
|
||||
@[simp] def Header.colNames (h : Header) : List String :=
|
||||
h.map fun x => x.2
|
||||
|
||||
abbrev Row := List DataEntry
|
||||
|
|
@ -69,7 +69,7 @@ structure DataFrame where
|
|||
|
||||
namespace DataFrame
|
||||
|
||||
def empty (header : Header := []) : DataFrame :=
|
||||
@[simp] def empty (header : Header := []) : DataFrame :=
|
||||
⟨header, [], by simp⟩
|
||||
|
||||
theorem consistentConcatOfConsistentRow
|
||||
|
|
@ -88,12 +88,12 @@ def addRow (df : DataFrame) (row : List DataEntry)
|
|||
|
||||
end DataFrame
|
||||
|
||||
def h : Header := [(TInt, "id"), (TString, "name")]
|
||||
abbrev h : Header := [(TInt, "id"), (TString, "name")]
|
||||
|
||||
def r : List Row := [[1, "alex"]]
|
||||
abbrev r : List Row := [[1, "alex"]]
|
||||
|
||||
-- this no longer works
|
||||
def df1 : DataFrame := DataFrame.mk h r
|
||||
abbrev df1 : DataFrame := DataFrame.mk h r
|
||||
|
||||
-- and this ofc breaks now
|
||||
def df2 : DataFrame := df1.addRow [2, "juddy"]
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ def Expr.concat : Expr → Expr → Expr
|
|||
theorem Expr.denote_concat (ctx : Context α) (a b : Expr) : denote ctx (concat a b) = denote ctx (Expr.op a b) := by
|
||||
induction a with
|
||||
| var i => rfl
|
||||
| op _ _ _ ih => simp [denote, ih, ctx.assoc]
|
||||
| op _ _ _ ih => simp [denote, concat, ih, ctx.assoc]
|
||||
|
||||
def Expr.flat : Expr → Expr
|
||||
| Expr.op a b => concat (flat a) (flat b)
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ theorem concatEq (xs : List α) (h : xs ≠ []) : concat (dropLast xs) (last xs
|
|||
match xs, h with
|
||||
| [], h => contradiction
|
||||
| [x], h => rfl
|
||||
| x₁::x₂::xs, h => simp [concat, last, concatEq (x₂::xs) List.noConfusion]
|
||||
| x₁::x₂::xs, h => simp [concat, dropLast, last, concatEq (x₂::xs) List.noConfusion]
|
||||
|
||||
theorem lengthCons {α} (x : α) (xs : List α) : (x::xs).length = xs.length + 1 :=
|
||||
rfl
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue