fix: eta reduce mvar assignments in isDefEq (#4387)

I made a modification to the `mkLambdaFVars` function, adding a
`etaReduce : Bool` parameter that determines whether a new lambda of the
form `fun x => f x` should be replaced by `f`. I then set this option to
true at `isDefEq` when processing metavariable assignments.

This means that many unnecessary eta unreduced expression are now
reduced. This is beneficial for users, so that they do not have to deal
with such unreduced expressions. It is also beneficial for performance,
leading to a 0.6% improvement in build instructions. Most notably,
`Mathlib.Algebra.DirectLimit`, previously a top 50 slowest file, has
sped up by 40%.

Quite a number of proof in mathlib broke. Many of these involve removing
a now unnecessary `simp only`. In other cases, a simp or rewrite doesn't
work anymore, such as a `simp_rw [mul_comm]` that was used to rewrite
`fun x => 2*x`, but now this term has turned into `HMul.hMul 2`.

Closes #4386
This commit is contained in:
JovanGerb 2024-06-19 01:41:40 +02:00 committed by GitHub
parent 294b1d5839
commit c87205bc9b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 60 additions and 40 deletions

View file

@ -912,8 +912,8 @@ def mkForallFVars (xs : Array Expr) (e : Expr) (usedOnly : Bool := false) (usedL
/-- Takes an array `xs` of free variables and metavariables and a
body term `e` and creates `fun ..xs => e`, suitably
abstracting `e` and the types in `xs`. -/
def mkLambdaFVars (xs : Array Expr) (e : Expr) (usedOnly : Bool := false) (usedLetOnly : Bool := true) (binderInfoForMVars := BinderInfo.implicit) : MetaM Expr :=
if xs.isEmpty then return e else liftMkBindingM <| MetavarContext.mkLambda xs e usedOnly usedLetOnly binderInfoForMVars
def mkLambdaFVars (xs : Array Expr) (e : Expr) (usedOnly : Bool := false) (usedLetOnly : Bool := true) (etaReduce : Bool := false) (binderInfoForMVars := BinderInfo.implicit) : MetaM Expr :=
if xs.isEmpty then return e else liftMkBindingM <| MetavarContext.mkLambda xs e usedOnly usedLetOnly etaReduce binderInfoForMVars
def mkLetFVars (xs : Array Expr) (e : Expr) (usedLetOnly := true) (binderInfoForMVars := BinderInfo.implicit) : MetaM Expr :=
mkLambdaFVars xs e (usedLetOnly := usedLetOnly) (binderInfoForMVars := binderInfoForMVars)

View file

@ -419,10 +419,10 @@ where the index is the position in the local context.
-/
private partial def mkLambdaFVarsWithLetDeps (xs : Array Expr) (v : Expr) : MetaM (Option Expr) := do
if not (← hasLetDeclsInBetween) then
mkLambdaFVars xs v
mkLambdaFVars xs v (etaReduce := true)
else
let ys ← addLetDeps
mkLambdaFVars ys v
mkLambdaFVars ys v (etaReduce := true)
where
/-- Return true if there are let-declarions between `xs[0]` and `xs[xs.size-1]`.

View file

@ -333,26 +333,6 @@ def getSubgoals (lctx : LocalContext) (localInsts : LocalInstances) (xs : Array
subgoals := inst.synthOrder.map (mvars[·]!) |>.toList
}
/--
Similar to `mkLambdaFVars`, but ensures result is eta-reduced.
For example, suppose `e` is the local variable `inst x y`, and `xs` is `#[x, y]`, then
the result is `inst` instead of `fun x y => inst x y`.
We added this auxiliary function because of aliases such as `DecidablePred`. For example,
consider the following definition.
```
def filter (p : α → Prop) [inst : DecidablePred p] (xs : List α) : List α :=
match xs with
| [] => []
| x :: xs' => if p x then x :: filter p xs' else filter p xs'
```
Without `mkLambdaFVars'`, the implicit instance at the `filter` applications would be `fun x => inst x` instead of `inst`.
Moreover, the equation lemmas associated with `filter` would have `fun x => inst x` on their right-hand-side. Then,
we would start getting terms such as `fun x => (fun x => inst x) x` when using the equational theorem.
-/
private def mkLambdaFVars' (xs : Array Expr) (e : Expr) : MetaM Expr :=
return (← mkLambdaFVars xs e).eta
/--
Try to synthesize metavariable `mvar` using the instance `inst`.
Remark: `mctx` is set using `withMCtx`.
@ -370,7 +350,23 @@ def tryResolve (mvar : Expr) (inst : Instance) : MetaM (Option (MetavarContext
withTraceNode `Meta.synthInstance.tryResolve (withMCtx (← getMCtx) do
return m!"{exceptOptionEmoji ·} {← instantiateMVars mvarTypeBody} ≟ {← instantiateMVars instTypeBody}") do
if (← isDefEq mvarTypeBody instTypeBody) then
let instVal ← mkLambdaFVars' xs instVal
/-
We set `etaReduce := true`.
For example, suppose `e` is the local variable `inst x y`, and `xs` is `#[x, y]`, then
the result is `inst` instead of `fun x y => inst x y`.
Consider the following definition.
```
def filter (p : α → Prop) [inst : DecidablePred p] (xs : List α) : List α :=
match xs with
| [] => []
| x :: xs' => if p x then x :: filter p xs' else filter p xs'
```
Without `etaReduce := true`, the implicit instance at the `filter` applications would be `fun x => inst x` instead of `inst`.
Moreover, the equation lemmas associated with `filter` would have `fun x => inst x` on their right-hand-side. Then,
we would start getting terms such as `fun x => (fun x => inst x) x` when using the equational theorem.
-/
let instVal ← mkLambdaFVars xs instVal (etaReduce := true)
if (← isDefEq mvar instVal) then
return some ((← getMCtx), subgoals)
return none
@ -483,7 +479,7 @@ private def removeUnusedArguments? (mctx : MetavarContext) (mvar : Expr) : MetaM
let ys := ys.toArray
let mvarType' ← mkForallFVars ys body
withLocalDeclD `redf mvarType' fun f => do
let transformer ← mkLambdaFVars' #[f] (← mkLambdaFVars' xs (mkAppN f ys))
let transformer ← mkLambdaFVars #[f] (← mkLambdaFVars xs (mkAppN f ys) (etaReduce := true)) (etaReduce := true)
trace[Meta.synthInstance.unusedArgs] "{mvarType}\nhas unused arguments, reduced type{indentExpr mvarType'}\nTransformer{indentExpr transformer}"
return some (mvarType', transformer)

View file

@ -1274,13 +1274,25 @@ partial def revert (xs : Array Expr) (mvarId : MVarId) : M (Expr × Array Expr)
let e ← elimMVarDeps xs e
pure (e.abstractRange i xs)
private def mkLambda' (x : Name) (bi : BinderInfo) (t : Expr) (b : Expr) (etaReduce : Bool) : Expr :=
if etaReduce then
match b with
| .app f (.bvar 0) =>
if !f.hasLooseBVar 0 then
f.lowerLooseBVars 1 1
else
mkLambda x bi t b
| _ => mkLambda x bi t b
else
mkLambda x bi t b
/--
Similar to `LocalContext.mkBinding`, but handles metavariables correctly.
If `usedOnly == true` then `forall` and `lambda` expressions are created only for used variables.
If `usedLetOnly == true` then `let` expressions are created only for used (let-) variables. -/
@[specialize] def mkBinding (isLambda : Bool) (lctx : LocalContext) (xs : Array Expr) (e : Expr) (usedOnly : Bool) (usedLetOnly : Bool) : M (Expr × Nat) := do
@[specialize] def mkBinding (isLambda : Bool) (lctx : LocalContext) (xs : Array Expr) (e : Expr) (usedOnly : Bool) (usedLetOnly : Bool) (etaReduce : Bool) : M Expr := do
let e ← abstractRange xs xs.size e
xs.size.foldRevM (init := (e, 0)) fun i (e, num) => do
xs.size.foldRevM (init := e) fun i e => do
let x := xs[i]!
if x.isFVar then
match lctx.getFVar! x with
@ -1289,27 +1301,27 @@ partial def revert (xs : Array Expr) (mvarId : MVarId) : M (Expr × Array Expr)
let type := type.headBeta;
let type ← abstractRange xs i type
if isLambda then
return (Lean.mkLambda n bi type e, num + 1)
return mkLambda' n bi type e etaReduce
else
return (Lean.mkForall n bi type e, num + 1)
return Lean.mkForall n bi type e
else
return (e.lowerLooseBVars 1 1, num)
return e.lowerLooseBVars 1 1
| LocalDecl.ldecl _ _ n type value nonDep _ =>
if !usedLetOnly || e.hasLooseBVar 0 then
let type ← abstractRange xs i type
let value ← abstractRange xs i value
return (mkLet n type value e nonDep, num + 1)
return mkLet n type value e nonDep
else
return (e.lowerLooseBVars 1 1, num)
return e.lowerLooseBVars 1 1
else
let mvarDecl := (← get).mctx.getDecl x.mvarId!
let type := mvarDecl.type.headBeta
let type ← abstractRange xs i type
let id ← if mvarDecl.userName.isAnonymous then mkFreshBinderName else pure mvarDecl.userName
if isLambda then
return (Lean.mkLambda id (← read).binderInfoForMVars type e, num + 1)
return mkLambda' id (← read).binderInfoForMVars type e etaReduce
else
return (Lean.mkForall id (← read).binderInfoForMVars type e, num + 1)
return Lean.mkForall id (← read).binderInfoForMVars type e
end MkBinding
@ -1325,15 +1337,15 @@ def elimMVarDeps (xs : Array Expr) (e : Expr) (preserveOrder : Bool) : MkBinding
def revert (xs : Array Expr) (mvarId : MVarId) (preserveOrder : Bool) : MkBindingM (Expr × Array Expr) := fun ctx =>
MkBinding.revert xs mvarId { preserveOrder, mainModule := ctx.mainModule }
def mkBinding (isLambda : Bool) (xs : Array Expr) (e : Expr) (usedOnly : Bool := false) (usedLetOnly : Bool := true) (binderInfoForMVars := BinderInfo.implicit) : MkBindingM (Expr × Nat) := fun ctx =>
def mkBinding (isLambda : Bool) (xs : Array Expr) (e : Expr) (usedOnly : Bool := false) (usedLetOnly : Bool := true) (etaReduce := false) (binderInfoForMVars := BinderInfo.implicit) : MkBindingM Expr := fun ctx =>
let mvarIdsToAbstract := xs.foldl (init := {}) fun s x => if x.isMVar then s.insert x.mvarId! else s
MkBinding.mkBinding isLambda ctx.lctx xs e usedOnly usedLetOnly { preserveOrder := false, binderInfoForMVars, mvarIdsToAbstract, mainModule := ctx.mainModule }
MkBinding.mkBinding isLambda ctx.lctx xs e usedOnly usedLetOnly etaReduce { preserveOrder := false, binderInfoForMVars, mvarIdsToAbstract, mainModule := ctx.mainModule }
@[inline] def mkLambda (xs : Array Expr) (e : Expr) (usedOnly : Bool := false) (usedLetOnly : Bool := true) (binderInfoForMVars := BinderInfo.implicit) : MkBindingM Expr :=
return (← mkBinding (isLambda := true) xs e usedOnly usedLetOnly binderInfoForMVars).1
@[inline] def mkLambda (xs : Array Expr) (e : Expr) (usedOnly : Bool := false) (usedLetOnly : Bool := true) (etaReduce := false) (binderInfoForMVars := BinderInfo.implicit) : MkBindingM Expr :=
return ← mkBinding (isLambda := true) xs e usedOnly usedLetOnly etaReduce binderInfoForMVars
@[inline] def mkForall (xs : Array Expr) (e : Expr) (usedOnly : Bool := false) (usedLetOnly : Bool := true) (binderInfoForMVars := BinderInfo.implicit) : MkBindingM Expr :=
return (← mkBinding (isLambda := false) xs e usedOnly usedLetOnly binderInfoForMVars).1
return ← mkBinding (isLambda := false) xs e usedOnly usedLetOnly false binderInfoForMVars
@[inline] def abstractRange (e : Expr) (n : Nat) (xs : Array Expr) : MkBindingM Expr := fun ctx =>
MkBinding.abstractRange xs n e { preserveOrder := false, mainModule := ctx.mainModule }

View file

@ -0,0 +1,8 @@
instance Pi.hasLe {ι : Type u} {α : ι → Type v} [∀ i, LE (α i)] :
LE (∀ i, α i) where le x y := ∀ i, x i ≤ y i
variable {ι : Type u} {α : ι → Type v} [inst : (i : ι) → LE (α i)]
set_option trace.Meta.isDefEq.assign true
#check @Pi.hasLe ι _ inst

View file

@ -0,0 +1,4 @@
Pi.hasLe : LE ((i : ι) → α i)
[Meta.isDefEq.assign] ✅ ?m i := α i
[Meta.isDefEq.assign.beforeMkLambda] ?m [i] := α i
[Meta.isDefEq.assign.checkTypes] ✅ (?m : ι → Type ?u) := (α : ι → Type v)