diff --git a/tests/bench/sym/simp_3.lean b/tests/bench/sym/simp_3.lean index d1a34d8125..46ef8369bc 100644 --- a/tests/bench/sym/simp_3.lean +++ b/tests/bench/sym/simp_3.lean @@ -12,6 +12,16 @@ def getProofSize (r : Sym.Simp.Result) : MetaM Nat := do | .rfl _ => return 0 | .step _ p _ => (ShareCommon.shareCommon' p).numObjs +def checkWithKernel (r : Sym.Simp.Result) : MetaM Float := do + match r with + | .rfl _ => return 0.0 + | .step _ p _ => + let p := ShareCommon.shareCommon' p + let startTime ← IO.monoNanosNow + Meta.checkWithKernel p + let endTime ← IO.monoNanosNow + return (endTime - startTime).toFloat / 1000000.0 + def mkSimpMethods : MetaM Sym.Simp.Methods := do let thms : Sym.Simp.Theorems := {} let thms := thms.insert (← Sym.Simp.mkTheoremFromDecl ``Nat.zero_add) @@ -30,21 +40,37 @@ def simp (e : Expr) : MetaM (Sym.Simp.Result × Float) := Sym.SymM.run do -- | .rfl _ => logInfo "rfl" -- | .step e' h _ => -- logInfo e'; logInfo h - -- let startTime ← IO.monoNanosNow - -- checkWithKernel h - -- let endTime ← IO.monoNanosNow - -- let timeMs := (endTime - startTime).toFloat / 1000000.0 - -- logInfo s!"kernel time {timeMs} ms" return (r, timeMs) -def mkLetBench (n : Nat) (includeUnused : Bool) : MetaM Expr := do +def ppExample (e : Expr) : MetaM Unit := do + forallTelescope e fun _ e => do + IO.println "Example:" + IO.println (← ppExpr e) + IO.println "====>" + match (← simp e).1 with + | .rfl _ => IO.println "" + | .step e' _ _ => + IO.println (← ppExpr e') + IO.println "" + +def benchSimp (name : String) (e : Expr) (check := false) : MetaM Unit := + forallTelescope e fun _ e => do + let (r, timeMs) ← simp e + let proofSize ← getProofSize r + if check then + let kMs ← checkWithKernel r + IO.println s!"{name}: {timeMs}ms, kernel: {kMs}ms, proof_size={proofSize}" + else + IO.println s!"{name}: {timeMs}ms, proof_size={proofSize}" + +def mkHaveChainBench (n : Nat) (includeUnused : Bool) : MetaM Expr := do let zero := mkNatLit 0 let one := mkNatLit 1 let rec go (n : Nat) (xs : Array Expr) (v : Expr) (e : Expr) : MetaM Expr := do match n with | 0 => mkLetFVars (usedLetOnly := false) (generalizeNondepLet := false) xs e | n+1 => - if includeUnused && n % 2 == 0 then + if !includeUnused || n % 2 == 0 then withLetDecl (nondep := true) `x (mkConst ``Nat) (mkNatAdd zero (mkNatAdd v one)) fun x => go n (xs.push x) x (mkNatAdd zero (mkNatAdd e x)) else @@ -52,29 +78,81 @@ def mkLetBench (n : Nat) (includeUnused : Bool) : MetaM Expr := do go n (xs.push y) v (mkNatAdd zero (mkNatAdd e zero)) go n #[] zero zero -def benchLet (n : Nat) (includeUnused : Bool) : MetaM Unit := do - let e ← mkLetBench n includeUnused - let (r, timeMs) ← simp e - let proofSize ← getProofSize r - IO.println s!"have_{n}: {timeMs}ms, proof_size={proofSize}" +def benchHaveChain (n : Nat) (includeUnused : Bool) (check : Bool := false) : MetaM Unit := do + let e ← mkHaveChainBench n includeUnused + let name := if includeUnused then s!"have_chain_unused_{n}" else s!"have_chain_{n}" + benchSimp name e check + +def mkHaveParallelBench (n : Nat) (simpValues : Bool) : MetaM Expr := do + withLocalDeclD `x Nat.mkType fun x => do + let zero := mkNatLit 0 + let rec go (n : Nat) (xs : Array Expr) (e : Expr) : MetaM Expr := do + match n with + | 0 => mkLetFVars (usedLetOnly := false) (generalizeNondepLet := false) xs e + | n+1 => + let val := if simpValues then + -- Values should be in `simp` normal form + mkNatAdd x (mkNatLit n) + else + mkNatAdd zero (mkNatAdd x (mkNatLit n)) + withLetDecl (nondep := true) `y (mkConst ``Nat) val fun x => + go n (xs.push x) (mkNatAdd x e) + let r ← go n #[] zero + mkForallFVars #[x] r + +def benchHaveParallel (n : Nat) (simpValues : Bool) (check : Bool := false) : MetaM Unit := do + let e ← mkHaveParallelBench n simpValues + let name := if simpValues then s!"have_parallel_simp_vals_{n}" else s!"have_parallel_unsimp_vals_{n}" + benchSimp name e check + +def mkHaveChain1Bench (n : Nat) : MetaM Expr := do + let zero := mkNatLit 0 + let one := mkNatLit 1 + let rec go (n : Nat) (xs : Array Expr) (v : Expr) (e : Expr) : MetaM Expr := do + match n with + | 0 => mkLetFVars (usedLetOnly := false) (generalizeNondepLet := false) xs (mkNatAdd v e) + | n+1 => + withLetDecl (nondep := true) `x (mkConst ``Nat) (mkNatAdd zero (mkNatAdd v one)) fun x => + go n (xs.push x) x (mkNatAdd one e) + go n #[] zero zero + +def benchHaveChain1 (n : Nat) (check : Bool := false) : MetaM Unit := do + let e ← mkHaveChain1Bench n + benchSimp s!"have_chain1_{n}" e check + +def run (k : Nat → MetaM Unit) : MetaM Unit := do + for n in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 200, 300, 400, 500, + 600, 700, 800, 900, 1000, 1200, 1400, 1600, 1800, 2000] do + k n set_option maxRecDepth 100000 /-! ## Run all benchmarks -/ def runAllBenchmarks : MetaM Unit := do - IO.println "=== Simplifier Stress Tests ===" + IO.println "=== Simplifier Have-telescope Stress Tests ===" IO.println "" IO.println "" - IO.println "--- Benchmark 1: have block without unused variables ---" - for n in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 200, 300, 400, 500, - 600, 700, 800, 900, 1000, 1200, 1400, 1600, 1800, 2000] do - benchLet n false + IO.println "--- Benchmark 1: have-telescope chain without unused variables ---" + ppExample (← mkHaveChainBench 5 false) + run fun n => benchHaveChain n false (n < 110) - IO.println "--- Benchmark 2: have block with unused variables ---" - for n in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 200, 300, 400, 500, - 600, 700, 800, 900, 1000, 1200, 1400, 1600, 1800, 2000] do - benchLet n true + IO.println "--- Benchmark 2: have-telescope chain with unused variables ---" + ppExample (← mkHaveChainBench 5 true) + run fun n => benchHaveChain n true (n < 120) + + IO.println "--- Benchmark 3: have-telescope parallel declarations with simplified values ---" + ppExample (← mkHaveParallelBench 5 true) + run fun n => benchHaveParallel n true (n < 120) + + IO.println "--- Benchmark 4: have-telescope parallel declarations with unsimplified values ---" + ppExample (← mkHaveParallelBench 5 false) + run fun n => benchHaveParallel n false (n < 120) + + IO.println "" + IO.println "--- Benchmark 5: have-telescope chain with 1 dep ---" + ppExample (← mkHaveChain1Bench 5) + run fun n => benchHaveChain1 n (n < 600) #eval runAllBenchmarks