From 0e81835fae1edbd26c659f2a8dbad5c18e4b225a Mon Sep 17 00:00:00 2001 From: Yi Ming Date: Sat, 2 May 2026 13:33:45 +0800 Subject: [PATCH] refactor(lsp): message stream abstraction for message framing --- runtime/lua/vim/lsp/rpc.lua | 84 ++++++++++++++---------------- runtime/lua/vim/net/_transport.lua | 55 +++++++++++++++++++ 2 files changed, 93 insertions(+), 46 deletions(-) diff --git a/runtime/lua/vim/lsp/rpc.lua b/runtime/lua/vim/lsp/rpc.lua index 5026a132f5..a81dbc7280 100644 --- a/runtime/lua/vim/lsp/rpc.lua +++ b/runtime/lua/vim/lsp/rpc.lua @@ -226,35 +226,25 @@ end function M.create_read_loop(handle_body, on_exit, on_error) on_exit = on_exit or function() end on_error = on_error or function() end - local co = coroutine.create(message_decoder) - coroutine.resume(co) - return function(err, chunk) - if err then - on_error(err, M.client_errors.READ_ERROR) - return - end - - if not chunk then - on_exit() - return - end - - if coroutine.status(co) == 'dead' then - return - end - - while true do - local ok, res = coroutine.resume(co, chunk) - if not ok then - on_error(res, M.client_errors.INVALID_SERVER_MESSAGE) - break - elseif res then - handle_body(res) - chunk = '' + local message_stream = vim_transport.MessageStream.new( + message_decoder, + format_message_with_content_length, + function(err, chunk) + if err then + on_error(err, M.client_errors.READ_ERROR) + elseif chunk then + handle_body(chunk) else - break + on_exit() end + end, + function(err) + on_error(err, M.client_errors.INVALID_SERVER_MESSAGE) end + ) + + return function(err, chunk) + message_stream:feed(err, chunk) end end @@ -264,6 +254,7 @@ end --- @field private message_callbacks table dict of message_id to callback --- @field private notify_reply_callbacks table dict of message_id to callback --- @field private transport vim.Transport +--- @field private message_stream vim.MessageStream --- @field private dispatchers vim.lsp.rpc.Dispatchers --- --- See [vim.lsp.rpc.request()] @@ -282,8 +273,10 @@ local Client = {} ---@package ---@param dispatchers vim.lsp.rpc.Dispatchers ---@param transport vim.Transport +---@param decode fun(buf: vim._core.stringbuffer): string? +---@param format fun(msg: string): string ---@return vim.lsp.rpc.Client -function Client.new(dispatchers, transport) +function Client.new(dispatchers, transport, decode, format) local result = { message_index = 0, message_callbacks = {}, @@ -324,27 +317,26 @@ function Client.new(dispatchers, transport) ---@cast result vim.lsp.rpc.Client local self = setmetatable(result, { __index = Client }) - --- @param body string - local function handle_body(body) - self:handle_body(body) - end - local function on_exit() - ---@diagnostic disable-next-line: invisible - self.transport:terminate() - end - - --- @param errkind vim.lsp.rpc.ClientErrors - local function on_error(err, errkind) - self:on_error(errkind, err) - if errkind == M.client_errors.INVALID_SERVER_MESSAGE then + self.message_stream = vim_transport.MessageStream.new(decode, format, function(err, data) + if err then + self:on_error(M.client_errors.READ_ERROR, err) + elseif data then + self:handle_body(data) + else ---@diagnostic disable-next-line: invisible self.transport:terminate() end - end + end, function(err) + self:on_error(M.client_errors.INVALID_SERVER_MESSAGE, err) + ---@diagnostic disable-next-line: invisible + self.transport:terminate() + end) - local on_read = M.create_read_loop(handle_body, on_exit, on_error) - transport:listen(on_read, dispatchers.on_exit) + transport:listen(function(err, data) + ---@diagnostic disable-next-line: invisible + self.message_stream:feed(err, data) + end, dispatchers.on_exit) return self end @@ -356,7 +348,7 @@ function Client:encode_and_send(payload) end local jsonstr = vim.json.encode(payload) - self.transport:write(format_message_with_content_length(jsonstr)) + self.transport:write(self.message_stream.encode(jsonstr)) return true end @@ -637,7 +629,7 @@ function M.connect(host_or_path, port) dispatchers = merge_dispatchers(dispatchers) local transport = vim_transport.TransportConnect.new(host_or_path, port) - return Client.new(dispatchers, transport) + return Client.new(dispatchers, transport, message_decoder, format_message_with_content_length) end end @@ -665,7 +657,7 @@ function M.start(cmd, dispatchers, extra_spawn_params) dispatchers = merge_dispatchers(dispatchers) local transport = vim_transport.TransportRun.new(cmd, extra_spawn_params) - return Client.new(dispatchers, transport) + return Client.new(dispatchers, transport, message_decoder, format_message_with_content_length) end return M diff --git a/runtime/lua/vim/net/_transport.lua b/runtime/lua/vim/net/_transport.lua index b39462ad8c..b7c6e93aaa 100644 --- a/runtime/lua/vim/net/_transport.lua +++ b/runtime/lua/vim/net/_transport.lua @@ -182,7 +182,62 @@ function TransportConnect:terminate() end end +--- Create a message stream from a coroutine decoder. +--- +--- The decoder receives transport data from `coroutine.yield()` +--- and may yield a message body when a full message is available. +--- A nil yield means it needs more transport data. +--- Decoder errors are reported through `on_error`. +--- +---@class (private, exact) vim.MessageStream +---@field private co thread +---@field private decode fun() +---@field private on_read fun(err: string?, data: string?) +---@field private on_error fun(err: any) +---@field feed fun(self: vim.MessageStream, err: string?, data: string?) +---@field encode fun(msg: string): string +---@field new fun(decode: fun(), encode: (fun(msg: string): string), on_read: fun(err: string?, data: string?), on_error: fun(err: any)): vim.MessageStream +local MessageStream = {} + +function MessageStream.new(decode, encode, on_read, on_error) + local self = setmetatable({ + co = coroutine.create(decode), + decode = decode, + on_read = on_read, + on_error = on_error, + encode = encode, + }, { __index = MessageStream }) + + coroutine.resume(self.co) + return self +end + +---@param err string? +---@param data string? +function MessageStream:feed(err, data) + if err then + self.on_read(err, nil) + return + elseif data == nil then + self.on_read(nil, nil) + return + end + + while true do + local ok, body = coroutine.resume(self.co, data) + if not ok then + self.on_error(body) + return + elseif body == nil then + break + end + self.on_read(nil, body) + data = '' + end +end + return { TransportRun = TransportRun, TransportConnect = TransportConnect, + MessageStream = MessageStream, }