channel: Delay notifications to avoid client race conditions

It is currently possible for a client to send a response that doesn't match the
current server->client request(at the top of the stack). This commit fixes that
by delaying notifications to until the first `channel_send_call` invocation
returns.

Also remove the "call stack" size check, vim will already break if the call
stack goes too deep.
This commit is contained in:
Thiago de Arruda
2014-11-05 13:55:53 -03:00
parent 8979ede45d
commit d83868fe90
2 changed files with 109 additions and 23 deletions

View File

@@ -66,14 +66,23 @@ typedef struct {
uint64_t request_id;
} RequestEvent;
#define RequestEventFreer(x)
KMEMPOOL_INIT(RequestEventPool, RequestEvent, RequestEventFreer)
kmempool_t(RequestEventPool) *request_event_pool = NULL;
typedef struct {
Channel *channel;
String method;
Array args;
} DelayedNotification;
#define _noop(x)
KMEMPOOL_INIT(RequestEventPool, RequestEvent, _noop)
KLIST_INIT(DelayedNotification, DelayedNotification, _noop)
static kmempool_t(RequestEventPool) *request_event_pool = NULL;
static klist_t(DelayedNotification) *delayed_notifications = NULL;
static uint64_t next_id = 1;
static PMap(uint64_t) *channels = NULL;
static PMap(cstr_t) *event_strings = NULL;
static msgpack_sbuffer out_buffer;
static size_t pending_requests = 0;
#ifdef INCLUDE_GENERATED_DECLARATIONS
# include "msgpack_rpc/channel.c.generated.h"
@@ -83,6 +92,7 @@ static msgpack_sbuffer out_buffer;
void channel_init(void)
{
request_event_pool = kmp_init(RequestEventPool);
delayed_notifications = kl_init(DelayedNotification);
channels = pmap_new(uint64_t)();
event_strings = pmap_new(cstr_t)();
msgpack_sbuffer_init(&out_buffer);
@@ -173,14 +183,26 @@ bool channel_send_event(uint64_t id, char *name, Array args)
{
Channel *channel = NULL;
if (id > 0) {
if (!(channel = pmap_get(uint64_t)(channels, id)) || channel->closed) {
api_free_array(args);
return false;
}
send_event(channel, name, args);
if (id && (!(channel = pmap_get(uint64_t)(channels, id))
|| channel->closed)) {
api_free_array(args);
return false;
}
if (pending_requests) {
DelayedNotification p = {
.channel = channel,
.method = cstr_to_string(name),
.args = args
};
// Pending request, queue the notification for sending later
*kl_pushp(DelayedNotification, delayed_notifications) = p;
} else {
broadcast_event(name, args);
if (channel) {
send_event(channel, name, args);
} else {
broadcast_event(name, args);
}
}
return true;
@@ -206,16 +228,6 @@ Object channel_send_call(uint64_t id,
return NIL;
}
if (kv_size(channel->call_stack) > 20) {
// 20 stack depth is more than anyone should ever need for RPC calls
api_set_error(err,
Exception,
_("Channel %" PRIu64 " crossed maximum stack depth"),
channel->id);
api_free_array(args);
return NIL;
}
uint64_t request_id = channel->next_request_id++;
// Send the msgpack-rpc request
send_request(channel, request_id, method_name, args);
@@ -223,18 +235,24 @@ Object channel_send_call(uint64_t id,
// Push the frame
ChannelCallFrame frame = {request_id, false, false, NIL};
kv_push(ChannelCallFrame *, channel->call_stack, &frame);
pending_requests++;
event_poll_until(-1, frame.returned);
(void)kv_pop(channel->call_stack);
pending_requests--;
if (frame.errored) {
api_set_error(err, Exception, "%s", frame.result.data.string.data);
return NIL;
}
if (channel->closed && !kv_size(channel->call_stack)) {
if (!kv_size(channel->call_stack) && channel->closed) {
free_channel(channel);
}
if (!pending_requests) {
send_delayed_notifications();
}
return frame.result;
}
@@ -678,6 +696,7 @@ static void complete_call(msgpack_object *obj, Channel *channel)
static void call_set_error(Channel *channel, char *msg)
{
ELOG("Msgpack-RPC error: %s", msg);
for (size_t i = 0; i < kv_size(channel->call_stack); i++) {
ChannelCallFrame *frame = kv_A(channel->call_stack, i);
frame->returned = true;
@@ -727,6 +746,20 @@ static WBuffer *serialize_response(uint64_t channel_id,
return rv;
}
static void send_delayed_notifications(void)
{
DelayedNotification p;
while (kl_shift(DelayedNotification, delayed_notifications, &p) == 0) {
if (p.channel) {
send_event(p.channel, p.method.data, p.args);
} else {
broadcast_event(p.method.data, p.args);
}
free(p.method.data);
}
}
#if MIN_LOG_LEVEL <= DEBUG_LOG_LEVEL
#define REQ "[request] "
#define RES "[response] "
@@ -764,3 +797,4 @@ static void log_msg_close(FILE *f, msgpack_object msg)
fclose(f);
}
#endif