feat: simplify lambdas in Sym.simp (#11898)
This PR adds support for simplifying lambda expressions in `Sym.simp`. It is much more efficient than standard simp for very large lambda expressions with many binders. The key idea is to generate a custom function extensionality theorem for the type of the lambda being simplified. This technique is compatible with the standard `simp` tactic, and will be ported in a separate PR. <img width="581" height="455" alt="image" src="https://github.com/user-attachments/assets/5911dc6c-03f0-48ed-843b-b8cb4f67ee61" /> ### `lambda` benchmark summary | Lambda size | MetaM (ms) | SymM (ms) | Speedup | |-------------|------------|-----------|---------| | 50 | 22.7 | 0.74 | ~31× | | 100 | 120.5 | 1.75 | ~69× | | 150 | 359.6 | 2.90 | ~124× | | 200 | 809.5 | 4.51 | ~180× |
This commit is contained in:
parent
35d8925c50
commit
f1c903ca65
5 changed files with 203 additions and 3 deletions
54
src/Lean/Meta/Sym/Simp/Funext.lean
Normal file
54
src/Lean/Meta/Sym/Simp/Funext.lean
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
/-
|
||||
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Basic
|
||||
import Lean.Meta.InferType
|
||||
namespace Lean.Meta.Sym.Simp
|
||||
/--
|
||||
Given `xs` containing free variables
|
||||
`(x₁ : α₁) (x₂ : α₂[x₁]) ... (xₙ : αₙ[x₁, ..., x_{n-1}])`
|
||||
and `β` a type of the form `β[x₁, ..., xₙ]`,
|
||||
creates the custom function extensionality theorem
|
||||
```
|
||||
∀ (f g : (x₁ : α₁) → (x₂ : α₂[x₁]) → ... → (xₙ : αₙ[x₁, ..., x_{n-1}]) → β[x₁, ..., xₙ])
|
||||
(h : ∀ x₁ ... xₙ, f x₁ ... xₙ = g x₁ ... xₙ),
|
||||
f = g
|
||||
```
|
||||
The theorem has three arguments `f`, `g`, and `h`.
|
||||
This auxiliary theorem is used by the simplifier when visiting lambda expressions.
|
||||
-/
|
||||
public def mkFunextFor (xs : Array Expr) (β : Expr) : MetaM Expr := do
|
||||
let type ← mkForallFVars xs β
|
||||
let v ← getLevel β
|
||||
withLocalDeclD `f type fun f =>
|
||||
withLocalDeclD `g type fun g => do
|
||||
let lhs := mkAppN f xs
|
||||
let rhs := mkAppN g xs
|
||||
let p := mkApp3 (mkConst ``Eq [v]) β lhs rhs
|
||||
let p ← mkForallFVars xs p
|
||||
withLocalDeclD `h p fun h => do
|
||||
let mut result := mkAppN h xs |>.abstract xs
|
||||
let mut i := xs.size
|
||||
let mut β := β.abstract xs
|
||||
let mut v := v
|
||||
let mut f := mkAppN f xs |>.abstract xs
|
||||
let mut g := mkAppN g xs |>.abstract xs
|
||||
while i > 0 do
|
||||
i := i - 1
|
||||
let x := xs[i]!
|
||||
let α_i ← inferType x
|
||||
let u_i ← getLevel α_i
|
||||
let α_i := α_i.abstractRange i xs
|
||||
f := f.appFn!.lowerLooseBVars 1 1
|
||||
g := g.appFn!.lowerLooseBVars 1 1
|
||||
result := mkLambda `x default α_i result
|
||||
result := mkApp5 (mkConst ``funext [u_i, v]) α_i (mkLambda `x .default α_i β) f g result
|
||||
β := mkForall `x .default α_i β
|
||||
v := mkLevelIMax' u_i v
|
||||
mkLambdaFVars #[f, g, h] result
|
||||
|
||||
end Lean.Meta.Sym.Simp
|
||||
|
|
@ -7,15 +7,35 @@ module
|
|||
prelude
|
||||
public import Lean.Meta.Sym.Simp.SimpM
|
||||
import Lean.Meta.Tactic.Grind.AlphaShareBuilder
|
||||
import Lean.Meta.Sym.InferType
|
||||
import Lean.Meta.Sym.Simp.Result
|
||||
import Lean.Meta.Sym.Simp.Simproc
|
||||
import Lean.Meta.Sym.Simp.Congr
|
||||
import Lean.Meta.Sym.Simp.Funext
|
||||
namespace Lean.Meta.Sym.Simp
|
||||
open Grind
|
||||
|
||||
def simpLambda (_ : Expr) : SimpM Result := do
|
||||
-- **TODO**
|
||||
return .rfl
|
||||
def simpLambda (e : Expr) : SimpM Result := do
|
||||
-- **TODO**: Add free variable reuse
|
||||
lambdaTelescope e fun xs b => do
|
||||
match (← simp b) with
|
||||
| .rfl => return .rfl
|
||||
| .step b' h =>
|
||||
let h ← mkLambdaFVars xs h
|
||||
-- **TODO**: Add `mkLambdaFVarsS`?
|
||||
let e' ← shareCommonInc (← mkLambdaFVars xs b')
|
||||
let funext ← getFunext xs b
|
||||
return .step e' (mkApp3 funext e e' h)
|
||||
where
|
||||
getFunext (xs : Array Expr) (b : Expr) : SimpM Expr := do
|
||||
let key ← inferType e
|
||||
if let some h := (← get).funext.find? { expr := key } then
|
||||
return h
|
||||
else
|
||||
let β ← inferType b
|
||||
let h ← mkFunextFor xs β
|
||||
modify fun s => { s with funext := s.funext.insert { expr := key } h }
|
||||
return h
|
||||
|
||||
def simpForall (_ : Expr) : SimpM Result := do
|
||||
-- **TODO**
|
||||
|
|
|
|||
|
|
@ -138,6 +138,8 @@ structure State where
|
|||
binderStack : List (ExprPtr × FVarId) := []
|
||||
/-- Number of steps performed so far. -/
|
||||
numSteps := 0
|
||||
/-- Cache for generated funext theorems -/
|
||||
funext : PHashMap ExprPtr Expr := {}
|
||||
|
||||
/-- Monad for the structural simplifier, layered on top of `SymM`. -/
|
||||
abbrev SimpM := ReaderT MethodsRef $ ReaderT Context StateRefT State SymM
|
||||
|
|
|
|||
60
tests/bench/sym/meta_simp_2.lean
Normal file
60
tests/bench/sym/meta_simp_2.lean
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
import Lean
|
||||
open Lean Meta
|
||||
opaque f : Nat → Nat
|
||||
|
||||
namespace SimpBench
|
||||
|
||||
/-!
|
||||
## `MetaM` Simplifier benchmarks
|
||||
-/
|
||||
def getProofSize (r : Simp.Result) : MetaM Nat := do
|
||||
(← r.getProof).numObjs
|
||||
|
||||
def mkSimpContext (config : Simp.Config := {}) : MetaM Simp.Context := do
|
||||
let s : SimpTheorems := {}
|
||||
let s ← s.addConst ``Nat.zero_add
|
||||
let config := { config with implicitDefEqProofs := false }
|
||||
Simp.mkContext config #[s] {}
|
||||
|
||||
def simp (e : Expr) : MetaM (Simp.Result × Float) := Sym.SymM.run' do
|
||||
-- let e ← Grind.shareCommon e
|
||||
let startTime ← IO.monoNanosNow
|
||||
let (r, _) ← Meta.simp e (← mkSimpContext)
|
||||
let endTime ← IO.monoNanosNow
|
||||
-- logInfo e
|
||||
-- logInfo r.expr
|
||||
-- check (← r.getProof)
|
||||
let timeMs := (endTime - startTime).toFloat / 1000000.0
|
||||
return (r, timeMs)
|
||||
|
||||
def mkLambdaBench (n : Nat) : MetaM Expr := do
|
||||
let zero := mkNatLit 0
|
||||
let rec go (n : Nat) (xs : Array Expr) (e : Expr) : MetaM Expr := do
|
||||
match n with
|
||||
| 0 => mkLambdaFVars xs e
|
||||
| n+1 =>
|
||||
withLocalDeclD `x (mkConst ``Nat) fun x =>
|
||||
go n (xs.push x) (mkNatAdd zero (mkNatAdd e x))
|
||||
go n #[] zero
|
||||
|
||||
def benchLambda (n : Nat) : MetaM Unit := do
|
||||
let e ← mkLambdaBench n
|
||||
let (r, timeMs) ← simp e
|
||||
let proofSize ← getProofSize r
|
||||
IO.println s!"lambda_{n}: {timeMs}ms, proof_size={proofSize}"
|
||||
|
||||
set_option maxRecDepth 100000
|
||||
|
||||
/-! ## Run all benchmarks -/
|
||||
def runAllBenchmarks : MetaM Unit := do
|
||||
IO.println "=== Simplifier Stress Tests ==="
|
||||
IO.println ""
|
||||
|
||||
IO.println ""
|
||||
IO.println "--- Benchmark 1: Transitivity chain ---"
|
||||
for n in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200] do
|
||||
benchLambda n
|
||||
|
||||
#eval runAllBenchmarks
|
||||
|
||||
end SimpBench
|
||||
64
tests/bench/sym/simp_2.lean
Normal file
64
tests/bench/sym/simp_2.lean
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
import Lean
|
||||
open Lean Meta
|
||||
opaque f : Nat → Nat
|
||||
|
||||
namespace SimpBench
|
||||
/-!
|
||||
## `SymM` Simplifier benchmarks
|
||||
-/
|
||||
|
||||
def getProofSize (r : Sym.Simp.Result) : MetaM Nat :=
|
||||
match r with
|
||||
| .rfl => return 0
|
||||
| .step _ p => p.numObjs
|
||||
|
||||
def mkSimpMethods : MetaM Sym.Simp.Methods := do
|
||||
let thms : Sym.Simp.Theorems := {}
|
||||
let thm ← Sym.Simp.mkTheoremFromDecl ``Nat.zero_add
|
||||
let thms := thms.insert thm
|
||||
return { post := thms.rewrite }
|
||||
|
||||
def simp (e : Expr) : MetaM (Sym.Simp.Result × Float) := Sym.SymM.run' do
|
||||
let e ← Grind.shareCommon e
|
||||
let methods ← mkSimpMethods
|
||||
let startTime ← IO.monoNanosNow
|
||||
let r ← Sym.simp e methods { maxSteps := 100000000 }
|
||||
let endTime ← IO.monoNanosNow
|
||||
-- logInfo e
|
||||
-- match r with
|
||||
-- | .rfl => logInfo "rfl"
|
||||
-- | .step e' h => logInfo e'; logInfo h; check h
|
||||
let timeMs := (endTime - startTime).toFloat / 1000000.0
|
||||
return (r, timeMs)
|
||||
|
||||
def mkLambdaBench (n : Nat) : MetaM Expr := do
|
||||
let zero := mkNatLit 0
|
||||
let rec go (n : Nat) (xs : Array Expr) (e : Expr) : MetaM Expr := do
|
||||
match n with
|
||||
| 0 => mkLambdaFVars xs e
|
||||
| n+1 =>
|
||||
withLocalDeclD `x (mkConst ``Nat) fun x =>
|
||||
go n (xs.push x) (mkNatAdd zero (mkNatAdd e x))
|
||||
go n #[] zero
|
||||
|
||||
def benchLambda (n : Nat) : MetaM Unit := do
|
||||
let e ← mkLambdaBench n
|
||||
let (r, timeMs) ← simp e
|
||||
let proofSize ← getProofSize r
|
||||
IO.println s!"lambda_{n}: {timeMs}ms, proof_size={proofSize}"
|
||||
|
||||
set_option maxRecDepth 100000
|
||||
|
||||
/-! ## Run all benchmarks -/
|
||||
def runAllBenchmarks : MetaM Unit := do
|
||||
IO.println "=== Simplifier Stress Tests ==="
|
||||
IO.println ""
|
||||
|
||||
IO.println ""
|
||||
IO.println "--- Benchmark 1: Lambda block ---"
|
||||
for n in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200] do
|
||||
benchLambda n
|
||||
|
||||
#eval runAllBenchmarks
|
||||
|
||||
end SimpBench
|
||||
Loading…
Add table
Reference in a new issue