refactor(lsp): better encapsulation and readability, inline unnecessary

This commit is contained in:
Yi Ming
2026-03-29 21:58:10 +08:00
parent 4784f96e59
commit dbd93de62a
2 changed files with 125 additions and 125 deletions

View File

@@ -1,74 +1,76 @@
local uv = vim.uv
local log = require('vim.lsp.log')
local is_win = vim.fn.has('win32') == 1
--- Checks whether a given path exists and is a directory.
---@param filename string path to check
---@return boolean
local function is_dir(filename)
local stat = uv.fs_stat(filename)
return stat and stat.type == 'directory' or false
end
--- Interface for transport implementations.
---
--- @class (private) vim.lsp.rpc.Transport
--- @field listen fun(self: vim.lsp.rpc.Transport, on_read: fun(err: any, data: string), on_exit: fun(code: integer, signal: integer))
--- @field write fun(self: vim.lsp.rpc.Transport, msg: string)
--- @field is_closing fun(self: vim.lsp.rpc.Transport): boolean
--- @field terminate fun(self: vim.lsp.rpc.Transport)
--- @class (private,exact) vim.lsp.rpc.Transport.Run : vim.lsp.rpc.Transport
--- @field new fun(): vim.lsp.rpc.Transport.Run
--- Transport backed by newly spawned process using `vim.system()`.
---
--- @class (private) vim.lsp.rpc.Transport.Run : vim.lsp.rpc.Transport
--- @field cmd string[] Command to start the LSP server.
--- @field extra_spawn_params? vim.lsp.rpc.ExtraSpawnParams
--- @field sysobj? vim.SystemObj
local TransportRun = {}
--- @return vim.lsp.rpc.Transport.Run
function TransportRun.new()
return setmetatable({}, { __index = TransportRun })
end
--- @param cmd string[] Command to start the LSP server.
--- @param extra_spawn_params? vim.lsp.rpc.ExtraSpawnParams
--- @return vim.lsp.rpc.Transport.Run
function TransportRun.new(cmd, extra_spawn_params)
return setmetatable({
cmd = cmd,
extra_spawn_params = extra_spawn_params,
}, { __index = TransportRun })
end
--- @param on_read fun(err: any, data: string)
--- @param on_exit fun(code: integer, signal: integer)
function TransportRun:run(cmd, extra_spawn_params, on_read, on_exit)
function TransportRun:listen(on_read, on_exit)
local function on_stderr(_, chunk)
if chunk then
log.error('rpc', cmd[1], 'stderr', chunk)
log.error('rpc', self.cmd[1], 'stderr', chunk)
end
end
extra_spawn_params = extra_spawn_params or {}
self.extra_spawn_params = self.extra_spawn_params or {}
if extra_spawn_params.cwd then
assert(is_dir(extra_spawn_params.cwd), 'cwd must be a directory')
if self.extra_spawn_params.cwd then
local stat = uv.fs_stat(self.extra_spawn_params.cwd)
assert(stat and stat.type == 'directory' or false, 'cwd must be a directory')
end
local detached = not is_win
if extra_spawn_params.detached ~= nil then
detached = extra_spawn_params.detached
-- Default to non-detached on Windows.
local detached = vim.fn.has('win32') ~= 1
if self.extra_spawn_params.detached ~= nil then
detached = self.extra_spawn_params.detached
end
local ok, sysobj_or_err = pcall(vim.system, cmd, {
---@type boolean, vim.SystemObj|string
local ok, sysobj_or_err = pcall(vim.system, self.cmd, {
stdin = true,
stdout = on_read,
stderr = on_stderr,
cwd = extra_spawn_params.cwd,
env = extra_spawn_params.env,
cwd = self.extra_spawn_params.cwd,
env = self.extra_spawn_params.env,
detach = detached,
}, function(obj)
on_exit(obj.code, obj.signal)
end)
if not ok then
local err = sysobj_or_err --[[@as string]]
if not ok then ---@cast sysobj_or_err string
local err = sysobj_or_err
local sfx = err:match('ENOENT')
and '. The language server is either not installed, missing from PATH, or not executable.'
or string.format(' with error message: %s', err)
error(('Spawning language server with cmd: `%s` failed%s'):format(vim.inspect(cmd), sfx))
end
error(('Spawning language server with cmd: `%s` failed%s'):format(vim.inspect(self.cmd), sfx))
end ---@cast sysobj_or_err vim.SystemObj
self.sysobj = sysobj_or_err --[[@as vim.SystemObj]]
self.sysobj = sysobj_or_err
end
function TransportRun:write(msg)
@@ -80,24 +82,35 @@ function TransportRun:is_closing()
end
function TransportRun:terminate()
assert(self.sysobj):kill(15)
local sysobj = assert(self.sysobj)
if sysobj:is_closing() then
return
end
sysobj:kill(15)
end
--- @class (private,exact) vim.lsp.rpc.Transport.Connect : vim.lsp.rpc.Transport
--- @field new fun(): vim.lsp.rpc.Transport.Connect
--- Transport backed by an existing `uv.uv_pipe_t` or `uv.uv_tcp_t` connection.
---
--- @class (private) vim.lsp.rpc.Transport.Connect : vim.lsp.rpc.Transport
--- @field host_or_path string
--- @field port? integer
--- @field handle? uv.uv_pipe_t|uv.uv_tcp_t
--- Connect returns a PublicClient synchronously so the caller
--- can immediately send messages before the connection is established
--- -> Need to buffer them until that happens
--- can immediately send messages before the connection is established.
--- These messages are buffered in `msgbuf`.
--- @field connected boolean
--- @field closing boolean
--- @field msgbuf vim.Ringbuf
--- @field on_exit? fun(code: integer, signal: integer)
local TransportConnect = {}
--- @param host_or_path string
--- @param port? integer
--- @return vim.lsp.rpc.Transport.Connect
function TransportConnect.new()
function TransportConnect.new(host_or_path, port)
return setmetatable({
host_or_path = host_or_path,
port = port,
connected = false,
-- size should be enough because the client can't really do anything until initialization is done
-- which required a response from the server - implying the connection got established
@@ -106,20 +119,18 @@ function TransportConnect.new()
}, { __index = TransportConnect })
end
--- @param host_or_path string
--- @param port? integer
--- @param on_read fun(err: any, data: string)
--- @param on_exit? fun(code: integer, signal: integer)
function TransportConnect:connect(host_or_path, port, on_read, on_exit)
function TransportConnect:listen(on_read, on_exit)
self.on_exit = on_exit
self.handle = (
port and assert(uv.new_tcp(), 'Could not create new TCP socket')
self.port and assert(uv.new_tcp(), 'Could not create new TCP socket')
or assert(uv.new_pipe(false), 'Pipe could not be opened.')
)
local function on_connect(err)
if err then
local address = not port and host_or_path or (host_or_path .. ':' .. port)
local address = not self.port and self.host_or_path or (self.host_or_path .. ':' .. self.port)
vim.schedule(function()
vim.notify(
string.format('Could not connect to %s, reason: %s', address, vim.inspect(err)),
@@ -135,15 +146,14 @@ function TransportConnect:connect(host_or_path, port, on_read, on_exit)
end
end
if not port then
self.handle:connect(host_or_path, on_connect)
if not self.port then
self.handle:connect(self.host_or_path, on_connect)
return
end
--- @diagnostic disable-next-line:param-type-mismatch bad UV typing
local info = uv.getaddrinfo(host_or_path, nil)
local resolved_host = info and info[1] and info[1].addr or host_or_path
self.handle:connect(resolved_host, port, on_connect)
local info = uv.getaddrinfo(self.host_or_path, nil)
local resolved_host = info and info[1] and info[1].addr or self.host_or_path
self.handle:connect(resolved_host, self.port, on_connect)
end
function TransportConnect:write(msg)

View File

@@ -2,7 +2,7 @@ local log = require('vim.lsp.log')
local protocol = require('vim.lsp.protocol')
local lsp_transport = require('vim.lsp._transport')
local strbuffer = require('vim._core.stringbuffer')
local validate, schedule_wrap = vim.validate, vim.schedule_wrap
local validate = vim.validate
--- Embeds the given string into a table and correctly computes `Content-Length`.
---
@@ -316,7 +316,30 @@ function Client.new(dispatchers, transport)
return result:_notify(method, params)
end
return setmetatable(result, { __index = Client })
---@cast result vim.lsp.rpc.Client
local self = setmetatable(result, { __index = Client })
--- @param body string
local function handle_body(body)
self:handle_body(body)
end
local function on_exit()
---@diagnostic disable-next-line: invisible
self.transport:terminate()
end
--- @param errkind vim.lsp.rpc.ClientErrors
local function on_error(err, errkind)
self:on_error(errkind, err)
if errkind == M.client_errors.INVALID_SERVER_MESSAGE then
---@diagnostic disable-next-line: invisible
self.transport:terminate()
end
end
local on_read = M.create_read_loop(handle_body, on_exit, on_error)
transport:listen(on_read, dispatchers.on_exit)
return self
end
---@private
@@ -380,47 +403,37 @@ function Client:_request(method, params, callback, notify_reply_callback)
return false
end
self.message_callbacks[message_id] = schedule_wrap(callback)
self.message_callbacks[message_id] = vim.schedule_wrap(callback)
if notify_reply_callback then
self.notify_reply_callbacks[message_id] = schedule_wrap(notify_reply_callback)
self.notify_reply_callbacks[message_id] = vim.schedule_wrap(notify_reply_callback)
end
return result, message_id
end
---@package
---@param errkind vim.lsp.rpc.ClientErrors
---@param ... any
function Client:on_error(errkind, ...)
---@param err any
function Client:on_error(errkind, err)
assert(M.client_errors[errkind])
-- TODO what to do if this fails?
pcall(self.dispatchers.on_error, errkind, ...)
end
---@private
---@param errkind integer
---@param status boolean
---@param head any
---@param ... any
---@return boolean status
---@return any head
---@return any? ...
function Client:pcall_handler(errkind, status, head, ...)
if not status then
self:on_error(errkind, head, ...)
return status, head
end
return status, head, ...
pcall(self.dispatchers.on_error, errkind, err)
end
---@private
---@param errkind integer
---@param fn function
---@param ... any
---@return boolean status
---@return any head
---@return any? ...
---@return boolean success
---@return any result
---@return any ...
function Client:try_call(errkind, fn, ...)
return self:pcall_handler(errkind, pcall(fn, ...))
local args = vim.F.pack_len(...)
return xpcall(function()
-- PUC Lua 5.1 xpcall() does not support forwarding extra arguments.
return fn(vim.F.unpack_len(args))
end, function(err)
self:on_error(errkind, err)
end)
end
-- TODO periodically check message_callbacks for old requests past a certain
@@ -434,25 +447,28 @@ function Client:handle_body(body)
if not ok then
self:on_error(M.client_errors.INVALID_SERVER_JSON, decoded)
return
elseif type(decoded) ~= 'table' then
self:on_error(M.client_errors.INVALID_SERVER_MESSAGE, decoded)
return
end
log.debug('rpc.receive', decoded)
if type(decoded) ~= 'table' then
self:on_error(M.client_errors.INVALID_SERVER_MESSAGE, decoded)
elseif type(decoded.method) == 'string' and decoded.id then
local err --- @type lsp.ResponseError?
-- Received a request.
if type(decoded.method) == 'string' and decoded.id then
-- Schedule here so that the users functions don't trigger an error and
-- we can still use the result.
vim.schedule(coroutine.wrap(function()
local status, result
status, result, err = self:try_call(
--- @type boolean, any, lsp.ResponseError?
local success, result, err = self:try_call(
M.client_errors.SERVER_REQUEST_HANDLER_ERROR,
self.dispatchers.server_request,
decoded.method,
decoded.params
)
log.debug('server_request: callback result', { status = status, result = result, err = err })
if status then
log.debug('server_request: callback result', { status = success, result = result, err = err })
-- Dispatcher returns without an exception.
if success then
if result == nil and err == nil then
error(
string.format(
@@ -480,10 +496,12 @@ function Client:handle_body(body)
end
self:send_response(decoded.id, err, result)
end))
-- Proceed only if exactly one of 'result' or 'error' is present, as required by the LSP spec:
-- - If 'error' is nil, then 'result' must be present.
-- - If 'result' is nil, then 'error' must be present (and not vim.NIL).
elseif
-- Received a response to a request we sent.
-- Proceed only if exactly one of 'result' or 'error' is present,
-- as required by the JSON-RPC spec:
-- * If 'error' is nil, then 'result' must be present.
-- * If 'result' is nil, then 'error' must be present (and not vim.NIL).
decoded.id
and (
(decoded.error == nil and decoded.result ~= nil)
@@ -537,7 +555,7 @@ function Client:handle_body(body)
log.error('No callback found for server response id ' .. result_id)
end
elseif type(decoded.method) == 'string' then
-- Notification
-- Received a notification.
self:try_call(
M.client_errors.NOTIFICATION_HANDLER_ERROR,
self.dispatchers.notification,
@@ -565,11 +583,11 @@ local function merge_dispatchers(dispatchers)
---@type vim.lsp.rpc.Dispatchers
local merged = {
notification = (
dispatchers.notification and schedule_wrap(dispatchers.notification)
dispatchers.notification and vim.schedule_wrap(dispatchers.notification)
or default_dispatchers.notification
),
on_error = (
dispatchers.on_error and schedule_wrap(dispatchers.on_error)
dispatchers.on_error and vim.schedule_wrap(dispatchers.on_error)
or default_dispatchers.on_error
),
on_exit = dispatchers.on_exit or default_dispatchers.on_exit,
@@ -578,26 +596,6 @@ local function merge_dispatchers(dispatchers)
return merged
end
--- @param client vim.lsp.rpc.Client
--- @param on_exit? fun()
local function create_client_read_loop(client, on_exit)
--- @param body string
local function handle_body(body)
client:handle_body(body)
end
--- @param errkind vim.lsp.rpc.ClientErrors
local function on_error(err, errkind)
client:on_error(errkind, err)
if errkind == M.client_errors.INVALID_SERVER_MESSAGE then
---@diagnostic disable-next-line: invisible
client.transport:terminate()
end
end
return M.create_read_loop(handle_body, on_exit, on_error)
end
--- Create a LSP RPC client factory that connects to either:
---
--- - a named pipe (windows)
@@ -611,6 +609,8 @@ end
---@param port integer? TCP port to connect to. If absent the first argument must be a pipe
---@return fun(dispatchers: vim.lsp.rpc.Dispatchers): vim.lsp.rpc.Client
function M.connect(host_or_path, port)
log.info('Connecting RPC client', { host_or_path = host_or_path, port = port })
validate('host_or_path', host_or_path, 'string')
validate('port', port, 'number', true)
@@ -619,14 +619,8 @@ function M.connect(host_or_path, port)
dispatchers = merge_dispatchers(dispatchers)
local transport = lsp_transport.TransportConnect.new()
local client = Client.new(dispatchers, transport)
local on_read = create_client_read_loop(client, function()
transport:terminate()
end)
transport:connect(host_or_path, port, on_read, dispatchers.on_exit)
return client
local transport = lsp_transport.TransportConnect.new(host_or_path, port)
return Client.new(dispatchers, transport)
end
end
@@ -653,12 +647,8 @@ function M.start(cmd, dispatchers, extra_spawn_params)
dispatchers = merge_dispatchers(dispatchers)
local transport = lsp_transport.TransportRun.new()
local client = Client.new(dispatchers, transport)
local on_read = create_client_read_loop(client)
transport:run(cmd, extra_spawn_params, on_read, dispatchers.on_exit)
return client
local transport = lsp_transport.TransportRun.new(cmd, extra_spawn_params)
return Client.new(dispatchers, transport)
end
return M