From a99393483949dc56d692f3a37dbcfc51db568341 Mon Sep 17 00:00:00 2001 From: Joachim Breitner Date: Thu, 29 Aug 2024 18:47:40 +0200 Subject: [PATCH] feat: generate `f.eq_unfold` lemmas (#5141) With this, lean produces the following zoo of rewrite rules: ``` Option.map.eq_1 : Option.map f none = none Option.map.eq_2 : Option.map f (some x) = some (f x) Option.map.eq_def : Option.map f p = match o with | none => none | (some x) => some (f x) Option.map.eq_unfold : Option.map = fun f p => match o with | none => none | (some x) => some (f x) ``` The `f.eq_unfold` variant is especially useful to rewrite with `rw` under binders. This implements and fixes #5110 --- src/Lean/Elab/PreDefinition.lean | 1 + src/Lean/Elab/PreDefinition/EqUnfold.lean | 62 ++++++++++++++++++++ src/Lean/Elab/PreDefinition/Eqns.lean | 17 ------ src/Lean/Meta/Eqns.lean | 14 ++--- tests/lean/run/unfoldLemma.lean | 69 +++++++++++++++++++++++ 5 files changed, 138 insertions(+), 25 deletions(-) create mode 100644 src/Lean/Elab/PreDefinition/EqUnfold.lean create mode 100644 tests/lean/run/unfoldLemma.lean diff --git a/src/Lean/Elab/PreDefinition.lean b/src/Lean/Elab/PreDefinition.lean index 20aa90ac24..55116017a0 100644 --- a/src/Lean/Elab/PreDefinition.lean +++ b/src/Lean/Elab/PreDefinition.lean @@ -11,3 +11,4 @@ import Lean.Elab.PreDefinition.MkInhabitant import Lean.Elab.PreDefinition.WF import Lean.Elab.PreDefinition.Eqns import Lean.Elab.PreDefinition.Nonrec.Eqns +import Lean.Elab.PreDefinition.EqUnfold diff --git a/src/Lean/Elab/PreDefinition/EqUnfold.lean b/src/Lean/Elab/PreDefinition/EqUnfold.lean new file mode 100644 index 0000000000..983b2202e1 --- /dev/null +++ b/src/Lean/Elab/PreDefinition/EqUnfold.lean @@ -0,0 +1,62 @@ +/- +Copyright (c) 2024 Lean FRO. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Joachim Breitner +-/ +prelude +import Lean.Meta.Eqns +import Lean.Meta.Tactic.Util +import Lean.Meta.Tactic.Rfl +import Lean.Meta.Tactic.Intro +import Lean.Meta.Tactic.Apply + +namespace Lean.Meta + +/-- Try to close goal using `rfl` with smart unfolding turned off. -/ +def tryURefl (mvarId : MVarId) : MetaM Bool := + withOptions (smartUnfolding.set · false) do + try mvarId.refl; return true catch _ => return false + +/-- +Returns the "const unfold" theorem (`f.eq_unfold`) for the given declaration. +This is not extensible, and always builds on the unfold theorem (`f.eq_def`). +-/ +def getConstUnfoldEqnFor? (declName : Name) : MetaM (Option Name) := do + let some unfoldEqnName ← getUnfoldEqnFor? (nonRec := true) declName | return none + let info ← getConstInfo unfoldEqnName + let type ← forallTelescope info.type fun xs eq => do + let some (_, lhs, rhs) := eq.eq? | throwError "Unexpected unfold theorem type {info.type}" + unless lhs.getAppFn.isConstOf declName do + throwError "Unexpected unfold theorem type {info.type}" + unless lhs.getAppArgs == xs do + throwError "Unexpected unfold theorem type {info.type}" + let type ← mkEq lhs.getAppFn (← mkLambdaFVars xs rhs) + return type + let value ← withNewMCtxDepth do + let main ← mkFreshExprSyntheticOpaqueMVar type + if (← tryURefl main.mvarId!) then -- try to make a rfl lemma if possible + instantiateMVars main + else forallTelescope info.type fun xs _eq => do + let mut proof ← mkConstWithLevelParams unfoldEqnName + proof := mkAppN proof xs + for x in xs.reverse do + proof ← mkLambdaFVars #[x] proof + proof ← mkAppM ``funext #[proof] + return proof + let name := .str declName eqUnfoldThmSuffix + addDecl <| Declaration.thmDecl { + name, type, value + levelParams := info.levelParams + } + return some name + + +builtin_initialize + registerReservedNameAction fun name => do + let .str p s := name | return false + unless (← getEnv).isSafeDefinition p do return false + if s == eqUnfoldThmSuffix then + return (← MetaM.run' <| getConstUnfoldEqnFor? p).isSome + return false + +end Lean.Meta diff --git a/src/Lean/Elab/PreDefinition/Eqns.lean b/src/Lean/Elab/PreDefinition/Eqns.lean index 7b61e0ab8c..11ec07514d 100644 --- a/src/Lean/Elab/PreDefinition/Eqns.lean +++ b/src/Lean/Elab/PreDefinition/Eqns.lean @@ -43,15 +43,6 @@ def expandRHS? (mvarId : MVarId) : MetaM (Option MVarId) := do let (true, rhs') := expand false rhs | return none return some (← mvarId.replaceTargetDefEq (← mkEq lhs rhs')) -def funext? (mvarId : MVarId) : MetaM (Option MVarId) := do - let target ← mvarId.getType' - let some (_, _, rhs) := target.eq? | return none - unless rhs.isLambda do return none - commitWhenSome? do - let [mvarId] ← mvarId.apply (← mkConstWithFreshMVarLevels ``funext) | return none - let (_, mvarId) ← mvarId.intro1 - return some mvarId - def simpMatch? (mvarId : MVarId) : MetaM (Option MVarId) := do let mvarId' ← Split.simpMatchTarget mvarId if mvarId != mvarId' then return some mvarId' else return none @@ -244,11 +235,6 @@ where if let some mvarId ← expandRHS? mvarId then return (← go mvarId) - -- The following `funext?` was producing an overapplied `lhs`. Possible refinement: only do it - -- if we want to apply `splitMatch` on the body of the lambda - /- if let some mvarId ← funext? mvarId then - return (← go mvarId) -/ - if (← shouldUseSimpMatch (← mvarId.getType')) then if let some mvarId ← simpMatch? mvarId then return (← go mvarId) @@ -348,9 +334,6 @@ partial def mkUnfoldProof (declName : Name) (mvarId : MVarId) : MetaM Unit := do let rec go (mvarId : MVarId) : MetaM Unit := do if (← tryEqns mvarId) then return () - -- Remark: we removed funext? from `mkEqnTypes` - -- else if let some mvarId ← funext? mvarId then - -- go mvarId if (← shouldUseSimpMatch (← mvarId.getType')) then if let some mvarId ← simpMatch? mvarId then diff --git a/src/Lean/Meta/Eqns.lean b/src/Lean/Meta/Eqns.lean index b14f1ea47f..8707ca4edc 100644 --- a/src/Lean/Meta/Eqns.lean +++ b/src/Lean/Meta/Eqns.lean @@ -66,27 +66,25 @@ def isEqnReservedNameSuffix (s : String) : Bool := eqnThmSuffixBasePrefix.isPrefixOf s && (s.drop 3).isNat def unfoldThmSuffix := "eq_def" - -/-- Returns `true` if `s == "eq_def"` -/ -def isUnfoldReservedNameSuffix (s : String) : Bool := - s == unfoldThmSuffix +def eqUnfoldThmSuffix := "eq_unfold" /-- Throw an error if names for equation theorems for `declName` are not available. -/ def ensureEqnReservedNamesAvailable (declName : Name) : CoreM Unit := do + ensureReservedNameAvailable declName eqUnfoldThmSuffix ensureReservedNameAvailable declName unfoldThmSuffix ensureReservedNameAvailable declName eqn1ThmSuffix -- TODO: `declName` may need to reserve multiple `eq_` names, but we check only the first one. -- Possible improvement: try to efficiently compute the number of equation theorems at declaration time, and check all of them. /-- -Ensures that `f.eq_def` and `f.eq_` are reserved names if `f` is a safe definition. +Ensures that `f.eq_def`, `f.unfold` and `f.eq_` are reserved names if `f` is a safe definition. -/ builtin_initialize registerReservedNamePredicate fun env n => match n with | .str p s => - (isEqnReservedNameSuffix s || isUnfoldReservedNameSuffix s) + (isEqnReservedNameSuffix s || s == unfoldThmSuffix || s == eqUnfoldThmSuffix) && env.isSafeDefinition p -- Remark: `f.match_.eq_` are private definitions and are not treated as reserved names -- Reason: `f.match_.splitter is generated at the same time, and can eliminate into type. @@ -261,7 +259,7 @@ def registerGetUnfoldEqnFn (f : GetUnfoldEqnFn) : IO Unit := do getUnfoldEqnFnsRef.modify (f :: ·) /-- -Returns an "unfold" theorem for the given declaration. +Returns an "unfold" theorem (`f.eq_def`) for the given declaration. By default, we do not create unfold theorems for nonrecursive definitions. You can use `nonRec := true` to override this behavior. -/ @@ -286,7 +284,7 @@ builtin_initialize unless (← getEnv).isSafeDefinition p do return false if isEqnReservedNameSuffix s then return (← MetaM.run' <| getEqnsFor? p).isSome - if isUnfoldReservedNameSuffix s then + if s == unfoldThmSuffix then return (← MetaM.run' <| getUnfoldEqnFor? p (nonRec := true)).isSome return false diff --git a/tests/lean/run/unfoldLemma.lean b/tests/lean/run/unfoldLemma.lean new file mode 100644 index 0000000000..01cb1d88a2 --- /dev/null +++ b/tests/lean/run/unfoldLemma.lean @@ -0,0 +1,69 @@ +def Option_map (f : α → β) : Option α → Option β + | none => none + | some x => some (f x) + +/-- +info: equations: +theorem Option_map.eq_1.{u_1, u_2} : ∀ {α : Type u_1} {β : Type u_2} (f : α → β), Option_map f none = none +theorem Option_map.eq_2.{u_1, u_2} : ∀ {α : Type u_1} {β : Type u_2} (f : α → β) (x_1 : α), + Option_map f (some x_1) = some (f x_1) +-/ +#guard_msgs in +#print equations Option_map + +/-- +info: Option_map.eq_def.{u_1, u_2} {α : Type u_1} {β : Type u_2} (f : α → β) : + ∀ (x : Option α), + Option_map f x = + match x with + | none => none + | some x => some (f x) +-/ +#guard_msgs in +#check Option_map.eq_def + +/-- +info: Option_map.eq_unfold.{u_1, u_2} : + @Option_map = fun {α} {β} f x => + match x with + | none => none + | some x => some (f x) +-/ +#guard_msgs in +#check Option_map.eq_unfold + +def answer := 42 + +/-- info: answer.eq_unfold : answer = 42 -/ +#guard_msgs in +#check answer.eq_unfold + +-- structural recursion +def List_map (f : α → β) : List α → List β + | [] => [] + | x::xs => f x :: List_map f xs +/-- +info: List_map.eq_unfold.{u_1, u_2} : + @List_map = fun {α} {β} f x => + match x with + | [] => [] + | x :: xs => f x :: List_map f xs +-/ +#guard_msgs in +#check List_map.eq_unfold + +-- wf recursion +def List_map2 (f : α → β) : List α → List β + | [] => [] + | x::xs => f x :: List_map2 f xs +termination_by l => l + +/-- +info: List_map2.eq_unfold.{u_1, u_2} : + @List_map2 = fun {α} {β} f x => + match x with + | [] => [] + | x :: xs => f x :: List_map2 f xs +-/ +#guard_msgs in +#check List_map2.eq_unfold