From dc65bb080aa7e511d359bcafbf733ce51bfa4ece Mon Sep 17 00:00:00 2001 From: Gabriel Ebner Date: Sun, 9 Jan 2022 12:38:53 +0100 Subject: [PATCH] fix: race condition in RPC request handler --- src/Lean/Server/FileWorker.lean | 15 +++++++-------- src/Lean/Server/FileWorker/Utils.lean | 4 ++-- src/Lean/Server/Rpc/RequestHandling.lean | 8 ++------ 3 files changed, 11 insertions(+), 16 deletions(-) diff --git a/src/Lean/Server/FileWorker.lean b/src/Lean/Server/FileWorker.lean index 86bf6c7ba8..bb5ef45128 100644 --- a/src/Lean/Server/FileWorker.lean +++ b/src/Lean/Server/FileWorker.lean @@ -304,20 +304,19 @@ def handleRpcRelease (p : Lsp.RpcReleaseParams) : WorkerM Unit := do -- NOTE(WN): when the worker restarts e.g. due to changed imports, we may receive `rpc/release` -- for the previous RPC session. This is fine, just ignore. if let some seshRef := st.rpcSessions.find? p.sessionId then - let mut sesh ← seshRef.get - for ref in p.refs do - sesh := sesh.release ref |>.snd - sesh ← sesh.keptAlive - seshRef.set sesh + let monoMsNow ← IO.monoMsNow + seshRef.modify fun sesh => Id.run do + let mut sesh := sesh + for ref in p.refs do + sesh := sesh.release ref |>.snd + sesh.keptAlive monoMsNow def handleRpcKeepAlive (p : Lsp.RpcKeepAliveParams) : WorkerM Unit := do let st ← get match st.rpcSessions.find? p.sessionId with | none => return | some seshRef => - let sesh ← seshRef.get - let sesh ← sesh.keptAlive - seshRef.set sesh + seshRef.modify (·.keptAlive (← IO.monoMsNow)) end NotificationHandling diff --git a/src/Lean/Server/FileWorker/Utils.lean b/src/Lean/Server/FileWorker/Utils.lean index df05b44b2f..ddadeaf94d 100644 --- a/src/Lean/Server/FileWorker/Utils.lean +++ b/src/Lean/Server/FileWorker/Utils.lean @@ -96,8 +96,8 @@ def release (st : RpcSession) (ref : Lsp.RpcRef) : Bool × RpcSession := let released := st.aliveRefs.contains ref (released, { st with aliveRefs := st.aliveRefs.erase ref }) -def keptAlive (s : RpcSession) : IO RpcSession := do - return { s with expireTime := (← IO.monoMsNow) + keepAliveTimeMs } +def keptAlive (monoMsNow : Nat) (s : RpcSession) : RpcSession := + { s with expireTime := monoMsNow + keepAliveTimeMs } def hasExpired (s : RpcSession) : IO Bool := return s.expireTime ≤ (← IO.monoMsNow) diff --git a/src/Lean/Server/Rpc/RequestHandling.lean b/src/Lean/Server/Rpc/RequestHandling.lean index 61269d33d6..603cb7d06a 100644 --- a/src/Lean/Server/Rpc/RequestHandling.lean +++ b/src/Lean/Server/Rpc/RequestHandling.lean @@ -44,12 +44,10 @@ def registerRpcCallHandler (method : Name) let some seshRef ← rc.rpcSessions.find? seshId | throwThe RequestError { code := JsonRpc.ErrorCode.rpcNeedsReconnect message := s!"Outdated RPC session" } - let sesh ← seshRef.get - let t ← RequestM.asTask do let paramsLsp ← liftExcept <| parseRequestParams paramLspType j let act := rpcDecode (α := paramType) (β := paramLspType) (m := StateM FileWorker.RpcSession) paramsLsp - match act.run' sesh with + match act.run' (← seshRef.get) with | Except.ok v => return v | Except.error e => throwThe RequestError { code := JsonRpc.ErrorCode.invalidParams @@ -64,9 +62,7 @@ def registerRpcCallHandler (method : Name) | Except.error e => throw e | Except.ok ret => do let act := rpcEncode (α := respType) (β := respLspType) (m := StateM FileWorker.RpcSession) ret - let (retLsp, sesh') := act.run sesh - seshRef.set sesh' - return toJson retLsp + toJson (← seshRef.modifyGet act.run) rpcProcedures.modify fun ps => ps.insert method ⟨wrapper⟩