feat: add simpArrowTelescope for compact proofs of arrow simplification (#12152)
This PR adds `simpArrowTelescope`, a simproc that simplifies telescopes of non-dependent arrows (p₁ → p₂ → ... → q) while avoiding quadratic proof growth. When using `Expr.forallE` to represent nested implications, each nesting level bumps de Bruijn indices in subterms, destroying sharing even with hash-consing. For example, a free variable `x` gets different de Bruijn representations at each depth, causing proof terms to grow. `simpArrowTelescope` works by: - Converting arrows to `Arrow p q` (a definitional wrapper) - Simplifying each component - Converting back to `→` form Since `Arrow` arguments are not under binders, subterms remain identical across nesting levels and can be shared. The `simp_4` benchmark demonstrates the improvement: With `forallE`: ~160ms, proof_size ≈ 173k With `Arrow`: ~43ms, proof_size ≈ 16k Tradeoff: `simpArrowTelescope` misses simplifications that depend on the arrow structure (e.g., `p → p` to `True`), since post-methods aren't applied to intermediate arrows. Thus, it is not used by default. to use it, one has to set `simpArrowTelescope` as a `pre`-method.
This commit is contained in:
parent
9e241a4087
commit
ba8c2ed4ee
3 changed files with 131 additions and 19 deletions
|
|
@ -44,6 +44,32 @@ theorem implies_congr_left {p₁ p₂ : Sort u} {q : Sort v} (h : p₁ = p₂) :
|
|||
theorem implies_congr_right {p : Sort u} {q₁ q₂ : Sort v} (h : q₁ = q₂) : (p → q₁) = (p → q₂) :=
|
||||
h ▸ rfl
|
||||
|
||||
namespace Lean
|
||||
/--
|
||||
`Arrow α β` is definitionally equal to `α → β`, but represented as a function
|
||||
application rather than `Expr.forallE`.
|
||||
|
||||
This representation is useful for proof automation that builds nested implications
|
||||
like `pₙ → ... → p₂ → p₁`. With `Expr.forallE`, each nesting level introduces a
|
||||
binder that bumps de Bruijn indices in subterms, destroying sharing even with
|
||||
hash-consing. For example, if `p₁` contains `#20`, then at depth 2 it becomes `#21`,
|
||||
at depth 3 it becomes `#22`, etc., causing quadratic proof growth.
|
||||
|
||||
With `arrow`, both arguments are explicit (not under binders), so subterms remain
|
||||
identical across nesting levels and can be shared, yielding linear-sized proofs.
|
||||
-/
|
||||
def Arrow (α : Sort u) (β : Sort v) : Sort (imax u v) := α → β
|
||||
|
||||
theorem arrow_congr {p₁ p₂ : Sort u} {q₁ q₂ : Sort v} (h₁ : p₁ = p₂) (h₂ : q₁ = q₂) : Arrow p₁ q₁ = Arrow p₂ q₂ :=
|
||||
h₁ ▸ h₂ ▸ rfl
|
||||
|
||||
theorem arrow_congr_left {p₁ p₂ : Sort u} {q : Sort v} (h : p₁ = p₂) : Arrow p₁ q = Arrow p₂ q :=
|
||||
h ▸ rfl
|
||||
|
||||
theorem arrow_congr_right {p : Sort u} {q₁ q₂ : Sort v} (h : q₁ = q₂) : Arrow p q₁ = Arrow p q₂ :=
|
||||
h ▸ rfl
|
||||
end Lean
|
||||
|
||||
theorem iff_congr {p₁ p₂ q₁ q₂ : Prop} (h₁ : p₁ ↔ p₂) (h₂ : q₁ ↔ q₂) : (p₁ ↔ q₁) ↔ (p₂ ↔ q₂) :=
|
||||
Iff.of_eq (propext h₁ ▸ propext h₂ ▸ rfl)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ module
|
|||
prelude
|
||||
public import Lean.Meta.Sym.Simp.SimpM
|
||||
import Lean.Meta.Sym.AlphaShareBuilder
|
||||
import Lean.Meta.Sym.InferType
|
||||
namespace Lean.Meta.Sym.Simp
|
||||
|
||||
/--
|
||||
|
|
@ -25,7 +26,7 @@ The proof uses the approach used in `mkFunextFor` followed by an `Eq.ndrec`.
|
|||
def mkForallCongrFor (xs : Array Expr) : MetaM Expr := do
|
||||
let prop := mkSort 0
|
||||
let type ← mkForallFVars xs prop
|
||||
let w ← getLevel type
|
||||
let w ← Meta.getLevel type
|
||||
withLocalDeclD `p type fun p =>
|
||||
withLocalDeclD `q type fun q => do
|
||||
let eq := mkApp3 (mkConst ``Eq [1]) prop (mkAppN p xs) (mkAppN q xs)
|
||||
|
|
@ -53,6 +54,78 @@ def mkForallCongrFor (xs : Array Expr) : MetaM Expr := do
|
|||
|
||||
open Internal
|
||||
|
||||
structure ArrowInfo where
|
||||
binderName : Name
|
||||
binderInfo : BinderInfo
|
||||
|
||||
structure ToArrowResult where
|
||||
arrow : Expr
|
||||
infos : List ArrowInfo
|
||||
v : Level
|
||||
|
||||
def toArrow (e : Expr) : SymM ToArrowResult := do
|
||||
if let .forallE n α β bi := e then
|
||||
if !β.hasLooseBVars then
|
||||
let { arrow, infos, v } ← toArrow β
|
||||
let u ← getLevel α
|
||||
let arrow ← mkAppS₂ (← mkConstS ``Arrow [u, v]) α arrow
|
||||
let info := { binderName := n, binderInfo := bi }
|
||||
return { arrow, v := mkLevelIMax' u v, infos := info :: infos }
|
||||
return { arrow := e, infos := [], v := (← getLevel e) }
|
||||
|
||||
def toForall (e : Expr) (infos : List ArrowInfo) : SymM Expr := do
|
||||
let { binderName, binderInfo, .. } :: infos := infos | return e
|
||||
let_expr Arrow α β := e | return e
|
||||
mkForallS binderName binderInfo α (← toForall β infos)
|
||||
|
||||
partial def simpArrows (e : Expr) : SimpM Result := do
|
||||
let_expr f@Arrow p q := e | simp e
|
||||
match (← simp p), (← simpArrows q) with
|
||||
| .rfl _, .rfl _ => return .rfl
|
||||
| .step p' h _, .rfl _ =>
|
||||
let e' ← mkAppS₂ f p' q
|
||||
return .step e' <| mkApp4 (mkConst ``arrow_congr_left f.constLevels!) p p' q h
|
||||
| .rfl _, .step q' h _ =>
|
||||
let e' ← mkAppS₂ f p q'
|
||||
return .step e' <| mkApp4 (mkConst ``arrow_congr_right f.constLevels!) p q q' h
|
||||
| .step p' h₁ _, .step q' h₂ _ =>
|
||||
let e' ← mkAppS₂ f p' q'
|
||||
return .step e' <| mkApp6 (mkConst ``arrow_congr f.constLevels!) p p' q q' h₁ h₂
|
||||
|
||||
/--
|
||||
Simplifies a telescope of non-dependent arrows `p₁ → p₂ → ... → pₙ → q` by:
|
||||
1. Converting to `Arrow p₁ (Arrow p₂ (... (Arrow pₙ q)))` (see `toArrow`)
|
||||
2. Simplifying each `pᵢ` and `q` (see `simpArrows`)
|
||||
3. Converting back to `→` form (see `toForall`)
|
||||
|
||||
Using `Arrow` (a definitional wrapper around `→`) avoids the quadratic proof growth that
|
||||
occurs with `Expr.forallE`. With `forallE`, each nesting level bumps de Bruijn indices in
|
||||
subterms, destroying sharing. For example, if each `pᵢ` contains a free variable `x`, the
|
||||
de Bruijn representation of `x` differs at each depth, preventing hash-consing from
|
||||
recognizing them as identical.
|
||||
|
||||
With `Arrow`, both arguments are explicit (not under binders), so subterms remain identical
|
||||
across nesting levels and can be shared, yielding linear-sized proofs.
|
||||
|
||||
**Tradeoff**: This function simplifies each `pᵢ` and `q` individually, but misses
|
||||
simplifications that depend on the arrow structure itself. For example, `q → p → p`
|
||||
won't be simplified to `True` (when `p : Prop`) because the simplifier does not have
|
||||
a chance to apply `post` methods to the intermediate arrow `p → p`.
|
||||
|
||||
Thus, this is a simproc that is meant to be used as a pre-method and marks the
|
||||
result as fully simplified to prevent `simpArrow` from being applied.
|
||||
-/
|
||||
public def simpArrowTelescope : Simproc := fun e => do
|
||||
unless e.isArrow do return .rfl -- not applicable
|
||||
let { arrow, infos, v } ← toArrow e
|
||||
let .step arrow' h _ ← simpArrows arrow | return .rfl (done := true)
|
||||
let e' ← toForall arrow' infos
|
||||
let α := mkSort v
|
||||
let v1 := v.succ
|
||||
let h := mkApp6 (mkConst ``Eq.trans [v1]) α e arrow arrow' (mkApp2 (mkConst ``Eq.refl [v1]) α arrow) h
|
||||
let h := mkApp6 (mkConst ``Eq.trans [v1]) α e arrow' e' h (mkApp2 (mkConst ``Eq.refl [v1]) α e')
|
||||
return .step e' h (done := true)
|
||||
|
||||
public def simpArrow (e : Expr) : SimpM Result := do
|
||||
let p := e.bindingDomain!
|
||||
let q := e.bindingBody!
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
import Lean
|
||||
open Lean Meta
|
||||
opaque f : Nat → Nat
|
||||
|
||||
namespace SimpBench
|
||||
/-!
|
||||
## `SymM` Simplifier benchmarks
|
||||
|
|
@ -24,15 +22,18 @@ def checkWithKernel (r : Sym.Simp.Result) : MetaM Float := do
|
|||
let endTime ← IO.monoNanosNow
|
||||
return (endTime - startTime).toFloat / 1000000.0
|
||||
|
||||
def mkSimpMethods : MetaM Sym.Simp.Methods := do
|
||||
def mkSimpMethods (arrowTelescope : Bool) : MetaM Sym.Simp.Methods := do
|
||||
let thms : Sym.Simp.Theorems := {}
|
||||
let thms := thms.insert (← Sym.Simp.mkTheoremFromDecl ``Nat.zero_add)
|
||||
let thms := thms.insert (← Sym.Simp.mkTheoremFromDecl ``Nat.add_zero)
|
||||
return { post := thms.rewrite }
|
||||
if arrowTelescope then
|
||||
return { pre := Sym.Simp.simpArrowTelescope, post := thms.rewrite }
|
||||
else
|
||||
return { post := thms.rewrite }
|
||||
|
||||
def simp (e : Expr) : MetaM (Sym.Simp.Result × Float) := Sym.SymM.run do
|
||||
def simp (e : Expr) (arrowTelescope : Bool) : MetaM (Sym.Simp.Result × Float) := Sym.SymM.run do
|
||||
let e ← Grind.shareCommon e
|
||||
let methods ← mkSimpMethods
|
||||
let methods ← mkSimpMethods arrowTelescope
|
||||
let startTime ← IO.monoNanosNow
|
||||
let r ← Sym.simp e methods { maxSteps := 100000000 }
|
||||
let endTime ← IO.monoNanosNow
|
||||
|
|
@ -44,11 +45,11 @@ def simp (e : Expr) : MetaM (Sym.Simp.Result × Float) := Sym.SymM.run do
|
|||
-- logInfo e'; logInfo h
|
||||
return (r, timeMs)
|
||||
|
||||
def ppExample (e : Expr) (info := false) : MetaM Unit := do
|
||||
def ppExample (e : Expr) (arrowTelescope : Bool) (info := false) : MetaM Unit := do
|
||||
IO.println "Example:"
|
||||
IO.println (← ppExpr e)
|
||||
IO.println "====>"
|
||||
match (← simp e).1 with
|
||||
match (← simp e arrowTelescope).1 with
|
||||
| .rfl _ => IO.println "<no change>"
|
||||
| .step e' h _ =>
|
||||
IO.println (← ppExpr e')
|
||||
|
|
@ -59,8 +60,8 @@ def ppExample (e : Expr) (info := false) : MetaM Unit := do
|
|||
IO.println (← ppExpr h)
|
||||
IO.println ""
|
||||
|
||||
def benchSimp (name : String) (e : Expr) (check := false) : MetaM Unit := do
|
||||
let (r, timeMs) ← simp e
|
||||
def benchSimp (name : String) (e : Expr) (arrowTelescope : Bool) (check := false) : MetaM Unit := do
|
||||
let (r, timeMs) ← simp e arrowTelescope
|
||||
let proofSize ← getProofSize r
|
||||
if check then
|
||||
let kMs ← checkWithKernel r
|
||||
|
|
@ -89,9 +90,14 @@ def mkForallBench (n : Nat) (useImplies : Bool) : MetaM Expr :=
|
|||
go n (← mkArrow (mkNatEq xs[n]! (mkNatAdd (mkNatLit 0) xs[n]!)) e)
|
||||
go n (mkConst ``True)
|
||||
|
||||
def benchForall (n : Nat) (useImplies : Bool) (check := false) : MetaM Unit := do
|
||||
let e ← mkForallBench n useImplies
|
||||
benchSimp s!"forall_{n}" e check
|
||||
inductive Kind where
|
||||
| implies
|
||||
| arrowTelescope
|
||||
| arrow
|
||||
|
||||
def benchForall (n : Nat) (kind : Kind) (check := false) : MetaM Unit := do
|
||||
let e ← mkForallBench n (kind matches .implies)
|
||||
benchSimp s!"forall_{n}" e (kind matches .arrowTelescope) check
|
||||
|
||||
set_option maxRecDepth 100000
|
||||
|
||||
|
|
@ -102,17 +108,24 @@ def runAllBenchmarks : MetaM Unit := do
|
|||
|
||||
IO.println ""
|
||||
IO.println "--- Benchmark 1: Forall Telescope block using arrows in the body ---"
|
||||
ppExample (← mkForallBench 5 false)
|
||||
ppExample (← mkForallBench 5 false) false
|
||||
|
||||
for n in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200, 300, 400] do
|
||||
benchForall n false (n < 500)
|
||||
benchForall n .arrow (n < 500)
|
||||
|
||||
IO.println ""
|
||||
IO.println "--- Benchmark 1: Forall Telescope block using `implies` in the body ---"
|
||||
ppExample (← mkForallBench 5 true)
|
||||
IO.println "--- Benchmark 2: Forall Telescope block using arrow telescope in the body ---"
|
||||
ppExample (← mkForallBench 5 false) true
|
||||
|
||||
for n in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200, 300, 400] do
|
||||
benchForall n true (n < 500)
|
||||
benchForall n .arrowTelescope (n < 500)
|
||||
|
||||
IO.println ""
|
||||
IO.println "--- Benchmark 3: Forall Telescope block using `implies` in the body ---"
|
||||
ppExample (← mkForallBench 5 true) false
|
||||
|
||||
for n in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200, 300, 400] do
|
||||
benchForall n .implies (n < 500)
|
||||
|
||||
#eval runAllBenchmarks
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue