Hey everyone! In this tutorial, I’ll walk you through optimizing a RoPE (Rotary Position Embedding) kernel in CUDA that outperforms PyTorch’s built-in implementation. If you’re working with long sequences, this optimization can make a big difference in performance.
Let’s dive in!
First, we define our CUDA kernel:
__global__ void rope_kernel(float* x, float* out, int N){
x
: Input tensor (float array)out
: Output tensor (float array)N
: The working size, which represents the hidden dimension divided by 4 (since we use float4
for efficiency)float4
int idx = blockIdx.x * blockDim.x + threadIdx.x;
float4 x_v = reinterpret_cast<float4*>(&(x[idx * 4]))[0];
Each thread computes its global index and loads four float values at once using float4
, reducing memory overhead and improving computational efficiency.
int token_pos = idx / N;
int token_idx = idx % N;
token_pos
: The position of the token in the sequencetoken_idx
: The index of the token within the hidden dimension float exp1 = 1.0f / powf(theta, token_idx * 2 / (N * 4));
float exp2 = 1.0f / powf(theta, ((token_idx * 2) + 1) / (N * 4));
These values determine the scaling factors for rotation, ensuring correct positional encoding.
float sin1 = sinf(token_pos / exp1);
float cos1 = cosf(token_pos / exp1);
float sin2 = sinf(token_pos / exp2);
float cos2 = cosf(token_pos / exp2);
Using these precomputed sine and cosine values, we can efficiently rotate embeddings.
float4 out_v;
out_v.x = x_v.x * cos1 - x_v.y * sin1;
out_v.y = x_v.x * sin1 + x_v.y * cos1;
out_v.z = x_v.z * cos2 - x_v.w * sin2;
out_v.w = x_v.z * sin2 + x_v.w * cos2;
reinterpret_cast<float4*>(&(out[idx * 4]))[0] = out_v;
}
We apply RoPE transformations pairwise and store the results back into the output tensor.
Next, we bind our CUDA kernel to PyTorch.
void rope(torch::Tensor x, torch::Tensor out) {
int seq_len = x.size(0);
int hidden_size = x.size(1);
seq_len
: Number of tokens in the sequencehidden_size
: Number of features per tokenN
int N = (int)(hidden_size / 4);
Since we process four values at a time using float4
, we divide hidden_size
by 4.
dim3 grid((seq_len * N + BLOCK_SIZE - 1) / BLOCK_SIZE);
dim3 block(BLOCK_SIZE);
rope_kernel<<<grid, block>>>(x.data_ptr<float>(), out.data_ptr<float>(), N);
}
This sets up our CUDA grid and calls the kernel with the appropriate tensor pointers.
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
TORCH_BINDING_COMMON_EXTENSION(rope)
}
#define BLOCK_SIZE 256
#define theta 10000.0f
#define STRINGFY(str) #str
#define TORCH_BINDING_COMMON_EXTENSION(func)
m.def(STRINGFY(func), &func, STRINGFY(func));
This ensures that PyTorch can call our CUDA function efficiently.
lib = load(
name="rope",
sources=["rope.cu"],
extra_cuda_cflags=[
"-O3",
"--use_fast_math",
],
extra_cflags=["-std=c++17"],
)
By compiling with -O3
and --use_fast_math
, we ensure maximum performance.
Here are some benchmark results comparing a naive implementation to our optimized CUDA kernel:
Testing M=4096, N=512
Naive: 2.8655ms, CUDA f32: 0.0244ms
------------------------------------------------------------
Testing M=4096, N=1024
Naive: 0.4080ms, CUDA f32: 0.0508ms
------------------------------------------------------------
Testing M=8192, N=512
Naive: 0.3888ms, CUDA f32: 0.0509ms
------------------------------------------------------------
Testing M=8192, N=1024
Naive: 0.9925ms, CUDA f32: 0.3385ms
------------------------------------------------------------
The speedup is clear—our CUDA kernel is orders of magnitude faster!
By leveraging CUDA’s parallel processing and float4
optimizations, we’ve drastically improved RoPE computation speed. This is a great example of how hardware acceleration can unlock major efficiency gains in deep learning workloads.
Try it out and let me know if you have any questions. Happy coding!