From f1c903ca654ec1fc413194620ff2df73ad871acd Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 4 Jan 2026 17:00:30 -0800 Subject: [PATCH] feat: simplify lambdas in `Sym.simp` (#11898) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds support for simplifying lambda expressions in `Sym.simp`. It is much more efficient than standard simp for very large lambda expressions with many binders. The key idea is to generate a custom function extensionality theorem for the type of the lambda being simplified. This technique is compatible with the standard `simp` tactic, and will be ported in a separate PR. image ### `lambda` benchmark summary | Lambda size | MetaM (ms) | SymM (ms) | Speedup | |-------------|------------|-----------|---------| | 50 | 22.7 | 0.74 | ~31× | | 100 | 120.5 | 1.75 | ~69× | | 150 | 359.6 | 2.90 | ~124× | | 200 | 809.5 | 4.51 | ~180× | --- src/Lean/Meta/Sym/Simp/Funext.lean | 54 +++++++++++++++++++++++++ src/Lean/Meta/Sym/Simp/Main.lean | 26 ++++++++++-- src/Lean/Meta/Sym/Simp/SimpM.lean | 2 + tests/bench/sym/meta_simp_2.lean | 60 ++++++++++++++++++++++++++++ tests/bench/sym/simp_2.lean | 64 ++++++++++++++++++++++++++++++ 5 files changed, 203 insertions(+), 3 deletions(-) create mode 100644 src/Lean/Meta/Sym/Simp/Funext.lean create mode 100644 tests/bench/sym/meta_simp_2.lean create mode 100644 tests/bench/sym/simp_2.lean diff --git a/src/Lean/Meta/Sym/Simp/Funext.lean b/src/Lean/Meta/Sym/Simp/Funext.lean new file mode 100644 index 0000000000..2a625302bd --- /dev/null +++ b/src/Lean/Meta/Sym/Simp/Funext.lean @@ -0,0 +1,54 @@ +/- +Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +module +prelude +public import Lean.Meta.Basic +import Lean.Meta.InferType +namespace Lean.Meta.Sym.Simp +/-- +Given `xs` containing free variables +`(x₁ : α₁) (x₂ : α₂[x₁]) ... (xₙ : αₙ[x₁, ..., x_{n-1}])` +and `β` a type of the form `β[x₁, ..., xₙ]`, +creates the custom function extensionality theorem +``` +∀ (f g : (x₁ : α₁) → (x₂ : α₂[x₁]) → ... → (xₙ : αₙ[x₁, ..., x_{n-1}]) → β[x₁, ..., xₙ]) + (h : ∀ x₁ ... xₙ, f x₁ ... xₙ = g x₁ ... xₙ), + f = g +``` +The theorem has three arguments `f`, `g`, and `h`. +This auxiliary theorem is used by the simplifier when visiting lambda expressions. +-/ +public def mkFunextFor (xs : Array Expr) (β : Expr) : MetaM Expr := do + let type ← mkForallFVars xs β + let v ← getLevel β + withLocalDeclD `f type fun f => + withLocalDeclD `g type fun g => do + let lhs := mkAppN f xs + let rhs := mkAppN g xs + let p := mkApp3 (mkConst ``Eq [v]) β lhs rhs + let p ← mkForallFVars xs p + withLocalDeclD `h p fun h => do + let mut result := mkAppN h xs |>.abstract xs + let mut i := xs.size + let mut β := β.abstract xs + let mut v := v + let mut f := mkAppN f xs |>.abstract xs + let mut g := mkAppN g xs |>.abstract xs + while i > 0 do + i := i - 1 + let x := xs[i]! + let α_i ← inferType x + let u_i ← getLevel α_i + let α_i := α_i.abstractRange i xs + f := f.appFn!.lowerLooseBVars 1 1 + g := g.appFn!.lowerLooseBVars 1 1 + result := mkLambda `x default α_i result + result := mkApp5 (mkConst ``funext [u_i, v]) α_i (mkLambda `x .default α_i β) f g result + β := mkForall `x .default α_i β + v := mkLevelIMax' u_i v + mkLambdaFVars #[f, g, h] result + +end Lean.Meta.Sym.Simp diff --git a/src/Lean/Meta/Sym/Simp/Main.lean b/src/Lean/Meta/Sym/Simp/Main.lean index 11ed515129..70c1b75f10 100644 --- a/src/Lean/Meta/Sym/Simp/Main.lean +++ b/src/Lean/Meta/Sym/Simp/Main.lean @@ -7,15 +7,35 @@ module prelude public import Lean.Meta.Sym.Simp.SimpM import Lean.Meta.Tactic.Grind.AlphaShareBuilder +import Lean.Meta.Sym.InferType import Lean.Meta.Sym.Simp.Result import Lean.Meta.Sym.Simp.Simproc import Lean.Meta.Sym.Simp.Congr +import Lean.Meta.Sym.Simp.Funext namespace Lean.Meta.Sym.Simp open Grind -def simpLambda (_ : Expr) : SimpM Result := do - -- **TODO** - return .rfl +def simpLambda (e : Expr) : SimpM Result := do + -- **TODO**: Add free variable reuse + lambdaTelescope e fun xs b => do + match (← simp b) with + | .rfl => return .rfl + | .step b' h => + let h ← mkLambdaFVars xs h + -- **TODO**: Add `mkLambdaFVarsS`? + let e' ← shareCommonInc (← mkLambdaFVars xs b') + let funext ← getFunext xs b + return .step e' (mkApp3 funext e e' h) +where + getFunext (xs : Array Expr) (b : Expr) : SimpM Expr := do + let key ← inferType e + if let some h := (← get).funext.find? { expr := key } then + return h + else + let β ← inferType b + let h ← mkFunextFor xs β + modify fun s => { s with funext := s.funext.insert { expr := key } h } + return h def simpForall (_ : Expr) : SimpM Result := do -- **TODO** diff --git a/src/Lean/Meta/Sym/Simp/SimpM.lean b/src/Lean/Meta/Sym/Simp/SimpM.lean index cca7d355b8..b17fdcf963 100644 --- a/src/Lean/Meta/Sym/Simp/SimpM.lean +++ b/src/Lean/Meta/Sym/Simp/SimpM.lean @@ -138,6 +138,8 @@ structure State where binderStack : List (ExprPtr × FVarId) := [] /-- Number of steps performed so far. -/ numSteps := 0 + /-- Cache for generated funext theorems -/ + funext : PHashMap ExprPtr Expr := {} /-- Monad for the structural simplifier, layered on top of `SymM`. -/ abbrev SimpM := ReaderT MethodsRef $ ReaderT Context StateRefT State SymM diff --git a/tests/bench/sym/meta_simp_2.lean b/tests/bench/sym/meta_simp_2.lean new file mode 100644 index 0000000000..97809b0ad3 --- /dev/null +++ b/tests/bench/sym/meta_simp_2.lean @@ -0,0 +1,60 @@ +import Lean +open Lean Meta +opaque f : Nat → Nat + +namespace SimpBench + +/-! +## `MetaM` Simplifier benchmarks +-/ +def getProofSize (r : Simp.Result) : MetaM Nat := do + (← r.getProof).numObjs + +def mkSimpContext (config : Simp.Config := {}) : MetaM Simp.Context := do + let s : SimpTheorems := {} + let s ← s.addConst ``Nat.zero_add + let config := { config with implicitDefEqProofs := false } + Simp.mkContext config #[s] {} + +def simp (e : Expr) : MetaM (Simp.Result × Float) := Sym.SymM.run' do + -- let e ← Grind.shareCommon e + let startTime ← IO.monoNanosNow + let (r, _) ← Meta.simp e (← mkSimpContext) + let endTime ← IO.monoNanosNow + -- logInfo e + -- logInfo r.expr + -- check (← r.getProof) + let timeMs := (endTime - startTime).toFloat / 1000000.0 + return (r, timeMs) + +def mkLambdaBench (n : Nat) : MetaM Expr := do + let zero := mkNatLit 0 + let rec go (n : Nat) (xs : Array Expr) (e : Expr) : MetaM Expr := do + match n with + | 0 => mkLambdaFVars xs e + | n+1 => + withLocalDeclD `x (mkConst ``Nat) fun x => + go n (xs.push x) (mkNatAdd zero (mkNatAdd e x)) + go n #[] zero + +def benchLambda (n : Nat) : MetaM Unit := do + let e ← mkLambdaBench n + let (r, timeMs) ← simp e + let proofSize ← getProofSize r + IO.println s!"lambda_{n}: {timeMs}ms, proof_size={proofSize}" + +set_option maxRecDepth 100000 + +/-! ## Run all benchmarks -/ +def runAllBenchmarks : MetaM Unit := do + IO.println "=== Simplifier Stress Tests ===" + IO.println "" + + IO.println "" + IO.println "--- Benchmark 1: Transitivity chain ---" + for n in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200] do + benchLambda n + +#eval runAllBenchmarks + +end SimpBench diff --git a/tests/bench/sym/simp_2.lean b/tests/bench/sym/simp_2.lean new file mode 100644 index 0000000000..f8e75a767c --- /dev/null +++ b/tests/bench/sym/simp_2.lean @@ -0,0 +1,64 @@ +import Lean +open Lean Meta +opaque f : Nat → Nat + +namespace SimpBench +/-! +## `SymM` Simplifier benchmarks +-/ + +def getProofSize (r : Sym.Simp.Result) : MetaM Nat := + match r with + | .rfl => return 0 + | .step _ p => p.numObjs + +def mkSimpMethods : MetaM Sym.Simp.Methods := do + let thms : Sym.Simp.Theorems := {} + let thm ← Sym.Simp.mkTheoremFromDecl ``Nat.zero_add + let thms := thms.insert thm + return { post := thms.rewrite } + +def simp (e : Expr) : MetaM (Sym.Simp.Result × Float) := Sym.SymM.run' do + let e ← Grind.shareCommon e + let methods ← mkSimpMethods + let startTime ← IO.monoNanosNow + let r ← Sym.simp e methods { maxSteps := 100000000 } + let endTime ← IO.monoNanosNow + -- logInfo e + -- match r with + -- | .rfl => logInfo "rfl" + -- | .step e' h => logInfo e'; logInfo h; check h + let timeMs := (endTime - startTime).toFloat / 1000000.0 + return (r, timeMs) + +def mkLambdaBench (n : Nat) : MetaM Expr := do + let zero := mkNatLit 0 + let rec go (n : Nat) (xs : Array Expr) (e : Expr) : MetaM Expr := do + match n with + | 0 => mkLambdaFVars xs e + | n+1 => + withLocalDeclD `x (mkConst ``Nat) fun x => + go n (xs.push x) (mkNatAdd zero (mkNatAdd e x)) + go n #[] zero + +def benchLambda (n : Nat) : MetaM Unit := do + let e ← mkLambdaBench n + let (r, timeMs) ← simp e + let proofSize ← getProofSize r + IO.println s!"lambda_{n}: {timeMs}ms, proof_size={proofSize}" + +set_option maxRecDepth 100000 + +/-! ## Run all benchmarks -/ +def runAllBenchmarks : MetaM Unit := do + IO.println "=== Simplifier Stress Tests ===" + IO.println "" + + IO.println "" + IO.println "--- Benchmark 1: Lambda block ---" + for n in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200] do + benchLambda n + +#eval runAllBenchmarks + +end SimpBench