feat: generate f.eq_unfold lemmas (#5141)
With this, lean produces the following zoo of rewrite rules: ``` Option.map.eq_1 : Option.map f none = none Option.map.eq_2 : Option.map f (some x) = some (f x) Option.map.eq_def : Option.map f p = match o with | none => none | (some x) => some (f x) Option.map.eq_unfold : Option.map = fun f p => match o with | none => none | (some x) => some (f x) ``` The `f.eq_unfold` variant is especially useful to rewrite with `rw` under binders. This implements and fixes #5110
This commit is contained in:
parent
aa3c87b2c7
commit
a993934839
5 changed files with 138 additions and 25 deletions
|
|
@ -11,3 +11,4 @@ import Lean.Elab.PreDefinition.MkInhabitant
|
|||
import Lean.Elab.PreDefinition.WF
|
||||
import Lean.Elab.PreDefinition.Eqns
|
||||
import Lean.Elab.PreDefinition.Nonrec.Eqns
|
||||
import Lean.Elab.PreDefinition.EqUnfold
|
||||
|
|
|
|||
62
src/Lean/Elab/PreDefinition/EqUnfold.lean
Normal file
62
src/Lean/Elab/PreDefinition/EqUnfold.lean
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
/-
|
||||
Copyright (c) 2024 Lean FRO. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Joachim Breitner
|
||||
-/
|
||||
prelude
|
||||
import Lean.Meta.Eqns
|
||||
import Lean.Meta.Tactic.Util
|
||||
import Lean.Meta.Tactic.Rfl
|
||||
import Lean.Meta.Tactic.Intro
|
||||
import Lean.Meta.Tactic.Apply
|
||||
|
||||
namespace Lean.Meta
|
||||
|
||||
/-- Try to close goal using `rfl` with smart unfolding turned off. -/
|
||||
def tryURefl (mvarId : MVarId) : MetaM Bool :=
|
||||
withOptions (smartUnfolding.set · false) do
|
||||
try mvarId.refl; return true catch _ => return false
|
||||
|
||||
/--
|
||||
Returns the "const unfold" theorem (`f.eq_unfold`) for the given declaration.
|
||||
This is not extensible, and always builds on the unfold theorem (`f.eq_def`).
|
||||
-/
|
||||
def getConstUnfoldEqnFor? (declName : Name) : MetaM (Option Name) := do
|
||||
let some unfoldEqnName ← getUnfoldEqnFor? (nonRec := true) declName | return none
|
||||
let info ← getConstInfo unfoldEqnName
|
||||
let type ← forallTelescope info.type fun xs eq => do
|
||||
let some (_, lhs, rhs) := eq.eq? | throwError "Unexpected unfold theorem type {info.type}"
|
||||
unless lhs.getAppFn.isConstOf declName do
|
||||
throwError "Unexpected unfold theorem type {info.type}"
|
||||
unless lhs.getAppArgs == xs do
|
||||
throwError "Unexpected unfold theorem type {info.type}"
|
||||
let type ← mkEq lhs.getAppFn (← mkLambdaFVars xs rhs)
|
||||
return type
|
||||
let value ← withNewMCtxDepth do
|
||||
let main ← mkFreshExprSyntheticOpaqueMVar type
|
||||
if (← tryURefl main.mvarId!) then -- try to make a rfl lemma if possible
|
||||
instantiateMVars main
|
||||
else forallTelescope info.type fun xs _eq => do
|
||||
let mut proof ← mkConstWithLevelParams unfoldEqnName
|
||||
proof := mkAppN proof xs
|
||||
for x in xs.reverse do
|
||||
proof ← mkLambdaFVars #[x] proof
|
||||
proof ← mkAppM ``funext #[proof]
|
||||
return proof
|
||||
let name := .str declName eqUnfoldThmSuffix
|
||||
addDecl <| Declaration.thmDecl {
|
||||
name, type, value
|
||||
levelParams := info.levelParams
|
||||
}
|
||||
return some name
|
||||
|
||||
|
||||
builtin_initialize
|
||||
registerReservedNameAction fun name => do
|
||||
let .str p s := name | return false
|
||||
unless (← getEnv).isSafeDefinition p do return false
|
||||
if s == eqUnfoldThmSuffix then
|
||||
return (← MetaM.run' <| getConstUnfoldEqnFor? p).isSome
|
||||
return false
|
||||
|
||||
end Lean.Meta
|
||||
|
|
@ -43,15 +43,6 @@ def expandRHS? (mvarId : MVarId) : MetaM (Option MVarId) := do
|
|||
let (true, rhs') := expand false rhs | return none
|
||||
return some (← mvarId.replaceTargetDefEq (← mkEq lhs rhs'))
|
||||
|
||||
def funext? (mvarId : MVarId) : MetaM (Option MVarId) := do
|
||||
let target ← mvarId.getType'
|
||||
let some (_, _, rhs) := target.eq? | return none
|
||||
unless rhs.isLambda do return none
|
||||
commitWhenSome? do
|
||||
let [mvarId] ← mvarId.apply (← mkConstWithFreshMVarLevels ``funext) | return none
|
||||
let (_, mvarId) ← mvarId.intro1
|
||||
return some mvarId
|
||||
|
||||
def simpMatch? (mvarId : MVarId) : MetaM (Option MVarId) := do
|
||||
let mvarId' ← Split.simpMatchTarget mvarId
|
||||
if mvarId != mvarId' then return some mvarId' else return none
|
||||
|
|
@ -244,11 +235,6 @@ where
|
|||
if let some mvarId ← expandRHS? mvarId then
|
||||
return (← go mvarId)
|
||||
|
||||
-- The following `funext?` was producing an overapplied `lhs`. Possible refinement: only do it
|
||||
-- if we want to apply `splitMatch` on the body of the lambda
|
||||
/- if let some mvarId ← funext? mvarId then
|
||||
return (← go mvarId) -/
|
||||
|
||||
if (← shouldUseSimpMatch (← mvarId.getType')) then
|
||||
if let some mvarId ← simpMatch? mvarId then
|
||||
return (← go mvarId)
|
||||
|
|
@ -348,9 +334,6 @@ partial def mkUnfoldProof (declName : Name) (mvarId : MVarId) : MetaM Unit := do
|
|||
let rec go (mvarId : MVarId) : MetaM Unit := do
|
||||
if (← tryEqns mvarId) then
|
||||
return ()
|
||||
-- Remark: we removed funext? from `mkEqnTypes`
|
||||
-- else if let some mvarId ← funext? mvarId then
|
||||
-- go mvarId
|
||||
|
||||
if (← shouldUseSimpMatch (← mvarId.getType')) then
|
||||
if let some mvarId ← simpMatch? mvarId then
|
||||
|
|
|
|||
|
|
@ -66,27 +66,25 @@ def isEqnReservedNameSuffix (s : String) : Bool :=
|
|||
eqnThmSuffixBasePrefix.isPrefixOf s && (s.drop 3).isNat
|
||||
|
||||
def unfoldThmSuffix := "eq_def"
|
||||
|
||||
/-- Returns `true` if `s == "eq_def"` -/
|
||||
def isUnfoldReservedNameSuffix (s : String) : Bool :=
|
||||
s == unfoldThmSuffix
|
||||
def eqUnfoldThmSuffix := "eq_unfold"
|
||||
|
||||
/--
|
||||
Throw an error if names for equation theorems for `declName` are not available.
|
||||
-/
|
||||
def ensureEqnReservedNamesAvailable (declName : Name) : CoreM Unit := do
|
||||
ensureReservedNameAvailable declName eqUnfoldThmSuffix
|
||||
ensureReservedNameAvailable declName unfoldThmSuffix
|
||||
ensureReservedNameAvailable declName eqn1ThmSuffix
|
||||
-- TODO: `declName` may need to reserve multiple `eq_<idx>` names, but we check only the first one.
|
||||
-- Possible improvement: try to efficiently compute the number of equation theorems at declaration time, and check all of them.
|
||||
|
||||
/--
|
||||
Ensures that `f.eq_def` and `f.eq_<idx>` are reserved names if `f` is a safe definition.
|
||||
Ensures that `f.eq_def`, `f.unfold` and `f.eq_<idx>` are reserved names if `f` is a safe definition.
|
||||
-/
|
||||
builtin_initialize registerReservedNamePredicate fun env n =>
|
||||
match n with
|
||||
| .str p s =>
|
||||
(isEqnReservedNameSuffix s || isUnfoldReservedNameSuffix s)
|
||||
(isEqnReservedNameSuffix s || s == unfoldThmSuffix || s == eqUnfoldThmSuffix)
|
||||
&& env.isSafeDefinition p
|
||||
-- Remark: `f.match_<idx>.eq_<idx>` are private definitions and are not treated as reserved names
|
||||
-- Reason: `f.match_<idx>.splitter is generated at the same time, and can eliminate into type.
|
||||
|
|
@ -261,7 +259,7 @@ def registerGetUnfoldEqnFn (f : GetUnfoldEqnFn) : IO Unit := do
|
|||
getUnfoldEqnFnsRef.modify (f :: ·)
|
||||
|
||||
/--
|
||||
Returns an "unfold" theorem for the given declaration.
|
||||
Returns an "unfold" theorem (`f.eq_def`) for the given declaration.
|
||||
By default, we do not create unfold theorems for nonrecursive definitions.
|
||||
You can use `nonRec := true` to override this behavior.
|
||||
-/
|
||||
|
|
@ -286,7 +284,7 @@ builtin_initialize
|
|||
unless (← getEnv).isSafeDefinition p do return false
|
||||
if isEqnReservedNameSuffix s then
|
||||
return (← MetaM.run' <| getEqnsFor? p).isSome
|
||||
if isUnfoldReservedNameSuffix s then
|
||||
if s == unfoldThmSuffix then
|
||||
return (← MetaM.run' <| getUnfoldEqnFor? p (nonRec := true)).isSome
|
||||
return false
|
||||
|
||||
|
|
|
|||
69
tests/lean/run/unfoldLemma.lean
Normal file
69
tests/lean/run/unfoldLemma.lean
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
def Option_map (f : α → β) : Option α → Option β
|
||||
| none => none
|
||||
| some x => some (f x)
|
||||
|
||||
/--
|
||||
info: equations:
|
||||
theorem Option_map.eq_1.{u_1, u_2} : ∀ {α : Type u_1} {β : Type u_2} (f : α → β), Option_map f none = none
|
||||
theorem Option_map.eq_2.{u_1, u_2} : ∀ {α : Type u_1} {β : Type u_2} (f : α → β) (x_1 : α),
|
||||
Option_map f (some x_1) = some (f x_1)
|
||||
-/
|
||||
#guard_msgs in
|
||||
#print equations Option_map
|
||||
|
||||
/--
|
||||
info: Option_map.eq_def.{u_1, u_2} {α : Type u_1} {β : Type u_2} (f : α → β) :
|
||||
∀ (x : Option α),
|
||||
Option_map f x =
|
||||
match x with
|
||||
| none => none
|
||||
| some x => some (f x)
|
||||
-/
|
||||
#guard_msgs in
|
||||
#check Option_map.eq_def
|
||||
|
||||
/--
|
||||
info: Option_map.eq_unfold.{u_1, u_2} :
|
||||
@Option_map = fun {α} {β} f x =>
|
||||
match x with
|
||||
| none => none
|
||||
| some x => some (f x)
|
||||
-/
|
||||
#guard_msgs in
|
||||
#check Option_map.eq_unfold
|
||||
|
||||
def answer := 42
|
||||
|
||||
/-- info: answer.eq_unfold : answer = 42 -/
|
||||
#guard_msgs in
|
||||
#check answer.eq_unfold
|
||||
|
||||
-- structural recursion
|
||||
def List_map (f : α → β) : List α → List β
|
||||
| [] => []
|
||||
| x::xs => f x :: List_map f xs
|
||||
/--
|
||||
info: List_map.eq_unfold.{u_1, u_2} :
|
||||
@List_map = fun {α} {β} f x =>
|
||||
match x with
|
||||
| [] => []
|
||||
| x :: xs => f x :: List_map f xs
|
||||
-/
|
||||
#guard_msgs in
|
||||
#check List_map.eq_unfold
|
||||
|
||||
-- wf recursion
|
||||
def List_map2 (f : α → β) : List α → List β
|
||||
| [] => []
|
||||
| x::xs => f x :: List_map2 f xs
|
||||
termination_by l => l
|
||||
|
||||
/--
|
||||
info: List_map2.eq_unfold.{u_1, u_2} :
|
||||
@List_map2 = fun {α} {β} f x =>
|
||||
match x with
|
||||
| [] => []
|
||||
| x :: xs => f x :: List_map2 f xs
|
||||
-/
|
||||
#guard_msgs in
|
||||
#check List_map2.eq_unfold
|
||||
Loading…
Add table
Reference in a new issue