Ragged Paged Attention: A High-Performance and Flexible LLM Inference Kernel for TPU
Researchers introduced Ragged Paged Attention (RPA), a specialized inference kernel optimized for Google's TPUs that enables efficient large language model deployment. The innovation addresses the GPU-centric design of existing LLM serving systems by implementing fine-grained tiling and custom software pipelines, achieving up to 86% memory bandwidth utilization on TPU hardware.
The shift toward TPU-based LLM inference represents a significant infrastructure evolution driven by cost optimization imperatives in the AI industry. While most LLM serving frameworks remain GPU-optimized, RPA bridges a critical technical gap by tailoring attention mechanisms—the computationally intensive core of transformer models—specifically for TPU architectures. This matters because TPUs offer superior total cost of ownership compared to GPUs for large-scale inference workloads, but their distinct memory hierarchies and architectural constraints had previously limited adoption.
The paper's contribution extends beyond a single optimization technique. The three-pronged approach—dynamic memory slicing through fine-grained tiling, fused KV cache updates, and workload-aware kernel compilation—demonstrates sophisticated systems thinking around production inference challenges. By achieving 73% model FLOPs utilization in prefill operations and 86% memory bandwidth utilization in decode phases on Llama 3 8B, RPA validates that TPUs can match or exceed GPU performance when properly optimized.
The integration into vLLM and SGLang, two widely-adopted open-source LLM serving frameworks, provides immediate practical impact for enterprises deploying at scale. This standardization around TPU-optimized inference reduces switching costs and encourages broader adoption of Google's accelerator infrastructure, potentially reshaping vendor competition in the inference market. For cloud providers and enterprises, this kernel design establishes a clearer economic case for TPU deployments, particularly for cost-sensitive inference operations where margin compression pressures existing GPU-dominant architectures.
- →RPA achieves production-grade performance on TPUs with 86% memory bandwidth utilization in decode operations.
- →Fine-grained tiling and fused KV cache updates enable efficient handling of ragged execution patterns in dynamic LLM serving.
- →Integration into vLLM and SGLang standardizes TPU inference optimization across major open-source serving frameworks.
- →The kernel design demonstrates that TPUs can compete with GPUs on inference performance when architecture-specific optimizations are applied.
- →TPU-optimized inference reduces total cost of ownership for large-scale LLM deployments compared to GPU alternatives.