perf: Use flat ByteArrays in Trie (#2529)
This commit is contained in:
parent
de76a5d922
commit
b2d668c340
10 changed files with 282 additions and 80 deletions
|
|
@ -33,6 +33,10 @@ opaque fromUTF8Unchecked (a : @& ByteArray) : String
|
|||
@[extern "lean_string_to_utf8"]
|
||||
opaque toUTF8 (a : @& String) : ByteArray
|
||||
|
||||
/-- Accesses a byte in the UTF-8 encoding of the `String`. O(1) -/
|
||||
@[extern "lean_string_get_byte_fast"]
|
||||
opaque getUtf8Byte (s : @& String) (n : Nat) (h : n < s.utf8ByteSize) : UInt8
|
||||
|
||||
theorem Iterator.sizeOf_next_lt_of_hasNext (i : String.Iterator) (h : i.hasNext) : sizeOf i.next < sizeOf i := by
|
||||
cases i; rename_i s pos; simp [Iterator.next, Iterator.sizeOf_eq]; simp [Iterator.hasNext] at h
|
||||
exact Nat.sub_lt_sub_left h (String.lt_next s pos)
|
||||
|
|
|
|||
|
|
@ -1,110 +1,202 @@
|
|||
/-
|
||||
Copyright (c) 2018 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Author: Sebastian Ullrich, Leonardo de Moura
|
||||
Author: Sebastian Ullrich, Leonardo de Moura, Joachim Breitner
|
||||
|
||||
Trie for tokenizing the Lean language
|
||||
A string trie data structure, used for tokenizing the Lean language
|
||||
-/
|
||||
import Lean.Data.Format
|
||||
|
||||
namespace Lean
|
||||
namespace Parser
|
||||
namespace Data
|
||||
|
||||
/-
|
||||
## Implementation notes
|
||||
|
||||
Tries have typically many nodes with small degree, where a linear scan
|
||||
through the (compact) `ByteArray` is faster than using binary search or
|
||||
search trees like `RBTree`.
|
||||
|
||||
Moreover, many nodes have degree 1, which justifies the special case `Node1`
|
||||
constructor.
|
||||
|
||||
The code would be a bit less repetitive if we used something like the following
|
||||
```
|
||||
mutual
|
||||
def Trie α := Option α × ByteAssoc α
|
||||
|
||||
inductive ByteAssoc α where
|
||||
| leaf : Trie α
|
||||
| node1 : UInt8 → Trie α → Trie α
|
||||
| node : ByteArray → Array (Trie α) → Trie α
|
||||
end
|
||||
```
|
||||
but that would come at the cost of extra indirections.
|
||||
-/
|
||||
|
||||
/-- A Trie is a key-value store where the keys are of type `String`,
|
||||
and the internal structure is a tree that branches on the bytes of the string. -/
|
||||
inductive Trie (α : Type) where
|
||||
| Node : Option α → RBNode Char (fun _ => Trie α) → Trie α
|
||||
| leaf : Option α → Trie α
|
||||
| node1 : Option α → UInt8 → Trie α → Trie α
|
||||
| node : Option α → ByteArray → Array (Trie α) → Trie α
|
||||
|
||||
namespace Trie
|
||||
variable {α : Type}
|
||||
|
||||
def empty : Trie α :=
|
||||
⟨none, RBNode.leaf⟩
|
||||
/-- The empty `Trie` -/
|
||||
def empty : Trie α := leaf none
|
||||
|
||||
instance : EmptyCollection (Trie α) :=
|
||||
⟨empty⟩
|
||||
|
||||
instance : Inhabited (Trie α) where
|
||||
default := Node none RBNode.leaf
|
||||
default := empty
|
||||
|
||||
partial def insert (t : Trie α) (s : String) (val : α) : Trie α :=
|
||||
let rec insertEmpty (i : String.Pos) : Trie α :=
|
||||
match s.atEnd i with
|
||||
| true => Trie.Node (some val) RBNode.leaf
|
||||
| false =>
|
||||
let c := s.get i
|
||||
let t := insertEmpty (s.next i)
|
||||
Trie.Node none (RBNode.singleton c t)
|
||||
/-- Insert or update the value at a the given key `s`. -/
|
||||
partial def upsert (t : Trie α) (s : String) (f : Option α → α) : Trie α :=
|
||||
let rec insertEmpty (i : Nat) : Trie α :=
|
||||
if h : i < s.utf8ByteSize then
|
||||
let c := s.getUtf8Byte i h
|
||||
let t := insertEmpty (i + 1)
|
||||
node1 none c t
|
||||
else
|
||||
leaf (f .none)
|
||||
let rec loop
|
||||
| Trie.Node v m, i =>
|
||||
match s.atEnd i with
|
||||
| true => Trie.Node (some val) m -- overrides old value
|
||||
| false =>
|
||||
let c := s.get i
|
||||
let i := s.next i
|
||||
let t := match RBNode.find compare m c with
|
||||
| none => insertEmpty i
|
||||
| some t => loop t i
|
||||
Trie.Node v (RBNode.insert compare m c t)
|
||||
loop t 0
|
||||
| i, leaf v =>
|
||||
if h : i < s.utf8ByteSize then
|
||||
let c := s.getUtf8Byte i h
|
||||
let t := insertEmpty (i + 1)
|
||||
node1 v c t
|
||||
else
|
||||
leaf (f v)
|
||||
| i, node1 v c' t' =>
|
||||
if h : i < s.utf8ByteSize then
|
||||
let c := s.getUtf8Byte i h
|
||||
if c == c'
|
||||
then node1 v c' (loop (i + 1) t')
|
||||
else
|
||||
let t := insertEmpty (i + 1)
|
||||
node v (.mk #[c, c']) #[t, t']
|
||||
else
|
||||
node1 (f v) c' t'
|
||||
| i, node v cs ts =>
|
||||
if h : i < s.utf8ByteSize then
|
||||
let c := s.getUtf8Byte i h
|
||||
match cs.findIdx? (· == c) with
|
||||
| none =>
|
||||
let t := insertEmpty (i + 1)
|
||||
node v (cs.push c) (ts.push t)
|
||||
| some idx =>
|
||||
node v cs (ts.modify idx (loop (i + 1)))
|
||||
else
|
||||
node (f v) cs ts
|
||||
loop 0 t
|
||||
|
||||
/-- Inserts a value at a the given key `s`, overriding an existing value if present. -/
|
||||
partial def insert (t : Trie α) (s : String) (val : α) : Trie α :=
|
||||
upsert t s (fun _ => val)
|
||||
|
||||
/-- Looks up a value at the given key `s`. -/
|
||||
partial def find? (t : Trie α) (s : String) : Option α :=
|
||||
let rec loop
|
||||
| Trie.Node val m, i =>
|
||||
match s.atEnd i with
|
||||
| true => val
|
||||
| false =>
|
||||
let c := s.get i
|
||||
let i := s.next i
|
||||
match RBNode.find compare m c with
|
||||
| i, leaf val =>
|
||||
if i < s.utf8ByteSize then
|
||||
none
|
||||
else
|
||||
val
|
||||
| i, node1 val c' t' =>
|
||||
if h : i < s.utf8ByteSize then
|
||||
let c := s.getUtf8Byte i h
|
||||
if c == c'
|
||||
then loop (i + 1) t'
|
||||
else none
|
||||
else
|
||||
val
|
||||
| i, node val cs ts =>
|
||||
if h : i < s.utf8ByteSize then
|
||||
let c := s.getUtf8Byte i h
|
||||
match cs.findIdx? (· == c) with
|
||||
| none => none
|
||||
| some t => loop t i
|
||||
loop t 0
|
||||
| some idx => loop (i + 1) (ts.get! idx)
|
||||
else
|
||||
val
|
||||
loop 0 t
|
||||
|
||||
/-- Return values that match the given `prefix` -/
|
||||
partial def findPrefix (t : Trie α) (pre : String) : Array α :=
|
||||
go t 0 |>.run #[] |>.2
|
||||
where
|
||||
go (t : Trie α) (i : String.Pos) : StateM (Array α) Unit :=
|
||||
if pre.atEnd i then
|
||||
collect t
|
||||
else
|
||||
let k := pre.get i
|
||||
let i := pre.next i
|
||||
let ⟨_, cs⟩ := t
|
||||
cs.forM fun k' c => do
|
||||
if k == k' then go c i
|
||||
/-- Returns an `Array` of all values in the trie, in no particular order. -/
|
||||
partial def values (t : Trie α) : Array α := go t |>.run #[] |>.2
|
||||
where
|
||||
go : Trie α → StateM (Array α) Unit
|
||||
| leaf a? => do
|
||||
if let some a := a? then
|
||||
modify (·.push a)
|
||||
| node1 a? _ t' => do
|
||||
if let some a := a? then
|
||||
modify (·.push a)
|
||||
go t'
|
||||
| node a? _ ts => do
|
||||
if let some a := a? then
|
||||
modify (·.push a)
|
||||
ts.forM fun t' => go t'
|
||||
|
||||
collect (t : Trie α) : StateM (Array α) Unit := do
|
||||
let ⟨a?, cs⟩ := t
|
||||
if let some a := a? then
|
||||
modify (·.push a)
|
||||
cs.forM fun _ c => collect c
|
||||
/-- Returns all values whose key have the given string `pre` as a prefix, in no particular order. -/
|
||||
partial def findPrefix (t : Trie α) (pre : String) : Array α := go t 0
|
||||
where
|
||||
go (t : Trie α) (i : Nat) : Array α :=
|
||||
if h : i < pre.utf8ByteSize then
|
||||
let c := pre.getUtf8Byte i h
|
||||
match t with
|
||||
| leaf _val => .empty
|
||||
| node1 _val c' t' =>
|
||||
if c == c'
|
||||
then go t' (i + 1)
|
||||
else .empty
|
||||
| node _val cs ts =>
|
||||
match cs.findIdx? (· == c) with
|
||||
| none => .empty
|
||||
| some idx => go (ts.get! idx) (i + 1)
|
||||
else
|
||||
t.values
|
||||
|
||||
private def updtAcc (v : Option α) (i : String.Pos) (acc : String.Pos × Option α) : String.Pos × Option α :=
|
||||
match v, acc with
|
||||
| some v, (_, _) => (i, some v) -- we pattern match on `acc` to enable memory reuse
|
||||
| none, acc => acc
|
||||
|
||||
partial def matchPrefix (s : String) (t : Trie α) (i : String.Pos) : String.Pos × Option α :=
|
||||
/-- Find the longest _key_ in the trie that is contained in the given string `s` at position `i`,
|
||||
and return the associated value. -/
|
||||
partial def matchPrefix (s : String) (t : Trie α) (i : String.Pos) : Option α :=
|
||||
let rec loop
|
||||
| Trie.Node v m, i, acc =>
|
||||
match s.atEnd i with
|
||||
| true => updtAcc v i acc
|
||||
| false =>
|
||||
let acc := updtAcc v i acc
|
||||
let c := s.get i
|
||||
let i := s.next i
|
||||
match RBNode.find compare m c with
|
||||
| some t => loop t i acc
|
||||
| none => acc
|
||||
loop t i (i, none)
|
||||
| leaf v, _, res =>
|
||||
if v.isSome then v else res
|
||||
| node1 v c' t', i, res =>
|
||||
let res := if v.isSome then v else res
|
||||
if h : i < s.utf8ByteSize then
|
||||
let c := s.getUtf8Byte i h
|
||||
if c == c'
|
||||
then loop t' (i + 1) res
|
||||
else res
|
||||
else
|
||||
res
|
||||
| node v cs ts, i, res =>
|
||||
let res := if v.isSome then v else res
|
||||
if h : i < s.utf8ByteSize then
|
||||
let c := s.getUtf8Byte i h
|
||||
match cs.findIdx? (· == c) with
|
||||
| none => res
|
||||
| some idx => loop (ts.get! idx) (i + 1) res
|
||||
else
|
||||
res
|
||||
loop t i.byteIdx none
|
||||
|
||||
private partial def toStringAux {α : Type} : Trie α → List Format
|
||||
| Trie.Node _ map => map.fold (fun Fs c t =>
|
||||
format (repr c) :: (Format.group $ Format.nest 2 $ flip Format.joinSep Format.line $ toStringAux t) :: Fs) []
|
||||
| leaf _ => []
|
||||
| node1 _ c t =>
|
||||
[ format (repr c), Format.group $ Format.nest 4 $ flip Format.joinSep Format.line $ toStringAux t ]
|
||||
| node _ cs ts =>
|
||||
List.join $ List.zipWith (fun c t =>
|
||||
[ format (repr c), (Format.group $ Format.nest 4 $ flip Format.joinSep Format.line $ toStringAux t) ]
|
||||
) cs.toList ts.toList
|
||||
|
||||
instance {α : Type} : ToString (Trie α) :=
|
||||
⟨fun t => (flip Format.joinSep Format.line $ toStringAux t).pretty⟩
|
||||
|
||||
end Trie
|
||||
|
||||
end Parser
|
||||
end Data
|
||||
end Lean
|
||||
|
|
|
|||
|
|
@ -821,7 +821,7 @@ private def tokenFnAux : ParserFn := fun c s =>
|
|||
else if curr == '`' && isIdFirstOrBeginEscape (getNext input i) then
|
||||
nameLitAux i c s
|
||||
else
|
||||
let (_, tk) := c.tokens.matchPrefix input i
|
||||
let tk := c.tokens.matchPrefix input i
|
||||
identFnAux i tk .anonymous c s
|
||||
|
||||
private def updateTokenCache (startPos : String.Pos) (s : ParserState) : ParserState :=
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ def minPrec : Nat := eval_prec min
|
|||
|
||||
abbrev Token := String
|
||||
|
||||
abbrev TokenTable := Trie Token
|
||||
abbrev TokenTable := Lean.Data.Trie Token
|
||||
|
||||
abbrev SyntaxNodeKindSet := PersistentHashMap SyntaxNodeKind Unit
|
||||
|
||||
|
|
|
|||
|
|
@ -1014,6 +1014,11 @@ static inline uint32_t lean_string_utf8_get_fast(b_lean_obj_arg s, b_lean_obj_ar
|
|||
if ((c & 0x80) == 0) return c;
|
||||
return lean_string_utf8_get_fast_cold(str, idx, lean_string_size(s), c);
|
||||
}
|
||||
static inline uint8_t lean_string_get_byte_fast(b_lean_obj_arg s, b_lean_obj_arg i) {
|
||||
char const * str = lean_string_cstr(s);
|
||||
size_t idx = lean_unbox(i);
|
||||
return str[idx];
|
||||
}
|
||||
|
||||
LEAN_SHARED lean_obj_res lean_string_utf8_next(b_lean_obj_arg s, b_lean_obj_arg i);
|
||||
LEAN_SHARED lean_obj_res lean_string_utf8_next_fast_cold(size_t i, unsigned char c);
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ options get_default_options() {
|
|||
// see https://leanprover.github.io/lean4/doc/dev/bootstrap.html#further-bootstrapping-complications
|
||||
#if LEAN_IS_STAGE0 == 1
|
||||
// switch to `true` for ABI-breaking changes affecting meta code
|
||||
opts = opts.update({"interpreter", "prefer_native"}, false);
|
||||
opts = opts.update({"interpreter", "prefer_native"}, true);
|
||||
// switch to `true` for changing built-in parsers used in quotations
|
||||
opts = opts.update({"internal", "parseQuotWithCurrentStage"}, false);
|
||||
// toggling `parseQuotWithCurrentStage` may also require toggling the following option if macros/syntax
|
||||
|
|
|
|||
88
tests/compiler/trie.lean
Normal file
88
tests/compiler/trie.lean
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
import Lean.Data.Trie
|
||||
|
||||
/-!
|
||||
|
||||
# Tests for the trie data structure
|
||||
|
||||
This test tests the `Lean.Parser.Trie` data structure by bisimulation with a simple `Array String`:
|
||||
It performs a sequence of trie creation steps, and after each steps checks
|
||||
whether the trie is operationally equivalent to the array of strings.
|
||||
|
||||
This test does not bother with values that are different than they `String` they are stored under.
|
||||
This test does not test `upsert`; since `Trie.insert` goes through it, it should be sufficient
|
||||
(and it would make this test approach more complicated.)
|
||||
-/
|
||||
|
||||
open Lean.Data
|
||||
|
||||
/-- These keys used in `T.check` below. Also include keys for negative lookup tests here! -/
|
||||
def keys : Array String := #[
|
||||
"",
|
||||
"h",
|
||||
"hello",
|
||||
"helloo",
|
||||
"hellooo",
|
||||
"helloooooo",
|
||||
"hella",
|
||||
"hellx",
|
||||
"hö",
|
||||
"hü",
|
||||
"hä",
|
||||
"💩"
|
||||
]
|
||||
|
||||
/-- A trie together with a reference value as an array of values -/
|
||||
def T := Trie String × Array String
|
||||
|
||||
def T.empty : T := (.empty, .empty)
|
||||
|
||||
def T.insert : T → String → T := fun (t,a) s =>
|
||||
(t.insert s s, if a.contains s then a else a.push s)
|
||||
|
||||
/-- A convenience function for use in this test case -/
|
||||
def Array.sorted : Array String → Array String := fun a =>
|
||||
a.qsort (fun s1 s2 => s1 < s2)
|
||||
|
||||
/-- The intendend semanics of `Trie.findPrefix` -/
|
||||
def Array.findPrefix : Array String → String → Array String := fun a s =>
|
||||
a.filter (fun s' => s.isPrefixOf s')
|
||||
|
||||
/-- The intendend semanics of `Trie.matchPrefix`: Longest prefix found in trie -/
|
||||
def Array.matchPrefix : Array String → String → Option String := fun a s => Id.run do
|
||||
for i in List.reverse (List.range (s.length + 1)) do
|
||||
let pfix := s.take i
|
||||
if let some _ := a.find? (· == pfix) then
|
||||
return some pfix
|
||||
return none
|
||||
|
||||
|
||||
def T.check : T → IO Unit := fun (t,a) => do
|
||||
-- Check lookup equivalence
|
||||
keys.forM fun s => do
|
||||
unless t.find? s = a.find? (· == s) do
|
||||
IO.println s!"find? differs: key = {s}"
|
||||
-- Check findPrefix equivalence
|
||||
keys.forM fun s => do
|
||||
unless (t.findPrefix s).sorted = (a.findPrefix s).sorted do
|
||||
IO.println s!"findPrefix differs: key = {s}"
|
||||
-- Check matchPrefix equivalence
|
||||
keys.forM fun s => do
|
||||
unless t.matchPrefix s 0 = a.matchPrefix s do
|
||||
IO.println s!"matchPrefix differs: key = {s}, got: {t.matchPrefix s 0} exp: {a.matchPrefix s} "
|
||||
let s' := "somePrefix" ++ s
|
||||
unless t.matchPrefix s' ((0 : String.Pos) + "somePrefix") = a.matchPrefix s do
|
||||
IO.println s!"matchPrefix differs (with prefix): key = {s}"
|
||||
|
||||
def main : IO Unit := do
|
||||
-- Add tricky insert sequences here:
|
||||
for seq in #[
|
||||
#["hello", "hella", "hellooo", "h", "hö", "hü", "💩", "", "hü"],
|
||||
#["", "helooooo"]
|
||||
] do
|
||||
IO.println "Resetting trie"
|
||||
let mut t : T := T.empty
|
||||
t.check
|
||||
for s in seq do
|
||||
IO.println s!"Inserting {s}"
|
||||
t := t.insert s
|
||||
t.check
|
||||
13
tests/compiler/trie.lean.expected.out
Normal file
13
tests/compiler/trie.lean.expected.out
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
Resetting trie
|
||||
Inserting hello
|
||||
Inserting hella
|
||||
Inserting hellooo
|
||||
Inserting h
|
||||
Inserting hö
|
||||
Inserting hü
|
||||
Inserting 💩
|
||||
Inserting
|
||||
Inserting hü
|
||||
Resetting trie
|
||||
Inserting
|
||||
Inserting helooooo
|
||||
|
|
@ -168,8 +168,8 @@ do tk ← monadLift peekToken,
|
|||
| Syntax.atom ⟨_, sym⟩ := do
|
||||
cfg ← readCfg,
|
||||
(match cfg.tokens.matchPrefix sym.mkIterator with
|
||||
| some ⟨_, tkCfg⟩ := pure tkCfg.lbp
|
||||
| _ := error "currLbp: unreachable")
|
||||
| some tkCfg := pure tkCfg.lbp
|
||||
| _ := error "currLbp: unreachable")
|
||||
| Syntax.rawNode {kind := @number, ..} := pure maxPrec
|
||||
| Syntax.rawNode {kind := @stringLit, ..} := pure maxPrec
|
||||
| Syntax.ident _ := pure maxPrec
|
||||
|
|
@ -182,7 +182,7 @@ do tk ← monadLift peekToken,
|
|||
match tk with
|
||||
| Syntax.atom ⟨_, sym⟩ := do
|
||||
cfg ← read,
|
||||
-- some ⟨_, tkCfg⟩ ← pure (cfg.tokens.matchPrefix sym.mkIterator) | error "currLbp: unreachable",
|
||||
-- some tkCfg ← pure (cfg.tokens.matchPrefix sym.mkIterator) | error "currLbp: unreachable",
|
||||
pure 0
|
||||
| Syntax.ident _ := pure maxPrec
|
||||
| Syntax.rawNode {kind := @number, ..} := pure maxPrec
|
||||
|
|
|
|||
|
|
@ -764,7 +764,7 @@ private def tokenFnAux : BasicParserFn
|
|||
else if c.isDigit then
|
||||
numberFnAux s d
|
||||
else
|
||||
let (_, tk) := cfg.tokens.matchPrefix s i in
|
||||
let tk := cfg.tokens.matchPrefix s i in
|
||||
identFnAux i tk Name.anonymous s d
|
||||
|
||||
private def updateCache (startPos : Nat) (d : ParserData) : ParserData :=
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue