A Vector-Level View of GPT-2

Computational infographic of vector-wise inference in a decoder-only transformer

With annotations based on Anthropic's A Mathematical Framework for Transformer Circuits
and Neel Nanda's Comprehensive Mechanistic Interpretability Explainer & Glossary

View the draft
Feedback welcome!

The goal for this graphic is to aid others spinning up on mechanistic interpretability work with transformer models. The idea is to blend...

  1. A vector-level computational graph

    Neural nets are typically implemented as a series of tensor operations because current tooling is highly optimized for such operations. But the most computationally efficient way to code a neural network isn't necessarily the best way to understand what's going on inside of one. Vectors (embeddings) are the fundamental information-bearing units in transformers, and are—with few exceptions—operated on completely independently. An explanation cast in terms of batch × head × position × d_head tensors with thousands of high-dimension vectors packed into them loses focus on how information flows through a transformer model.

    Implementations sometimes even convolute the computational structure of transformers in order to improve performance. For example, the original transformers paper describes multi-headed attention as involving a “concatenation” of the attention-weighted result vectors from each head, which is then projected back to the residual stream. Implementations and discussion since have largely conformed with this precedent. But concatenation is utterly unprincipled, and it obscures the more intuitive way that information can be seen as flowing through attention heads: result vectors can be directly and independently projected back to the residual stream (as depicted in the present diagram), without any concatenation operation.

  2. with,
  3. A mechanistic interpretability infographic

    Existing work (such as Anthropic's excellent Transformer Circuits thread) is weighty and our understanding is rapidly evolving. A good primer illustration might help people bootstrap into this important research program.

This is a draft with some content still missing but shortly forthcoming. It was created for the project phase of BlueDot Impact's AI Safety Fundamentals' AI Alignment Course. Source available on Github for adaptation.

The target audience has a rough understanding of the transformer architecture but is fuzzy on the specifics, and is interested in becoming familiar with some of the core findings of the mechanistic interpretability research program. This diagram may be a good follow-up to Jay Alamar's introductory piece, The Illustrated Transformer. I tried to use similar color coding where possible. As compared to that piece, this diagram:

  1. depicts a decoder-only model (GPT-2 124M, a common reference model for interpretability work) rather than an encoder-decoder,
  2. assumes greater familiarity with the basic operation of transformers,
  3. presents additional technical details and conceptualizations relevant to mechanistic interpretability.

Thanks to my BlueDot colleague Hannes Whittingham for feedback and encouragement. Check out his final project on Reinforcement Learning from LLM Feedback!