In this blog post, I will break down the forward pass of FlashAttention and explain how it works step by step. Let’s start from the beginning. Link : https://arxiv.org/pdf/2501.12948
First, we define a kernel template to allow flexibility with different data types:
template <typename T>
__global__ void flashKernel(const T *Q, // Pointer to Q tensor
const T *K, // Pointer to K tensor
const T *V, // Pointer to V tensor
T *O, // Pointer to the output tensor
T *m, // Pointer to the m tensor for maximum values
T *l, // Pointer to the l tensor for row sums
const int seq_len, // Sequence length for QKV
const int head_dim, // Dimension of the head
int Tc, int Tr, int Bc, int Br) {
// Tc and Tr: Number of tiles for each block (columns and rows)
// Bc and Br: Tile sizes for columns and rows
Now, let’s discuss how the thread and block indexes are calculated. For simplicity, I assume you’re already familiar with CUDA concepts (a dedicated tutorial will follow).
int threadIndex = threadIdx.x; // Thread index for rows
int BatchIndex = blockIdx.x; // Batch index (row of the grid)
int HeadIndex = blockIdx.y; // Head index
int NrHeads = gridDim.y; // Total number of heads
Next, we calculate the offsets for Q, K, V, and the auxiliary tensors (l and m):
int qkvOffset = (BatchIndex * NrHeads * seq_len * head_dim) +
(HeadIndex * seq_len * head_dim);
// The QKV tensor has shape: [BatchSize, NrHeads, SeqLen, HeadDim]
// To jump between batches: BatchSize * NrHeads * SeqLen * HeadDim
// To jump between heads: NrHeads * SeqLen * HeadDim
int lmOffset = (BatchIndex * NrHeads * seq_len) + (HeadIndex * seq_len);
// Offsets for auxiliary tensors (l and m) with shape: [BatchSize, NrHeads, SeqLen]
The key idea is to store a tile of size Bc x HeadDim in shared memory for efficient computation:
extern __shared__ float sharedMemory[];
int TileSize = Bc * head_dim;
T *Qi = sharedMemory; // Points to sharedMemory[0]
T *Ki = &sharedMemory[TileSize]; // Next TileSize
T *Vi = &sharedMemory[2 * TileSize];
T *Si = &sharedMemory[3 * TileSize];
The main computation is performed in a loop, iterating over all columns of K and V:
for (int j = 0; j < Tc; ++j) {
// Load tiles of K and V into shared memory
for (int aux = 0; aux < head_dim; ++aux) {
Ki[threadIndex * head_dim + aux] =
K[qkvOffset + j * TileSize + threadIndex * head_dim + aux];
Vi[threadIndex * head_dim + aux] =
V[qkvOffset + j * TileSize + threadIndex * head_dim + aux];
}
__syncthreads(); // Synchronize threads
Next, iterate over queries:
for (int i = 0; i < Tr; ++i) {
// Load queries into shared memory
for (int aux = 0; aux < head_dim; ++aux) {
Qi[threadIndex * head_dim + aux] =
Q[qkvOffset + i * TileSize + threadIndex * head_dim + aux];
}
// Load previous max and sum values
float rowPrevMax = m[lmOffset + i * Br + threadIndex];
float rowPrevSum = l[lmOffset + i * Br + threadIndex];
Sij, the similarity scores, using the dot product of Q and K:
float rowMax = -INFINITY;
for (int y = 0; y < Bc; ++y) {
float sum = 0;
// The Dot-Product between each Q and K
for (int x = 0; x < head_dim; ++x) {
sum += Qi[y * seq_len + x] * Ki[x * seq_len + y];
}
sum *= rsqrtf(head_dim); // Scale by √(head_dim)
Si[Br * threadIndex + y] = sum;
rowMax = fmaxf(rowMax, sum); // Finding the Maximum Value
}
Perform the online softmax operation:
float rowSum = 0;
for (int aux = 0; aux < Bc; ++aux) {
Si[threadIndex * Br + aux] =
expf(Si[threadIndex * Br + aux] - rowMax);
rowSum += Si[threadIndex * Br + aux];
}
float newMax = fmaxf(rowPrevMax, rowMax); // Updating the new Maximum
float newSum = rowPrevSum * expf(rowPrevMax - newMax) +
rowSum * expf(rowMax - newMax);// Updating the new SumRow
Finally, update the output tensor and auxiliary values:
for (int aux = 0; aux < head_dim; ++aux) {
float value = 0.0f;
//The product between Sij and V
for (int y = 0; y < Bc; ++y) {
value += Si[threadIndex * Bc + y] * Vi[y * head_dim + aux];
}
//Writing in the memory the output
O[qkvOffset + (TileSize * i) +
(threadIndex * seq_len) + aux] =
(1 / newSum) * ((rowPrevSum * expf(rowPrevMax - newMax) * value) +
(expf(rowMax - newMax) * value));
}
l[lmOffset + i * Br + threadIndex] = newSum;
m[lmOffset + i * Br + threadIndex] = newMax;
}
__syncthreads();
}
}