From 680ab139513f1c201e2add1197eaca2e189125fc Mon Sep 17 00:00:00 2001 From: Yi Ming Date: Fri, 22 May 2026 17:33:09 +0800 Subject: [PATCH] refactor(lsp): avoid using coroutine when parsing frames --- runtime/lua/vim/lsp/rpc.lua | 73 +++++++++++++++++------------- runtime/lua/vim/net/_transport.lua | 34 ++++++++------ 2 files changed, 60 insertions(+), 47 deletions(-) diff --git a/runtime/lua/vim/lsp/rpc.lua b/runtime/lua/vim/lsp/rpc.lua index a81dbc7280..11993c90fd 100644 --- a/runtime/lua/vim/lsp/rpc.lua +++ b/runtime/lua/vim/lsp/rpc.lua @@ -25,16 +25,18 @@ end --- We ignore lines ending with `\n` that don't contain `content-length`, since some servers --- write log to standard output and there's no way to avoid it. --- See https://github.com/neovim/neovim/pull/35743#pullrequestreview-3379705828 ---- @param header string The header to parse +--- @param ptr vim._core.stringbuffer.ptr The ptr to buffer to parse +--- @param start integer The starting index of the buffer to parse, 0-based +--- @param len integer The length of the header to parse --- @return integer -local function get_content_length(header) +local function get_content_length(ptr, start, len) local state = 'name' - local i, len = 1, #header + local i, end_ = start, start + len local j, name = 1, 'content-length' local buf = strbuffer.new() local digit = true - while i <= len do - local c = header:byte(i) + while i < end_ do + local c = ptr[i] if state == 'name' then if c >= 65 and c <= 90 then -- lower case c = c + 32 @@ -54,7 +56,7 @@ local function get_content_length(header) i = i - 1 end elseif state == 'value' then - if c == 13 and header:byte(i + 1) == 10 then -- must end with \r\n + if c == 13 and ptr[i + 1] == 10 then -- must end with \r\n local value = buf:get() if digit then return vim._assert_integer(value) @@ -73,7 +75,11 @@ local function get_content_length(header) end i = i + 1 end - error('Content-Length not found in header: ' .. header) + local header = strbuffer.new() + for k = start, end_ - 1 do + header:put(string.char(ptr[k])) + end + error('Content-Length not found in header: ' .. header:tostring()) end local M = {} @@ -190,33 +196,36 @@ local default_dispatchers = { end, } ---- @async -local function message_decoder() - local strbuf = strbuffer.new() - while true do - local header_len ---@type integer? - local ptr, len = strbuf:ref() - for i = 0, len - 4 do - -- Find the header boundary "\r\n\r\n" - -- (compare bytes instead of string.find(), to avoid a string alloc). - if ptr[i] == 13 and ptr[i + 1] == 10 and ptr[i + 2] == 13 and ptr[i + 3] == 10 then - header_len = i + 2 - break - end - end - if header_len then - local header = strbuf:get(header_len) - strbuf:skip(2) -- skip past header boundary - local content_length = get_content_length(header) - while strbuffer.len(strbuf) < content_length do - strbuf:put(coroutine.yield()) - end - local body = strbuf:get(content_length) - strbuf:put(coroutine.yield(body)) - else - strbuf:put(coroutine.yield()) +--- Parse one `Content-Length` framed message from `strbuf`. +--- +--- Returns a body after consuming one full frame, returns nil if more bytes are needed. +--- Raises an error if the buffered data is not a valid frame. +--- +---@param strbuf vim._core.stringbuffer +---@return string? +local function message_decoder(strbuf) + local header_len ---@type integer? + local ptr, len = strbuf:ref() + for i = 0, len - 4 do + -- Find the header boundary "\r\n\r\n" + -- (compare bytes instead of string.find(), to avoid a string alloc). + if ptr[i] == 13 and ptr[i + 1] == 10 and ptr[i + 2] == 13 and ptr[i + 3] == 10 then + header_len = i + 2 + break end end + + if not header_len then + return nil + end + + local content_length = get_content_length(ptr, 0, header_len) + if strbuffer.len(strbuf) < header_len + 2 + content_length then + return nil + end + + strbuf:skip(header_len + 2) -- skip past header boundary + return strbuf:get(content_length) end --- @private diff --git a/runtime/lua/vim/net/_transport.lua b/runtime/lua/vim/net/_transport.lua index b7c6e93aaa..b138fc327c 100644 --- a/runtime/lua/vim/net/_transport.lua +++ b/runtime/lua/vim/net/_transport.lua @@ -1,5 +1,6 @@ local uv = vim.uv local log = require('vim.lsp.log') +local strbuffer = require('vim._core.stringbuffer') --- Interface for transport implementations. --- @@ -182,34 +183,36 @@ function TransportConnect:terminate() end end ---- Create a message stream from a coroutine decoder. +--- Create a message stream from a decoder. --- ---- The decoder receives transport data from `coroutine.yield()` ---- and may yield a message body when a full message is available. ---- A nil yield means it needs more transport data. ---- Decoder errors are reported through `on_error`. +--- The decoder consumes from the given string buffer +--- and returns a message body when a full message is available. +--- `nil` means it needs more transport data. +--- decoder errors are reported through `on_error`. --- ---@class (private, exact) vim.MessageStream ----@field private co thread ----@field private decode fun() +---@field private strbuf string.buffer +---@field private decode fun(strbuf: string.buffer): string? ---@field private on_read fun(err: string?, data: string?) ---@field private on_error fun(err: any) ---@field feed fun(self: vim.MessageStream, err: string?, data: string?) ---@field encode fun(msg: string): string ----@field new fun(decode: fun(), encode: (fun(msg: string): string), on_read: fun(err: string?, data: string?), on_error: fun(err: any)): vim.MessageStream +---@field new fun(decode: (fun(strbuf: string.buffer): string?), encode: (fun(msg: string): string), on_read: fun(err: string?, data: string?), on_error: fun(err: any)): vim.MessageStream local MessageStream = {} +---@param decode fun(strbuf: string.buffer): string? +---@param encode fun(msg: string): string +---@param on_read fun(err: string?, data: string?) +---@param on_error fun(err: any) +---@return vim.MessageStream function MessageStream.new(decode, encode, on_read, on_error) - local self = setmetatable({ - co = coroutine.create(decode), + return setmetatable({ + strbuf = strbuffer.new(), decode = decode, on_read = on_read, on_error = on_error, encode = encode, }, { __index = MessageStream }) - - coroutine.resume(self.co) - return self end ---@param err string? @@ -223,8 +226,10 @@ function MessageStream:feed(err, data) return end + self.strbuf:put(data) + while true do - local ok, body = coroutine.resume(self.co, data) + local ok, body = pcall(self.decode, self.strbuf) if not ok then self.on_error(body) return @@ -232,7 +237,6 @@ function MessageStream:feed(err, data) break end self.on_read(nil, body) - data = '' end end