lean4-htt/tests/bench/sym/simp_3.lean
Leonardo de Moura d57f71c1c0
perf: optimize kernel type-checking for have-telescope simplification in Sym.simp (#11967)
This PR implements a new strategy for simplifying `have`-telescopes in
`Sym.simp` that achieves linear kernel type-checking time instead of
quadratic.

## Problem

When simplifying deep `have`-telescopes, the previous approach using
`have_congr'` produced proofs that type-checked in quadratic time. The
simplifier itself was fast, but the kernel became the bottleneck for
large telescopes.

For example, at n=100:
- **Before**: simp = 2.4ms, kernel = **225ms**
- **After**: simp = 3.5ms, kernel = **10ms**

The quadratic behavior occurred because the kernel creates fresh free
variables for each binder when type-checking, destroying sharing and
producing O(n²) intermediate terms.

## Solution

We transform sequential `have`-telescopes into a parallel
beta-application form:

```
have x₁ := v₁; have x₂ := v₂[x₁]; b[x₁, x₂]
  ↓ (definitionally equal)
(fun x₁ x₂' => b[x₁, x₂' x₁]) v₁ (fun x₁ => v₂[x₁])
```

This parallel form leverages the efficient simplifier for lambdas in
`Sym.simp`. This form enables:
1. Independent simplification of each argument
2. Proof construction using standard congruence lemmas
3. Linear kernel type-checking time

The algorithm has three phases:
1. **`toBetaApp`**: Transform telescope → parallel beta-application
2. **`simpBetaApp`**: Simplify using `congr`/`congrArg`/`congrFun'` and
`simpLambda`
3. **`toHave`**: Convert back to `have` form

## Benchmark Results

### Benchmark 1: Chain with all variables used in body

| n | Before (simp) | Before (kernel) | After (simp) | After (kernel) |
|---|---------------|-----------------|--------------|----------------|
| 50 | 1.2ms | 32ms | 1.6ms | 4.4ms |
| 100 | 2.4ms | **225ms** | 3.5ms | **10ms** |
| 200 | 4.5ms | — | 8.4ms | 27ms |
| 500 | 11.7ms | — | 33.6ms | 128ms |

### Benchmark 3: Parallel declarations (simplified values)

| n | Before (simp) | Before (kernel) | After (simp) | After (kernel) |
|---|---------------|-----------------|--------------|----------------|
| 50 | 0.5ms | 24ms | 0.8ms | 1.8ms |
| 100 | 1.2ms | **169ms** | 1.8ms | **5.3ms** |
| 200 | 2.2ms | — | 3.9ms | 17ms |
| 500 | 5.9ms | — | 12.3ms | 93ms |

### Benchmark 5: Chain with single dependency

| n | Before (simp) | Before (kernel) | After (simp) | After (kernel) |
|---|---------------|-----------------|--------------|----------------|
| 100 | 1.6ms | 6.2ms | 1.8ms | 6.2ms |
| 200 | 2.8ms | 21.6ms | 4.4ms | 16.5ms |
| 500 | 7.3ms | **125ms** | 12.8ms | **72ms** |

Key observations:
- Kernel time is now **linear** in telescope depth (previously
quadratic)
- Simp time increases slightly due to the transformation overhead
- Total time (simp + kernel) is dramatically reduced for large
telescopes
- The improvement is most pronounced when the body depends on many
variables

## Trade-offs

- Proof sizes are larger (more congruence lemma applications)
- Simp time has ~1.5x overhead from the transformation
- For very small telescopes (n < 10), the overhead may not pay off

The optimization targets the critical path: kernel type-checking was the
bottleneck preventing scaling to realistic symbolic simulation
workloads.
2026-01-11 02:20:47 +00:00

159 lines
5.6 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import Lean
open Lean Meta
opaque f : Nat → Nat
namespace SimpBench
/-!
## `SymM` Simplifier benchmarks
-/
def getProofSize (r : Sym.Simp.Result) : MetaM Nat := do
match r with
| .rfl _ => return 0
| .step _ p _ => (ShareCommon.shareCommon' p).numObjs
def checkWithKernel (r : Sym.Simp.Result) : MetaM Float := do
match r with
| .rfl _ => return 0.0
| .step _ p _ =>
let p := ShareCommon.shareCommon' p
let startTime ← IO.monoNanosNow
Meta.checkWithKernel p
let endTime ← IO.monoNanosNow
return (endTime - startTime).toFloat / 1000000.0
def mkSimpMethods : MetaM Sym.Simp.Methods := do
let thms : Sym.Simp.Theorems := {}
let thms := thms.insert (← Sym.Simp.mkTheoremFromDecl ``Nat.zero_add)
let thms := thms.insert (← Sym.Simp.mkTheoremFromDecl ``Nat.add_zero)
return { post := thms.rewrite }
def simp (e : Expr) : MetaM (Sym.Simp.Result × Float) := Sym.SymM.run do
let e ← Grind.shareCommon e
let methods ← mkSimpMethods
let startTime ← IO.monoNanosNow
let r ← Sym.simp e methods { maxSteps := 100000000 }
let endTime ← IO.monoNanosNow
let timeMs := (endTime - startTime).toFloat / 1000000.0
-- logInfo e
-- match r with
-- | .rfl _ => logInfo "rfl"
-- | .step e' h _ =>
-- logInfo e'; logInfo h
return (r, timeMs)
def ppExample (e : Expr) : MetaM Unit := do
forallTelescope e fun _ e => do
IO.println "Example:"
IO.println (← ppExpr e)
IO.println "====>"
match (← simp e).1 with
| .rfl _ => IO.println "<no change>"
| .step e' h _ =>
IO.println (← ppExpr e')
IO.println (← ppExpr h)
IO.println ""
def benchSimp (name : String) (e : Expr) (check := false) : MetaM Unit :=
forallTelescope e fun _ e => do
let (r, timeMs) ← simp e
let proofSize ← getProofSize r
if check then
let kMs ← checkWithKernel r
IO.println s!"{name}: {timeMs}ms, kernel: {kMs}ms, proof_size={proofSize}"
else
IO.println s!"{name}: {timeMs}ms, proof_size={proofSize}"
def mkHaveChainBench (n : Nat) (includeUnused : Bool) : MetaM Expr := do
let zero := mkNatLit 0
let one := mkNatLit 1
let rec go (n : Nat) (xs : Array Expr) (v : Expr) (e : Expr) : MetaM Expr := do
match n with
| 0 => mkLetFVars (usedLetOnly := false) (generalizeNondepLet := false) xs e
| n+1 =>
if !includeUnused || n % 2 == 0 then
withLetDecl (nondep := true) `x (mkConst ``Nat) (mkNatAdd zero (mkNatAdd v one)) fun x =>
go n (xs.push x) x (mkNatAdd zero (mkNatAdd e x))
else
withLetDecl (nondep := true) `y (mkConst ``Nat) zero fun y =>
go n (xs.push y) v (mkNatAdd zero (mkNatAdd e zero))
go n #[] zero zero
def benchHaveChain (n : Nat) (includeUnused : Bool) (check : Bool := false) : MetaM Unit := do
let e ← mkHaveChainBench n includeUnused
let name := if includeUnused then s!"have_chain_unused_{n}" else s!"have_chain_{n}"
benchSimp name e check
def mkHaveParallelBench (n : Nat) (simpValues : Bool) : MetaM Expr := do
withLocalDeclD `x Nat.mkType fun x => do
let zero := mkNatLit 0
let rec go (n : Nat) (xs : Array Expr) (e : Expr) : MetaM Expr := do
match n with
| 0 => mkLetFVars (usedLetOnly := false) (generalizeNondepLet := false) xs e
| n+1 =>
let val := if simpValues then
-- Values should be in `simp` normal form
mkNatAdd x (mkNatLit n)
else
mkNatAdd zero (mkNatAdd x (mkNatLit n))
withLetDecl (nondep := true) `y (mkConst ``Nat) val fun x =>
go n (xs.push x) (mkNatAdd x e)
let r ← go n #[] zero
mkForallFVars #[x] r
def benchHaveParallel (n : Nat) (simpValues : Bool) (check : Bool := false) : MetaM Unit := do
let e ← mkHaveParallelBench n simpValues
let name := if simpValues then s!"have_parallel_simp_vals_{n}" else s!"have_parallel_unsimp_vals_{n}"
benchSimp name e check
def mkHaveChain1Bench (n : Nat) : MetaM Expr := do
let zero := mkNatLit 0
let one := mkNatLit 1
let rec go (n : Nat) (xs : Array Expr) (v : Expr) (e : Expr) : MetaM Expr := do
match n with
| 0 => mkLetFVars (usedLetOnly := false) (generalizeNondepLet := false) xs (mkNatAdd v e)
| n+1 =>
withLetDecl (nondep := true) `x (mkConst ``Nat) (mkNatAdd zero (mkNatAdd v one)) fun x =>
go n (xs.push x) x (mkNatAdd one e)
go n #[] zero zero
def benchHaveChain1 (n : Nat) (check : Bool := false) : MetaM Unit := do
let e ← mkHaveChain1Bench n
benchSimp s!"have_chain1_{n}" e check
def run (k : Nat → MetaM Unit) : MetaM Unit := do
for n in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 200, 300, 400, 500] do
k n
set_option maxRecDepth 100000
/-! ## Run all benchmarks -/
def runAllBenchmarks : MetaM Unit := do
IO.println "=== Simplifier Have-telescope Stress Tests ==="
IO.println ""
IO.println ""
IO.println "--- Benchmark 1: have-telescope chain without unused variables ---"
ppExample (← mkHaveChainBench 5 false)
run fun n => benchHaveChain n false true
IO.println "--- Benchmark 2: have-telescope chain with unused variables ---"
ppExample (← mkHaveChainBench 5 true)
run fun n => benchHaveChain n true true
IO.println "--- Benchmark 3: have-telescope parallel declarations with simplified values ---"
ppExample (← mkHaveParallelBench 5 true)
run fun n => benchHaveParallel n true true
IO.println "--- Benchmark 4: have-telescope parallel declarations with unsimplified values ---"
ppExample (← mkHaveParallelBench 5 false)
run fun n => benchHaveParallel n false true
IO.println ""
IO.println "--- Benchmark 5: have-telescope chain with 1 dep ---"
ppExample (← mkHaveChain1Bench 5)
run fun n => benchHaveChain1 n true
#eval runAllBenchmarks
end SimpBench