Authors: P. Xu , S. Prince

In part I of this tutorial we introduced the self-attention mechanism and the transformer architecture. In part II, we discussed position encoding and how to extend the transformer to longer sequence lengths. We also discussed connections between the transformer and other machine learning models.

In this final part, we discuss challenges with transformer training dynamics and introduce some of the tricks that practitioners use to get transformers to converge. This discussion will be suitable for researchers who already understand the transformer architecture, and who are interested in training transformers and similar models from scratch.

Tricks for training transformers

Despite their broad applications, transformers are surprisingly difficult to train from scratch. One of the contributions of the original transformer paper was to use four tricks that collectively allow stable training:

1. Residual connections: Each transformer layer takes the $I\times D$ data matrix $\mathbf{X}$ where $I$ is the number of inputs and $D$ the dimensionality of those inputs and returns an object of the same size. It performs the following operations:

\begin{eqnarray}
    \mathbf{X} &\leftarrow& \mathbf{X} + \bf{MhSa}[\mathbf{X}] \nonumber \\
    \mathbf{X} &\leftarrow& \bf{Layernorm}[\mathbf{X}] \hspace{3cm}\nonumber\\
    \mathbf{x}_{i} &\leftarrow& \mathbf{x}_{i}+\bf{mlp}[\mathbf{x}_{i}] \hspace{3.6cm}\forall\; i\in\{1\ldots I\}\nonumber\\
    \mathbf{X} &\leftarrow& \bf{Layernorm}[\mathbf{X}], \tag{1}
\end{eqnarray}

which include two residual connections around the multi-head self-attention $\bf{MhSa}[\bullet]$ and multi-layer perceptron $\bf{mlp}[\bullet]$ components (figure 1).

Figure 1. The transformer layer. The input consists of a $I\times D$ matrix containing the $D$ dimensional embeddings for each of the $I$ input tokens. The output is a matrix of the same size. The transformer layer consists of a series of operations. First, there is a multi-head attention block which allows the embeddings to interact with one another. This is in a residual network, so the original inputs are added back to the output. Second, a layer norm operation is applied. Third, there is a second residual network where the same two-layer fully-connected neural network is applied to each representation separately. Finally, layer-norm is applied again.

2. Layer normalization: After each residual connection, a layer normalization procedure is applied:

\begin{equation}
    \bf Layernorm[\mathbf{X}] = \gamma\cdot \frac{\mathbf{X}-\mu}{\sigma}+\beta, \tag{2}
\end{equation}

where $\mu$ and $\sigma$ are the mean and standard deviation of the elements of $\mathbf{X}$ (but are separate for each member of the batch), and $\gamma$ and $\beta$ are learned parameters.

3. Learning rate warm-up: The learning rate is increased linearly from $0$ to $R$ over first $T_{R}$ time steps so that:

\begin{equation}
    \mbox{lr}[t] = R \cdot \frac{t}{T_{R}}. \tag{3}
\end{equation}

4. Adaptive optimizers: Transformers need to be trained with adaptive optimizers like Adam, which recursively estimates the momentum and the learning rate separately for each parameter at each time-step. In practice, relatively large batch sizes of $>1,000$ are usually employed.

Removing any of these tricks makes training unstable and often leads to complete training failures. However, they have been employed without a full understanding of why they are required.

As transformers are applied more widely, it is increasingly important that we have a better understanding of transformer training. To this end, a number of recent papers have been devoted to demystifying this topic and exploring better training methods. In the rest of this blog post, we will connect these separate efforts to form a comprehensive overview of this topic.

Why are these tricks required?

In this section we will review some tricks and see that there are complex dependencies between them: Some tricks cause problems, which are in turn resolved by others. We will see that there are complex dependencies between them, so that some of the tricks cause problems, which are in turn resolved by others. In the subsequent section we will discuss improvements to the training process that follow from this understanding.

Learning rate warm-up

Learning rate warm-up (in which the learning rate is gradually increased during the early stages of training) is particularly puzzling. This is not required for most deep learning architectures. However, training fails for transformers if we just start with a typical learning rate. If we start with a very small learning rate, then the training is stable, but then it takes an impractically long time.

Xiong et al., 2020 explored this phenomenon by conducting experiments on a machine translation task with different optimizers and learning rate schedules. Their results (figure 2) show that learning rate warm-up is essential for both Adam and SGD, and that the training process is sensitive to the warm-up steps.

Figure 2. Importance of learning rate warm-up for transformers. a) Validation loss for Adam optimizer for IWSLT14 De-En task. Performance is poor (curves are higher) without warm-up or if learning rate increases too quickly. Performance is much better if learning rate is gradually increased. b) Validation loss for the SGD optimizer for IWSLT14 De-En task. A similar pattern is observed, but the overall results are much worse than for Adam. Adapted from Xiong et al., 2020.

Although learning rate warm-up works, it has some obvious disadvantages. It introduces an extra hyper-parameter (the number of warm-up steps) and it initializes the learning rate to zero which slows the training down. Hence, it's important that we understand why it is necessary.

To help answer this question, Huang et al., 2020 visualized the gradient of the loss $\mathcal{L}$ with respect to the input embeddings $\mathbf{X}$, and the size of the Adam updates during the first 100 steps of training (figure 3). They found that without warm-up, the gradients vanish very quickly, and the Adam updates also rapidly become much smaller. Diminishing gradients at lower layers in the transformer model without warm-up have also been observed by Liu et al., 2020.

Figure 3. Removing learning rate warm-up makes gradients vanish. a) Histogram of magnitude of gradient of loss function with respect to inputs during first 100 steps of training. When we use learning rate warm-up, these gradients remain constant. b) Without warm-up the gradients vanish after just a few steps. c-d) We observe a similar pattern of behaviour in the histograms of Adam update sizes with and without warm-up. Adapted from Huang et al., 2020.

Residual connections, layer normalization and Adam

To understand why learning rate warm-up is required, and why the gradients vanish without it, we will first need to understand the reasons for, and the consequences of using residual connections, layer normalization, and Adam.

Residual networks were developed in computer vision; they make networks easier to optimize and allow deeper networks to be trained. In computer vision, the additive residual connections are usually placed around convolutional layers and combined with batch normalization. In the transformer, they are placed around the self-attention and feed-forward networks and combined with layer normalization (figure 1). From this perspective, the transformer architecture could be considered a "deep residual self-attention network".

Figure 4. Residual connections, layer normalization, and Adam. We start from the premise that we would like to train deep networks on NLP data that allow the embeddings to interact. To train deep networks, we need residual connections, and these cause gradient explosion, which must be resolved by normalization. Transformers use layer normalization (which normalizes each member of the batch independently) because the batch statistics of language data exhibit large fluctuations. To allow the embeddings to interact efficiently, we use self-attention, but the gradients of the parameters in the self-attention module are unbalanced (i.e., some are much larger than others). This necessitates the use of Adam, which uses a separate learning rate for each parameter.

Zhang et al., 2019 show that the output variance of residual networks grows exponentially with depth. Hence, normalization is used to prevent gradient explosion for deep residual networks. Layer normalization is used in the transformer because the statistics of language data exhibit large fluctuations across the batch dimension, and this leads to instability in batch normalization. 

Transformers also differ from convolutional networks in that stochastic gradient descent does not work well for training (figure 2) and adaptive optimizers like Adam are required. Liu et al., 2020 observed that differentiating through the self-attention mechanism creates unbalanced gradients. In particular, the gradients for the query $\boldsymbol\Phi_{q}$ and key $\boldsymbol\Phi_{k}$ parameters were much smaller than those for the value parameters $\boldsymbol\Phi_{v}$, and so the former parameters change much more slowly. This is a direct consequence of the mathematical expression for self-attention. The Adam optimizer fixes this problem by essentially having different learning rates for each parameter.

To conclude, we've seen that residual connections are needed to allow us to train deep networks. These cause gradient explosion, which is resolved by using layer normalization. The self-attention computation causes unbalanced gradients, which necessitates the use of Adam (figure 4). In the next section, we'll see that layer normalization and Adam themselves cause more problems, which ultimately result in the need for learning rate warm-up.

Gradient shrinkage effect

Xiong et al., 2020 found that the magnitude of the gradients through layer normalization is inversely proportional to magnitude of the input. Specifically, the gradient has the following property:

\begin{equation}
    \left\lVert \frac{\partial \bf Layernorm[\mathbf{X}]}{\partial \mathbf{X}} \right\rVert=\mathcal{O}\left(\frac{\sqrt{D}}{\lVert\mathbf{X}\rVert}\right), \tag{4}
\end{equation}

where $\mathbf{X}$ is the input to layer normalization and $D$ is the embedding dimension. If the input norm $\lVert \mathbf{X} \rVert$ is larger than $\sqrt{D}$ then back-propagating through layer normalization reduces the gradient magnitude in lower layers. As this effect compounds through multiple layers, it causes the gradient to vanish at lower layers for deep models.  We will term this the gradient shrinkage effect.

Unbalanced dependencies and amplified output perturbations

Layer normalization also causes unbalanced dependencies between the two branches of the residual connection around the self-attention module. In other words, the output of $\bf LayerNorm[\mathbf{X}+\bf Sa[\mathbf{X}]]$ depends much more on the self-attention computation $\bf Sa[\mathbf{X}]$ than the skip connection $\mathbf{X}$. This means that the outputs depend much more on later layers than earlier layers. Liu et al., 2019 show that this happens empirically in practice.

Moreover, they show that this leads to amplified output perturbations; small changes to the network parameters cause large output fluctuations. More precisely, they proved that for a transformer network $\bf T_{N}[\mathbf{X},\boldsymbol\theta]$ with parameters $\boldsymbol\theta$, the output variance scales with the number of layers $N$ when we randomly perturb the parameters to $\boldsymbol\theta^{*} = \boldsymbol\theta+\boldsymbol\delta$:

\begin{equation}
\mbox{Var}\left[\bf T_{N}[\mathbf{X};\boldsymbol\theta] - \bf T_{N}[\mathbf{X};\boldsymbol\theta^{*}]\right] = \mathcal{O}(N), \tag{5}
\end{equation}

They also show that this happens empirically with both random parameter changes and Adam updates (figure 5). The result is that the output changes more and more when we update the parameters, which destabilizes transformer training.

Figure 5. Amplified output perturbations. Consider a transformer network $\bf T_{N}[\mathbf{X},\boldsymbol\theta]$ with $N$ layers and parameters $\boldsymbol\theta$. a) A random perturbation is added to the parameters $\boldsymbol\theta$ to create new parameters $\boldsymbol\theta^{*}$. The sum of squares of the change in output $\left|\bf T_{N}[\mathbf{X};\boldsymbol\theta] - \bf T_{N}[\mathbf{X};\boldsymbol\theta^{*}]\right|_{2}^{2}$ increases linearly with the number of layers. b) The same effect is observed when we apply Adam updates (note log scale). Adapted from Liu et al., 2019.

Aggravating effect of Adam

Furthermore, using adaptive optimizers like Adam aggravates both the gradient shrinkage effect and the amplified output perturbations. Liu et al., 2019 show that the variance of the Adam updates is unbounded at the start of training, and these updates are also known to exhibit high variance in the early stages of training.

This can lead to problematic large updates early on which can make the input norm $\Vert \mathbf{X} \rVert$ to each layer increase as we move through the network and thus increased gradient shrinkage as predicted by equation 2.3. Moreover, the output fluctuation which is already amplified by the network structure will be even greater for these large parameter updates.

Why learning rate warm-up helps

To summarize, residual connections are required in the transformer architecture for the ease of optimization, which further requires layer normalization to avoid gradient explosion and adaptive optimizers like Adam to address unbalanced gradients in the self-attention blocks. On the flip side, the use of layer normalization causes the gradients to shrink in the early layers and also amplifies the output perturbations. Moreover, the instability of Adam in the early stages of training exacerbates both of these effects (figure 6).

Figure 6. Consequences of using layer normalization and Adam. Layer normalization in transformer networks causes unbalanced dependencies between the two branches of the residual blocks. This in turn leads to amplified output perturbations; the output becomes increasingly sensitive to changes in the parameters as the depth of the network grows. Layer normalization also means that the gradient decreases proportionally to the output norm of each transformer block. This gradient shrinkage aggregates as we back-propagate through the network and means the gradients become very small for the early layers. In the early stages of training the Adam updates are unstable and can be large; this in turn causes the output of the later layers to fluctuate more causing large output norms which in turn causes more gradient shrinkage. The solution is to use learning rate warm-up which effectively damps the Adam updates during the unstable early phase.

This is where learning rate warm-up comes in: it effectively stabilizes the Adam updates during the early stages of training by making the parameter changes much smaller. Consequently, Adam no longer aggravates gradient shrinkage and amplification of output perturbations and training becomes relatively stable.

Better methods for training transformers

In the previous section, we argued that the transformer architecture, and the statistics of language data require us to use layer normalization and train with adaptive optimizers like Adam. These choices in turn cause other problems that are resolved by using learning rate warm-up. In this section, we consider alternative methods for training deep transformers that don't require learning rate warm-up.

We'll consider three approaches that respectively remove the normalization from the network, attempt to re-balance the dependency on the two paths of the residual networks, and reduce the variance of the optimizer updates.

Removing layer normalization

Since both the problems of gradient shrinkage and unbalanced dependencies are directly connected to layer normalization, it is natural to question whether we can train deep transformer models without this step. It has indeed been demonstrated that this is possible, and that we can achieve even better generalization without layer normalization.

Recall that the normalization mechanism is introduced to prevent gradients exploding in deep residual networks. It follows that if we can stabilize the gradient updates $\Delta \boldsymbol\theta$, then we can remove layer normalization. Zhang et al., 2019 demonstrated that the gradient updates $\Delta \pmb \theta$ can be bounded when using the SGD optimizer to train residual MLP or convolution blocks by appropriately initializing the weights. Based on this work, Huang et al., 2020 derived an analogous initialization scheme for residual self-attention blocks.

Although the theoretical derivations are for SGD updates, these results hold well for adaptive optimizers like Adam in practice. Furthermore, it follows from the Taylor expansion:

\begin{equation}\label{eq:taylor}
    \Delta\bf T[\mathbf{X},\boldsymbol\theta]  \approx \frac{\partial \bf T[\mathbf{X},\boldsymbol\theta]}{\partial \pmb\theta} \Delta \pmb\theta, \tag{6}
\end{equation}

that the output fluctuation $\Delta f$ is also bounded by bounding the gradient updates $\Delta\boldsymbol\theta$. As a result, both the gradient vanishing and the amplified output perturbations are resolved with stable gradient updates.

The proposed initialization scheme is known as T-Fixup and is easy to implement. Consider a multi-head self-attention block where the $h^{th}$ head computes

\begin{equation}
    {\bf Sa}_{h}[\mathbf{X}] =\bf Softmax\left[\frac{(\mathbf{X}\boldsymbol\Phi_{qh})(\mathbf{X}\boldsymbol\Phi_{kh})^{T}}{\sqrt{d_{q}}}\right]\mathbf{X}\boldsymbol\Phi_{vh}. \tag{7}
\end{equation}

where $\mathbf{X}$ is the input data matrix containing word embeddings in its rows and $\boldsymbol\Phi_{qh}$, $\boldsymbol\Phi_{kh}$ and $\boldsymbol\Phi_{vh}$ are the weight parameters for the queries, keys, and values respectively. The outputs of these self-attention mechanisms are concatenated and another linear transform $\boldsymbol\Phi_{c}$ is applied to combine them:

\begin{equation}
    {\bf MhSa}[\mathbf{X}] = \left[{\bf Sa}_{1}[\mathbf{X}]\;{\bf Sa}_{2}[\mathbf{X}]\;\ldots\;{\bf Sa}_{H}[\mathbf{X}] \right]\boldsymbol\Phi_{c}. \tag{8}
\end{equation}

The T-Fixup scheme for encoder decoder attention is then as follows:

  1. Apply Xavier initialization for all parameters excluding input embeddings. Use Gaussian initialization $\mbox{Norm}_{\theta}[0, D^{-\frac12}]$ for input embeddings where $D$ is the embedding dimension.
  2. Scale $\boldsymbol\Phi_{vh}$ and $\boldsymbol\Phi_{c}$ in each encoder attention block and weight matrices in each encoder MLP block by $0.67N_{e}^{-\frac14}$ where $N_{e}$ is the number of transformer blocks (i.e, self-attention + MLP)  in the encoder.  Scale the input embeddings to the encoder by  $(9N_{e})^{-\frac14}$
  3. Scale parameters $\boldsymbol\Phi_{vh}$ and $\boldsymbol\Phi_{c}$ in each decoder attention block, weight matrices in each decoder MLP block and input embeddings in the decoder by $(9N_d)^{-\frac14}$ where $N_d$ is the number of transformer blocks in the decoder.

In practice, T-Fixup is able to train significantly deeper transformer models with improved performance on the task of machine translation. For the detailed derivation of this method, we refer the readers to the original paper.

Balancing the residual dependencies

An alternative approach is to balance the residual dependencies, which in turn will limit the output perturbations $\Delta \bf T[\mathbf{X}]$. Equation 6 shows that controlling the magnitude of the output fluctuation $\Delta \bf T[\mathbf{X}]$ also bounds the magnitude of the gradient updates $\Delta \boldsymbol\theta$, which in turn mitigates the problem of gradient vanishing. Here we'll consider three possible approaches.

Pre-LN transformers: One simple solution is to change the location of layer normalization inside the transformer layer so that it occurs inside the residual blocks and before the self-attention or MLP (figure 7). This is known as the pre-LN transformer. This simple change can help control the gradient magnitude and balance the residual dependencies.

Pre-LN transformer models can be trained without learning rate warm-up. However, they also lead to inferior empirical performance. It has been speculated that this is because now the models are restricted not to depend too much on the contents of their residual layers Liu et al., 2020.

Figure 7. Moving the layer normalization to balance residual dependencies. (a) Original transformer layer (b) The pre-LN transformer layer moves the layer norm operation inside the residual block and before the self-attention / MLP blocks.

Admin: To bridge this performance gap, adaptive model initialization or Admin aims to bound the output fluctuation $\Delta \bf T[\mathbf{X}]$ by controlling the residual dependencies while retaining the original architecture.

Admin adds a new parameter $1\times D$ parameter vector $\boldsymbol\psi$ to each residual block. The self-attention block is then constructed as $\bf LayerNorm[\mathbf{X} \odot \boldsymbol\Psi + \bf MhSa[\mathbf{X}]]$ where $\odot$ is the element-wise product and $\boldsymbol\Psi$ is an $I\times D$ matrix where each row is a copy of $\boldsymbol\psi$. The residual connection around the parallel MLP layer is treated in the same way (figure 8a).

The new parameters at the $n^{th}$ layer are initialized to be the output standard deviation at that layer before this intervention. This can be estimated by setting all elements of $\boldsymbol\psi$ to one and forward propagating on a few training instances.

ReZero: In a similar vein, ReZero removes the layer normalization and introduces a single trainable parameter $\alpha$ per residual layer so that the self-attention block residual layer becomes, $\mathbf{X} + \alpha\bf MhSa[\mathbf{X}]$, where $\alpha$ is initialized to zero (figure 8b). The result of this is that the entire network is initialized just to compute the identity function, and the contributions of the self-attention and MLP layers are gradually and adaptively introduced.

Empirically, both Admin and ReZero work well for training deeper transformer models with better generalization performance, which demonstrates the effectiveness of balancing the residual dependencies.

Figure 8. Balancing dependencies by adding extra parameters in the residual layers. a) The Admin method pointwise multiplies the input $\mathbf{X}$ being passed through the residual connection around the self-attention layer by a matrix $\boldsymbol\Psi_{sa}$ where each row of this matrix contains the same vector $\boldsymbol\psi_{sa}$. With suitable initialization, this balances the contributions of the two paths. The MLP residual block is treated similarly, but with distinct parameters $\boldsymbol\Phi_{mlp}$. b) The ReZero method removes the LayerNorm altogether and multiplies the output of the self-attention layer and the output of the MLP layer to by a shared learned parameter $\alpha$. The parameter $\alpha$ is initialized to zero, so this method initializes the entire network to the compute the identity and allows the outputs of the two branches of the residual blocks to be adaptively balanced.

Reducing the optimizer variance

We noted before that the high variance of learning rates in the Adam optimizer at the early stages of training exacerbates the problems of amplified output perturbations and gradient vanishing. Liu et al., (2019) argue that this is due to the lack of samples in the early stages of learning. They base their argument on an experiment in which they do not change the model parameters or momentum term of Adam for the first 2000 learning steps, but only adapt the learning rate. After this, warm-up is no longer required.

Based on these observations, they propose Rectified Adam or RAdam which gradually changes the momentum term over time in a way that helps avoid high variance. One way to think of this is that we have effectively incorporated learning rate warm-up into the Adam algorithm, but in a principled way.

How to train deeper transformers on small datasets

In the previous sections, we have seen that great progress has been made towards understanding transformer training. Several solutions have been proposed that allow the training of significantly deeper transformer models with improved empirical performance.

However, they have only been applied to tasks with sufficient training data such as machine translation and language modelling. This is possibly due to the commonly-held belief that training deep transformers from scratch requires large datasets. For small datasets, it is typical just to add shallow and simple additional layers (e.g., a classifier head) to pre-trained models and then fine-tune.

So, what prevents practitioners from training deep transformers on small datasets? It turns out that the final missing piece of the puzzle is the batch size. For small datasets, it's necessary to leverage large pre-trained models and then fine-tune. However, the size of these models limits the batch size and when the batch size is small, the variance of the updates is even larger, which makes training even harder. Even if we could use a larger batch size, it usually results in poorer generalization, especially on small datasets.

In short, small datasets require pre-trained models and small batch sizes to perform well, but these two requirements make training additional transformer layers challenging. To resolve the high variance of training updates in small batch sizes, the three ideas from the previous section can all be applied. However, these approaches all assume that the inputs to the transformers are randomly initialized embeddings, but this is not true if we are adding yet-to-be-trained transformers on top of pre-trained models (figure 9).

Figure 9. Architecture on which DT-Fixup can be applied. The pre- and post-transformer modules can be any architectures that can be stably trained with Adam, including MLP, RNN, CNN, or a pre-trained transformer model which can be stably fine-tuned with a small learning rate.

DT-Fixup is a data-dependent initialization strategy developed by Borealis AI. It adapts the T-Fixup method for this type of mixed setting. DT-Fixup allows significantly deeper transformers to be trained with small datasets for challenging tasks such as Text-to-SQL semantic parsing and logical reading comprehension. This demonstrates that training deep transformers with small datasets is feasible with the correct optimization procedure.

Conclusion

In the first two parts of this blog, we introduced the transformer, and discussed extensions and relations to other models. In this final part, we have discussed the complex topic of how to train deep transformer models effectively. We hope that this discussion will help practitioners to train deeper transformers for different applications and help identify potential directions for further improving transformer training.

Authors