In this blog post, we’ll explore the intricacies of optimizing the Softmax function on CUDA GPUs. From efficient memory access to warp-level reductions, we’ll walk through advanced techniques that make Softmax computations blazing fast! Let’s dive straight into the code.
__global__ void SoftMaxKernel(float *input, float * output, int w, int h)
{
int row = blockIdx.x; // row index
int ty = threadIdx.y; // thread index
int warp_id = ty/32; // warp id
int lane_id = ty%32; // lane id
Lets see what memory do we use now. We optimize it by using shared memory and float4 variables.
__shared__ float reduction[BLOCK_DIM_Y/32]; // Reduction array
float4 reg_array[CEILING((WIDTH/4),BLOCK_DIM_Y)]; // Registry array
int reg_array_idx = 0; // Index of the array
if (row < h) // check the limits
{
float maxval = 0; // set the max value first to 0
#pragma unroll // unroll the array . Good optimization
// This is how #pragma unroll will do
//
// for ( int i = 0; i < 5; i++ )
// b[i] = i;
//
// b[0] = 0;
// b[1] = 1;
// b[2] = 2;
// b[3] = 3;
// b[4] = 4;
Now we are entering the main loop for calculating the maximum value.
// Simple memory coalescing
for (int i = ty; i<WIDTH/4; i+=BLOCK_DIM_Y)
{
// Create a float 4 variable to further improve for low level
float4 val = reinterpret_cast<float4*>(&a[row*WIDTH + i*4])[0];
maxval = fmaxf(maxval, val.x);
maxval = fmaxf(maxval, val.y);
maxval = fmaxf(maxval, val.z);
maxval = fmaxf(maxval, val.w);
// now store the value in the registry array
reg_array[reg_array_idx] = val;
reg_array_idx+=1;
// increment it
}
We will use the reg array so that we can save values later in the float4 format and increment them so that we will add later to it other values.
Reduction. We will use warp-level primitives. It allows threads in the same wrap to share data without shared memory. It can be done without synchronization.
We can view this as a way we can command to all threads to do this and they know this by default so they don’t think before doing it.
maxval = fmaxf(maxval, __shfl_xor_sync(0xffffffff, maxval, 16, 32));
// each thread will compare with ThreadIndx + 16
maxval = fmaxf(maxval, __shfl_xor_sync(0xffffffff, maxval, 8, 32));
// each thread will compare with ThreadIndx + 16
maxval = fmaxf(maxval, __shfl_xor_sync(0xffffffff, maxval, 4, 32));
// each thread will compare with ThreadIndx + 4
maxval = fmaxf(maxval, __shfl_xor_sync(0xffffffff, maxval, 2, 32));
// each thread will compare with ThreadIndx + 2
maxval = fmaxf(maxval, __shfl_xor_sync(0xffffffff, maxval, 1, 32));
// each thread will compare with ThreadIndx + 1
// All threads will contain the maximum value
if (lane_id == 0) // lane 0 of each wrap writes the computed max val to sram
{
reduction[warp_id] = maxval;
}
__syncthreads();
if (warp_id == 0)
{
// Now we computer the maximum value between all wraps
// Same logic applied
maxval = ty < BLOCK_DIM_Y/32 ? reduction[ty] : 0;
maxval = fmaxf(maxval, __shfl_xor_sync(0xffffffff, maxval, 16, 32));
maxval = fmaxf(maxval, __shfl_xor_sync(0xffffffff, maxval, 8, 32));
maxval = fmaxf(maxval, __shfl_xor_sync(0xffffffff, maxval, 4, 32));
maxval = fmaxf(maxval, __shfl_xor_sync(0xffffffff, maxval, 2, 32));
maxval = fmaxf(maxval, __shfl_xor_sync(0xffffffff, maxval, 1, 32));
}
if (ty == 0)
{
reduction[0] = maxval;
}
__syncthreads();
maxval = reduction[0];
Now lets get into the divisor part. It will be the same as the maximum values so we will skip some parts.
float divisor = 0.f;
reg_array_idx=0;
// Same logic here
#pragma unroll URF
for (int i = ty; i<WIDTH/4; i+=BLOCK_DIM_Y)
{
float4 val = reg_array[reg_array_idx];
val.x = __expf(val.x - maxval);
val.y = __expf(val.y - maxval);
val.z = __expf(val.z - maxval);
val.w = __expf(val.w - maxval);
// We calculate with float 4 the exp(x[i] - max)
divisor += val.x;
divisor += val.y;
divisor += val.z;
divisor += val.w;
// save the divisor
reg_array[reg_array_idx] = val;
// save the values
reg_array_idx+=1;
}
// Now create a sum reduction
divisor += __shfl_xor_sync(0xffffffff, divisor, 16, 32);
divisor += __shfl_xor_sync(0xffffffff, divisor, 8, 32);
divisor += __shfl_xor_sync(0xffffffff, divisor, 4, 32);
divisor += __shfl_xor_sync(0xffffffff, divisor, 2, 32);
divisor += __shfl_xor_sync(0xffffffff, divisor, 1, 32);
if (lane_id == 0)
{
reduction[warp_id] = divisor;
}
__syncthreads();
if (warp_id == 0)
{
divisor = ty < BLOCK_DIM_Y/32 ? reduction[ty] : 0;
divisor += __shfl_xor_sync(0xffffffff, divisor, 16, 32);
divisor += __shfl_xor_sync(0xffffffff, divisor, 8, 32);
divisor += __shfl_xor_sync(0xffffffff, divisor, 4);
divisor += __shfl_xor_sync(0xffffffff, divisor, 2);
divisor += __shfl_xor_sync(0xffffffff, divisor, 1);
}
if (ty == 0)
{
reduction[0] = divisor;
}
__syncthreads();
divisor = reduction[0];
Lets see the normalization part. where we divide the values saved in the reg_array
by the divisor
reg_array_idx = 0;
#pragma unroll URF
for (int i = ty; i<WIDTH/4; i+=BLOCK_DIM_Y)
{
float4 val = reg_array[reg_array_idx];
// So here we get the values from the float4 values and devide them by the divisor
val.x = val.x/divisor;
val.y = val.y/divisor;
val.z = val.z/divisor;
val.w = val.w/divisor;
// Now we cast them back :D
reinterpret_cast<float4*>(&b[row*WIDTH + i*4])[0] = val;
reg_array_idx+=1;
}
}
}
float4
: Optimized Data Representationfloat4
allows processing four floating-point values as a single unit. This vectorized format enhances memory coalescing and leverages GPU SIMD capabilities for faster computation.
Ensures threads in a warp access contiguous memory locations in a single transaction. It minimizes latency and maximizes GPU memory bandwidth efficiency.
Primitives like __shfl_xor_sync
enable direct communication between threads in a warp, bypassing shared memory and improving reduction and aggregation performance.
Reduction operations are performed directly in registers, avoiding shared memory latency. This approach leverages warp-level primitives for efficient summation or maximization.
Computes the softmax function incrementally by combining warp-level reductions and efficient normalization. It avoids storing intermediate values and operates directly in fast local memory.
https://github.com/SzymonOzog/FastSoftmax
https://www.youtube.com/watch?v=IpHjDoW4ffw
https://github.com/facebookincubator/AITemplate/wiki/How-to-write-a-fast-Softmax-CUDA-kernel%3F