Removed the need for SDL_CreateTLS()

This eliminates the tap dancing needed for allocating TLS slots, we'll automatically allocate them as needed, in a thread-safe way.
This commit is contained in:
Sam Lantinga
2024-07-16 09:43:07 -07:00
parent 1592452cad
commit ec3bb4c029
12 changed files with 75 additions and 112 deletions

View File

@@ -29,34 +29,37 @@
/* The storage is local to the thread, but the IDs are global for the process */
static SDL_AtomicInt SDL_tls_allocated;
static SDL_AtomicInt SDL_tls_id;
void SDL_InitTLSData(void)
{
SDL_SYS_InitTLSData();
}
SDL_TLSID SDL_CreateTLS(void)
{
static SDL_AtomicInt SDL_tls_id;
return (SDL_TLSID)(SDL_AtomicIncRef(&SDL_tls_id) + 1);
}
void *SDL_GetTLS(SDL_TLSID id)
void *SDL_GetTLS(SDL_TLSID *id)
{
SDL_TLSData *storage;
int storage_index;
storage = SDL_SYS_GetTLSData();
if (!storage || id == 0 || id > storage->limit) {
if (id == NULL) {
SDL_InvalidParamError("id");
return NULL;
}
return storage->array[id - 1].data;
storage_index = SDL_AtomicGet(id) - 1;
storage = SDL_SYS_GetTLSData();
if (!storage || storage_index < 0 || storage_index >= storage->limit) {
return NULL;
}
return storage->array[storage_index].data;
}
int SDL_SetTLS(SDL_TLSID id, const void *value, SDL_TLSDestructorCallback destructor)
int SDL_SetTLS(SDL_TLSID *id, const void *value, SDL_TLSDestructorCallback destructor)
{
SDL_TLSData *storage;
int storage_index;
if (id == 0) {
if (id == NULL) {
return SDL_InvalidParamError("id");
}
@@ -66,14 +69,27 @@ int SDL_SetTLS(SDL_TLSID id, const void *value, SDL_TLSDestructorCallback destru
*/
SDL_InitTLSData();
/* Get the storage index associated with the ID in a thread-safe way */
storage_index = SDL_AtomicGet(id) - 1;
if (storage_index < 0) {
int new_id = (SDL_AtomicIncRef(&SDL_tls_id) + 1);
SDL_AtomicCompareAndSwap(id, 0, new_id);
/* If there was a race condition we'll have wasted an ID, but every thread
* will have the same storage index for this id.
*/
storage_index = SDL_AtomicGet(id) - 1;
}
/* Get the storage for the current thread */
storage = SDL_SYS_GetTLSData();
if (!storage || (id > storage->limit)) {
if (!storage || storage_index >= storage->limit) {
unsigned int i, oldlimit, newlimit;
SDL_TLSData *new_storage;
oldlimit = storage ? storage->limit : 0;
newlimit = (id + TLS_ALLOC_CHUNKSIZE);
newlimit = (storage_index + TLS_ALLOC_CHUNKSIZE);
new_storage = (SDL_TLSData *)SDL_realloc(storage, sizeof(*storage) + (newlimit - 1) * sizeof(storage->array[0]));
if (!new_storage) {
return -1;
@@ -91,8 +107,8 @@ int SDL_SetTLS(SDL_TLSID id, const void *value, SDL_TLSDestructorCallback destru
SDL_AtomicIncRef(&SDL_tls_allocated);
}
storage->array[id - 1].data = SDL_const_cast(void *, value);
storage->array[id - 1].destructor = destructor;
storage->array[storage_index].data = SDL_const_cast(void *, value);
storage->array[storage_index].destructor = destructor;
return 0;
}
@@ -103,7 +119,7 @@ void SDL_CleanupTLS(void)
/* Cleanup the storage for the current thread */
storage = SDL_SYS_GetTLSData();
if (storage) {
unsigned int i;
int i;
for (i = 0; i < storage->limit; ++i) {
if (storage->array[i].destructor) {
storage->array[i].destructor(storage->array[i].data);
@@ -261,42 +277,15 @@ SDL_error *SDL_GetErrBuf(SDL_bool create)
#ifdef SDL_THREADS_DISABLED
return SDL_GetStaticErrBuf();
#else
static SDL_SpinLock tls_lock;
static SDL_bool tls_being_created;
static SDL_TLSID tls_errbuf;
const SDL_error *ALLOCATION_IN_PROGRESS = (SDL_error *)-1;
SDL_error *errbuf;
if (!tls_errbuf && !create) {
return NULL;
}
/* tls_being_created is there simply to prevent recursion if SDL_CreateTLS() fails.
It also means it's possible for another thread to also use SDL_global_errbuf,
but that's very unlikely and hopefully won't cause issues.
*/
if (!tls_errbuf && !tls_being_created) {
SDL_LockSpinlock(&tls_lock);
if (!tls_errbuf) {
SDL_TLSID slot;
tls_being_created = SDL_TRUE;
slot = SDL_CreateTLS();
tls_being_created = SDL_FALSE;
SDL_MemoryBarrierRelease();
tls_errbuf = slot;
}
SDL_UnlockSpinlock(&tls_lock);
}
if (!tls_errbuf) {
return SDL_GetStaticErrBuf();
}
SDL_MemoryBarrierAcquire();
errbuf = (SDL_error *)SDL_GetTLS(tls_errbuf);
if (errbuf == ALLOCATION_IN_PROGRESS) {
return SDL_GetStaticErrBuf();
}
errbuf = (SDL_error *)SDL_GetTLS(&tls_errbuf);
if (!errbuf) {
if (!create) {
return NULL;
}
/* Get the original memory functions for this allocation because the lifetime
* of the error buffer may span calls to SDL_SetMemoryFunctions() by the app
*/
@@ -304,17 +293,14 @@ SDL_error *SDL_GetErrBuf(SDL_bool create)
SDL_free_func free_func;
SDL_GetOriginalMemoryFunctions(NULL, NULL, &realloc_func, &free_func);
/* Mark that we're in the middle of allocating our buffer */
SDL_SetTLS(tls_errbuf, ALLOCATION_IN_PROGRESS, NULL);
errbuf = (SDL_error *)realloc_func(NULL, sizeof(*errbuf));
if (!errbuf) {
SDL_SetTLS(tls_errbuf, NULL, NULL);
return SDL_GetStaticErrBuf();
}
SDL_zerop(errbuf);
errbuf->realloc_func = realloc_func;
errbuf->free_func = free_func;
SDL_SetTLS(tls_errbuf, errbuf, SDL_FreeErrBuf);
SDL_SetTLS(&tls_errbuf, errbuf, SDL_FreeErrBuf);
}
return errbuf;
#endif /* SDL_THREADS_DISABLED */