diff --git a/src/Lean/Parser/Basic.lean b/src/Lean/Parser/Basic.lean index 23f370e525..66b377b6f3 100644 --- a/src/Lean/Parser/Basic.lean +++ b/src/Lean/Parser/Basic.lean @@ -581,6 +581,26 @@ fun c s => { info := sepBy1Info p.info sep.info, fn := sepBy1Fn allowTrailingSep p.fn sep.fn unboxSingleton } +/- Apply `f` to the syntax object produced by `p` -/ +@[inline] def withResultOfFn (p : ParserFn) (f : Syntax → Syntax) : ParserFn := +fun c s => + let s := p c s; + if s.hasError then s + else + let stx := s.stxStack.back; + s.popSyntax.pushSyntax (f stx) + +@[noinline] def withResultOfInfo (p : ParserInfo) : ParserInfo := +{ collectTokens := p.collectTokens, + collectKinds := p.collectKinds } + +@[inline] def withResultOf (p : Parser) (f : Syntax → Syntax) : Parser := +{ info := withResultOfInfo p.info, + fn := withResultOfFn p.fn f } + +abbrev unboxSingleton (p : Parser) : Parser := +withResultOf p fun stx => if stx.getNumArgs == 1 then stx.getArg 0 else stx + @[specialize] partial def satisfyFn (p : Char → Bool) (errorMsg : String := "unexpected character") : ParserFn | c, s => let i := s.pos; diff --git a/src/Lean/PrettyPrinter/Formatter.lean b/src/Lean/PrettyPrinter/Formatter.lean index 79c07419fa..9e8f127222 100644 --- a/src/Lean/PrettyPrinter/Formatter.lean +++ b/src/Lean/PrettyPrinter/Formatter.lean @@ -323,6 +323,10 @@ else def optional.formatter (p : Formatter) : Formatter := do concatArgs p +@[combinatorFormatter Parser.withResultOf] +def withResultOf.formatter (p : Formatter) (f : Syntax → Syntax) : Formatter := do +concatArgs p + @[combinatorFormatter sepBy] def sepBy.formatter (p pSep : Formatter) : Formatter := do stx ← getCur; diff --git a/src/Lean/PrettyPrinter/Parenthesizer.lean b/src/Lean/PrettyPrinter/Parenthesizer.lean index c4bf6288a2..9b17b33e9d 100644 --- a/src/Lean/PrettyPrinter/Parenthesizer.lean +++ b/src/Lean/PrettyPrinter/Parenthesizer.lean @@ -408,6 +408,10 @@ else def optional.parenthesizer (p : Parenthesizer) : Parenthesizer := do visitArgs p +@[combinatorParenthesizer Lean.Parser.withResultOf] +def withResultOf.parenthesizer (p : Parenthesizer) (f : Syntax → Syntax) : Parenthesizer := do +visitArgs p + @[combinatorParenthesizer Lean.Parser.sepBy] def sepBy.parenthesizer (p pSep : Parenthesizer) : Parenthesizer := do stx ← getCur;