refactor: add basic stringbuffer shim

This commit is contained in:
Lewis Russell
2025-03-01 14:44:29 +00:00
parent f517fcd148
commit e76a7e8afb
2 changed files with 137 additions and 105 deletions

View File

@@ -5,8 +5,8 @@ local validate, schedule_wrap = vim.validate, vim.schedule_wrap
--- Embeds the given string into a table and correctly computes `Content-Length`.
---
---@param message string
---@return string message with `Content-Length` attribute
--- @param message string
--- @return string message with `Content-Length` attribute
local function format_message_with_content_length(message)
return table.concat({
'Content-Length: ',
@@ -18,8 +18,8 @@ end
--- Extract content-length from the header
---
---@param header string The header to parse
---@return integer?
--- @param header string The header to parse
--- @return integer
local function get_content_length(header)
for line in header:gmatch('(.-)\r\n') do
if line == '' then
@@ -27,112 +27,12 @@ local function get_content_length(header)
end
local key, value = line:match('^%s*(%S+)%s*:%s*(%d+)%s*$')
if key and key:lower() == 'content-length' then
return tonumber(value)
return assert(tonumber(value))
end
end
error('Content-Length not found in header: ' .. header)
end
-- This is the start of any possible header patterns. The gsub converts it to a
-- case insensitive pattern.
local header_start_pattern = ('content'):gsub('%w', function(c)
return '[' .. c .. c:upper() .. ']'
end)
local has_strbuffer, strbuffer = pcall(require, 'string.buffer')
--- The actual workhorse.
---@type function
local request_parser_loop
if has_strbuffer then
request_parser_loop = function()
local buf = strbuffer.new()
while true do
local msg = buf:tostring()
local header_end = msg:find('\r\n\r\n', 1, true)
if header_end then
local header = buf:get(header_end + 1)
buf:skip(2) -- skip past header boundary
local content_length = get_content_length(header)
while #buf < content_length do
local chunk = coroutine.yield()
buf:put(chunk)
end
local body = buf:get(content_length)
local chunk = coroutine.yield(body)
buf:put(chunk)
else
local chunk = coroutine.yield()
buf:put(chunk)
end
end
end
else
request_parser_loop = function()
local buffer = '' -- only for header part
while true do
-- A message can only be complete if it has a double CRLF and also the full
-- payload, so first let's check for the CRLFs
local header_end, body_start = buffer:find('\r\n\r\n', 1, true)
-- Start parsing the headers
if header_end then
-- This is a workaround for servers sending initial garbage before
-- sending headers, such as if a bash script sends stdout. It assumes
-- that we know all of the headers ahead of time. At this moment, the
-- only valid headers start with "Content-*", so that's the thing we will
-- be searching for.
-- TODO(ashkan) I'd like to remove this, but it seems permanent :(
local buffer_start = buffer:find(header_start_pattern)
if not buffer_start then
error(
string.format(
"Headers were expected, a different response was received. The server response was '%s'.",
buffer
)
)
end
local header = buffer:sub(buffer_start, header_end + 1)
local content_length = get_content_length(header)
-- Use table instead of just string to buffer the message. It prevents
-- a ton of strings allocating.
-- ref. http://www.lua.org/pil/11.6.html
---@type string[]
local body_chunks = { buffer:sub(body_start + 1) }
local body_length = #body_chunks[1]
-- Keep waiting for data until we have enough.
while body_length < content_length do
---@type string
local chunk = coroutine.yield()
or error('Expected more data for the body. The server may have died.') -- TODO hmm.
table.insert(body_chunks, chunk)
body_length = body_length + #chunk
end
local last_chunk = body_chunks[#body_chunks]
body_chunks[#body_chunks] = last_chunk:sub(1, content_length - body_length - 1)
local rest = ''
if body_length > content_length then
rest = last_chunk:sub(content_length - body_length)
end
local body = table.concat(body_chunks)
-- Yield our data.
--- @type string
local data = coroutine.yield(body)
or error('Expected more data for the body. The server may have died.')
buffer = rest .. data
else
-- Get more data since we don't have enough.
--- @type string
local data = coroutine.yield()
or error('Expected more data for the header. The server may have died.')
buffer = buffer .. data
end
end
end
end
local M = {}
--- Mapping of error codes used by the client
@@ -249,6 +149,28 @@ local default_dispatchers = {
end,
}
local strbuffer = require('vim._stringbuffer')
local function request_parser_loop()
local buf = strbuffer.new()
while true do
local msg = buf:tostring()
local header_end = msg:find('\r\n\r\n', 1, true)
if header_end then
local header = buf:get(header_end + 1)
buf:skip(2) -- skip past header boundary
local content_length = get_content_length(header)
while strbuffer.len(buf) < content_length do
buf:put(coroutine.yield())
end
local body = buf:get(content_length)
buf:put(coroutine.yield(body))
else
buf:put(coroutine.yield())
end
end
end
--- @private
--- @param handle_body fun(body: string)
--- @param on_exit? fun()