refactor(lsp): message stream abstraction for message framing

This commit is contained in:
Yi Ming
2026-05-02 13:33:45 +08:00
parent 929e644a5a
commit 0e81835fae
2 changed files with 93 additions and 46 deletions

View File

@@ -226,35 +226,25 @@ end
function M.create_read_loop(handle_body, on_exit, on_error)
on_exit = on_exit or function() end
on_error = on_error or function() end
local co = coroutine.create(message_decoder)
coroutine.resume(co)
return function(err, chunk)
if err then
on_error(err, M.client_errors.READ_ERROR)
return
end
if not chunk then
on_exit()
return
end
if coroutine.status(co) == 'dead' then
return
end
while true do
local ok, res = coroutine.resume(co, chunk)
if not ok then
on_error(res, M.client_errors.INVALID_SERVER_MESSAGE)
break
elseif res then
handle_body(res)
chunk = ''
local message_stream = vim_transport.MessageStream.new(
message_decoder,
format_message_with_content_length,
function(err, chunk)
if err then
on_error(err, M.client_errors.READ_ERROR)
elseif chunk then
handle_body(chunk)
else
break
on_exit()
end
end,
function(err)
on_error(err, M.client_errors.INVALID_SERVER_MESSAGE)
end
)
return function(err, chunk)
message_stream:feed(err, chunk)
end
end
@@ -264,6 +254,7 @@ end
--- @field private message_callbacks table<integer, function> dict of message_id to callback
--- @field private notify_reply_callbacks table<integer, function> dict of message_id to callback
--- @field private transport vim.Transport
--- @field private message_stream vim.MessageStream
--- @field private dispatchers vim.lsp.rpc.Dispatchers
---
--- See [vim.lsp.rpc.request()]
@@ -282,8 +273,10 @@ local Client = {}
---@package
---@param dispatchers vim.lsp.rpc.Dispatchers
---@param transport vim.Transport
---@param decode fun(buf: vim._core.stringbuffer): string?
---@param format fun(msg: string): string
---@return vim.lsp.rpc.Client
function Client.new(dispatchers, transport)
function Client.new(dispatchers, transport, decode, format)
local result = {
message_index = 0,
message_callbacks = {},
@@ -324,27 +317,26 @@ function Client.new(dispatchers, transport)
---@cast result vim.lsp.rpc.Client
local self = setmetatable(result, { __index = Client })
--- @param body string
local function handle_body(body)
self:handle_body(body)
end
local function on_exit()
---@diagnostic disable-next-line: invisible
self.transport:terminate()
end
--- @param errkind vim.lsp.rpc.ClientErrors
local function on_error(err, errkind)
self:on_error(errkind, err)
if errkind == M.client_errors.INVALID_SERVER_MESSAGE then
self.message_stream = vim_transport.MessageStream.new(decode, format, function(err, data)
if err then
self:on_error(M.client_errors.READ_ERROR, err)
elseif data then
self:handle_body(data)
else
---@diagnostic disable-next-line: invisible
self.transport:terminate()
end
end
end, function(err)
self:on_error(M.client_errors.INVALID_SERVER_MESSAGE, err)
---@diagnostic disable-next-line: invisible
self.transport:terminate()
end)
local on_read = M.create_read_loop(handle_body, on_exit, on_error)
transport:listen(on_read, dispatchers.on_exit)
transport:listen(function(err, data)
---@diagnostic disable-next-line: invisible
self.message_stream:feed(err, data)
end, dispatchers.on_exit)
return self
end
@@ -356,7 +348,7 @@ function Client:encode_and_send(payload)
end
local jsonstr = vim.json.encode(payload)
self.transport:write(format_message_with_content_length(jsonstr))
self.transport:write(self.message_stream.encode(jsonstr))
return true
end
@@ -637,7 +629,7 @@ function M.connect(host_or_path, port)
dispatchers = merge_dispatchers(dispatchers)
local transport = vim_transport.TransportConnect.new(host_or_path, port)
return Client.new(dispatchers, transport)
return Client.new(dispatchers, transport, message_decoder, format_message_with_content_length)
end
end
@@ -665,7 +657,7 @@ function M.start(cmd, dispatchers, extra_spawn_params)
dispatchers = merge_dispatchers(dispatchers)
local transport = vim_transport.TransportRun.new(cmd, extra_spawn_params)
return Client.new(dispatchers, transport)
return Client.new(dispatchers, transport, message_decoder, format_message_with_content_length)
end
return M