Computational infographic of vector-wise inference in a decoder-only transformer
With annotations based on Anthropic's A
Mathematical Framework for Transformer Circuits
and Neel Nanda's Comprehensive Mechanistic Interpretability Explainer & Glossary
The goal for this graphic is to aid others spinning up on mechanistic interpretability work with transformer models. The idea is to blend...
Neural nets are typically implemented as a series of tensor operations because current tooling is highly optimized for such operations. But the most computationally efficient way to code a neural network isn't necessarily the best way to understand what's going on inside of one. Vectors (embeddings) are the fundamental information-bearing units in transformers, and are—with few exceptions—operated on completely independently. An explanation cast in terms of batch × head × position × d_head tensors with thousands of high-dimension vectors packed into them loses focus on how information flows through a transformer model.
Implementations sometimes even convolute the computational structure of transformers in order to improve performance. For example, the original transformers paper describes multi-headed attention as involving a “concatenation” of the attention-weighted result vectors from each head, which is then projected back to the residual stream. Implementations and discussion since have largely conformed with this precedent. But concatenation is utterly unprincipled, and it obscures the more intuitive way that information can be seen as flowing through attention heads: result vectors can be directly and independently projected back to the residual stream (as depicted in the present diagram), without any concatenation operation.
Existing work (such as Anthropic's excellent Transformer Circuits thread) is weighty and our understanding is rapidly evolving. A good primer illustration might help people bootstrap into this important research program.
This is a draft with some content still missing but shortly forthcoming. It was created for the project phase of BlueDot Impact's AI Safety Fundamentals' AI Alignment Course. Source available on Github for adaptation.
The target audience has a rough understanding of the transformer architecture but is fuzzy on the specifics, and is interested in becoming familiar with some of the core findings of the mechanistic interpretability research program. This diagram may be a good follow-up to Jay Alamar's introductory piece, The Illustrated Transformer. I tried to use similar color coding where possible. As compared to that piece, this diagram:
Thanks to my BlueDot colleague Hannes Whittingham for feedback and encouragement. Check out his final project on Reinforcement Learning from LLM Feedback!
Main sources
Or if you'd like some more introductory resources