Understanding the Forward Pass of FlashAttention

In this blog post, I will break down the forward pass of FlashAttention and explain how it works step by step. Let’s start from the beginning. Link : https://arxiv.org/pdf/2501.12948

Setting Up the Kernel Template

First, we define a kernel template to allow flexibility with different data types:

template <typename T>
__global__ void flashKernel(const T *Q,         // Pointer to Q tensor
                            const T *K,         // Pointer to K tensor
                            const T *V,         // Pointer to V tensor
                            T *O,               // Pointer to the output tensor
                            T *m,               // Pointer to the m tensor for maximum values 
                            T *l,               // Pointer to the l tensor for row sums
                            const int seq_len,  // Sequence length for QKV
                            const int head_dim, // Dimension of the head
                            int Tc, int Tr, int Bc, int Br) {
    // Tc and Tr: Number of tiles for each block (columns and rows)
    // Bc and Br: Tile sizes for columns and rows

Determining Indexes

Now, let’s discuss how the thread and block indexes are calculated. For simplicity, I assume you’re already familiar with CUDA concepts (a dedicated tutorial will follow).

    int threadIndex = threadIdx.x;     // Thread index for rows
    int BatchIndex = blockIdx.x;       // Batch index (row of the grid)
    int HeadIndex = blockIdx.y;        // Head index
    int NrHeads = gridDim.y;           // Total number of heads

Offsets for QKV and LM

Next, we calculate the offsets for Q, K, V, and the auxiliary tensors (l and m):

    int qkvOffset = (BatchIndex * NrHeads * seq_len * head_dim) + 
                    (HeadIndex * seq_len * head_dim);
    // The QKV tensor has shape: [BatchSize, NrHeads, SeqLen, HeadDim]
    // To jump between batches: BatchSize * NrHeads * SeqLen * HeadDim
    // To jump between heads:   NrHeads * SeqLen * HeadDim

    int lmOffset = (BatchIndex * NrHeads * seq_len) + (HeadIndex * seq_len);
    // Offsets for auxiliary tensors (l and m) with shape: [BatchSize, NrHeads, SeqLen]

Utilizing Shared Memory

The key idea is to store a tile of size Bc x HeadDim in shared memory for efficient computation:

    extern __shared__ float sharedMemory[];
    int TileSize = Bc * head_dim;
    T *Qi = sharedMemory;                   // Points to sharedMemory[0]
    T *Ki = &sharedMemory[TileSize];        // Next TileSize
    T *Vi = &sharedMemory[2 * TileSize];
    T *Si = &sharedMemory[3 * TileSize];

Main Computation Loop

The main computation is performed in a loop, iterating over all columns of K and V:

    for (int j = 0; j < Tc; ++j) {
        // Load tiles of K and V into shared memory
        for (int aux = 0; aux < head_dim; ++aux) {
            Ki[threadIndex * head_dim + aux] = 
                K[qkvOffset + j * TileSize + threadIndex * head_dim + aux];
            Vi[threadIndex * head_dim + aux] = 
                V[qkvOffset + j * TileSize + threadIndex * head_dim + aux];
        __syncthreads(); // Synchronize threads

Next, iterate over queries:

        for (int i = 0; i < Tr; ++i) {
            // Load queries into shared memory
            for (int aux = 0; aux < head_dim; ++aux) {
                Qi[threadIndex * head_dim + aux] = 
                    Q[qkvOffset + i * TileSize + threadIndex * head_dim + aux];

            // Load previous max and sum values
            float rowPrevMax = m[lmOffset + i * Br + threadIndex];
            float rowPrevSum = l[lmOffset + i * Br + threadIndex];

​ Sij, the similarity scores, using the dot product of Q and K:

            float rowMax = -INFINITY;
            for (int y = 0; y < Bc; ++y) {
                float sum = 0;
                // The Dot-Product between each Q and K
                for (int x = 0; x < head_dim; ++x) {
                    sum += Qi[y * seq_len + x] * Ki[x * seq_len + y];
                sum *= rsqrtf(head_dim);  // Scale by √(head_dim)
                Si[Br * threadIndex + y] = sum;
                rowMax = fmaxf(rowMax, sum); // Finding the Maximum Value

Applying Online Softmax

Perform the online softmax operation:

            float rowSum = 0;
            for (int aux = 0; aux < Bc; ++aux) {
                Si[threadIndex * Br + aux] = 
                    expf(Si[threadIndex * Br + aux] - rowMax);
                rowSum += Si[threadIndex * Br + aux];

            float newMax = fmaxf(rowPrevMax, rowMax); // Updating the new Maximum
            float newSum = rowPrevSum * expf(rowPrevMax - newMax) + 
                           rowSum * expf(rowMax - newMax);// Updating the new SumRow

Writing Outputs

Finally, update the output tensor and auxiliary values:

            for (int aux = 0; aux < head_dim; ++aux) {
                float value = 0.0f;
                //The product between Sij and V
                for (int y = 0; y < Bc; ++y) {
                    value += Si[threadIndex * Bc + y] * Vi[y * head_dim + aux];
                //Writing in the memory the output
                O[qkvOffset + (TileSize * i) + 
                  (threadIndex * seq_len) + aux] = 
                    (1 / newSum) * ((rowPrevSum * expf(rowPrevMax - newMax) * value) +
                                    (expf(rowMax - newMax) * value));

            l[lmOffset + i * Br + threadIndex] = newSum;
            m[lmOffset + i * Br + threadIndex] = newMax;