fix: ensure that tail-recursive List.flatten is used everywhere (#12678)

This PR marks `List.flatten`, `List.flatMap`, `List.intercalate` as
noncomputable to ensure that their `csimp` variants are used everywhere.

We also mark `List.flatMapM` as noncomputable and provide a
tail-recursive implementation, and mark `List.utf8Encode` as
noncomputable, which only exists for specification purposes anyway (at
this point).

Closes #12676.
This commit is contained in:
Markus Himmel 2026-02-25 07:24:15 +01:00 committed by GitHub
parent de65af8318
commit 87ec768a50
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 51 additions and 102 deletions

View file

@ -36,3 +36,4 @@ public import Init.Data.List.FinRange
public import Init.Data.List.Lex
public import Init.Data.List.Range
public import Init.Data.List.Scan
public import Init.Data.List.ControlImpl

View file

@ -2186,7 +2186,7 @@ Examples:
* `List.intercalate sep [a, b] = a ++ sep ++ b`
* `List.intercalate sep [a, b, c] = a ++ sep ++ b ++ sep ++ c`
-/
def intercalate (sep : List α) (xs : List (List α)) : List α :=
noncomputable def intercalate (sep : List α) (xs : List (List α)) : List α :=
(intersperse sep xs).flatten
/-! ### eraseDupsBy -/

View file

@ -219,9 +219,9 @@ def filterMapM {m : Type u → Type v} [Monad m] {α : Type w} {β : Type u} (f
Applies a monadic function that returns a list to each element of a list, from left to right, and
concatenates the resulting lists.
-/
@[inline, expose]
def flatMapM {m : Type u → Type v} [Monad m] {α : Type w} {β : Type u} (f : α → m (List β)) (as : List α) : m (List β) :=
let rec @[specialize] loop
@[expose]
noncomputable def flatMapM {m : Type u → Type v} [Monad m] {α : Type w} {β : Type u} (f : α → m (List β)) (as : List α) : m (List β) :=
let rec loop
| [], bs => pure bs.reverse.flatten
| a :: as, bs => do
let bs' ← f a

View file

@ -0,0 +1,35 @@
/-
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Markus Himmel
-/
module
prelude
public import Init.Data.List.Control
public import Init.Data.List.Impl
public section
namespace List
/--
Applies a monadic function that returns a list to each element of a list, from left to right, and
concatenates the resulting lists.
-/
@[inline, expose]
def flatMapMTR {m : Type u → Type v} [Monad m] {α : Type w} {β : Type u} (f : α → m (List β)) (as : List α) : m (List β) :=
let rec @[specialize] loop
| [], bs => pure bs.reverse.flatten
| a :: as, bs => do
let bs' ← f a
loop as (bs' :: bs)
loop as []
@[csimp] theorem flatMapM_eq_flatMapMTR : @flatMapM = @flatMapMTR := by
funext m _ α β f l
simp only [flatMapM, flatMapMTR]
generalize [] = m
fun_induction flatMapM.loop <;> simp_all [flatMapMTR.loop]
end List

View file

@ -3098,7 +3098,7 @@ Examples:
* `[["a"], ["b", "c"]].flatten = ["a", "b", "c"]`
* `[["a"], [], ["b", "c"], ["d", "e", "f"]].flatten = ["a", "b", "c", "d", "e", "f"]`
-/
def List.flatten : List (List α) → List α
noncomputable def List.flatten : List (List α) → List α
| nil => nil
| cons l L => List.append l (flatten L)
@ -3125,7 +3125,7 @@ Examples:
* `[2, 3, 2].flatMap List.range = [0, 1, 0, 1, 2, 0, 1]`
* `["red", "blue"].flatMap String.toList = ['r', 'e', 'd', 'b', 'l', 'u', 'e']`
-/
@[inline] def List.flatMap {α : Type u} {β : Type v} (b : α → List β) (as : List α) : List β := flatten (map b as)
@[inline] noncomputable def List.flatMap {α : Type u} {β : Type v} (b : α → List β) (as : List α) : List β := flatten (map b as)
/--
`Array α` is the type of [dynamic arrays](https://en.wikipedia.org/wiki/Dynamic_array) with elements
@ -3453,7 +3453,7 @@ def String.utf8EncodeChar (c : Char) : List UInt8 :=
/-- Encode a list of characters (Unicode scalar value) in UTF-8. This is an inefficient model
implementation. Use `List.asString` instead. -/
def List.utf8Encode (l : List Char) : ByteArray :=
noncomputable def List.utf8Encode (l : List Char) : ByteArray :=
l.flatMap String.utf8EncodeChar |>.toByteArray
/-- A byte array is valid UTF-8 if it is of the form `List.Internal.utf8Encode m` for some `m`.

View file

@ -11,6 +11,7 @@ public import Std.Data.DHashMap.RawDef
public import Std.Data.Internal.List.Defs
public import Std.Data.DHashMap.Internal.Index
public import Init.Data.Nat.Power2.Basic
import Init.Data.List.Impl
import Init.Omega
public section

View file

@ -1,92 +0,0 @@
(Acc.rec, Acc.recC)
(Array.instDecidableEmpEq, Array.instDecidableEmpEqImpl)
(Array.instDecidableEq, Array.instDecidableEqImpl)
(Array.instDecidableEqEmp, Array.instDecidableEqEmpImpl)
(Array.pmap, Array.pmapImpl)
(ByteArray.append, ByteArray.fastAppend)
(List.append, List.appendTR)
(List.dropLast, List.dropLastTR)
(List.erase, List.eraseTR)
(List.eraseIdx, List.eraseIdxTR)
(List.eraseP, List.erasePTR)
(List.filter, List.filterTR)
(List.findRev?, List.findRev?TR)
(List.findSomeRev?, List.findSomeRev?TR)
(List.flatMap, List.flatMapTR)
(List.flatten, List.flattenTR)
(List.foldr, List.foldrTR)
(List.insertIdx, List.insertIdxTR)
(List.intercalate, List.intercalateTR)
(List.intersperse, List.intersperseTR)
(List.leftpad, List.leftpadTR)
(List.length, List.lengthTR)
(List.map, List.mapTR)
(List.merge, List.MergeSort.Internal.mergeTR)
(List.mergeSort, List.MergeSort.Internal.mergeSortTR₂)
(List.modify, List.modifyTR)
(List.pmap, List.pmapImpl)
(List.range', List.range'TR)
(List.replace, List.replaceTR)
(List.replicate, List.replicateTR)
(List.set, List.setTR)
(List.take, List.takeTR)
(List.takeWhile, List.takeWhileTR)
(List.unzip, List.unzipTR)
(List.zipIdx, List.zipIdxTR)
(List.zipWith, List.zipWithTR)
(Nat.all, Nat.allTR)
(Nat.any, Nat.anyTR)
(Nat.fold, Nat.foldTR)
(Nat.rec, Nat.recCompiled)
(Nat.repeat, Nat.repeatTR)
(String.utf8EncodeChar, String.utf8EncodeCharFast)
(Thunk.fn, Thunk.fnImpl)
(Vector.pmap, Vector.pmapImpl)
(String.Slice.Pos.next, String.Slice.Pos.nextFast)
csimpCore.lean:56:0-56:11: error: ❌️ Docstring on `#guard_msgs` does not match generated message:
info: (Acc.rec, Acc.recC)
(Array.instDecidableEmpEq, Array.instDecidableEmpEqImpl)
(Array.instDecidableEq, Array.instDecidableEqImpl)
(Array.instDecidableEqEmp, Array.instDecidableEqEmpImpl)
(Array.pmap, Array.pmapImpl)
(ByteArray.append, ByteArray.fastAppend)
(List.append, List.appendTR)
(List.dropLast, List.dropLastTR)
(List.erase, List.eraseTR)
+ (List.eraseIdx, List.eraseIdxTR)
(List.eraseP, List.erasePTR)
(List.filter, List.filterTR)
(List.findRev?, List.findRev?TR)
+ (List.findSomeRev?, List.findSomeRev?TR)
(List.flatMap, List.flatMapTR)
(List.flatten, List.flattenTR)
(List.foldr, List.foldrTR)
(List.insertIdx, List.insertIdxTR)
(List.intercalate, List.intercalateTR)
(List.intersperse, List.intersperseTR)
(List.leftpad, List.leftpadTR)
(List.length, List.lengthTR)
(List.map, List.mapTR)
(List.merge, List.MergeSort.Internal.mergeTR)
(List.mergeSort, List.MergeSort.Internal.mergeSortTR₂)
(List.modify, List.modifyTR)
(List.pmap, List.pmapImpl)
(List.range', List.range'TR)
(List.replace, List.replaceTR)
(List.replicate, List.replicateTR)
(List.set, List.setTR)
(List.take, List.takeTR)
(List.takeWhile, List.takeWhileTR)
(List.unzip, List.unzipTR)
(List.zipIdx, List.zipIdxTR)
(List.zipWith, List.zipWithTR)
(Nat.all, Nat.allTR)
(Nat.any, Nat.anyTR)
(Nat.fold, Nat.foldTR)
(Nat.rec, Nat.recCompiled)
(Nat.repeat, Nat.repeatTR)
(String.utf8EncodeChar, String.utf8EncodeCharFast)
(Thunk.fn, Thunk.fnImpl)
(Vector.pmap, Vector.pmapImpl)
(String.Slice.Pos.next, String.Slice.Pos.nextFast)

View file

@ -0,0 +1,5 @@
module
/-- info: 10000000 -/
#guard_msgs in
#eval (List.range 10000000).flatMapM (m := Id) (fun d => pure [d]) |>.length

View file

@ -18,14 +18,13 @@ info: (Acc.rec, Acc.recC)
(List.append, List.appendTR)
(List.dropLast, List.dropLastTR)
(List.erase, List.eraseTR)
(List.eraseIdx, List.eraseIdxTR)
(List.eraseP, List.erasePTR)
(List.filter, List.filterTR)
(List.findRev?, List.findRev?TR)
(List.flatMap, List.flatMapTR)
(List.flatten, List.flattenTR)
(List.findSomeRev?, List.findSomeRev?TR)
(List.foldr, List.foldrTR)
(List.insertIdx, List.insertIdxTR)
(List.intercalate, List.intercalateTR)
(List.intersperse, List.intersperseTR)
(List.leftpad, List.leftpadTR)
(List.length, List.lengthTR)