fix: ensure that tail-recursive List.flatten is used everywhere (#12678)
This PR marks `List.flatten`, `List.flatMap`, `List.intercalate` as noncomputable to ensure that their `csimp` variants are used everywhere. We also mark `List.flatMapM` as noncomputable and provide a tail-recursive implementation, and mark `List.utf8Encode` as noncomputable, which only exists for specification purposes anyway (at this point). Closes #12676.
This commit is contained in:
parent
de65af8318
commit
87ec768a50
9 changed files with 51 additions and 102 deletions
|
|
@ -36,3 +36,4 @@ public import Init.Data.List.FinRange
|
|||
public import Init.Data.List.Lex
|
||||
public import Init.Data.List.Range
|
||||
public import Init.Data.List.Scan
|
||||
public import Init.Data.List.ControlImpl
|
||||
|
|
|
|||
|
|
@ -2186,7 +2186,7 @@ Examples:
|
|||
* `List.intercalate sep [a, b] = a ++ sep ++ b`
|
||||
* `List.intercalate sep [a, b, c] = a ++ sep ++ b ++ sep ++ c`
|
||||
-/
|
||||
def intercalate (sep : List α) (xs : List (List α)) : List α :=
|
||||
noncomputable def intercalate (sep : List α) (xs : List (List α)) : List α :=
|
||||
(intersperse sep xs).flatten
|
||||
|
||||
/-! ### eraseDupsBy -/
|
||||
|
|
|
|||
|
|
@ -219,9 +219,9 @@ def filterMapM {m : Type u → Type v} [Monad m] {α : Type w} {β : Type u} (f
|
|||
Applies a monadic function that returns a list to each element of a list, from left to right, and
|
||||
concatenates the resulting lists.
|
||||
-/
|
||||
@[inline, expose]
|
||||
def flatMapM {m : Type u → Type v} [Monad m] {α : Type w} {β : Type u} (f : α → m (List β)) (as : List α) : m (List β) :=
|
||||
let rec @[specialize] loop
|
||||
@[expose]
|
||||
noncomputable def flatMapM {m : Type u → Type v} [Monad m] {α : Type w} {β : Type u} (f : α → m (List β)) (as : List α) : m (List β) :=
|
||||
let rec loop
|
||||
| [], bs => pure bs.reverse.flatten
|
||||
| a :: as, bs => do
|
||||
let bs' ← f a
|
||||
|
|
|
|||
35
src/Init/Data/List/ControlImpl.lean
Normal file
35
src/Init/Data/List/ControlImpl.lean
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
/-
|
||||
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Author: Markus Himmel
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Init.Data.List.Control
|
||||
public import Init.Data.List.Impl
|
||||
|
||||
public section
|
||||
|
||||
namespace List
|
||||
|
||||
/--
|
||||
Applies a monadic function that returns a list to each element of a list, from left to right, and
|
||||
concatenates the resulting lists.
|
||||
-/
|
||||
@[inline, expose]
|
||||
def flatMapMTR {m : Type u → Type v} [Monad m] {α : Type w} {β : Type u} (f : α → m (List β)) (as : List α) : m (List β) :=
|
||||
let rec @[specialize] loop
|
||||
| [], bs => pure bs.reverse.flatten
|
||||
| a :: as, bs => do
|
||||
let bs' ← f a
|
||||
loop as (bs' :: bs)
|
||||
loop as []
|
||||
|
||||
@[csimp] theorem flatMapM_eq_flatMapMTR : @flatMapM = @flatMapMTR := by
|
||||
funext m _ α β f l
|
||||
simp only [flatMapM, flatMapMTR]
|
||||
generalize [] = m
|
||||
fun_induction flatMapM.loop <;> simp_all [flatMapMTR.loop]
|
||||
|
||||
end List
|
||||
|
|
@ -3098,7 +3098,7 @@ Examples:
|
|||
* `[["a"], ["b", "c"]].flatten = ["a", "b", "c"]`
|
||||
* `[["a"], [], ["b", "c"], ["d", "e", "f"]].flatten = ["a", "b", "c", "d", "e", "f"]`
|
||||
-/
|
||||
def List.flatten : List (List α) → List α
|
||||
noncomputable def List.flatten : List (List α) → List α
|
||||
| nil => nil
|
||||
| cons l L => List.append l (flatten L)
|
||||
|
||||
|
|
@ -3125,7 +3125,7 @@ Examples:
|
|||
* `[2, 3, 2].flatMap List.range = [0, 1, 0, 1, 2, 0, 1]`
|
||||
* `["red", "blue"].flatMap String.toList = ['r', 'e', 'd', 'b', 'l', 'u', 'e']`
|
||||
-/
|
||||
@[inline] def List.flatMap {α : Type u} {β : Type v} (b : α → List β) (as : List α) : List β := flatten (map b as)
|
||||
@[inline] noncomputable def List.flatMap {α : Type u} {β : Type v} (b : α → List β) (as : List α) : List β := flatten (map b as)
|
||||
|
||||
/--
|
||||
`Array α` is the type of [dynamic arrays](https://en.wikipedia.org/wiki/Dynamic_array) with elements
|
||||
|
|
@ -3453,7 +3453,7 @@ def String.utf8EncodeChar (c : Char) : List UInt8 :=
|
|||
|
||||
/-- Encode a list of characters (Unicode scalar value) in UTF-8. This is an inefficient model
|
||||
implementation. Use `List.asString` instead. -/
|
||||
def List.utf8Encode (l : List Char) : ByteArray :=
|
||||
noncomputable def List.utf8Encode (l : List Char) : ByteArray :=
|
||||
l.flatMap String.utf8EncodeChar |>.toByteArray
|
||||
|
||||
/-- A byte array is valid UTF-8 if it is of the form `List.Internal.utf8Encode m` for some `m`.
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ public import Std.Data.DHashMap.RawDef
|
|||
public import Std.Data.Internal.List.Defs
|
||||
public import Std.Data.DHashMap.Internal.Index
|
||||
public import Init.Data.Nat.Power2.Basic
|
||||
import Init.Data.List.Impl
|
||||
import Init.Omega
|
||||
|
||||
public section
|
||||
|
|
|
|||
|
|
@ -1,92 +0,0 @@
|
|||
(Acc.rec, Acc.recC)
|
||||
(Array.instDecidableEmpEq, Array.instDecidableEmpEqImpl)
|
||||
(Array.instDecidableEq, Array.instDecidableEqImpl)
|
||||
(Array.instDecidableEqEmp, Array.instDecidableEqEmpImpl)
|
||||
(Array.pmap, Array.pmapImpl)
|
||||
(ByteArray.append, ByteArray.fastAppend)
|
||||
(List.append, List.appendTR)
|
||||
(List.dropLast, List.dropLastTR)
|
||||
(List.erase, List.eraseTR)
|
||||
(List.eraseIdx, List.eraseIdxTR)
|
||||
(List.eraseP, List.erasePTR)
|
||||
(List.filter, List.filterTR)
|
||||
(List.findRev?, List.findRev?TR)
|
||||
(List.findSomeRev?, List.findSomeRev?TR)
|
||||
(List.flatMap, List.flatMapTR)
|
||||
(List.flatten, List.flattenTR)
|
||||
(List.foldr, List.foldrTR)
|
||||
(List.insertIdx, List.insertIdxTR)
|
||||
(List.intercalate, List.intercalateTR)
|
||||
(List.intersperse, List.intersperseTR)
|
||||
(List.leftpad, List.leftpadTR)
|
||||
(List.length, List.lengthTR)
|
||||
(List.map, List.mapTR)
|
||||
(List.merge, List.MergeSort.Internal.mergeTR)
|
||||
(List.mergeSort, List.MergeSort.Internal.mergeSortTR₂)
|
||||
(List.modify, List.modifyTR)
|
||||
(List.pmap, List.pmapImpl)
|
||||
(List.range', List.range'TR)
|
||||
(List.replace, List.replaceTR)
|
||||
(List.replicate, List.replicateTR)
|
||||
(List.set, List.setTR)
|
||||
(List.take, List.takeTR)
|
||||
(List.takeWhile, List.takeWhileTR)
|
||||
(List.unzip, List.unzipTR)
|
||||
(List.zipIdx, List.zipIdxTR)
|
||||
(List.zipWith, List.zipWithTR)
|
||||
(Nat.all, Nat.allTR)
|
||||
(Nat.any, Nat.anyTR)
|
||||
(Nat.fold, Nat.foldTR)
|
||||
(Nat.rec, Nat.recCompiled)
|
||||
(Nat.repeat, Nat.repeatTR)
|
||||
(String.utf8EncodeChar, String.utf8EncodeCharFast)
|
||||
(Thunk.fn, Thunk.fnImpl)
|
||||
(Vector.pmap, Vector.pmapImpl)
|
||||
(String.Slice.Pos.next, String.Slice.Pos.nextFast)
|
||||
csimpCore.lean:56:0-56:11: error: ❌️ Docstring on `#guard_msgs` does not match generated message:
|
||||
|
||||
info: (Acc.rec, Acc.recC)
|
||||
(Array.instDecidableEmpEq, Array.instDecidableEmpEqImpl)
|
||||
(Array.instDecidableEq, Array.instDecidableEqImpl)
|
||||
(Array.instDecidableEqEmp, Array.instDecidableEqEmpImpl)
|
||||
(Array.pmap, Array.pmapImpl)
|
||||
(ByteArray.append, ByteArray.fastAppend)
|
||||
(List.append, List.appendTR)
|
||||
(List.dropLast, List.dropLastTR)
|
||||
(List.erase, List.eraseTR)
|
||||
+ (List.eraseIdx, List.eraseIdxTR)
|
||||
(List.eraseP, List.erasePTR)
|
||||
(List.filter, List.filterTR)
|
||||
(List.findRev?, List.findRev?TR)
|
||||
+ (List.findSomeRev?, List.findSomeRev?TR)
|
||||
(List.flatMap, List.flatMapTR)
|
||||
(List.flatten, List.flattenTR)
|
||||
(List.foldr, List.foldrTR)
|
||||
(List.insertIdx, List.insertIdxTR)
|
||||
(List.intercalate, List.intercalateTR)
|
||||
(List.intersperse, List.intersperseTR)
|
||||
(List.leftpad, List.leftpadTR)
|
||||
(List.length, List.lengthTR)
|
||||
(List.map, List.mapTR)
|
||||
(List.merge, List.MergeSort.Internal.mergeTR)
|
||||
(List.mergeSort, List.MergeSort.Internal.mergeSortTR₂)
|
||||
(List.modify, List.modifyTR)
|
||||
(List.pmap, List.pmapImpl)
|
||||
(List.range', List.range'TR)
|
||||
(List.replace, List.replaceTR)
|
||||
(List.replicate, List.replicateTR)
|
||||
(List.set, List.setTR)
|
||||
(List.take, List.takeTR)
|
||||
(List.takeWhile, List.takeWhileTR)
|
||||
(List.unzip, List.unzipTR)
|
||||
(List.zipIdx, List.zipIdxTR)
|
||||
(List.zipWith, List.zipWithTR)
|
||||
(Nat.all, Nat.allTR)
|
||||
(Nat.any, Nat.anyTR)
|
||||
(Nat.fold, Nat.foldTR)
|
||||
(Nat.rec, Nat.recCompiled)
|
||||
(Nat.repeat, Nat.repeatTR)
|
||||
(String.utf8EncodeChar, String.utf8EncodeCharFast)
|
||||
(Thunk.fn, Thunk.fnImpl)
|
||||
(Vector.pmap, Vector.pmapImpl)
|
||||
(String.Slice.Pos.next, String.Slice.Pos.nextFast)
|
||||
5
tests/lean/run/12676.lean
Normal file
5
tests/lean/run/12676.lean
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
module
|
||||
|
||||
/-- info: 10000000 -/
|
||||
#guard_msgs in
|
||||
#eval (List.range 10000000).flatMapM (m := Id) (fun d => pure [d]) |>.length
|
||||
|
|
@ -18,14 +18,13 @@ info: (Acc.rec, Acc.recC)
|
|||
(List.append, List.appendTR)
|
||||
(List.dropLast, List.dropLastTR)
|
||||
(List.erase, List.eraseTR)
|
||||
(List.eraseIdx, List.eraseIdxTR)
|
||||
(List.eraseP, List.erasePTR)
|
||||
(List.filter, List.filterTR)
|
||||
(List.findRev?, List.findRev?TR)
|
||||
(List.flatMap, List.flatMapTR)
|
||||
(List.flatten, List.flattenTR)
|
||||
(List.findSomeRev?, List.findSomeRev?TR)
|
||||
(List.foldr, List.foldrTR)
|
||||
(List.insertIdx, List.insertIdxTR)
|
||||
(List.intercalate, List.intercalateTR)
|
||||
(List.intersperse, List.intersperseTR)
|
||||
(List.leftpad, List.leftpadTR)
|
||||
(List.length, List.lengthTR)
|
||||
Loading…
Add table
Reference in a new issue