Optimizing Softmax on CUDA: A Dive into High-Performance Kernels

In this blog post, we’ll explore the intricacies of optimizing the Softmax function on CUDA GPUs. From efficient memory access to warp-level reductions, we’ll walk through advanced techniques that make Softmax computations blazing fast! Let’s dive straight into the code.

Declaration

__global__ void SoftMaxKernel(float *input, float * output, int w, int h)
{
  int row = blockIdx.x;     // row index 
  int ty = threadIdx.y;     // thread index
  int warp_id = ty/32;      // warp id
  int lane_id = ty%32;      // lane id
  

Shared Memory and other variables

Lets see what memory do we use now. We optimize it by using shared memory and float4 variables.

  __shared__ float reduction[BLOCK_DIM_Y/32];        // Reduction array 
  float4 reg_array[CEILING((WIDTH/4),BLOCK_DIM_Y)];  // Registry array
  int reg_array_idx = 0;  // Index of the array

Checks and unroll

  if (row < h)  // check the limits
  {
    float maxval = 0; // set the max value first to 0
#pragma unroll   // unroll the array . Good optimization

// This is how #pragma unroll will do 
//
// for ( int i = 0; i < 5; i++ )
//      b[i] = i;
// 

// b[0] = 0;
// b[1] = 1;
// b[2] = 2;
// b[3] = 3;
// b[4] = 4;

Main loop

Now we are entering the main loop for calculating the maximum value.

    // Simple memory coalescing 
    for (int i = ty; i<WIDTH/4; i+=BLOCK_DIM_Y)
    {
      // Create a float 4 variable to further improve for low level
      float4 val = reinterpret_cast<float4*>(&a[row*WIDTH + i*4])[0];
      maxval = fmaxf(maxval, val.x);
      maxval = fmaxf(maxval, val.y);
      maxval = fmaxf(maxval, val.z);
      maxval = fmaxf(maxval, val.w);
      
      // now store the value in the registry array
      reg_array[reg_array_idx] = val;
      reg_array_idx+=1;
      // increment it
    }

We will use the reg array so that we can save values later in the float4 format and increment them so that we will add later to it other values.

Reduction

Reduction. We will use warp-level primitives. It allows threads in the same wrap to share data without shared memory. It can be done without synchronization.

We can view this as a way we can command to all threads to do this and they know this by default so they don’t think before doing it.

    maxval = fmaxf(maxval, __shfl_xor_sync(0xffffffff, maxval, 16, 32)); 
    // each thread will compare with ThreadIndx + 16

    maxval = fmaxf(maxval, __shfl_xor_sync(0xffffffff, maxval, 8, 32));
    // each thread will compare with ThreadIndx + 16
    
    maxval = fmaxf(maxval, __shfl_xor_sync(0xffffffff, maxval, 4, 32));
    // each thread will compare with ThreadIndx + 4
    maxval = fmaxf(maxval, __shfl_xor_sync(0xffffffff, maxval, 2, 32));
    // each thread will compare with ThreadIndx + 2

    maxval = fmaxf(maxval, __shfl_xor_sync(0xffffffff, maxval, 1, 32));
    // each thread will compare with ThreadIndx + 1

    // All threads will contain the maximum value

    if (lane_id == 0) // lane 0 of each wrap writes the computed max val to sram
    {
      reduction[warp_id] = maxval; 
    }
    __syncthreads();
    if (warp_id == 0)
    {
        // Now we computer the maximum value between all wraps 
        // Same logic applied
        maxval = ty < BLOCK_DIM_Y/32 ? reduction[ty] : 0;
        maxval = fmaxf(maxval, __shfl_xor_sync(0xffffffff, maxval, 16, 32));
        maxval = fmaxf(maxval, __shfl_xor_sync(0xffffffff, maxval, 8, 32));
        maxval = fmaxf(maxval, __shfl_xor_sync(0xffffffff, maxval, 4, 32));
        maxval = fmaxf(maxval, __shfl_xor_sync(0xffffffff, maxval, 2, 32));
        maxval = fmaxf(maxval, __shfl_xor_sync(0xffffffff, maxval, 1, 32));
    }
    if (ty == 0)
    {
        reduction[0] = maxval;
    }
    __syncthreads();
    maxval = reduction[0];

Calculating the division

Now lets get into the divisor part. It will be the same as the maximum values so we will skip some parts.

    float divisor = 0.f;
    reg_array_idx=0;
    
    // Same logic here 
#pragma unroll URF
    for (int i = ty; i<WIDTH/4; i+=BLOCK_DIM_Y)
    {
        float4 val = reg_array[reg_array_idx];
        val.x = __expf(val.x - maxval);
        val.y = __expf(val.y - maxval);
        val.z = __expf(val.z - maxval);
        val.w = __expf(val.w - maxval);
        // We calculate with float 4 the exp(x[i] - max)
        divisor += val.x;
        divisor += val.y;
        divisor += val.z;
        divisor += val.w;
        // save the divisor
        reg_array[reg_array_idx] = val;
        // save the values
      reg_array_idx+=1;
    }

    // Now create a sum reduction
    divisor += __shfl_xor_sync(0xffffffff, divisor, 16, 32);
    divisor += __shfl_xor_sync(0xffffffff, divisor, 8, 32);
    divisor += __shfl_xor_sync(0xffffffff, divisor, 4, 32);
    divisor += __shfl_xor_sync(0xffffffff, divisor, 2, 32);
    divisor += __shfl_xor_sync(0xffffffff, divisor, 1, 32);

    if (lane_id == 0)
    {
      reduction[warp_id] = divisor;
    }

    __syncthreads();
    if (warp_id == 0)
    {
        divisor = ty < BLOCK_DIM_Y/32 ? reduction[ty] : 0;
        divisor += __shfl_xor_sync(0xffffffff, divisor, 16, 32);
        divisor += __shfl_xor_sync(0xffffffff, divisor, 8, 32);
        divisor += __shfl_xor_sync(0xffffffff, divisor, 4);
        divisor += __shfl_xor_sync(0xffffffff, divisor, 2);
        divisor += __shfl_xor_sync(0xffffffff, divisor, 1);
    }

    if (ty == 0)
    {
        reduction[0] = divisor;
    }

    __syncthreads();
    divisor = reduction[0];

Online Softmax for output

Lets see the normalization part. where we divide the values saved in the reg_array by the divisor

    reg_array_idx = 0;
#pragma unroll URF
    for (int i = ty; i<WIDTH/4; i+=BLOCK_DIM_Y)
    {
        float4 val = reg_array[reg_array_idx];
        // So here we get the values from the float4 values and devide them by the divisor
        val.x = val.x/divisor;
        val.y = val.y/divisor;
        val.z = val.z/divisor;
        val.w = val.w/divisor;

        // Now we cast them back :D
        reinterpret_cast<float4*>(&b[row*WIDTH + i*4])[0] = val;
      reg_array_idx+=1;
    }

  }
}

Recap Of What We Learned

float4: Optimized Data Representation

float4 allows processing four floating-point values as a single unit. This vectorized format enhances memory coalescing and leverages GPU SIMD capabilities for faster computation.

Memory Coalescing

Ensures threads in a warp access contiguous memory locations in a single transaction. It minimizes latency and maximizes GPU memory bandwidth efficiency.

Warp-Level Primitives

Primitives like __shfl_xor_sync enable direct communication between threads in a warp, bypassing shared memory and improving reduction and aggregation performance.

Register-Level Reduction

Reduction operations are performed directly in registers, avoiding shared memory latency. This approach leverages warp-level primitives for efficient summation or maximization.

Online Softmax

Computes the softmax function incrementally by combining warp-level reductions and efficient normalization. It avoids storing intermediate values and operates directly in fast local memory.

Reference :

https://github.com/SzymonOzog/FastSoftmax

https://www.youtube.com/watch?v=IpHjDoW4ffw

https://github.com/karpathy/llm.c/blob/7ecd8906afe6ed7a2b2cdb731c042f26d525b820/dev/cuda/softmax_forward.cu

https://github.com/facebookincubator/AITemplate/wiki/How-to-write-a-fast-Softmax-CUDA-kernel%3F

//thanks 4 reading