From d4c832ecb01bfb22117dcc79875796126a0c06d7 Mon Sep 17 00:00:00 2001 From: Joachim Breitner Date: Fri, 5 Dec 2025 17:16:31 +0100 Subject: [PATCH] perf: de-fuel some recursive definitions in Core (#11416) This PR follows up on #7965 and avoids manual fuel constructions in some recursive definitions. --- src/Init/Data/String/Basic.lean | 142 +++++++++-------------- src/Init/Data/String/Decode.lean | 3 + src/Init/Grind/Module/NatModuleNorm.lean | 18 +-- src/Init/Grind/Ordered/Linarith.lean | 25 ++-- src/Init/Grind/Ring/CommSolver.lean | 43 +++---- 5 files changed, 93 insertions(+), 138 deletions(-) diff --git a/src/Init/Data/String/Basic.lean b/src/Init/Data/String/Basic.lean index f2dca5c35c..97b8ce2629 100644 --- a/src/Init/Data/String/Basic.lean +++ b/src/Init/Data/String/Basic.lean @@ -64,103 +64,73 @@ not a sequence of Unicode scalar values. -/ @[inline, expose] def ByteArray.utf8Decode? (b : ByteArray) : Option (Array Char) := - go (b.size + 1) 0 #[] (by simp) (by simp) + go 0 #[] (by simp) where - go (fuel : Nat) (i : Nat) (acc : Array Char) (hi : i ≤ b.size) (hf : b.size - i < fuel) : Option (Array Char) := - match fuel, hf with - | fuel + 1, _ => - if i = b.size then - some acc - else - match h : utf8DecodeChar? b i with - | none => none - | some c => go fuel (i + c.utf8Size) (acc.push c) - (le_size_of_utf8DecodeChar?_eq_some h) - (have := c.utf8Size_pos; have := le_size_of_utf8DecodeChar?_eq_some h; by omega) - termination_by structural fuel + @[semireducible] + go (i : Nat) (acc : Array Char) (hi : i ≤ b.size) : Option (Array Char) := + if i < b.size then + match h : utf8DecodeChar? b i with + | none => none + | some c => go (i + c.utf8Size) (acc.push c) (le_size_of_utf8DecodeChar?_eq_some h) + else + some acc + termination_by b.size - i + decreasing_by have := c.utf8Size_pos; omega @[expose, extern "lean_string_validate_utf8"] def ByteArray.validateUTF8 (b : @& ByteArray) : Bool := - go (b.size + 1) 0 (by simp) (by simp) + go 0 (by simp) where - go (fuel : Nat) (i : Nat) (hi : i ≤ b.size) (hf : b.size - i < fuel) : Bool := - match fuel, hf with - | fuel + 1, _ => - if hi : i = b.size then - true - else - match h : validateUTF8At b i with - | false => false - | true => go fuel (i + b[i].utf8ByteSize (isUTF8FirstByte_of_validateUTF8At h)) - ?_ ?_ - termination_by structural fuel + @[semireducible] + go (i : Nat) (hi : i ≤ b.size) : Bool := + if hi : i < b.size then + match h : validateUTF8At b i with + | false => false + | true => go (i + b[i].utf8ByteSize (isUTF8FirstByte_of_validateUTF8At h)) ?_ + else + true + termination_by b.size - i + decreasing_by + have := b[i].utf8ByteSize_pos (isUTF8FirstByte_of_validateUTF8At h); omega finally all_goals rw [ByteArray.validateUTF8At_eq_isSome_utf8DecodeChar?] at h · rw [← ByteArray.utf8Size_utf8DecodeChar (h := h)] exact add_utf8Size_utf8DecodeChar_le_size - · rw [← ByteArray.utf8Size_utf8DecodeChar (h := h)] - have := add_utf8Size_utf8DecodeChar_le_size (h := h) - have := (b.utf8DecodeChar i h).utf8Size_pos - omega -theorem ByteArray.isSome_utf8Decode?Go_eq_validateUTF8Go {b : ByteArray} {fuel : Nat} - {i : Nat} {acc : Array Char} {hi : i ≤ b.size} {hf : b.size - i < fuel} : - (utf8Decode?.go b fuel i acc hi hf).isSome = validateUTF8.go b fuel i hi hf := by +theorem ByteArray.isSome_utf8Decode?Go_eq_validateUTF8Go {b : ByteArray} + {i : Nat} {acc : Array Char} {hi : i ≤ b.size} : + (utf8Decode?.go b i acc hi).isSome = validateUTF8.go b i hi := by fun_induction utf8Decode?.go with - | case1 => simp [validateUTF8.go] - | case2 i acc hi fuel hf h₁ h₂ => - simp only [Option.isSome_none, validateUTF8.go, h₁, ↓reduceDIte, Bool.false_eq] + | case1 i acc hi h₁ h₂ => + unfold validateUTF8.go + simp only [Option.isSome_none, ↓reduceDIte, Bool.false_eq, h₁] split · rfl · rename_i heq simp [validateUTF8At_eq_isSome_utf8DecodeChar?, h₂] at heq - | case3 i acc hi fuel hf h₁ c h₂ ih => - simp [validateUTF8.go, h₁] + | case2 i acc hi h₁ c h₂ ih => + unfold validateUTF8.go + simp only [↓reduceDIte, ih, h₁] split · rename_i heq simp [validateUTF8At_eq_isSome_utf8DecodeChar?, h₂] at heq - · rw [ih] - congr + · congr rw [← ByteArray.utf8Size_utf8DecodeChar (h := by simp [h₂])] simp [utf8DecodeChar, h₂] + | case3 => unfold validateUTF8.go; simp [*] theorem ByteArray.isSome_utf8Decode?_eq_validateUTF8 {b : ByteArray} : b.utf8Decode?.isSome = b.validateUTF8 := b.isSome_utf8Decode?Go_eq_validateUTF8Go -theorem ByteArray.utf8Decode?.go.congr {b b' : ByteArray} {fuel fuel' i i' : Nat} {acc acc' : Array Char} {hi hi' hf hf'} - (hbb' : b = b') (hii' : i = i') (hacc : acc = acc') : - ByteArray.utf8Decode?.go b fuel i acc hi hf = ByteArray.utf8Decode?.go b' fuel' i' acc' hi' hf' := by - subst hbb' hii' hacc - fun_induction ByteArray.utf8Decode?.go b fuel i acc hi hf generalizing fuel' with - | case1 => - rw [go.eq_def] - split - simp - | case2 => - rw [go.eq_def] - split <;> split - · simp_all - · split <;> simp_all - | case3 => - conv => rhs; rw [go.eq_def] - split <;> split - · simp_all - · split - · simp_all - · rename_i c₁ hc₁ ih _ _ _ _ _ c₂ hc₂ - obtain rfl : c₁ = c₂ := by rw [← Option.some_inj, ← hc₁, ← hc₂] - apply ih - @[simp] theorem ByteArray.utf8Decode?_empty : ByteArray.empty.utf8Decode? = some #[] := by simp [utf8Decode?, utf8Decode?.go] -private theorem ByteArray.isSome_utf8Decode?go_iff {b : ByteArray} {fuel i : Nat} {hi : i ≤ b.size} {hf} {acc : Array Char} : - (ByteArray.utf8Decode?.go b fuel i acc hi hf).isSome ↔ IsValidUTF8 (b.extract i b.size) := by +private theorem ByteArray.isSome_utf8Decode?go_iff {b : ByteArray} {hi : i ≤ b.size} {acc : Array Char} : + (ByteArray.utf8Decode?.go b i acc hi).isSome ↔ IsValidUTF8 (b.extract i b.size) := by fun_induction ByteArray.utf8Decode?.go with - | case1 => simp - | case2 fuel i hi hf acc h₁ h₂ => + | case1 i hi acc h₁ h₂ => simp only [Option.isSome_none, Bool.false_eq_true, false_iff] rintro ⟨l, hl⟩ have : l ≠ [] := by @@ -170,7 +140,7 @@ private theorem ByteArray.isSome_utf8Decode?go_iff {b : ByteArray} {fuel i : Nat rw [← l.cons_head_tail this] at hl rw [utf8DecodeChar?_eq_utf8DecodeChar?_extract, hl, List.utf8DecodeChar?_utf8Encode_cons] at h₂ simp at h₂ - | case3 i acc hi fuel hf h₁ c h₂ ih => + | case2 i acc hi h₁ c h₂ ih => rw [ih] have h₂' := h₂ rw [utf8DecodeChar?_eq_utf8DecodeChar?_extract] at h₂' @@ -179,6 +149,9 @@ private theorem ByteArray.isSome_utf8Decode?go_iff {b : ByteArray} {fuel i : Nat (le_size_of_utf8DecodeChar?_eq_some h₂)] at hl ⊢ rw [ByteArray.append_inj_left hl (by have := le_size_of_utf8DecodeChar?_eq_some h₂; simp; omega), ← List.utf8Encode_singleton, isValidUTF8_utf8Encode_singleton_append_iff] + | case3 i => + have : i = b.size := by omega + simp [*] theorem ByteArray.isSome_utf8Decode?_iff {b : ByteArray} : b.utf8Decode?.isSome ↔ IsValidUTF8 b := by @@ -305,27 +278,21 @@ theorem String.length_toList {s : String} : s.toList.length = s.length := (rfl) @[deprecated String.length_toList (since := "2025-10-30")] theorem String.length_data {b : String} : b.toList.length = b.length := (rfl) -private theorem ByteArray.utf8Decode?go_eq_utf8Decode?go_extract {b : ByteArray} {fuel i : Nat} {hi : i ≤ b.size} {hf} {acc : Array Char} : - utf8Decode?.go b fuel i acc hi hf = (utf8Decode?.go (b.extract i b.size) fuel 0 #[] (by simp) (by simp [hf])).map (acc ++ ·) := by - fun_cases utf8Decode?.go b fuel i acc hi hf with - | case1 => - rw [utf8Decode?.go] - simp only [size_extract, Nat.le_refl, Nat.min_eq_left, Nat.zero_add, List.push_toArray, - List.nil_append] - rw [if_pos (by omega)] - simp - | case2 fuel hf₁ h₁ h₂ hf₂ => +private theorem ByteArray.utf8Decode?go_eq_utf8Decode?go_extract {b : ByteArray} {hi : i ≤ b.size} {acc : Array Char} : + utf8Decode?.go b i acc hi = (utf8Decode?.go (b.extract i b.size) 0 #[] (by simp)).map (acc ++ ·) := by + fun_cases utf8Decode?.go b i acc hi with + | case1 h₁ h₂ => rw [utf8Decode?.go] simp only [size_extract, Nat.le_refl, Nat.min_eq_left, Nat.zero_add, List.push_toArray, List.nil_append] - rw [if_neg (by omega)] + rw [if_pos (by omega)] rw [utf8DecodeChar?_eq_utf8DecodeChar?_extract] at h₂ split <;> simp_all - | case3 fuel hf₁ h₁ c h₂ hf₂ => + | case2 h₁ c h₂ => conv => rhs; rw [utf8Decode?.go] simp only [size_extract, Nat.le_refl, Nat.min_eq_left, Nat.zero_add, List.push_toArray, List.nil_append] - rw [if_neg (by omega)] + rw [if_pos (by omega)] rw [utf8DecodeChar?_eq_utf8DecodeChar?_extract] at h₂ split · simp_all @@ -338,20 +305,25 @@ private theorem ByteArray.utf8Decode?go_eq_utf8Decode?go_extract {b : ByteArray} simp only [size_extract, Nat.le_refl, Nat.min_eq_left, Option.map_map, ByteArray.extract_extract] have : (fun x => acc ++ x) ∘ (fun x => #[c] ++ x) = fun x => acc.push c ++ x := by funext; simp simp [(by omega : i + (b.size - i) = b.size), this] -termination_by fuel + | case3 => + rw [utf8Decode?.go] + simp only [size_extract, Nat.le_refl, Nat.min_eq_left, Nat.zero_add, List.push_toArray, + List.nil_append] + rw [if_neg (by omega)] + simp +termination_by b.size - i theorem ByteArray.utf8Decode?_utf8Encode_singleton_append {l : ByteArray} {c : Char} : ([c].utf8Encode ++ l).utf8Decode? = l.utf8Decode?.map (#[c] ++ ·) := by rw [utf8Decode?, utf8Decode?.go, - if_neg (by simp [List.utf8Encode_singleton]; have := c.utf8Size_pos; omega)] + if_pos (by simp [List.utf8Encode_singleton]; have := c.utf8Size_pos; omega)] split · simp_all [List.utf8DecodeChar?_utf8Encode_singleton_append] · rename_i d h obtain rfl : c = d := by simpa [List.utf8DecodeChar?_utf8Encode_singleton_append] using h rw [utf8Decode?go_eq_utf8Decode?go_extract, utf8Decode?] simp only [List.push_toArray, List.nil_append, Nat.zero_add] - congr 1 - apply ByteArray.utf8Decode?.go.congr _ rfl rfl + congr 2 apply extract_append_eq_right _ (by simp) simp [List.utf8Encode_singleton] diff --git a/src/Init/Data/String/Decode.lean b/src/Init/Data/String/Decode.lean index 2db172c61d..98dfa5aa6b 100644 --- a/src/Init/Data/String/Decode.lean +++ b/src/Init/Data/String/Decode.lean @@ -1441,6 +1441,9 @@ public def utf8ByteSize (c : UInt8) (_h : c.IsUTF8FirstByte) : Nat := else 4 +public theorem utf8ByteSize_pos (c : UInt8) (h : c.IsUTF8FirstByte) : 0 < c.utf8ByteSize h := by + fun_cases utf8ByteSize <;> simp + def _root_.ByteArray.utf8DecodeChar?.FirstByte.utf8ByteSize : FirstByte → Nat | .invalid => 0 | .done => 1 diff --git a/src/Init/Grind/Module/NatModuleNorm.lean b/src/Init/Grind/Module/NatModuleNorm.lean index c7fbdd4585..6a1bbb93ff 100644 --- a/src/Init/Grind/Module/NatModuleNorm.lean +++ b/src/Init/Grind/Module/NatModuleNorm.lean @@ -79,9 +79,9 @@ theorem Poly.denoteN_append {α} [NatModule α] (ctx : Context α) (p₁ p₂ : attribute [local simp] Poly.denoteN_append -theorem Poly.denoteN_combine' {α} [NatModule α] (ctx : Context α) (fuel : Nat) (p₁ p₂ : Poly) - : p₁.NonnegCoeffs → p₂.NonnegCoeffs → (p₁.combine' fuel p₂).denoteN ctx = p₁.denoteN ctx + p₂.denoteN ctx := by - fun_induction p₁.combine' fuel p₂ <;> intro h₁ h₂ <;> try simp [*, zero_add, add_zero] +theorem Poly.denoteN_combine {α} [NatModule α] (ctx : Context α) (p₁ p₂ : Poly) + : p₁.NonnegCoeffs → p₂.NonnegCoeffs → (p₁.combine p₂).denoteN ctx = p₁.denoteN ctx + p₂.denoteN ctx := by + fun_induction p₁.combine p₂ <;> intro h₁ h₂ <;> try simp [*, zero_add, add_zero] next hx _ h ih => simp at hx simp +zetaDelta at h @@ -103,10 +103,6 @@ theorem Poly.denoteN_combine' {α} [NatModule α] (ctx : Context α) (fuel : Nat cases h₂; next h₂ => simp [ih h₁ h₂, *]; ac_rfl -theorem Poly.denoteN_combine {α} [NatModule α] (ctx : Context α) (p₁ p₂ : Poly) - : p₁.NonnegCoeffs → p₂.NonnegCoeffs → (p₁.combine p₂).denoteN ctx = p₁.denoteN ctx + p₂.denoteN ctx := by - intros; simp [combine, denoteN_combine', *] - theorem Poly.denoteN_mul' {α} [NatModule α] (ctx : Context α) (p : Poly) (k : Nat) : p.NonnegCoeffs → (p.mul' k).denoteN ctx = k • p.denoteN ctx := by induction p <;> simp [mul', *, nsmul_zero] next ih => @@ -151,9 +147,8 @@ theorem Poly.append_Nonneg (p₁ p₂ : Poly) : p₁.NonnegCoeffs → p₂.Nonne fun_induction append <;> intro h₁ h₂; simp [*] next ih => cases h₁; constructor; assumption; apply ih <;> assumption -theorem Poly.combine'_Nonneg (fuel : Nat) (p₁ p₂ : Poly) : p₁.NonnegCoeffs → p₂.NonnegCoeffs → (p₁.combine' fuel p₂).NonnegCoeffs := by - fun_induction Poly.combine' - next => apply Poly.append_Nonneg +theorem Poly.combine_Nonneg (p₁ p₂ : Poly) : p₁.NonnegCoeffs → p₂.NonnegCoeffs → (p₁.combine p₂).NonnegCoeffs := by + fun_induction Poly.combine next => intros; assumption next => intros; assumption next ih => @@ -172,9 +167,6 @@ theorem Poly.combine'_Nonneg (fuel : Nat) (p₁ p₂ : Poly) : p₁.NonnegCoeffs constructor; assumption apply ih; constructor; assumption; assumption; assumption -theorem Poly.combine_Nonneg (p₁ p₂ : Poly) : p₁.NonnegCoeffs → p₂.NonnegCoeffs → (p₁.combine p₂).NonnegCoeffs := by - simp [combine]; apply Poly.combine'_Nonneg - theorem Expr.toPolyN_Nonneg (e : Expr) : e.toPolyN.NonnegCoeffs := by fun_induction toPolyN <;> try constructor <;> simp next => constructor; simp; constructor diff --git a/src/Init/Grind/Ordered/Linarith.lean b/src/Init/Grind/Ordered/Linarith.lean index 65acedb558..450d1b0a31 100644 --- a/src/Init/Grind/Ordered/Linarith.lean +++ b/src/Init/Grind/Ordered/Linarith.lean @@ -123,26 +123,22 @@ def Poly.append (p₁ p₂ : Poly) : Poly := | .nil => p₂ | .add k x p₁ => .add k x (append p₁ p₂) -def Poly.combine' (fuel : Nat) (p₁ p₂ : Poly) : Poly := - match fuel with - | 0 => p₁.append p₂ - | fuel + 1 => match p₁, p₂ with +def Poly.combine (p₁ p₂ : Poly) : Poly := + match p₁, p₂ with | .nil, p₂ => p₂ | p₁, .nil => p₁ | .add a₁ x₁ p₁, .add a₂ x₂ p₂ => bif Nat.beq x₁ x₂ then let a := a₁ + a₂ bif a == 0 then - combine' fuel p₁ p₂ + combine p₁ p₂ else - .add a x₁ (combine' fuel p₁ p₂) + .add a x₁ (combine p₁ p₂) else bif Nat.blt x₂ x₁ then - .add a₁ x₁ (combine' fuel p₁ (.add a₂ x₂ p₂)) + .add a₁ x₁ (combine p₁ (.add a₂ x₂ p₂)) else - .add a₂ x₂ (combine' fuel (.add a₁ x₁ p₁) p₂) - -def Poly.combine (p₁ p₂ : Poly) : Poly := - combine' 100000000 p₁ p₂ + .add a₂ x₂ (combine (.add a₁ x₁ p₁) p₂) + termination_by sizeOf p₁ + sizeOf p₂ /-- Converts the given expression into a polynomial. -/ def Expr.toPoly' (e : Expr) : Poly := @@ -205,8 +201,8 @@ theorem Poly.denote_append {α} [IntModule α] (ctx : Context α) (p₁ p₂ : P attribute [local simp] Poly.denote_append -theorem Poly.denote_combine' {α} [IntModule α] (ctx : Context α) (fuel : Nat) (p₁ p₂ : Poly) : (p₁.combine' fuel p₂).denote ctx = p₁.denote ctx + p₂.denote ctx := by - fun_induction p₁.combine' fuel p₂ <;> +theorem Poly.denote_combine {α} [IntModule α] (ctx : Context α) (p₁ p₂ : Poly) : (p₁.combine p₂).denote ctx = p₁.denote ctx + p₂.denote ctx := by + fun_induction p₁.combine p₂ <;> simp_all +zetaDelta [denote] next h _ => rw [Int.add_comm] at h @@ -214,9 +210,6 @@ theorem Poly.denote_combine' {α} [IntModule α] (ctx : Context α) (fuel : Nat) next => rw [add_zsmul]; ac_rfl all_goals ac_rfl -theorem Poly.denote_combine {α} [IntModule α] (ctx : Context α) (p₁ p₂ : Poly) : (p₁.combine p₂).denote ctx = p₁.denote ctx + p₂.denote ctx := by - simp [combine, denote_combine'] - attribute [local simp] Poly.denote_combine private theorem Expr.denote_toPoly'_go {α} [IntModule α] {k p} (ctx : Context α) (e : Expr) diff --git a/src/Init/Grind/Ring/CommSolver.lean b/src/Init/Grind/Ring/CommSolver.lean index d3ddd0bb3d..59557b1436 100644 --- a/src/Init/Grind/Ring/CommSolver.lean +++ b/src/Init/Grind/Ring/CommSolver.lean @@ -889,26 +889,22 @@ where | .num k' => acc.insert (k*k' % c) m | .add k' m' p => go p (acc.insert (k*k' % c) (m.mul_nc m')) -def Poly.combineC (p₁ p₂ : Poly) (c : Nat) : Poly := - go hugeFuel p₁ p₂ -where - go (fuel : Nat) (p₁ p₂ : Poly) : Poly := - match fuel with - | 0 => p₁.concat p₂ - | fuel + 1 => match p₁, p₂ with - | .num k₁, .num k₂ => .num ((k₁ + k₂) % c) - | .num k₁, .add k₂ m₂ p₂ => addConstC (.add k₂ m₂ p₂) k₁ c - | .add k₁ m₁ p₁, .num k₂ => addConstC (.add k₁ m₁ p₁) k₂ c - | .add k₁ m₁ p₁, .add k₂ m₂ p₂ => - match m₁.grevlex m₂ with - | .eq => - let k := (k₁ + k₂) % c - bif k == 0 then - go fuel p₁ p₂ - else - .add k m₁ (go fuel p₁ p₂) - | .gt => .add k₁ m₁ (go fuel p₁ (.add k₂ m₂ p₂)) - | .lt => .add k₂ m₂ (go fuel (.add k₁ m₁ p₁) p₂) +@[semireducible] +def Poly.combineC (p₁ p₂ : Poly) (c : Nat) : Poly := match p₁, p₂ with + | .num k₁, .num k₂ => .num ((k₁ + k₂) % c) + | .num k₁, .add k₂ m₂ p₂ => addConstC (.add k₂ m₂ p₂) k₁ c + | .add k₁ m₁ p₁, .num k₂ => addConstC (.add k₁ m₁ p₁) k₂ c + | .add k₁ m₁ p₁, .add k₂ m₂ p₂ => + match m₁.grevlex m₂ with + | .eq => + let k := (k₁ + k₂) % c + bif k == 0 then + combineC p₁ p₂ c + else + .add k m₁ (combineC p₁ p₂ c) + | .gt => .add k₁ m₁ (combineC p₁ (.add k₂ m₂ p₂) c) + | .lt => .add k₂ m₂ (combineC (.add k₁ m₁ p₁) p₂ c) +termination_by sizeOf p₁ + sizeOf p₂ def Poly.mulC (p₁ : Poly) (p₂ : Poly) (c : Nat) : Poly := go p₁ (.num 0) @@ -1432,10 +1428,9 @@ theorem Poly.denote_mulMonC_nc {α c} [Ring α] [IsCharP α c] (ctx : Context α theorem Poly.denote_combineC {α c} [Ring α] [IsCharP α c] (ctx : Context α) (p₁ p₂ : Poly) : (combineC p₁ p₂ c).denote ctx = p₁.denote ctx + p₂.denote ctx := by - unfold combineC; generalize hugeFuel = fuel - fun_induction combineC.go - <;> simp [*, denote_concat, denote_addConstC, denote, intCast_add, - add_comm, add_left_comm, add_assoc, IsCharP.intCast_emod, zsmul_eq_intCast_mul] + fun_induction combineC + <;> simp [*, denote_addConstC, denote, intCast_add, add_comm, add_left_comm, add_assoc, + IsCharP.intCast_emod, zsmul_eq_intCast_mul] next hg _ h _ => simp +zetaDelta at h rw [← add_assoc, Mon.eq_of_grevlex hg, ← right_distrib, ← intCast_add,