This PR implements a new strategy for simplifying `have`-telescopes in `Sym.simp` that achieves linear kernel type-checking time instead of quadratic. ## Problem When simplifying deep `have`-telescopes, the previous approach using `have_congr'` produced proofs that type-checked in quadratic time. The simplifier itself was fast, but the kernel became the bottleneck for large telescopes. For example, at n=100: - **Before**: simp = 2.4ms, kernel = **225ms** - **After**: simp = 3.5ms, kernel = **10ms** The quadratic behavior occurred because the kernel creates fresh free variables for each binder when type-checking, destroying sharing and producing O(n²) intermediate terms. ## Solution We transform sequential `have`-telescopes into a parallel beta-application form: ``` have x₁ := v₁; have x₂ := v₂[x₁]; b[x₁, x₂] ↓ (definitionally equal) (fun x₁ x₂' => b[x₁, x₂' x₁]) v₁ (fun x₁ => v₂[x₁]) ``` This parallel form leverages the efficient simplifier for lambdas in `Sym.simp`. This form enables: 1. Independent simplification of each argument 2. Proof construction using standard congruence lemmas 3. Linear kernel type-checking time The algorithm has three phases: 1. **`toBetaApp`**: Transform telescope → parallel beta-application 2. **`simpBetaApp`**: Simplify using `congr`/`congrArg`/`congrFun'` and `simpLambda` 3. **`toHave`**: Convert back to `have` form ## Benchmark Results ### Benchmark 1: Chain with all variables used in body | n | Before (simp) | Before (kernel) | After (simp) | After (kernel) | |---|---------------|-----------------|--------------|----------------| | 50 | 1.2ms | 32ms | 1.6ms | 4.4ms | | 100 | 2.4ms | **225ms** | 3.5ms | **10ms** | | 200 | 4.5ms | — | 8.4ms | 27ms | | 500 | 11.7ms | — | 33.6ms | 128ms | ### Benchmark 3: Parallel declarations (simplified values) | n | Before (simp) | Before (kernel) | After (simp) | After (kernel) | |---|---------------|-----------------|--------------|----------------| | 50 | 0.5ms | 24ms | 0.8ms | 1.8ms | | 100 | 1.2ms | **169ms** | 1.8ms | **5.3ms** | | 200 | 2.2ms | — | 3.9ms | 17ms | | 500 | 5.9ms | — | 12.3ms | 93ms | ### Benchmark 5: Chain with single dependency | n | Before (simp) | Before (kernel) | After (simp) | After (kernel) | |---|---------------|-----------------|--------------|----------------| | 100 | 1.6ms | 6.2ms | 1.8ms | 6.2ms | | 200 | 2.8ms | 21.6ms | 4.4ms | 16.5ms | | 500 | 7.3ms | **125ms** | 12.8ms | **72ms** | Key observations: - Kernel time is now **linear** in telescope depth (previously quadratic) - Simp time increases slightly due to the transformation overhead - Total time (simp + kernel) is dramatically reduced for large telescopes - The improvement is most pronounced when the body depends on many variables ## Trade-offs - Proof sizes are larger (more congruence lemma applications) - Simp time has ~1.5x overhead from the transformation - For very small telescopes (n < 10), the overhead may not pay off The optimization targets the critical path: kernel type-checking was the bottleneck preventing scaling to realistic symbolic simulation workloads.
159 lines
5.6 KiB
Text
159 lines
5.6 KiB
Text
import Lean
|
||
open Lean Meta
|
||
opaque f : Nat → Nat
|
||
|
||
namespace SimpBench
|
||
/-!
|
||
## `SymM` Simplifier benchmarks
|
||
-/
|
||
|
||
def getProofSize (r : Sym.Simp.Result) : MetaM Nat := do
|
||
match r with
|
||
| .rfl _ => return 0
|
||
| .step _ p _ => (ShareCommon.shareCommon' p).numObjs
|
||
|
||
def checkWithKernel (r : Sym.Simp.Result) : MetaM Float := do
|
||
match r with
|
||
| .rfl _ => return 0.0
|
||
| .step _ p _ =>
|
||
let p := ShareCommon.shareCommon' p
|
||
let startTime ← IO.monoNanosNow
|
||
Meta.checkWithKernel p
|
||
let endTime ← IO.monoNanosNow
|
||
return (endTime - startTime).toFloat / 1000000.0
|
||
|
||
def mkSimpMethods : MetaM Sym.Simp.Methods := do
|
||
let thms : Sym.Simp.Theorems := {}
|
||
let thms := thms.insert (← Sym.Simp.mkTheoremFromDecl ``Nat.zero_add)
|
||
let thms := thms.insert (← Sym.Simp.mkTheoremFromDecl ``Nat.add_zero)
|
||
return { post := thms.rewrite }
|
||
|
||
def simp (e : Expr) : MetaM (Sym.Simp.Result × Float) := Sym.SymM.run do
|
||
let e ← Grind.shareCommon e
|
||
let methods ← mkSimpMethods
|
||
let startTime ← IO.monoNanosNow
|
||
let r ← Sym.simp e methods { maxSteps := 100000000 }
|
||
let endTime ← IO.monoNanosNow
|
||
let timeMs := (endTime - startTime).toFloat / 1000000.0
|
||
-- logInfo e
|
||
-- match r with
|
||
-- | .rfl _ => logInfo "rfl"
|
||
-- | .step e' h _ =>
|
||
-- logInfo e'; logInfo h
|
||
return (r, timeMs)
|
||
|
||
def ppExample (e : Expr) : MetaM Unit := do
|
||
forallTelescope e fun _ e => do
|
||
IO.println "Example:"
|
||
IO.println (← ppExpr e)
|
||
IO.println "====>"
|
||
match (← simp e).1 with
|
||
| .rfl _ => IO.println "<no change>"
|
||
| .step e' h _ =>
|
||
IO.println (← ppExpr e')
|
||
IO.println (← ppExpr h)
|
||
IO.println ""
|
||
|
||
def benchSimp (name : String) (e : Expr) (check := false) : MetaM Unit :=
|
||
forallTelescope e fun _ e => do
|
||
let (r, timeMs) ← simp e
|
||
let proofSize ← getProofSize r
|
||
if check then
|
||
let kMs ← checkWithKernel r
|
||
IO.println s!"{name}: {timeMs}ms, kernel: {kMs}ms, proof_size={proofSize}"
|
||
else
|
||
IO.println s!"{name}: {timeMs}ms, proof_size={proofSize}"
|
||
|
||
def mkHaveChainBench (n : Nat) (includeUnused : Bool) : MetaM Expr := do
|
||
let zero := mkNatLit 0
|
||
let one := mkNatLit 1
|
||
let rec go (n : Nat) (xs : Array Expr) (v : Expr) (e : Expr) : MetaM Expr := do
|
||
match n with
|
||
| 0 => mkLetFVars (usedLetOnly := false) (generalizeNondepLet := false) xs e
|
||
| n+1 =>
|
||
if !includeUnused || n % 2 == 0 then
|
||
withLetDecl (nondep := true) `x (mkConst ``Nat) (mkNatAdd zero (mkNatAdd v one)) fun x =>
|
||
go n (xs.push x) x (mkNatAdd zero (mkNatAdd e x))
|
||
else
|
||
withLetDecl (nondep := true) `y (mkConst ``Nat) zero fun y =>
|
||
go n (xs.push y) v (mkNatAdd zero (mkNatAdd e zero))
|
||
go n #[] zero zero
|
||
|
||
def benchHaveChain (n : Nat) (includeUnused : Bool) (check : Bool := false) : MetaM Unit := do
|
||
let e ← mkHaveChainBench n includeUnused
|
||
let name := if includeUnused then s!"have_chain_unused_{n}" else s!"have_chain_{n}"
|
||
benchSimp name e check
|
||
|
||
def mkHaveParallelBench (n : Nat) (simpValues : Bool) : MetaM Expr := do
|
||
withLocalDeclD `x Nat.mkType fun x => do
|
||
let zero := mkNatLit 0
|
||
let rec go (n : Nat) (xs : Array Expr) (e : Expr) : MetaM Expr := do
|
||
match n with
|
||
| 0 => mkLetFVars (usedLetOnly := false) (generalizeNondepLet := false) xs e
|
||
| n+1 =>
|
||
let val := if simpValues then
|
||
-- Values should be in `simp` normal form
|
||
mkNatAdd x (mkNatLit n)
|
||
else
|
||
mkNatAdd zero (mkNatAdd x (mkNatLit n))
|
||
withLetDecl (nondep := true) `y (mkConst ``Nat) val fun x =>
|
||
go n (xs.push x) (mkNatAdd x e)
|
||
let r ← go n #[] zero
|
||
mkForallFVars #[x] r
|
||
|
||
def benchHaveParallel (n : Nat) (simpValues : Bool) (check : Bool := false) : MetaM Unit := do
|
||
let e ← mkHaveParallelBench n simpValues
|
||
let name := if simpValues then s!"have_parallel_simp_vals_{n}" else s!"have_parallel_unsimp_vals_{n}"
|
||
benchSimp name e check
|
||
|
||
def mkHaveChain1Bench (n : Nat) : MetaM Expr := do
|
||
let zero := mkNatLit 0
|
||
let one := mkNatLit 1
|
||
let rec go (n : Nat) (xs : Array Expr) (v : Expr) (e : Expr) : MetaM Expr := do
|
||
match n with
|
||
| 0 => mkLetFVars (usedLetOnly := false) (generalizeNondepLet := false) xs (mkNatAdd v e)
|
||
| n+1 =>
|
||
withLetDecl (nondep := true) `x (mkConst ``Nat) (mkNatAdd zero (mkNatAdd v one)) fun x =>
|
||
go n (xs.push x) x (mkNatAdd one e)
|
||
go n #[] zero zero
|
||
|
||
def benchHaveChain1 (n : Nat) (check : Bool := false) : MetaM Unit := do
|
||
let e ← mkHaveChain1Bench n
|
||
benchSimp s!"have_chain1_{n}" e check
|
||
|
||
def run (k : Nat → MetaM Unit) : MetaM Unit := do
|
||
for n in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 200, 300, 400, 500] do
|
||
k n
|
||
|
||
set_option maxRecDepth 100000
|
||
|
||
/-! ## Run all benchmarks -/
|
||
def runAllBenchmarks : MetaM Unit := do
|
||
IO.println "=== Simplifier Have-telescope Stress Tests ==="
|
||
IO.println ""
|
||
|
||
IO.println ""
|
||
IO.println "--- Benchmark 1: have-telescope chain without unused variables ---"
|
||
ppExample (← mkHaveChainBench 5 false)
|
||
run fun n => benchHaveChain n false true
|
||
|
||
IO.println "--- Benchmark 2: have-telescope chain with unused variables ---"
|
||
ppExample (← mkHaveChainBench 5 true)
|
||
run fun n => benchHaveChain n true true
|
||
|
||
IO.println "--- Benchmark 3: have-telescope parallel declarations with simplified values ---"
|
||
ppExample (← mkHaveParallelBench 5 true)
|
||
run fun n => benchHaveParallel n true true
|
||
|
||
IO.println "--- Benchmark 4: have-telescope parallel declarations with unsimplified values ---"
|
||
ppExample (← mkHaveParallelBench 5 false)
|
||
run fun n => benchHaveParallel n false true
|
||
|
||
IO.println ""
|
||
IO.println "--- Benchmark 5: have-telescope chain with 1 dep ---"
|
||
ppExample (← mkHaveChain1Bench 5)
|
||
run fun n => benchHaveChain1 n true
|
||
|
||
#eval runAllBenchmarks
|
||
|
||
end SimpBench
|