Misconceptions I had about transformers

Misconception #1
Photo by Robert Linder / Unsplash

Transformers. Since the wildly influential paper "Attention Is All You Need" was published in 2017, they've become enigmatic silvers bullet for all your AI problems. They can translate. They can write high school essays. They can interpret images, or generate new ones. They can replace all our old RL algorithms. What can't they do?

I'm not going to write another blog post explaining transformers. You can find lots of smarter people than me writing about this all over the Internet. (Try https://nostalgebraist.tumblr.com/post/185326092369/1-classic-fully-connected-neural-networks-these or https://jalammar.github.io/illustrated-transformer/ if you need a starting point.) If you've never read about how transformers work, this probably isn't the blog post to start with. Go read those, then come back here.

As most of these posts will tell you, the most important part of transformer models are "multi-headed self-attention." If you understand self-attention, you understand the most important part. I found the basic concept of attention reasonably understandable - but I found a lot of the terminology really confusing. For my own benefit, and hopefully for someone else's too, I've taken the liberty of writing up the confusions I had and what answers I found.

I've had the great fortune of spending the past few weeks at Redwood Research's MLAB, where I had a lot of help from others understanding the state of deep learning research and informed a lot of these answers.

Why are they called transformers?

Transformers were introduced as a way to improve translation. Translation is a "transduction" task - basically, a sequence-to-sequence task. Transformers were a new approach to transduction - they're clearly quite distinct from RNNs - so their inventors gave them a new name. Today, we use transformers for plenty of non-transduction tasks, like text classification, but the name has stuck anyway.

RNNs can handle arbitrarily long inputs, while transformers have a maximum length. Doesn't that make RNNs better than transformers?

Here's another place where my previous blog post was pretty unclear: transformers perform better than RNNs because they can look farther back into the past through the text, even though they have a limited input and output size. This is because RNNs have to pass the context they know along token-by-token, while transformers can use attention to look anywhere in the text for more information. RNNs pass information along in a "hidden state" as it reads - sort of an internal memory. Transformers, on the other hand, can look backward at any token in the text based on its position and its content. Because of complex logic structures they can build using many layers of attention (see "A Mathematical Framework for Transformer Circuits" from Anthropic for examples), they can retain context better than RNNs.

For example, a transformer trying to figure out the author of a blog post might infer, "the time I saw 'By' at the beginning of this post it was followed by 'Tim Bauman', so the author must be 'Tim Bauman.'" To perform the same feat, an RNN would have to pass the author name along for every token in the whole blog post til the end and hope that no other important information had come along to overwrite the author name in the meantime. It's as if, as you read this, you couldn't scroll back up to the top to re-read passages you cared about.

So, even though transformers have a maximum memory size, they still remember longer than RNNs do.

One last interesting point: besides attention, transformers that I've seen actually work a lot like RNNs. Each token is processed in parallel and on its own. If the transformer is trying to generate text, you only look at the output for the last token, just like an RNN. But since every token can look at every other token during the attention phase, the whole text does need to processed

What are encoders and decoders in a transformer?

This is one of the most confusing things about transformers, to me. For example, take the image below from "Attention Is All You Need." The left hand side is the encoder and the right hand side is the decoder.

Intuitively, I assumed that "encoders" would take in the input and "decoders" would output the translations. But translation actually is an iterative process of generating one token at a time til you have a complete sentence. So, in order to generate a translation, you take in both the untranslated string and the prefix of the translation and then output the next token. Even more confusingly, "Outputs (shifted right)" and "Output probabilities" mean slightly different things in this diagram. "Outputs (shifted right)" are the previous outputs, while "Output probabilities" is the probabilities of all the possible tokens you could choose as the next token.

To give an example of translating English to Spanish, inputs might be: "My name is Tim," outputs (shifted right) would be "Me", and output probabilities would choose "llamo" as the next token. Then, to get the token after that, you'd have to run the new outputs through the transformer again.

But, I sneakily haven't actually answered this question yet. An encoder's attention can look forward and backward in the string, while a decoder can only look backward. Why does this make sense? Well, because the sentence you're translating from is already completed, so you can look at the end of the sentence when figuring out the meaning of the first few words. But, when making an output, you only know about the words you've already written, so you can't look forward in time to know what words you haven't written yet. In the diagram above, "masked multi-headed attention" means they don't let attention look at anything past the current position when it's processing a given token.

Okay, that kind of makes sense for the translation case. But what does it mean for GPT to be "decoder-only" and BERT to be "encoder-only"?

BERT came before GPT, and its innovation was that you could train a really language model by only using the encoder part of the transformer architecture. This meant that BERT generally processed the whole text together and its attention mechanism could look forward and backward in the text to understand what it meant. In order to train it, you "mask" certain tokens and BERT predicts what goes there. You can generate text using an encoder model by masking the last token and seeing what it generates.

GPT, on the other hand, is only a decoder. That means that it's trained only by generating new text, and the model can only look backwards when processing any given token. This is not for accuracy - it gives worse results, since it can't think as holistically about the text - but for speed. If you can guarantee that you only look backwards, you can actually save the internal state of the model for past tokens you've generated and reuse it when generating the next token. This greatly reduces the compute needed to generate lots of text, meaning you can make bigger models more easily.

But the names "encoder" and "decoder" are totally misleading, and neither is really an encoder or decoder as you'd think of them in other contexts.