From 05fb67af90c5fca323a6b2af47513587563323c9 Mon Sep 17 00:00:00 2001 From: Marc Huisinga Date: Fri, 14 Feb 2025 12:55:43 +0100 Subject: [PATCH] feat: request cancellation (#7054) This PR adds language server support for request cancellation to the following expensive requests: Code actions, auto-completion, document symbols, folding ranges and semantic highlighting. This means that when the client informs the language server that a request is stale (e.g. because it belongs to a previous state of the document), the language server will now prematurely cancel the computation of the response in order to reduce the CPU load for requests that will be discarded by the client anyways. --- src/Lean/Server/CodeActions/Basic.lean | 1 + src/Lean/Server/Completion.lean | 4 +- .../Completion/CompletionCollectors.lean | 39 +++++++--- src/Lean/Server/FileWorker.lean | 10 ++- src/Lean/Server/FileWorker/InlayHints.lean | 3 +- .../Server/FileWorker/RequestHandling.lean | 17 ++-- .../FileWorker/SemanticHighlighting.lean | 6 +- src/Lean/Server/RequestCancellation.lean | 77 +++++++++++++++++++ src/Lean/Server/Requests.lean | 53 +++---------- 9 files changed, 142 insertions(+), 68 deletions(-) create mode 100644 src/Lean/Server/RequestCancellation.lean diff --git a/src/Lean/Server/CodeActions/Basic.lean b/src/Lean/Server/CodeActions/Basic.lean index 43963cac6d..2ca55171fd 100644 --- a/src/Lean/Server/CodeActions/Basic.lean +++ b/src/Lean/Server/CodeActions/Basic.lean @@ -125,6 +125,7 @@ def handleCodeAction (params : CodeActionParams) : RequestM (RequestTask (Array let caps ← names.mapM evalCodeActionProvider return (← builtinCodeActionProviders.get).toList.toArray ++ Array.zip names caps caps.flatMapM fun (providerName, cap) => do + RequestM.checkCancelled let cas ← cap params snap cas.mapIdxM fun i lca => do if lca.lazy?.isNone then return lca.eager diff --git a/src/Lean/Server/Completion.lean b/src/Lean/Server/Completion.lean index 9ffb290958..61446b366b 100644 --- a/src/Lean/Server/Completion.lean +++ b/src/Lean/Server/Completion.lean @@ -5,6 +5,7 @@ Authors: Leonardo de Moura, Marc Huisinga -/ prelude import Lean.Server.Completion.CompletionCollectors +import Lean.Server.RequestCancellation import Std.Data.HashMap namespace Lean.Server.Completion @@ -61,11 +62,12 @@ partial def find? (cmdStx : Syntax) (infoTree : InfoTree) (caps : ClientCapabilities) - : IO CompletionList := do + : CancellableM CompletionList := do let prioritizedPartitions := findPrioritizedCompletionPartitionsAt fileMap hoverPos cmdStx infoTree let mut allCompletions := #[] for partition in prioritizedPartitions do for (i, completionInfoPos) in partition do + CancellableM.checkCancelled let completions : Array ScoredCompletionItem ← match i.info with | .id stx id danglingDot lctx .. => diff --git a/src/Lean/Server/Completion/CompletionCollectors.lean b/src/Lean/Server/Completion/CompletionCollectors.lean index 97ff51ce9d..4498432614 100644 --- a/src/Lean/Server/Completion/CompletionCollectors.lean +++ b/src/Lean/Server/Completion/CompletionCollectors.lean @@ -8,6 +8,7 @@ import Lean.Data.FuzzyMatching import Lean.Elab.Tactic.Doc import Lean.Server.Completion.CompletionResolution import Lean.Server.Completion.EligibleHeaderDecls +import Lean.Server.RequestCancellation namespace Lean.Server.Completion open Elab @@ -36,7 +37,7 @@ section Infrastructure Monad used for completion computation that allows modifying a completion `State` and reading `CompletionParams`. -/ - private abbrev M := ReaderT Context $ StateRefT State MetaM + private abbrev M := ReaderT Context $ StateRefT State $ CancellableT MetaM /-- Adds a new completion item to the state in `M`. -/ private def addItem @@ -114,10 +115,13 @@ section Infrastructure (ctx : ContextInfo) (lctx : LocalContext) (x : M Unit) - : IO (Array ScoredCompletionItem) := - ctx.runMetaM lctx do - let (_, s) ← x.run ⟨params, completionInfoPos⟩ |>.run {} - return s.items + : CancellableM (Array ScoredCompletionItem) := do + let tk ← read + let r ← ctx.runMetaM lctx do + x.run ⟨params, completionInfoPos⟩ |>.run {} |>.run tk + match r with + | .error _ => throw .requestCancelled + | .ok (_, s) => return s.items end Infrastructure @@ -161,6 +165,16 @@ section Utils return fuzzyMatchScoreWithThreshold? s₁ s₂ |>.map (declName, · / (p₂.getNumParts + 1).toFloat) return none + private def forEligibleDeclsWithCancellationM [Monad m] [MonadEnv m] + [MonadLiftT (ST IO.RealWorld) m] [MonadCancellable m] [MonadLiftT IO m] + (f : Name → ConstantInfo → m PUnit) : m PUnit := do + let _ ← StateT.run (s := 0) <| forEligibleDeclsM fun decl ci => do + modify (· + 1) + if (← get) >= 10000 then + RequestCancellation.check + set <| 0 + f decl ci + end Utils section IdCompletionUtils @@ -349,7 +363,7 @@ private def idCompletionCore addUnresolvedCompletionItem localDecl.userName (.fvar localDecl.fvarId) (kind := CompletionItemKind.variable) score -- search for matches in the environment let env ← getEnv - forEligibleDeclsM fun declName c => do + forEligibleDeclsWithCancellationM fun declName c => do let bestMatch? ← (·.2) <$> StateT.run (s := none) do let matchUsingNamespace (ns : Name) : StateT (Option (Name × Float)) M Unit := do let some (label, score) ← matchDecl? ns id danglingDot declName @@ -380,6 +394,7 @@ private def idCompletionCore matchUsingNamespace Name.anonymous if let some (bestLabel, bestScore) := bestMatch? then addUnresolvedCompletionItem bestLabel (.const declName) (← getCompletionKindForDecl c) bestScore + RequestCancellation.check let matchAlias (ns : Name) (alias : Name) : Option Float := -- Recall that aliases may not be atomic and include the namespace where they were created. if ns.isPrefixOf alias then @@ -434,7 +449,7 @@ def idCompletion (id : Name) (hoverInfo : HoverInfo) (danglingDot : Bool) - : IO (Array ScoredCompletionItem) := + : CancellableM (Array ScoredCompletionItem) := runM params completionInfoPos ctx lctx do idCompletionCore ctx stx id hoverInfo danglingDot @@ -443,7 +458,7 @@ def dotCompletion (completionInfoPos : Nat) (ctx : ContextInfo) (info : TermInfo) - : IO (Array ScoredCompletionItem) := + : CancellableM (Array ScoredCompletionItem) := runM params completionInfoPos ctx info.lctx do let nameSet ← try getDotCompletionTypeNames (← instantiateMVars (← inferType info.expr)) @@ -452,7 +467,7 @@ def dotCompletion if nameSet.isEmpty then return - forEligibleDeclsM fun declName c => do + forEligibleDeclsWithCancellationM fun declName c => do let unnormedTypeName := declName.getPrefix if ! nameSet.contains unnormedTypeName then return @@ -471,7 +486,7 @@ def dotIdCompletion (lctx : LocalContext) (id : Name) (expectedType? : Option Expr) - : IO (Array ScoredCompletionItem) := + : CancellableM (Array ScoredCompletionItem) := runM params completionInfoPos ctx lctx do let some expectedType := expectedType? | return () @@ -485,7 +500,7 @@ def dotIdCompletion catch _ => pure RBTree.empty - forEligibleDeclsM fun declName c => do + forEligibleDeclsWithCancellationM fun declName c => do let unnormedTypeName := declName.getPrefix if ! nameSet.contains unnormedTypeName then return @@ -513,7 +528,7 @@ def fieldIdCompletion (lctx : LocalContext) (id : Option Name) (structName : Name) - : IO (Array ScoredCompletionItem) := + : CancellableM (Array ScoredCompletionItem) := runM params completionInfoPos ctx lctx do let idStr := id.map (·.toString) |>.getD "" let fieldNames := getStructureFieldsFlattened (← getEnv) structName (includeSubobjectFields := false) diff --git a/src/Lean/Server/FileWorker.lean b/src/Lean/Server/FileWorker.lean index 7d8cb36ea8..e84320dcd1 100644 --- a/src/Lean/Server/FileWorker.lean +++ b/src/Lean/Server/FileWorker.lean @@ -543,14 +543,14 @@ section NotificationHandling let newDocText := foldDocumentChanges changes oldDoc.meta.text updateDocument ⟨docId.uri, newVersion, newDocText, oldDoc.meta.dependencyBuildMode⟩ for (_, r) in st.pendingRequests do - r.cancelTk.cancel .edit + r.cancelTk.cancelByEdit def handleCancelRequest (p : CancelParams) : WorkerM Unit := do let st ← get let some r := st.pendingRequests.find? p.id | return - r.cancelTk.cancel .cancelRequest + r.cancelTk.cancelByCancelRequest set <| { st with pendingRequests := st.pendingRequests.erase p.id } /-- @@ -741,6 +741,12 @@ section MessageHandling pure <| Task.pure <| .ok () | Except.ok t => (IO.mapTask · t) fun | Except.ok r => do + if ← cancelTk.wasCancelledByCancelRequest then + -- Try not to emit a partial response if this request was cancelled. + -- Clients usually discard responses for requests that they cancelled anyways, + -- but it's still good to send less over the wire in this case. + emitResponse ctx (isComplete := false) <| RequestError.requestCancelled.toLspResponseError id + return emitResponse ctx (isComplete := r.isComplete) <| .response id (toJson r.response) | Except.error e => emitResponse ctx (isComplete := false) <| e.toLspResponseError id diff --git a/src/Lean/Server/FileWorker/InlayHints.lean b/src/Lean/Server/FileWorker/InlayHints.lean index 339b468e3e..bdc3216b39 100644 --- a/src/Lean/Server/FileWorker/InlayHints.lean +++ b/src/Lean/Server/FileWorker/InlayHints.lean @@ -121,7 +121,7 @@ def handleInlayHints (_ : InlayHintParams) (s : InlayHintState) : | some lastEditTimestamp => let timeSinceLastEditMs := timestamp - lastEditTimestamp inlayHintEditDelayMs - timeSinceLastEditMs - let (snaps, _, isComplete) ← ctx.doc.cmdSnaps.getFinishedPrefixWithConsistentLatency editDelayMs.toUInt32 (cancelTk? := ctx.cancelTk.truncatedTask) + let (snaps, _, isComplete) ← ctx.doc.cmdSnaps.getFinishedPrefixWithConsistentLatency editDelayMs.toUInt32 (cancelTk? := ctx.cancelTk.cancellationTask) let finishedRange? : Option String.Range := do return ⟨⟨0⟩, ← List.max? <| snaps.map (fun s => s.endPos)⟩ let oldInlayHints := @@ -143,7 +143,6 @@ def handleInlayHints (_ : InlayHintParams) (s : InlayHintState) : let lspInlayHints ← inlayHints.mapM (·.toLspInlayHint srcSearchPath ctx.doc.meta.text) let r := { response := lspInlayHints, isComplete } let s := { s with oldInlayHints := inlayHints } - RequestM.checkCanceled return (r, s) def handleInlayHintsDidChange (p : DidChangeTextDocumentParams) diff --git a/src/Lean/Server/FileWorker/RequestHandling.lean b/src/Lean/Server/FileWorker/RequestHandling.lean index 3349dacfc4..aae83847d2 100644 --- a/src/Lean/Server/FileWorker/RequestHandling.lean +++ b/src/Lean/Server/FileWorker/RequestHandling.lean @@ -426,13 +426,14 @@ partial def handleDocumentSymbol (_ : DocumentSymbolParams) let t := doc.cmdSnaps.waitAll mapTask t fun (snaps, _) => do let mut stxs := snaps.map (·.stx) - return { syms := toDocumentSymbols doc.meta.text stxs #[] [] } + return { syms := ← toDocumentSymbols doc.meta.text stxs #[] [] } where toDocumentSymbols (text : FileMap) (stxs : List Syntax) (syms : Array DocumentSymbol) (stack : List NamespaceEntry) : - Array DocumentSymbol := + RequestM (Array DocumentSymbol) := do + RequestM.checkCancelled match stxs with - | [] => stack.foldl (fun syms entry => entry.finish text syms none) syms + | [] => return stack.foldl (fun syms entry => entry.finish text syms none) syms | stx::stxs => match stx with | `(namespace $id) => let entry := { name := id.getId.componentsRev, stx, selection := id, prevSiblings := syms } @@ -455,9 +456,9 @@ where let syms := entry.finish text syms stx popStack (n - entry.name.length) syms stack popStack (id.map (·.getId.getNumParts) |>.getD 1) syms stack - | _ => Id.run do + | _ => do unless stx.isOfKind ``Lean.Parser.Command.declaration do - return toDocumentSymbols text stxs syms stack + return ← toDocumentSymbols text stxs syms stack if let some stxRange := stx.getRange? then let (name, selection) := match stx with | `($_:declModifiers $_:attrKind instance $[$np:namedPrio]? $[$id$[.{$ls,*}]?]? $sig:declSig $_) => @@ -475,7 +476,7 @@ where range := stxRange.toLspRange text selectionRange := selRange.toLspRange text } - return toDocumentSymbols text stxs (syms.push sym) stack + return ← toDocumentSymbols text stxs (syms.push sym) stack toDocumentSymbols text stxs syms stack partial def handleFoldingRange (_ : FoldingRangeParams) @@ -494,7 +495,9 @@ partial def handleFoldingRange (_ : FoldingRangeParams) if let (_, start)::rest := sections then addRange text FoldingRangeKind.region start text.source.endPos addRanges text rest [] - | stx::stxs => match stx with + | stx::stxs => do + RequestM.checkCancelled + match stx with | `(namespace $id) => addRanges text ((id.getId.getNumParts, stx.getPos?)::sections) stxs | `(section $(id)?) => diff --git a/src/Lean/Server/FileWorker/SemanticHighlighting.lean b/src/Lean/Server/FileWorker/SemanticHighlighting.lean index fc12536b39..d9d666b66e 100644 --- a/src/Lean/Server/FileWorker/SemanticHighlighting.lean +++ b/src/Lean/Server/FileWorker/SemanticHighlighting.lean @@ -147,13 +147,12 @@ def handleSemanticTokens (beginPos : String.Pos) (endPos? : Option String.Pos) -- for the full file before sending a response. This means that the response will be incomplete, -- which we mitigate by regularly sending `workspace/semanticTokens/refresh` requests in the -- `FileWorker` to tell the client to re-compute the semantic tokens. - let (snaps, _, isComplete) ← doc.cmdSnaps.getFinishedPrefixWithTimeout 3000 (cancelTk? := ctx.cancelTk.truncatedTask) + let (snaps, _, isComplete) ← doc.cmdSnaps.getFinishedPrefixWithTimeout 3000 (cancelTk? := ctx.cancelTk.cancellationTask) asTask <| do return { response := ← run doc snaps, isComplete } | some endPos => let t := doc.cmdSnaps.waitUntil (·.endPos >= endPos) mapTask t fun (snaps, _) => do - RequestM.checkCanceled return { response := ← run doc snaps, isComplete := true } where run doc snaps : RequestM SemanticTokens := do @@ -164,8 +163,11 @@ where let syntaxBasedSemanticTokens := collectSyntaxBasedSemanticTokens s.stx let infoBasedSemanticTokens := collectInfoBasedSemanticTokens s.infoTree leanSemanticTokens := leanSemanticTokens ++ syntaxBasedSemanticTokens ++ infoBasedSemanticTokens + RequestM.checkCancelled let absoluteLspSemanticTokens := computeAbsoluteLspSemanticTokens doc.meta.text beginPos endPos? leanSemanticTokens + RequestM.checkCancelled let absoluteLspSemanticTokens := filterDuplicateSemanticTokens absoluteLspSemanticTokens + RequestM.checkCancelled let semanticTokens := computeDeltaLspSemanticTokens absoluteLspSemanticTokens return semanticTokens diff --git a/src/Lean/Server/RequestCancellation.lean b/src/Lean/Server/RequestCancellation.lean new file mode 100644 index 0000000000..72f68d1422 --- /dev/null +++ b/src/Lean/Server/RequestCancellation.lean @@ -0,0 +1,77 @@ +/- +Copyright (c) 2025 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Marc Huisinga +-/ +prelude +import Init.System.Promise + +namespace Lean.Server + +structure RequestCancellationToken where + cancelledByCancelRequest : IO.Ref Bool + cancelledByEdit : IO.Ref Bool + cancellationPromise : IO.Promise Unit + +namespace RequestCancellationToken + +def new : IO RequestCancellationToken := do + return { + cancelledByCancelRequest := ← IO.mkRef false + cancelledByEdit := ← IO.mkRef false + cancellationPromise := ← IO.Promise.new + } + +def cancelByCancelRequest (tk : RequestCancellationToken) : IO Unit := do + tk.cancelledByCancelRequest.set true + tk.cancellationPromise.resolve () + +def cancelByEdit (tk : RequestCancellationToken) : IO Unit := do + tk.cancelledByEdit.set true + tk.cancellationPromise.resolve () + +def cancellationTask (tk : RequestCancellationToken) : Task Unit := + tk.cancellationPromise.result! + +def wasCancelledByCancelRequest (tk : RequestCancellationToken) : IO Bool := + tk.cancelledByCancelRequest.get + +def wasCancelledByEdit (tk : RequestCancellationToken) : IO Bool := do + tk.cancelledByEdit.get + +end RequestCancellationToken + +structure RequestCancellation where + +def RequestCancellation.requestCancelled : RequestCancellation := {} + +abbrev CancellableT m := ReaderT RequestCancellationToken (ExceptT RequestCancellation m) +abbrev CancellableM := CancellableT IO + +def CancellableT.run (tk : RequestCancellationToken) (x : CancellableT m α) : + m (Except RequestCancellation α) := + x tk + +def CancellableM.run (tk : RequestCancellationToken) (x : CancellableM α) : + IO (Except RequestCancellation α) := + CancellableT.run tk x + +def CancellableT.checkCancelled [Monad m] [MonadLiftT IO m] : CancellableT m Unit := do + let tk ← read + if ← tk.wasCancelledByCancelRequest then + throw .requestCancelled + +def CancellableM.checkCancelled : CancellableM Unit := + CancellableT.checkCancelled + +class MonadCancellable (m : Type → Type v) where + checkCancelled : m PUnit + +instance (m n) [MonadLift m n] [MonadCancellable m] : MonadCancellable n where + checkCancelled := liftM (MonadCancellable.checkCancelled : m PUnit) + +instance [Monad m] [MonadLiftT IO m] : MonadCancellable (CancellableT m) where + checkCancelled := CancellableT.checkCancelled + +def RequestCancellation.check [MonadCancellable m] : m Unit := + MonadCancellable.checkCancelled diff --git a/src/Lean/Server/Requests.lean b/src/Lean/Server/Requests.lean index 096e014059..98eb052ca2 100644 --- a/src/Lean/Server/Requests.lean +++ b/src/Lean/Server/Requests.lean @@ -11,6 +11,8 @@ import Lean.Data.Json import Lean.Data.Lsp import Lean.Elab.Command +import Lean.Server.RequestCancellation + import Lean.Server.FileSource import Lean.Server.FileWorker.Utils @@ -127,47 +129,6 @@ def toLspResponseError (id : RequestID) (e : RequestError) : ResponseError Unit end RequestError -inductive RequestCancellationCause where - | cancelRequest - | edit - deriving Inhabited, BEq - -structure RequestCancellationToken where - promise : IO.Promise RequestCancellationCause - -namespace RequestCancellationToken - -def new : IO RequestCancellationToken := do - return { promise := ← IO.Promise.new } - -def cancel (tk : RequestCancellationToken) (cause : RequestCancellationCause) : IO Unit := - tk.promise.resolve cause - -def task (tk : RequestCancellationToken) : Task RequestCancellationCause := - tk.promise.result! - -def truncatedTask (tk : RequestCancellationToken) : Task Unit := - tk.task.map (sync := true) fun _ => () - -def cancelled? (tk : RequestCancellationToken) : IO (Option RequestCancellationCause) := do - let t := tk.task - if ← IO.hasFinished t then - return some t.get - else - return none - -def wasCancelledByCancelRequest (tk : RequestCancellationToken) : IO Bool := do - let some c ← tk.cancelled? - | return false - return c matches .cancelRequest - -def wasCancelledByEdit (tk : RequestCancellationToken) : IO Bool := do - let some c ← tk.cancelled? - | return false - return c matches .edit - -end RequestCancellationToken - def parseRequestParams (paramType : Type) [FromJson paramType] (params : Json) : Except RequestError paramType := fromJson? params |>.mapError fun inner => @@ -201,6 +162,14 @@ instance : MonadLift (EIO Exception) RequestM where | .error e => throw <| ← RequestError.ofException e | .ok v => return v +instance : MonadLift CancellableM RequestM where + monadLift x := do + let ctx ← read + let r ← x.run ctx.cancelTk + match r with + | .error _ => throw RequestError.requestCancelled + | .ok v => return v + namespace RequestM open FileWorker open Snapshots @@ -224,7 +193,7 @@ def bindTask (t : Task α) (f : α → RequestM (RequestTask β)) : RequestM (Re let rc ← readThe RequestContext EIO.bindTask t (f · rc) -def checkCanceled : RequestM Unit := do +def checkCancelled : RequestM Unit := do let rc ← readThe RequestContext if ← rc.cancelTk.wasCancelledByCancelRequest then throw .requestCancelled