diff --git a/core/thread/thread_pool.odin b/core/thread/thread_pool.odin index d7a03d04c..ea772c725 100644 --- a/core/thread/thread_pool.odin +++ b/core/thread/thread_pool.odin @@ -20,6 +20,9 @@ Task :: struct { allocator: mem.Allocator, } +Thread_Init_Proc :: #type proc(thread_index: int, user_data: rawptr) +Thread_Fini_Proc :: #type proc(thread_index: int, user_data: rawptr) + // Do not access the pool's members directly while the pool threads are running, // since they use different kinds of locking and mutual exclusion devices. // Careless access can and will lead to nasty bugs. Once initialized, the @@ -36,6 +39,13 @@ Pool :: struct { num_done: int, // end of atomics + // called once per thread at startup + thread_init_proc: Thread_Init_Proc, + thread_init_data: rawptr, + // called once per thread at shutdown + thread_fini_proc: Thread_Fini_Proc, + thread_fini_data: rawptr, + is_running: bool, threads: []^Thread, @@ -55,6 +65,10 @@ pool_thread_runner :: proc(t: ^Thread) { data := cast(^Pool_Thread_Data)t.data pool := data.pool + if pool.thread_init_proc != nil { + pool.thread_init_proc(t.user_index, pool.thread_init_data) + } + for intrinsics.atomic_load(&pool.is_running) { sync.wait(&pool.sem_available) @@ -66,6 +80,10 @@ pool_thread_runner :: proc(t: ^Thread) { } } + if pool.thread_fini_proc != nil { + pool.thread_fini_proc(t.user_index, pool.thread_fini_data) + } + sync.post(&pool.sem_available, 1) } @@ -73,13 +91,26 @@ pool_thread_runner :: proc(t: ^Thread) { // it is destroyed. // // The thread pool requires an allocator which it either owns, or which is thread safe. -pool_init :: proc(pool: ^Pool, allocator: mem.Allocator, thread_count: int) { +pool_init :: proc( + pool: ^Pool, + allocator: mem.Allocator, + thread_count: int, + init_proc: Thread_Init_Proc = nil, + init_data: rawptr = nil, + fini_proc: Thread_Init_Proc = nil, + fini_data: rawptr = nil, +){ context.allocator = allocator pool.allocator = allocator queue.init(&pool.tasks) pool.tasks_done = make([dynamic]Task) pool.threads = make([]^Thread, max(thread_count, 1)) + pool.thread_init_proc = init_proc + pool.thread_fini_proc = fini_proc + pool.thread_init_data = init_data + pool.thread_fini_data = fini_data + pool.is_running = true for _, i in pool.threads {