perf: cache leading_parser and syntax as well

We better hope the `leading_parser`s are closed terms
This commit is contained in:
Sebastian Ullrich 2022-11-05 18:41:12 +01:00
parent da6efe1bca
commit ed03ff9d00
7 changed files with 50 additions and 36 deletions

View file

@ -89,15 +89,20 @@ private def elabParserMacroAux (prec e : Term) (withAnonymousAntiquot : Bool) :
match extractMacroScopes declName with
| { name := .str _ s, .. } =>
let kind := quote declName
let s := quote s
``(withAntiquot (mkAntiquot $s $kind $(quote withAnonymousAntiquot)) (leadingNode $kind $prec $e))
let mut p ← ``(withAntiquot
(mkAntiquot $(quote s) $kind $(quote withAnonymousAntiquot))
(leadingNode $kind $prec $e))
-- cache only unparameterized parsers
if (← getLCtx).all (·.isAuxDecl) then
p ← ``(withCache $kind $p)
return p
| _ => throwError "invalid `leading_parser` macro, unexpected declaration name"
@[builtin_term_elab «leading_parser»] def elabLeadingParserMacro : TermElab :=
adaptExpander fun stx => match stx with
| `(leading_parser $[: $prec?]? $[(withAnonymousAntiquot := $anon?)]? $e) =>
elabParserMacroAux (prec?.getD (quote Parser.maxPrec)) e (anon?.all (·.raw.isOfKind ``Parser.Term.trueVal))
| _ => throwUnsupportedSyntax
adaptExpander fun
| `(leading_parser $[: $prec?]? $[(withAnonymousAntiquot := $anon?)]? $e) =>
elabParserMacroAux (prec?.getD (quote Parser.maxPrec)) e (anon?.all (·.raw.isOfKind ``Parser.Term.trueVal))
| _ => throwUnsupportedSyntax
private def elabTParserMacroAux (prec lhsPrec e : Term) : TermElabM Syntax := do
let declName? ← getDeclName?

View file

@ -160,24 +160,24 @@ structure TokenCacheEntry where
stopPos : String.Pos := 0
token : Syntax := Syntax.missing
structure CategoryCacheKey extends ParserContextCacheKey where
cat : Name
pos : String.Pos
structure ParserCacheKey extends ParserContextCacheKey where
parserName : Name
pos : String.Pos
deriving BEq, Hashable
structure CategoryCacheEntry where
structure ParserCacheEntry where
stx : Syntax
lhsPrec : Nat
newPos : String.Pos
errorMsg : Option Error
structure ParserCache where
tokenCache : TokenCacheEntry
categoryCache : HashMap CategoryCacheKey CategoryCacheEntry
tokenCache : TokenCacheEntry
parserCache : HashMap ParserCacheKey ParserCacheEntry
def initCacheForInput (input : String) : ParserCache where
tokenCache := { startPos := input.endPos + ' ' /- make sure it is not a valid position -/ }
categoryCache := {}
tokenCache := { startPos := input.endPos + ' ' /- make sure it is not a valid position -/ }
parserCache := {}
structure ParserState where
stxStack : Array Syntax := #[]
@ -483,6 +483,23 @@ def suppressInsideQuot (p : Parser) : Parser := {
fn := suppressInsideQuotFn p.fn
}
def withCacheFn (parserName : Name) (p : ParserFn) : ParserFn := fun c s => Id.run do
let key := ⟨c.toParserContextCacheKey, parserName, s.pos⟩
if let some r := s.cache.parserCache.find? key then
-- TODO: turn this into a proper trace once we have these in the parser
--dbg_trace "parser cache hit: {parserName}:{s.pos} -> {r.stx}"
match s with
| ⟨stack, _, _, cache, _⟩ => return ⟨stack.push r.stx, r.lhsPrec, r.newPos, cache, r.errorMsg⟩
let initStackSz := s.stackSize
let s := p c s
if s.stackSize != initStackSz + 1 then
panic! s!"withCacheFn: unexpected stack growth {s.stxStack}"
{ s with cache.parserCache := s.cache.parserCache.insert key ⟨s.stxStack.back, s.lhsPrec, s.pos, s.errorMsg⟩ }
def withCache (parserName : Name) (p : Parser) : Parser where
info := p.info
fn := withCacheFn parserName p.fn
def leadingNode (n : SyntaxNodeKind) (prec : Nat) (p : Parser) : Parser :=
checkPrec prec >> node n p >> setLhsPrec prec
@ -1708,20 +1725,8 @@ builtin_initialize categoryParserFnExtension : EnvExtension CategoryParserFn ←
def categoryParserFn (catName : Name) : ParserFn := fun ctx s =>
categoryParserFnExtension.getState ctx.env catName ctx s
def categoryParser (catName : Name) (prec : Nat) : Parser := {
fn := fun c s => Id.run do
let c := { c with prec }
let key := ⟨c.toParserContextCacheKey, catName, s.pos⟩
if let some r := s.cache.categoryCache.find? key then
match s with
| ⟨stack, _, _, cache, _⟩ => return ⟨stack.push r.stx, r.lhsPrec, r.newPos, cache, r.errorMsg⟩
let initStackSz := s.stackSize
let s := categoryParserFn catName c s
if s.stackSize > initStackSz + 1 then
panic! s!"categoryParser: unexpected stack growth {s.stxStack}"
let s := if s.stackSize == initStackSz then s.pushSyntax .missing else s
{ s with cache.categoryCache := s.cache.categoryCache.insert key ⟨s.stxStack.back, s.lhsPrec, s.pos, s.errorMsg⟩ }
}
def categoryParser (catName : Name) (prec : Nat) : Parser where
fn c s := withCacheFn catName (categoryParserFn catName) { c with prec } s
-- Define `termParser` here because we need it for antiquotations
def termParser (prec : Nat := 0) : Parser :=

View file

@ -280,7 +280,7 @@ partial def compileParserDescr (categories : ParserCategories) (d : ParserDescr)
| ParserDescr.unary n d => return (← getUnaryAlias parserAliasesRef n) (← visit d)
| ParserDescr.binary n d₁ d₂ => return (← getBinaryAlias parserAliasesRef n) (← visit d₁) (← visit d₂)
| ParserDescr.node k prec d => return leadingNode k prec (← visit d)
| ParserDescr.nodeWithAntiquot n k d => return nodeWithAntiquot n k (← visit d) (anonymous := true)
| ParserDescr.nodeWithAntiquot n k d => return withCache k (nodeWithAntiquot n k (← visit d) (anonymous := true))
| ParserDescr.sepBy p sep psep trail => return sepBy (← visit p) sep (← visit psep) trail
| ParserDescr.sepBy1 p sep psep trail => return sepBy1 (← visit p) sep (← visit psep) trail
| ParserDescr.trailingNode k prec lhsPrec d => return trailingNode k prec lhsPrec (← visit d)

View file

@ -37,9 +37,7 @@ partial def parserNodeKind? (e : Expr) : MetaM (Option Name) := do
let e ← whnfCore e
if e matches Expr.lam .. then
lambdaLetTelescope e fun _ e => parserNodeKind? e
else if e.isAppOfArity ``nodeWithAntiquot 4 then
reduceEval? (e.getArg! 1)
else if e.isAppOfArity ``withAntiquot 2 then
else if e.isAppOfArity ``nodeWithAntiquot 4 || e.isAppOfArity ``withAntiquot 2 || e.isAppOfArity ``withCache 2 then
parserNodeKind? (e.getArg! 1)
else if e.isAppOfArity ``leadingNode 3 || e.isAppOfArity ``trailingNode 4 || e.isAppOfArity ``node 2 then
reduceEval? (e.getArg! 0)

View file

@ -305,6 +305,9 @@ def node.formatter (k : SyntaxNodeKind) (p : Formatter) : Formatter := do
checkKind k;
visitArgs p
@[combinator_formatter withCache]
def withCache.formatter (_parserName : Name) (p : Formatter) : Formatter := p
@[combinator_formatter trailingNode]
def trailingNode.formatter (k : SyntaxNodeKind) (_ _ : Nat) (p : Formatter) : Formatter := do
checkKind k

View file

@ -395,6 +395,9 @@ def node.parenthesizer (k : SyntaxNodeKind) (p : Parenthesizer) : Parenthesizer
def checkPrec.parenthesizer (prec : Nat) : Parenthesizer :=
addPrecCheck prec
@[combinator_parenthesizer withCache]
def withCache.parenthesizer (_parserName : Name) (p : Parenthesizer) : Parenthesizer := p
@[combinator_parenthesizer leadingNode]
def leadingNode.parenthesizer (k : SyntaxNodeKind) (prec : Nat) (p : Parenthesizer) : Parenthesizer := do
node.parenthesizer k p

View file

@ -1,8 +1,8 @@
some { range := { pos := { line := 128, column := 42 },
some { range := { pos := { line := 133, column := 42 },
charUtf16 := 42,
endPos := { line := 134, column := 31 },
endPos := { line := 139, column := 31 },
endCharUtf16 := 31 },
selectionRange := { pos := { line := 128, column := 46 },
selectionRange := { pos := { line := 133, column := 46 },
charUtf16 := 46,
endPos := { line := 128, column := 58 },
endPos := { line := 133, column := 58 },
endCharUtf16 := 58 } }