diff --git a/core/sync/chan/chan.odin b/core/sync/chan/chan.odin index eca4c28d7..c5a4cf317 100644 --- a/core/sync/chan/chan.odin +++ b/core/sync/chan/chan.odin @@ -7,6 +7,14 @@ import "core:mem" import "core:sync" import "core:math/rand" +when ODIN_TEST { +/* +Hook for testing _try_select_raw allowing the test harness to manipulate the +channels prior to the select actually operating on them. +*/ +__try_select_raw_pause : proc() = nil +} + /* Determines what operations `Chan` supports. */ @@ -1105,15 +1113,27 @@ can_send :: proc "contextless" (c: ^Raw_Chan) -> bool { return c.w_waiting == 0 } +/* +Specifies the direction of the selected channel. +*/ +Select_Status :: enum { + None, + Recv, + Send, +} + /* -Attempts to either send or receive messages on the specified channels. +Attempts to either send or receive messages on the specified channels without blocking. -`select_raw` first identifies which channels have messages ready to be received +`try_select_raw` first identifies which channels have messages ready to be received and which are available for sending. It then randomly selects one operation (either a send or receive) to perform. +If no channels have messages ready, the procedure is a noop. + Note: Each message in `send_msgs` corresponds to the send channel at the same index in `sends`. +If the message is nil, corresponding send channel will be skipped. **Inputs** - `recv`: A slice of channels to read from @@ -1145,18 +1165,18 @@ Example: // where the value from the read should be stored received_value: int - idx, ok := chan.select_raw(receive_chans[:], send_chans[:], msgs[:], &received_value) + idx, ok := chan.try_select_raw(receive_chans[:], send_chans[:], msgs[:], &received_value) fmt.println("SELECT: ", idx, ok) fmt.println("RECEIVED VALUE ", received_value) - idx, ok = chan.select_raw(receive_chans[:], send_chans[:], msgs[:], &received_value) + idx, ok = chan.try_select_raw(receive_chans[:], send_chans[:], msgs[:], &received_value) fmt.println("SELECT: ", idx, ok) fmt.println("RECEIVED VALUE ", received_value) // closing of a channel also affects the select operation chan.close(c) - idx, ok = chan.select_raw(receive_chans[:], send_chans[:], msgs[:], &received_value) + idx, ok = chan.try_select_raw(receive_chans[:], send_chans[:], msgs[:], &received_value) fmt.println("SELECT: ", idx, ok) } @@ -1170,7 +1190,7 @@ Output: */ @(require_results) -select_raw :: proc "odin" (recvs: []^Raw_Chan, sends: []^Raw_Chan, send_msgs: []rawptr, recv_out: rawptr) -> (select_idx: int, ok: bool) #no_bounds_check { +try_select_raw :: proc "odin" (recvs: []^Raw_Chan, sends: []^Raw_Chan, send_msgs: []rawptr, recv_out: rawptr) -> (select_idx: int, status: Select_Status) #no_bounds_check { Select_Op :: struct { idx: int, // local to the slice that was given is_recv: bool, @@ -1178,43 +1198,66 @@ select_raw :: proc "odin" (recvs: []^Raw_Chan, sends: []^Raw_Chan, send_msgs: [] candidate_count := builtin.len(recvs)+builtin.len(sends) candidates := ([^]Select_Op)(intrinsics.alloca(candidate_count*size_of(Select_Op), align_of(Select_Op))) - count := 0 - for c, i in recvs { - if can_recv(c) { - candidates[count] = { - is_recv = true, - idx = i, + try_loop: for { + count := 0 + + for c, i in recvs { + if can_recv(c) { + candidates[count] = { + is_recv = true, + idx = i, + } + count += 1 } - count += 1 } - } - for c, i in sends { - if can_send(c) { - candidates[count] = { - is_recv = false, - idx = i, + for c, i in sends { + if i > builtin.len(send_msgs)-1 || send_msgs[i] == nil { + continue + } + if can_send(c) { + candidates[count] = { + is_recv = false, + idx = i, + } + count += 1 } - count += 1 } - } - if count == 0 { - return - } + if count == 0 { + return -1, .None + } - select_idx = rand.int_max(count) if count > 0 else 0 + when ODIN_TEST { + if __try_select_raw_pause != nil { + __try_select_raw_pause() + } + } - sel := candidates[select_idx] - if sel.is_recv { - ok = recv_raw(recvs[sel.idx], recv_out) - } else { - ok = send_raw(sends[sel.idx], send_msgs[sel.idx]) + candidate_idx := rand.int_max(count) if count > 0 else 0 + + sel := candidates[candidate_idx] + if sel.is_recv { + status = .Recv + if !try_recv_raw(recvs[sel.idx], recv_out) { + continue try_loop + } + } else { + status = .Send + if !try_send_raw(sends[sel.idx], send_msgs[sel.idx]) { + continue try_loop + } + } + + return sel.idx, status } - return } +@(require_results, deprecated = "use try_select_raw") +select_raw :: proc "odin" (recvs: []^Raw_Chan, sends: []^Raw_Chan, send_msgs: []rawptr, recv_out: rawptr) -> (select_idx: int, status: Select_Status) #no_bounds_check { + return try_select_raw(recvs, sends, send_msgs, recv_out) +} /* `Raw_Queue` is a non-thread-safe queue implementation designed to store messages diff --git a/tests/core/sync/chan/test_core_sync_chan.odin b/tests/core/sync/chan/test_core_sync_chan.odin index 9b8d9b354..e8bb553b1 100644 --- a/tests/core/sync/chan/test_core_sync_chan.odin +++ b/tests/core/sync/chan/test_core_sync_chan.odin @@ -272,3 +272,180 @@ test_accept_message_from_closed_buffered_chan :: proc(t: ^testing.T) { testing.expect_value(t, result, 64) testing.expect(t, ok) } + +// Ensures that if any input channel is eligible to receive or send, the try_select_raw +// operation will process it. +@test +test_try_select_raw_happy :: proc(t: ^testing.T) { + testing.set_fail_timeout(t, FAIL_TIME) + + recv1, recv1_err := chan.create(chan.Chan(int), context.allocator) + + assert(recv1_err == nil, "allocation failed") + defer chan.destroy(recv1) + + recv2, recv2_err := chan.create(chan.Chan(int), 1, context.allocator) + + assert(recv2_err == nil, "allocation failed") + defer chan.destroy(recv2) + + send1, send1_err := chan.create(chan.Chan(int), 1, context.allocator) + + assert(send1_err == nil, "allocation failed") + defer chan.destroy(send1) + + msg := 42 + + // Preload recv2 to make it eligible for selection. + testing.expect_value(t, chan.send(recv2, msg), true) + + recvs := [?]^chan.Raw_Chan{recv1, recv2} + sends := [?]^chan.Raw_Chan{send1} + msgs := [?]rawptr{&msg} + received_value: int + + iteration_count := 0 + did_none_count := 0 + did_send_count := 0 + did_receive_count := 0 + + // This loop is expected to iterate three times. Twice to do the receive and + // send operations, and a third time to exit. + receive_loop: for { + + iteration_count += 1 + + idx, status := chan.try_select_raw(recvs[:], sends[:], msgs[:], &received_value) + + switch status { + case .None: + did_none_count += 1 + break receive_loop + + case .Recv: + did_receive_count += 1 + testing.expect_value(t, idx, 1) + testing.expect_value(t, received_value, msg) + received_value = 0 + + case .Send: + did_send_count += 1 + testing.expect_value(t, idx, 0) + v, ok := chan.try_recv(send1) + testing.expect_value(t, ok, true) + testing.expect_value(t, v, msg) + msgs[0] = nil // nil out the message to avoid constantly resending the same value. + } + } + + testing.expect_value(t, iteration_count, 3) + testing.expect_value(t, did_none_count, 1) + testing.expect_value(t, did_receive_count, 1) + testing.expect_value(t, did_send_count, 1) +} + +// Ensures that if no input channels are eligible to receive or send, the +// try_select_raw operation does not block. +@test +test_try_select_raw_default_state :: proc(t: ^testing.T) { + testing.set_fail_timeout(t, FAIL_TIME) + + recv1, recv1_err := chan.create(chan.Chan(int), context.allocator) + + assert(recv1_err == nil, "allocation failed") + defer chan.destroy(recv1) + + recv2, recv2_err := chan.create(chan.Chan(int), context.allocator) + + assert(recv2_err == nil, "allocation failed") + defer chan.destroy(recv2) + + recvs := [?]^chan.Raw_Chan{recv1, recv2} + received_value: int + + idx, status := chan.try_select_raw(recvs[:], nil, nil, &received_value) + + testing.expect_value(t, idx, -1) + testing.expect_value(t, status, chan.Select_Status.None) +} + +// Ensures that the operation will not block even if the input channels are +// consumed by a competing thread; that is, a value is received from another +// thread between calls to can_{send,recv} and try_{send,recv}_raw. +@test +test_try_select_raw_no_toctou :: proc(t: ^testing.T) { + testing.set_fail_timeout(t, FAIL_TIME) + + // Trigger will be used to coordinate between the thief and the try_select. + trigger, trigger_err := chan.create(chan.Chan(any), context.allocator) + + assert(trigger_err == nil, "allocation failed") + defer chan.destroy(trigger) + + @(static) + __global_context_for_test: rawptr + + __global_context_for_test = &trigger + defer __global_context_for_test = nil + + // Setup the pause proc. This will be invoked after the input channels are + // checked for eligibility but before any channel operations are attempted. + chan.__try_select_raw_pause = proc() { + trigger := (cast(^chan.Chan(any))(__global_context_for_test))^ + + // Notify the thief that we are paused so that it can steal the value. + _ = chan.send(trigger, "signal") + + // Wait for comfirmation of the burglary. + _, _ = chan.recv(trigger) + } + + defer chan.__try_select_raw_pause = nil + + recv1, recv1_err := chan.create(chan.Chan(int), 1, context.allocator) + + assert(recv1_err == nil, "allocation failed") + defer chan.destroy(recv1) + + Context :: struct { + recv1: chan.Chan(int), + trigger: chan.Chan(any), + } + + ctx := Context{ + recv1 = recv1, + trigger = trigger, + } + + // Spin up a thread that will steal the value from the input channel after + // try_select has already considered it eligible for selection. + thief := thread.create_and_start_with_poly_data(ctx, proc(ctx: Context) { + // Wait for eligibility check. + _, _ = chan.recv(ctx.trigger) + + // Steal the value. + v, ok := chan.recv(ctx.recv1) + + assert(ok, "recv1: expected to receive a value") + assert(v == 42, "recv1: unexpected receive value") + + // Notify select that we have stolen the value and that it can proceed. + _ = chan.send(ctx.trigger, "signal") + }) + + recvs := [?]^chan.Raw_Chan{recv1} + received_value: int + + // Ensure channel is eligible prior to entering the select. + testing.expect_value(t, chan.send(recv1, 42), true) + + // Execute the try_select_raw, assert that we don't block, and that we receive + // .None status since the value was stolen by the other thread. + idx, status := chan.try_select_raw(recvs[:], nil, nil, &received_value) + + testing.expect_value(t, idx, -1) + testing.expect_value(t, status, chan.Select_Status.None) + + thread.join(thief) + thread.destroy(thief) +}