Transformer Lens

Automate and integrate Transformer Lens for deep mechanistic interpretability research

TransformerLens is a community skill for mechanistic interpretability research on transformer models, covering activation inspection, attention pattern analysis, residual stream decomposition, and circuit-level investigation of model behavior.

What Is This?

Overview

TransformerLens provides patterns for analyzing the internal mechanisms of transformer language models. It covers hooking into model activations at every layer, attention head output inspection, residual stream contribution analysis, logit attribution to specific components, activation patching for causal intervention experiments, and visualization of internal model computations. The skill enables researchers to understand how transformer models process information and make predictions.

Who Should Use This

This skill serves interpretability researchers investigating how transformer models compute outputs, safety researchers analyzing model behavior to identify potential failure modes, and ML engineers debugging unexpected model predictions by tracing them to specific internal components.

Why Use It?

Problems It Solves

Language models are opaque, making it difficult to understand why they produce specific outputs. Standard debugging approaches treat models as black boxes without revealing internal computation pathways. Identifying which attention heads or layers contribute to a particular behavior requires specialized tooling. Testing causal hypotheses about model mechanisms needs activation intervention capabilities that standard frameworks lack.

Core Highlights

Hook-based activation access captures intermediate values at every layer, head, and MLP component. Logit attribution traces final output probabilities back to contributions from individual model components. Activation patching enables causal experiments by replacing activations from one input with another. Pre-loaded model support provides clean access to GPT-2, Pythia, and other popular architectures.

How to Use It?

Basic Usage

from dataclasses import dataclass, field

@dataclass
class ActivationCache:
    activations: dict[str, list[list[float]]] = field(
        default_factory=dict)

    def store(self, hook_name: str, values: list[list[float]]):
        self.activations[hook_name] = values

    def get(self, hook_name: str) -> list[list[float]]:
        return self.activations.get(hook_name, [])

    def list_hooks(self) -> list[str]:
        return list(self.activations.keys())

class TransformerAnalyzer:
    def __init__(self, num_layers: int, num_heads: int):
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.cache = ActivationCache()

    def register_hooks(self) -> list[str]:
        hooks = []
        for layer in range(self.num_layers):
            hooks.append(f"blocks.{layer}.attn.hook_result")
            hooks.append(f"blocks.{layer}.mlp.hook_post")
            hooks.append(f"blocks.{layer}.hook_resid_post")
        return hooks

    def get_attention_pattern(self, layer: int,
                               head: int) -> list[list[float]]:
        hook = f"blocks.{layer}.attn.hook_pattern"
        patterns = self.cache.get(hook)
        if patterns and head < len(patterns):
            return patterns[head]
        return []

Real-World Examples

from dataclasses import dataclass, field

@dataclass
class LogitAttribution:
    token: str
    components: dict[str, float] = field(default_factory=dict)

    @property
    def top_contributors(self) -> list[tuple[str, float]]:
        sorted_items = sorted(
            self.components.items(),
            key=lambda x: abs(x[1]), reverse=True)
        return sorted_items[:5]

class CircuitAnalyzer:
    def __init__(self, analyzer: TransformerAnalyzer):
        self.analyzer = analyzer

    def compute_logit_attribution(self, token: str,
                                   residual_contributions: dict[str, float]
                                   ) -> LogitAttribution:
        return LogitAttribution(
            token=token, components=residual_contributions)

    def activation_patch(self, clean_cache: ActivationCache,
                          corrupt_cache: ActivationCache,
                          hook_name: str) -> dict:
        clean_act = clean_cache.get(hook_name)
        corrupt_act = corrupt_cache.get(hook_name)
        if not clean_act or not corrupt_act:
            return {"error": "Missing activations"}
        diff = [[c - p for c, p in zip(cr, pr)]
                for cr, pr in zip(clean_act, corrupt_act)]
        magnitude = sum(sum(abs(v) for v in row) for row in diff)
        return {"hook": hook_name,
                "patch_magnitude": round(magnitude, 4)}

    def find_important_heads(self,
                              attributions: list[LogitAttribution]
                              ) -> list[str]:
        head_scores: dict[str, float] = {}
        for attr in attributions:
            for comp, score in attr.components.items():
                if "attn" in comp:
                    head_scores[comp] = head_scores.get(
                        comp, 0) + abs(score)
        return sorted(head_scores, key=head_scores.get,
                       reverse=True)[:5]

Advanced Tips

Use activation patching to establish causal rather than correlational relationships between components and model behavior. Compare attention patterns across layers to identify information flow from input tokens to prediction positions. Run logit attribution on multiple examples to verify that observed circuits generalize beyond individual inputs.

When to Use It?

Use Cases

Investigate which attention heads are responsible for a specific linguistic capability like subject-verb agreement. Debug a model prediction by tracing the logit contribution of each layer and head to the final output. Conduct activation patching experiments to identify the minimal circuit required for a particular model behavior.

Related Topics

Mechanistic interpretability research, attention visualization, circuit-level model analysis, probing classifiers, and activation intervention methods.

Important Notes

Requirements

The transformer_lens Python package for model loading and hook management. A supported model architecture such as GPT-2 or Pythia. GPU access for running models with activation caching enabled, which increases memory usage.

Usage Recommendations

Do: start with small models like GPT-2 Small for developing analysis techniques before scaling to larger models. Cache activations for the specific hooks you need rather than all hooks simultaneously. Verify findings with multiple input examples to distinguish general circuits from input-specific artifacts.

Don't: assume that attention weights alone explain model behavior without examining value vectors and MLP contributions. Run full activation caching on large models without checking memory requirements first. Draw conclusions from single-example analyses that may not generalize.

Limitations

Supported model architectures are limited to those explicitly implemented in the library. Activation caching increases memory usage proportionally to model size and sequence length. Interpretability findings on small models may not transfer to larger models with different internal structures.