diff --git a/src/Lean/Elab/BuiltinNotation.lean b/src/Lean/Elab/BuiltinNotation.lean index 8d3cf8b446..96f7b903d3 100644 --- a/src/Lean/Elab/BuiltinNotation.lean +++ b/src/Lean/Elab/BuiltinNotation.lean @@ -89,15 +89,20 @@ private def elabParserMacroAux (prec e : Term) (withAnonymousAntiquot : Bool) : match extractMacroScopes declName with | { name := .str _ s, .. } => let kind := quote declName - let s := quote s - ``(withAntiquot (mkAntiquot $s $kind $(quote withAnonymousAntiquot)) (leadingNode $kind $prec $e)) + let mut p ← ``(withAntiquot + (mkAntiquot $(quote s) $kind $(quote withAnonymousAntiquot)) + (leadingNode $kind $prec $e)) + -- cache only unparameterized parsers + if (← getLCtx).all (·.isAuxDecl) then + p ← ``(withCache $kind $p) + return p | _ => throwError "invalid `leading_parser` macro, unexpected declaration name" @[builtin_term_elab «leading_parser»] def elabLeadingParserMacro : TermElab := - adaptExpander fun stx => match stx with - | `(leading_parser $[: $prec?]? $[(withAnonymousAntiquot := $anon?)]? $e) => - elabParserMacroAux (prec?.getD (quote Parser.maxPrec)) e (anon?.all (·.raw.isOfKind ``Parser.Term.trueVal)) - | _ => throwUnsupportedSyntax + adaptExpander fun + | `(leading_parser $[: $prec?]? $[(withAnonymousAntiquot := $anon?)]? $e) => + elabParserMacroAux (prec?.getD (quote Parser.maxPrec)) e (anon?.all (·.raw.isOfKind ``Parser.Term.trueVal)) + | _ => throwUnsupportedSyntax private def elabTParserMacroAux (prec lhsPrec e : Term) : TermElabM Syntax := do let declName? ← getDeclName? diff --git a/src/Lean/Parser/Basic.lean b/src/Lean/Parser/Basic.lean index 9dab951edc..43564172b8 100644 --- a/src/Lean/Parser/Basic.lean +++ b/src/Lean/Parser/Basic.lean @@ -160,24 +160,24 @@ structure TokenCacheEntry where stopPos : String.Pos := 0 token : Syntax := Syntax.missing -structure CategoryCacheKey extends ParserContextCacheKey where - cat : Name - pos : String.Pos +structure ParserCacheKey extends ParserContextCacheKey where + parserName : Name + pos : String.Pos deriving BEq, Hashable -structure CategoryCacheEntry where +structure ParserCacheEntry where stx : Syntax lhsPrec : Nat newPos : String.Pos errorMsg : Option Error structure ParserCache where - tokenCache : TokenCacheEntry - categoryCache : HashMap CategoryCacheKey CategoryCacheEntry + tokenCache : TokenCacheEntry + parserCache : HashMap ParserCacheKey ParserCacheEntry def initCacheForInput (input : String) : ParserCache where - tokenCache := { startPos := input.endPos + ' ' /- make sure it is not a valid position -/ } - categoryCache := {} + tokenCache := { startPos := input.endPos + ' ' /- make sure it is not a valid position -/ } + parserCache := {} structure ParserState where stxStack : Array Syntax := #[] @@ -483,6 +483,23 @@ def suppressInsideQuot (p : Parser) : Parser := { fn := suppressInsideQuotFn p.fn } +def withCacheFn (parserName : Name) (p : ParserFn) : ParserFn := fun c s => Id.run do + let key := ⟨c.toParserContextCacheKey, parserName, s.pos⟩ + if let some r := s.cache.parserCache.find? key then + -- TODO: turn this into a proper trace once we have these in the parser + --dbg_trace "parser cache hit: {parserName}:{s.pos} -> {r.stx}" + match s with + | ⟨stack, _, _, cache, _⟩ => return ⟨stack.push r.stx, r.lhsPrec, r.newPos, cache, r.errorMsg⟩ + let initStackSz := s.stackSize + let s := p c s + if s.stackSize != initStackSz + 1 then + panic! s!"withCacheFn: unexpected stack growth {s.stxStack}" + { s with cache.parserCache := s.cache.parserCache.insert key ⟨s.stxStack.back, s.lhsPrec, s.pos, s.errorMsg⟩ } + +def withCache (parserName : Name) (p : Parser) : Parser where + info := p.info + fn := withCacheFn parserName p.fn + def leadingNode (n : SyntaxNodeKind) (prec : Nat) (p : Parser) : Parser := checkPrec prec >> node n p >> setLhsPrec prec @@ -1708,20 +1725,8 @@ builtin_initialize categoryParserFnExtension : EnvExtension CategoryParserFn ← def categoryParserFn (catName : Name) : ParserFn := fun ctx s => categoryParserFnExtension.getState ctx.env catName ctx s -def categoryParser (catName : Name) (prec : Nat) : Parser := { - fn := fun c s => Id.run do - let c := { c with prec } - let key := ⟨c.toParserContextCacheKey, catName, s.pos⟩ - if let some r := s.cache.categoryCache.find? key then - match s with - | ⟨stack, _, _, cache, _⟩ => return ⟨stack.push r.stx, r.lhsPrec, r.newPos, cache, r.errorMsg⟩ - let initStackSz := s.stackSize - let s := categoryParserFn catName c s - if s.stackSize > initStackSz + 1 then - panic! s!"categoryParser: unexpected stack growth {s.stxStack}" - let s := if s.stackSize == initStackSz then s.pushSyntax .missing else s - { s with cache.categoryCache := s.cache.categoryCache.insert key ⟨s.stxStack.back, s.lhsPrec, s.pos, s.errorMsg⟩ } -} +def categoryParser (catName : Name) (prec : Nat) : Parser where + fn c s := withCacheFn catName (categoryParserFn catName) { c with prec } s -- Define `termParser` here because we need it for antiquotations def termParser (prec : Nat := 0) : Parser := diff --git a/src/Lean/Parser/Extension.lean b/src/Lean/Parser/Extension.lean index 71f3acceca..a7e9e5c8e9 100644 --- a/src/Lean/Parser/Extension.lean +++ b/src/Lean/Parser/Extension.lean @@ -280,7 +280,7 @@ partial def compileParserDescr (categories : ParserCategories) (d : ParserDescr) | ParserDescr.unary n d => return (← getUnaryAlias parserAliasesRef n) (← visit d) | ParserDescr.binary n d₁ d₂ => return (← getBinaryAlias parserAliasesRef n) (← visit d₁) (← visit d₂) | ParserDescr.node k prec d => return leadingNode k prec (← visit d) - | ParserDescr.nodeWithAntiquot n k d => return nodeWithAntiquot n k (← visit d) (anonymous := true) + | ParserDescr.nodeWithAntiquot n k d => return withCache k (nodeWithAntiquot n k (← visit d) (anonymous := true)) | ParserDescr.sepBy p sep psep trail => return sepBy (← visit p) sep (← visit psep) trail | ParserDescr.sepBy1 p sep psep trail => return sepBy1 (← visit p) sep (← visit psep) trail | ParserDescr.trailingNode k prec lhsPrec d => return trailingNode k prec lhsPrec (← visit d) diff --git a/src/Lean/ParserCompiler.lean b/src/Lean/ParserCompiler.lean index 1128960b99..ba73523eb6 100644 --- a/src/Lean/ParserCompiler.lean +++ b/src/Lean/ParserCompiler.lean @@ -37,9 +37,7 @@ partial def parserNodeKind? (e : Expr) : MetaM (Option Name) := do let e ← whnfCore e if e matches Expr.lam .. then lambdaLetTelescope e fun _ e => parserNodeKind? e - else if e.isAppOfArity ``nodeWithAntiquot 4 then - reduceEval? (e.getArg! 1) - else if e.isAppOfArity ``withAntiquot 2 then + else if e.isAppOfArity ``nodeWithAntiquot 4 || e.isAppOfArity ``withAntiquot 2 || e.isAppOfArity ``withCache 2 then parserNodeKind? (e.getArg! 1) else if e.isAppOfArity ``leadingNode 3 || e.isAppOfArity ``trailingNode 4 || e.isAppOfArity ``node 2 then reduceEval? (e.getArg! 0) diff --git a/src/Lean/PrettyPrinter/Formatter.lean b/src/Lean/PrettyPrinter/Formatter.lean index 87c613c1b0..8f86ee48d6 100644 --- a/src/Lean/PrettyPrinter/Formatter.lean +++ b/src/Lean/PrettyPrinter/Formatter.lean @@ -305,6 +305,9 @@ def node.formatter (k : SyntaxNodeKind) (p : Formatter) : Formatter := do checkKind k; visitArgs p +@[combinator_formatter withCache] +def withCache.formatter (_parserName : Name) (p : Formatter) : Formatter := p + @[combinator_formatter trailingNode] def trailingNode.formatter (k : SyntaxNodeKind) (_ _ : Nat) (p : Formatter) : Formatter := do checkKind k diff --git a/src/Lean/PrettyPrinter/Parenthesizer.lean b/src/Lean/PrettyPrinter/Parenthesizer.lean index 1ff7eae7bd..f57c80f170 100644 --- a/src/Lean/PrettyPrinter/Parenthesizer.lean +++ b/src/Lean/PrettyPrinter/Parenthesizer.lean @@ -395,6 +395,9 @@ def node.parenthesizer (k : SyntaxNodeKind) (p : Parenthesizer) : Parenthesizer def checkPrec.parenthesizer (prec : Nat) : Parenthesizer := addPrecCheck prec +@[combinator_parenthesizer withCache] +def withCache.parenthesizer (_parserName : Name) (p : Parenthesizer) : Parenthesizer := p + @[combinator_parenthesizer leadingNode] def leadingNode.parenthesizer (k : SyntaxNodeKind) (prec : Nat) (p : Parenthesizer) : Parenthesizer := do node.parenthesizer k p diff --git a/tests/lean/1021.lean.expected.out b/tests/lean/1021.lean.expected.out index 12fdf9d22a..9983172c94 100644 --- a/tests/lean/1021.lean.expected.out +++ b/tests/lean/1021.lean.expected.out @@ -1,8 +1,8 @@ -some { range := { pos := { line := 128, column := 42 }, +some { range := { pos := { line := 133, column := 42 }, charUtf16 := 42, - endPos := { line := 134, column := 31 }, + endPos := { line := 139, column := 31 }, endCharUtf16 := 31 }, - selectionRange := { pos := { line := 128, column := 46 }, + selectionRange := { pos := { line := 133, column := 46 }, charUtf16 := 46, - endPos := { line := 128, column := 58 }, + endPos := { line := 133, column := 58 }, endCharUtf16 := 58 } }