test(tests/playground/parser/parser): "liftable" longestMatch
For lists of size 0, 1 and 2, it avoids the overhead of creating temporary lists of closures. I measure the overhead with `test1.lean` and there is no overhead in this case. `test1.lean` has a test for length = 4, and the overhead is 7%. We only use longestMatch to implement the Pratt Parser. The lists should be small. So, the overhead is acceptable. If it is not. We can add back the `longestMatch` specific for `TermParser`. cc @kha
This commit is contained in:
parent
5991337279
commit
014c7e3374
1 changed files with 116 additions and 91 deletions
|
|
@ -519,6 +519,101 @@ partial def identFnAux (startPos : Nat) (tk : Option TokenConfig) : Name → Par
|
|||
else
|
||||
mkTokenAndFixPos startPos tk s d
|
||||
|
||||
def ParserData.keepNewError (d : ParserData) (oldStackSize : Nat) : ParserData :=
|
||||
match d with
|
||||
| ⟨stack, pos, cache, err⟩ := ⟨stack.shrink oldStackSize, pos, cache, err⟩
|
||||
|
||||
def ParserData.keepPrevError (d : ParserData) (oldStackSize : Nat) (oldStopPos : String.Pos) (oldError : Option String) : ParserData :=
|
||||
match d with
|
||||
| ⟨stack, _, cache, _⟩ := ⟨stack.shrink oldStackSize, oldStopPos, cache, oldError⟩
|
||||
|
||||
def ParserData.mergeErrors (d : ParserData) (oldStackSize : Nat) (oldError : String) : ParserData :=
|
||||
match d with
|
||||
| ⟨stack, pos, cache, some err⟩ := ⟨stack.shrink oldStackSize, pos, cache, some (err ++ "; " ++ oldError)⟩
|
||||
| other := other
|
||||
|
||||
def ParserData.mkLongestNodeAlt (d : ParserData) (startSize : Nat) : ParserData :=
|
||||
match d with
|
||||
| ⟨stack, pos, cache, _⟩ :=
|
||||
if stack.size == startSize then ⟨stack.push Syntax.missing, pos, cache, none⟩ -- parser did not create any node, then we just add `Syntax.missing`
|
||||
else if stack.size == startSize + 1 then d
|
||||
else
|
||||
-- parser created more than one node, combine them into a single node
|
||||
let node := Syntax.node nullKind (stack.extract startSize stack.size) [] in
|
||||
let stack := stack.shrink startSize in
|
||||
⟨stack.push node, pos, cache, none⟩
|
||||
|
||||
def ParserData.keepLatest (d : ParserData) (startStackSize : Nat) : ParserData :=
|
||||
match d with
|
||||
| ⟨stack, pos, cache, _⟩ :=
|
||||
let node := stack.back in
|
||||
let stack := stack.shrink startStackSize in
|
||||
let stack := stack.push node in
|
||||
⟨stack, pos, cache, none⟩
|
||||
|
||||
def ParserData.replaceLongest (d : ParserData) (startStackSize : Nat) (prevStackSize : Nat) : ParserData :=
|
||||
let d := d.mkLongestNodeAlt prevStackSize in
|
||||
d.keepLatest startStackSize
|
||||
|
||||
def longestMatchStep (startSize : Nat) (startPos : String.Pos) (p : ParserFn) : ParserFn :=
|
||||
λ s d,
|
||||
let prevErrorMsg := d.errorMsg in
|
||||
let prevStopPos := d.pos in
|
||||
let prevSize := d.stackSize in
|
||||
let d := d.restore prevSize startPos in
|
||||
let d := p s d in
|
||||
match prevErrorMsg, d.errorMsg with
|
||||
| none, none := -- both succeeded
|
||||
if d.pos > prevStopPos then d.replaceLongest startSize prevSize -- replace
|
||||
else if d.pos < prevStopPos then d.restore prevSize prevStopPos -- keep prev
|
||||
else d.mkLongestNodeAlt prevSize -- keep both
|
||||
| none, some _ := -- prev succeeded, current failed
|
||||
d.restore prevSize prevStopPos
|
||||
| some oldError, some _ := -- both failed
|
||||
if d.pos > prevStopPos then d.keepNewError prevSize
|
||||
else if d.pos < prevStopPos then d.keepPrevError prevSize prevStopPos prevErrorMsg
|
||||
else d.mergeErrors prevSize oldError
|
||||
| some _, none := -- prev failed, current succeeded
|
||||
d.mkLongestNodeAlt startSize
|
||||
|
||||
def longestMatchMkResult (startSize : Nat) (d : ParserData) : ParserData :=
|
||||
if !d.hasError && d.stackSize > startSize + 1 then d.mkNode choiceKind startSize else d
|
||||
|
||||
def longestMatchFnAux (startSize : Nat) (startPos : String.Pos) : List ParserFn → ParserFn
|
||||
| [] := λ _ d, longestMatchMkResult startSize d
|
||||
| (p::ps) := λ s d,
|
||||
let d := longestMatchStep startSize startPos p s d in
|
||||
longestMatchFnAux ps s d
|
||||
|
||||
def longestMatchFn₁ (p : ParserFn) : ParserFn :=
|
||||
λ s d,
|
||||
let startSize := d.stackSize in
|
||||
let d := p s d in
|
||||
if d.hasError then d else d.mkLongestNodeAlt startSize
|
||||
|
||||
def longestMatchFn₂ (p q : ParserFn) : ParserFn :=
|
||||
λ s d,
|
||||
let startSize := d.stackSize in
|
||||
let startPos := d.pos in
|
||||
let d := p s d in
|
||||
let d := if d.hasError then d.shrinkStack startSize else d.mkLongestNodeAlt startSize in
|
||||
let d := longestMatchStep startSize startPos q s d in
|
||||
longestMatchMkResult startSize d
|
||||
|
||||
def longestMatchFn : List ParserFn → ParserFn
|
||||
| [] := λ _ d, d.mkError "longest match: empty list"
|
||||
| [p] := longestMatchFn₁ p
|
||||
| (p::ps) := λ s d,
|
||||
let startSize := d.stackSize in
|
||||
let startPos := d.pos in
|
||||
let d := p s d in
|
||||
if d.hasError then
|
||||
let d := d.shrinkStack startSize in
|
||||
longestMatchFnAux startSize startPos ps s d
|
||||
else
|
||||
let d := d.mkLongestNodeAlt startSize in
|
||||
longestMatchFnAux startSize startPos ps s d
|
||||
|
||||
structure AbsParser (ρ : Type) :=
|
||||
(info : ParserInfo := {})
|
||||
(fn : ρ)
|
||||
|
|
@ -529,14 +624,16 @@ class ParserFnLift (ρ : Type) :=
|
|||
(lift {} : ParserFn → ρ)
|
||||
(map : (ParserFn → ParserFn) → ρ → ρ)
|
||||
(map₂ : (ParserFn → ParserFn → ParserFn) → ρ → ρ → ρ)
|
||||
(mapList : (List ParserFn → ParserFn) → List ρ → ρ)
|
||||
|
||||
instance parserLiftInhabited {ρ : Type} [ParserFnLift ρ] : Inhabited ρ :=
|
||||
⟨ParserFnLift.lift (default _)⟩
|
||||
|
||||
instance idParserLift : ParserFnLift ParserFn :=
|
||||
{ lift := λ p, p,
|
||||
map := λ m p, m p,
|
||||
map₂ := λ m p1 p2, m p1 p2 }
|
||||
{ lift := λ p, p,
|
||||
map := λ m p, m p,
|
||||
map₂ := λ m p1 p2, m p1 p2,
|
||||
mapList := λ m ps, m ps }
|
||||
|
||||
@[inline]
|
||||
def liftParser {ρ : Type} [ParserFnLift ρ] (info : ParserInfo) (fn : ParserFn) : AbsParser ρ :=
|
||||
|
|
@ -560,7 +657,8 @@ EnvParserFn (α → ρ) ρ
|
|||
instance envParserLift (α ρ : Type) [ParserFnLift ρ] : ParserFnLift (EnvParserFn α ρ) :=
|
||||
{ lift := λ p a, ParserFnLift.lift p,
|
||||
map := λ m p a, ParserFnLift.map m (p a),
|
||||
map₂ := λ m p1 p2 a, ParserFnLift.map₂ m (p1 a) (p2 a) }
|
||||
map₂ := λ m p1 p2 a, ParserFnLift.map₂ m (p1 a) (p2 a),
|
||||
mapList := λ m ps a, ParserFnLift.mapList m (ps.map (λ p, p a)) }
|
||||
|
||||
instance recParserLift (α ρ : Type) [ParserFnLift ρ] : ParserFnLift (RecParserFn α ρ) :=
|
||||
inferInstanceAs (ParserFnLift (EnvParserFn (α → ρ) ρ))
|
||||
|
|
@ -603,6 +701,20 @@ mapParser₂ sepByInfo (sepByFn allowTrailingSep) p sep
|
|||
@[inline] def sepBy1 {ρ : Type} [ParserFnLift ρ] (p sep : AbsParser ρ) (allowTrailingSep : Bool := false) : AbsParser ρ :=
|
||||
mapParser₂ sepBy1Info (sepBy1Fn allowTrailingSep) p sep
|
||||
|
||||
def longestMatchInfo {ρ : Type} (ps : List (AbsParser ρ)) : ParserInfo :=
|
||||
{ updateTokens := λ trie, ps.foldl (λ trie p, p.info.updateTokens trie) trie,
|
||||
firstTokens := ps.foldl (λ tks p, p.info.firstTokens ++ tks) [] }
|
||||
|
||||
def liftLongestMatchFn {ρ : Type} [ParserFnLift ρ] : List (AbsParser ρ) → ρ
|
||||
| [] := ParserFnLift.lift (longestMatchFn [])
|
||||
| [p] := ParserFnLift.map longestMatchFn₁ p.fn
|
||||
| [p, q] := ParserFnLift.map₂ longestMatchFn₂ p.fn q.fn
|
||||
| ps := ParserFnLift.mapList longestMatchFn (ps.map (λ p, p.fn))
|
||||
|
||||
@[inline] def longestMatch {ρ : Type} [ParserFnLift ρ] (ps : List (AbsParser ρ)) : AbsParser ρ :=
|
||||
{ info := longestMatchInfo ps,
|
||||
fn := liftLongestMatchFn ps }
|
||||
|
||||
abbrev BasicParserFn : Type := EnvParserFn ParserConfig ParserFn
|
||||
abbrev BasicParser : Type := AbsParser BasicParserFn
|
||||
abbrev CmdParserFn (ρ : Type) : Type := EnvParserFn ρ (RecParserFn Unit ParserFn)
|
||||
|
|
@ -788,93 +900,6 @@ if d.hasError then
|
|||
else
|
||||
Except.ok d.stxStack.back
|
||||
|
||||
def ParserData.keepNewError (d : ParserData) (oldStackSize : Nat) : ParserData :=
|
||||
match d with
|
||||
| ⟨stack, pos, cache, err⟩ := ⟨stack.shrink oldStackSize, pos, cache, err⟩
|
||||
|
||||
def ParserData.keepPrevError (d : ParserData) (oldStackSize : Nat) (oldStopPos : String.Pos) (oldError : Option String) : ParserData :=
|
||||
match d with
|
||||
| ⟨stack, _, cache, _⟩ := ⟨stack.shrink oldStackSize, oldStopPos, cache, oldError⟩
|
||||
|
||||
def ParserData.mergeErrors (d : ParserData) (oldStackSize : Nat) (oldError : String) : ParserData :=
|
||||
match d with
|
||||
| ⟨stack, pos, cache, some err⟩ := ⟨stack.shrink oldStackSize, pos, cache, some (err ++ "; " ++ oldError)⟩
|
||||
| other := other
|
||||
|
||||
def ParserData.mkLongestNodeAlt (d : ParserData) (startSize : Nat) : ParserData :=
|
||||
match d with
|
||||
| ⟨stack, pos, cache, _⟩ :=
|
||||
if stack.size == startSize then ⟨stack.push Syntax.missing, pos, cache, none⟩ -- parser did not create any node, then we just add `Syntax.missing`
|
||||
else if stack.size == startSize + 1 then d
|
||||
else
|
||||
-- parser created more than one node, combine them into a single node
|
||||
let node := Syntax.node nullKind (stack.extract startSize stack.size) [] in
|
||||
let stack := stack.shrink startSize in
|
||||
⟨stack.push node, pos, cache, none⟩
|
||||
|
||||
def ParserData.keepLatest (d : ParserData) (startStackSize : Nat) : ParserData :=
|
||||
match d with
|
||||
| ⟨stack, pos, cache, _⟩ :=
|
||||
let node := stack.back in
|
||||
let stack := stack.shrink startStackSize in
|
||||
let stack := stack.push node in
|
||||
⟨stack, pos, cache, none⟩
|
||||
|
||||
def ParserData.replaceLongest (d : ParserData) (startStackSize : Nat) (prevStackSize : Nat) : ParserData :=
|
||||
let d := d.mkLongestNodeAlt prevStackSize in
|
||||
d.keepLatest startStackSize
|
||||
|
||||
def longestMatchFnAux (startSize : Nat) (startPos : String.Pos) : List TermParserFn → TermParserFn
|
||||
| [] := λ _ _ _ _ d, if !d.hasError && d.stackSize > startSize + 1 then d.mkNode choiceKind startSize else d
|
||||
| (p::ps) := λ tp cfg cp s d,
|
||||
let prevErrorMsg := d.errorMsg in
|
||||
let prevStopPos := d.pos in
|
||||
let prevSize := d.stackSize in
|
||||
let d := d.restore prevSize startPos in
|
||||
let d := p tp cfg cp s d in
|
||||
match prevErrorMsg, d.errorMsg with
|
||||
| none, none := -- both succeeded
|
||||
let d :=
|
||||
if d.pos > prevStopPos then d.replaceLongest startSize prevSize -- replace
|
||||
else if d.pos < prevStopPos then d.restore prevSize prevStopPos -- keep prev
|
||||
else d.mkLongestNodeAlt prevSize in -- keep both
|
||||
longestMatchFnAux ps tp cfg cp s d
|
||||
| none, some _ := -- prev succeeded, current failed
|
||||
let d := d.restore prevSize prevStopPos in
|
||||
longestMatchFnAux ps tp cfg cp s d
|
||||
| some oldError, some _ := -- both failed
|
||||
let d :=
|
||||
if d.pos > prevStopPos then d.keepNewError prevSize
|
||||
else if d.pos < prevStopPos then d.keepPrevError prevSize prevStopPos prevErrorMsg
|
||||
else d.mergeErrors prevSize oldError in
|
||||
longestMatchFnAux ps tp cfg cp s d
|
||||
| some _, none := -- prev failed, current succeeded
|
||||
let d := d.mkLongestNodeAlt startSize in
|
||||
longestMatchFnAux ps tp cfg cp s d
|
||||
|
||||
def longestMatchFn : List TermParserFn → TermParserFn
|
||||
| [] := λ _ _ _ _ d, d.mkError "longest match: empty list"
|
||||
| [p] := λ tp cfg cp s d,
|
||||
let startSize := d.stackSize in
|
||||
let d := p tp cfg cp s d in
|
||||
if d.hasError then d else d.mkLongestNodeAlt startSize
|
||||
| (p::ps) := λ tp cfg cp s d,
|
||||
let startSize := d.stackSize in
|
||||
let startPos := d.pos in
|
||||
let d := p tp cfg cp s d in
|
||||
if d.hasError then
|
||||
let d := d.shrinkStack startSize in
|
||||
longestMatchFnAux startSize startPos ps tp cfg cp s d
|
||||
else
|
||||
let d := d.mkLongestNodeAlt startSize in
|
||||
longestMatchFnAux startSize startPos ps tp cfg cp s d
|
||||
|
||||
-- Helper function for testing longestMatchFn, we don't use this function directly
|
||||
def longestMatch (ps : List TermParser) : TermParser :=
|
||||
{ info := { updateTokens := λ trie, ps.foldl (λ trie p, p.info.updateTokens trie) trie,
|
||||
firstTokens := ps.foldl (λ tks p, p.info.firstTokens ++ tks) [] },
|
||||
fn := longestMatchFn (ps.map (λ p, p.fn)) }
|
||||
|
||||
-- Stopped here
|
||||
|
||||
@[noinline] def termPrattParser (tbl : TermParsingTables) (rbp : Nat) : TermParserFn :=
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue