Flash Attention

Automate and integrate Flash Attention for faster and more efficient transformer model training

Flash Attention is a community skill for implementing memory-efficient attention mechanisms in transformer models, covering Flash Attention algorithm integration, memory optimization, sequence length scaling, kernel configuration, and training throughput improvement for large language model development.

What Is This?

Overview

Flash Attention provides patterns for using the Flash Attention algorithm to speed up transformer attention computation while reducing memory usage. It covers algorithm integration that replaces standard attention with the tiled, IO-aware implementation in existing models, memory optimization that reduces attention memory from quadratic to linear in sequence length, sequence length scaling that enables training on longer contexts by eliminating the attention memory bottleneck, kernel configuration that tunes block sizes and precision for specific GPU architectures, and throughput improvement that increases training speed by reducing memory transfers between GPU HBM and SRAM. The skill enables researchers to train transformer models faster with longer sequences.

Who Should Use This

This skill serves machine learning engineers training large language models, researchers experimenting with long-context transformers, and MLOps engineers optimizing GPU utilization for training infrastructure.

Why Use It?

Problems It Solves

Standard attention computes the full attention matrix which uses memory quadratic in sequence length, limiting context size. Memory transfers between GPU HBM and SRAM dominate attention runtime rather than floating-point operations. Long sequences that would benefit model quality cannot fit in GPU memory with standard attention. Training throughput is limited by the memory bandwidth bottleneck in attention layers.

Core Highlights

Tiled computation processes attention in blocks that fit in GPU SRAM avoiding materialization of the full attention matrix. IO-aware scheduling minimizes reads and writes between HBM and SRAM for each tile. Backward pass recomputation trades minimal extra compute for significant memory savings. Multi-head support processes all attention heads with fused kernels.

How to Use It?

Basic Usage

import torch
from flash_attn\
  import flash_attn_func

class FlashAttention(
  torch.nn.Module
):
  def __init__(
    self,
    dim: int,
    num_heads: int,
    dropout: float = 0.0
  ):
    super().__init__()
    self.num_heads =\
      num_heads
    self.head_dim =\
      dim // num_heads
    self.qkv = torch.nn\
      .Linear(
        dim, 3 * dim)
    self.out = torch.nn\
      .Linear(dim, dim)
    self.dropout =\
      dropout

  def forward(
    self,
    x: torch.Tensor
  ) -> torch.Tensor:
    B, S, D = x.shape
    qkv = self.qkv(x)\
      .reshape(
        B, S, 3,
        self.num_heads,
        self.head_dim)
    q, k, v = qkv\
      .unbind(dim=2)
    out = flash_attn_func(
      q, k, v,
      dropout_p=\
        self.dropout
        if self.training
        else 0.0)
    return self.out(
      out.reshape(
        B, S, D))

Real-World Examples

from flash_attn\
  import flash_attn_func

class TransformerBlock(
  torch.nn.Module
):
  def __init__(
    self,
    dim: int,
    heads: int,
    ff_dim: int
  ):
    super().__init__()
    self.attn =\
      FlashAttention(
        dim, heads)
    self.norm1 =\
      torch.nn.LayerNorm(
        dim)
    self.norm2 =\
      torch.nn.LayerNorm(
        dim)
    self.ff =\
      torch.nn.Sequential(
        torch.nn.Linear(
          dim, ff_dim),
        torch.nn.GELU(),
        torch.nn.Linear(
          ff_dim, dim))

  def forward(
    self,
    x: torch.Tensor
  ) -> torch.Tensor:
    x = x + self.attn(
      self.norm1(x))
    x = x + self.ff(
      self.norm2(x))
    return x

model = TransformerBlock(
  dim=1024,
  heads=16,
  ff_dim=4096).cuda()
x = torch.randn(
  4, 8192, 1024,
  device='cuda')

Advanced Tips

Use Flash Attention 2 which provides additional optimizations for non-contiguous key-value layouts and better parallelism across sequence length. Enable causal masking through the causal parameter rather than applying a separate mask tensor which adds memory overhead. Combine Flash Attention with gradient checkpointing for maximum memory savings when training on very long sequences.

When to Use It?

Use Cases

Replace standard attention in a transformer model to reduce memory usage and enable longer context training. Speed up language model fine-tuning by reducing the attention computation bottleneck. Scale sequence length beyond standard attention limits for document-level understanding tasks.

Related Topics

Flash Attention, transformer optimization, GPU memory, attention mechanisms, long context, and training throughput.

Important Notes

Requirements

NVIDIA GPU with Ampere architecture or later for optimal performance. Flash Attention library installed with CUDA toolkit compatibility. PyTorch version compatible with the Flash Attention kernel build.

Usage Recommendations

Do: use Flash Attention 2 for the latest optimizations and broadest hardware support. Verify numerical output matches standard attention within expected floating-point tolerance after integration. Profile memory usage before and after integration to quantify the improvement.

Don't: assume Flash Attention benefits all sequence lengths since very short sequences may not see significant speedup due to kernel launch overhead. Mix Flash Attention with custom attention masks that require materializing the full attention matrix. Skip gradient verification after replacing attention since numerical differences can affect training stability.

Limitations

Flash Attention requires NVIDIA GPUs with Ampere or later architecture and does not run on older GPU generations. Custom attention patterns like sparse or sliding window attention need specialized kernel implementations. The library must be compiled against the specific CUDA version which adds build dependency complexity.