Attention seeker: Unpacking the role of ‘attention’ in transformers like GPT
Published:
Introduction
I know, I know, this is a topic that has been absolutely exhausted on the internet… But for me to understand and fully ingrain a mathematical concept, I need to understand why it works and all the mechanics of how. As such, I have coded my own large language models from scratch in Python, and so this post represents more or less an article of my learnings for future Trent to refer to as a quick, tuned (pun intended) resource, but I hope some of you also find it useful if you think similarly!
To set the scene quickly, generative pre-trained transformers (GPT) have taken the world by storm by pushing large language models (LLMs) to the forefront of the minds of everyone who has access to the internet. However, despite their widespread usage and implementation in actual business, healthcare, and other applications, most people who access one or more of the available GPTs do not understand how GPTs work or how they are built. In fairness, anyone outside the companies building these large-scale models would not know exactly what the specific architecture is either, however, that does not mean we cannot understand how to build a GPT, nor the key concepts that make them such a powerful tool. In an age where data literacy is ever so critical, I think it’s incredibly valuable to understand more about the tools one is using. This can help users who might not even be writing their own code to hypothesise about the behaviour and outputs of the GPT they are entrusting for their task. Aside from that, I just love building stuff, and GPTs and deep neural networks are no exception! And that’s a good enough reason for me to sink countless hours into coding my own!
With that out of the way, let’s get into it. Basically, in 2017, Google humbly dropped this landmark paper which forever changed the modern technology world. In it, they introduce the idea of the Transformer – a new type of neural network architecture where tokenised input text has each token (in some scoped context window) weighted according to an attention mechanism, where attention is implemented in a multi-head fashion (i.e., several times in parallel). This mechanism essentially amplifies the signal of ‘important’ tokens for the context while diminishing others. The broad intuition behind attention is to place differential value on the information preceding the bit of information you are seeking to generate future values from. For example, in a sentence, the words preceding the final word often gives the reader a pretty good idea as to what the most likely final word will be. More generally, the punctuation, sentence length, and other semantics of writing often suggest that the sentence might be drawing to a close and therefore a final word is likely needed. If you are imagining all of these possibilities being encoded as probabilities – please hold onto that intuition!
In the paper, all of this culminated in the following diagram of the transformer neural network architecture1:
The authors then went on to define a similarly seminal equation which described how their approach to attention is implemented:
\[ \text{Attention}(Q, K, V) = \text{softmax}\frac{QK^{T}}{\sqrt{d_{k}}}V \]
where \(Q\) is a matrix of queries, \(K\) is a matrix of keys, \(V\) is a matrix of values, and \(d_{k}\) is the head size. We will dive into what these things are shortly. For those unfamiliar with activation functions in machine learning, softmax converts a set of real numbers into a probability distribution, which is invaluable for classification problems (both binary and multiclass). It works on some input tuple \(\mathbf{x}\) like so:
\[ \text{softmax}(\mathbf{x}_{i}) = \frac{e^{x_{i}}}{\sum_{j=1}^{K}e^{x_{j}}} \]
Since softmax outputs values between \(0\) and \(1\), it can be used to serve a slightly different purpose than just producing class label probabilities. In the context of attention, it can be used to generate weights. That’s right – since the outputted values are in the domain \([0,1]\) we can use the resulting values as weights. In plain English, these weights represent the affinity each token in a sequence has for every other token. We can use these affinities to then construct much more realistic predictions about future characters that resemble natural language.
Back to the attention equation, scaling by \(\sqrt{d_{k}}\) is done for numerical stability, as the authors found that when \(d_{k}\) was large, the dot products grew in magnitude such that the softmax function had extremely small gradients which led to performance issues.
Quick note on tokenisation of text
In order to understand more about attention in the context of transformers, let’s quickly touch on what it means to feed text as input into a neural network. Basically, text is converted into numbers – typically integers – which encode the textual information. By doing the reverse and decoding the encoded numbers (using the mapping), we get back the original text. The tricky part is that there is no ‘best’ way to do it, and each offers distinct advantages and disadvtanges. For example, the simplest form of tokenisation would be to treat each character as its own token. Consider the following text input:
"But they were all of them deceived."
Character-level tokenisation would look like (ignoring the spaces which I can’t colour here, but are considered tokens):
B u t t h e y w e r e a l l o f t h e m d e c e i v e d .
When tokenised (here in Python), the above text would look like:
chars = sorted(list(set("But they were all of them deceived.")))
print(''.join(chars))
## .Bacdefhilmortuvwy
If this was the entirety of our input data, our vocabulary size (i.e., the number of possible tokens to generate from the LLM) would be:
vocab_size = len(chars)
print(vocab_size)
## 19
From here, we could create a mapping of these unique characters (or tokens) to integers which something like a neural network can then work with:
str_to_int = {ch:i for i,ch in enumerate(chars)} # Map characters to integers like a lookup table
int_to_str = {i:ch for i,ch in enumerate(chars)} # Mapp integers to characters like a lookup table
encode = lambda s: [str_to_int[c] for c in s] # Take a string and output a list of integers
decode = lambda l: ''.join([int_to_str[i] for i in l]) # Take a list of integers and output a string
For example:
print(encode("But they were all of them deceived."))
## [2, 15, 14, 0, 14, 8, 6, 18, 0, 17, 6, 13, 6, 0, 3, 10, 10, 0, 12, 7, 0, 14, 8, 6, 11, 0, 5, 6, 4, 6, 9, 16, 6, 5, 1]
print(decode(encode("But they were all of them deceived.")))
## But they were all of them deceived.
However, many of the current proprietary LLMs do not use character-level tokenisation. For example, OpenAI considers one token to be equivalent to \(\approx 4\) characters. For our above example, this might look something like:
But they were all of them dece ived .
This would instead produce a smaller vocabulary size of \(8\). However, you could imagine that tokenisation might not always just be inferred from scrolling across the input data from left to right. Perhaps with all the research conducted into natural language processing, some effective lexicons of tokens have been constructed? This is absolutely the case! For example, OpenAI has open-sourced tiktoken
– an algorithm that uses byte pair encoding to convert text into tokens. We won’t implement it here as we are primarily focused on how self-attention works, but I just thought it helpful to explain the data we are working with when considering attention in a transformer. We will just be considering the simple character-level tokenisation in this post, but the broader ideas generalise to any tokenisation structure.
Basic Python implementation
What follows is a code synthesis/exploration of the incredible content presented on self-attention in this video, this video, and the numerous papers and articles I read on the topic when I built my own GPTs from scratch in Python. First, let’s import PyTorch since it’s the library that most of the major GPTs are programmed in:
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(123)
Introducing batches
Batching is a crucial part of deep learning which essentially controls the number of training samples processed in a forward and backward pass through the network. This is important, as feeding in the entire dataset at once is very computationally intensive and a large burden on memory. While batching is an intuitive solution, unfortunately it makes our lives a little trickier when tracking the dimensions of our tensors through the network to ensure that we can always matrix multiply objects of the requisite dimensions. Basically, rather than just having objects of dimensions \(N \times M\) or something similar, we now have objects with dimensions \(B \times N \times M\). More correctly, in the syntax of PyTorch, it would be \(B \times T \times C\), where \(B\) is the batch size, \(T\) is time (i.e., the maximum context length for predictions, also known as ‘block size’), and \(C\) is the channel size (i.e., the number of embedding dimensions – ‘channels’ as a terminology is a bit less intuitive for transformers but makes a bit more sense for architectures like convolutional neural networks for image recognition).
Putting all of that into practice, we can produce a \(B \times T \times C\) tensor of random numbers for illustrative purposes:
B,T,C = 4,8,32
x = torch.rand(B,T,C)
print(x)
## tensor([[[0.2961, 0.5166, 0.2517, ..., 0.0390, 0.9268, 0.7388],
## [0.7179, 0.7058, 0.9156, ..., 0.9131, 0.0275, 0.1634],
## [0.3009, 0.5201, 0.3834, ..., 0.8459, 0.3033, 0.6060],
## ...,
## [0.6042, 0.9836, 0.1444, ..., 0.0149, 0.0757, 0.0131],
## [0.6886, 0.9024, 0.1123, ..., 0.1364, 0.6918, 0.3545],
## [0.7969, 0.0061, 0.2528, ..., 0.0890, 0.4759, 0.5104]],
##
## [[0.5840, 0.1227, 0.9587, ..., 0.9079, 0.6650, 0.3573],
## [0.0975, 0.2956, 0.9027, ..., 0.2588, 0.7239, 0.3604],
## [0.1829, 0.2956, 0.8646, ..., 0.6647, 0.9296, 0.3848],
## ...,
## [0.2590, 0.7162, 0.5689, ..., 0.1197, 0.7091, 0.1012],
## [0.1098, 0.6353, 0.3719, ..., 0.2206, 0.3352, 0.7797],
## [0.4196, 0.0050, 0.1368, ..., 0.9959, 0.6785, 0.3981]],
##
## [[0.5921, 0.0056, 0.5577, ..., 0.7036, 0.7429, 0.9616],
## [0.5214, 0.5024, 0.6241, ..., 0.1842, 0.1508, 0.6205],
## [0.8014, 0.3660, 0.3785, ..., 0.9848, 0.7145, 0.7961],
## ...,
## [0.8078, 0.5055, 0.2281, ..., 0.1461, 0.5924, 0.4857],
## [0.1711, 0.9303, 0.7285, ..., 0.0072, 0.7181, 0.1904],
## [0.0051, 0.0117, 0.6601, ..., 0.9541, 0.8567, 0.4604]],
##
## [[0.2238, 0.3047, 0.3019, ..., 0.8718, 0.5126, 0.0086],
## [0.8053, 0.0787, 0.6293, ..., 0.6941, 0.6661, 0.1499],
## [0.7697, 0.1543, 0.2570, ..., 0.7464, 0.1591, 0.7705],
## ...,
## [0.9514, 0.2032, 0.6429, ..., 0.7668, 0.9030, 0.7455],
## [0.9970, 0.7154, 0.0031, ..., 0.9812, 0.0329, 0.6061],
## [0.9745, 0.2383, 0.3850, ..., 0.9115, 0.5589, 0.8239]]])
Here, x
plays the role of our data. If we employ a super basic implementation of the above softmax idea from above, we can produce weights by instantiating a matrix of zeroes (i.e., just a blanket average with no contextual information learned from data, for now), applying softmax to it, and multiplying it by x
. Here are the first two steps, noting that softmax produces just a constant \(0.1250\) because we have a context length of \(8\), but each value in each row of \(8\) is zero, which gives us just one unique value. Therefore, \(\frac{1}{8} = 0.1250\):
wei = torch.zeros((T,T))
wei = F.softmax(wei, dim=-1)
print(wei)
## tensor([[0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250],
## [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250],
## [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250],
## [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250],
## [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250],
## [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250],
## [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250],
## [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])
Now we can matrix multiply x
and wei
, where, for now, each row is identical:
out = wei @ x
print(out)
## tensor([[[0.6291, 0.6708, 0.4074, ..., 0.3540, 0.4096, 0.4836],
## [0.6291, 0.6708, 0.4074, ..., 0.3540, 0.4096, 0.4836],
## [0.6291, 0.6708, 0.4074, ..., 0.3540, 0.4096, 0.4836],
## ...,
## [0.6291, 0.6708, 0.4074, ..., 0.3540, 0.4096, 0.4836],
## [0.6291, 0.6708, 0.4074, ..., 0.3540, 0.4096, 0.4836],
## [0.6291, 0.6708, 0.4074, ..., 0.3540, 0.4096, 0.4836]],
##
## [[0.3261, 0.3951, 0.5430, ..., 0.5956, 0.6181, 0.3962],
## [0.3261, 0.3951, 0.5430, ..., 0.5956, 0.6181, 0.3962],
## [0.3261, 0.3951, 0.5430, ..., 0.5956, 0.6181, 0.3962],
## ...,
## [0.3261, 0.3951, 0.5430, ..., 0.5956, 0.6181, 0.3962],
## [0.3261, 0.3951, 0.5430, ..., 0.5956, 0.6181, 0.3962],
## [0.3261, 0.3951, 0.5430, ..., 0.5956, 0.6181, 0.3962]],
##
## [[0.5663, 0.4306, 0.4145, ..., 0.5826, 0.5510, 0.5400],
## [0.5663, 0.4306, 0.4145, ..., 0.5826, 0.5510, 0.5400],
## [0.5663, 0.4306, 0.4145, ..., 0.5826, 0.5510, 0.5400],
## ...,
## [0.5663, 0.4306, 0.4145, ..., 0.5826, 0.5510, 0.5400],
## [0.5663, 0.4306, 0.4145, ..., 0.5826, 0.5510, 0.5400],
## [0.5663, 0.4306, 0.4145, ..., 0.5826, 0.5510, 0.5400]],
##
## [[0.6523, 0.3443, 0.4556, ..., 0.7442, 0.5753, 0.4855],
## [0.6523, 0.3443, 0.4556, ..., 0.7442, 0.5753, 0.4855],
## [0.6523, 0.3443, 0.4556, ..., 0.7442, 0.5753, 0.4855],
## ...,
## [0.6523, 0.3443, 0.4556, ..., 0.7442, 0.5753, 0.4855],
## [0.6523, 0.3443, 0.4556, ..., 0.7442, 0.5753, 0.4855],
## [0.6523, 0.3443, 0.4556, ..., 0.7442, 0.5753, 0.4855]]])
Unfortunately, the above example is technically cheating. Even though our \(T\) dimension is \(8\), the network will not always see all \(8\) characters. If we implemented a network as rigid as this, the outputs would likely not resemble human language. Instead, we want some variability – meaning that it could see anywhere between one and eight tokens at any given time. This means, for example, if the network saw just one token, then it should not have weights assigned to the future seven tokens in the context sequence because it does not have access to this information. One way to solve this is by using the lower triangle of a matrix, such as the following:
tril = torch.tril(torch.ones(T,T))
print(tril)
## tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
## [1., 1., 0., 0., 0., 0., 0., 0.],
## [1., 1., 1., 0., 0., 0., 0., 0.],
## [1., 1., 1., 1., 0., 0., 0., 0.],
## [1., 1., 1., 1., 1., 0., 0., 0.],
## [1., 1., 1., 1., 1., 1., 0., 0.],
## [1., 1., 1., 1., 1., 1., 1., 0.],
## [1., 1., 1., 1., 1., 1., 1., 1.]])
See how the ones span across the T
dimension in a way that scales with the amount of valid preceding information? We can now use this to our advantage. Before we re-implement wei
, one other trick is to set the upper triangular values (i.e., the zeroes) to be '-inf'
because softmax will handle them appropriately and ensure that those ‘cheating’ values are assigned zero after the calculation of weights and therefore do not count towards the relative calculation of the other weights (as would be the case if we left them as true zeroes before applying softmax):
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
print(wei)
## tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
## [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
## [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
## [0., 0., 0., 0., -inf, -inf, -inf, -inf],
## [0., 0., 0., 0., 0., -inf, -inf, -inf],
## [0., 0., 0., 0., 0., 0., -inf, -inf],
## [0., 0., 0., 0., 0., 0., 0., -inf],
## [0., 0., 0., 0., 0., 0., 0., 0.]])
Finally we can compute the softmax function:
wei = F.softmax(wei, dim=-1)
print(wei)
## tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
## [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
## [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
## [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
## [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
## [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
## [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
## [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])
See how all the weights now sum to 1 nicely still? Here is the final step where we matrix multiply wei
with x
(i.e., our token data):
out = wei @ x
print(out)
## tensor([[[0.2961, 0.5166, 0.2517, ..., 0.0390, 0.9268, 0.7388],
## [0.5070, 0.6112, 0.5837, ..., 0.4761, 0.4772, 0.4511],
## [0.4383, 0.5808, 0.5169, ..., 0.5993, 0.4192, 0.5027],
## ...,
## [0.5912, 0.7430, 0.4823, ..., 0.4345, 0.3515, 0.5007],
## [0.6051, 0.7657, 0.4295, ..., 0.3919, 0.4001, 0.4798],
## [0.6291, 0.6708, 0.4074, ..., 0.3540, 0.4096, 0.4836]],
##
## [[0.5840, 0.1227, 0.9587, ..., 0.9079, 0.6650, 0.3573],
## [0.3407, 0.2091, 0.9307, ..., 0.5834, 0.6944, 0.3588],
## [0.2881, 0.2380, 0.9087, ..., 0.6105, 0.7728, 0.3675],
## ...,
## [0.3466, 0.4201, 0.6393, ..., 0.5914, 0.6552, 0.3319],
## [0.3127, 0.4509, 0.6011, ..., 0.5384, 0.6095, 0.3959],
## [0.3261, 0.3951, 0.5430, ..., 0.5956, 0.6181, 0.3962]],
##
## [[0.5921, 0.0056, 0.5577, ..., 0.7036, 0.7429, 0.9616],
## [0.5567, 0.2540, 0.5909, ..., 0.4439, 0.4468, 0.7911],
## [0.6383, 0.2913, 0.5201, ..., 0.6242, 0.5361, 0.7928],
## ...,
## [0.7257, 0.4172, 0.3213, ..., 0.6166, 0.4722, 0.6115],
## [0.6465, 0.4905, 0.3795, ..., 0.5295, 0.5073, 0.5514],
## [0.5663, 0.4306, 0.4145, ..., 0.5826, 0.5510, 0.5400]],
##
## [[0.2238, 0.3047, 0.3019, ..., 0.8718, 0.5126, 0.0086],
## [0.5146, 0.1917, 0.4656, ..., 0.7829, 0.5894, 0.0792],
## [0.5996, 0.1792, 0.3961, ..., 0.7708, 0.4459, 0.3097],
## ...,
## [0.5412, 0.3001, 0.5427, ..., 0.6769, 0.6684, 0.4091],
## [0.6063, 0.3595, 0.4657, ..., 0.7203, 0.5776, 0.4372],
## [0.6523, 0.3443, 0.4556, ..., 0.7442, 0.5753, 0.4855]]])
Boom!
Adding sophistication: Keys and queries
Okay, so we have a simple attention mechanism working. In its current state, this approach is equivalent to just a simple averaging of the values preceding a given token. However, we don’t actually want this to be uniform. Instead, we would rather have these weights learned from data so that certain tokens find other tokens more or less interesting. For example, it’s likely that a vowel might be seeking consonants in its short-term past to understand its potential position in a word or broader sentence. The attention approach proposed by Google gets around this problem by introducing the notion of keys and queries. Basically, every token in every position has two vectors: a query vector and a key vector. The query vector indexes what the token is looking for while the key vector indexes what the token contains. This is a pretty elegant idea, because to then get relationships between tokens in a sequence we just need to compute a dot product between the queries and keys. The result of this dot product then becomes our new wei
. Beautiful.
NOTE: It might be helpful for some people to think of all of this as just a sophisticated Microsoft Excel lookup procedure where each token is looking up every other one in the sequence, with a softmax and dot product thrown in because they are sick.
Let’s explore in code to solidify the idea. Recall that our current implementation looks like the following:
B,T,C = 4,8,32
x = torch.randn(B,T,C)
tril = torch.tril(torch.ones(T,T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
out = wei @ x
We can now implement just a single head of attention2 through the addition of queries and keys. More specifically, we are going to add a self-attention head after the designation of x
but before the masked_fill
and softmax
operations to essentially replace our usage of wei = torch.zeros((T,T))
.
The first change we are going to make is to create a hyperparameter head_size
which, intuitively, governs the size of the attention head. We can then use this via a linear pass through a neural network layer to create our key and query with final dimension head_size
. Importantly, we are going to disable the calculation of a bias term for the linear neural network layer as we just want the raw matrix multiplication with fixed weights:
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
Now to create the \(Q\) and \(K\) from the paper, we just call key
and query
on x
to connect the x
tensor of shape (B,T,C)
to head_size
:
Q = query(x)
K = key(x)
Importantly, no communication between tokens has happened yet! The query and key objects have not yet been combined, meaning that the information each token is seeking (query) has not been compared to the information it offers (key). To connect the two in the final step, we compute our new and improve wei
as the dot product between \(Q\) and \(K\). However, they both currently have the same dimensions of (B,T,head_size)
due to the linear mapping. In order to matrix multiply, we need to transpose \(K\) to have shape (B,head_size,T)
. We can do this by transposing along the last two dimensions of \(K\):
This gives wei
the shape (B,T,T)
which is correct to enable the rest of the process. A useful way to think about wei
is to consider it as containing, for every \(B\), a \(T \times T\) matrix containing the affinities (i.e., relationships) between all the tokens.
Putting all of it together:
B,T,C = 4,8,32
x = torch.randn(B,T,C)
# Add a single head for self-attention
head_size = 16 # Hyperparameter
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
K = key(x) # Has shape (B,T,head_size) due to linear mapping above
Q = query(x) # Has shape (B,T,head_size) due to linear mapping above
wei = Q @ K.transpose(-2,-1) # Produces (B,T,head_size) @ (B,head_size,T) = (B,T,T)
# Business as usual
tril = torch.tril(torch.ones(T,T))
#wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
out = wei @ x
Maybe looking at the first entry of wei
might solidify the learning:
print(wei[0])
## tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
## [0.2461, 0.7539, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
## [0.1788, 0.0721, 0.7490, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
## [0.0198, 0.0119, 0.9663, 0.0021, 0.0000, 0.0000, 0.0000, 0.0000],
## [0.0868, 0.0224, 0.6254, 0.0866, 0.1788, 0.0000, 0.0000, 0.0000],
## [0.0674, 0.2265, 0.0070, 0.2752, 0.2119, 0.2120, 0.0000, 0.0000],
## [0.0844, 0.0189, 0.6506, 0.0281, 0.0320, 0.1514, 0.0346, 0.0000],
## [0.0332, 0.0784, 0.1998, 0.1428, 0.1057, 0.0981, 0.3181, 0.0240]],
## grad_fn=<SelectBackward0>)
You can see that we now have affinities between all the tokens in the sequences learned from data, where each row is scaled as weights between \(0\) and \(1\). This is great, but the keen reader might notice that we are still missing one component from the equation in the paper: \(V\).
Incorporating V
The final part of this \(Q,K,V\) adventure is to, instead of aggregating the tokens x
, use the values (i.e., \(V\)) produced by propagating the linear neural network structure with size head_size
onto x
and then using the resulting tensor. There is a good reason for this: Our current approach produces a final out
tensor of shape (4,8,32)
, but we are using head_size=16
. We can verify this:
print(out.shape)
## torch.Size([4, 8, 32])
However, if we build a value
object using the same nn.Linear()
structure as we did for query
and key
and then right at the end multiply wei
by value
applied to x
(i.e., \(V\)), let’s see what dimensions we get:
B,T,C = 4,8,32
x = torch.randn(B,T,C)
head_size = 16 # Hyperparameter
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
K = key(x) # Has shape (B,T,head_size) due to linear mapping above
Q = query(x) # Has shape (B,T,head_size) due to linear mapping above
wei = Q @ K.transpose(-2,-1) # Produces (B,T,head_size) @ (B,head_size,T) = (B,T,T)
tril = torch.tril(torch.ones(T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
V = value(x)
out = wei @ V
print(out.shape)
## torch.Size([4, 8, 16])
Perfect! And now, for the final part of the equation, we can add in the scaling factor \(\sqrt{d_{k}}\) (where \(d_{k}\) is the head size) to ensure that the values passed into softmax are diffuse and not on an order of magnitude similar to head_size
that might cause issues where it drastically emphasises large values. This is a simple addition to the line that initialises wei
(note that \(\frac{1}{\sqrt{d_{k}}}\) is equivalent to \(\text{head_size}^{-0.5}\)):
B,T,C = 4,8,32
x = torch.randn(B,T,C)
head_size = 16 # Hyperparameter
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
K = key(x) # Has shape (B,T,head_size) due to linear mapping above
Q = query(x) # Has shape (B,T,head_size) due to linear mapping above
wei = Q @ K.transpose(-2,-1) * head_size**-0.5 # Produces (B,T,head_size) @ (B,head_size,T) = (B,T,T)
tril = torch.tril(torch.ones(T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
V = value(x)
out = wei @ V
And we are done! This approach to basic self-attention is now ready to be incorporated into a GPT coded in PyTorch.
Conclusion
Thank you for enduring yet another Trent waffle session! Hopefully at least some of this is mildly helpful to someone other than me :)