feat: generate f.eq_unfold lemmas (#5141)

With this, lean produces the following zoo of rewrite rules:
```
Option.map.eq_1      : Option.map f none = none
Option.map.eq_2      : Option.map f (some x) = some (f x)
Option.map.eq_def    : Option.map f p = match o with | none => none | (some x) => some (f x)
Option.map.eq_unfold : Option.map = fun f p => match o with | none => none | (some x) => some (f x)
```

The `f.eq_unfold` variant is especially useful to rewrite with `rw`
under
binders.

This implements and fixes #5110
This commit is contained in:
Joachim Breitner 2024-08-29 18:47:40 +02:00 committed by GitHub
parent aa3c87b2c7
commit a993934839
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 138 additions and 25 deletions

View file

@ -11,3 +11,4 @@ import Lean.Elab.PreDefinition.MkInhabitant
import Lean.Elab.PreDefinition.WF
import Lean.Elab.PreDefinition.Eqns
import Lean.Elab.PreDefinition.Nonrec.Eqns
import Lean.Elab.PreDefinition.EqUnfold

View file

@ -0,0 +1,62 @@
/-
Copyright (c) 2024 Lean FRO. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Joachim Breitner
-/
prelude
import Lean.Meta.Eqns
import Lean.Meta.Tactic.Util
import Lean.Meta.Tactic.Rfl
import Lean.Meta.Tactic.Intro
import Lean.Meta.Tactic.Apply
namespace Lean.Meta
/-- Try to close goal using `rfl` with smart unfolding turned off. -/
def tryURefl (mvarId : MVarId) : MetaM Bool :=
withOptions (smartUnfolding.set · false) do
try mvarId.refl; return true catch _ => return false
/--
Returns the "const unfold" theorem (`f.eq_unfold`) for the given declaration.
This is not extensible, and always builds on the unfold theorem (`f.eq_def`).
-/
def getConstUnfoldEqnFor? (declName : Name) : MetaM (Option Name) := do
let some unfoldEqnName ← getUnfoldEqnFor? (nonRec := true) declName | return none
let info ← getConstInfo unfoldEqnName
let type ← forallTelescope info.type fun xs eq => do
let some (_, lhs, rhs) := eq.eq? | throwError "Unexpected unfold theorem type {info.type}"
unless lhs.getAppFn.isConstOf declName do
throwError "Unexpected unfold theorem type {info.type}"
unless lhs.getAppArgs == xs do
throwError "Unexpected unfold theorem type {info.type}"
let type ← mkEq lhs.getAppFn (← mkLambdaFVars xs rhs)
return type
let value ← withNewMCtxDepth do
let main ← mkFreshExprSyntheticOpaqueMVar type
if (← tryURefl main.mvarId!) then -- try to make a rfl lemma if possible
instantiateMVars main
else forallTelescope info.type fun xs _eq => do
let mut proof ← mkConstWithLevelParams unfoldEqnName
proof := mkAppN proof xs
for x in xs.reverse do
proof ← mkLambdaFVars #[x] proof
proof ← mkAppM ``funext #[proof]
return proof
let name := .str declName eqUnfoldThmSuffix
addDecl <| Declaration.thmDecl {
name, type, value
levelParams := info.levelParams
}
return some name
builtin_initialize
registerReservedNameAction fun name => do
let .str p s := name | return false
unless (← getEnv).isSafeDefinition p do return false
if s == eqUnfoldThmSuffix then
return (← MetaM.run' <| getConstUnfoldEqnFor? p).isSome
return false
end Lean.Meta

View file

@ -43,15 +43,6 @@ def expandRHS? (mvarId : MVarId) : MetaM (Option MVarId) := do
let (true, rhs') := expand false rhs | return none
return some (← mvarId.replaceTargetDefEq (← mkEq lhs rhs'))
def funext? (mvarId : MVarId) : MetaM (Option MVarId) := do
let target ← mvarId.getType'
let some (_, _, rhs) := target.eq? | return none
unless rhs.isLambda do return none
commitWhenSome? do
let [mvarId] ← mvarId.apply (← mkConstWithFreshMVarLevels ``funext) | return none
let (_, mvarId) ← mvarId.intro1
return some mvarId
def simpMatch? (mvarId : MVarId) : MetaM (Option MVarId) := do
let mvarId' ← Split.simpMatchTarget mvarId
if mvarId != mvarId' then return some mvarId' else return none
@ -244,11 +235,6 @@ where
if let some mvarId ← expandRHS? mvarId then
return (← go mvarId)
-- The following `funext?` was producing an overapplied `lhs`. Possible refinement: only do it
-- if we want to apply `splitMatch` on the body of the lambda
/- if let some mvarId ← funext? mvarId then
return (← go mvarId) -/
if (← shouldUseSimpMatch (← mvarId.getType')) then
if let some mvarId ← simpMatch? mvarId then
return (← go mvarId)
@ -348,9 +334,6 @@ partial def mkUnfoldProof (declName : Name) (mvarId : MVarId) : MetaM Unit := do
let rec go (mvarId : MVarId) : MetaM Unit := do
if (← tryEqns mvarId) then
return ()
-- Remark: we removed funext? from `mkEqnTypes`
-- else if let some mvarId ← funext? mvarId then
-- go mvarId
if (← shouldUseSimpMatch (← mvarId.getType')) then
if let some mvarId ← simpMatch? mvarId then

View file

@ -66,27 +66,25 @@ def isEqnReservedNameSuffix (s : String) : Bool :=
eqnThmSuffixBasePrefix.isPrefixOf s && (s.drop 3).isNat
def unfoldThmSuffix := "eq_def"
/-- Returns `true` if `s == "eq_def"` -/
def isUnfoldReservedNameSuffix (s : String) : Bool :=
s == unfoldThmSuffix
def eqUnfoldThmSuffix := "eq_unfold"
/--
Throw an error if names for equation theorems for `declName` are not available.
-/
def ensureEqnReservedNamesAvailable (declName : Name) : CoreM Unit := do
ensureReservedNameAvailable declName eqUnfoldThmSuffix
ensureReservedNameAvailable declName unfoldThmSuffix
ensureReservedNameAvailable declName eqn1ThmSuffix
-- TODO: `declName` may need to reserve multiple `eq_<idx>` names, but we check only the first one.
-- Possible improvement: try to efficiently compute the number of equation theorems at declaration time, and check all of them.
/--
Ensures that `f.eq_def` and `f.eq_<idx>` are reserved names if `f` is a safe definition.
Ensures that `f.eq_def`, `f.unfold` and `f.eq_<idx>` are reserved names if `f` is a safe definition.
-/
builtin_initialize registerReservedNamePredicate fun env n =>
match n with
| .str p s =>
(isEqnReservedNameSuffix s || isUnfoldReservedNameSuffix s)
(isEqnReservedNameSuffix s || s == unfoldThmSuffix || s == eqUnfoldThmSuffix)
&& env.isSafeDefinition p
-- Remark: `f.match_<idx>.eq_<idx>` are private definitions and are not treated as reserved names
-- Reason: `f.match_<idx>.splitter is generated at the same time, and can eliminate into type.
@ -261,7 +259,7 @@ def registerGetUnfoldEqnFn (f : GetUnfoldEqnFn) : IO Unit := do
getUnfoldEqnFnsRef.modify (f :: ·)
/--
Returns an "unfold" theorem for the given declaration.
Returns an "unfold" theorem (`f.eq_def`) for the given declaration.
By default, we do not create unfold theorems for nonrecursive definitions.
You can use `nonRec := true` to override this behavior.
-/
@ -286,7 +284,7 @@ builtin_initialize
unless (← getEnv).isSafeDefinition p do return false
if isEqnReservedNameSuffix s then
return (← MetaM.run' <| getEqnsFor? p).isSome
if isUnfoldReservedNameSuffix s then
if s == unfoldThmSuffix then
return (← MetaM.run' <| getUnfoldEqnFor? p (nonRec := true)).isSome
return false

View file

@ -0,0 +1,69 @@
def Option_map (f : α → β) : Option α → Option β
| none => none
| some x => some (f x)
/--
info: equations:
theorem Option_map.eq_1.{u_1, u_2} : ∀ {α : Type u_1} {β : Type u_2} (f : α → β), Option_map f none = none
theorem Option_map.eq_2.{u_1, u_2} : ∀ {α : Type u_1} {β : Type u_2} (f : α → β) (x_1 : α),
Option_map f (some x_1) = some (f x_1)
-/
#guard_msgs in
#print equations Option_map
/--
info: Option_map.eq_def.{u_1, u_2} {α : Type u_1} {β : Type u_2} (f : α → β) :
∀ (x : Option α),
Option_map f x =
match x with
| none => none
| some x => some (f x)
-/
#guard_msgs in
#check Option_map.eq_def
/--
info: Option_map.eq_unfold.{u_1, u_2} :
@Option_map = fun {α} {β} f x =>
match x with
| none => none
| some x => some (f x)
-/
#guard_msgs in
#check Option_map.eq_unfold
def answer := 42
/-- info: answer.eq_unfold : answer = 42 -/
#guard_msgs in
#check answer.eq_unfold
-- structural recursion
def List_map (f : α → β) : List α → List β
| [] => []
| x::xs => f x :: List_map f xs
/--
info: List_map.eq_unfold.{u_1, u_2} :
@List_map = fun {α} {β} f x =>
match x with
| [] => []
| x :: xs => f x :: List_map f xs
-/
#guard_msgs in
#check List_map.eq_unfold
-- wf recursion
def List_map2 (f : α → β) : List α → List β
| [] => []
| x::xs => f x :: List_map2 f xs
termination_by l => l
/--
info: List_map2.eq_unfold.{u_1, u_2} :
@List_map2 = fun {α} {β} f x =>
match x with
| [] => []
| x :: xs => f x :: List_map2 f xs
-/
#guard_msgs in
#check List_map2.eq_unfold