From 16bc6ebcb6753ebd53eea4e9fbb078744c33cbd2 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 21 Dec 2024 05:16:15 +0100 Subject: [PATCH] fix: ensure `simp` and `dsimp` do not unfold too much (#6397) This PR ensures that `simp` and `dsimp` do not unfold definitions that are not intended to be unfolded by the user. See issue #5755 for an example affected by this issue. Closes #5755 --------- Co-authored-by: Kim Morrison --- src/Init/Data/Array/Lemmas.lean | 15 +++++---- src/Init/Data/List/Lex.lean | 4 +-- src/Init/Data/List/Zip.lean | 2 +- src/Lean/Meta/Tactic/Simp/Main.lean | 21 ++++++++++--- src/Lean/Meta/WHNF.lean | 48 +++++++++++++++++++++++----- tests/lean/run/1024.lean | 8 ++--- tests/lean/run/5755.lean | 49 +++++++++++++++++++++++++++++ tests/lean/run/946.lean | 12 +++---- tests/lean/run/ac_expr.lean | 2 +- tests/lean/run/concatElim.lean | 2 +- 10 files changed, 129 insertions(+), 34 deletions(-) create mode 100644 tests/lean/run/5755.lean diff --git a/src/Init/Data/Array/Lemmas.lean b/src/Init/Data/Array/Lemmas.lean index 655b4ad609..2e6db2b2aa 100644 --- a/src/Init/Data/Array/Lemmas.lean +++ b/src/Init/Data/Array/Lemmas.lean @@ -954,13 +954,18 @@ theorem size_eq_of_beq [BEq α] {xs ys : Array α} (h : xs == ys) : xs.size = ys rw [Bool.eq_iff_iff] simp +contextual +private theorem beq_of_beq_singleton [BEq α] {a b : α} : #[a] == #[b] → a == b := by + intro h + have : isEqv #[a] #[b] BEq.beq = true := h + simp [isEqv, isEqvAux] at this + assumption + @[simp] theorem reflBEq_iff [BEq α] : ReflBEq (Array α) ↔ ReflBEq α := by constructor · intro h constructor intro a - suffices (#[a] == #[a]) = true by - simpa only [instBEq, isEqv, isEqvAux, Bool.and_true] + apply beq_of_beq_singleton simp · intro h constructor @@ -973,11 +978,9 @@ theorem size_eq_of_beq [BEq α] {xs ys : Array α} (h : xs == ys) : xs.size = ys · intro a b h apply singleton_inj.1 apply eq_of_beq - simp only [instBEq, isEqv, isEqvAux] - simpa + simpa [instBEq, isEqv, isEqvAux] · intro a - suffices (#[a] == #[a]) = true by - simpa only [instBEq, isEqv, isEqvAux, Bool.and_true] + apply beq_of_beq_singleton simp · intro h constructor diff --git a/src/Init/Data/List/Lex.lean b/src/Init/Data/List/Lex.lean index c4f9b63fda..b6a72e14c9 100644 --- a/src/Init/Data/List/Lex.lean +++ b/src/Init/Data/List/Lex.lean @@ -333,7 +333,7 @@ theorem lex_eq_true_iff_exists [BEq α] (lt : α → α → Bool) : cases l₂ with | nil => simp [lex] | cons b l₂ => - simp only [lex_cons_cons, Bool.or_eq_true, Bool.and_eq_true, ih, isEqv, length_cons] + simp [lex_cons_cons, Bool.or_eq_true, Bool.and_eq_true, ih, isEqv, length_cons] constructor · rintro (hab | ⟨hab, ⟨h₁, h₂⟩ | ⟨i, h₁, h₂, w₁, w₂⟩⟩) · exact .inr ⟨0, by simp [hab]⟩ @@ -397,7 +397,7 @@ theorem lex_eq_false_iff_exists [BEq α] [PartialEquivBEq α] (lt : α → α cases l₂ with | nil => simp [lex] | cons b l₂ => - simp only [lex_cons_cons, Bool.or_eq_false_iff, Bool.and_eq_false_imp, ih, isEqv, + simp [lex_cons_cons, Bool.or_eq_false_iff, Bool.and_eq_false_imp, ih, isEqv, Bool.and_eq_true, length_cons] constructor · rintro ⟨hab, h⟩ diff --git a/src/Init/Data/List/Zip.lean b/src/Init/Data/List/Zip.lean index 59e3691b7a..8b0b80808b 100644 --- a/src/Init/Data/List/Zip.lean +++ b/src/Init/Data/List/Zip.lean @@ -259,7 +259,7 @@ theorem zip_map (f : α → γ) (g : β → δ) : | [], _ => rfl | _, [] => by simp only [map, zip_nil_right] | _ :: _, _ :: _ => by - simp only [map, zip_cons_cons, zip_map, Prod.map]; constructor + simp only [map, zip_cons_cons, zip_map, Prod.map]; try constructor -- TODO: remove try constructor after update stage0 theorem zip_map_left (f : α → γ) (l₁ : List α) (l₂ : List β) : zip (l₁.map f) l₂ = (zip l₁ l₂).map (Prod.map f id) := by rw [← zip_map, map_id] diff --git a/src/Lean/Meta/Tactic/Simp/Main.lean b/src/Lean/Meta/Tactic/Simp/Main.lean index e9d8868787..e1c63556ac 100644 --- a/src/Lean/Meta/Tactic/Simp/Main.lean +++ b/src/Lean/Meta/Tactic/Simp/Main.lean @@ -47,6 +47,17 @@ def isOfScientificLit (e : Expr) : Bool := def isCharLit (e : Expr) : Bool := e.isAppOfArity ``Char.ofNat 1 && e.appArg!.isRawNatLit +/-- +Unfold definition even if it is not marked as `@[reducible]`. +Remark: We never unfold irreducible definitions. Mathlib relies on that in the implementation of the +command `irreducible_def`. +-/ +private def unfoldDefinitionAny? (e : Expr) : MetaM (Option Expr) := do + if let .const declName _ := e.getAppFn then + if (← isIrreducible declName) then + return none + unfoldDefinition? e (ignoreTransparency := true) + private def reduceProjFn? (e : Expr) : SimpM (Option Expr) := do matchConst e.getAppFn (fun _ => pure none) fun cinfo _ => do match (← getProjectionFnInfo? cinfo.name) with @@ -83,7 +94,7 @@ private def reduceProjFn? (e : Expr) : SimpM (Option Expr) := do let major := e.getArg! projInfo.numParams unless (← isConstructorApp major) do return none - reduceProjCont? (← withDefault <| unfoldDefinition? e) + reduceProjCont? (← unfoldDefinitionAny? e) else -- `structure` projections reduceProjCont? (← unfoldDefinition? e) @@ -133,7 +144,7 @@ private def unfold? (e : Expr) : SimpM (Option Expr) := do if cfg.unfoldPartialApp -- If we are unfolding partial applications, ignore issue #2042 -- When smart unfolding is enabled, and `f` supports it, we don't need to worry about issue #2042 || (smartUnfolding.get options && (← getEnv).contains (mkSmartUnfoldingNameFor fName)) then - withDefault <| unfoldDefinition? e + unfoldDefinitionAny? e else -- `We are not unfolding partial applications, and `fName` does not have smart unfolding support. -- Thus, we must check whether the arity of the function >= number of arguments. @@ -142,16 +153,16 @@ private def unfold? (e : Expr) : SimpM (Option Expr) := do let arity := value.getNumHeadLambdas -- Partially applied function, return `none`. See issue #2042 if arity > e.getAppNumArgs then return none - withDefault <| unfoldDefinition? e + unfoldDefinitionAny? e if (← isProjectionFn fName) then return none -- should be reduced by `reduceProjFn?` else if ctx.config.autoUnfold then if ctx.simpTheorems.isErased (.decl fName) then return none else if hasSmartUnfoldingDecl (← getEnv) fName then - withDefault <| unfoldDefinition? e + unfoldDefinitionAny? e else if (← isMatchDef fName) then - let some value ← withDefault <| unfoldDefinition? e | return none + let some value ← unfoldDefinitionAny? e | return none let .reduced value ← withSimpMetaConfig <| reduceMatcher? value | return none return some value else diff --git a/src/Lean/Meta/WHNF.lean b/src/Lean/Meta/WHNF.lean index bad1ed8cc3..97009be71f 100644 --- a/src/Lean/Meta/WHNF.lean +++ b/src/Lean/Meta/WHNF.lean @@ -64,10 +64,38 @@ def isAuxDef (constName : Name) : MetaM Bool := do let env ← getEnv return isAuxRecursor env constName || isNoConfusion env constName -@[inline] private def matchConstAux {α} (e : Expr) (failK : Unit → MetaM α) (k : ConstantInfo → List Level → MetaM α) : MetaM α := do - let .const name lvls := e +/-- +Retrieves `ConstInfo` for `declName`. +Remark: if `ignoreTransparency = false`, then `getUnfoldableConst?` is used. +For example, if `ignoreTransparency = false` and `transparencyMode = .reducible` and `declName` is not reducible, +then the result is `none`. +-/ +private def getConstInfo? (declName : Name) (ignoreTransparency : Bool) : MetaM (Option ConstantInfo) := do + if ignoreTransparency then + return (← getEnv).find? declName + else + getUnfoldableConst? declName + +/-- +Similar to `getConstInfo?` but using `getUnfoldableConstNoEx?`. +-/ +private def getConstInfoNoEx? (declName : Name) (ignoreTransparency : Bool) : MetaM (Option ConstantInfo) := do + if ignoreTransparency then + return (← getEnv).find? declName + else + getUnfoldableConstNoEx? declName + +/-- +If `e` is of the form `Expr.const declName us`, executes `k info us` if +- `declName` is in the `Environment` and (is unfoldable or `ignoreTransparency = true`) +- `info` is the `ConstantInfo` associated with `declName`. + +Otherwise executes `failK`. +-/ +@[inline] private def matchConstAux {α} (e : Expr) (failK : Unit → MetaM α) (k : ConstantInfo → List Level → MetaM α) (ignoreTransparency := false) : MetaM α := do + let .const declName lvls := e | failK () - let (some cinfo) ← getUnfoldableConst? name + let some cinfo ← getConstInfo? declName ignoreTransparency | failK () k cinfo lvls @@ -713,11 +741,14 @@ mutual else unfoldProjInst? e - /-- Unfold definition using "smart unfolding" if possible. -/ - partial def unfoldDefinition? (e : Expr) : MetaM (Option Expr) := + /-- + Unfold definition using "smart unfolding" if possible. + If `ignoreTransparency = true`, then the definition is unfolded even if the transparency setting does not allow it. + -/ + partial def unfoldDefinition? (e : Expr) (ignoreTransparency := false) : MetaM (Option Expr) := match e with | .app f _ => - matchConstAux f.getAppFn (fun _ => unfoldProjInstWhenInstances? e) fun fInfo fLvls => do + matchConstAux (ignoreTransparency := ignoreTransparency) f.getAppFn (fun _ => unfoldProjInstWhenInstances? e) fun fInfo fLvls => do if fInfo.levelParams.length != fLvls.length then return none else @@ -756,7 +787,8 @@ mutual Remark 2: the match expression reduces reduces to `cons a xs` when the discriminants are `⟨0, h⟩` and `xs`. - Remark 3: this check is unnecessary in most cases, but we don't need dependent elimination to trigger the issue fixed by this extra check. Here is another example that triggers the issue fixed by this check. + Remark 3: this check is unnecessary in most cases, but we don't need dependent elimination to trigger the issue + fixed by this extra check. Here is another example that triggers the issue fixed by this check. ``` def f : Nat → Nat → Nat | 0, y => y @@ -788,7 +820,7 @@ mutual else unfoldDefault () | .const declName lvls => do - let some cinfo ← getUnfoldableConstNoEx? declName | pure none + let some cinfo ← getConstInfoNoEx? declName ignoreTransparency | pure none -- check smart unfolding only after `getUnfoldableConstNoEx?` because smart unfoldings have a -- significant chance of not existing and `Environment.contains` misses are more costly if smartUnfolding.get (← getOptions) && (← getEnv).contains (mkSmartUnfoldingNameFor declName) then diff --git a/tests/lean/run/1024.lean b/tests/lean/run/1024.lean index d584f48b86..6f47597e18 100644 --- a/tests/lean/run/1024.lean +++ b/tests/lean/run/1024.lean @@ -16,14 +16,14 @@ namespace Vector' (v.snoc x).nth k = x := by cases k; rename_i k hk induction v generalizing k <;> subst h - · simp only [nth] - · simp! [*] + · simp only [nth, snoc] + · simp! [*, nth] theorem nth_snoc_eq_works (k: Fin (n+1))(v : Vector' α n) (h: k.val = n): (v.snoc x).nth k = x := by cases k; rename_i k hk induction v generalizing k <;> subst h - · simp only [nth] - · simp[*,nth] + · simp only [nth, snoc] + · simp [*, nth, snoc] end Vector' diff --git a/tests/lean/run/5755.lean b/tests/lean/run/5755.lean new file mode 100644 index 0000000000..0eb39b85d1 --- /dev/null +++ b/tests/lean/run/5755.lean @@ -0,0 +1,49 @@ +inductive C : Type where +| C1 (b : Bool) : C +| C2 (c1 c2 : C) : C +deriving Inhabited + +open C + +def id1 (b : Bool) : C := C1 b + +def id2 (c : C) : C := + match c with + | C1 b => id1 b + | C2 c1 c2 => C2 (id2 c1) (id2 c2) + +theorem id2_is_idempotent : id2 (id2 c) ≠ id2 c := + match c with + | C1 b => by + guard_target =ₛ id2 (id2 (C1 b)) ≠ id2 (C1 b) + dsimp only [id2] + guard_target =ₛ id2 (id1 b) ≠ id1 b + sorry + | C2 _ _ => by + sorry + +example : id2 (id1 b) ≠ a := by + fail_if_success dsimp only [id2] + dsimp only [id2, id1] + guard_target =ₛ C1 b ≠ a + sorry + + +/- +Here is another problematic example that has been fixed. +-/ + + +def f : Nat → Nat + | 0 => 1 + | x+1 => 2 * f x + +def fib : Nat → Nat + | 0 => 1 + | 1 => 1 + | x+2 => fib (x+1) + fib x + +example : 0 + f (fib 10000) = a := by + simp [f] -- should not trigger max rec depth + guard_target =ₛ f (fib 10000) = a + sorry diff --git a/tests/lean/run/946.lean b/tests/lean/run/946.lean index 82d9d656cc..7dd7531e16 100644 --- a/tests/lean/run/946.lean +++ b/tests/lean/run/946.lean @@ -45,10 +45,10 @@ end DataEntry abbrev Header := List (DataType × String) -def Header.colTypes (h : Header) : List DataType := +@[simp] def Header.colTypes (h : Header) : List DataType := h.map fun x => x.1 -def Header.colNames (h : Header) : List String := +@[simp] def Header.colNames (h : Header) : List String := h.map fun x => x.2 abbrev Row := List DataEntry @@ -69,7 +69,7 @@ structure DataFrame where namespace DataFrame -def empty (header : Header := []) : DataFrame := +@[simp] def empty (header : Header := []) : DataFrame := ⟨header, [], by simp⟩ theorem consistentConcatOfConsistentRow @@ -88,12 +88,12 @@ def addRow (df : DataFrame) (row : List DataEntry) end DataFrame -def h : Header := [(TInt, "id"), (TString, "name")] +abbrev h : Header := [(TInt, "id"), (TString, "name")] -def r : List Row := [[1, "alex"]] +abbrev r : List Row := [[1, "alex"]] -- this no longer works -def df1 : DataFrame := DataFrame.mk h r +abbrev df1 : DataFrame := DataFrame.mk h r -- and this ofc breaks now def df2 : DataFrame := df1.addRow [2, "juddy"] diff --git a/tests/lean/run/ac_expr.lean b/tests/lean/run/ac_expr.lean index 223f077ade..8203e73fce 100644 --- a/tests/lean/run/ac_expr.lean +++ b/tests/lean/run/ac_expr.lean @@ -32,7 +32,7 @@ def Expr.concat : Expr → Expr → Expr theorem Expr.denote_concat (ctx : Context α) (a b : Expr) : denote ctx (concat a b) = denote ctx (Expr.op a b) := by induction a with | var i => rfl - | op _ _ _ ih => simp [denote, ih, ctx.assoc] + | op _ _ _ ih => simp [denote, concat, ih, ctx.assoc] def Expr.flat : Expr → Expr | Expr.op a b => concat (flat a) (flat b) diff --git a/tests/lean/run/concatElim.lean b/tests/lean/run/concatElim.lean index 39b9dd4933..f597d8fbaf 100644 --- a/tests/lean/run/concatElim.lean +++ b/tests/lean/run/concatElim.lean @@ -20,7 +20,7 @@ theorem concatEq (xs : List α) (h : xs ≠ []) : concat (dropLast xs) (last xs match xs, h with | [], h => contradiction | [x], h => rfl - | x₁::x₂::xs, h => simp [concat, last, concatEq (x₂::xs) List.noConfusion] + | x₁::x₂::xs, h => simp [concat, dropLast, last, concatEq (x₂::xs) List.noConfusion] theorem lengthCons {α} (x : α) (xs : List α) : (x::xs).length = xs.length + 1 := rfl