Torchtitan

Torchtitan

Automate and integrate TorchTitan for large-scale distributed model training workflows

Category: productivity Source: Orchestra-Research/AI-Research-SKILLs

TorchTitan is a community skill for training large-scale transformer models with PyTorch, covering model parallelism, tensor sharding, pipeline parallelism, distributed training orchestration, and memory optimization for billion-parameter language models.

What Is This?

Overview

TorchTitan provides guidance on training large transformer models across multiple GPUs and nodes using PyTorch native distributed primitives. It covers tensor parallelism that splits individual layers across devices for models too large for single GPU memory, pipeline parallelism that distributes sequential model stages across GPUs with micro-batch scheduling, fully sharded data parallel training that distributes optimizer states and gradients across all workers to reduce per-device memory, activation checkpointing that trades compute for memory by recomputing activations during the backward pass, and mixed precision training that uses lower precision arithmetic to accelerate computation while maintaining convergence. The skill helps teams train billion-parameter models efficiently across GPU clusters with optimal hardware utilization and memory management.

Who Should Use This

This skill serves AI research teams training large language models, infrastructure engineers building distributed training platforms, and ML engineers scaling model architectures beyond single-GPU capacity. Teams working with models in the 7B to 70B parameter range will find the parallelism strategies particularly relevant.

Why Use It?

Problems It Solves

Large transformer models exceed single GPU memory capacity making naive single-device training impossible. Distributing model layers across devices requires careful partitioning to minimize communication overhead. Optimizer states for billion-parameter models consume more memory than the model weights themselves, often by a factor of two or more when using Adam-class optimizers. Achieving high GPU utilization across many devices needs carefully balanced workload distribution, strategic overlap of computation with communication, and proper micro-batch sizing to minimize pipeline bubbles.

Core Highlights

Tensor parallelism splits individual layers across multiple GPUs. Pipeline scheduler distributes sequential stages with micro-batch interleaving. FSDP shards model parameters and optimizer states across distributed workers. Memory optimizer enables selective activation checkpointing and automatic mixed precision.

How to Use It?

Basic Usage

import torch
from torch.distributed import (
    init_process_group
)
from torch.distributed.fsdp\
    import (
    FullyShardedDataParallel
        as FSDP,
    MixedPrecision
)

init_process_group(
    backend='nccl')

mp_policy = MixedPrecision(
    param_dtype=(
        torch.bfloat16),
    reduce_dtype=(
        torch.bfloat16),
    buffer_dtype=(
        torch.bfloat16))

model = build_transformer(
    vocab_size=32000,
    dim=4096,
    n_layers=32,
    n_heads=32)

model = FSDP(
    model,
    mixed_precision=(
        mp_policy),
    use_orig_params=True)

Real-World Examples

from torch.distributed\
    .algorithms\
    ._checkpoint\
    .checkpoint_wrapper\
    import (
    checkpoint_wrapper,
    apply_activation_checkpointing
)

def checkpoint_policy(
    module
):
    return isinstance(
        module,
        TransformerBlock)

apply_activation_checkpointing(
    model,
    check_fn=(
        checkpoint_policy))

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=3e-4,
    weight_decay=0.1)

for batch in dataloader:
    optimizer.zero_grad()
    loss = model(
        batch['input_ids'],
        batch['labels'])
    loss.backward()
    torch.nn.utils\
        .clip_grad_norm_(
        model.parameters(),
        1.0)
    optimizer.step()

Advanced Tips

Combine FSDP with tensor parallelism for 2D parallel training that handles both parameter sharding and layer splitting. Overlap gradient all-reduce with backward computation to hide communication latency. Profile memory usage per device using tools such as PyTorch's memory snapshot utility to find the optimal checkpoint granularity and identify unexpected allocation spikes before scaling to larger clusters.

When to Use It?

Use Cases

Train a multi-billion parameter large language model across a distributed GPU cluster. Scale a large vision transformer to handle extremely large image datasets with distributed data loading. Fine-tune a foundation model with parameter-efficient methods on limited GPU memory.

Related Topics

PyTorch, distributed training, FSDP, tensor parallelism, pipeline parallelism, large language models, and GPU clusters.

Important Notes

Requirements

PyTorch with NCCL backend for multi-GPU communication across distributed processes. Multiple GPUs with high-bandwidth interconnects such as NVLink or InfiniBand for efficient tensor sharding. Distributed launcher such as torchrun for spawning training processes across nodes with proper rank assignment, world size configuration, and fault tolerance through elastic training support.

Usage Recommendations

Do: start with FSDP before adding tensor or pipeline parallelism since it provides the simplest scaling path. Monitor per-device GPU memory utilization to ensure balanced and efficient sharding across all devices. Use gradient accumulation to simulate larger batch sizes when memory is constrained.

Don't: mix parallelism strategies without understanding their interaction effects on communication patterns. Skip gradient clipping since large models are prone to loss spikes during training. Assume linear scaling of throughput with GPU count since communication overhead grows with cluster size.

Limitations

Communication overhead between GPUs reduces scaling efficiency as the number of devices increases beyond optimal ratios. Pipeline parallelism introduces bubble time where some GPUs idle during micro-batch transitions. Debugging distributed training failures is significantly harder than single-GPU training due to non-deterministic communication timing, multi-process log coordination, and race conditions in gradient synchronization across devices. Structured logging with per-rank output and deterministic data seeding can reduce the time spent isolating these failures.