perf: de-fuel some recursive definitions in Core (#11416)

This PR follows up on #7965 and avoids manual fuel constructions in some
recursive definitions.
This commit is contained in:
Joachim Breitner 2025-12-05 17:16:31 +01:00 committed by GitHub
parent 9cbff55c56
commit d4c832ecb0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 93 additions and 138 deletions

View file

@ -64,103 +64,73 @@ not a sequence of Unicode scalar values.
-/
@[inline, expose]
def ByteArray.utf8Decode? (b : ByteArray) : Option (Array Char) :=
go (b.size + 1) 0 #[] (by simp) (by simp)
go 0 #[] (by simp)
where
go (fuel : Nat) (i : Nat) (acc : Array Char) (hi : i ≤ b.size) (hf : b.size - i < fuel) : Option (Array Char) :=
match fuel, hf with
| fuel + 1, _ =>
if i = b.size then
some acc
else
match h : utf8DecodeChar? b i with
| none => none
| some c => go fuel (i + c.utf8Size) (acc.push c)
(le_size_of_utf8DecodeChar?_eq_some h)
(have := c.utf8Size_pos; have := le_size_of_utf8DecodeChar?_eq_some h; by omega)
termination_by structural fuel
@[semireducible]
go (i : Nat) (acc : Array Char) (hi : i ≤ b.size) : Option (Array Char) :=
if i < b.size then
match h : utf8DecodeChar? b i with
| none => none
| some c => go (i + c.utf8Size) (acc.push c) (le_size_of_utf8DecodeChar?_eq_some h)
else
some acc
termination_by b.size - i
decreasing_by have := c.utf8Size_pos; omega
@[expose, extern "lean_string_validate_utf8"]
def ByteArray.validateUTF8 (b : @& ByteArray) : Bool :=
go (b.size + 1) 0 (by simp) (by simp)
go 0 (by simp)
where
go (fuel : Nat) (i : Nat) (hi : i ≤ b.size) (hf : b.size - i < fuel) : Bool :=
match fuel, hf with
| fuel + 1, _ =>
if hi : i = b.size then
true
else
match h : validateUTF8At b i with
| false => false
| true => go fuel (i + b[i].utf8ByteSize (isUTF8FirstByte_of_validateUTF8At h))
?_ ?_
termination_by structural fuel
@[semireducible]
go (i : Nat) (hi : i ≤ b.size) : Bool :=
if hi : i < b.size then
match h : validateUTF8At b i with
| false => false
| true => go (i + b[i].utf8ByteSize (isUTF8FirstByte_of_validateUTF8At h)) ?_
else
true
termination_by b.size - i
decreasing_by
have := b[i].utf8ByteSize_pos (isUTF8FirstByte_of_validateUTF8At h); omega
finally
all_goals rw [ByteArray.validateUTF8At_eq_isSome_utf8DecodeChar?] at h
· rw [← ByteArray.utf8Size_utf8DecodeChar (h := h)]
exact add_utf8Size_utf8DecodeChar_le_size
· rw [← ByteArray.utf8Size_utf8DecodeChar (h := h)]
have := add_utf8Size_utf8DecodeChar_le_size (h := h)
have := (b.utf8DecodeChar i h).utf8Size_pos
omega
theorem ByteArray.isSome_utf8Decode?Go_eq_validateUTF8Go {b : ByteArray} {fuel : Nat}
{i : Nat} {acc : Array Char} {hi : i ≤ b.size} {hf : b.size - i < fuel} :
(utf8Decode?.go b fuel i acc hi hf).isSome = validateUTF8.go b fuel i hi hf := by
theorem ByteArray.isSome_utf8Decode?Go_eq_validateUTF8Go {b : ByteArray}
{i : Nat} {acc : Array Char} {hi : i ≤ b.size} :
(utf8Decode?.go b i acc hi).isSome = validateUTF8.go b i hi := by
fun_induction utf8Decode?.go with
| case1 => simp [validateUTF8.go]
| case2 i acc hi fuel hf h₁ h₂ =>
simp only [Option.isSome_none, validateUTF8.go, h₁, ↓reduceDIte, Bool.false_eq]
| case1 i acc hi h₁ h₂ =>
unfold validateUTF8.go
simp only [Option.isSome_none, ↓reduceDIte, Bool.false_eq, h₁]
split
· rfl
· rename_i heq
simp [validateUTF8At_eq_isSome_utf8DecodeChar?, h₂] at heq
| case3 i acc hi fuel hf h₁ c h₂ ih =>
simp [validateUTF8.go, h₁]
| case2 i acc hi h₁ c h₂ ih =>
unfold validateUTF8.go
simp only [↓reduceDIte, ih, h₁]
split
· rename_i heq
simp [validateUTF8At_eq_isSome_utf8DecodeChar?, h₂] at heq
· rw [ih]
congr
· congr
rw [← ByteArray.utf8Size_utf8DecodeChar (h := by simp [h₂])]
simp [utf8DecodeChar, h₂]
| case3 => unfold validateUTF8.go; simp [*]
theorem ByteArray.isSome_utf8Decode?_eq_validateUTF8 {b : ByteArray} :
b.utf8Decode?.isSome = b.validateUTF8 :=
b.isSome_utf8Decode?Go_eq_validateUTF8Go
theorem ByteArray.utf8Decode?.go.congr {b b' : ByteArray} {fuel fuel' i i' : Nat} {acc acc' : Array Char} {hi hi' hf hf'}
(hbb' : b = b') (hii' : i = i') (hacc : acc = acc') :
ByteArray.utf8Decode?.go b fuel i acc hi hf = ByteArray.utf8Decode?.go b' fuel' i' acc' hi' hf' := by
subst hbb' hii' hacc
fun_induction ByteArray.utf8Decode?.go b fuel i acc hi hf generalizing fuel' with
| case1 =>
rw [go.eq_def]
split
simp
| case2 =>
rw [go.eq_def]
split <;> split
· simp_all
· split <;> simp_all
| case3 =>
conv => rhs; rw [go.eq_def]
split <;> split
· simp_all
· split
· simp_all
· rename_i c₁ hc₁ ih _ _ _ _ _ c₂ hc₂
obtain rfl : c₁ = c₂ := by rw [← Option.some_inj, ← hc₁, ← hc₂]
apply ih
@[simp]
theorem ByteArray.utf8Decode?_empty : ByteArray.empty.utf8Decode? = some #[] := by
simp [utf8Decode?, utf8Decode?.go]
private theorem ByteArray.isSome_utf8Decode?go_iff {b : ByteArray} {fuel i : Nat} {hi : i ≤ b.size} {hf} {acc : Array Char} :
(ByteArray.utf8Decode?.go b fuel i acc hi hf).isSome ↔ IsValidUTF8 (b.extract i b.size) := by
private theorem ByteArray.isSome_utf8Decode?go_iff {b : ByteArray} {hi : i ≤ b.size} {acc : Array Char} :
(ByteArray.utf8Decode?.go b i acc hi).isSome ↔ IsValidUTF8 (b.extract i b.size) := by
fun_induction ByteArray.utf8Decode?.go with
| case1 => simp
| case2 fuel i hi hf acc h₁ h₂ =>
| case1 i hi acc h₁ h₂ =>
simp only [Option.isSome_none, Bool.false_eq_true, false_iff]
rintro ⟨l, hl⟩
have : l ≠ [] := by
@ -170,7 +140,7 @@ private theorem ByteArray.isSome_utf8Decode?go_iff {b : ByteArray} {fuel i : Nat
rw [← l.cons_head_tail this] at hl
rw [utf8DecodeChar?_eq_utf8DecodeChar?_extract, hl, List.utf8DecodeChar?_utf8Encode_cons] at h₂
simp at h₂
| case3 i acc hi fuel hf h₁ c h₂ ih =>
| case2 i acc hi h₁ c h₂ ih =>
rw [ih]
have h₂' := h₂
rw [utf8DecodeChar?_eq_utf8DecodeChar?_extract] at h₂'
@ -179,6 +149,9 @@ private theorem ByteArray.isSome_utf8Decode?go_iff {b : ByteArray} {fuel i : Nat
(le_size_of_utf8DecodeChar?_eq_some h₂)] at hl ⊢
rw [ByteArray.append_inj_left hl (by have := le_size_of_utf8DecodeChar?_eq_some h₂; simp; omega),
← List.utf8Encode_singleton, isValidUTF8_utf8Encode_singleton_append_iff]
| case3 i =>
have : i = b.size := by omega
simp [*]
theorem ByteArray.isSome_utf8Decode?_iff {b : ByteArray} :
b.utf8Decode?.isSome ↔ IsValidUTF8 b := by
@ -305,27 +278,21 @@ theorem String.length_toList {s : String} : s.toList.length = s.length := (rfl)
@[deprecated String.length_toList (since := "2025-10-30")]
theorem String.length_data {b : String} : b.toList.length = b.length := (rfl)
private theorem ByteArray.utf8Decode?go_eq_utf8Decode?go_extract {b : ByteArray} {fuel i : Nat} {hi : i ≤ b.size} {hf} {acc : Array Char} :
utf8Decode?.go b fuel i acc hi hf = (utf8Decode?.go (b.extract i b.size) fuel 0 #[] (by simp) (by simp [hf])).map (acc ++ ·) := by
fun_cases utf8Decode?.go b fuel i acc hi hf with
| case1 =>
rw [utf8Decode?.go]
simp only [size_extract, Nat.le_refl, Nat.min_eq_left, Nat.zero_add, List.push_toArray,
List.nil_append]
rw [if_pos (by omega)]
simp
| case2 fuel hf₁ h₁ h₂ hf₂ =>
private theorem ByteArray.utf8Decode?go_eq_utf8Decode?go_extract {b : ByteArray} {hi : i ≤ b.size} {acc : Array Char} :
utf8Decode?.go b i acc hi = (utf8Decode?.go (b.extract i b.size) 0 #[] (by simp)).map (acc ++ ·) := by
fun_cases utf8Decode?.go b i acc hi with
| case1 h₁ h₂ =>
rw [utf8Decode?.go]
simp only [size_extract, Nat.le_refl, Nat.min_eq_left, Nat.zero_add, List.push_toArray,
List.nil_append]
rw [if_neg (by omega)]
rw [if_pos (by omega)]
rw [utf8DecodeChar?_eq_utf8DecodeChar?_extract] at h₂
split <;> simp_all
| case3 fuel hf₁ h₁ c h₂ hf₂ =>
| case2 h₁ c h₂ =>
conv => rhs; rw [utf8Decode?.go]
simp only [size_extract, Nat.le_refl, Nat.min_eq_left, Nat.zero_add, List.push_toArray,
List.nil_append]
rw [if_neg (by omega)]
rw [if_pos (by omega)]
rw [utf8DecodeChar?_eq_utf8DecodeChar?_extract] at h₂
split
· simp_all
@ -338,20 +305,25 @@ private theorem ByteArray.utf8Decode?go_eq_utf8Decode?go_extract {b : ByteArray}
simp only [size_extract, Nat.le_refl, Nat.min_eq_left, Option.map_map, ByteArray.extract_extract]
have : (fun x => acc ++ x) ∘ (fun x => #[c] ++ x) = fun x => acc.push c ++ x := by funext; simp
simp [(by omega : i + (b.size - i) = b.size), this]
termination_by fuel
| case3 =>
rw [utf8Decode?.go]
simp only [size_extract, Nat.le_refl, Nat.min_eq_left, Nat.zero_add, List.push_toArray,
List.nil_append]
rw [if_neg (by omega)]
simp
termination_by b.size - i
theorem ByteArray.utf8Decode?_utf8Encode_singleton_append {l : ByteArray} {c : Char} :
([c].utf8Encode ++ l).utf8Decode? = l.utf8Decode?.map (#[c] ++ ·) := by
rw [utf8Decode?, utf8Decode?.go,
if_neg (by simp [List.utf8Encode_singleton]; have := c.utf8Size_pos; omega)]
if_pos (by simp [List.utf8Encode_singleton]; have := c.utf8Size_pos; omega)]
split
· simp_all [List.utf8DecodeChar?_utf8Encode_singleton_append]
· rename_i d h
obtain rfl : c = d := by simpa [List.utf8DecodeChar?_utf8Encode_singleton_append] using h
rw [utf8Decode?go_eq_utf8Decode?go_extract, utf8Decode?]
simp only [List.push_toArray, List.nil_append, Nat.zero_add]
congr 1
apply ByteArray.utf8Decode?.go.congr _ rfl rfl
congr 2
apply extract_append_eq_right _ (by simp)
simp [List.utf8Encode_singleton]

View file

@ -1441,6 +1441,9 @@ public def utf8ByteSize (c : UInt8) (_h : c.IsUTF8FirstByte) : Nat :=
else
4
public theorem utf8ByteSize_pos (c : UInt8) (h : c.IsUTF8FirstByte) : 0 < c.utf8ByteSize h := by
fun_cases utf8ByteSize <;> simp
def _root_.ByteArray.utf8DecodeChar?.FirstByte.utf8ByteSize : FirstByte → Nat
| .invalid => 0
| .done => 1

View file

@ -79,9 +79,9 @@ theorem Poly.denoteN_append {α} [NatModule α] (ctx : Context α) (p₁ p₂ :
attribute [local simp] Poly.denoteN_append
theorem Poly.denoteN_combine' {α} [NatModule α] (ctx : Context α) (fuel : Nat) (p₁ p₂ : Poly)
: p₁.NonnegCoeffs → p₂.NonnegCoeffs → (p₁.combine' fuel p₂).denoteN ctx = p₁.denoteN ctx + p₂.denoteN ctx := by
fun_induction p₁.combine' fuel p₂ <;> intro h₁ h₂ <;> try simp [*, zero_add, add_zero]
theorem Poly.denoteN_combine {α} [NatModule α] (ctx : Context α) (p₁ p₂ : Poly)
: p₁.NonnegCoeffs → p₂.NonnegCoeffs → (p₁.combine p₂).denoteN ctx = p₁.denoteN ctx + p₂.denoteN ctx := by
fun_induction p₁.combine p₂ <;> intro h₁ h₂ <;> try simp [*, zero_add, add_zero]
next hx _ h ih =>
simp at hx
simp +zetaDelta at h
@ -103,10 +103,6 @@ theorem Poly.denoteN_combine' {α} [NatModule α] (ctx : Context α) (fuel : Nat
cases h₂; next h₂ =>
simp [ih h₁ h₂, *]; ac_rfl
theorem Poly.denoteN_combine {α} [NatModule α] (ctx : Context α) (p₁ p₂ : Poly)
: p₁.NonnegCoeffs → p₂.NonnegCoeffs → (p₁.combine p₂).denoteN ctx = p₁.denoteN ctx + p₂.denoteN ctx := by
intros; simp [combine, denoteN_combine', *]
theorem Poly.denoteN_mul' {α} [NatModule α] (ctx : Context α) (p : Poly) (k : Nat) : p.NonnegCoeffs → (p.mul' k).denoteN ctx = k • p.denoteN ctx := by
induction p <;> simp [mul', *, nsmul_zero]
next ih =>
@ -151,9 +147,8 @@ theorem Poly.append_Nonneg (p₁ p₂ : Poly) : p₁.NonnegCoeffs → p₂.Nonne
fun_induction append <;> intro h₁ h₂; simp [*]
next ih => cases h₁; constructor; assumption; apply ih <;> assumption
theorem Poly.combine'_Nonneg (fuel : Nat) (p₁ p₂ : Poly) : p₁.NonnegCoeffs → p₂.NonnegCoeffs → (p₁.combine' fuel p₂).NonnegCoeffs := by
fun_induction Poly.combine'
next => apply Poly.append_Nonneg
theorem Poly.combine_Nonneg (p₁ p₂ : Poly) : p₁.NonnegCoeffs → p₂.NonnegCoeffs → (p₁.combine p₂).NonnegCoeffs := by
fun_induction Poly.combine
next => intros; assumption
next => intros; assumption
next ih =>
@ -172,9 +167,6 @@ theorem Poly.combine'_Nonneg (fuel : Nat) (p₁ p₂ : Poly) : p₁.NonnegCoeffs
constructor; assumption
apply ih; constructor; assumption; assumption; assumption
theorem Poly.combine_Nonneg (p₁ p₂ : Poly) : p₁.NonnegCoeffs → p₂.NonnegCoeffs → (p₁.combine p₂).NonnegCoeffs := by
simp [combine]; apply Poly.combine'_Nonneg
theorem Expr.toPolyN_Nonneg (e : Expr) : e.toPolyN.NonnegCoeffs := by
fun_induction toPolyN <;> try constructor <;> simp
next => constructor; simp; constructor

View file

@ -123,26 +123,22 @@ def Poly.append (p₁ p₂ : Poly) : Poly :=
| .nil => p₂
| .add k x p₁ => .add k x (append p₁ p₂)
def Poly.combine' (fuel : Nat) (p₁ p₂ : Poly) : Poly :=
match fuel with
| 0 => p₁.append p₂
| fuel + 1 => match p₁, p₂ with
def Poly.combine (p₁ p₂ : Poly) : Poly :=
match p₁, p₂ with
| .nil, p₂ => p₂
| p₁, .nil => p₁
| .add a₁ x₁ p₁, .add a₂ x₂ p₂ =>
bif Nat.beq x₁ x₂ then
let a := a₁ + a₂
bif a == 0 then
combine' fuel p₁ p₂
combine p₁ p₂
else
.add a x₁ (combine' fuel p₁ p₂)
.add a x₁ (combine p₁ p₂)
else bif Nat.blt x₂ x₁ then
.add a₁ x₁ (combine' fuel p₁ (.add a₂ x₂ p₂))
.add a₁ x₁ (combine p₁ (.add a₂ x₂ p₂))
else
.add a₂ x₂ (combine' fuel (.add a₁ x₁ p₁) p₂)
def Poly.combine (p₁ p₂ : Poly) : Poly :=
combine' 100000000 p₁ p₂
.add a₂ x₂ (combine (.add a₁ x₁ p₁) p₂)
termination_by sizeOf p₁ + sizeOf p₂
/-- Converts the given expression into a polynomial. -/
def Expr.toPoly' (e : Expr) : Poly :=
@ -205,8 +201,8 @@ theorem Poly.denote_append {α} [IntModule α] (ctx : Context α) (p₁ p₂ : P
attribute [local simp] Poly.denote_append
theorem Poly.denote_combine' {α} [IntModule α] (ctx : Context α) (fuel : Nat) (p₁ p₂ : Poly) : (p₁.combine' fuel p₂).denote ctx = p₁.denote ctx + p₂.denote ctx := by
fun_induction p₁.combine' fuel p₂ <;>
theorem Poly.denote_combine {α} [IntModule α] (ctx : Context α) (p₁ p₂ : Poly) : (p₁.combine p₂).denote ctx = p₁.denote ctx + p₂.denote ctx := by
fun_induction p₁.combine p₂ <;>
simp_all +zetaDelta [denote]
next h _ =>
rw [Int.add_comm] at h
@ -214,9 +210,6 @@ theorem Poly.denote_combine' {α} [IntModule α] (ctx : Context α) (fuel : Nat)
next => rw [add_zsmul]; ac_rfl
all_goals ac_rfl
theorem Poly.denote_combine {α} [IntModule α] (ctx : Context α) (p₁ p₂ : Poly) : (p₁.combine p₂).denote ctx = p₁.denote ctx + p₂.denote ctx := by
simp [combine, denote_combine']
attribute [local simp] Poly.denote_combine
private theorem Expr.denote_toPoly'_go {α} [IntModule α] {k p} (ctx : Context α) (e : Expr)

View file

@ -889,26 +889,22 @@ where
| .num k' => acc.insert (k*k' % c) m
| .add k' m' p => go p (acc.insert (k*k' % c) (m.mul_nc m'))
def Poly.combineC (p₁ p₂ : Poly) (c : Nat) : Poly :=
go hugeFuel p₁ p₂
where
go (fuel : Nat) (p₁ p₂ : Poly) : Poly :=
match fuel with
| 0 => p₁.concat p₂
| fuel + 1 => match p₁, p₂ with
| .num k₁, .num k₂ => .num ((k₁ + k₂) % c)
| .num k₁, .add k₂ m₂ p₂ => addConstC (.add k₂ m₂ p₂) k₁ c
| .add k₁ m₁ p₁, .num k₂ => addConstC (.add k₁ m₁ p₁) k₂ c
| .add k₁ m₁ p₁, .add k₂ m₂ p₂ =>
match m₁.grevlex m₂ with
| .eq =>
let k := (k₁ + k₂) % c
bif k == 0 then
go fuel p₁ p₂
else
.add k m₁ (go fuel p₁ p₂)
| .gt => .add k₁ m₁ (go fuel p₁ (.add k₂ m₂ p₂))
| .lt => .add k₂ m₂ (go fuel (.add k₁ m₁ p₁) p₂)
@[semireducible]
def Poly.combineC (p₁ p₂ : Poly) (c : Nat) : Poly := match p₁, p₂ with
| .num k₁, .num k₂ => .num ((k₁ + k₂) % c)
| .num k₁, .add k₂ m₂ p₂ => addConstC (.add k₂ m₂ p₂) k₁ c
| .add k₁ m₁ p₁, .num k₂ => addConstC (.add k₁ m₁ p₁) k₂ c
| .add k₁ m₁ p₁, .add k₂ m₂ p₂ =>
match m₁.grevlex m₂ with
| .eq =>
let k := (k₁ + k₂) % c
bif k == 0 then
combineC p₁ p₂ c
else
.add k m₁ (combineC p₁ p₂ c)
| .gt => .add k₁ m₁ (combineC p₁ (.add k₂ m₂ p₂) c)
| .lt => .add k₂ m₂ (combineC (.add k₁ m₁ p₁) p₂ c)
termination_by sizeOf p₁ + sizeOf p₂
def Poly.mulC (p₁ : Poly) (p₂ : Poly) (c : Nat) : Poly :=
go p₁ (.num 0)
@ -1432,10 +1428,9 @@ theorem Poly.denote_mulMonC_nc {α c} [Ring α] [IsCharP α c] (ctx : Context α
theorem Poly.denote_combineC {α c} [Ring α] [IsCharP α c] (ctx : Context α) (p₁ p₂ : Poly)
: (combineC p₁ p₂ c).denote ctx = p₁.denote ctx + p₂.denote ctx := by
unfold combineC; generalize hugeFuel = fuel
fun_induction combineC.go
<;> simp [*, denote_concat, denote_addConstC, denote, intCast_add,
add_comm, add_left_comm, add_assoc, IsCharP.intCast_emod, zsmul_eq_intCast_mul]
fun_induction combineC
<;> simp [*, denote_addConstC, denote, intCast_add, add_comm, add_left_comm, add_assoc,
IsCharP.intCast_emod, zsmul_eq_intCast_mul]
next hg _ h _ =>
simp +zetaDelta at h
rw [← add_assoc, Mon.eq_of_grevlex hg, ← right_distrib, ← intCast_add,