gpu: Rework driver name queries, add GetGPUShaderFormats

This commit is contained in:
Ethan Lee
2024-09-13 11:16:43 -04:00
parent 6d92de5d3a
commit 96e147b2b9
12 changed files with 129 additions and 67 deletions

View File

@@ -150,17 +150,28 @@ static const GPU_ShaderSources frag_shader_sources[NUM_FRAG_SHADERS] = {
static SDL_GPUShader *CompileShader(const GPU_ShaderSources *sources, SDL_GPUDevice *device, SDL_GPUShaderStage stage)
{
const GPU_ShaderModuleSource *sms = NULL;
SDL_GPUDriver driver = SDL_GetGPUDriver(device);
SDL_GPUShaderFormat formats = SDL_GetGPUShaderFormats(device);
switch (driver) {
// clang-format off
IF_VULKAN( case SDL_GPU_DRIVER_VULKAN: sms = &sources->spirv; break;)
IF_D3D11( case SDL_GPU_DRIVER_D3D11: sms = &sources->dxbc50; break;)
IF_D3D12( case SDL_GPU_DRIVER_D3D12: sms = &sources->dxil60; break;)
IF_METAL( case SDL_GPU_DRIVER_METAL: sms = &sources->msl; break;)
// clang-format on
default:
if (formats == SDL_GPU_SHADERFORMAT_INVALID) {
// SDL_GetGPUShaderFormats already set the error
return NULL;
#if HAVE_SPIRV_SHADERS
} else if (formats & SDL_GPU_SHADERFORMAT_SPIRV) {
sms = &sources->spirv;
#endif // HAVE_SPIRV_SHADERS
#if HAVE_DXBC50_SHADERS
} else if (formats & SDL_GPU_SHADERFORMAT_DXBC) {
sms = &sources->dxbc50;
#endif // HAVE_DXBC50_SHADERS
#if HAVE_DXIL60_SHADERS
} else if (formats & SDL_GPU_SHADERFORMAT_DXIL) {
sms = &sources->dxil60;
#endif // HAVE_DXIL60_SHADERS
#if HAVE_METAL_SHADERS
} else if (formats & SDL_GPU_SHADERFORMAT_MSL) {
sms = &sources->msl;
#endif // HAVE_METAL_SHADERS
} else {
SDL_SetError("Unsupported GPU backend");
return NULL;
}
@@ -170,7 +181,11 @@ static SDL_GPUShader *CompileShader(const GPU_ShaderSources *sources, SDL_GPUDev
sci.code_size = sms->code_len;
sci.format = sms->format;
// FIXME not sure if this is correct
sci.entrypoint = driver == SDL_GPU_DRIVER_METAL ? "main0" : "main";
sci.entrypoint =
#if HAVE_METAL_SHADERS
(sms == &sources->msl) ? "main0" :
#endif // HAVE_METAL_SHADERS
"main";
sci.num_samplers = sources->num_samplers;
sci.num_uniform_buffers = sources->num_uniform_buffers;
sci.stage = stage;