perf: Use flat ByteArrays in Trie (#2529)

This commit is contained in:
Joachim Breitner 2023-09-20 13:22:37 +02:00 committed by GitHub
parent de76a5d922
commit b2d668c340
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 282 additions and 80 deletions

View file

@ -33,6 +33,10 @@ opaque fromUTF8Unchecked (a : @& ByteArray) : String
@[extern "lean_string_to_utf8"]
opaque toUTF8 (a : @& String) : ByteArray
/-- Accesses a byte in the UTF-8 encoding of the `String`. O(1) -/
@[extern "lean_string_get_byte_fast"]
opaque getUtf8Byte (s : @& String) (n : Nat) (h : n < s.utf8ByteSize) : UInt8
theorem Iterator.sizeOf_next_lt_of_hasNext (i : String.Iterator) (h : i.hasNext) : sizeOf i.next < sizeOf i := by
cases i; rename_i s pos; simp [Iterator.next, Iterator.sizeOf_eq]; simp [Iterator.hasNext] at h
exact Nat.sub_lt_sub_left h (String.lt_next s pos)

View file

@ -1,110 +1,202 @@
/-
Copyright (c) 2018 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Sebastian Ullrich, Leonardo de Moura
Author: Sebastian Ullrich, Leonardo de Moura, Joachim Breitner
Trie for tokenizing the Lean language
A string trie data structure, used for tokenizing the Lean language
-/
import Lean.Data.Format
namespace Lean
namespace Parser
namespace Data
/-
## Implementation notes
Tries have typically many nodes with small degree, where a linear scan
through the (compact) `ByteArray` is faster than using binary search or
search trees like `RBTree`.
Moreover, many nodes have degree 1, which justifies the special case `Node1`
constructor.
The code would be a bit less repetitive if we used something like the following
```
mutual
def Trie α := Option α × ByteAssoc α
inductive ByteAssoc α where
| leaf : Trie α
| node1 : UInt8 → Trie α → Trie α
| node : ByteArray → Array (Trie α) → Trie α
end
```
but that would come at the cost of extra indirections.
-/
/-- A Trie is a key-value store where the keys are of type `String`,
and the internal structure is a tree that branches on the bytes of the string. -/
inductive Trie (α : Type) where
| Node : Option α → RBNode Char (fun _ => Trie α) → Trie α
| leaf : Option α → Trie α
| node1 : Option α → UInt8 → Trie α → Trie α
| node : Option α → ByteArray → Array (Trie α) → Trie α
namespace Trie
variable {α : Type}
def empty : Trie α :=
⟨none, RBNode.leaf⟩
/-- The empty `Trie` -/
def empty : Trie α := leaf none
instance : EmptyCollection (Trie α) :=
⟨empty⟩
instance : Inhabited (Trie α) where
default := Node none RBNode.leaf
default := empty
partial def insert (t : Trie α) (s : String) (val : α) : Trie α :=
let rec insertEmpty (i : String.Pos) : Trie α :=
match s.atEnd i with
| true => Trie.Node (some val) RBNode.leaf
| false =>
let c := s.get i
let t := insertEmpty (s.next i)
Trie.Node none (RBNode.singleton c t)
/-- Insert or update the value at a the given key `s`. -/
partial def upsert (t : Trie α) (s : String) (f : Option αα) : Trie α :=
let rec insertEmpty (i : Nat) : Trie α :=
if h : i < s.utf8ByteSize then
let c := s.getUtf8Byte i h
let t := insertEmpty (i + 1)
node1 none c t
else
leaf (f .none)
let rec loop
| Trie.Node v m, i =>
match s.atEnd i with
| true => Trie.Node (some val) m -- overrides old value
| false =>
let c := s.get i
let i := s.next i
let t := match RBNode.find compare m c with
| none => insertEmpty i
| some t => loop t i
Trie.Node v (RBNode.insert compare m c t)
loop t 0
| i, leaf v =>
if h : i < s.utf8ByteSize then
let c := s.getUtf8Byte i h
let t := insertEmpty (i + 1)
node1 v c t
else
leaf (f v)
| i, node1 v c' t' =>
if h : i < s.utf8ByteSize then
let c := s.getUtf8Byte i h
if c == c'
then node1 v c' (loop (i + 1) t')
else
let t := insertEmpty (i + 1)
node v (.mk #[c, c']) #[t, t']
else
node1 (f v) c' t'
| i, node v cs ts =>
if h : i < s.utf8ByteSize then
let c := s.getUtf8Byte i h
match cs.findIdx? (· == c) with
| none =>
let t := insertEmpty (i + 1)
node v (cs.push c) (ts.push t)
| some idx =>
node v cs (ts.modify idx (loop (i + 1)))
else
node (f v) cs ts
loop 0 t
/-- Inserts a value at a the given key `s`, overriding an existing value if present. -/
partial def insert (t : Trie α) (s : String) (val : α) : Trie α :=
upsert t s (fun _ => val)
/-- Looks up a value at the given key `s`. -/
partial def find? (t : Trie α) (s : String) : Option α :=
let rec loop
| Trie.Node val m, i =>
match s.atEnd i with
| true => val
| false =>
let c := s.get i
let i := s.next i
match RBNode.find compare m c with
| i, leaf val =>
if i < s.utf8ByteSize then
none
else
val
| i, node1 val c' t' =>
if h : i < s.utf8ByteSize then
let c := s.getUtf8Byte i h
if c == c'
then loop (i + 1) t'
else none
else
val
| i, node val cs ts =>
if h : i < s.utf8ByteSize then
let c := s.getUtf8Byte i h
match cs.findIdx? (· == c) with
| none => none
| some t => loop t i
loop t 0
| some idx => loop (i + 1) (ts.get! idx)
else
val
loop 0 t
/-- Return values that match the given `prefix` -/
partial def findPrefix (t : Trie α) (pre : String) : Array α :=
go t 0 |>.run #[] |>.2
where
go (t : Trie α) (i : String.Pos) : StateM (Array α) Unit :=
if pre.atEnd i then
collect t
else
let k := pre.get i
let i := pre.next i
let ⟨_, cs⟩ := t
cs.forM fun k' c => do
if k == k' then go c i
/-- Returns an `Array` of all values in the trie, in no particular order. -/
partial def values (t : Trie α) : Array α := go t |>.run #[] |>.2
where
go : Trie α → StateM (Array α) Unit
| leaf a? => do
if let some a := a? then
modify (·.push a)
| node1 a? _ t' => do
if let some a := a? then
modify (·.push a)
go t'
| node a? _ ts => do
if let some a := a? then
modify (·.push a)
ts.forM fun t' => go t'
collect (t : Trie α) : StateM (Array α) Unit := do
let ⟨a?, cs⟩ := t
if let some a := a? then
modify (·.push a)
cs.forM fun _ c => collect c
/-- Returns all values whose key have the given string `pre` as a prefix, in no particular order. -/
partial def findPrefix (t : Trie α) (pre : String) : Array α := go t 0
where
go (t : Trie α) (i : Nat) : Array α :=
if h : i < pre.utf8ByteSize then
let c := pre.getUtf8Byte i h
match t with
| leaf _val => .empty
| node1 _val c' t' =>
if c == c'
then go t' (i + 1)
else .empty
| node _val cs ts =>
match cs.findIdx? (· == c) with
| none => .empty
| some idx => go (ts.get! idx) (i + 1)
else
t.values
private def updtAcc (v : Option α) (i : String.Pos) (acc : String.Pos × Option α) : String.Pos × Option α :=
match v, acc with
| some v, (_, _) => (i, some v) -- we pattern match on `acc` to enable memory reuse
| none, acc => acc
partial def matchPrefix (s : String) (t : Trie α) (i : String.Pos) : String.Pos × Option α :=
/-- Find the longest _key_ in the trie that is contained in the given string `s` at position `i`,
and return the associated value. -/
partial def matchPrefix (s : String) (t : Trie α) (i : String.Pos) : Option α :=
let rec loop
| Trie.Node v m, i, acc =>
match s.atEnd i with
| true => updtAcc v i acc
| false =>
let acc := updtAcc v i acc
let c := s.get i
let i := s.next i
match RBNode.find compare m c with
| some t => loop t i acc
| none => acc
loop t i (i, none)
| leaf v, _, res =>
if v.isSome then v else res
| node1 v c' t', i, res =>
let res := if v.isSome then v else res
if h : i < s.utf8ByteSize then
let c := s.getUtf8Byte i h
if c == c'
then loop t' (i + 1) res
else res
else
res
| node v cs ts, i, res =>
let res := if v.isSome then v else res
if h : i < s.utf8ByteSize then
let c := s.getUtf8Byte i h
match cs.findIdx? (· == c) with
| none => res
| some idx => loop (ts.get! idx) (i + 1) res
else
res
loop t i.byteIdx none
private partial def toStringAux {α : Type} : Trie α → List Format
| Trie.Node _ map => map.fold (fun Fs c t =>
format (repr c) :: (Format.group $ Format.nest 2 $ flip Format.joinSep Format.line $ toStringAux t) :: Fs) []
| leaf _ => []
| node1 _ c t =>
[ format (repr c), Format.group $ Format.nest 4 $ flip Format.joinSep Format.line $ toStringAux t ]
| node _ cs ts =>
List.join $ List.zipWith (fun c t =>
[ format (repr c), (Format.group $ Format.nest 4 $ flip Format.joinSep Format.line $ toStringAux t) ]
) cs.toList ts.toList
instance {α : Type} : ToString (Trie α) :=
⟨fun t => (flip Format.joinSep Format.line $ toStringAux t).pretty⟩
end Trie
end Parser
end Data
end Lean

View file

@ -821,7 +821,7 @@ private def tokenFnAux : ParserFn := fun c s =>
else if curr == '`' && isIdFirstOrBeginEscape (getNext input i) then
nameLitAux i c s
else
let (_, tk) := c.tokens.matchPrefix input i
let tk := c.tokens.matchPrefix input i
identFnAux i tk .anonymous c s
private def updateTokenCache (startPos : String.Pos) (s : ParserState) : ParserState :=

View file

@ -31,7 +31,7 @@ def minPrec : Nat := eval_prec min
abbrev Token := String
abbrev TokenTable := Trie Token
abbrev TokenTable := Lean.Data.Trie Token
abbrev SyntaxNodeKindSet := PersistentHashMap SyntaxNodeKind Unit

View file

@ -1014,6 +1014,11 @@ static inline uint32_t lean_string_utf8_get_fast(b_lean_obj_arg s, b_lean_obj_ar
if ((c & 0x80) == 0) return c;
return lean_string_utf8_get_fast_cold(str, idx, lean_string_size(s), c);
}
static inline uint8_t lean_string_get_byte_fast(b_lean_obj_arg s, b_lean_obj_arg i) {
char const * str = lean_string_cstr(s);
size_t idx = lean_unbox(i);
return str[idx];
}
LEAN_SHARED lean_obj_res lean_string_utf8_next(b_lean_obj_arg s, b_lean_obj_arg i);
LEAN_SHARED lean_obj_res lean_string_utf8_next_fast_cold(size_t i, unsigned char c);

View file

@ -6,7 +6,7 @@ options get_default_options() {
// see https://leanprover.github.io/lean4/doc/dev/bootstrap.html#further-bootstrapping-complications
#if LEAN_IS_STAGE0 == 1
// switch to `true` for ABI-breaking changes affecting meta code
opts = opts.update({"interpreter", "prefer_native"}, false);
opts = opts.update({"interpreter", "prefer_native"}, true);
// switch to `true` for changing built-in parsers used in quotations
opts = opts.update({"internal", "parseQuotWithCurrentStage"}, false);
// toggling `parseQuotWithCurrentStage` may also require toggling the following option if macros/syntax

88
tests/compiler/trie.lean Normal file
View file

@ -0,0 +1,88 @@
import Lean.Data.Trie
/-!
# Tests for the trie data structure
This test tests the `Lean.Parser.Trie` data structure by bisimulation with a simple `Array String`:
It performs a sequence of trie creation steps, and after each steps checks
whether the trie is operationally equivalent to the array of strings.
This test does not bother with values that are different than they `String` they are stored under.
This test does not test `upsert`; since `Trie.insert` goes through it, it should be sufficient
(and it would make this test approach more complicated.)
-/
open Lean.Data
/-- These keys used in `T.check` below. Also include keys for negative lookup tests here! -/
def keys : Array String := #[
"",
"h",
"hello",
"helloo",
"hellooo",
"helloooooo",
"hella",
"hellx",
"hö",
"hü",
"hä",
"💩"
]
/-- A trie together with a reference value as an array of values -/
def T := Trie String × Array String
def T.empty : T := (.empty, .empty)
def T.insert : T → String → T := fun (t,a) s =>
(t.insert s s, if a.contains s then a else a.push s)
/-- A convenience function for use in this test case -/
def Array.sorted : Array String → Array String := fun a =>
a.qsort (fun s1 s2 => s1 < s2)
/-- The intendend semanics of `Trie.findPrefix` -/
def Array.findPrefix : Array String → String → Array String := fun a s =>
a.filter (fun s' => s.isPrefixOf s')
/-- The intendend semanics of `Trie.matchPrefix`: Longest prefix found in trie -/
def Array.matchPrefix : Array String → String → Option String := fun a s => Id.run do
for i in List.reverse (List.range (s.length + 1)) do
let pfix := s.take i
if let some _ := a.find? (· == pfix) then
return some pfix
return none
def T.check : T → IO Unit := fun (t,a) => do
-- Check lookup equivalence
keys.forM fun s => do
unless t.find? s = a.find? (· == s) do
IO.println s!"find? differs: key = {s}"
-- Check findPrefix equivalence
keys.forM fun s => do
unless (t.findPrefix s).sorted = (a.findPrefix s).sorted do
IO.println s!"findPrefix differs: key = {s}"
-- Check matchPrefix equivalence
keys.forM fun s => do
unless t.matchPrefix s 0 = a.matchPrefix s do
IO.println s!"matchPrefix differs: key = {s}, got: {t.matchPrefix s 0} exp: {a.matchPrefix s} "
let s' := "somePrefix" ++ s
unless t.matchPrefix s' ((0 : String.Pos) + "somePrefix") = a.matchPrefix s do
IO.println s!"matchPrefix differs (with prefix): key = {s}"
def main : IO Unit := do
-- Add tricky insert sequences here:
for seq in #[
#["hello", "hella", "hellooo", "h", "hö", "hü", "💩", "", "hü"],
#["", "helooooo"]
] do
IO.println "Resetting trie"
let mut t : T := T.empty
t.check
for s in seq do
IO.println s!"Inserting {s}"
t := t.insert s
t.check

View file

@ -0,0 +1,13 @@
Resetting trie
Inserting hello
Inserting hella
Inserting hellooo
Inserting h
Inserting hö
Inserting hü
Inserting 💩
Inserting
Inserting hü
Resetting trie
Inserting
Inserting helooooo

View file

@ -168,8 +168,8 @@ do tk ← monadLift peekToken,
| Syntax.atom ⟨_, sym⟩ := do
cfg ← readCfg,
(match cfg.tokens.matchPrefix sym.mkIterator with
| some ⟨_, tkCfg := pure tkCfg.lbp
| _ := error "currLbp: unreachable")
| some tkCfg := pure tkCfg.lbp
| _ := error "currLbp: unreachable")
| Syntax.rawNode {kind := @number, ..} := pure maxPrec
| Syntax.rawNode {kind := @stringLit, ..} := pure maxPrec
| Syntax.ident _ := pure maxPrec
@ -182,7 +182,7 @@ do tk ← monadLift peekToken,
match tk with
| Syntax.atom ⟨_, sym⟩ := do
cfg ← read,
-- some ⟨_, tkCfg ← pure (cfg.tokens.matchPrefix sym.mkIterator) | error "currLbp: unreachable",
-- some tkCfg ← pure (cfg.tokens.matchPrefix sym.mkIterator) | error "currLbp: unreachable",
pure 0
| Syntax.ident _ := pure maxPrec
| Syntax.rawNode {kind := @number, ..} := pure maxPrec

View file

@ -764,7 +764,7 @@ private def tokenFnAux : BasicParserFn
else if c.isDigit then
numberFnAux s d
else
let (_, tk) := cfg.tokens.matchPrefix s i in
let tk := cfg.tokens.matchPrefix s i in
identFnAux i tk Name.anonymous s d
private def updateCache (startPos : Nat) (d : ParserData) : ParserData :=