# Tutorial #17: Transformers III: Training

In part I of this tutorial we introduced the self-attention mechanism and the transformer architecture. In part II, we discussed position encoding and how to extend the transformer to longer sequence lengths. We also discussed connections between the transformer and other machine learning models.

In this final part, we discuss challenges with transformer training dynamics and introduce some of the tricks that practitioners use to get transformers to converge. This discussion will be suitable for researchers who already understand the transformer architecture, and who are interested in training transformers and similar models from scratch.

## Tricks for training transformers

Despite their broad applications, transformers are surprisingly difficult to train from scratch. One of the contributions of the original transformer paper was to use four tricks that collectively allow stable training:

**1. Residual connections:** Each transformer layer takes the $I\times D$ data matrix $\mathbf{X}$ where $I$ is the number of inputs and $D$ the dimensionality of those inputs and returns an object of the same size. It performs the following operations:

\begin{eqnarray}

\mathbf{X} &\leftarrow& \mathbf{X} + \bf{MhSa}[\mathbf{X}] \nonumber \\

\mathbf{X} &\leftarrow& \bf{Layernorm}[\mathbf{X}] \hspace{3cm}\nonumber\\

\mathbf{x}_{i} &\leftarrow& \mathbf{x}_{i}+\bf{mlp}[\mathbf{x}_{i}] \hspace{3.6cm}\forall\; i\in\{1\ldots I\}\nonumber\\

\mathbf{X} &\leftarrow& \bf{Layernorm}[\mathbf{X}], \tag{1}

\end{eqnarray}

which include two residual connections around the multi-head self-attention $\bf{MhSa}[\bullet]$ and multi-layer perceptron $\bf{mlp}[\bullet]$ components (figure 1).

**2. Layer normalization:** After each residual connection, a layer normalization procedure is applied:

\begin{equation}

\bf Layernorm[\mathbf{X}] = \gamma\cdot \frac{\mathbf{X}-\mu}{\sigma}+\beta, \tag{2}

\end{equation}

where $\mu$ and $\sigma$ are the mean and standard deviation of the elements of $\mathbf{X}$ (but are separate for each member of the batch), and $\gamma$ and $\beta$ are learned parameters.

**3. Learning rate warm-up:** The learning rate is increased linearly from $0$ to $R$ over first $T_{R}$ time steps so that:

\begin{equation}

\mbox{lr}[t] = R \cdot \frac{t}{T_{R}}. \tag{3}

\end{equation}

**4. Adaptive optimizers:** Transformers need to be trained with adaptive optimizers like *Adam*, which recursively estimates the momentum and the learning rate separately for each parameter at each time-step. In practice, relatively large batch sizes of $>1,000$ are usually employed.

Removing any of these tricks makes training unstable and often leads to complete training failures. However, they have been employed without a full understanding of why they are required.

As transformers are applied more widely, it is increasingly important that we have a better understanding of transformer training. To this end, a number of recent papers have been devoted to demystifying this topic and exploring better training methods. In the rest of this blog post, we will connect these separate efforts to form a comprehensive overview of this topic.

## Why are these tricks required?

In this section we will review some tricks and see that there are complex dependencies between them: Some tricks cause problems, which are in turn resolved by others. We will see that there are complex dependencies between them, so that some of the tricks cause problems, which are in turn resolved by others. In the subsequent section we will discuss improvements to the training process that follow from this understanding.

## Learning rate warm-up

Learning rate warm-up (in which the learning rate is gradually increased during the early stages of training) is particularly puzzling. This is not required for most deep learning architectures. However, training fails for transformers if we just start with a typical learning rate. If we start with a very small learning rate, then the training is stable, but then it takes an impractically long time.

Xiong *et* al., 2020 explored this phenomenon by conducting experiments on a machine translation task with different optimizers and learning rate schedules. Their results (figure 2) show that learning rate warm-up is essential for both Adam and SGD, and that the training process is sensitive to the warm-up steps.

Although learning rate warm-up works, it has some obvious disadvantages. It introduces an extra hyper-parameter (the number of warm-up steps) and it initializes the learning rate to zero which slows the training down. Hence, it's important that we understand why it is necessary.

To help answer this question, Huang *et* al., 2020 visualized the gradient of the loss $\mathcal{L}$ with respect to the input embeddings $\mathbf{X}$, and the size of the Adam updates during the first 100 steps of training (figure 3). They found that without warm-up, the gradients vanish very quickly, and the Adam updates also rapidly become much smaller. Diminishing gradients at lower layers in the transformer model without warm-up have also been observed by Liu *et* al., 2020.

## Residual connections, layer normalization and Adam

To understand why learning rate warm-up is required, and why the gradients vanish without it, we will first need to understand the reasons for, and the consequences of using residual connections, layer normalization, and Adam.

Residual networks were developed in computer vision; they make networks easier to optimize and allow deeper networks to be trained. In computer vision, the additive residual connections are usually placed around convolutional layers and combined with batch normalization. In the transformer, they are placed around the self-attention and feed-forward networks and combined with layer normalization (figure 1). From this perspective, the transformer architecture could be considered a "deep residual self-attention network".

Zhang *et* al., 2019 show that the output variance of residual networks grows exponentially with depth. Hence, normalization is used to prevent *gradient explosion* for deep residual networks. Layer normalization is used in the transformer because the statistics of language data exhibit large fluctuations across the batch dimension, and this leads to instability in batch normalization.

Transformers also differ from convolutional networks in that stochastic gradient descent does not work well for training (figure 2) and adaptive optimizers like Adam are required. Liu *et* al., 2020 observed that differentiating through the self-attention mechanism creates *unbalanced gradients*. In particular, the gradients for the query $\boldsymbol\Phi_{q}$ and key $\boldsymbol\Phi_{k}$ parameters were much smaller than those for the value parameters $\boldsymbol\Phi_{v}$, and so the former parameters change much more slowly. This is a direct consequence of the mathematical expression for self-attention. The Adam optimizer fixes this problem by essentially having different learning rates for each parameter.

To conclude, we've seen that residual connections are needed to allow us to train deep networks. These cause gradient explosion, which is resolved by using layer normalization. The self-attention computation causes unbalanced gradients, which necessitates the use of Adam (figure 4). In the next section, we'll see that layer normalization and Adam themselves cause more problems, which ultimately result in the need for learning rate warm-up.

## Gradient shrinkage effect

Xiong *et* al., 2020 found that the magnitude of the gradients through layer normalization is inversely proportional to magnitude of the input. Specifically, the gradient has the following property:

\begin{equation}

\left\lVert \frac{\partial \bf Layernorm[\mathbf{X}]}{\partial \mathbf{X}} \right\rVert=\mathcal{O}\left(\frac{\sqrt{D}}{\lVert\mathbf{X}\rVert}\right), \tag{4}

\end{equation}

where $\mathbf{X}$ is the input to layer normalization and $D$ is the embedding dimension. If the input norm $\lVert \mathbf{X} \rVert$ is larger than $\sqrt{D}$ then back-propagating through layer normalization reduces the gradient magnitude in lower layers. As this effect compounds through multiple layers, it causes the gradient to vanish at lower layers for deep models. We will term this the *gradient shrinkage effect*.

## Unbalanced dependencies and amplified output perturbations

Layer normalization also causes *unbalanced dependencies* between the two branches of the residual connection around the self-attention module. In other words, the output of $\bf LayerNorm[\mathbf{X}+\bf Sa[\mathbf{X}]]$ depends much more on the self-attention computation $\bf Sa[\mathbf{X}]$ than the skip connection $\mathbf{X}$. This means that the outputs depend much more on later layers than earlier layers. Liu *et* al., 2019 show that this happens empirically in practice.

Moreover, they show that this leads to *amplified output perturbations*; small changes to the network parameters cause large output fluctuations. More precisely, they proved that for a transformer network $\bf T_{N}[\mathbf{X},\boldsymbol\theta]$ with parameters $\boldsymbol\theta$, the output variance scales with the number of layers $N$ when we randomly perturb the parameters to $\boldsymbol\theta^{*} = \boldsymbol\theta+\boldsymbol\delta$:

\begin{equation}

\mbox{Var}\left[\bf T_{N}[\mathbf{X};\boldsymbol\theta] - \bf T_{N}[\mathbf{X};\boldsymbol\theta^{*}]\right] = \mathcal{O}(N), \tag{5}

\end{equation}

They also show that this happens empirically with both random parameter changes and Adam updates (figure 5). The result is that the output changes more and more when we update the parameters, which destabilizes transformer training.

## Aggravating effect of Adam

Furthermore, using adaptive optimizers like Adam aggravates both the gradient shrinkage effect and the amplified output perturbations. Liu *et* al., 2019 show that the variance of the Adam updates is unbounded at the start of training, and these updates are also known to exhibit high variance in the early stages of training.

This can lead to problematic large updates early on which can make the input norm $\Vert \mathbf{X} \rVert$ to each layer increase as we move through the network and thus increased gradient shrinkage as predicted by equation 2.3. Moreover, the output fluctuation which is already amplified by the network structure will be even greater for these large parameter updates.

## Why learning rate warm-up helps

To summarize, residual connections are required in the transformer architecture for the ease of optimization, which further requires layer normalization to avoid gradient explosion and adaptive optimizers like Adam to address unbalanced gradients in the self-attention blocks. On the flip side, the use of layer normalization causes the gradients to shrink in the early layers and also amplifies the output perturbations. Moreover, the instability of Adam in the early stages of training exacerbates both of these effects (figure 6).

This is where learning rate warm-up comes in: it effectively stabilizes the Adam updates during the early stages of training by making the parameter changes much smaller. Consequently, Adam no longer aggravates gradient shrinkage and amplification of output perturbations and training becomes relatively stable.

## Better methods for training transformers

In the previous section, we argued that the transformer architecture, and the statistics of language data require us to use layer normalization and train with adaptive optimizers like Adam. These choices in turn cause other problems that are resolved by using learning rate warm-up. In this section, we consider alternative methods for training deep transformers that don't require learning rate warm-up.

We'll consider three approaches that respectively remove the normalization from the network, attempt to re-balance the dependency on the two paths of the residual networks, and reduce the variance of the optimizer updates.

## Removing layer normalization

Since both the problems of gradient shrinkage and unbalanced dependencies are directly connected to layer normalization, it is natural to question whether we can train deep transformer models without this step. It has indeed been demonstrated that this is possible, and that we can achieve even better generalization without layer normalization.

Recall that the normalization mechanism is introduced to prevent gradients exploding in deep residual networks. It follows that if we can stabilize the gradient updates $\Delta \boldsymbol\theta$, then we can remove layer normalization. Zhang *et* al., 2019 demonstrated that the gradient updates $\Delta \pmb \theta$ can be bounded when using the SGD optimizer to train residual MLP or convolution blocks by appropriately initializing the weights. Based on this work, Huang *et* al., 2020 derived an analogous initialization scheme for residual self-attention blocks.

Although the theoretical derivations are for SGD updates, these results hold well for adaptive optimizers like Adam in practice. Furthermore, it follows from the Taylor expansion:

\begin{equation}\label{eq:taylor}

\Delta\bf T[\mathbf{X},\boldsymbol\theta] \approx \frac{\partial \bf T[\mathbf{X},\boldsymbol\theta]}{\partial \pmb\theta} \Delta \pmb\theta, \tag{6}

\end{equation}

that the output fluctuation $\Delta f$ is also bounded by bounding the gradient updates $\Delta\boldsymbol\theta$. As a result, both the gradient vanishing and the amplified output perturbations are resolved with stable gradient updates.

The proposed initialization scheme is known as *T-Fixup* and is easy to implement. Consider a multi-head self-attention block where the $h^{th}$ head computes

\begin{equation}

{\bf Sa}_{h}[\mathbf{X}] =\bf Softmax\left[\frac{(\mathbf{X}\boldsymbol\Phi_{qh})(\mathbf{X}\boldsymbol\Phi_{kh})^{T}}{\sqrt{d_{q}}}\right]\mathbf{X}\boldsymbol\Phi_{vh}. \tag{7}

\end{equation}

where $\mathbf{X}$ is the input data matrix containing word embeddings in its rows and $\boldsymbol\Phi_{qh}$, $\boldsymbol\Phi_{kh}$ and $\boldsymbol\Phi_{vh}$ are the weight parameters for the queries, keys, and values respectively. The outputs of these self-attention mechanisms are concatenated and another linear transform $\boldsymbol\Phi_{c}$ is applied to combine them:

\begin{equation}

{\bf MhSa}[\mathbf{X}] = \left[{\bf Sa}_{1}[\mathbf{X}]\;{\bf Sa}_{2}[\mathbf{X}]\;\ldots\;{\bf Sa}_{H}[\mathbf{X}] \right]\boldsymbol\Phi_{c}. \tag{8}

\end{equation}

The T-Fixup scheme for encoder decoder attention is then as follows:

- Apply Xavier initialization for all parameters excluding input embeddings. Use Gaussian initialization $\mbox{Norm}_{\theta}[0, D^{-\frac12}]$ for input embeddings where $D$ is the embedding dimension.
- Scale $\boldsymbol\Phi_{vh}$ and $\boldsymbol\Phi_{c}$ in each encoder attention block and weight matrices in each encoder MLP block by $0.67N_{e}^{-\frac14}$ where $N_{e}$ is the number of transformer blocks (i.e, self-attention + MLP) in the encoder. Scale the input embeddings to the encoder by $(9N_{e})^{-\frac14}$
- Scale parameters $\boldsymbol\Phi_{vh}$ and $\boldsymbol\Phi_{c}$ in each decoder attention block, weight matrices in each decoder MLP block and input embeddings in the decoder by $(9N_d)^{-\frac14}$ where $N_d$ is the number of transformer blocks in the decoder.

In practice, *T-Fixup* is able to train significantly deeper transformer models with improved performance on the task of machine translation. For the detailed derivation of this method, we refer the readers to the original paper.

## Balancing the residual dependencies

An alternative approach is to balance the residual dependencies, which in turn will limit the output perturbations $\Delta \bf T[\mathbf{X}]$. Equation 6 shows that controlling the magnitude of the output fluctuation $\Delta \bf T[\mathbf{X}]$ also bounds the magnitude of the gradient updates $\Delta \boldsymbol\theta$, which in turn mitigates the problem of gradient vanishing. Here we'll consider three possible approaches.

**Pre-LN transformers:** One simple solution is to change the location of layer normalization inside the transformer layer so that it occurs inside the residual blocks and before the self-attention or MLP (figure 7). This is known as the *pre-LN transformer*. This simple change can help control the gradient magnitude and balance the residual dependencies.

Pre-LN transformer models can be trained without learning rate warm-up. However, they also lead to inferior empirical performance. It has been speculated that this is because now the models are restricted not to depend too much on the contents of their residual layers Liu *et* al., 2020.

**Admin:** To bridge this performance gap, *adaptive model initialization* or *Admin* aims to bound the output fluctuation $\Delta \bf T[\mathbf{X}]$ by controlling the residual dependencies while retaining the original architecture.

Admin adds a new parameter $1\times D$ parameter vector $\boldsymbol\psi$ to each residual block. The self-attention block is then constructed as $\bf LayerNorm[\mathbf{X} \odot \boldsymbol\Psi + \bf MhSa[\mathbf{X}]]$ where $\odot$ is the element-wise product and $\boldsymbol\Psi$ is an $I\times D$ matrix where each row is a copy of $\boldsymbol\psi$. The residual connection around the parallel MLP layer is treated in the same way (figure 8a).

The new parameters at the $n^{th}$ layer are initialized to be the output standard deviation at that layer before this intervention. This can be estimated by setting all elements of $\boldsymbol\psi$ to one and forward propagating on a few training instances.

**ReZero:** In a similar vein, *ReZero* removes the layer normalization and introduces a single trainable parameter $\alpha$ per residual layer so that the self-attention block residual layer becomes, $\mathbf{X} + \alpha\bf MhSa[\mathbf{X}]$, where $\alpha$ is initialized to zero (figure 8b). The result of this is that the entire network is initialized just to compute the identity function, and the contributions of the self-attention and MLP layers are gradually and adaptively introduced.

Empirically, both Admin and ReZero work well for training deeper transformer models with better generalization performance, which demonstrates the effectiveness of balancing the residual dependencies.

## Reducing the optimizer variance

We noted before that the high variance of learning rates in the Adam optimizer at the early stages of training exacerbates the problems of amplified output perturbations and gradient vanishing. Liu *et* al., (2019) argue that this is due to the lack of samples in the early stages of learning. They base their argument on an experiment in which they do not change the model parameters or momentum term of Adam for the first 2000 learning steps, but only adapt the learning rate. After this, warm-up is no longer required.

Based on these observations, they propose *Rectified Adam* or *RAdam* which gradually changes the momentum term over time in a way that helps avoid high variance. One way to think of this is that we have effectively incorporated learning rate warm-up into the Adam algorithm, but in a principled way.

## How to train deeper transformers on small datasets

In the previous sections, we have seen that great progress has been made towards understanding transformer training. Several solutions have been proposed that allow the training of significantly deeper transformer models with improved empirical performance.

However, they have only been applied to tasks with sufficient training data such as machine translation and language modelling. This is possibly due to the commonly-held belief that training deep transformers from scratch requires large datasets. For small datasets, it is typical just to add shallow and simple additional layers (*e.g.*, a classifier head) to pre-trained models and then fine-tune.

So, what prevents practitioners from training deep transformers on small datasets? It turns out that the final missing piece of the puzzle is the batch size. For small datasets, it's necessary to leverage large pre-trained models and then fine-tune. However, the size of these models limits the batch size and when the batch size is small, the variance of the updates is even larger, which makes training even harder. Even if we could use a larger batch size, it usually results in poorer generalization, especially on small datasets.

In short, small datasets require pre-trained models and small batch sizes to perform well, but these two requirements make training additional transformer layers challenging. To resolve the high variance of training updates in small batch sizes, the three ideas from the previous section can all be applied. However, these approaches all assume that the inputs to the transformers are randomly initialized embeddings, but this is not true if we are adding yet-to-be-trained transformers on top of pre-trained models (figure 9).

DT-Fixup is a data-dependent initialization strategy developed by Borealis AI. It adapts the T-Fixup method for this type of mixed setting. DT-Fixup allows significantly deeper transformers to be trained with small datasets for challenging tasks such as Text-to-SQL semantic parsing and logical reading comprehension. This demonstrates that training deep transformers with small datasets is feasible with the correct optimization procedure.

## Conclusion

In the first two parts of this blog, we introduced the transformer, and discussed extensions and relations to other models. In this final part, we have discussed the complex topic of how to train deep transformer models effectively. We hope that this discussion will help practitioners to train deeper transformers for different applications and help identify potential directions for further improving transformer training.