From 62d2688579b8cb543ccae1952d8a44feff6839ec Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 26 Jan 2026 13:30:29 -0800 Subject: [PATCH] feat: eta-reduction support in `SymM` (#12168) This PR adds support for eta-reduction in `SymM`. --- src/Lean/Meta/Sym.lean | 1 + src/Lean/Meta/Sym/Eta.lean | 53 ++++++++++++++++++++++++ src/Lean/Meta/Sym/Pattern.lean | 14 ++++++- src/Lean/Meta/Sym/ProofInstInfo.lean | 4 +- tests/lean/run/sym_pattern_3.lean | 60 ++++++++++++++++++++++++++++ 5 files changed, 129 insertions(+), 3 deletions(-) create mode 100644 src/Lean/Meta/Sym/Eta.lean create mode 100644 tests/lean/run/sym_pattern_3.lean diff --git a/src/Lean/Meta/Sym.lean b/src/Lean/Meta/Sym.lean index bd49ceadd3..daa49fd597 100644 --- a/src/Lean/Meta/Sym.lean +++ b/src/Lean/Meta/Sym.lean @@ -23,6 +23,7 @@ public import Lean.Meta.Sym.Apply public import Lean.Meta.Sym.InferType public import Lean.Meta.Sym.Simp public import Lean.Meta.Sym.Util +public import Lean.Meta.Sym.Eta public import Lean.Meta.Sym.Grind /-! diff --git a/src/Lean/Meta/Sym/Eta.lean b/src/Lean/Meta/Sym/Eta.lean new file mode 100644 index 0000000000..b887ea212f --- /dev/null +++ b/src/Lean/Meta/Sym/Eta.lean @@ -0,0 +1,53 @@ +/- +Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +module +prelude +public import Lean.Meta.Sym.ExprPtr +public import Lean.Meta.Basic +import Lean.Meta.Transform +namespace Lean.Meta.Sym +/-- +Checks if `body` is eta-expanded with `n` applications: `f (.bvar (n-1)) ... (.bvar 0)`. +Returns `f` if so and `f` has no loose bvars; otherwise returns `default`. +- `n`: number of remaining applications to check +- `i`: expected bvar index (starts at 0, increments with each application) +- `default`: returned when not eta-reducible (enables pointer equality check) +-/ +def etaReduceAux (body : Expr) (n : Nat) (i : Nat) (default : Expr) : Expr := Id.run do + match n with + | 0 => if body.hasLooseBVars then default else body + | n+1 => + let .app f (.bvar j) := body | default + if j == i then etaReduceAux f n (i+1) default else default + +/-- +If `e` is of the form `(fun x₁ ... xₙ => f x₁ ... xₙ)` and `f` does not contain `x₁`, ..., `xₙ`, +then returns `f`. Otherwise, returns `e`. + +Returns the original expression when not reducible to enable pointer equality checks. +-/ +public def etaReduce (e : Expr) : Expr := + go e 0 +where + go (body : Expr) (n : Nat) : Expr := + match body with + | .lam _ _ b _ => go b (n+1) + | _ => if n == 0 then e else etaReduceAux body n 0 e + +/-- Returns `true` if `e` can be eta-reduced. Uses pointer equality for efficiency. -/ +public def isEtaReducible (e : Expr) : Bool := + !isSameExpr e (etaReduce e) + +/-- Applies `etaReduce` to all subexpressions. Returns `e` unchanged if no subexpression is eta-reducible. -/ +public def etaReduceAll (e : Expr) : MetaM Expr := do + unless Option.isSome <| e.find? isEtaReducible do return e + let pre (e : Expr) : MetaM TransformStep := do + let e' := etaReduce e + if isSameExpr e e' then return .continue + else return .visit e' + Meta.transform e (pre := pre) + +end Lean.Meta.Sym diff --git a/src/Lean/Meta/Sym/Pattern.lean b/src/Lean/Meta/Sym/Pattern.lean index b658dc85f2..90502e0f92 100644 --- a/src/Lean/Meta/Sym/Pattern.lean +++ b/src/Lean/Meta/Sym/Pattern.lean @@ -18,6 +18,7 @@ import Lean.Meta.Sym.ProofInstInfo import Lean.Meta.Sym.AlphaShareBuilder import Lean.Meta.Sym.LitValues import Lean.Meta.Sym.Offset +import Lean.Meta.Sym.Eta namespace Lean.Meta.Sym open Internal @@ -323,7 +324,11 @@ def isAssignedMVar (e : Expr) : MetaM Bool := | _ => return false partial def process (p : Expr) (e : Expr) : UnifyM Bool := do - match p with + let e' := etaReduce e + if !isSameExpr e e' then + -- **Note**: We eagerly eta reduce patterns + process p e' + else match p with | .bvar bidx => assignExpr bidx e | .mdata _ p => process p e | .const declName us => @@ -723,7 +728,12 @@ def isDefEqApp (tFn : Expr) (t : Expr) (s : Expr) (_ : tFn = t.getAppFn) : DefEq @[export lean_sym_def_eq] def isDefEqMainImpl (t : Expr) (s : Expr) : DefEqM Bool := do if isSameExpr t s then return true - match t, s with + -- **Note**: `etaReduce` is supposed to be fast, and does not allocate memory + let t' := etaReduce t + let s' := etaReduce s + if !isSameExpr t t' || !isSameExpr s s' then + isDefEqMain t' s' + else match t, s with | .lit l₁, .lit l₂ => return l₁ == l₂ | .sort u, .sort v => isLevelDefEqS u v | .lam .., .lam .. => isDefEqBindingS t s diff --git a/src/Lean/Meta/Sym/ProofInstInfo.lean b/src/Lean/Meta/Sym/ProofInstInfo.lean index ee5c59e05b..d203d5b410 100644 --- a/src/Lean/Meta/Sym/ProofInstInfo.lean +++ b/src/Lean/Meta/Sym/ProofInstInfo.lean @@ -9,6 +9,7 @@ public import Lean.Meta.Sym.SymM import Lean.Meta.Sym.IsClass import Lean.Meta.Sym.Util import Lean.Meta.Transform +import Lean.Meta.Sym.Eta namespace Lean.Meta.Sym /-- @@ -17,7 +18,8 @@ Preprocesses types that used for pattern matching and unification. public def preprocessType (type : Expr) : MetaM Expr := do let type ← Sym.unfoldReducible type let type ← Core.betaReduce type - zetaReduce type + let type ← zetaReduce type + etaReduceAll type /-- Analyzes whether the given free variables (aka arguments) are proofs or instances. diff --git a/tests/lean/run/sym_pattern_3.lean b/tests/lean/run/sym_pattern_3.lean new file mode 100644 index 0000000000..504c223f71 --- /dev/null +++ b/tests/lean/run/sym_pattern_3.lean @@ -0,0 +1,60 @@ +import Std.Data.HashMap +import Lean.Meta.Sym +import Lean.Meta.DiscrTree.Basic +open Lean Meta Sym Grind +set_option sym.debug true + +abbrev S := Nat +abbrev M α := StateM S α + +def Exec (s : S) (k : M α) (post : α → S → Prop) : Prop := + post (k s).1 (k s).2 + +theorem Exec.bind (k₁ : M α) (k₂ : α → M β) (post : β → S → Prop) : + Exec s k₁ (fun a s₁ => Exec s₁ (k₂ a) post) + → Exec s (k₁ >>= k₂) post := by + simp [Exec, Bind.bind, StateT.bind] + cases k₁ s; simp + +def goal := ∀ a b, Exec b (set a >>= fun _ => get) fun v _ => v = a +set_option pp.explicit true + +/-! +Recall that `SymM` patterns are eagerly eta-reduced. +Goals are not, but the pattern matcher/unifier performs eta whenever it is needed. +-/ + +/-- +info: Pattern: +@Exec #5 #4 (@bind (StateT Nat Id) (@Monad.toBind (StateT Nat Id) (@StateT.instMonad Nat Id Id.instMonad)) #6 #5 #3 #2) + #1 +--- +info: a b : Nat +⊢ @Exec Nat b + (@bind (fun α => StateT Nat Id α) (@Monad.toBind (fun α => StateT Nat Id α) (@StateT.instMonad Nat Id Id.instMonad)) + PUnit Nat (@set Nat (fun α => StateT Nat Id α) (@instMonadStateOfStateTOfMonad Nat Id Id.instMonad) a) fun x => + @get Nat (fun α => StateT Nat Id α) + (@instMonadStateOfMonadStateOf Nat (fun α => StateT Nat Id α) + (@instMonadStateOfStateTOfMonad Nat Id Id.instMonad))) + fun v x => @Eq Nat v a +--- +info: a b : Nat +⊢ @Exec PUnit b (@set Nat (fun α => StateT Nat Id α) (@instMonadStateOfStateTOfMonad Nat Id Id.instMonad) a) + fun a_1 s₁ => + @Exec Nat s₁ + (@get Nat (fun α => StateT Nat Id α) + (@instMonadStateOfMonadStateOf Nat (fun α => StateT Nat Id α) + (@instMonadStateOfStateTOfMonad Nat Id Id.instMonad))) + fun v x => @Eq Nat v a +-/ +#guard_msgs in +run_meta SymM.run do + let bindRule ← mkBackwardRuleFromDecl ``Exec.bind + let a ← unfoldDefinition (mkConst ``goal) + logInfo m!"Pattern:\n{bindRule.pattern.pattern}" + forallTelescope a fun _ body => do + let mvar ← mkFreshExprMVar body + let mvarId ← preprocessMVar mvar.mvarId! + logInfo mvarId + let .goals [mvarId] ← bindRule.apply mvarId | failure + logInfo mvarId