From 1b42255493cf2e0966092a541e0dae7ffffb2e09 Mon Sep 17 00:00:00 2001 From: Wojciech Nawrocki Date: Sat, 17 Jul 2021 14:23:22 -0700 Subject: [PATCH] feat: check RPC reference types --- src/Lean/Server/FileWorker/Rpc.lean | 48 +++++++++++++++++++++++---- src/Lean/Server/FileWorker/Utils.lean | 6 ++-- 2 files changed, 45 insertions(+), 9 deletions(-) diff --git a/src/Lean/Server/FileWorker/Rpc.lean b/src/Lean/Server/FileWorker/Rpc.lean index ab462994ca..2700826f44 100644 --- a/src/Lean/Server/FileWorker/Rpc.lean +++ b/src/Lean/Server/FileWorker/Rpc.lean @@ -1,4 +1,5 @@ import Lean.Elab.Command +import Lean.Elab.Term import Lean.Meta.Basic import Lean.Data.Lsp.Extra import Lean.Server.Requests @@ -60,17 +61,17 @@ structure WithRpcRef (α : Type u) where namespace WithRpcRef -private unsafe def encode (r : WithRpcRef α) : RequestM Lsp.RpcRef := do +protected unsafe def encodeUnsafe (typeName : Name) (r : WithRpcRef α) : RequestM Lsp.RpcRef := do let rc ← read let some rpcSesh ← rc.rpcSesh? | throwThe IO.Error "internal server error: forgot to validate RPC session" let ref ← IO.UntypedRef.mkRefUnsafe r.val let uid := ref.ptr -- TODO random uuid? - rpcSesh.aliveRefs.modify fun refs => refs.insert uid ref + rpcSesh.aliveRefs.modify fun refs => refs.insert uid (typeName, ref) return uid -private unsafe def decodeAs (α : Type) (r : Lsp.RpcRef) : RequestM (WithRpcRef α) := do +protected unsafe def decodeUnsafeAs (α) (typeName : Name) (r : Lsp.RpcRef) : RequestM (WithRpcRef α) := do let rc ← read let some rpcSesh ← rc.rpcSesh? | throwThe IO.Error "internal server error: forgot to validate RPC session" @@ -78,11 +79,44 @@ private unsafe def decodeAs (α : Type) (r : Lsp.RpcRef) : RequestM (WithRpcRef match (← rpcSesh.aliveRefs.get).find? r with | none => throwThe RequestError { code := JsonRpc.ErrorCode.invalidParams message := s!"RPC reference '{r}' is not valid" } - | some r => WithRpcRef.mk <$> r.getAsUnsafe α + | some (nm, ref) => + if nm != typeName then + throwThe RequestError { code := JsonRpc.ErrorCode.invalidParams + message := s!"RPC call type mismatch in reference '{r}'\n" ++ + "expected '{typeName}', got '{nm}'" } + WithRpcRef.mk <$> ref.getAsUnsafe α -unsafe instance : LspEncoding (WithRpcRef α) Lsp.RpcRef where - lspEncode r := encode r - lspDecode r := decodeAs α r +-- TODO(WN): Make this a parameterised `deriving LspEncoding (ref := true)` +open Elab Command Term in +elab "mkWithRefInstance" typeId:ident : command => do + let env ← getEnv + let tps ← liftTermElabM none do + resolveName typeId typeId.getId [] [] none + for (tp, _) in tps do + if let some typeNm := tp.constName? then + -- TODO(WN): check that `tp` is not a scalar type + let cmds ← `( + namespace $typeId:ident + private unsafe def encodeUnsafe (r : WithRpcRef $typeId:ident) : RequestM Lsp.RpcRef := + WithRpcRef.encodeUnsafe $(quote typeNm) r + + @[implementedBy encodeUnsafe] + private constant encode (r : WithRpcRef $typeId:ident) : RequestM Lsp.RpcRef + + private unsafe def decodeUnsafe (r : Lsp.RpcRef) : RequestM (WithRpcRef $typeId:ident) := + WithRpcRef.decodeUnsafeAs $typeId:ident $(quote typeNm) r + + @[implementedBy decodeUnsafe] + private constant decode (r : Lsp.RpcRef) : RequestM (WithRpcRef $typeId:ident) + + instance : LspEncoding (WithRpcRef $typeId:ident) Lsp.RpcRef where + lspEncode a := encode a + lspDecode a := decode a + end $typeId:ident + ) + Command.elabCommand cmds + return + throwError "unknown type '{typeId}'" end WithRpcRef diff --git a/src/Lean/Server/FileWorker/Utils.lean b/src/Lean/Server/FileWorker/Utils.lean index 21241e46c9..fa3c08a578 100644 --- a/src/Lean/Server/FileWorker/Utils.lean +++ b/src/Lean/Server/FileWorker/Utils.lean @@ -55,10 +55,12 @@ structure EditableDocument where structure RpcSession where sessionId : USize - /-- Objects that are being kept alive for the RPC client, mapped to by their RPC reference. + /-- Objects that are being kept alive for the RPC client, together with their type names, + mapped to by their RPC reference. + Note that we may currently have multiple references to the same object. It is only disposed of once all of those are gone. This simplifies the client a bit as it can drop every reference received separately. -/ - aliveRefs : IO.Ref (Std.PersistentHashMap USize UntypedRef) + aliveRefs : IO.Ref (Std.PersistentHashMap USize (Name × UntypedRef)) end Lean.Server.FileWorker