From c87205bc9ba037248951caaa4f65772b4129fac4 Mon Sep 17 00:00:00 2001 From: JovanGerb <56355248+JovanGerb@users.noreply.github.com> Date: Wed, 19 Jun 2024 01:41:40 +0200 Subject: [PATCH] fix: eta reduce mvar assignments in `isDefEq` (#4387) I made a modification to the `mkLambdaFVars` function, adding a `etaReduce : Bool` parameter that determines whether a new lambda of the form `fun x => f x` should be replaced by `f`. I then set this option to true at `isDefEq` when processing metavariable assignments. This means that many unnecessary eta unreduced expression are now reduced. This is beneficial for users, so that they do not have to deal with such unreduced expressions. It is also beneficial for performance, leading to a 0.6% improvement in build instructions. Most notably, `Mathlib.Algebra.DirectLimit`, previously a top 50 slowest file, has sped up by 40%. Quite a number of proof in mathlib broke. Many of these involve removing a now unnecessary `simp only`. In other cases, a simp or rewrite doesn't work anymore, such as a `simp_rw [mul_comm]` that was used to rewrite `fun x => 2*x`, but now this term has turned into `HMul.hMul 2`. Closes #4386 --- src/Lean/Meta/Basic.lean | 4 +- src/Lean/Meta/ExprDefEq.lean | 4 +- src/Lean/Meta/SynthInstance.lean | 40 +++++++++---------- src/Lean/MetavarContext.lean | 40 ++++++++++++------- tests/lean/etaReducedMvarAssignments.lean | 8 ++++ ...taReducedMvarAssignments.lean.expected.out | 4 ++ 6 files changed, 60 insertions(+), 40 deletions(-) create mode 100644 tests/lean/etaReducedMvarAssignments.lean create mode 100644 tests/lean/etaReducedMvarAssignments.lean.expected.out diff --git a/src/Lean/Meta/Basic.lean b/src/Lean/Meta/Basic.lean index c11eabf36b..1aadbc64b7 100644 --- a/src/Lean/Meta/Basic.lean +++ b/src/Lean/Meta/Basic.lean @@ -912,8 +912,8 @@ def mkForallFVars (xs : Array Expr) (e : Expr) (usedOnly : Bool := false) (usedL /-- Takes an array `xs` of free variables and metavariables and a body term `e` and creates `fun ..xs => e`, suitably abstracting `e` and the types in `xs`. -/ -def mkLambdaFVars (xs : Array Expr) (e : Expr) (usedOnly : Bool := false) (usedLetOnly : Bool := true) (binderInfoForMVars := BinderInfo.implicit) : MetaM Expr := - if xs.isEmpty then return e else liftMkBindingM <| MetavarContext.mkLambda xs e usedOnly usedLetOnly binderInfoForMVars +def mkLambdaFVars (xs : Array Expr) (e : Expr) (usedOnly : Bool := false) (usedLetOnly : Bool := true) (etaReduce : Bool := false) (binderInfoForMVars := BinderInfo.implicit) : MetaM Expr := + if xs.isEmpty then return e else liftMkBindingM <| MetavarContext.mkLambda xs e usedOnly usedLetOnly etaReduce binderInfoForMVars def mkLetFVars (xs : Array Expr) (e : Expr) (usedLetOnly := true) (binderInfoForMVars := BinderInfo.implicit) : MetaM Expr := mkLambdaFVars xs e (usedLetOnly := usedLetOnly) (binderInfoForMVars := binderInfoForMVars) diff --git a/src/Lean/Meta/ExprDefEq.lean b/src/Lean/Meta/ExprDefEq.lean index 7e77eb078c..e2a40ca94f 100644 --- a/src/Lean/Meta/ExprDefEq.lean +++ b/src/Lean/Meta/ExprDefEq.lean @@ -419,10 +419,10 @@ where the index is the position in the local context. -/ private partial def mkLambdaFVarsWithLetDeps (xs : Array Expr) (v : Expr) : MetaM (Option Expr) := do if not (← hasLetDeclsInBetween) then - mkLambdaFVars xs v + mkLambdaFVars xs v (etaReduce := true) else let ys ← addLetDeps - mkLambdaFVars ys v + mkLambdaFVars ys v (etaReduce := true) where /-- Return true if there are let-declarions between `xs[0]` and `xs[xs.size-1]`. diff --git a/src/Lean/Meta/SynthInstance.lean b/src/Lean/Meta/SynthInstance.lean index 629578ce5f..9001694916 100644 --- a/src/Lean/Meta/SynthInstance.lean +++ b/src/Lean/Meta/SynthInstance.lean @@ -333,26 +333,6 @@ def getSubgoals (lctx : LocalContext) (localInsts : LocalInstances) (xs : Array subgoals := inst.synthOrder.map (mvars[·]!) |>.toList } -/-- -Similar to `mkLambdaFVars`, but ensures result is eta-reduced. -For example, suppose `e` is the local variable `inst x y`, and `xs` is `#[x, y]`, then -the result is `inst` instead of `fun x y => inst x y`. - -We added this auxiliary function because of aliases such as `DecidablePred`. For example, -consider the following definition. -``` -def filter (p : α → Prop) [inst : DecidablePred p] (xs : List α) : List α := - match xs with - | [] => [] - | x :: xs' => if p x then x :: filter p xs' else filter p xs' -``` -Without `mkLambdaFVars'`, the implicit instance at the `filter` applications would be `fun x => inst x` instead of `inst`. -Moreover, the equation lemmas associated with `filter` would have `fun x => inst x` on their right-hand-side. Then, -we would start getting terms such as `fun x => (fun x => inst x) x` when using the equational theorem. --/ -private def mkLambdaFVars' (xs : Array Expr) (e : Expr) : MetaM Expr := - return (← mkLambdaFVars xs e).eta - /-- Try to synthesize metavariable `mvar` using the instance `inst`. Remark: `mctx` is set using `withMCtx`. @@ -370,7 +350,23 @@ def tryResolve (mvar : Expr) (inst : Instance) : MetaM (Option (MetavarContext withTraceNode `Meta.synthInstance.tryResolve (withMCtx (← getMCtx) do return m!"{exceptOptionEmoji ·} {← instantiateMVars mvarTypeBody} ≟ {← instantiateMVars instTypeBody}") do if (← isDefEq mvarTypeBody instTypeBody) then - let instVal ← mkLambdaFVars' xs instVal + /- + We set `etaReduce := true`. + For example, suppose `e` is the local variable `inst x y`, and `xs` is `#[x, y]`, then + the result is `inst` instead of `fun x y => inst x y`. + + Consider the following definition. + ``` + def filter (p : α → Prop) [inst : DecidablePred p] (xs : List α) : List α := + match xs with + | [] => [] + | x :: xs' => if p x then x :: filter p xs' else filter p xs' + ``` + Without `etaReduce := true`, the implicit instance at the `filter` applications would be `fun x => inst x` instead of `inst`. + Moreover, the equation lemmas associated with `filter` would have `fun x => inst x` on their right-hand-side. Then, + we would start getting terms such as `fun x => (fun x => inst x) x` when using the equational theorem. + -/ + let instVal ← mkLambdaFVars xs instVal (etaReduce := true) if (← isDefEq mvar instVal) then return some ((← getMCtx), subgoals) return none @@ -483,7 +479,7 @@ private def removeUnusedArguments? (mctx : MetavarContext) (mvar : Expr) : MetaM let ys := ys.toArray let mvarType' ← mkForallFVars ys body withLocalDeclD `redf mvarType' fun f => do - let transformer ← mkLambdaFVars' #[f] (← mkLambdaFVars' xs (mkAppN f ys)) + let transformer ← mkLambdaFVars #[f] (← mkLambdaFVars xs (mkAppN f ys) (etaReduce := true)) (etaReduce := true) trace[Meta.synthInstance.unusedArgs] "{mvarType}\nhas unused arguments, reduced type{indentExpr mvarType'}\nTransformer{indentExpr transformer}" return some (mvarType', transformer) diff --git a/src/Lean/MetavarContext.lean b/src/Lean/MetavarContext.lean index 9895afd991..c8db6a4a42 100644 --- a/src/Lean/MetavarContext.lean +++ b/src/Lean/MetavarContext.lean @@ -1274,13 +1274,25 @@ partial def revert (xs : Array Expr) (mvarId : MVarId) : M (Expr × Array Expr) let e ← elimMVarDeps xs e pure (e.abstractRange i xs) +private def mkLambda' (x : Name) (bi : BinderInfo) (t : Expr) (b : Expr) (etaReduce : Bool) : Expr := + if etaReduce then + match b with + | .app f (.bvar 0) => + if !f.hasLooseBVar 0 then + f.lowerLooseBVars 1 1 + else + mkLambda x bi t b + | _ => mkLambda x bi t b + else + mkLambda x bi t b + /-- Similar to `LocalContext.mkBinding`, but handles metavariables correctly. If `usedOnly == true` then `forall` and `lambda` expressions are created only for used variables. If `usedLetOnly == true` then `let` expressions are created only for used (let-) variables. -/ -@[specialize] def mkBinding (isLambda : Bool) (lctx : LocalContext) (xs : Array Expr) (e : Expr) (usedOnly : Bool) (usedLetOnly : Bool) : M (Expr × Nat) := do +@[specialize] def mkBinding (isLambda : Bool) (lctx : LocalContext) (xs : Array Expr) (e : Expr) (usedOnly : Bool) (usedLetOnly : Bool) (etaReduce : Bool) : M Expr := do let e ← abstractRange xs xs.size e - xs.size.foldRevM (init := (e, 0)) fun i (e, num) => do + xs.size.foldRevM (init := e) fun i e => do let x := xs[i]! if x.isFVar then match lctx.getFVar! x with @@ -1289,27 +1301,27 @@ partial def revert (xs : Array Expr) (mvarId : MVarId) : M (Expr × Array Expr) let type := type.headBeta; let type ← abstractRange xs i type if isLambda then - return (Lean.mkLambda n bi type e, num + 1) + return mkLambda' n bi type e etaReduce else - return (Lean.mkForall n bi type e, num + 1) + return Lean.mkForall n bi type e else - return (e.lowerLooseBVars 1 1, num) + return e.lowerLooseBVars 1 1 | LocalDecl.ldecl _ _ n type value nonDep _ => if !usedLetOnly || e.hasLooseBVar 0 then let type ← abstractRange xs i type let value ← abstractRange xs i value - return (mkLet n type value e nonDep, num + 1) + return mkLet n type value e nonDep else - return (e.lowerLooseBVars 1 1, num) + return e.lowerLooseBVars 1 1 else let mvarDecl := (← get).mctx.getDecl x.mvarId! let type := mvarDecl.type.headBeta let type ← abstractRange xs i type let id ← if mvarDecl.userName.isAnonymous then mkFreshBinderName else pure mvarDecl.userName if isLambda then - return (Lean.mkLambda id (← read).binderInfoForMVars type e, num + 1) + return mkLambda' id (← read).binderInfoForMVars type e etaReduce else - return (Lean.mkForall id (← read).binderInfoForMVars type e, num + 1) + return Lean.mkForall id (← read).binderInfoForMVars type e end MkBinding @@ -1325,15 +1337,15 @@ def elimMVarDeps (xs : Array Expr) (e : Expr) (preserveOrder : Bool) : MkBinding def revert (xs : Array Expr) (mvarId : MVarId) (preserveOrder : Bool) : MkBindingM (Expr × Array Expr) := fun ctx => MkBinding.revert xs mvarId { preserveOrder, mainModule := ctx.mainModule } -def mkBinding (isLambda : Bool) (xs : Array Expr) (e : Expr) (usedOnly : Bool := false) (usedLetOnly : Bool := true) (binderInfoForMVars := BinderInfo.implicit) : MkBindingM (Expr × Nat) := fun ctx => +def mkBinding (isLambda : Bool) (xs : Array Expr) (e : Expr) (usedOnly : Bool := false) (usedLetOnly : Bool := true) (etaReduce := false) (binderInfoForMVars := BinderInfo.implicit) : MkBindingM Expr := fun ctx => let mvarIdsToAbstract := xs.foldl (init := {}) fun s x => if x.isMVar then s.insert x.mvarId! else s - MkBinding.mkBinding isLambda ctx.lctx xs e usedOnly usedLetOnly { preserveOrder := false, binderInfoForMVars, mvarIdsToAbstract, mainModule := ctx.mainModule } + MkBinding.mkBinding isLambda ctx.lctx xs e usedOnly usedLetOnly etaReduce { preserveOrder := false, binderInfoForMVars, mvarIdsToAbstract, mainModule := ctx.mainModule } -@[inline] def mkLambda (xs : Array Expr) (e : Expr) (usedOnly : Bool := false) (usedLetOnly : Bool := true) (binderInfoForMVars := BinderInfo.implicit) : MkBindingM Expr := - return (← mkBinding (isLambda := true) xs e usedOnly usedLetOnly binderInfoForMVars).1 +@[inline] def mkLambda (xs : Array Expr) (e : Expr) (usedOnly : Bool := false) (usedLetOnly : Bool := true) (etaReduce := false) (binderInfoForMVars := BinderInfo.implicit) : MkBindingM Expr := + return ← mkBinding (isLambda := true) xs e usedOnly usedLetOnly etaReduce binderInfoForMVars @[inline] def mkForall (xs : Array Expr) (e : Expr) (usedOnly : Bool := false) (usedLetOnly : Bool := true) (binderInfoForMVars := BinderInfo.implicit) : MkBindingM Expr := - return (← mkBinding (isLambda := false) xs e usedOnly usedLetOnly binderInfoForMVars).1 + return ← mkBinding (isLambda := false) xs e usedOnly usedLetOnly false binderInfoForMVars @[inline] def abstractRange (e : Expr) (n : Nat) (xs : Array Expr) : MkBindingM Expr := fun ctx => MkBinding.abstractRange xs n e { preserveOrder := false, mainModule := ctx.mainModule } diff --git a/tests/lean/etaReducedMvarAssignments.lean b/tests/lean/etaReducedMvarAssignments.lean new file mode 100644 index 0000000000..9594d194ae --- /dev/null +++ b/tests/lean/etaReducedMvarAssignments.lean @@ -0,0 +1,8 @@ +instance Pi.hasLe {ι : Type u} {α : ι → Type v} [∀ i, LE (α i)] : + LE (∀ i, α i) where le x y := ∀ i, x i ≤ y i + +variable {ι : Type u} {α : ι → Type v} [inst : (i : ι) → LE (α i)] + +set_option trace.Meta.isDefEq.assign true + +#check @Pi.hasLe ι _ inst diff --git a/tests/lean/etaReducedMvarAssignments.lean.expected.out b/tests/lean/etaReducedMvarAssignments.lean.expected.out new file mode 100644 index 0000000000..7c743ec28e --- /dev/null +++ b/tests/lean/etaReducedMvarAssignments.lean.expected.out @@ -0,0 +1,4 @@ +Pi.hasLe : LE ((i : ι) → α i) +[Meta.isDefEq.assign] ✅ ?m i := α i + [Meta.isDefEq.assign.beforeMkLambda] ?m [i] := α i + [Meta.isDefEq.assign.checkTypes] ✅ (?m : ι → Type ?u) := (α : ι → Type v)