diff --git a/runtime/lua/vim/lsp/rpc.lua b/runtime/lua/vim/lsp/rpc.lua index 4feb154b5e..9a63ef9cad 100644 --- a/runtime/lua/vim/lsp/rpc.lua +++ b/runtime/lua/vim/lsp/rpc.lua @@ -175,27 +175,34 @@ end --- @private --- @param handle_body fun(body: string) --- @param on_exit? fun() ---- @param on_error fun(err: any) +--- @param on_error? fun(err: any, errkind: vim.lsp.rpc.ClientErrors) function M.create_read_loop(handle_body, on_exit, on_error) - local parse_chunk = coroutine.wrap(request_parser_loop) --[[@as fun(chunk: string?): string]] - parse_chunk() + on_exit = on_exit or function() end + on_error = on_error or function() end + local co = coroutine.create(request_parser_loop) + coroutine.resume(co) return function(err, chunk) if err then - on_error(err) + on_error(err, M.client_errors.READ_ERROR) return end if not chunk then - if on_exit then - on_exit() - end + on_exit() + return + end + + if coroutine.status(co) == 'dead' then return end while true do - local body = parse_chunk(chunk) - if body then - handle_body(body) + local ok, res = coroutine.resume(co, chunk) + if not ok then + on_error(res, M.client_errors.INVALID_SERVER_MESSAGE) + break + elseif res then + handle_body(res) chunk = '' else break @@ -547,8 +554,12 @@ local function create_client_read_loop(client, on_exit) client:handle_body(body) end - local function on_error(err) - client:on_error(M.client_errors.READ_ERROR, err) + --- @param errkind vim.lsp.rpc.ClientErrors + local function on_error(err, errkind) + client:on_error(errkind, err) + if errkind == M.client_errors.INVALID_SERVER_MESSAGE then + client.transport:terminate() + end end return M.create_read_loop(handle_body, on_exit, on_error) diff --git a/test/functional/plugin/lsp_spec.lua b/test/functional/plugin/lsp_spec.lua index 71fdefff23..66475290fd 100644 --- a/test/functional/plugin/lsp_spec.lua +++ b/test/functional/plugin/lsp_spec.lua @@ -1971,6 +1971,43 @@ describe('LSP', function() } end) + it('should catch error while parsing invalid header', function() + local header = 'Content-Length: \r\n' + local called = false + exec_lua(function() + local server = assert(vim.uv.new_tcp()) + server:bind('127.0.0.1', 0) + server:listen(1, function(e) + assert(not e, e) + local socket = assert(vim.uv.new_tcp()) + server:accept(socket) + socket:write(header .. '\r\n', function() + socket:shutdown() + server:close() + end) + end) + local client = assert(vim.uv.new_tcp()) + local on_read = require('vim.lsp.rpc').create_read_loop(function() end, function() + client:close() + end, function(err, code) + vim.rpcnotify(1, 'error', err, code) + end) + client:connect('127.0.0.1', server:getsockname().port, function() + client:read_start(on_read) + end) + end) + n.run(nil, function(method, args) + local err, code = unpack(args) --- @type string, number + eq('error', method) + eq(1, code) + matches(vim.pesc('Content-Length not found in header: ' .. header) .. '$', err) + called = true + stop() + return NIL + end, nil, 1000) + eq(true, called) + end) + it('should not trim vim.NIL from the end of a list', function() local expected_handlers = { { NIL, {}, { method = 'shutdown', client_id = 1 } },