Transformers

Introduction to Attention

Consider the following two sentences:

1. The dog did not cross the road as it was too tired.
2. The dog did not cross the road as it was too narrow.

In the first sentence, the pronoun it refers to dog whereas in the second sentence it refers to the road. We need our model to understand this relationship. To achieve this, we can use the whole sequence to compute a weighted average of each embedding instead of using a fixed embedding for each token. Another way to formulate this is to say that given a sequence of token embeddings \(x_1, \ldots, x_n\), we produce a sequence of new embeddings \(x_1', \ldots, x_n'\) where each \(x_i'\) is a linear combination of all the \(x_j\):

\[x_i' = \sum_{j = 1}^{n} w_{ji}x_j\]

The coefficients \(w_{ji}\) are called attention weights and are normalized so that \(\sum_j w_{ji} = 1\). The weighted averaging scheme would probably assign a higher weight to the word dog when creating the embedding for the word it for our first sentence example and would assign a higher weight on road for the second one. Embeddings that are generated in this way are called contextualized embeddings and predate the invention of transformers in language models like ELMo.

Note: We use the terms token and word interchangeably here, even though they could be different. More on this later.

Computing the Attention Weights

Dot-Product (Multiplicative)

The dot product-based scoring function is the simplest one and has no parameters to tune. We can compute the dot product between \(x_i\) and \(x_j\) to compute \(w_{ij}\). Dot product indicates how similar \(x_i\) is to \(x_j\) and can hence be used as a weight.

\[w_{ji} = x_i \cdot x_j\]

Scaled Dot-Product Attention

The scaled dot product-based scoring function divides the dot product by the square root of the dimension. According to Vaswani et al., as the dimension increases, the dot products grow larger, which pushes the softmax function into regions with extreme gradients.

\[w_{ji} = \frac{x_i \cdot x_j}{\sqrt{d}}\]

To illustrate why the dot products get large, assume that the components of $q$ and $k$ are independent random variables with mean $0$ and variance $1$. Then their dot product, \(x_i \cdot x_j = \sum_{r=1}^{d} x_{ir} x_{jr}\), has mean $0$ and variance \(d\).

In our simple example, we only used the embeddings “as is” to compute the attention scores and weights. Though this mechanism is simple, there is no learning of weights happening in this. And this is where the mechanism of query and key matrices comes into the picture. In practice, the self-attention layer applies three independent linear transformations to each embedding to generate the query, key, and value vectors. These transformations project the embeddings and each projection carries its own set of learnable parameters, which allows the self-attention layer to focus on different semantic aspects of the sequence.

Self-Attention: Queries, Keys and Values

The concept of self-attention is based on three vector representations:

  1. Query
  2. Key
  3. Value

Instead of directly finding the dot product between \(x_i\) and \(x_j\), we project each token embedding into three vectors called query, key, and value. The query, key, and value vectors are obtained by projecting the input vector, $x_i$, at time $i$ on the learnable weight matrices $W^{Q} \in \mathbb{R}^{d_{\text{model}} \times d_k}, \quad W^{K} \in \mathbb{R}^{d_{\text{model}} \times d_k}, \quad W^{V} \in \mathbb{R}^{d_{\text{model}} \times d_v}$ to get $q_i$, $k_i$, and $v_i$, respectively. These weight matrices $W^{Q}$, $W^{K}$, and $W^{V}$ are randomly initialized, and the weights are jointly learned from the training process. In the original transormers paper, the authors kept $d_k = d_v = 64$; however, we can choose them to be of different dimensions.

Consider sentence 1 again for example. Let’s focus on finding the embedding for the token it. Assume that the dimension of all the tokens is $d_{\text{model}}=512$ i.e. we represent each token through a dense vector of length $512$. All these embeddings are randomly initialised before training and are learned during training the model.

Instead of finding the dot product between it and all other tokens in the sentence and using them as the weights, we do the following:

Similarly, we can find the embeddings of all other tokens by casting them into their query vectors. For this purpose, we pack the queries, keys and values into matrices $Q, K$ and $V$ as below:

\[Q = \begin{array}{c} \begin{pmatrix} \overrightarrow{q_0} \\ \overrightarrow{q_1} \\ \vdots \\ \overrightarrow{q_m} \end{pmatrix} \end{array}; \ K^T = \begin{array}{c} \begin{pmatrix} \overrightarrow{k_0} & \overrightarrow{k_1} & \ldots & \overrightarrow{k_n} \end{pmatrix} \end{array}; \ V = \begin{array}{c} \begin{pmatrix} \overrightarrow{v_0} \\ \overrightarrow{v_1} \\ \vdots \\ \overrightarrow{v_n} \end{pmatrix} \end{array}\]

Here, $Q \in \mathbb{R}^{m \times d_k}$, $K \in \mathbb{R}^{n \times d_k}$ and $V \in \mathbb{R}^{n \times d_v}$. Note that for the algebra to work out, the number of keys and values $n$ must be equal, but the number of queries $m$ can vary. Likewise, the dimensionality of the keys and queries must match, but that of the values can vary.

To find out the attention scores for each query with all the keys as in step 4, we can do:

\[\frac{1}{\sqrt{d_k}}QK^T = \frac{1}{\sqrt{d_k}} \begin{array}{c} \begin{pmatrix} \overrightarrow{q}_0 \cdot \overrightarrow{k}_0 & \overrightarrow{q}_0 \cdot \overrightarrow{k}_1 & \ldots & \overrightarrow{q}_0 \cdot \overrightarrow{k}_n \\ \overrightarrow{q}_1 \cdot \overrightarrow{k}_0 & \overrightarrow{q}_1 \cdot \overrightarrow{k}_1 & \ldots & \overrightarrow{q}_1 \cdot \overrightarrow{k}_n \\ \vdots & \vdots & \ddots & \vdots \\ \overrightarrow{q}_m \cdot \overrightarrow{k}_0 & \overrightarrow{q}_m \cdot \overrightarrow{k}_1 & \ldots & \overrightarrow{q}_m \cdot \overrightarrow{k}_n \\ \end{pmatrix} \end{array}\]

To find the attention weights for each query with all the keys as in step 5, we can do:

\[\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) = \begin{array}{c} \begin{pmatrix} \text{softmax}\left(\frac{1}{\sqrt{d_k}} \langle \overrightarrow{q}_0 \cdot \overrightarrow{k}_0, \overrightarrow{q}_0 \cdot \overrightarrow{k}_1, \ldots, \overrightarrow{q}_0 \cdot \overrightarrow{k}_n \rangle \right) \\ \text{softmax}\left(\frac{1}{\sqrt{d_k}} \langle \overrightarrow{q}_1 \cdot \overrightarrow{k}_0, \overrightarrow{q}_1 \cdot \overrightarrow{k}_1, \ldots, \overrightarrow{q}_1 \cdot \overrightarrow{k}_n \rangle \right) \\ \vdots \\ \text{softmax}\left(\frac{1}{\sqrt{d_k}} \langle \overrightarrow{q}_m \cdot \overrightarrow{k}_0, \overrightarrow{q}_m \cdot \overrightarrow{k}_1, \ldots, \overrightarrow{q}_m \cdot \overrightarrow{k}_n \rangle \right) \\ \end{pmatrix} \end{array} = \begin{array}{c} \begin{pmatrix} s_{0,0} & s_{0,1} & \ldots & s_{0,n} \\ s_{1,0} & s_{1,1} & \ldots & s_{1,n} \\ \vdots & \vdots & \ddots & \vdots \\ s_{m,0} & s_{m,1} & \ldots & s_{m,n} \\ \end{pmatrix} \end{array}\]

where for each row $i$, as a result of the softmax operation, \(\sum_{j=0}^{n} s_{i,j} = 1\)

The last step is to multiply this matrix by $V$:

\[\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V = \begin{array}{c} \begin{pmatrix} s_{0,0} & s_{0,1} & \ldots & s_{0,n} \\ s_{1,0} & s_{1,1} & \ldots & s_{1,n} \\ \vdots & \vdots & \ddots & \vdots \\ s_{m,0} & s_{m,1} & \ldots & s_{m,n} \\ \end{pmatrix} \end{array} \begin{array}{c} \begin{pmatrix} \overrightarrow{v_0} \\ \overrightarrow{v_1} \\ \vdots \\ \overrightarrow{v_n} \end{pmatrix} \end{array} = \begin{array}{c} \begin{pmatrix} \sum_{i=0}^{n} s_{0,i} \overrightarrow{v}_i \\ \sum_{i=0}^{n} s_{1,i} \overrightarrow{v}_i \\ \vdots \\ \sum_{i=0}^{n} s_{m,i} \overrightarrow{v}_i \\ \end{pmatrix} \end{array}\]

Instead of a vector computation for each token $i$, the input matrix $X \in \mathbb{R}^{l \times d}$, where $l$ is the maximum length of the sentence and $d$ is the dimension of the inputs, combines with each of the query, key, and value matrices as a single computation given by:

\[\text{attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V\]

where \(d_k=64\) in our case. \(Q, K\) and \(V\) weight matrices are randomly initialized and the weights are jointly learned from the training process.

This self-attention layer helps the model capture the context of the word in its representation for example, it may be about learning grammar, tense, conjugation, etc..

Multi-Headed Attention

Instead of a single self-attention head, there can be \(h\) parallel self-attention heads; this is known as multi-head attention. Multi-head attention involves multiple sets of query/key/value weight matrices, each resulting in different query/key/value matrices for the inputs. These output matrices from each head are concatenated and multiplied with an additional weight matrix, $W_O$.

In the original transformer paper, the authors used \(h = 8\) heads. Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. With a single attention head, averaging inhibits this.

\[\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)W_O\]

where

\[\text{head}_i = \text{attention}(QW_{i}^{Q}, KW_{i}^{K}, VW_{i}^{V})\]

The projections are parameter matrices:

\(W_{i}^{Q} \in \mathbb{R}^{d_{\text{model}} \times d_k}, \quad W_{i}^{K} \in \mathbb{R}^{d_{\text{model}} \times d_k}, \quad W_{i}^{V} \in \mathbb{R}^{d_{\text{model}} \times d_v}\) and \(W_O \in \mathbb{R}^{hd_v \times d_{\text{model}}}\)

For each self-attention head, the authors used \(d_k = d_v = \frac{d_{\text{model}}}{h} = 64\). Due to the reduced dimension of each head, the total computational cost is similar to that of single-head attention with full dimensionality.

Each head can learn something different, for example, it may be about learning grammar, tense, conjugation, etc. For instance, one head can focus on subject-verb interaction, whereas another finds nearby adjectives. It also helps the model expand the focus to different positions. Obviously we don’t handcraft these relations into the model, and they are fully learned from the data. If you are familiar with computer vision models you might see the resemblance to filters in convolutional neural networks, where one filter can be responsible for detecting faces and another one finds wheels of cars in images.

Positional Encoding

Word order and positions play a crucial role in most of the NLP tasks. By taking one word at a time, recurrent neural networks essentially incorporate word order. In transformer architecture, to gain speed and parallelism, recurrent neural networks are replaced by multi-head attention layers. The multi-head attention layer is effectively a weighted sum and does not take into account token position. Thus it becomes necessary to explicitly pass the information about the word order to the model layer as one way of capturing it. There are several ways to achieve this. We can use a learnable pattern, especially when the pretraining dataset is sufficiently large. This works exactly the same way as the token embeddings, but using the position index instead of the token ID as input. The final output embedding is simply the sum of the input embeddings and positional embedding. While learnable position embeddings are easy to implement and widely used, there are some alternatives.

  1. Absolute positional representations: Transformer models can use static patterns consisting of modulated sine and cosine signals to encode the positions of the tokens. This works especially well when there are not large volumes of data available. If the length of the sentence is given by $l$ and the embedding dimension/depth is given by $d$, positional encoding $P$ is a 2-D matrix of the same dimension, i.e., $P \in \mathbb{R}^{l \times d}$. Every position can be represented with equations in terms of $i$ (along $l$) and $j$ (along $d$) dimensions as: $$ \begin{align} P_{i,2j} &= \sin\left(\frac{i}{10000^{2j/d}}\right) \\ P_{i,2j+1} &= \cos\left(\frac{i}{10000^{2j/d}}\right) \end{align} $$ for $i = 0, \ldots, l-1$ and $j = 0, \ldots, \left\lfloor \frac{d-1}{2} \right\rfloor$. We use the sin function when $j$ is even and the cos function when $j$ is odd. The function definition above indicates that the frequencies are decreasing along the vector dimension and form a geometric progression from $2\pi$ to $10000 \cdot 2\pi$ on the wavelengths. The two matrices, i.e., the word embeddings $W$ and the positional encoding $P$, are added to generate the input representation $X = W + P \in \mathbb{R}^{l \times d}$.

  2. Relative positional representations: Relative positional representations encode the relative positions between tokens. Models such as DeBERTa use such representations.

References