diff --git a/lib/system/threads.nim b/lib/system/threads.nim index 3291d8c733..039a9f2e19 100644 --- a/lib/system/threads.nim +++ b/lib/system/threads.nim @@ -311,6 +311,9 @@ var threadCreationHandlers: array[60, proc () {.nimcall, gcsafe.}] countThreadCreationHandlers: int + threadDestructionHandlers: array[60, proc () {.nimcall, gcsafe.}] + countThreadDestructionHandlers: int + proc onThreadCreation*(handler: proc () {.nimcall, gcsafe.}) = ## Registers a global handler that is called at thread creation. ## This can be used to initialize thread local variables properly. @@ -334,10 +337,22 @@ proc onThreadCreation*(handler: proc () {.nimcall, gcsafe.}) = threadCreationHandlers[countThreadCreationHandlers] = handler inc countThreadCreationHandlers +proc onThreadDestruction*(handler: proc () {.nimcall, gcsafe.}) = + ## Registers a global handler that is called at thread destruction. + ## Threads are destructed when the ``.thread`` proc returns + ## normally or raises an exception. Note that unhandled exceptions + ## in a thread nevertheless cause the whole process to die. + threadDestructionHandlers[countThreadDestructionHandlers] = handler + inc countThreadDestructionHandlers + template beforeThreadRuns() = for i in 0..countThreadCreationHandlers-1: threadCreationHandlers[i]() +template afterThreadRuns() = + for i in 0..countThreadDestructionHandlers-1: + threadDestructionHandlers[i]() + proc runOnThreadCreationHandlers*() = ## This runs every registered ``onThreadCreation`` handler and is usually ## used to initialize thread local storage for the main thread. Since the @@ -360,21 +375,27 @@ when defined(boehmgc): proc threadProcWrapDispatch[TArg](sb: pointer, thrd: pointer) {.noconv.} = boehmGC_register_my_thread(sb) beforeThreadRuns() - let thrd = cast[ptr Thread[TArg]](thrd) - when TArg is void: - thrd.dataFn() - else: - thrd.dataFn(thrd.data) + try: + let thrd = cast[ptr Thread[TArg]](thrd) + when TArg is void: + thrd.dataFn() + else: + thrd.dataFn(thrd.data) + finally: + afterThreadRuns() boehmGC_unregister_my_thread() else: proc threadProcWrapDispatch[TArg](thrd: ptr Thread[TArg]) = beforeThreadRuns() - when TArg is void: - thrd.dataFn() - else: - var x: TArg - deepCopy(x, thrd.data) - thrd.dataFn(x) + try: + when TArg is void: + thrd.dataFn() + else: + var x: TArg + deepCopy(x, thrd.data) + thrd.dataFn(x) + finally: + afterThreadRuns() proc threadProcWrapStackFrame[TArg](thrd: ptr Thread[TArg]) = when defined(boehmgc): diff --git a/tests/threads/tonthreadcreation.nim b/tests/threads/tonthreadcreation.nim index c96e86d4d7..5d9b777b83 100644 --- a/tests/threads/tonthreadcreation.nim +++ b/tests/threads/tonthreadcreation.nim @@ -1,5 +1,6 @@ discard """ - output: '''some string here''' + output: '''some string here +dying some string here''' """ var @@ -10,11 +11,15 @@ proc setPerThread() = {.gcsafe.}: deepCopy(perThread, someGlobal) +proc threadDied() {.gcsafe} = + echo "dying ", perThread + proc foo() {.thread.} = echo perThread proc main = onThreadCreation setPerThread + onThreadDestruction threadDied var t: Thread[void] createThread[void](t, foo) t.joinThread()