feat: add simpArrowTelescope for compact proofs of arrow simplification (#12152)

This PR adds `simpArrowTelescope`, a simproc that simplifies telescopes
of non-dependent arrows (p₁ → p₂ → ... → q) while avoiding quadratic
proof growth.

When using `Expr.forallE` to represent nested implications, each nesting
level bumps de Bruijn indices in subterms, destroying sharing even with
hash-consing. For example, a free variable `x` gets different de Bruijn
representations at each depth, causing proof terms to grow.

`simpArrowTelescope` works by:

- Converting arrows to `Arrow p q` (a definitional wrapper)
- Simplifying each component
- Converting back to `→` form

Since `Arrow` arguments are not under binders, subterms remain identical
across nesting levels and can be shared.

The `simp_4` benchmark demonstrates the improvement:

With `forallE`: ~160ms, proof_size ≈ 173k
With `Arrow`: ~43ms, proof_size ≈ 16k
Tradeoff: `simpArrowTelescope` misses simplifications that depend on the
arrow structure (e.g., `p → p` to `True`), since post-methods aren't
applied to intermediate arrows. Thus, it is not used by default. to use
it, one has to set `simpArrowTelescope` as a `pre`-method.
This commit is contained in:
Leonardo de Moura 2026-01-25 12:43:59 -08:00 committed by GitHub
parent 9e241a4087
commit ba8c2ed4ee
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 131 additions and 19 deletions

View file

@ -44,6 +44,32 @@ theorem implies_congr_left {p₁ p₂ : Sort u} {q : Sort v} (h : p₁ = p₂) :
theorem implies_congr_right {p : Sort u} {q₁ q₂ : Sort v} (h : q₁ = q₂) : (p → q₁) = (p → q₂) :=
h ▸ rfl
namespace Lean
/--
`Arrow α β` is definitionally equal to `α → β`, but represented as a function
application rather than `Expr.forallE`.
This representation is useful for proof automation that builds nested implications
like `pₙ → ... → p₂ → p₁`. With `Expr.forallE`, each nesting level introduces a
binder that bumps de Bruijn indices in subterms, destroying sharing even with
hash-consing. For example, if `p₁` contains `#20`, then at depth 2 it becomes `#21`,
at depth 3 it becomes `#22`, etc., causing quadratic proof growth.
With `arrow`, both arguments are explicit (not under binders), so subterms remain
identical across nesting levels and can be shared, yielding linear-sized proofs.
-/
def Arrow (α : Sort u) (β : Sort v) : Sort (imax u v) := α → β
theorem arrow_congr {p₁ p₂ : Sort u} {q₁ q₂ : Sort v} (h₁ : p₁ = p₂) (h₂ : q₁ = q₂) : Arrow p₁ q₁ = Arrow p₂ q₂ :=
h₁ ▸ h₂ ▸ rfl
theorem arrow_congr_left {p₁ p₂ : Sort u} {q : Sort v} (h : p₁ = p₂) : Arrow p₁ q = Arrow p₂ q :=
h ▸ rfl
theorem arrow_congr_right {p : Sort u} {q₁ q₂ : Sort v} (h : q₁ = q₂) : Arrow p q₁ = Arrow p q₂ :=
h ▸ rfl
end Lean
theorem iff_congr {p₁ p₂ q₁ q₂ : Prop} (h₁ : p₁ ↔ p₂) (h₂ : q₁ ↔ q₂) : (p₁ ↔ q₁) ↔ (p₂ ↔ q₂) :=
Iff.of_eq (propext h₁ ▸ propext h₂ ▸ rfl)

View file

@ -7,6 +7,7 @@ module
prelude
public import Lean.Meta.Sym.Simp.SimpM
import Lean.Meta.Sym.AlphaShareBuilder
import Lean.Meta.Sym.InferType
namespace Lean.Meta.Sym.Simp
/--
@ -25,7 +26,7 @@ The proof uses the approach used in `mkFunextFor` followed by an `Eq.ndrec`.
def mkForallCongrFor (xs : Array Expr) : MetaM Expr := do
let prop := mkSort 0
let type ← mkForallFVars xs prop
let w ← getLevel type
let w ← Meta.getLevel type
withLocalDeclD `p type fun p =>
withLocalDeclD `q type fun q => do
let eq := mkApp3 (mkConst ``Eq [1]) prop (mkAppN p xs) (mkAppN q xs)
@ -53,6 +54,78 @@ def mkForallCongrFor (xs : Array Expr) : MetaM Expr := do
open Internal
structure ArrowInfo where
binderName : Name
binderInfo : BinderInfo
structure ToArrowResult where
arrow : Expr
infos : List ArrowInfo
v : Level
def toArrow (e : Expr) : SymM ToArrowResult := do
if let .forallE n α β bi := e then
if !β.hasLooseBVars then
let { arrow, infos, v } ← toArrow β
let u ← getLevel α
let arrow ← mkAppS₂ (← mkConstS ``Arrow [u, v]) α arrow
let info := { binderName := n, binderInfo := bi }
return { arrow, v := mkLevelIMax' u v, infos := info :: infos }
return { arrow := e, infos := [], v := (← getLevel e) }
def toForall (e : Expr) (infos : List ArrowInfo) : SymM Expr := do
let { binderName, binderInfo, .. } :: infos := infos | return e
let_expr Arrow α β := e | return e
mkForallS binderName binderInfo α (← toForall β infos)
partial def simpArrows (e : Expr) : SimpM Result := do
let_expr f@Arrow p q := e | simp e
match (← simp p), (← simpArrows q) with
| .rfl _, .rfl _ => return .rfl
| .step p' h _, .rfl _ =>
let e' ← mkAppS₂ f p' q
return .step e' <| mkApp4 (mkConst ``arrow_congr_left f.constLevels!) p p' q h
| .rfl _, .step q' h _ =>
let e' ← mkAppS₂ f p q'
return .step e' <| mkApp4 (mkConst ``arrow_congr_right f.constLevels!) p q q' h
| .step p' h₁ _, .step q' h₂ _ =>
let e' ← mkAppS₂ f p' q'
return .step e' <| mkApp6 (mkConst ``arrow_congr f.constLevels!) p p' q q' h₁ h₂
/--
Simplifies a telescope of non-dependent arrows `p₁ → p₂ → ... → pₙ → q` by:
1. Converting to `Arrow p₁ (Arrow p₂ (... (Arrow pₙ q)))` (see `toArrow`)
2. Simplifying each `pᵢ` and `q` (see `simpArrows`)
3. Converting back to `→` form (see `toForall`)
Using `Arrow` (a definitional wrapper around `→`) avoids the quadratic proof growth that
occurs with `Expr.forallE`. With `forallE`, each nesting level bumps de Bruijn indices in
subterms, destroying sharing. For example, if each `pᵢ` contains a free variable `x`, the
de Bruijn representation of `x` differs at each depth, preventing hash-consing from
recognizing them as identical.
With `Arrow`, both arguments are explicit (not under binders), so subterms remain identical
across nesting levels and can be shared, yielding linear-sized proofs.
**Tradeoff**: This function simplifies each `pᵢ` and `q` individually, but misses
simplifications that depend on the arrow structure itself. For example, `q → p → p`
won't be simplified to `True` (when `p : Prop`) because the simplifier does not have
a chance to apply `post` methods to the intermediate arrow `p → p`.
Thus, this is a simproc that is meant to be used as a pre-method and marks the
result as fully simplified to prevent `simpArrow` from being applied.
-/
public def simpArrowTelescope : Simproc := fun e => do
unless e.isArrow do return .rfl -- not applicable
let { arrow, infos, v } ← toArrow e
let .step arrow' h _ ← simpArrows arrow | return .rfl (done := true)
let e' ← toForall arrow' infos
let α := mkSort v
let v1 := v.succ
let h := mkApp6 (mkConst ``Eq.trans [v1]) α e arrow arrow' (mkApp2 (mkConst ``Eq.refl [v1]) α arrow) h
let h := mkApp6 (mkConst ``Eq.trans [v1]) α e arrow' e' h (mkApp2 (mkConst ``Eq.refl [v1]) α e')
return .step e' h (done := true)
public def simpArrow (e : Expr) : SimpM Result := do
let p := e.bindingDomain!
let q := e.bindingBody!

View file

@ -1,7 +1,5 @@
import Lean
open Lean Meta
opaque f : Nat → Nat
namespace SimpBench
/-!
## `SymM` Simplifier benchmarks
@ -24,15 +22,18 @@ def checkWithKernel (r : Sym.Simp.Result) : MetaM Float := do
let endTime ← IO.monoNanosNow
return (endTime - startTime).toFloat / 1000000.0
def mkSimpMethods : MetaM Sym.Simp.Methods := do
def mkSimpMethods (arrowTelescope : Bool) : MetaM Sym.Simp.Methods := do
let thms : Sym.Simp.Theorems := {}
let thms := thms.insert (← Sym.Simp.mkTheoremFromDecl ``Nat.zero_add)
let thms := thms.insert (← Sym.Simp.mkTheoremFromDecl ``Nat.add_zero)
return { post := thms.rewrite }
if arrowTelescope then
return { pre := Sym.Simp.simpArrowTelescope, post := thms.rewrite }
else
return { post := thms.rewrite }
def simp (e : Expr) : MetaM (Sym.Simp.Result × Float) := Sym.SymM.run do
def simp (e : Expr) (arrowTelescope : Bool) : MetaM (Sym.Simp.Result × Float) := Sym.SymM.run do
let e ← Grind.shareCommon e
let methods ← mkSimpMethods
let methods ← mkSimpMethods arrowTelescope
let startTime ← IO.monoNanosNow
let r ← Sym.simp e methods { maxSteps := 100000000 }
let endTime ← IO.monoNanosNow
@ -44,11 +45,11 @@ def simp (e : Expr) : MetaM (Sym.Simp.Result × Float) := Sym.SymM.run do
-- logInfo e'; logInfo h
return (r, timeMs)
def ppExample (e : Expr) (info := false) : MetaM Unit := do
def ppExample (e : Expr) (arrowTelescope : Bool) (info := false) : MetaM Unit := do
IO.println "Example:"
IO.println (← ppExpr e)
IO.println "====>"
match (← simp e).1 with
match (← simp e arrowTelescope).1 with
| .rfl _ => IO.println "<no change>"
| .step e' h _ =>
IO.println (← ppExpr e')
@ -59,8 +60,8 @@ def ppExample (e : Expr) (info := false) : MetaM Unit := do
IO.println (← ppExpr h)
IO.println ""
def benchSimp (name : String) (e : Expr) (check := false) : MetaM Unit := do
let (r, timeMs) ← simp e
def benchSimp (name : String) (e : Expr) (arrowTelescope : Bool) (check := false) : MetaM Unit := do
let (r, timeMs) ← simp e arrowTelescope
let proofSize ← getProofSize r
if check then
let kMs ← checkWithKernel r
@ -89,9 +90,14 @@ def mkForallBench (n : Nat) (useImplies : Bool) : MetaM Expr :=
go n (← mkArrow (mkNatEq xs[n]! (mkNatAdd (mkNatLit 0) xs[n]!)) e)
go n (mkConst ``True)
def benchForall (n : Nat) (useImplies : Bool) (check := false) : MetaM Unit := do
let e ← mkForallBench n useImplies
benchSimp s!"forall_{n}" e check
inductive Kind where
| implies
| arrowTelescope
| arrow
def benchForall (n : Nat) (kind : Kind) (check := false) : MetaM Unit := do
let e ← mkForallBench n (kind matches .implies)
benchSimp s!"forall_{n}" e (kind matches .arrowTelescope) check
set_option maxRecDepth 100000
@ -102,17 +108,24 @@ def runAllBenchmarks : MetaM Unit := do
IO.println ""
IO.println "--- Benchmark 1: Forall Telescope block using arrows in the body ---"
ppExample (← mkForallBench 5 false)
ppExample (← mkForallBench 5 false) false
for n in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200, 300, 400] do
benchForall n false (n < 500)
benchForall n .arrow (n < 500)
IO.println ""
IO.println "--- Benchmark 1: Forall Telescope block using `implies` in the body ---"
ppExample (← mkForallBench 5 true)
IO.println "--- Benchmark 2: Forall Telescope block using arrow telescope in the body ---"
ppExample (← mkForallBench 5 false) true
for n in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200, 300, 400] do
benchForall n true (n < 500)
benchForall n .arrowTelescope (n < 500)
IO.println ""
IO.println "--- Benchmark 3: Forall Telescope block using `implies` in the body ---"
ppExample (← mkForallBench 5 true) false
for n in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200, 300, 400] do
benchForall n .implies (n < 500)
#eval runAllBenchmarks