refactor(lsp): avoid using coroutine when parsing frames

This commit is contained in:
Yi Ming
2026-05-22 17:33:09 +08:00
parent 0e81835fae
commit 680ab13951
2 changed files with 60 additions and 47 deletions

View File

@@ -25,16 +25,18 @@ end
--- We ignore lines ending with `\n` that don't contain `content-length`, since some servers
--- write log to standard output and there's no way to avoid it.
--- See https://github.com/neovim/neovim/pull/35743#pullrequestreview-3379705828
--- @param header string The header to parse
--- @param ptr vim._core.stringbuffer.ptr The ptr to buffer to parse
--- @param start integer The starting index of the buffer to parse, 0-based
--- @param len integer The length of the header to parse
--- @return integer
local function get_content_length(header)
local function get_content_length(ptr, start, len)
local state = 'name'
local i, len = 1, #header
local i, end_ = start, start + len
local j, name = 1, 'content-length'
local buf = strbuffer.new()
local digit = true
while i <= len do
local c = header:byte(i)
while i < end_ do
local c = ptr[i]
if state == 'name' then
if c >= 65 and c <= 90 then -- lower case
c = c + 32
@@ -54,7 +56,7 @@ local function get_content_length(header)
i = i - 1
end
elseif state == 'value' then
if c == 13 and header:byte(i + 1) == 10 then -- must end with \r\n
if c == 13 and ptr[i + 1] == 10 then -- must end with \r\n
local value = buf:get()
if digit then
return vim._assert_integer(value)
@@ -73,7 +75,11 @@ local function get_content_length(header)
end
i = i + 1
end
error('Content-Length not found in header: ' .. header)
local header = strbuffer.new()
for k = start, end_ - 1 do
header:put(string.char(ptr[k]))
end
error('Content-Length not found in header: ' .. header:tostring())
end
local M = {}
@@ -190,33 +196,36 @@ local default_dispatchers = {
end,
}
--- @async
local function message_decoder()
local strbuf = strbuffer.new()
while true do
local header_len ---@type integer?
local ptr, len = strbuf:ref()
for i = 0, len - 4 do
-- Find the header boundary "\r\n\r\n"
-- (compare bytes instead of string.find(), to avoid a string alloc).
if ptr[i] == 13 and ptr[i + 1] == 10 and ptr[i + 2] == 13 and ptr[i + 3] == 10 then
header_len = i + 2
break
end
end
if header_len then
local header = strbuf:get(header_len)
strbuf:skip(2) -- skip past header boundary
local content_length = get_content_length(header)
while strbuffer.len(strbuf) < content_length do
strbuf:put(coroutine.yield())
end
local body = strbuf:get(content_length)
strbuf:put(coroutine.yield(body))
else
strbuf:put(coroutine.yield())
--- Parse one `Content-Length` framed message from `strbuf`.
---
--- Returns a body after consuming one full frame, returns nil if more bytes are needed.
--- Raises an error if the buffered data is not a valid frame.
---
---@param strbuf vim._core.stringbuffer
---@return string?
local function message_decoder(strbuf)
local header_len ---@type integer?
local ptr, len = strbuf:ref()
for i = 0, len - 4 do
-- Find the header boundary "\r\n\r\n"
-- (compare bytes instead of string.find(), to avoid a string alloc).
if ptr[i] == 13 and ptr[i + 1] == 10 and ptr[i + 2] == 13 and ptr[i + 3] == 10 then
header_len = i + 2
break
end
end
if not header_len then
return nil
end
local content_length = get_content_length(ptr, 0, header_len)
if strbuffer.len(strbuf) < header_len + 2 + content_length then
return nil
end
strbuf:skip(header_len + 2) -- skip past header boundary
return strbuf:get(content_length)
end
--- @private

View File

@@ -1,5 +1,6 @@
local uv = vim.uv
local log = require('vim.lsp.log')
local strbuffer = require('vim._core.stringbuffer')
--- Interface for transport implementations.
---
@@ -182,34 +183,36 @@ function TransportConnect:terminate()
end
end
--- Create a message stream from a coroutine decoder.
--- Create a message stream from a decoder.
---
--- The decoder receives transport data from `coroutine.yield()`
--- and may yield a message body when a full message is available.
--- A nil yield means it needs more transport data.
--- Decoder errors are reported through `on_error`.
--- The decoder consumes from the given string buffer
--- and returns a message body when a full message is available.
--- `nil` means it needs more transport data.
--- decoder errors are reported through `on_error`.
---
---@class (private, exact) vim.MessageStream
---@field private co thread
---@field private decode fun()
---@field private strbuf string.buffer
---@field private decode fun(strbuf: string.buffer): string?
---@field private on_read fun(err: string?, data: string?)
---@field private on_error fun(err: any)
---@field feed fun(self: vim.MessageStream, err: string?, data: string?)
---@field encode fun(msg: string): string
---@field new fun(decode: fun(), encode: (fun(msg: string): string), on_read: fun(err: string?, data: string?), on_error: fun(err: any)): vim.MessageStream
---@field new fun(decode: (fun(strbuf: string.buffer): string?), encode: (fun(msg: string): string), on_read: fun(err: string?, data: string?), on_error: fun(err: any)): vim.MessageStream
local MessageStream = {}
---@param decode fun(strbuf: string.buffer): string?
---@param encode fun(msg: string): string
---@param on_read fun(err: string?, data: string?)
---@param on_error fun(err: any)
---@return vim.MessageStream
function MessageStream.new(decode, encode, on_read, on_error)
local self = setmetatable({
co = coroutine.create(decode),
return setmetatable({
strbuf = strbuffer.new(),
decode = decode,
on_read = on_read,
on_error = on_error,
encode = encode,
}, { __index = MessageStream })
coroutine.resume(self.co)
return self
end
---@param err string?
@@ -223,8 +226,10 @@ function MessageStream:feed(err, data)
return
end
self.strbuf:put(data)
while true do
local ok, body = coroutine.resume(self.co, data)
local ok, body = pcall(self.decode, self.strbuf)
if not ok then
self.on_error(body)
return
@@ -232,7 +237,6 @@ function MessageStream:feed(err, data)
break
end
self.on_read(nil, body)
data = ''
end
end