From 87ec768a509c45987a7d73258da0cc1ef047925c Mon Sep 17 00:00:00 2001 From: Markus Himmel Date: Wed, 25 Feb 2026 07:24:15 +0100 Subject: [PATCH] fix: ensure that tail-recursive `List.flatten` is used everywhere (#12678) This PR marks `List.flatten`, `List.flatMap`, `List.intercalate` as noncomputable to ensure that their `csimp` variants are used everywhere. We also mark `List.flatMapM` as noncomputable and provide a tail-recursive implementation, and mark `List.utf8Encode` as noncomputable, which only exists for specification purposes anyway (at this point). Closes #12676. --- src/Init/Data/List.lean | 1 + src/Init/Data/List/Basic.lean | 2 +- src/Init/Data/List/Control.lean | 6 +- src/Init/Data/List/ControlImpl.lean | 35 +++++++++ src/Init/Prelude.lean | 6 +- src/Std/Data/DHashMap/Internal/Defs.lean | 1 + tests/lean/csimpCore.lean.expected.out | 92 ------------------------ tests/lean/run/12676.lean | 5 ++ tests/lean/{ => run}/csimpCore.lean | 5 +- 9 files changed, 51 insertions(+), 102 deletions(-) create mode 100644 src/Init/Data/List/ControlImpl.lean delete mode 100644 tests/lean/csimpCore.lean.expected.out create mode 100644 tests/lean/run/12676.lean rename tests/lean/{ => run}/csimpCore.lean (94%) diff --git a/src/Init/Data/List.lean b/src/Init/Data/List.lean index 8d8f4b8622..a72de15475 100644 --- a/src/Init/Data/List.lean +++ b/src/Init/Data/List.lean @@ -36,3 +36,4 @@ public import Init.Data.List.FinRange public import Init.Data.List.Lex public import Init.Data.List.Range public import Init.Data.List.Scan +public import Init.Data.List.ControlImpl diff --git a/src/Init/Data/List/Basic.lean b/src/Init/Data/List/Basic.lean index f5d8700c0f..ed925dd1fd 100644 --- a/src/Init/Data/List/Basic.lean +++ b/src/Init/Data/List/Basic.lean @@ -2186,7 +2186,7 @@ Examples: * `List.intercalate sep [a, b] = a ++ sep ++ b` * `List.intercalate sep [a, b, c] = a ++ sep ++ b ++ sep ++ c` -/ -def intercalate (sep : List α) (xs : List (List α)) : List α := +noncomputable def intercalate (sep : List α) (xs : List (List α)) : List α := (intersperse sep xs).flatten /-! ### eraseDupsBy -/ diff --git a/src/Init/Data/List/Control.lean b/src/Init/Data/List/Control.lean index 0d35b80b42..277fda5e1b 100644 --- a/src/Init/Data/List/Control.lean +++ b/src/Init/Data/List/Control.lean @@ -219,9 +219,9 @@ def filterMapM {m : Type u → Type v} [Monad m] {α : Type w} {β : Type u} (f Applies a monadic function that returns a list to each element of a list, from left to right, and concatenates the resulting lists. -/ -@[inline, expose] -def flatMapM {m : Type u → Type v} [Monad m] {α : Type w} {β : Type u} (f : α → m (List β)) (as : List α) : m (List β) := - let rec @[specialize] loop +@[expose] +noncomputable def flatMapM {m : Type u → Type v} [Monad m] {α : Type w} {β : Type u} (f : α → m (List β)) (as : List α) : m (List β) := + let rec loop | [], bs => pure bs.reverse.flatten | a :: as, bs => do let bs' ← f a diff --git a/src/Init/Data/List/ControlImpl.lean b/src/Init/Data/List/ControlImpl.lean new file mode 100644 index 0000000000..f479d16566 --- /dev/null +++ b/src/Init/Data/List/ControlImpl.lean @@ -0,0 +1,35 @@ +/- +Copyright (c) 2026 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Author: Markus Himmel +-/ +module + +prelude +public import Init.Data.List.Control +public import Init.Data.List.Impl + +public section + +namespace List + +/-- +Applies a monadic function that returns a list to each element of a list, from left to right, and +concatenates the resulting lists. +-/ +@[inline, expose] +def flatMapMTR {m : Type u → Type v} [Monad m] {α : Type w} {β : Type u} (f : α → m (List β)) (as : List α) : m (List β) := + let rec @[specialize] loop + | [], bs => pure bs.reverse.flatten + | a :: as, bs => do + let bs' ← f a + loop as (bs' :: bs) + loop as [] + +@[csimp] theorem flatMapM_eq_flatMapMTR : @flatMapM = @flatMapMTR := by + funext m _ α β f l + simp only [flatMapM, flatMapMTR] + generalize [] = m + fun_induction flatMapM.loop <;> simp_all [flatMapMTR.loop] + +end List diff --git a/src/Init/Prelude.lean b/src/Init/Prelude.lean index a64b888b55..1592caf251 100644 --- a/src/Init/Prelude.lean +++ b/src/Init/Prelude.lean @@ -3098,7 +3098,7 @@ Examples: * `[["a"], ["b", "c"]].flatten = ["a", "b", "c"]` * `[["a"], [], ["b", "c"], ["d", "e", "f"]].flatten = ["a", "b", "c", "d", "e", "f"]` -/ -def List.flatten : List (List α) → List α +noncomputable def List.flatten : List (List α) → List α | nil => nil | cons l L => List.append l (flatten L) @@ -3125,7 +3125,7 @@ Examples: * `[2, 3, 2].flatMap List.range = [0, 1, 0, 1, 2, 0, 1]` * `["red", "blue"].flatMap String.toList = ['r', 'e', 'd', 'b', 'l', 'u', 'e']` -/ -@[inline] def List.flatMap {α : Type u} {β : Type v} (b : α → List β) (as : List α) : List β := flatten (map b as) +@[inline] noncomputable def List.flatMap {α : Type u} {β : Type v} (b : α → List β) (as : List α) : List β := flatten (map b as) /-- `Array α` is the type of [dynamic arrays](https://en.wikipedia.org/wiki/Dynamic_array) with elements @@ -3453,7 +3453,7 @@ def String.utf8EncodeChar (c : Char) : List UInt8 := /-- Encode a list of characters (Unicode scalar value) in UTF-8. This is an inefficient model implementation. Use `List.asString` instead. -/ -def List.utf8Encode (l : List Char) : ByteArray := +noncomputable def List.utf8Encode (l : List Char) : ByteArray := l.flatMap String.utf8EncodeChar |>.toByteArray /-- A byte array is valid UTF-8 if it is of the form `List.Internal.utf8Encode m` for some `m`. diff --git a/src/Std/Data/DHashMap/Internal/Defs.lean b/src/Std/Data/DHashMap/Internal/Defs.lean index bcbd862f06..ddf61c7248 100644 --- a/src/Std/Data/DHashMap/Internal/Defs.lean +++ b/src/Std/Data/DHashMap/Internal/Defs.lean @@ -11,6 +11,7 @@ public import Std.Data.DHashMap.RawDef public import Std.Data.Internal.List.Defs public import Std.Data.DHashMap.Internal.Index public import Init.Data.Nat.Power2.Basic +import Init.Data.List.Impl import Init.Omega public section diff --git a/tests/lean/csimpCore.lean.expected.out b/tests/lean/csimpCore.lean.expected.out deleted file mode 100644 index 62c46a0d21..0000000000 --- a/tests/lean/csimpCore.lean.expected.out +++ /dev/null @@ -1,92 +0,0 @@ -(Acc.rec, Acc.recC) -(Array.instDecidableEmpEq, Array.instDecidableEmpEqImpl) -(Array.instDecidableEq, Array.instDecidableEqImpl) -(Array.instDecidableEqEmp, Array.instDecidableEqEmpImpl) -(Array.pmap, Array.pmapImpl) -(ByteArray.append, ByteArray.fastAppend) -(List.append, List.appendTR) -(List.dropLast, List.dropLastTR) -(List.erase, List.eraseTR) -(List.eraseIdx, List.eraseIdxTR) -(List.eraseP, List.erasePTR) -(List.filter, List.filterTR) -(List.findRev?, List.findRev?TR) -(List.findSomeRev?, List.findSomeRev?TR) -(List.flatMap, List.flatMapTR) -(List.flatten, List.flattenTR) -(List.foldr, List.foldrTR) -(List.insertIdx, List.insertIdxTR) -(List.intercalate, List.intercalateTR) -(List.intersperse, List.intersperseTR) -(List.leftpad, List.leftpadTR) -(List.length, List.lengthTR) -(List.map, List.mapTR) -(List.merge, List.MergeSort.Internal.mergeTR) -(List.mergeSort, List.MergeSort.Internal.mergeSortTR₂) -(List.modify, List.modifyTR) -(List.pmap, List.pmapImpl) -(List.range', List.range'TR) -(List.replace, List.replaceTR) -(List.replicate, List.replicateTR) -(List.set, List.setTR) -(List.take, List.takeTR) -(List.takeWhile, List.takeWhileTR) -(List.unzip, List.unzipTR) -(List.zipIdx, List.zipIdxTR) -(List.zipWith, List.zipWithTR) -(Nat.all, Nat.allTR) -(Nat.any, Nat.anyTR) -(Nat.fold, Nat.foldTR) -(Nat.rec, Nat.recCompiled) -(Nat.repeat, Nat.repeatTR) -(String.utf8EncodeChar, String.utf8EncodeCharFast) -(Thunk.fn, Thunk.fnImpl) -(Vector.pmap, Vector.pmapImpl) -(String.Slice.Pos.next, String.Slice.Pos.nextFast) -csimpCore.lean:56:0-56:11: error: ❌️ Docstring on `#guard_msgs` does not match generated message: - - info: (Acc.rec, Acc.recC) - (Array.instDecidableEmpEq, Array.instDecidableEmpEqImpl) - (Array.instDecidableEq, Array.instDecidableEqImpl) - (Array.instDecidableEqEmp, Array.instDecidableEqEmpImpl) - (Array.pmap, Array.pmapImpl) - (ByteArray.append, ByteArray.fastAppend) - (List.append, List.appendTR) - (List.dropLast, List.dropLastTR) - (List.erase, List.eraseTR) -+ (List.eraseIdx, List.eraseIdxTR) - (List.eraseP, List.erasePTR) - (List.filter, List.filterTR) - (List.findRev?, List.findRev?TR) -+ (List.findSomeRev?, List.findSomeRev?TR) - (List.flatMap, List.flatMapTR) - (List.flatten, List.flattenTR) - (List.foldr, List.foldrTR) - (List.insertIdx, List.insertIdxTR) - (List.intercalate, List.intercalateTR) - (List.intersperse, List.intersperseTR) - (List.leftpad, List.leftpadTR) - (List.length, List.lengthTR) - (List.map, List.mapTR) - (List.merge, List.MergeSort.Internal.mergeTR) - (List.mergeSort, List.MergeSort.Internal.mergeSortTR₂) - (List.modify, List.modifyTR) - (List.pmap, List.pmapImpl) - (List.range', List.range'TR) - (List.replace, List.replaceTR) - (List.replicate, List.replicateTR) - (List.set, List.setTR) - (List.take, List.takeTR) - (List.takeWhile, List.takeWhileTR) - (List.unzip, List.unzipTR) - (List.zipIdx, List.zipIdxTR) - (List.zipWith, List.zipWithTR) - (Nat.all, Nat.allTR) - (Nat.any, Nat.anyTR) - (Nat.fold, Nat.foldTR) - (Nat.rec, Nat.recCompiled) - (Nat.repeat, Nat.repeatTR) - (String.utf8EncodeChar, String.utf8EncodeCharFast) - (Thunk.fn, Thunk.fnImpl) - (Vector.pmap, Vector.pmapImpl) - (String.Slice.Pos.next, String.Slice.Pos.nextFast) diff --git a/tests/lean/run/12676.lean b/tests/lean/run/12676.lean new file mode 100644 index 0000000000..5f3c889028 --- /dev/null +++ b/tests/lean/run/12676.lean @@ -0,0 +1,5 @@ +module + +/-- info: 10000000 -/ +#guard_msgs in +#eval (List.range 10000000).flatMapM (m := Id) (fun d => pure [d]) |>.length diff --git a/tests/lean/csimpCore.lean b/tests/lean/run/csimpCore.lean similarity index 94% rename from tests/lean/csimpCore.lean rename to tests/lean/run/csimpCore.lean index 3cecc06017..26be0fda76 100644 --- a/tests/lean/csimpCore.lean +++ b/tests/lean/run/csimpCore.lean @@ -18,14 +18,13 @@ info: (Acc.rec, Acc.recC) (List.append, List.appendTR) (List.dropLast, List.dropLastTR) (List.erase, List.eraseTR) +(List.eraseIdx, List.eraseIdxTR) (List.eraseP, List.erasePTR) (List.filter, List.filterTR) (List.findRev?, List.findRev?TR) -(List.flatMap, List.flatMapTR) -(List.flatten, List.flattenTR) +(List.findSomeRev?, List.findSomeRev?TR) (List.foldr, List.foldrTR) (List.insertIdx, List.insertIdxTR) -(List.intercalate, List.intercalateTR) (List.intersperse, List.intersperseTR) (List.leftpad, List.leftpadTR) (List.length, List.lengthTR)