From 970261b1e108bdb1e4536b83d6600886511dc8a6 Mon Sep 17 00:00:00 2001 From: Joachim Breitner Date: Wed, 13 Nov 2024 16:54:29 +0100 Subject: [PATCH] perf: optimize Nat.Linear.Expr.toPoly (#6062) --- src/Init/Data/Nat/Linear.lean | 44 ++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/src/Init/Data/Nat/Linear.lean b/src/Init/Data/Nat/Linear.lean index 2a77c7ab8e..69d8bd854c 100644 --- a/src/Init/Data/Nat/Linear.lean +++ b/src/Init/Data/Nat/Linear.lean @@ -146,12 +146,16 @@ def Poly.combineAux (fuel : Nat) (p₁ p₂ : Poly) : Poly := def Poly.combine (p₁ p₂ : Poly) : Poly := combineAux hugeFuel p₁ p₂ -def Expr.toPoly : Expr → Poly - | Expr.num k => bif k == 0 then [] else [ (k, fixedVar) ] - | Expr.var i => [(1, i)] - | Expr.add a b => a.toPoly ++ b.toPoly - | Expr.mulL k a => a.toPoly.mul k - | Expr.mulR a k => a.toPoly.mul k +def Expr.toPoly (e : Expr) := go 1 e [] +where + -- Implementation note: This assembles the result using difference lists + -- to avoid `++` on lists. + go (coeff : Nat) : Expr → (Poly → Poly) + | Expr.num k => bif k == 0 then id else ((coeff * k, fixedVar) :: ·) + | Expr.var i => ((coeff, i) :: ·) + | Expr.add a b => go coeff a ∘ go coeff b + | Expr.mulL k a + | Expr.mulR a k => bif k == 0 then id else go (coeff * k) a def Poly.norm (p : Poly) : Poly := p.sort.fuse @@ -516,13 +520,25 @@ theorem Poly.denote_combine (ctx : Context) (p₁ p₂ : Poly) : (p₁.combine p attribute [local simp] Poly.denote_combine +theorem Expr.denote_toPoly_go (ctx : Context) (e : Expr) : + (toPoly.go k e p).denote ctx = k * e.denote ctx + p.denote ctx := by + induction k, e using Expr.toPoly.go.induct generalizing p with + | case1 k k' => + simp only [toPoly.go] + by_cases h : k' == 0 + · simp [h, eq_of_beq h] + · simp [h, Var.denote] + | case2 k i => simp [toPoly.go] + | case3 k a b iha ihb => simp [toPoly.go, iha, ihb] + | case4 k k' a ih + | case5 k a k' ih => + simp only [toPoly.go, denote, mul_eq] + by_cases h : k' == 0 + · simp [h, eq_of_beq h] + · simp [h, cond_false, ih, Nat.mul_assoc] + theorem Expr.denote_toPoly (ctx : Context) (e : Expr) : e.toPoly.denote ctx = e.denote ctx := by - induction e with - | num k => by_cases h : k == 0 <;> simp [toPoly, h, Var.denote]; simp [eq_of_beq h] - | var i => simp [toPoly] - | add a b iha ihb => simp [toPoly, iha, ihb] - | mulL k a ih => simp [toPoly, ih, -Poly.mul] - | mulR k a ih => simp [toPoly, ih, -Poly.mul] + simp [toPoly, Expr.denote_toPoly_go] attribute [local simp] Expr.denote_toPoly @@ -554,8 +570,8 @@ theorem ExprCnstr.denote_toPoly (ctx : Context) (c : ExprCnstr) : c.toPoly.denot cases c; rename_i eq lhs rhs simp [ExprCnstr.denote, PolyCnstr.denote, ExprCnstr.toPoly]; by_cases h : eq = true <;> simp [h] - · simp [Poly.denote_eq, Expr.toPoly] - · simp [Poly.denote_le, Expr.toPoly] + · simp [Poly.denote_eq] + · simp [Poly.denote_le] attribute [local simp] ExprCnstr.denote_toPoly