Abstract
A common method to study deep learning systems is to use simplified model representations—for example, using singular value decomposition to visualize the model’s hidden states in a lower dimensional space. This approach assumes that the results of these simplifications are faithful to the original model. Here, we illustrate an important caveat to this assumption: even if the simplified representations can accurately approximate the full model on the training set, they may fail to accurately capture the model’s behavior out of distribution. We illustrate this by training Transformer models on controlled datasets with systematic generalization splits, including the Dyck balanced-parenthesis languages and a code completion task. We simplify these models using tools like dimensionality reduction and clustering, and then explicitly test how these simplified proxies match the behavior of the original model. We find consistent generalization gaps: cases in which the simplified proxies are more faithful to the original model on the in-distribution evaluations and less faithful on various tests of systematic generalization. This includes cases where the original model generalizes systematically but the simplified proxies fail, and cases where the simplified proxies generalize better. Together, our results raise questions about the extent to which mechanistic interpretations derived using tools like SVD can reliably predict what a model will do in novel situations.
- In-distribution: inputs that follow the same pattern as the data the model saw during training or the data used to build the simplification
- Out-of-distribution: inputs that differ in ways the training/simplification data didn’t cover (new structures, deeper nesting, new languages, unseen patterns)
- “Simplification”:
- Take the key and query vectors from a trained attention head for ~1000 sequences from the training data
- Form a big matrix of these key/query vectors
- Compute the top singular vectors (eigenvectors) of this matrix (PCA)
- For every future input, replace each original key/query vector with its projection onto those top components
- Use the projected keys/queries to compute attention weights instead of the originals.
- This helps identify what an attention head or MLP is “doing”—e.g. “This head tracks bracket depth,” “This head copies previous tokens,” “This circuit implements addition.”