feat: request cancellation (#7054)

This PR adds language server support for request cancellation to the
following expensive requests: Code actions, auto-completion, document
symbols, folding ranges and semantic highlighting. This means that when
the client informs the language server that a request is stale (e.g.
because it belongs to a previous state of the document), the language
server will now prematurely cancel the computation of the response in
order to reduce the CPU load for requests that will be discarded by the
client anyways.
This commit is contained in:
Marc Huisinga 2025-02-14 12:55:43 +01:00 committed by GitHub
parent 22d1d04059
commit 05fb67af90
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 142 additions and 68 deletions

View file

@ -125,6 +125,7 @@ def handleCodeAction (params : CodeActionParams) : RequestM (RequestTask (Array
let caps ← names.mapM evalCodeActionProvider
return (← builtinCodeActionProviders.get).toList.toArray ++ Array.zip names caps
caps.flatMapM fun (providerName, cap) => do
RequestM.checkCancelled
let cas ← cap params snap
cas.mapIdxM fun i lca => do
if lca.lazy?.isNone then return lca.eager

View file

@ -5,6 +5,7 @@ Authors: Leonardo de Moura, Marc Huisinga
-/
prelude
import Lean.Server.Completion.CompletionCollectors
import Lean.Server.RequestCancellation
import Std.Data.HashMap
namespace Lean.Server.Completion
@ -61,11 +62,12 @@ partial def find?
(cmdStx : Syntax)
(infoTree : InfoTree)
(caps : ClientCapabilities)
: IO CompletionList := do
: CancellableM CompletionList := do
let prioritizedPartitions := findPrioritizedCompletionPartitionsAt fileMap hoverPos cmdStx infoTree
let mut allCompletions := #[]
for partition in prioritizedPartitions do
for (i, completionInfoPos) in partition do
CancellableM.checkCancelled
let completions : Array ScoredCompletionItem ←
match i.info with
| .id stx id danglingDot lctx .. =>

View file

@ -8,6 +8,7 @@ import Lean.Data.FuzzyMatching
import Lean.Elab.Tactic.Doc
import Lean.Server.Completion.CompletionResolution
import Lean.Server.Completion.EligibleHeaderDecls
import Lean.Server.RequestCancellation
namespace Lean.Server.Completion
open Elab
@ -36,7 +37,7 @@ section Infrastructure
Monad used for completion computation that allows modifying a completion `State` and reading
`CompletionParams`.
-/
private abbrev M := ReaderT Context $ StateRefT State MetaM
private abbrev M := ReaderT Context $ StateRefT State $ CancellableT MetaM
/-- Adds a new completion item to the state in `M`. -/
private def addItem
@ -114,10 +115,13 @@ section Infrastructure
(ctx : ContextInfo)
(lctx : LocalContext)
(x : M Unit)
: IO (Array ScoredCompletionItem) :=
ctx.runMetaM lctx do
let (_, s) ← x.run ⟨params, completionInfoPos⟩ |>.run {}
return s.items
: CancellableM (Array ScoredCompletionItem) := do
let tk ← read
let r ← ctx.runMetaM lctx do
x.run ⟨params, completionInfoPos⟩ |>.run {} |>.run tk
match r with
| .error _ => throw .requestCancelled
| .ok (_, s) => return s.items
end Infrastructure
@ -161,6 +165,16 @@ section Utils
return fuzzyMatchScoreWithThreshold? s₁ s₂ |>.map (declName, · / (p₂.getNumParts + 1).toFloat)
return none
private def forEligibleDeclsWithCancellationM [Monad m] [MonadEnv m]
[MonadLiftT (ST IO.RealWorld) m] [MonadCancellable m] [MonadLiftT IO m]
(f : Name → ConstantInfo → m PUnit) : m PUnit := do
let _ ← StateT.run (s := 0) <| forEligibleDeclsM fun decl ci => do
modify (· + 1)
if (← get) >= 10000 then
RequestCancellation.check
set <| 0
f decl ci
end Utils
section IdCompletionUtils
@ -349,7 +363,7 @@ private def idCompletionCore
addUnresolvedCompletionItem localDecl.userName (.fvar localDecl.fvarId) (kind := CompletionItemKind.variable) score
-- search for matches in the environment
let env ← getEnv
forEligibleDeclsM fun declName c => do
forEligibleDeclsWithCancellationM fun declName c => do
let bestMatch? ← (·.2) <$> StateT.run (s := none) do
let matchUsingNamespace (ns : Name) : StateT (Option (Name × Float)) M Unit := do
let some (label, score) ← matchDecl? ns id danglingDot declName
@ -380,6 +394,7 @@ private def idCompletionCore
matchUsingNamespace Name.anonymous
if let some (bestLabel, bestScore) := bestMatch? then
addUnresolvedCompletionItem bestLabel (.const declName) (← getCompletionKindForDecl c) bestScore
RequestCancellation.check
let matchAlias (ns : Name) (alias : Name) : Option Float :=
-- Recall that aliases may not be atomic and include the namespace where they were created.
if ns.isPrefixOf alias then
@ -434,7 +449,7 @@ def idCompletion
(id : Name)
(hoverInfo : HoverInfo)
(danglingDot : Bool)
: IO (Array ScoredCompletionItem) :=
: CancellableM (Array ScoredCompletionItem) :=
runM params completionInfoPos ctx lctx do
idCompletionCore ctx stx id hoverInfo danglingDot
@ -443,7 +458,7 @@ def dotCompletion
(completionInfoPos : Nat)
(ctx : ContextInfo)
(info : TermInfo)
: IO (Array ScoredCompletionItem) :=
: CancellableM (Array ScoredCompletionItem) :=
runM params completionInfoPos ctx info.lctx do
let nameSet ← try
getDotCompletionTypeNames (← instantiateMVars (← inferType info.expr))
@ -452,7 +467,7 @@ def dotCompletion
if nameSet.isEmpty then
return
forEligibleDeclsM fun declName c => do
forEligibleDeclsWithCancellationM fun declName c => do
let unnormedTypeName := declName.getPrefix
if ! nameSet.contains unnormedTypeName then
return
@ -471,7 +486,7 @@ def dotIdCompletion
(lctx : LocalContext)
(id : Name)
(expectedType? : Option Expr)
: IO (Array ScoredCompletionItem) :=
: CancellableM (Array ScoredCompletionItem) :=
runM params completionInfoPos ctx lctx do
let some expectedType := expectedType?
| return ()
@ -485,7 +500,7 @@ def dotIdCompletion
catch _ =>
pure RBTree.empty
forEligibleDeclsM fun declName c => do
forEligibleDeclsWithCancellationM fun declName c => do
let unnormedTypeName := declName.getPrefix
if ! nameSet.contains unnormedTypeName then
return
@ -513,7 +528,7 @@ def fieldIdCompletion
(lctx : LocalContext)
(id : Option Name)
(structName : Name)
: IO (Array ScoredCompletionItem) :=
: CancellableM (Array ScoredCompletionItem) :=
runM params completionInfoPos ctx lctx do
let idStr := id.map (·.toString) |>.getD ""
let fieldNames := getStructureFieldsFlattened (← getEnv) structName (includeSubobjectFields := false)

View file

@ -543,14 +543,14 @@ section NotificationHandling
let newDocText := foldDocumentChanges changes oldDoc.meta.text
updateDocument ⟨docId.uri, newVersion, newDocText, oldDoc.meta.dependencyBuildMode⟩
for (_, r) in st.pendingRequests do
r.cancelTk.cancel .edit
r.cancelTk.cancelByEdit
def handleCancelRequest (p : CancelParams) : WorkerM Unit := do
let st ← get
let some r := st.pendingRequests.find? p.id
| return
r.cancelTk.cancel .cancelRequest
r.cancelTk.cancelByCancelRequest
set <| { st with pendingRequests := st.pendingRequests.erase p.id }
/--
@ -741,6 +741,12 @@ section MessageHandling
pure <| Task.pure <| .ok ()
| Except.ok t => (IO.mapTask · t) fun
| Except.ok r => do
if ← cancelTk.wasCancelledByCancelRequest then
-- Try not to emit a partial response if this request was cancelled.
-- Clients usually discard responses for requests that they cancelled anyways,
-- but it's still good to send less over the wire in this case.
emitResponse ctx (isComplete := false) <| RequestError.requestCancelled.toLspResponseError id
return
emitResponse ctx (isComplete := r.isComplete) <| .response id (toJson r.response)
| Except.error e =>
emitResponse ctx (isComplete := false) <| e.toLspResponseError id

View file

@ -121,7 +121,7 @@ def handleInlayHints (_ : InlayHintParams) (s : InlayHintState) :
| some lastEditTimestamp =>
let timeSinceLastEditMs := timestamp - lastEditTimestamp
inlayHintEditDelayMs - timeSinceLastEditMs
let (snaps, _, isComplete) ← ctx.doc.cmdSnaps.getFinishedPrefixWithConsistentLatency editDelayMs.toUInt32 (cancelTk? := ctx.cancelTk.truncatedTask)
let (snaps, _, isComplete) ← ctx.doc.cmdSnaps.getFinishedPrefixWithConsistentLatency editDelayMs.toUInt32 (cancelTk? := ctx.cancelTk.cancellationTask)
let finishedRange? : Option String.Range := do
return ⟨⟨0⟩, ← List.max? <| snaps.map (fun s => s.endPos)⟩
let oldInlayHints :=
@ -143,7 +143,6 @@ def handleInlayHints (_ : InlayHintParams) (s : InlayHintState) :
let lspInlayHints ← inlayHints.mapM (·.toLspInlayHint srcSearchPath ctx.doc.meta.text)
let r := { response := lspInlayHints, isComplete }
let s := { s with oldInlayHints := inlayHints }
RequestM.checkCanceled
return (r, s)
def handleInlayHintsDidChange (p : DidChangeTextDocumentParams)

View file

@ -426,13 +426,14 @@ partial def handleDocumentSymbol (_ : DocumentSymbolParams)
let t := doc.cmdSnaps.waitAll
mapTask t fun (snaps, _) => do
let mut stxs := snaps.map (·.stx)
return { syms := toDocumentSymbols doc.meta.text stxs #[] [] }
return { syms := toDocumentSymbols doc.meta.text stxs #[] [] }
where
toDocumentSymbols (text : FileMap) (stxs : List Syntax)
(syms : Array DocumentSymbol) (stack : List NamespaceEntry) :
Array DocumentSymbol :=
RequestM (Array DocumentSymbol) := do
RequestM.checkCancelled
match stxs with
| [] => stack.foldl (fun syms entry => entry.finish text syms none) syms
| [] => return stack.foldl (fun syms entry => entry.finish text syms none) syms
| stx::stxs => match stx with
| `(namespace $id) =>
let entry := { name := id.getId.componentsRev, stx, selection := id, prevSiblings := syms }
@ -455,9 +456,9 @@ where
let syms := entry.finish text syms stx
popStack (n - entry.name.length) syms stack
popStack (id.map (·.getId.getNumParts) |>.getD 1) syms stack
| _ => Id.run do
| _ => do
unless stx.isOfKind ``Lean.Parser.Command.declaration do
return toDocumentSymbols text stxs syms stack
return toDocumentSymbols text stxs syms stack
if let some stxRange := stx.getRange? then
let (name, selection) := match stx with
| `($_:declModifiers $_:attrKind instance $[$np:namedPrio]? $[$id$[.{$ls,*}]?]? $sig:declSig $_) =>
@ -475,7 +476,7 @@ where
range := stxRange.toLspRange text
selectionRange := selRange.toLspRange text
}
return toDocumentSymbols text stxs (syms.push sym) stack
return toDocumentSymbols text stxs (syms.push sym) stack
toDocumentSymbols text stxs syms stack
partial def handleFoldingRange (_ : FoldingRangeParams)
@ -494,7 +495,9 @@ partial def handleFoldingRange (_ : FoldingRangeParams)
if let (_, start)::rest := sections then
addRange text FoldingRangeKind.region start text.source.endPos
addRanges text rest []
| stx::stxs => match stx with
| stx::stxs => do
RequestM.checkCancelled
match stx with
| `(namespace $id) =>
addRanges text ((id.getId.getNumParts, stx.getPos?)::sections) stxs
| `(section $(id)?) =>

View file

@ -147,13 +147,12 @@ def handleSemanticTokens (beginPos : String.Pos) (endPos? : Option String.Pos)
-- for the full file before sending a response. This means that the response will be incomplete,
-- which we mitigate by regularly sending `workspace/semanticTokens/refresh` requests in the
-- `FileWorker` to tell the client to re-compute the semantic tokens.
let (snaps, _, isComplete) ← doc.cmdSnaps.getFinishedPrefixWithTimeout 3000 (cancelTk? := ctx.cancelTk.truncatedTask)
let (snaps, _, isComplete) ← doc.cmdSnaps.getFinishedPrefixWithTimeout 3000 (cancelTk? := ctx.cancelTk.cancellationTask)
asTask <| do
return { response := ← run doc snaps, isComplete }
| some endPos =>
let t := doc.cmdSnaps.waitUntil (·.endPos >= endPos)
mapTask t fun (snaps, _) => do
RequestM.checkCanceled
return { response := ← run doc snaps, isComplete := true }
where
run doc snaps : RequestM SemanticTokens := do
@ -164,8 +163,11 @@ where
let syntaxBasedSemanticTokens := collectSyntaxBasedSemanticTokens s.stx
let infoBasedSemanticTokens := collectInfoBasedSemanticTokens s.infoTree
leanSemanticTokens := leanSemanticTokens ++ syntaxBasedSemanticTokens ++ infoBasedSemanticTokens
RequestM.checkCancelled
let absoluteLspSemanticTokens := computeAbsoluteLspSemanticTokens doc.meta.text beginPos endPos? leanSemanticTokens
RequestM.checkCancelled
let absoluteLspSemanticTokens := filterDuplicateSemanticTokens absoluteLspSemanticTokens
RequestM.checkCancelled
let semanticTokens := computeDeltaLspSemanticTokens absoluteLspSemanticTokens
return semanticTokens

View file

@ -0,0 +1,77 @@
/-
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Marc Huisinga
-/
prelude
import Init.System.Promise
namespace Lean.Server
structure RequestCancellationToken where
cancelledByCancelRequest : IO.Ref Bool
cancelledByEdit : IO.Ref Bool
cancellationPromise : IO.Promise Unit
namespace RequestCancellationToken
def new : IO RequestCancellationToken := do
return {
cancelledByCancelRequest := ← IO.mkRef false
cancelledByEdit := ← IO.mkRef false
cancellationPromise := ← IO.Promise.new
}
def cancelByCancelRequest (tk : RequestCancellationToken) : IO Unit := do
tk.cancelledByCancelRequest.set true
tk.cancellationPromise.resolve ()
def cancelByEdit (tk : RequestCancellationToken) : IO Unit := do
tk.cancelledByEdit.set true
tk.cancellationPromise.resolve ()
def cancellationTask (tk : RequestCancellationToken) : Task Unit :=
tk.cancellationPromise.result!
def wasCancelledByCancelRequest (tk : RequestCancellationToken) : IO Bool :=
tk.cancelledByCancelRequest.get
def wasCancelledByEdit (tk : RequestCancellationToken) : IO Bool := do
tk.cancelledByEdit.get
end RequestCancellationToken
structure RequestCancellation where
def RequestCancellation.requestCancelled : RequestCancellation := {}
abbrev CancellableT m := ReaderT RequestCancellationToken (ExceptT RequestCancellation m)
abbrev CancellableM := CancellableT IO
def CancellableT.run (tk : RequestCancellationToken) (x : CancellableT m α) :
m (Except RequestCancellation α) :=
x tk
def CancellableM.run (tk : RequestCancellationToken) (x : CancellableM α) :
IO (Except RequestCancellation α) :=
CancellableT.run tk x
def CancellableT.checkCancelled [Monad m] [MonadLiftT IO m] : CancellableT m Unit := do
let tk ← read
if ← tk.wasCancelledByCancelRequest then
throw .requestCancelled
def CancellableM.checkCancelled : CancellableM Unit :=
CancellableT.checkCancelled
class MonadCancellable (m : Type → Type v) where
checkCancelled : m PUnit
instance (m n) [MonadLift m n] [MonadCancellable m] : MonadCancellable n where
checkCancelled := liftM (MonadCancellable.checkCancelled : m PUnit)
instance [Monad m] [MonadLiftT IO m] : MonadCancellable (CancellableT m) where
checkCancelled := CancellableT.checkCancelled
def RequestCancellation.check [MonadCancellable m] : m Unit :=
MonadCancellable.checkCancelled

View file

@ -11,6 +11,8 @@ import Lean.Data.Json
import Lean.Data.Lsp
import Lean.Elab.Command
import Lean.Server.RequestCancellation
import Lean.Server.FileSource
import Lean.Server.FileWorker.Utils
@ -127,47 +129,6 @@ def toLspResponseError (id : RequestID) (e : RequestError) : ResponseError Unit
end RequestError
inductive RequestCancellationCause where
| cancelRequest
| edit
deriving Inhabited, BEq
structure RequestCancellationToken where
promise : IO.Promise RequestCancellationCause
namespace RequestCancellationToken
def new : IO RequestCancellationToken := do
return { promise := ← IO.Promise.new }
def cancel (tk : RequestCancellationToken) (cause : RequestCancellationCause) : IO Unit :=
tk.promise.resolve cause
def task (tk : RequestCancellationToken) : Task RequestCancellationCause :=
tk.promise.result!
def truncatedTask (tk : RequestCancellationToken) : Task Unit :=
tk.task.map (sync := true) fun _ => ()
def cancelled? (tk : RequestCancellationToken) : IO (Option RequestCancellationCause) := do
let t := tk.task
if ← IO.hasFinished t then
return some t.get
else
return none
def wasCancelledByCancelRequest (tk : RequestCancellationToken) : IO Bool := do
let some c ← tk.cancelled?
| return false
return c matches .cancelRequest
def wasCancelledByEdit (tk : RequestCancellationToken) : IO Bool := do
let some c ← tk.cancelled?
| return false
return c matches .edit
end RequestCancellationToken
def parseRequestParams (paramType : Type) [FromJson paramType] (params : Json)
: Except RequestError paramType :=
fromJson? params |>.mapError fun inner =>
@ -201,6 +162,14 @@ instance : MonadLift (EIO Exception) RequestM where
| .error e => throw <| ← RequestError.ofException e
| .ok v => return v
instance : MonadLift CancellableM RequestM where
monadLift x := do
let ctx ← read
let r ← x.run ctx.cancelTk
match r with
| .error _ => throw RequestError.requestCancelled
| .ok v => return v
namespace RequestM
open FileWorker
open Snapshots
@ -224,7 +193,7 @@ def bindTask (t : Task α) (f : α → RequestM (RequestTask β)) : RequestM (Re
let rc ← readThe RequestContext
EIO.bindTask t (f · rc)
def checkCanceled : RequestM Unit := do
def checkCancelled : RequestM Unit := do
let rc ← readThe RequestContext
if ← rc.cancelTk.wasCancelledByCancelRequest then
throw .requestCancelled