fix: ensure simp and dsimp do not unfold too much (#6397)

This PR ensures that `simp` and `dsimp` do not unfold definitions that
are not intended to be unfolded by the user. See issue #5755 for an
example affected by this issue.

Closes #5755

---------

Co-authored-by: Kim Morrison <kim@tqft.net>
This commit is contained in:
Leonardo de Moura 2024-12-21 05:16:15 +01:00 committed by GitHub
parent 9e30ac3265
commit 16bc6ebcb6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 129 additions and 34 deletions

View file

@ -954,13 +954,18 @@ theorem size_eq_of_beq [BEq α] {xs ys : Array α} (h : xs == ys) : xs.size = ys
rw [Bool.eq_iff_iff]
simp +contextual
private theorem beq_of_beq_singleton [BEq α] {a b : α} : #[a] == #[b] → a == b := by
intro h
have : isEqv #[a] #[b] BEq.beq = true := h
simp [isEqv, isEqvAux] at this
assumption
@[simp] theorem reflBEq_iff [BEq α] : ReflBEq (Array α) ↔ ReflBEq α := by
constructor
· intro h
constructor
intro a
suffices (#[a] == #[a]) = true by
simpa only [instBEq, isEqv, isEqvAux, Bool.and_true]
apply beq_of_beq_singleton
simp
· intro h
constructor
@ -973,11 +978,9 @@ theorem size_eq_of_beq [BEq α] {xs ys : Array α} (h : xs == ys) : xs.size = ys
· intro a b h
apply singleton_inj.1
apply eq_of_beq
simp only [instBEq, isEqv, isEqvAux]
simpa
simpa [instBEq, isEqv, isEqvAux]
· intro a
suffices (#[a] == #[a]) = true by
simpa only [instBEq, isEqv, isEqvAux, Bool.and_true]
apply beq_of_beq_singleton
simp
· intro h
constructor

View file

@ -333,7 +333,7 @@ theorem lex_eq_true_iff_exists [BEq α] (lt : αα → Bool) :
cases l₂ with
| nil => simp [lex]
| cons b l₂ =>
simp only [lex_cons_cons, Bool.or_eq_true, Bool.and_eq_true, ih, isEqv, length_cons]
simp [lex_cons_cons, Bool.or_eq_true, Bool.and_eq_true, ih, isEqv, length_cons]
constructor
· rintro (hab | ⟨hab, ⟨h₁, h₂⟩ | ⟨i, h₁, h₂, w₁, w₂⟩⟩)
· exact .inr ⟨0, by simp [hab]⟩
@ -397,7 +397,7 @@ theorem lex_eq_false_iff_exists [BEq α] [PartialEquivBEq α] (lt : αα
cases l₂ with
| nil => simp [lex]
| cons b l₂ =>
simp only [lex_cons_cons, Bool.or_eq_false_iff, Bool.and_eq_false_imp, ih, isEqv,
simp [lex_cons_cons, Bool.or_eq_false_iff, Bool.and_eq_false_imp, ih, isEqv,
Bool.and_eq_true, length_cons]
constructor
· rintro ⟨hab, h⟩

View file

@ -259,7 +259,7 @@ theorem zip_map (f : αγ) (g : β → δ) :
| [], _ => rfl
| _, [] => by simp only [map, zip_nil_right]
| _ :: _, _ :: _ => by
simp only [map, zip_cons_cons, zip_map, Prod.map]; constructor
simp only [map, zip_cons_cons, zip_map, Prod.map]; try constructor -- TODO: remove try constructor after update stage0
theorem zip_map_left (f : αγ) (l₁ : List α) (l₂ : List β) :
zip (l₁.map f) l₂ = (zip l₁ l₂).map (Prod.map f id) := by rw [← zip_map, map_id]

View file

@ -47,6 +47,17 @@ def isOfScientificLit (e : Expr) : Bool :=
def isCharLit (e : Expr) : Bool :=
e.isAppOfArity ``Char.ofNat 1 && e.appArg!.isRawNatLit
/--
Unfold definition even if it is not marked as `@[reducible]`.
Remark: We never unfold irreducible definitions. Mathlib relies on that in the implementation of the
command `irreducible_def`.
-/
private def unfoldDefinitionAny? (e : Expr) : MetaM (Option Expr) := do
if let .const declName _ := e.getAppFn then
if (← isIrreducible declName) then
return none
unfoldDefinition? e (ignoreTransparency := true)
private def reduceProjFn? (e : Expr) : SimpM (Option Expr) := do
matchConst e.getAppFn (fun _ => pure none) fun cinfo _ => do
match (← getProjectionFnInfo? cinfo.name) with
@ -83,7 +94,7 @@ private def reduceProjFn? (e : Expr) : SimpM (Option Expr) := do
let major := e.getArg! projInfo.numParams
unless (← isConstructorApp major) do
return none
reduceProjCont? (← withDefault <| unfoldDefinition? e)
reduceProjCont? (← unfoldDefinitionAny? e)
else
-- `structure` projections
reduceProjCont? (← unfoldDefinition? e)
@ -133,7 +144,7 @@ private def unfold? (e : Expr) : SimpM (Option Expr) := do
if cfg.unfoldPartialApp -- If we are unfolding partial applications, ignore issue #2042
-- When smart unfolding is enabled, and `f` supports it, we don't need to worry about issue #2042
|| (smartUnfolding.get options && (← getEnv).contains (mkSmartUnfoldingNameFor fName)) then
withDefault <| unfoldDefinition? e
unfoldDefinitionAny? e
else
-- `We are not unfolding partial applications, and `fName` does not have smart unfolding support.
-- Thus, we must check whether the arity of the function >= number of arguments.
@ -142,16 +153,16 @@ private def unfold? (e : Expr) : SimpM (Option Expr) := do
let arity := value.getNumHeadLambdas
-- Partially applied function, return `none`. See issue #2042
if arity > e.getAppNumArgs then return none
withDefault <| unfoldDefinition? e
unfoldDefinitionAny? e
if (← isProjectionFn fName) then
return none -- should be reduced by `reduceProjFn?`
else if ctx.config.autoUnfold then
if ctx.simpTheorems.isErased (.decl fName) then
return none
else if hasSmartUnfoldingDecl (← getEnv) fName then
withDefault <| unfoldDefinition? e
unfoldDefinitionAny? e
else if (← isMatchDef fName) then
let some value ← withDefault <| unfoldDefinition? e | return none
let some value ← unfoldDefinitionAny? e | return none
let .reduced value ← withSimpMetaConfig <| reduceMatcher? value | return none
return some value
else

View file

@ -64,10 +64,38 @@ def isAuxDef (constName : Name) : MetaM Bool := do
let env ← getEnv
return isAuxRecursor env constName || isNoConfusion env constName
@[inline] private def matchConstAux {α} (e : Expr) (failK : Unit → MetaM α) (k : ConstantInfo → List Level → MetaM α) : MetaM α := do
let .const name lvls := e
/--
Retrieves `ConstInfo` for `declName`.
Remark: if `ignoreTransparency = false`, then `getUnfoldableConst?` is used.
For example, if `ignoreTransparency = false` and `transparencyMode = .reducible` and `declName` is not reducible,
then the result is `none`.
-/
private def getConstInfo? (declName : Name) (ignoreTransparency : Bool) : MetaM (Option ConstantInfo) := do
if ignoreTransparency then
return (← getEnv).find? declName
else
getUnfoldableConst? declName
/--
Similar to `getConstInfo?` but using `getUnfoldableConstNoEx?`.
-/
private def getConstInfoNoEx? (declName : Name) (ignoreTransparency : Bool) : MetaM (Option ConstantInfo) := do
if ignoreTransparency then
return (← getEnv).find? declName
else
getUnfoldableConstNoEx? declName
/--
If `e` is of the form `Expr.const declName us`, executes `k info us` if
- `declName` is in the `Environment` and (is unfoldable or `ignoreTransparency = true`)
- `info` is the `ConstantInfo` associated with `declName`.
Otherwise executes `failK`.
-/
@[inline] private def matchConstAux {α} (e : Expr) (failK : Unit → MetaM α) (k : ConstantInfo → List Level → MetaM α) (ignoreTransparency := false) : MetaM α := do
let .const declName lvls := e
| failK ()
let (some cinfo) ← getUnfoldableConst? name
let some cinfo ← getConstInfo? declName ignoreTransparency
| failK ()
k cinfo lvls
@ -713,11 +741,14 @@ mutual
else
unfoldProjInst? e
/-- Unfold definition using "smart unfolding" if possible. -/
partial def unfoldDefinition? (e : Expr) : MetaM (Option Expr) :=
/--
Unfold definition using "smart unfolding" if possible.
If `ignoreTransparency = true`, then the definition is unfolded even if the transparency setting does not allow it.
-/
partial def unfoldDefinition? (e : Expr) (ignoreTransparency := false) : MetaM (Option Expr) :=
match e with
| .app f _ =>
matchConstAux f.getAppFn (fun _ => unfoldProjInstWhenInstances? e) fun fInfo fLvls => do
matchConstAux (ignoreTransparency := ignoreTransparency) f.getAppFn (fun _ => unfoldProjInstWhenInstances? e) fun fInfo fLvls => do
if fInfo.levelParams.length != fLvls.length then
return none
else
@ -756,7 +787,8 @@ mutual
Remark 2: the match expression reduces reduces to `cons a xs` when the discriminants are `⟨0, h⟩` and `xs`.
Remark 3: this check is unnecessary in most cases, but we don't need dependent elimination to trigger the issue fixed by this extra check. Here is another example that triggers the issue fixed by this check.
Remark 3: this check is unnecessary in most cases, but we don't need dependent elimination to trigger the issue
fixed by this extra check. Here is another example that triggers the issue fixed by this check.
```
def f : Nat → Nat → Nat
| 0, y => y
@ -788,7 +820,7 @@ mutual
else
unfoldDefault ()
| .const declName lvls => do
let some cinfo ← getUnfoldableConstNoEx? declName | pure none
let some cinfo ← getConstInfoNoEx? declName ignoreTransparency | pure none
-- check smart unfolding only after `getUnfoldableConstNoEx?` because smart unfoldings have a
-- significant chance of not existing and `Environment.contains` misses are more costly
if smartUnfolding.get (← getOptions) && (← getEnv).contains (mkSmartUnfoldingNameFor declName) then

View file

@ -16,14 +16,14 @@ namespace Vector'
(v.snoc x).nth k = x := by
cases k; rename_i k hk
induction v generalizing k <;> subst h
· simp only [nth]
· simp! [*]
· simp only [nth, snoc]
· simp! [*, nth]
theorem nth_snoc_eq_works (k: Fin (n+1))(v : Vector' α n)
(h: k.val = n):
(v.snoc x).nth k = x := by
cases k; rename_i k hk
induction v generalizing k <;> subst h
· simp only [nth]
· simp[*,nth]
· simp only [nth, snoc]
· simp [*, nth, snoc]
end Vector'

49
tests/lean/run/5755.lean Normal file
View file

@ -0,0 +1,49 @@
inductive C : Type where
| C1 (b : Bool) : C
| C2 (c1 c2 : C) : C
deriving Inhabited
open C
def id1 (b : Bool) : C := C1 b
def id2 (c : C) : C :=
match c with
| C1 b => id1 b
| C2 c1 c2 => C2 (id2 c1) (id2 c2)
theorem id2_is_idempotent : id2 (id2 c) ≠ id2 c :=
match c with
| C1 b => by
guard_target =ₛ id2 (id2 (C1 b)) ≠ id2 (C1 b)
dsimp only [id2]
guard_target =ₛ id2 (id1 b) ≠ id1 b
sorry
| C2 _ _ => by
sorry
example : id2 (id1 b) ≠ a := by
fail_if_success dsimp only [id2]
dsimp only [id2, id1]
guard_target =ₛ C1 b ≠ a
sorry
/-
Here is another problematic example that has been fixed.
-/
def f : Nat → Nat
| 0 => 1
| x+1 => 2 * f x
def fib : Nat → Nat
| 0 => 1
| 1 => 1
| x+2 => fib (x+1) + fib x
example : 0 + f (fib 10000) = a := by
simp [f] -- should not trigger max rec depth
guard_target =ₛ f (fib 10000) = a
sorry

View file

@ -45,10 +45,10 @@ end DataEntry
abbrev Header := List (DataType × String)
def Header.colTypes (h : Header) : List DataType :=
@[simp] def Header.colTypes (h : Header) : List DataType :=
h.map fun x => x.1
def Header.colNames (h : Header) : List String :=
@[simp] def Header.colNames (h : Header) : List String :=
h.map fun x => x.2
abbrev Row := List DataEntry
@ -69,7 +69,7 @@ structure DataFrame where
namespace DataFrame
def empty (header : Header := []) : DataFrame :=
@[simp] def empty (header : Header := []) : DataFrame :=
⟨header, [], by simp⟩
theorem consistentConcatOfConsistentRow
@ -88,12 +88,12 @@ def addRow (df : DataFrame) (row : List DataEntry)
end DataFrame
def h : Header := [(TInt, "id"), (TString, "name")]
abbrev h : Header := [(TInt, "id"), (TString, "name")]
def r : List Row := [[1, "alex"]]
abbrev r : List Row := [[1, "alex"]]
-- this no longer works
def df1 : DataFrame := DataFrame.mk h r
abbrev df1 : DataFrame := DataFrame.mk h r
-- and this ofc breaks now
def df2 : DataFrame := df1.addRow [2, "juddy"]

View file

@ -32,7 +32,7 @@ def Expr.concat : Expr → Expr → Expr
theorem Expr.denote_concat (ctx : Context α) (a b : Expr) : denote ctx (concat a b) = denote ctx (Expr.op a b) := by
induction a with
| var i => rfl
| op _ _ _ ih => simp [denote, ih, ctx.assoc]
| op _ _ _ ih => simp [denote, concat, ih, ctx.assoc]
def Expr.flat : Expr → Expr
| Expr.op a b => concat (flat a) (flat b)

View file

@ -20,7 +20,7 @@ theorem concatEq (xs : List α) (h : xs ≠ []) : concat (dropLast xs) (last xs
match xs, h with
| [], h => contradiction
| [x], h => rfl
| x₁::x₂::xs, h => simp [concat, last, concatEq (x₂::xs) List.noConfusion]
| x₁::x₂::xs, h => simp [concat, dropLast, last, concatEq (x₂::xs) List.noConfusion]
theorem lengthCons {α} (x : α) (xs : List α) : (x::xs).length = xs.length + 1 :=
rfl