feat: 'unknown identifier' code actions (#7665)

This PR adds support for code actions that resolve 'unknown identifier'
errors by either importing the missing declaration or by changing the
identifier to one from the environment.

<details>
<summary>Demo (Click to open)</summary>


![Demo](https://github.com/user-attachments/assets/ba575860-b76d-4213-8cd7-a5525cd60287)
</details>

Specifically, the following kinds of code actions are added by this PR,
all of which are triggered on 'unknown identifier' errors:
- A code action to import the module containing the identifier at the
text cursor position.
- A code action to change the identifier at the text cursor position to
one from the environment.
- A source action to import the modules for all unambiguous identifiers
in the file.

### Details
When clicking on an identifier with an 'unknown identifier' diagnostic,
after a debounce delay of 1000ms, the language server looks up the
(potentially partial) identifier at the position of the cursor in the
global reference data structure by fuzzy-matching against all
identifiers and collects the 10 closest matching entries. This search
accounts for open namespaces at the position of the cursor, including
the namespace of the type / expected type when using dot notation. The
10 closest matching entries are then offered to the user as code
actions:
- If the suggested identifier is not contained in the environment, a
code action that imports the module that the identifier is contained in
and changes the identifier to the suggested one is offered. The
suggestion is inserted in a "minimal" manner, i.e. by accounting for
open namespaces.
- If the suggested identifier is contained in the environment, a code
action that only changes the identifier to the suggested one is offered.
- If the suggested identifier is not contained in the environment and
the suggested identifier is a perfectly unambiguous match, a source
action to import all unambiguous in the file is offered.

The source action to import all unambiguous identifiers can also always
be triggered by right-clicking in the document and selecting the 'Source
Action...' entry.

At the moment, for large projects, the search for closely matching
identifiers in the global reference data structure is still a bit slow.
I hope to optimize it next quarter.

### Implementation notes
- Since the global reference data structure is in the watchdog process,
whereas the elaboration information is in the file worker process, this
PR implements support for file worker -> watchdog requests, including a
new `$/lean/queryModule` request that can be used by the file worker to
request global identifier information.
- To identify 'unknown identifier' errors, several 'unknown identifier'
errors in the elaborator are tagged with a new tag.
- The debounce delay of 1000ms is necessary because VS Code will
re-request code actions while editing an unknown identifier and also
while hovering over the identifier.
- We also implement cancellation for these 'unknown identifier' code
actions. Once the file worker responds to the request as having been
cancelled, the watchdog cancels its computation of all corresponding
file worker -> watchdog requests, too.
- Aliases (i.e. `export`) are currently not accounted for. I've found
that we currently don't handle them correctly in auto-completion, too,
so we will likely add support for this later when fixing the
corresponding auto-completion issue.
- The new code actions added by this request support incrementality.
This commit is contained in:
Marc Huisinga 2025-04-02 11:43:40 +02:00 committed by GitHub
parent 5df4e48dc9
commit 336b68ec20
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 1024 additions and 232 deletions

View file

@ -145,34 +145,12 @@ structure DiagnosticWith (α : Type) where
/-- An array of related diagnostic information, e.g. when symbol-names within a scope collide all definitions can be marked via this property. -/
relatedInformation? : Option (Array DiagnosticRelatedInformation) := none
/-- A data entry field that is preserved between a `textDocument/publishDiagnostics` notification and `textDocument/codeAction` request. -/
data?: Option Json := none
data? : Option Json := none
deriving Inhabited, BEq, ToJson, FromJson
def DiagnosticWith.fullRange (d : DiagnosticWith α) : Range :=
d.fullRange?.getD d.range
attribute [local instance] Ord.arrayOrd in
/-- Restriction of `DiagnosticWith` to properties that are displayed to users in the InfoView. -/
private structure DiagnosticWith.UserVisible (α : Type) where
range : Range
fullRange? : Option Range
severity? : Option DiagnosticSeverity
code? : Option DiagnosticCode
source? : Option String
message : α
tags? : Option (Array DiagnosticTag)
relatedInformation? : Option (Array DiagnosticRelatedInformation)
deriving Ord
/-- Extracts user-visible properties from the given `DiagnosticWith`. -/
private def DiagnosticWith.UserVisible.ofDiagnostic (d : DiagnosticWith α)
: DiagnosticWith.UserVisible α :=
{ d with }
/-- Compares `DiagnosticWith` instances modulo non-user-facing properties. -/
def compareByUserVisible [Ord α] (a b : DiagnosticWith α) : Ordering :=
compare (DiagnosticWith.UserVisible.ofDiagnostic a) (DiagnosticWith.UserVisible.ofDiagnostic b)
abbrev Diagnostic := DiagnosticWith String
/-- Parameters for the [`textDocument/publishDiagnostics` notification](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#textDocument_publishDiagnostics). -/

View file

@ -7,6 +7,7 @@ Authors: Joscha Mennicken
prelude
import Lean.Expr
import Lean.Data.Lsp.Basic
import Lean.Data.JsonRpc
import Std.Data.TreeMap
set_option linter.missingDocs true -- keep it documented
@ -201,4 +202,62 @@ structure LeanStaleDependencyParams where
staleDependency : DocumentUri
deriving FromJson, ToJson
/-- LSP type for `Lean.OpenDecl`. -/
inductive OpenNamespace
/-- All declarations in `«namespace»` are opened, except for `exceptions`. -/
| allExcept («namespace» : Name) (exceptions : Array Name)
/-- The declaration `«from»` is renamed to `to`. -/
| renamed («from» : Name) (to : Name)
deriving FromJson, ToJson
/-- Query in the `$/lean/queryModule` watchdog <- worker request. -/
structure LeanModuleQuery where
/-- Identifier (potentially partial) to query. -/
identifier : String
/--
Namespaces that are open at the position of `identifier`.
Used for accurately matching declarations against `identifier` in context.
-/
openNamespaces : Array OpenNamespace
deriving FromJson, ToJson
/--
Used in the `$/lean/queryModule` watchdog <- worker request, which is used by the worker to
extract information from the .ilean information in the watchdog.
-/
structure LeanQueryModuleParams where
/--
The request ID in the context of which this worker -> watchdog request was emitted.
Used for cancelling this request in the watchdog.
-/
sourceRequestID : JsonRpc.RequestID
/-- Module queries for extracting .ilean information in the watchdog. -/
queries : Array LeanModuleQuery
deriving FromJson, ToJson
/-- Result entry of a module query. -/
structure LeanIdentifier where
/-- Module that `decl` is defined in. -/
module : Name
/-- Full name of the declaration that matches the query. -/
decl : Name
/-- Whether this `decl` matched the query exactly. -/
isExactMatch : Bool
deriving FromJson, ToJson
/--
Result for a single module query.
Identifiers in the response are sorted descendingly by how well they match the query.
-/
abbrev LeanQueriedModule := Array LeanIdentifier
/-- Response for the `$/lean/queryModule` watchdog <- worker request. -/
structure LeanQueryModuleResponse where
/--
Results for each query in `LeanQueryModuleParams`.
Positions correspond to `queries` in the parameter of the request.
-/
queryResults : Array LeanQueriedModule
deriving FromJson, ToJson, Inhabited
end Lean.Lsp

View file

@ -1222,8 +1222,8 @@ private def resolveLValAux (e : Expr) (eType : Expr) (lval : LVal) : TermElabM L
-- Then search the environment
if let some (baseStructName, fullName) ← findMethod? structName (.mkSimple fieldName) then
return LValResolution.const baseStructName structName fullName
throwLValError e eType
m!"invalid field '{fieldName}', the environment does not contain '{Name.mkStr structName fieldName}'"
let msg := mkUnknownIdentifierMessage m!"invalid field '{fieldName}', the environment does not contain '{Name.mkStr structName fieldName}'"
throwLValError e eType msg
| none, LVal.fieldName _ _ (some suffix) _ =>
if e.isConst then
throwUnknownConstant (e.constName! ++ suffix)
@ -1502,7 +1502,7 @@ where
else if let some (fvar, []) ← resolveLocalName idNew then
return fvar
else
throwError "invalid dotted identifier notation, unknown identifier `{idNew}` from expected type{indentExpr expectedType}"
throwUnknownIdentifier m!"invalid dotted identifier notation, unknown identifier `{idNew}` from expected type{indentExpr expectedType}"
catch
| ex@(.error ..) =>
match (← unfoldDefinition? resultType) with
@ -1550,7 +1550,7 @@ private partial def elabAppFn (f : Syntax) (lvals : List LVal) (namedArgs : Arra
| `(@$_) => throwUnsupportedSyntax -- invalid occurrence of `@`
| `(_) => throwError "placeholders '_' cannot be used where a function is expected"
| `(.$id:ident) =>
addCompletionInfo <| CompletionInfo.dotId f id.getId (← getLCtx) expectedType?
addCompletionInfo <| CompletionInfo.dotId id id.getId (← getLCtx) expectedType?
let fConst ← resolveDotName id expectedType?
let s ← observing do
-- Use (force := true) because we want to record the result of .ident resolution even in patterns

View file

@ -1971,7 +1971,7 @@ where
isValidAutoBoundImplicitName n (relaxedAutoImplicit.get (← getOptions)) then
throwAutoBoundImplicitLocal n
else
throwError "unknown identifier '{Lean.mkConst n}'"
throwUnknownIdentifier m!"unknown identifier '{Lean.mkConst n}'"
mkConsts candidates explicitLevels
/--

View file

@ -69,9 +69,31 @@ protected def throwError [Monad m] [MonadError m] (msg : MessageData) : m α :=
let (ref, msg) ← AddErrorMessageContext.add ref msg
throw <| Exception.error ref msg
/--
Tag used for `unknown identifier` messages.
This tag is used by the 'import unknown identifier' code action to detect messages that should
prompt the code action.
-/
def unknownIdentifierMessageTag : Name := `unknownIdentifier
/--
Creates a `MessageData` that is tagged with `unknownIdentifierMessageTag`.
This tag is used by the 'import unknown identifier' code action to detect messages that should
prompt the code action.
-/
def mkUnknownIdentifierMessage (msg : MessageData) : MessageData :=
MessageData.tagged unknownIdentifierMessageTag msg
/--
Throw an unknown identifier error message that is tagged with `unknownIdentifierMessageTag`.
See also `mkUnknownIdentifierMessage`.
-/
def throwUnknownIdentifier [Monad m] [MonadError m] (msg : MessageData) : m α :=
Lean.throwError <| mkUnknownIdentifierMessage msg
/-- Throw an unknown constant error message. -/
def throwUnknownConstant [Monad m] [MonadError m] (constName : Name) : m α :=
Lean.throwError m!"unknown constant '{.ofConstName constName}'"
throwUnknownIdentifier m!"unknown constant '{.ofConstName constName}'"
/-- Throw an error exception using the given message data and reference syntax. -/
protected def throwErrorAt [Monad m] [MonadError m] (ref : Syntax) (msg : MessageData) : m α := do

View file

@ -80,7 +80,7 @@ partial def getFinishedPrefix : AsyncList ε α → BaseIO (List α × Option ε
else pure ⟨[], none, false⟩
partial def getFinishedPrefixWithTimeout (xs : AsyncList ε α) (timeoutMs : UInt32)
(cancelTk? : Option (ServerTask Unit) := none) : BaseIO (List α × Option ε × Bool) := do
(cancelTks : List (ServerTask Unit) := []) : BaseIO (List α × Option ε × Bool) := do
let timeoutTask : ServerTask (Unit ⊕ Except ε (AsyncList ε α)) ←
if timeoutMs == 0 then
pure <| ServerTask.pure (Sum.inl ())
@ -100,21 +100,17 @@ where
| delayed tl =>
let tl : ServerTask (Except ε (AsyncList ε α)) := tl
let tl := tl.mapCheap .inr
let cancelTk? := do return (← cancelTk?).mapCheap .inl
let tasks : { t : List _ // t.length > 0 } :=
match cancelTk? with
| none => ⟨[tl, timeoutTask], by exact Nat.zero_lt_succ _⟩
| some cancelTk => ⟨[tl, cancelTk, timeoutTask], by exact Nat.zero_lt_succ _⟩
let r ← ServerTask.waitAny tasks.val (h := tasks.property)
let cancelTks := cancelTks.map (·.mapCheap .inl)
let r ← ServerTask.waitAny (tl :: cancelTks ++ [timeoutTask])
match r with
| .inl _ => return ⟨[], none, false⟩ -- Timeout or cancellation - stop waiting
| .inr (.ok tl) => go timeoutTask tl
| .inr (.error e) => return ⟨[], some e, true⟩
partial def getFinishedPrefixWithConsistentLatency (xs : AsyncList ε α) (latencyMs : UInt32)
(cancelTk? : Option (ServerTask Unit) := none) : BaseIO (List α × Option ε × Bool) := do
(cancelTks : List (ServerTask Unit) := []) : BaseIO (List α × Option ε × Bool) := do
let timestamp ← IO.monoMsNow
let r ← xs.getFinishedPrefixWithTimeout latencyMs cancelTk?
let r ← xs.getFinishedPrefixWithTimeout latencyMs cancelTks
let passedTimeMs := (← IO.monoMsNow) - timestamp
let remainingLatencyMs := (latencyMs.toNat - passedTimeMs).toUInt32
sleepWithCancellation remainingLatencyMs
@ -123,14 +119,14 @@ where
sleepWithCancellation (sleepDurationMs : UInt32) : BaseIO Unit := do
if sleepDurationMs == 0 then
return
let some cancelTk := cancelTk?
| IO.sleep sleepDurationMs
return
if ← cancelTk.hasFinished then
if cancelTks.isEmpty then
IO.sleep sleepDurationMs
return
if ← cancelTks.anyM (·.hasFinished) then
return
let sleepTask ← Lean.Server.ServerTask.BaseIO.asTask do
IO.sleep sleepDurationMs
ServerTask.waitAny [sleepTask, cancelTk]
ServerTask.waitAny <| sleepTask :: cancelTks
end AsyncList

View file

@ -147,7 +147,7 @@ def handleCodeActionResolve (param : CodeAction) : RequestM (RequestTask CodeAct
let doc ← readDoc
let some data := param.data?
| throw (RequestError.invalidParams "Expected a data field on CodeAction.")
let data : CodeActionResolveData ← liftExcept <| Except.mapError RequestError.invalidParams <| fromJson? data
let data ← RequestM.parseRequestParams CodeActionResolveData data
let pos := doc.meta.text.lspPosToUtf8Pos data.params.range.end
withWaitFindSnap doc (fun s => s.endPos ≥ pos)
(notFoundX := throw <| RequestError.internalError "snapshot not found")

View file

@ -0,0 +1,341 @@
/-
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Marc Huisinga
-/
prelude
import Lean.Server.FileWorker.Utils
import Lean.Data.Lsp.Internal
import Lean.Server.Requests
import Lean.Server.Completion.CompletionInfoSelection
import Lean.Server.CodeActions.Basic
import Lean.Server.Completion.CompletionUtils
namespace Lean.Server.FileWorker
open Lean.Lsp
open Lean.Server.Completion
structure UnknownIdentifierInfo where
paramsRange : String.Range
diagRange : String.Range
def waitUnknownIdentifierRanges (doc : EditableDocument) (requestedRange : String.Range)
: BaseIO (Array String.Range) := do
let text := doc.meta.text
let msgLog := Language.toSnapshotTree doc.initSnap |>.collectMessagesInRange requestedRange |>.get
let mut ranges := #[]
for msg in msgLog.reportedPlusUnreported do
if ! msg.data.hasTag (· == unknownIdentifierMessageTag) then
continue
let msgRange : String.Range := ⟨text.ofPosition msg.pos, text.ofPosition <| msg.endPos.getD msg.pos⟩
if ! msgRange.overlaps requestedRange
(includeFirstStop := true) (includeSecondStop := true) then
continue
ranges := ranges.push msgRange
return ranges
structure Insertion where
fullName : Name
edit : TextEdit
structure Query extends LeanModuleQuery where
env : Environment
determineInsertion : Name → Insertion
partial def collectOpenNamespaces (currentNamespace : Name) (openDecls : List OpenDecl)
: Array OpenNamespace := Id.run do
let mut openNamespaces : Array OpenNamespace := #[]
let mut currentNamespace := currentNamespace
while ! currentNamespace.isAnonymous do
openNamespaces := openNamespaces.push <| .allExcept currentNamespace #[]
currentNamespace := currentNamespace.getPrefix
let openDeclNamespaces := openDecls.map fun
| .simple ns except => .allExcept ns except.toArray
| .explicit id declName => .renamed declName id
openNamespaces := openNamespaces ++ openDeclNamespaces.toArray
return openNamespaces
def computeFallbackQuery?
(doc : EditableDocument)
(requestedRange : String.Range)
(unknownIdentifierRange : String.Range)
(infoTree : Elab.InfoTree)
: Option Query := do
let text := doc.meta.text
let info? := infoTree.smallestInfo? fun info => Id.run do
let some range := info.range?
| return false
return range.overlaps requestedRange (includeFirstStop := true) (includeSecondStop := true)
let some (ctx, _) := info?
| none
return {
identifier := text.source.extract unknownIdentifierRange.start unknownIdentifierRange.stop
openNamespaces := collectOpenNamespaces ctx.currNamespace ctx.openDecls
env := ctx.env
determineInsertion decl :=
let minimizedId := minimizeGlobalIdentifierInContext ctx.currNamespace ctx.openDecls decl
{
fullName := minimizedId
edit := {
range := text.utf8RangeToLspRange unknownIdentifierRange
newText := minimizedId.toString
}
}
}
def computeIdQuery?
(doc : EditableDocument)
(ctx : Elab.ContextInfo)
(stx : Syntax)
(id : Name)
: Option Query := do
let some pos := stx.getPos? (canonicalOnly := true)
| none
let some tailPos := stx.getTailPos? (canonicalOnly := true)
| none
return {
identifier := id.toString
openNamespaces := collectOpenNamespaces ctx.currNamespace ctx.openDecls
env := ctx.env
determineInsertion decl :=
let minimizedId := minimizeGlobalIdentifierInContext ctx.currNamespace ctx.openDecls decl
{
fullName := minimizedId
edit := {
range := doc.meta.text.utf8RangeToLspRange ⟨pos, tailPos⟩
newText := minimizedId.toString
}
}
}
def computeDotQuery?
(doc : EditableDocument)
(ctx : Elab.ContextInfo)
(ti : Elab.TermInfo)
: IO (Option Query) := do
let text := doc.meta.text
let some pos := ti.stx.getPos? (canonicalOnly := true)
| return none
let some tailPos := ti.stx.getTailPos? (canonicalOnly := true)
| return none
let typeNames? : Option (Array Name) ← ctx.runMetaM ti.lctx do
try
return some <| ← getDotCompletionTypeNames (← Lean.instantiateMVars (← Lean.Meta.inferType ti.expr))
catch _ =>
return none
let some typeNames := typeNames?
| return none
return some {
identifier := text.source.extract pos tailPos
openNamespaces := typeNames.map (.allExcept · #[])
env := ctx.env
determineInsertion decl :=
{
fullName := decl
edit := {
range := text.utf8RangeToLspRange ⟨pos, tailPos⟩
newText := decl.getString!
}
}
}
def computeDotIdQuery?
(doc : EditableDocument)
(ctx : Elab.ContextInfo)
(stx : Syntax)
(id : Name)
(lctx : LocalContext)
(expectedType? : Option Expr)
: IO (Option Query) := do
let some pos := stx.getPos? (canonicalOnly := true)
| return none
let some tailPos := stx.getTailPos? (canonicalOnly := true)
| return none
let some expectedType := expectedType?
| return none
let typeNames? : Option (Array Name) ← ctx.runMetaM lctx do
let resultTypeFn := (← instantiateMVars expectedType).cleanupAnnotations.getAppFn.cleanupAnnotations
let .const .. := resultTypeFn
| return none
try
return some <| ← getDotCompletionTypeNames resultTypeFn
catch _ =>
return none
let some typeNames := typeNames?
| return none
return some {
identifier := id.toString
openNamespaces := typeNames.map (.allExcept · #[])
env := ctx.env
determineInsertion decl :=
{
fullName := decl
edit := {
range := doc.meta.text.utf8RangeToLspRange ⟨pos, tailPos⟩
newText := decl.getString!
}
}
}
def computeQuery?
(doc : EditableDocument)
(requestedRange : String.Range)
(unknownIdentifierRange : String.Range)
: RequestM (Option Query) := do
let text := doc.meta.text
let some (stx, infoTree) := RequestM.findCmdDataAtPos doc unknownIdentifierRange.stop (includeStop := true) |>.get
| return none
let completionInfo? : Option ContextualizedCompletionInfo := do
let (completionPartitions, _) := findPrioritizedCompletionPartitionsAt text unknownIdentifierRange.stop stx infoTree
let highestPrioPartition ← completionPartitions[0]?
let (completionInfo, _) ← highestPrioPartition[0]?
return completionInfo
let some ⟨_, ctx, info⟩ := completionInfo?
| return computeFallbackQuery? doc requestedRange unknownIdentifierRange infoTree
match info with
| .id (stx := stx) (id := id) .. =>
return computeIdQuery? doc ctx stx id
| .dot (termInfo := ti) .. =>
return ← computeDotQuery? doc ctx ti
| .dotId stx id lctx expectedType? =>
return ← computeDotIdQuery? doc ctx stx id lctx expectedType?
| _ => return none
def importAllUnknownIdentifiersProvider : Name := `unknownIdentifiers
def importAllUnknownIdentifiersCodeAction (params : CodeActionParams) (kind : String) : CodeAction := {
title := "Import all unambiguous unknown identifiers"
kind? := kind
data? := some <| toJson {
params,
providerName := importAllUnknownIdentifiersProvider
providerResultIndex := 0
: CodeActionResolveData
}
}
def handleUnknownIdentifierCodeAction
(id : JsonRpc.RequestID)
(params : CodeActionParams)
(requestedRange : String.Range)
(unknownIdentifierRanges : Array String.Range)
: RequestM (Array CodeAction) := do
let rc ← read
let doc := rc.doc
let text := doc.meta.text
let queries ← unknownIdentifierRanges.filterMapM fun unknownIdentifierRange =>
computeQuery? doc requestedRange unknownIdentifierRange
if queries.isEmpty then
return #[]
let responseTask ← RequestM.sendServerRequest LeanQueryModuleParams LeanQueryModuleResponse "$/lean/queryModule" {
sourceRequestID := id
queries := queries.map (·.toLeanModuleQuery)
: LeanQueryModuleParams
}
let r ← ServerTask.waitAny [
responseTask.mapCheap Sum.inl,
rc.cancelTk.requestCancellationTask.mapCheap Sum.inr
]
let .inl (.success response) := r
| RequestM.checkCancelled
return #[]
let headerStx := doc.initSnap.stx
let importInsertionPos : Lsp.Position :=
match headerStx.getTailPos? with
| some headerTailPos => {
line := (text.utf8PosToLspPos headerTailPos |>.line) + 1
character := 0
}
| none => { line := 0, character := 0 }
let importInsertionRange : Lsp.Range := ⟨importInsertionPos, importInsertionPos⟩
let mut unknownIdentifierCodeActions := #[]
let mut hasUnambigiousImportCodeAction := false
for q in queries, result in response.queryResults do
for ⟨mod, decl, isExactMatch⟩ in result do
let isDeclInEnv := q.env.contains decl
if ! isDeclInEnv && mod == q.env.mainModule then
-- Don't offer any code actions for identifiers defined further down in the same file
continue
let insertion := q.determineInsertion decl
if ! isDeclInEnv then
unknownIdentifierCodeActions := unknownIdentifierCodeActions.push {
title := s!"Import {insertion.fullName} from {mod}"
kind? := "quickfix"
edit? := WorkspaceEdit.ofTextDocumentEdit {
textDocument := doc.versionedIdentifier
edits := #[
{
range := importInsertionRange
newText := s!"import {mod}\n"
},
insertion.edit
]
}
}
if isExactMatch then
hasUnambigiousImportCodeAction := true
else
unknownIdentifierCodeActions := unknownIdentifierCodeActions.push {
title := s!"Change to {insertion.fullName}"
kind? := "quickfix"
edit? := WorkspaceEdit.ofTextDocumentEdit {
textDocument := doc.versionedIdentifier
edits := #[insertion.edit]
}
}
if hasUnambigiousImportCodeAction then
unknownIdentifierCodeActions := unknownIdentifierCodeActions.push <|
importAllUnknownIdentifiersCodeAction params "quickfix"
return unknownIdentifierCodeActions
def handleResolveImportAllUnknownIdentifiersCodeAction?
(id : JsonRpc.RequestID)
(action : CodeAction)
(unknownIdentifierRanges : Array String.Range)
: RequestM (Option CodeAction) := do
let rc ← read
let doc := rc.doc
let text := doc.meta.text
let queries ← unknownIdentifierRanges.filterMapM fun unknownIdentifierRange =>
computeQuery? doc ⟨0, text.source.endPos⟩ unknownIdentifierRange
if queries.isEmpty then
return none
let responseTask ← RequestM.sendServerRequest LeanQueryModuleParams LeanQueryModuleResponse "$/lean/queryModule" {
sourceRequestID := id
queries := queries.map (·.toLeanModuleQuery)
: LeanQueryModuleParams
}
let .success response := responseTask.get
| return none
let headerStx := doc.initSnap.stx
let importInsertionPos : Lsp.Position :=
match headerStx.getTailPos? with
| some headerTailPos => {
line := (text.utf8PosToLspPos headerTailPos |>.line) + 1
character := 0
}
| none => { line := 0, character := 0 }
let importInsertionRange : Lsp.Range := ⟨importInsertionPos, importInsertionPos⟩
let mut edits : Array TextEdit := #[]
let mut imports : Std.HashSet Name := ∅
for q in queries, result in response.queryResults do
let some ⟨mod, decl, _⟩ := result.find? fun id =>
id.isExactMatch && ! q.env.contains id.decl
| continue
if mod == q.env.mainModule then
continue
let insertion := q.determineInsertion decl
if ! imports.contains mod then
edits := edits.push {
range := importInsertionRange
newText := s!"import {mod}\n"
}
edits := edits.push insertion.edit
imports := imports.insert mod
return some { action with
edit? := WorkspaceEdit.ofTextDocumentEdit {
textDocument := doc.versionedIdentifier
edits
}
}

View file

@ -9,6 +9,7 @@ import Lean.Elab.Tactic.Doc
import Lean.Server.Completion.CompletionResolution
import Lean.Server.Completion.EligibleHeaderDecls
import Lean.Server.RequestCancellation
import Lean.Server.Completion.CompletionUtils
namespace Lean.Server.Completion
open Elab
@ -268,9 +269,6 @@ end IdCompletionUtils
section DotCompletionUtils
private def unfoldeDefinitionGuarded? (e : Expr) : MetaM (Option Expr) :=
try unfoldDefinition? e catch _ => pure none
/-- Return `true` if `e` is a `declName`-application, or can be unfolded (delta-reduced) to one. -/
private partial def isDefEqToAppOf (e : Expr) (declName : Name) : MetaM Bool := do
let isConstOf := match e.getAppFn with
@ -340,17 +338,11 @@ section DotCompletionUtils
Given a type, try to extract relevant type names for dot notation field completion.
We extract the type name, parent struct names, and unfold the type.
The process mimics the dot notation elaboration procedure at `App.lean` -/
private partial def getDotCompletionTypeNames (type : Expr) : MetaM NameSetModPrivate :=
return (← visit type |>.run RBTree.empty).2
where
visit (type : Expr) : StateRefT NameSetModPrivate MetaM Unit := do
let .const typeName _ := type.getAppFn | return ()
modify fun s => s.insert typeName
if isStructure (← getEnv) typeName then
for parentName in (← getAllParentStructures typeName) do
modify fun s => s.insert parentName
let some type ← unfoldeDefinitionGuarded? type | return ()
visit type
private def getDotCompletionTypeNameSet (type : Expr) : MetaM NameSetModPrivate := do
let mut set := .empty
for typeName in ← getDotCompletionTypeNames type do
set := set.insert typeName
return set
end DotCompletionUtils
@ -478,7 +470,7 @@ def dotCompletion
: CancellableM (Array CompletionItem) :=
runM params completionInfoPos ctx info.lctx do
let nameSet ← try
getDotCompletionTypeNames (← instantiateMVars (← inferType info.expr))
getDotCompletionTypeNameSet (← instantiateMVars (← inferType info.expr))
catch _ =>
pure RBTree.empty
if nameSet.isEmpty then
@ -513,7 +505,7 @@ def dotIdCompletion
| return ()
let nameSet ← try
getDotCompletionTypeNames resultTypeFn
getDotCompletionTypeNameSet resultTypeFn
catch _ =>
pure RBTree.empty

View file

@ -5,7 +5,7 @@ Authors: Leonardo de Moura, Marc Huisinga
-/
prelude
import Init.Prelude
import Lean.Elab.InfoTree.Types
import Lean.Meta.WHNF
namespace Lean.Server.Completion
open Elab
@ -19,4 +19,48 @@ structure ContextualizedCompletionInfo where
ctx : ContextInfo
info : CompletionInfo
partial def minimizeGlobalIdentifierInContext (currNamespace : Name) (openDecls : List OpenDecl) (id : Name)
: Name := Id.run do
let mut minimized := shortenIn id currNamespace
for openDecl in openDecls do
let candidate? := match openDecl with
| .simple ns except =>
let candidate := shortenIn id ns
if ! except.contains candidate then
some candidate
else
none
| .explicit alias declName =>
if declName == id then
some alias
else
none
if let some candidate := candidate? then
if candidate.getNumParts < minimized.getNumParts then
minimized := candidate
return minimized
where
shortenIn (id : Name) (contextNamespace : Name) : Name :=
if contextNamespace matches .anonymous then
id
else if contextNamespace.isPrefixOf id then
id.replacePrefix contextNamespace .anonymous
else
shortenIn id contextNamespace.getPrefix
def unfoldeDefinitionGuarded? (e : Expr) : MetaM (Option Expr) :=
try Lean.Meta.unfoldDefinition? e catch _ => pure none
partial def getDotCompletionTypeNames (type : Expr) : MetaM (Array Name) :=
return (← visit type |>.run #[]).2
where
visit (type : Expr) : StateRefT (Array Name) MetaM Unit := do
let .const typeName _ := type.getAppFn | return ()
modify fun s => s.push typeName
if isStructure (← getEnv) typeName then
for parentName in (← getAllParentStructures typeName) do
modify fun s => s.push parentName
let some type ← unfoldeDefinitionGuarded? type | return ()
visit type
end Lean.Server.Completion

View file

@ -29,6 +29,7 @@ import Lean.Server.FileWorker.SetupFile
import Lean.Server.Rpc.Basic
import Lean.Widget.InteractiveDiagnostic
import Lean.Server.Completion.ImportCompletion
import Lean.Server.CodeActions.UnknownIdentifier
/-!
For general server architecture, see `README.md`. For details of IPC communication, see `Watchdog.lean`.
@ -74,27 +75,28 @@ open Widget in
structure WorkerContext where
/-- Synchronized output channel for LSP messages. Notifications for outdated versions are
discarded on read. -/
chanOut : Std.Channel JsonRpc.Message
chanOut : Std.Channel JsonRpc.Message
/--
Latest document version received by the client, used for filtering out notifications from
previous versions.
-/
maxDocVersionRef : IO.Ref Int
freshRequestIdRef : IO.Ref Int
maxDocVersionRef : IO.Ref Int
freshRequestIdRef : IO.Ref Int
/--
Diagnostics that are included in every single `textDocument/publishDiagnostics` notification.
-/
stickyDiagnosticsRef : IO.Ref (Array InteractiveDiagnostic)
partialHandlersRef : IO.Ref (RBMap String PartialHandlerInfo compare)
hLog : FS.Stream
initParams : InitializeParams
processor : Parser.InputContext → BaseIO Lean.Language.Lean.InitialSnapshot
clientHasWidgets : Bool
stickyDiagnosticsRef : IO.Ref (Array InteractiveDiagnostic)
partialHandlersRef : IO.Ref (RBMap String PartialHandlerInfo compare)
pendingServerRequestsRef : IO.Ref (Std.TreeMap RequestID (IO.Promise (ServerRequestResponse Json)))
hLog : FS.Stream
initParams : InitializeParams
processor : Parser.InputContext → BaseIO Lean.Language.Lean.InitialSnapshot
clientHasWidgets : Bool
/--
Options defined on the worker cmdline (i.e. not including options from `setup-file`), used for
context-free tasks such as editing delay.
-/
cmdlineOpts : Options
cmdlineOpts : Options
def WorkerContext.modifyGetPartialHandler (ctx : WorkerContext) (method : String)
(f : PartialHandlerInfo → α × PartialHandlerInfo) : BaseIO α :=
@ -113,6 +115,31 @@ def WorkerContext.modifyPartialHandler (ctx : WorkerContext) (method : String)
def WorkerContext.updateRequestsInFlight (ctx : WorkerContext) (method : String) (f : Nat → Nat) : BaseIO Unit :=
ctx.modifyPartialHandler method fun h => { h with requestsInFlight := f h.requestsInFlight }
def WorkerContext.initPendingServerRequest
responseType [FromJson responseType] [Inhabited responseType]
(ctx : WorkerContext) (id : RequestID) :
BaseIO (ServerTask (ServerRequestResponse responseType)) := do
let responsePromise ← IO.Promise.new
ctx.pendingServerRequestsRef.modify (·.insert id responsePromise)
let responseTask := responsePromise.result!.asServerTask
let responseTask := responseTask.mapCheap fun
| .success response =>
match fromJson? response with
| .ok response => .success response
| .error message => .failure .invalidParams message
| .failure code message => .failure code message
return responseTask
def WorkerContext.resolveServerRequestResponse (ctx : WorkerContext) (id : RequestID)
(response : ServerRequestResponse Json) : BaseIO Unit := do
let responsePromise? ← ctx.pendingServerRequestsRef.modifyGet fun pendingServerRequests =>
let responsePromise? := pendingServerRequests.get? id
let pendingServerRequests := pendingServerRequests.erase id
(responsePromise?, pendingServerRequests)
let some responsePromise := responsePromise?
| return
responsePromise.resolve response
/-! # Asynchronous snapshot elaboration -/
section Elab
@ -409,6 +436,7 @@ section Initialization
let maxDocVersionRef ← IO.mkRef 0
let freshRequestIdRef ← IO.mkRef (0 : Int)
let stickyDiagnosticsRef ← IO.mkRef ∅
let pendingServerRequestsRef ← IO.mkRef ∅
let chanOut ← mkLspOutputChannel maxDocVersionRef
let srcSearchPathPromise ← IO.Promise.new
let timestamp ← IO.monoMsNow
@ -434,6 +462,7 @@ section Initialization
processor
clientHasWidgets
partialHandlersRef
pendingServerRequestsRef
maxDocVersionRef
freshRequestIdRef
cmdlineOpts := opts
@ -495,15 +524,26 @@ section Initialization
end Initialization
section ServerRequests
def sendServerRequest [ToJson α]
def sendServerRequest
paramType [ToJson paramType] responseType [FromJson responseType] [Inhabited responseType]
(ctx : WorkerContext)
(method : String)
(param : α)
: BaseIO Unit := do
(param : paramType)
: BaseIO (ServerTask (ServerRequestResponse responseType)) := do
let freshRequestId ← ctx.freshRequestIdRef.modifyGet fun freshRequestId =>
(freshRequestId, freshRequestId + 1)
let r : JsonRpc.Request α := ⟨freshRequestId, method, param⟩
let responseTask ← ctx.initPendingServerRequest responseType freshRequestId
let r : JsonRpc.Request paramType := ⟨freshRequestId, method, param⟩
ctx.chanOut.send r
return responseTask
def sendUntypedServerRequest
(ctx : WorkerContext)
(method : String)
(param : Json)
: BaseIO (ServerTask (ServerRequestResponse Json)) := do
sendServerRequest Json Json ctx method param
end ServerRequests
section Updates
@ -544,6 +584,7 @@ section NotificationHandling
cancelTk
hLog := ctx.hLog
initParams := ctx.initParams
serverRequestEmitter := sendUntypedServerRequest ctx
}
RequestM.runInIO (handleOnDidChange p) rc
if ¬ changes.isEmpty then
@ -634,32 +675,6 @@ section MessageHandling
: WorkerM Unit := do
updatePendingRequests (·.insert id r)
open Widget RequestM Language in
def handleGetInteractiveDiagnosticsRequest (params : GetInteractiveDiagnosticsParams) :
WorkerM (Array InteractiveDiagnostic) := do
let ctx ← read
let st ← get
-- NOTE: always uses latest document (which is the only one we can retrieve diagnostics for);
-- any race should be temporary as the client should re-request interactive diagnostics when
-- they receive the non-interactive diagnostics for the new document
let stickyDiags ← ctx.stickyDiagnosticsRef.get
let diags ← st.doc.diagnosticsRef.get
-- NOTE: does not wait for `lineRange?` to be fully elaborated, which would be problematic with
-- fine-grained incremental reporting anyway; instead, the client is obligated to resend the
-- request when the non-interactive diagnostics of this range have changed
return (stickyDiags ++ diags).filter fun diag =>
let r := diag.fullRange
let diagStartLine := r.start.line
let diagEndLine :=
if r.end.character == 0 then
r.end.line
else
r.end.line + 1
params.lineRange?.all fun ⟨s, e⟩ =>
-- does [s,e) intersect [diagStartLine,diagEndLine)?
s ≤ diagStartLine ∧ diagStartLine < e
diagStartLine ≤ s ∧ s < diagEndLine
def handleImportCompletionRequest (id : RequestID) (params : CompletionParams)
: WorkerM (ServerTask (Except Error AvailableImportsCache)) := do
let ctx ← read
@ -684,17 +699,9 @@ section MessageHandling
ctx.chanOut.send <| .response id (toJson completions)
pure { availableImports, lastRequestTimestampMs : AvailableImportsCache }
def handleRequest (id : RequestID) (method : String) (params : Json)
: WorkerM Unit := do
def handleStatefulPreRequestSpecialCases (id : RequestID) (method : String) (params : Json) : WorkerM Bool := do
let ctx ← read
let st ← get
ctx.modifyPartialHandler method fun h => { h with
pendingRefreshInfo? := none
requestsInFlight := h.requestsInFlight + 1
}
-- special cases
try
match method with
-- needs access to `WorkerState.rpcSessions`
@ -702,48 +709,128 @@ section MessageHandling
let ps ← parseParams RpcConnectParams params
let resp ← handleRpcConnect ps
ctx.chanOut.send <| .response id (toJson resp)
return
| "$/lean/rpc/call" =>
let params ← parseParams Lsp.RpcCallParams params
-- needs access to `EditableDocumentCore.diagnosticsRef`
if params.method == `Lean.Widget.getInteractiveDiagnostics then
let some seshRef := st.rpcSessions.find? params.sessionId
| ctx.chanOut.send <| .responseError id .rpcNeedsReconnect "Outdated RPC session" none
let params ← IO.ofExcept (fromJson? params.params)
let resp ← handleGetInteractiveDiagnosticsRequest params
let resp ← seshRef.modifyGet fun st =>
rpcEncode resp st.objects |>.map (·) ({st with objects := ·})
ctx.chanOut.send <| .response id resp
return
return true
| "textDocument/completion" =>
let params ← parseParams CompletionParams params
-- must not wait on import processing snapshot
if ImportCompletion.isImportCompletionRequest st.doc.meta.text st.doc.initSnap.stx params
then
let importCachingTask ← handleImportCompletionRequest id params
set { st with importCachingTask? := some importCachingTask }
return
| _ => pure ()
-- Must not wait on import processing snapshot
if ! ImportCompletion.isImportCompletionRequest st.doc.meta.text st.doc.initSnap.stx params then
return false
let importCachingTask ← handleImportCompletionRequest id params
set { st with importCachingTask? := some importCachingTask }
return true
| _ =>
return false
catch e =>
ctx.chanOut.send <| .responseError id .internalError (toString e) none
return
return true
let cancelTk ← RequestCancellationToken.new
-- TODO: move into language-specific request handling
let rc : RequestContext :=
{ rpcSessions := st.rpcSessions
srcSearchPathTask := st.srcSearchPathTask
doc := st.doc
cancelTk
hLog := ctx.hLog
initParams := ctx.initParams }
let requestTask? ← EIO.toIO' <| handleLspRequest method params rc
let requestTask ← match requestTask? with
| Except.error e =>
emitResponse ctx (isComplete := false) <| e.toLspResponseError id
pure <| ServerTask.pure <| .ok ()
| Except.ok requestTask => ServerTask.IO.mapTaskCheap (t := requestTask) fun
open Widget RequestM Language in
def handleGetInteractiveDiagnosticsRequest
(ctx : WorkerContext)
(params : GetInteractiveDiagnosticsParams)
: RequestM (Array InteractiveDiagnostic) := do
let doc ← readDoc
-- NOTE: always uses latest document (which is the only one we can retrieve diagnostics for);
-- any race should be temporary as the client should re-request interactive diagnostics when
-- they receive the non-interactive diagnostics for the new document
let stickyDiags ← ctx.stickyDiagnosticsRef.get
let diags ← doc.diagnosticsRef.get
-- NOTE: does not wait for `lineRange?` to be fully elaborated, which would be problematic with
-- fine-grained incremental reporting anyway; instead, the client is obligated to resend the
-- request when the non-interactive diagnostics of this range have changed
return (stickyDiags ++ diags).filter fun diag =>
let r := diag.fullRange
let diagStartLine := r.start.line
let diagEndLine :=
if r.end.character == 0 then
r.end.line
else
r.end.line + 1
params.lineRange?.all fun ⟨s, e⟩ =>
-- does [s,e) intersect [diagStartLine,diagEndLine)?
s ≤ diagStartLine ∧ diagStartLine < e
diagStartLine ≤ s ∧ s < diagEndLine
def handlePreRequestSpecialCases? (ctx : WorkerContext) (st : WorkerState)
(id : RequestID) (method : String) (params : Json)
: RequestM (Option (RequestTask (LspResponse Json))) := do
match method with
| "$/lean/rpc/call" =>
let params ← RequestM.parseRequestParams Lsp.RpcCallParams params
if params.method != `Lean.Widget.getInteractiveDiagnostics then
return none
let some seshRef := st.rpcSessions.find? params.sessionId
| throw RequestError.rpcNeedsReconnect
let params ← RequestM.parseRequestParams Widget.GetInteractiveDiagnosticsParams params.params
let resp ← handleGetInteractiveDiagnosticsRequest ctx params
let resp ← seshRef.modifyGet fun st =>
rpcEncode resp st.objects |>.map (·) ({st with objects := ·})
return some <| .pure { response := resp, isComplete := true }
| "codeAction/resolve" =>
let params ← RequestM.parseRequestParams CodeAction params
let some data := params.data?
| throw (RequestError.invalidParams "Expected a data field on CodeAction.")
let data ← RequestM.parseRequestParams CodeActionResolveData data
if data.providerName != importAllUnknownIdentifiersProvider then
return none
return some <| ← RequestM.asTask do
let fileRange := ⟨0, st.doc.meta.text.source.endPos⟩
let unknownIdentifierRanges ← waitUnknownIdentifierRanges st.doc fileRange
if unknownIdentifierRanges.isEmpty then
return { response := toJson params, isComplete := true }
let action? ← handleResolveImportAllUnknownIdentifiersCodeAction? id params unknownIdentifierRanges
let action := action?.getD params
return { response := toJson action, isComplete := true }
| _ =>
return none
def handlePostRequestSpecialCases (id : RequestID) (method : String) (params : Json)
(task : RequestTask (LspResponse Json)) : RequestM (RequestTask (LspResponse Json)) := do
let doc ← RequestM.readDoc
match method with
| "textDocument/codeAction" =>
let .ok (params : CodeActionParams) := fromJson? params
| return task
RequestM.mapRequestTaskCostly task fun r => do
let isSourceAction := params.context.only?.any fun only =>
only.contains "source" || only.contains "source.organizeImports"
if isSourceAction then
let unknownIdentifierRanges ← waitUnknownIdentifierRanges doc ⟨0, doc.meta.text.source.endPos⟩
if unknownIdentifierRanges.isEmpty then
return r
let .ok (codeActions : Array CodeAction) := fromJson? r.response
| return r
return { r with response := toJson <| codeActions.push <| importAllUnknownIdentifiersCodeAction params "source.organizeImports" }
else
let requestedRange := doc.meta.text.lspRangeToUtf8Range params.range
let unknownIdentifierRanges ← waitUnknownIdentifierRanges doc requestedRange
if unknownIdentifierRanges.isEmpty then
return r
let .ok (codeActions : Array CodeAction) := fromJson? r.response
| return r
RequestM.checkCancelled
-- Since computing the unknown identifier code actions is *really* expensive,
-- we only do it when the user has stopped typing for a second.
IO.sleep 1000
RequestM.checkCancelled
let unknownIdentifierCodeActions ← handleUnknownIdentifierCodeAction id params requestedRange unknownIdentifierRanges
return { r with response := toJson <| codeActions ++ unknownIdentifierCodeActions }
| _ =>
return task
def emitRequestResponse
(requestTask? : Except RequestError (RequestTask (LspResponse Json)))
(cancelTk : RequestCancellationToken)
(id : RequestID)
(method : String)
: WorkerM (ServerTask (Except Error Unit)) := do
let ctx ← read
match requestTask? with
| Except.error e =>
emitResponse ctx (isComplete := false) <| e.toLspResponseError id
return ServerTask.pure <| .ok ()
| Except.ok requestTask =>
ServerTask.IO.mapTaskCheap (t := requestTask) fun
| Except.ok r => do
if ← cancelTk.wasCancelledByCancelRequest then
-- Try not to emit a partial response if this request was cancelled.
@ -754,10 +841,7 @@ section MessageHandling
emitResponse ctx (isComplete := r.isComplete) <| .response id (toJson r.response)
| Except.error e =>
emitResponse ctx (isComplete := false) <| e.toLspResponseError id
queueRequest id { cancelTk, requestTask }
where
emitResponse (ctx : WorkerContext) (m : JsonRpc.Message) (isComplete : Bool) : IO Unit := do
ctx.chanOut.send m
let timestamp ← IO.monoMsNow
@ -770,8 +854,47 @@ section MessageHandling
some { lastRefreshTimestamp := timestamp, successiveRefreshAttempts := 0 }
}
def handleResponse (_ : RequestID) (_ : Json) : WorkerM Unit :=
return -- The only response that we currently expect here is always empty
def handleRequest (id : RequestID) (method : String) (params : Json)
: WorkerM Unit := do
let ctx ← read
let st ← get
ctx.modifyPartialHandler method fun h => { h with
pendingRefreshInfo? := none
requestsInFlight := h.requestsInFlight + 1
}
let hasHandledSpecialCase ← handleStatefulPreRequestSpecialCases id method params
if hasHandledSpecialCase then
return
let cancelTk ← RequestCancellationToken.new
-- TODO: move into language-specific request handling
let rc : RequestContext := {
rpcSessions := st.rpcSessions
srcSearchPathTask := st.srcSearchPathTask
doc := st.doc
cancelTk
hLog := ctx.hLog
initParams := ctx.initParams
serverRequestEmitter := sendUntypedServerRequest ctx
}
let requestTask? ← EIO.toIO' <| RequestM.run (rc := rc) do
if let some response ← handlePreRequestSpecialCases? ctx st id method params then
return response
let task ← handleLspRequest method params
let task ← handlePostRequestSpecialCases id method params task
return task
let requestTask ← emitRequestResponse requestTask? cancelTk id method
queueRequest id { cancelTk, requestTask }
def handleResponse (id : RequestID) (response : Json) : WorkerM Unit := do
let ctx ← read
ctx.resolveServerRequestResponse id (.success response)
def handleResponseError (id : RequestID) (code : ErrorCode) (message : String) : WorkerM Unit := do
let ctx ← read
ctx.resolveServerRequestResponse id (.failure code message)
end MessageHandling
@ -811,9 +934,8 @@ section MainLoop
| Message.response id result =>
handleResponse id result
mainLoop
| Message.responseError .. =>
-- Ignore all errors as we currently only handle a single request with an optional response
-- where failure is not an issue.
| Message.responseError id code message _ =>
handleResponseError id code message
mainLoop
| _ => throwServerError "Got invalid JSON-RPC message"
end MainLoop
@ -871,7 +993,7 @@ def runRefreshTasks : WorkerM (Array (ServerTask Unit)) := do
if cancelled then
return
continue
sendServerRequest ctx refreshMethod (none : Option Nat)
let _ ← sendServerRequest (Option Nat) (Option Nat) ctx refreshMethod none
return tasks
where

View file

@ -141,8 +141,8 @@ def handleInlayHints (p : InlayHintParams) (s : InlayHintState) :
| some lastEditTimestamp =>
let timeSinceLastEditMs := timestamp - lastEditTimestamp
inlayHintEditDelayMs - timeSinceLastEditMs
let (snaps, _, isComplete) ← ctx.doc.cmdSnaps.getFinishedPrefixWithConsistentLatency editDelayMs.toUInt32 (cancelTk? := ctx.cancelTk.cancellationTask)
if ← IO.hasFinished ctx.cancelTk.cancellationTask then
let (snaps, _, isComplete) ← ctx.doc.cmdSnaps.getFinishedPrefixWithConsistentLatency editDelayMs.toUInt32 (cancelTks := ctx.cancelTk.cancellationTasks)
if ← ctx.cancelTk.wasCancelled then
-- Inlay hint request has been cancelled, either by a cancellation request or another edit.
-- In the former case, we will simply discard the result and respond with a request error
-- denoting cancellation.

View file

@ -165,7 +165,7 @@ def handleSemanticTokensFull (_ : SemanticTokensParams) (_ : SemanticTokensState
-- for the full file before sending a response. This means that the response will be incomplete,
-- which we mitigate by regularly sending `workspace/semanticTokens/refresh` requests in the
-- `FileWorker` to tell the client to re-compute the semantic tokens.
let (snaps, _, isComplete) ← doc.cmdSnaps.getFinishedPrefixWithTimeout 3000 (cancelTk? := ctx.cancelTk.cancellationTask)
let (snaps, _, isComplete) ← doc.cmdSnaps.getFinishedPrefixWithTimeout 3000 (cancelTks := ctx.cancelTk.cancellationTasks)
let response ← computeSemanticTokens doc 0 none snaps
return ({ response, isComplete }, ⟨⟩)

View file

@ -528,26 +528,35 @@ def definitionOf?
return some ⟨⟨moduleUri, definitionRange⟩, definitionParentDeclInfo?⟩
return none
/-- A match in `References.definitionsMatching`. -/
structure MatchedDefinition (α β : Type) where
/-- Result of `filterMapMod`. -/
mod : α
/-- Result of `filterMapIdent`. -/
ident : β
/-- Definition range of matched identifier. -/
range : Range
/-- Yields all definitions matching the given `filter`. -/
def definitionsMatching
(self : References)
(srcSearchPath : SearchPath)
(filter : Name → Option α)
(maxAmount? : Option Nat := none) : IO $ Array (α × Location) := do
(self : References)
(filterMapMod : Name → IO (Option α))
(filterMapIdent : Name → IO (Option β))
(cancelTk? : Option CancelToken := none)
: IO (Array (MatchedDefinition α β)) := do
let mut result := #[]
for (module, refs) in self.allRefs do
let some path ← srcSearchPath.findModuleWithExt "lean" module
if let some cancelTk := cancelTk? then
if ← cancelTk.isSet then
return result
let some a ← filterMapMod module
| continue
let uri := System.Uri.pathToUri <| ← IO.FS.realPath path
for (ident, info) in refs do
let (RefIdent.const _ nameString, some ⟨definitionRange, _⟩) := (ident, info.definition?)
| continue
let some a := filter nameString.toName
let some b ← filterMapIdent nameString.toName
| continue
result := result.push (a, ⟨uri, definitionRange⟩)
if let some maxAmount := maxAmount? then
if result.size >= maxAmount then
return result
result := result.push ⟨a, b, definitionRange⟩
return result
end References

View file

@ -5,40 +5,52 @@ Authors: Marc Huisinga
-/
prelude
import Init.System.Promise
import Lean.Server.ServerTask
namespace Lean.Server
structure RequestCancellationToken where
cancelledByCancelRequest : IO.Ref Bool
cancelledByEdit : IO.Ref Bool
cancellationPromise : IO.Promise Unit
cancelledByCancelRequest : IO.Ref Bool
cancelledByEdit : IO.Ref Bool
requestCancellationPromise : IO.Promise Unit
editCancellationPromise : IO.Promise Unit
namespace RequestCancellationToken
def new : BaseIO RequestCancellationToken := do
return {
cancelledByCancelRequest := ← IO.mkRef false
cancelledByEdit := ← IO.mkRef false
cancellationPromise := ← IO.Promise.new
cancelledByCancelRequest := ← IO.mkRef false
cancelledByEdit := ← IO.mkRef false
requestCancellationPromise := ← IO.Promise.new
editCancellationPromise := ← IO.Promise.new
}
def cancelByCancelRequest (tk : RequestCancellationToken) : BaseIO Unit := do
tk.cancelledByCancelRequest.set true
tk.cancellationPromise.resolve ()
tk.requestCancellationPromise.resolve ()
def cancelByEdit (tk : RequestCancellationToken) : BaseIO Unit := do
tk.cancelledByEdit.set true
tk.cancellationPromise.resolve ()
tk.editCancellationPromise.resolve ()
def cancellationTask (tk : RequestCancellationToken) : Task Unit :=
tk.cancellationPromise.result!
def requestCancellationTask (tk : RequestCancellationToken): ServerTask Unit :=
tk.requestCancellationPromise.result!
def editCancellationTask (tk : RequestCancellationToken) : ServerTask Unit :=
tk.editCancellationPromise.result!
def cancellationTasks (tk : RequestCancellationToken) : List (ServerTask Unit) :=
[tk.requestCancellationTask, tk.editCancellationTask]
def wasCancelledByCancelRequest (tk : RequestCancellationToken) : BaseIO Bool :=
tk.cancelledByCancelRequest.get
def wasCancelledByEdit (tk : RequestCancellationToken) : BaseIO Bool := do
def wasCancelledByEdit (tk : RequestCancellationToken) : BaseIO Bool :=
tk.cancelledByEdit.get
def wasCancelled (tk : RequestCancellationToken) : BaseIO Bool := do
return (← tk.wasCancelledByCancelRequest) || (← tk.wasCancelledByEdit)
end RequestCancellationToken
structure RequestCancellation where

View file

@ -88,6 +88,18 @@ partial def SnapshotTree.findInfoTreeAtPos (text : FileMap) (tree : SnapshotTree
| return (none, .proceed (foldChildren := true))
return (infoTree, .done)
partial def SnapshotTree.collectMessagesInRange (tree : SnapshotTree)
(requestedRange : String.Range) : ServerTask MessageLog :=
tree.foldSnaps (init := .empty) fun snap log => Id.run do
let some stx := snap.stx?
| return .pure (log, .proceed (foldChildren := true))
let some range := stx.getRangeWithTrailing? (canonicalOnly := true)
| return .pure (log, .proceed (foldChildren := true))
if ! range.overlaps requestedRange (includeFirstStop := true) (includeSecondStop := true) then
return .pure (log, .proceed (foldChildren := false))
return snap.task.asServerTask.mapCheap fun tree => Id.run do
return (log ++ tree.element.diagnostics.msgLog, .proceed (foldChildren := true))
end Lean.Language
namespace Lean.Server
@ -109,7 +121,7 @@ def methodNotFound (method : String) : RequestError :=
message := s!"No request handler found for '{method}'" }
def invalidParams (message : String) : RequestError :=
{code := ErrorCode.invalidParams, message}
{ code := ErrorCode.invalidParams, message }
def internalError (message : String) : RequestError :=
{ code := ErrorCode.internalError, message }
@ -117,6 +129,9 @@ def internalError (message : String) : RequestError :=
def requestCancelled : RequestError :=
{ code := ErrorCode.requestCancelled, message := "" }
def rpcNeedsReconnect : RequestError :=
{ code := ErrorCode.rpcNeedsReconnect, message := "Outdated RPC session" }
def ofException (e : Lean.Exception) : IO RequestError :=
return internalError (← e.toMessageData.toString)
@ -132,17 +147,27 @@ end RequestError
def parseRequestParams (paramType : Type) [FromJson paramType] (params : Json)
: Except RequestError paramType :=
fromJson? params |>.mapError fun inner =>
{ code := JsonRpc.ErrorCode.parseError
message := s!"Cannot parse request params: {params.compress}\n{inner}" }
fromJson? params |>.mapError fun inner => {
code := JsonRpc.ErrorCode.invalidParams
message := s!"Cannot parse request params: {params.compress}\n{inner}"
}
inductive ServerRequestResponse (α : Type) where
| success (response : α)
| failure (code : JsonRpc.ErrorCode) (message : String)
deriving Inhabited
abbrev ServerRequestEmitter := (method : String) → (param : Json)
→ BaseIO (ServerTask (ServerRequestResponse Json))
structure RequestContext where
rpcSessions : RBMap UInt64 (IO.Ref FileWorker.RpcSession) compare
srcSearchPathTask : ServerTask SearchPath
doc : FileWorker.EditableDocument
hLog : IO.FS.Stream
initParams : Lsp.InitializeParams
cancelTk : RequestCancellationToken
rpcSessions : RBMap UInt64 (IO.Ref FileWorker.RpcSession) compare
srcSearchPathTask : ServerTask SearchPath
doc : FileWorker.EditableDocument
hLog : IO.FS.Stream
initParams : Lsp.InitializeParams
cancelTk : RequestCancellationToken
serverRequestEmitter : ServerRequestEmitter
def RequestContext.srcSearchPath (rc : RequestContext) : SearchPath :=
rc.srcSearchPathTask.get
@ -152,6 +177,9 @@ abbrev RequestT m := ReaderT RequestContext <| ExceptT RequestError m
/-- Workers execute request handlers in this monad. -/
abbrev RequestM := ReaderT RequestContext <| EIO RequestError
def RequestM.run (act : RequestM α) (rc : RequestContext) : EIO RequestError α :=
act rc
abbrev RequestTask.pure (a : α) : RequestTask α := ServerTask.pure (.ok a)
instance : MonadLift IO RequestM where
@ -209,11 +237,49 @@ def bindTaskCostly (t : ServerTask α) (f : α → RequestM (RequestTask β)) :
let rc ← readThe RequestContext
ServerTask.EIO.bindTaskCostly t (f · rc)
def mapRequestTaskCheap (t : RequestTask α) (f : α → RequestM β) : RequestM (RequestTask β) := do
mapTaskCheap (t := t) fun
| .error e => throw e
| .ok r => f r
def mapRequestTaskCostly (t : RequestTask α) (f : α → RequestM β) : RequestM (RequestTask β) := do
mapTaskCostly (t := t) fun
| .error e => throw e
| .ok r => f r
def bindRequestTaskCheap (t : RequestTask α) (f : α → RequestM (RequestTask β)) : RequestM (RequestTask β) := do
bindTaskCheap (t := t) fun
| .error e => throw e
| .ok r => f r
def bindRequestTaskCostly (t : RequestTask α) (f : α → RequestM (RequestTask β)) : RequestM (RequestTask β) := do
bindTaskCostly (t := t) fun
| .error e => throw e
| .ok r => f r
def parseRequestParams (paramType : Type) [FromJson paramType] (params : Json)
: RequestM paramType :=
EIO.ofExcept <| Server.parseRequestParams paramType params
def checkCancelled : RequestM Unit := do
let rc ← readThe RequestContext
if ← rc.cancelTk.wasCancelledByCancelRequest then
throw .requestCancelled
def sendServerRequest
paramType [ToJson paramType] responseType [FromJson responseType] [Inhabited responseType]
(method : String)
(param : paramType)
: RequestM (ServerTask (ServerRequestResponse responseType)) := do
let ctx ← read
let task ← ctx.serverRequestEmitter method (toJson param)
return task.mapCheap fun
| ServerRequestResponse.success response =>
match fromJson? response with
| .ok (response : responseType) => ServerRequestResponse.success response
| .error err => ServerRequestResponse.failure .parseError s!"Cannot parse server request response: {response.compress}\n{err}"
| ServerRequestResponse.failure code msg => ServerRequestResponse.failure code msg
def waitFindSnapAux (notFoundX : RequestM α) (x : Snapshot → RequestM α)
: Except IO.Error (Option Snapshot) → RequestM α
/- The elaboration task that we're waiting for may be aborted if the file contents change.
@ -385,7 +451,7 @@ def registerLspRequestHandler (method : String)
let fileSource := fun j =>
parseRequestParams paramType j |>.map Lsp.fileSource
let handle := fun j => do
let params ← liftExcept <| parseRequestParams paramType j
let params ← RequestM.parseRequestParams paramType j
let t ← handler params
pure <| t.mapCheap <| Except.map ToJson.toJson
@ -412,7 +478,7 @@ def chainLspRequestHandler (method : String)
let t ← oldHandler.handle j
let t := t.mapCheap fun x => x.bind fun j => FromJson.fromJson? j |>.mapError fun e =>
.internalError s!"Failed to parse original LSP response for `{method}` when chaining: {e}"
let params ← liftExcept <| parseRequestParams paramType j
let params ← RequestM.parseRequestParams paramType j
let t ← handler params t
pure <| t.mapCheap <| Except.map ToJson.toJson
@ -493,7 +559,7 @@ private def overrideStatefulLspRequestHandler
let stateRef ← IO.mkRef initState
let pureHandle : Json → Dynamic → RequestM (LspResponse Json × Dynamic) := fun param state => do
let param ← liftExcept <| parseRequestParams paramType param
let param ← RequestM.parseRequestParams paramType param
let state ← getState! method state stateType
let (r, state') ← handler param state
return ({ r with response := toJson r.response }, Dynamic.mk state')

View file

@ -414,6 +414,131 @@ section ServerM
s.references.modify fun refs =>
refs.finalizeWorkerRefs module params.version params.references
def emitServerRequestResponse [ToJson α] (fw : FileWorker) (r : Response α) : IO Unit := do
if ! ((← fw.state.atomically get) matches .running) then
return
try
fw.stdin.writeLspResponse r
catch _ =>
pure ()
def emitServerRequestResponseError (fw : FileWorker) (r : ResponseError Unit) : IO Unit := do
if ! ((← fw.state.atomically get) matches .running) then
return
try
fw.stdin.writeLspResponseError r
catch _ =>
pure ()
structure ModuleQueryMatchScore where
isExactMatch : Bool
score : Float
def ModuleQueryMatchScore.compare (ms1 ms2 : ModuleQueryMatchScore) : Ordering :=
let ⟨e1, s1⟩ := ms1
let ⟨e2, s2⟩ := ms2
if e1 && !e2 then
.gt
else if !e1 && e2 then
.lt
else
let d := s1 - s2
if d >= 0.0001 then
.gt
else if d <= -0.0001 then
.lt
else
.eq
structure ModuleQueryMatch extends ModuleQueryMatchScore where
decl : Name
declAsString : String
def ModuleQueryMatch.fastCompare (m1 m2 : ModuleQueryMatch) : Ordering :=
let ⟨ms1, _, s1⟩ := m1
let ⟨ms2, _, s2⟩ := m2
let r := ms1.compare ms2
if r != .eq then
r
else
Ord.compare s2.length s1.length
def ModuleQueryMatch.compare (m1 m2 : ModuleQueryMatch) : Ordering :=
let d1 := m1.decl
let d2 := m2.decl
if d2.isSuffixOf d1 then
.lt
else if d1.isSuffixOf d2 then
.gt
else
m1.fastCompare m2
def matchAgainstQuery? (query : LeanModuleQuery) (decl : Name) : Option ModuleQueryMatch := do
if isPrivateName decl then
none
let mut bestMatch? : Option ModuleQueryMatch := matchDecl? decl decl.toString
for openNamespace in query.openNamespaces do
match openNamespace with
| .allExcept «namespace» exceptions =>
if exceptions.contains decl then
continue
if ! «namespace».isPrefixOf decl then
continue
let namespacedDecl : Name := decl.replacePrefix «namespace» .anonymous
let match? := matchDecl? decl namespacedDecl.toString
bestMatch? := chooseBestMatch? bestMatch? match?
| .renamed «from» to =>
if decl != «from» then
continue
let match? := matchDecl? decl to.toString
bestMatch? := chooseBestMatch? bestMatch? match?
bestMatch?
where
matchDecl? (decl : Name) (identifier : String) : Option ModuleQueryMatch := do
if identifier == query.identifier then
return { decl, declAsString := decl.toString, isExactMatch := true, score := 1.0 }
let score ← FuzzyMatching.fuzzyMatchScoreWithThreshold? query.identifier identifier
return { decl, declAsString := decl.toString, isExactMatch := false, score }
chooseBestMatch? : Option ModuleQueryMatch → Option ModuleQueryMatch → Option ModuleQueryMatch
| none, none => none
| none, some m => some m
| some m, none => some m
| some m1, some m2 =>
if m1.compare m2 == .lt then
m2
else
m1
def handleQueryModule (fw : FileWorker) (id : RequestID) (params : LeanQueryModuleParams)
: ServerM (ServerTask Unit × CancelToken) := do
let s ← read
let cancelTk ← CancelToken.new
let task ← ServerTask.IO.asTask do
let refs ← s.references.get
let mut queryResults : Array LeanQueriedModule := #[]
for query in params.queries do
let filterMapMod mod := pure <| some mod
let filterMapIdent decl := pure <| matchAgainstQuery? query decl
let symbols ← refs.definitionsMatching filterMapMod filterMapIdent cancelTk
let sorted := symbols.qsort fun { ident := m1, .. } { ident := m2, .. } =>
m1.fastCompare m2 == .gt
let result : LeanQueriedModule := sorted.extract 0 10 |>.map fun m => {
module := m.mod
decl := m.ident.decl
isExactMatch := m.ident.isExactMatch
}
queryResults := queryResults.push result
if ← cancelTk.isSet then
emitServerRequestResponseError fw {
id, code := ErrorCode.requestCancelled, message := ""
}
return
emitServerRequestResponse fw {
id, result := { queryResults }
: Response LeanQueryModuleResponse
}
return (task.mapCheap (fun _ => ()), cancelTk)
/--
Updates the global import data with the import closure provided by the file worker after it
successfully processed its header.
@ -432,28 +557,42 @@ section ServerM
| Except.error e => WorkerEvent.ioError e
where
loop : ServerM WorkerEvent := do
let uri := fw.doc.uri
let o := (←read).hOut
let msg ←
try
fw.stdout.readLspMessage
catch _ =>
let exitCode ← fw.waitForProc
-- Remove surviving descendant processes, if any, such as from nested builds.
-- On Windows, we instead rely on elan doing this.
try fw.proc.kill catch _ => pure ()
-- TODO: Wait for process group to finish
match exitCode with
| 0 => return .terminated
| 2 => return .importsChanged
| _ => return .crashed exitCode
let mut pendingWorkerToWatchdogRequests : Std.TreeMap RequestID (ServerTask Unit × CancelToken) := ∅
while true do
let msg ←
try
fw.stdout.readLspMessage
catch _ =>
let exitCode ← fw.waitForProc
-- Remove surviving descendant processes, if any, such as from nested builds.
-- On Windows, we instead rely on elan doing this.
try fw.proc.kill catch _ => pure ()
-- TODO: Wait for process group to finish
match exitCode with
| 0 => return .terminated
| 2 => return .importsChanged
| _ => return .crashed exitCode
let (_, pendingWorkerToWatchdogRequests') ←
StateT.run (s := pendingWorkerToWatchdogRequests) <| handleMessage msg
pendingWorkerToWatchdogRequests := ∅
for (id, task, cancelTk) in pendingWorkerToWatchdogRequests' do
if ← task.hasFinished then
continue
pendingWorkerToWatchdogRequests := pendingWorkerToWatchdogRequests.insert id (task, cancelTk)
return .terminated
handleMessage (msg : JsonRpc.Message)
: StateT (Std.TreeMap RequestID (ServerTask Unit × CancelToken)) ServerM Unit :=
-- When the file worker is terminated by the main thread, the client can immediately launch
-- another file worker using `didOpen`. In this case, even when this task and the old file
-- worker process haven't terminated yet, we want to avoid emitting diagnostics and responses
-- from the old process, so that they can't race with one another in the client.
fw.state.atomically do
fw.state.atomically (m := StateT (Std.TreeMap RequestID (ServerTask Unit × CancelToken)) ServerM) do
let s ← get
let o := (← read).hOut
let uri := fw.doc.uri
if s matches .terminating then
return
-- Re. `o.writeLspMessage msg`:
@ -468,14 +607,24 @@ section ServerM
-- that were still pending.
if wasPending then
o.writeLspMessage msg
| Message.responseError id _ _ _ => do
| Message.responseError id code _ _ => do
let wasPending ← erasePendingRequest uri id
if code matches .requestCancelled then
let pendingWorkerToWatchdogRequests ← getThe (Std.TreeMap RequestID (ServerTask Unit × CancelToken))
if let some (_, cancelTk) := pendingWorkerToWatchdogRequests.get? id then
cancelTk.set
if wasPending then
o.writeLspMessage msg
| Message.request id method params? =>
let globalID ← (←read).serverRequestData.modifyGet
(·.trackOutboundRequest fw.doc.uri id)
o.writeLspMessage (Message.request globalID method params?)
if method == "$/lean/queryModule" then
if let some params := params? then
if let .ok (params : LeanQueryModuleParams) := fromJson? <| toJson params then
let (task, cancelTk) ← handleQueryModule fw id params
modifyThe (Std.TreeMap RequestID (ServerTask Unit × CancelToken)) (·.insert params.sourceRequestID (task, cancelTk))
else
let globalID ← (← read).serverRequestData.modifyGet
(·.trackOutboundRequest fw.doc.uri id)
o.writeLspMessage (Message.request globalID method params?)
| Message.notification "$/lean/ileanInfoUpdate" params =>
if let some params := params then
if let Except.ok params := FromJson.fromJson? <| ToJson.toJson params then
@ -491,8 +640,6 @@ section ServerM
| _ =>
o.writeLspMessage msg
loop
def startFileWorker (m : DocumentMeta) : ServerM Unit := do
let st ← read
st.hOut.writeLspMessage <| mkFileProgressAtPosNotification m 0
@ -818,19 +965,27 @@ def handleWorkspaceSymbol (p : WorkspaceSymbolParams) : ReaderT ReferenceRequest
if p.query.isEmpty then
return #[]
let references := (← read).references
let srcSearchPath := (← read).srcSearchPath
let symbols ← references.definitionsMatching srcSearchPath (maxAmount? := none)
fun name =>
let name := privateToUserName? name |>.getD name
if let some score := fuzzyMatchScoreWithThreshold? p.query name.toString then
some (name.toString, score)
else
none
let srcSearchPath : Lean.SearchPath := (← read).srcSearchPath
let filterMapMod mod := do
let some path ← srcSearchPath.findModuleWithExt "lean" mod
| return none
let uri := System.Uri.pathToUri <| ← IO.FS.realPath path
return some uri
let filterMapIdent ident := do
let ident := privateToUserName? ident |>.getD ident
if let some score := fuzzyMatchScoreWithThreshold? p.query ident.toString then
return some (ident.toString, score)
else
return none
let symbols ← references.definitionsMatching filterMapMod filterMapIdent
return symbols
|>.qsort (fun ((_, s1), _) ((_, s2), _) => s1 > s2)
|>.qsort (fun { ident := (_, s1), .. } { ident := (_, s2), .. } => s1 > s2)
|>.extract 0 100 -- max amount
|>.map fun ((name, _), location) =>
{ name, kind := SymbolKind.constant, location }
|>.map fun m => {
name := m.ident.1
kind := SymbolKind.constant
location := { uri := m.mod, range := m.range }
}
def handlePrepareRename (p : PrepareRenameParams) : ReaderT ReferenceRequestContext IO (Option Range) := do
-- This just checks that the cursor is over a renameable identifier
@ -1189,7 +1344,7 @@ def mkLeanServerCapabilities : ServerCapabilities := {
}
codeActionProvider? := some {
resolveProvider? := true,
codeActionKinds? := some #["quickfix", "refactor"]
codeActionKinds? := some #["quickfix", "refactor", "source.organizeImports"]
}
inlayHintProvider? := some {
resolveProvider? := false

View file

@ -65,10 +65,6 @@ where
| .trace .., _ => .text "(trace)"
tt.stripTags
/-- Compares interactive diagnostics modulo `TaggedText` tags and traces. -/
def compareAsDiagnostics (a b : InteractiveDiagnostic) : Ordering :=
compareByUserVisible a.toDiagnostic b.toDiagnostic
end InteractiveDiagnostic
private def mkPPContext (nCtx : NamingContext) (ctx : MessageDataContext) : PPContext := {