From ba8c2ed4ee14763f40fbcd986a3454b5bf0bfca8 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 25 Jan 2026 12:43:59 -0800 Subject: [PATCH] feat: add `simpArrowTelescope` for compact proofs of arrow simplification (#12152) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds `simpArrowTelescope`, a simproc that simplifies telescopes of non-dependent arrows (p₁ → p₂ → ... → q) while avoiding quadratic proof growth. When using `Expr.forallE` to represent nested implications, each nesting level bumps de Bruijn indices in subterms, destroying sharing even with hash-consing. For example, a free variable `x` gets different de Bruijn representations at each depth, causing proof terms to grow. `simpArrowTelescope` works by: - Converting arrows to `Arrow p q` (a definitional wrapper) - Simplifying each component - Converting back to `→` form Since `Arrow` arguments are not under binders, subterms remain identical across nesting levels and can be shared. The `simp_4` benchmark demonstrates the improvement: With `forallE`: ~160ms, proof_size ≈ 173k With `Arrow`: ~43ms, proof_size ≈ 16k Tradeoff: `simpArrowTelescope` misses simplifications that depend on the arrow structure (e.g., `p → p` to `True`), since post-methods aren't applied to intermediate arrows. Thus, it is not used by default. to use it, one has to set `simpArrowTelescope` as a `pre`-method. --- src/Init/SimpLemmas.lean | 26 +++++++++++ src/Lean/Meta/Sym/Simp/Forall.lean | 75 +++++++++++++++++++++++++++++- tests/bench/sym/simp_4.lean | 49 ++++++++++++------- 3 files changed, 131 insertions(+), 19 deletions(-) diff --git a/src/Init/SimpLemmas.lean b/src/Init/SimpLemmas.lean index d948aa32bc..82beccd132 100644 --- a/src/Init/SimpLemmas.lean +++ b/src/Init/SimpLemmas.lean @@ -44,6 +44,32 @@ theorem implies_congr_left {p₁ p₂ : Sort u} {q : Sort v} (h : p₁ = p₂) : theorem implies_congr_right {p : Sort u} {q₁ q₂ : Sort v} (h : q₁ = q₂) : (p → q₁) = (p → q₂) := h ▸ rfl +namespace Lean +/-- +`Arrow α β` is definitionally equal to `α → β`, but represented as a function +application rather than `Expr.forallE`. + +This representation is useful for proof automation that builds nested implications +like `pₙ → ... → p₂ → p₁`. With `Expr.forallE`, each nesting level introduces a +binder that bumps de Bruijn indices in subterms, destroying sharing even with +hash-consing. For example, if `p₁` contains `#20`, then at depth 2 it becomes `#21`, +at depth 3 it becomes `#22`, etc., causing quadratic proof growth. + +With `arrow`, both arguments are explicit (not under binders), so subterms remain +identical across nesting levels and can be shared, yielding linear-sized proofs. +-/ +def Arrow (α : Sort u) (β : Sort v) : Sort (imax u v) := α → β + +theorem arrow_congr {p₁ p₂ : Sort u} {q₁ q₂ : Sort v} (h₁ : p₁ = p₂) (h₂ : q₁ = q₂) : Arrow p₁ q₁ = Arrow p₂ q₂ := + h₁ ▸ h₂ ▸ rfl + +theorem arrow_congr_left {p₁ p₂ : Sort u} {q : Sort v} (h : p₁ = p₂) : Arrow p₁ q = Arrow p₂ q := + h ▸ rfl + +theorem arrow_congr_right {p : Sort u} {q₁ q₂ : Sort v} (h : q₁ = q₂) : Arrow p q₁ = Arrow p q₂ := + h ▸ rfl +end Lean + theorem iff_congr {p₁ p₂ q₁ q₂ : Prop} (h₁ : p₁ ↔ p₂) (h₂ : q₁ ↔ q₂) : (p₁ ↔ q₁) ↔ (p₂ ↔ q₂) := Iff.of_eq (propext h₁ ▸ propext h₂ ▸ rfl) diff --git a/src/Lean/Meta/Sym/Simp/Forall.lean b/src/Lean/Meta/Sym/Simp/Forall.lean index 4e37f4220c..0964275e5f 100644 --- a/src/Lean/Meta/Sym/Simp/Forall.lean +++ b/src/Lean/Meta/Sym/Simp/Forall.lean @@ -7,6 +7,7 @@ module prelude public import Lean.Meta.Sym.Simp.SimpM import Lean.Meta.Sym.AlphaShareBuilder +import Lean.Meta.Sym.InferType namespace Lean.Meta.Sym.Simp /-- @@ -25,7 +26,7 @@ The proof uses the approach used in `mkFunextFor` followed by an `Eq.ndrec`. def mkForallCongrFor (xs : Array Expr) : MetaM Expr := do let prop := mkSort 0 let type ← mkForallFVars xs prop - let w ← getLevel type + let w ← Meta.getLevel type withLocalDeclD `p type fun p => withLocalDeclD `q type fun q => do let eq := mkApp3 (mkConst ``Eq [1]) prop (mkAppN p xs) (mkAppN q xs) @@ -53,6 +54,78 @@ def mkForallCongrFor (xs : Array Expr) : MetaM Expr := do open Internal +structure ArrowInfo where + binderName : Name + binderInfo : BinderInfo + +structure ToArrowResult where + arrow : Expr + infos : List ArrowInfo + v : Level + +def toArrow (e : Expr) : SymM ToArrowResult := do + if let .forallE n α β bi := e then + if !β.hasLooseBVars then + let { arrow, infos, v } ← toArrow β + let u ← getLevel α + let arrow ← mkAppS₂ (← mkConstS ``Arrow [u, v]) α arrow + let info := { binderName := n, binderInfo := bi } + return { arrow, v := mkLevelIMax' u v, infos := info :: infos } + return { arrow := e, infos := [], v := (← getLevel e) } + +def toForall (e : Expr) (infos : List ArrowInfo) : SymM Expr := do + let { binderName, binderInfo, .. } :: infos := infos | return e + let_expr Arrow α β := e | return e + mkForallS binderName binderInfo α (← toForall β infos) + +partial def simpArrows (e : Expr) : SimpM Result := do + let_expr f@Arrow p q := e | simp e + match (← simp p), (← simpArrows q) with + | .rfl _, .rfl _ => return .rfl + | .step p' h _, .rfl _ => + let e' ← mkAppS₂ f p' q + return .step e' <| mkApp4 (mkConst ``arrow_congr_left f.constLevels!) p p' q h + | .rfl _, .step q' h _ => + let e' ← mkAppS₂ f p q' + return .step e' <| mkApp4 (mkConst ``arrow_congr_right f.constLevels!) p q q' h + | .step p' h₁ _, .step q' h₂ _ => + let e' ← mkAppS₂ f p' q' + return .step e' <| mkApp6 (mkConst ``arrow_congr f.constLevels!) p p' q q' h₁ h₂ + +/-- +Simplifies a telescope of non-dependent arrows `p₁ → p₂ → ... → pₙ → q` by: +1. Converting to `Arrow p₁ (Arrow p₂ (... (Arrow pₙ q)))` (see `toArrow`) +2. Simplifying each `pᵢ` and `q` (see `simpArrows`) +3. Converting back to `→` form (see `toForall`) + +Using `Arrow` (a definitional wrapper around `→`) avoids the quadratic proof growth that +occurs with `Expr.forallE`. With `forallE`, each nesting level bumps de Bruijn indices in +subterms, destroying sharing. For example, if each `pᵢ` contains a free variable `x`, the +de Bruijn representation of `x` differs at each depth, preventing hash-consing from +recognizing them as identical. + +With `Arrow`, both arguments are explicit (not under binders), so subterms remain identical +across nesting levels and can be shared, yielding linear-sized proofs. + +**Tradeoff**: This function simplifies each `pᵢ` and `q` individually, but misses +simplifications that depend on the arrow structure itself. For example, `q → p → p` +won't be simplified to `True` (when `p : Prop`) because the simplifier does not have +a chance to apply `post` methods to the intermediate arrow `p → p`. + +Thus, this is a simproc that is meant to be used as a pre-method and marks the +result as fully simplified to prevent `simpArrow` from being applied. +-/ +public def simpArrowTelescope : Simproc := fun e => do + unless e.isArrow do return .rfl -- not applicable + let { arrow, infos, v } ← toArrow e + let .step arrow' h _ ← simpArrows arrow | return .rfl (done := true) + let e' ← toForall arrow' infos + let α := mkSort v + let v1 := v.succ + let h := mkApp6 (mkConst ``Eq.trans [v1]) α e arrow arrow' (mkApp2 (mkConst ``Eq.refl [v1]) α arrow) h + let h := mkApp6 (mkConst ``Eq.trans [v1]) α e arrow' e' h (mkApp2 (mkConst ``Eq.refl [v1]) α e') + return .step e' h (done := true) + public def simpArrow (e : Expr) : SimpM Result := do let p := e.bindingDomain! let q := e.bindingBody! diff --git a/tests/bench/sym/simp_4.lean b/tests/bench/sym/simp_4.lean index c1b2b41c17..b3fbea2a46 100644 --- a/tests/bench/sym/simp_4.lean +++ b/tests/bench/sym/simp_4.lean @@ -1,7 +1,5 @@ import Lean open Lean Meta -opaque f : Nat → Nat - namespace SimpBench /-! ## `SymM` Simplifier benchmarks @@ -24,15 +22,18 @@ def checkWithKernel (r : Sym.Simp.Result) : MetaM Float := do let endTime ← IO.monoNanosNow return (endTime - startTime).toFloat / 1000000.0 -def mkSimpMethods : MetaM Sym.Simp.Methods := do +def mkSimpMethods (arrowTelescope : Bool) : MetaM Sym.Simp.Methods := do let thms : Sym.Simp.Theorems := {} let thms := thms.insert (← Sym.Simp.mkTheoremFromDecl ``Nat.zero_add) let thms := thms.insert (← Sym.Simp.mkTheoremFromDecl ``Nat.add_zero) - return { post := thms.rewrite } + if arrowTelescope then + return { pre := Sym.Simp.simpArrowTelescope, post := thms.rewrite } + else + return { post := thms.rewrite } -def simp (e : Expr) : MetaM (Sym.Simp.Result × Float) := Sym.SymM.run do +def simp (e : Expr) (arrowTelescope : Bool) : MetaM (Sym.Simp.Result × Float) := Sym.SymM.run do let e ← Grind.shareCommon e - let methods ← mkSimpMethods + let methods ← mkSimpMethods arrowTelescope let startTime ← IO.monoNanosNow let r ← Sym.simp e methods { maxSteps := 100000000 } let endTime ← IO.monoNanosNow @@ -44,11 +45,11 @@ def simp (e : Expr) : MetaM (Sym.Simp.Result × Float) := Sym.SymM.run do -- logInfo e'; logInfo h return (r, timeMs) -def ppExample (e : Expr) (info := false) : MetaM Unit := do +def ppExample (e : Expr) (arrowTelescope : Bool) (info := false) : MetaM Unit := do IO.println "Example:" IO.println (← ppExpr e) IO.println "====>" - match (← simp e).1 with + match (← simp e arrowTelescope).1 with | .rfl _ => IO.println "" | .step e' h _ => IO.println (← ppExpr e') @@ -59,8 +60,8 @@ def ppExample (e : Expr) (info := false) : MetaM Unit := do IO.println (← ppExpr h) IO.println "" -def benchSimp (name : String) (e : Expr) (check := false) : MetaM Unit := do - let (r, timeMs) ← simp e +def benchSimp (name : String) (e : Expr) (arrowTelescope : Bool) (check := false) : MetaM Unit := do + let (r, timeMs) ← simp e arrowTelescope let proofSize ← getProofSize r if check then let kMs ← checkWithKernel r @@ -89,9 +90,14 @@ def mkForallBench (n : Nat) (useImplies : Bool) : MetaM Expr := go n (← mkArrow (mkNatEq xs[n]! (mkNatAdd (mkNatLit 0) xs[n]!)) e) go n (mkConst ``True) -def benchForall (n : Nat) (useImplies : Bool) (check := false) : MetaM Unit := do - let e ← mkForallBench n useImplies - benchSimp s!"forall_{n}" e check +inductive Kind where + | implies + | arrowTelescope + | arrow + +def benchForall (n : Nat) (kind : Kind) (check := false) : MetaM Unit := do + let e ← mkForallBench n (kind matches .implies) + benchSimp s!"forall_{n}" e (kind matches .arrowTelescope) check set_option maxRecDepth 100000 @@ -102,17 +108,24 @@ def runAllBenchmarks : MetaM Unit := do IO.println "" IO.println "--- Benchmark 1: Forall Telescope block using arrows in the body ---" - ppExample (← mkForallBench 5 false) + ppExample (← mkForallBench 5 false) false for n in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200, 300, 400] do - benchForall n false (n < 500) + benchForall n .arrow (n < 500) IO.println "" - IO.println "--- Benchmark 1: Forall Telescope block using `implies` in the body ---" - ppExample (← mkForallBench 5 true) + IO.println "--- Benchmark 2: Forall Telescope block using arrow telescope in the body ---" + ppExample (← mkForallBench 5 false) true for n in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200, 300, 400] do - benchForall n true (n < 500) + benchForall n .arrowTelescope (n < 500) + + IO.println "" + IO.println "--- Benchmark 3: Forall Telescope block using `implies` in the body ---" + ppExample (← mkForallBench 5 true) false + + for n in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200, 300, 400] do + benchForall n .implies (n < 500) #eval runAllBenchmarks