diff --git a/src/Init/Data/Array/BinSearch.lean b/src/Init/Data/Array/BinSearch.lean index 530c5938b5..5656bcaab8 100644 --- a/src/Init/Data/Array/BinSearch.lean +++ b/src/Init/Data/Array/BinSearch.lean @@ -29,37 +29,38 @@ binSearchAux lt id as k lo hi @[inline] def binSearchContains {α : Type} [Inhabited α] (as : Array α) (k : α) (lt : α → α → Bool) (lo := 0) (hi := as.size - 1) : Bool := binSearchAux lt Option.isSome as k lo hi -@[specialize] partial def binInsertAuxAux {α : Type u} [Inhabited α] +@[specialize] private partial def binInsertAux {α : Type u} {m : Type u → Type v} [Monad m] [Inhabited α] (lt : α → α → Bool) - (merge : α → α) - (add : Unit → α) + (merge : α → m α) + (add : Unit → m α) (as : Array α) - (k : α) : Nat → Nat → Array α + (k : α) : Nat → Nat → m (Array α) | lo, hi => -- as[lo] < k < as[hi] - let m := (lo + hi)/2; - if lt (as.get! m) k then - if m == lo then as.insertAt (lo+1) (add ()) - else binInsertAuxAux m hi - else if lt k (as.get! m) then - binInsertAuxAux lo m - else - as.modify m $ fun a => merge a + let mid := (lo + hi)/2; + let midVal := as.get! mid; + if lt midVal k then + if mid == lo then do v ← add (); pure $ as.insertAt (lo+1) v + else binInsertAux mid hi + else if lt k midVal then + binInsertAux lo mid + else do + as.modifyM mid $ fun v => merge v -@[specialize] partial def binInsertAux {α : Type u} [Inhabited α] +@[specialize] partial def binInsertM {α : Type u} {m : Type u → Type v} [Monad m] [Inhabited α] (lt : α → α → Bool) - (merge : α → α) - (add : Unit → α) + (merge : α → m α) + (add : Unit → m α) (as : Array α) - (k : α) : Array α := -if as.isEmpty then as.push (add ()) -else if lt k (as.get! 0) then as.insertAt 0 (add ()) -else if !lt (as.get! 0) k then as.modify 0 $ fun a => merge a -else if lt as.back k then as.push (add ()) -else if !lt k as.back then as.modify (as.size - 1) $ fun a => merge a -else binInsertAuxAux lt merge add as k 0 (as.size - 1) + (k : α) : m (Array α) := +if as.isEmpty then do v ← add (); pure $ as.push v +else if lt k (as.get! 0) then do v ← add (); pure $ as.insertAt 0 v +else if !lt (as.get! 0) k then as.modifyM 0 $ merge +else if lt as.back k then do v ← add (); pure $ as.push v +else if !lt k as.back then as.modifyM (as.size - 1) $ merge +else binInsertAux lt merge add as k 0 (as.size - 1) @[inline] def binInsert {α : Type u} [Inhabited α] (lt : α → α → Bool) (as : Array α) (k : α) : Array α := -binInsertAux lt (fun _ => k) (fun _ => k) as k +Id.run $ binInsertM lt (fun _ => k) (fun _ => k) as k end Array diff --git a/tests/playground/DiscrTree.lean b/tests/playground/DiscrTree.lean index 08a870d610..c8d22fbd4d 100644 --- a/tests/playground/DiscrTree.lean +++ b/tests/playground/DiscrTree.lean @@ -79,8 +79,8 @@ partial def insertAux {α} [HasBeq α] (v : α) : Array Term → Trie α → Tri let todo := todo.pop; let todo := appendTodo todo t.args; let k := t.key; - node vs $ - cs.binInsertAux + node vs $ Id.run $ + cs.binInsertM (fun a b => a.1 < b.1) (fun ⟨_, s⟩ => (k, insertAux todo s)) -- merge with existing (fun _ => (k, createNodes v todo)) -- add new node