In the first section, we'll discuss position embeddings. The transformer operates on unordered sets of embeddings, but often we are processing ordered sequences (e.g., words in NLP). We will describe the ways that the architecture has been adapted to take into account the position of each element in the sequence. In the second section, we'll discuss efficiency. The attention computation grows quadratically with the sequence length and in practice this limits the maximum length we can use. We'll describe work that allows the transformer to work efficiently with longer sequences. We will conclude by describing how the self-attention mechanism relates to other models, including RNNs, graph neural networks, capsule networks, Hopfield networks, CNNs, gating networks, and hypernetworks.
In part I, we discussed how the core component of the transformer is dot-product self attention $\bf Sa[\mathbf{X}]$. In this section, we'll provide a brief review of this mechanism. Self-attention takes a set of vectors $\{\mathbf{x}_{i}\}_{i=1}^{I}$ (which form the $I$ rows of $\mathbf{X}$) and modifies them based on the degree to which they attend to each other:
\begin{equation}
{\bf Sa}[\mathbf{X}] =\bf Softmax\left[\frac{(\mathbf{X}\boldsymbol\Phi_{q})(\mathbf{X}\boldsymbol\Phi_{k})^{T}}{\sqrt{d_{q}}}\right]\mathbf{X}\boldsymbol\Phi_{v}, \tag{1}
\end{equation}
where the function $\bf Softmax[\bullet]$ performs a separate softmax operation on each row of the input. The terms $\boldsymbol\Phi_{q}, \boldsymbol\Phi_{k}$ and $\boldsymbol\Phi_{v}$ are known as the query matrices, key matrices and value matrices respectively, and when applied to the data they form the queries $\mathbf{X}\boldsymbol\Phi_{q}$, keys $\mathbf{X}\boldsymbol\Phi_{k}$, and values $\mathbf{X}\boldsymbol\Phi_{v}$.
In simple terms, for each input $\mathbf{x}_{i}$ the self attention mechanism returns a weighted sum of the values for every input $\mathbf{x}_{j}$, where the weight depends on the dot product similarity between the query for $\mathbf{x}_{i}$ and the key for $\mathbf{x}_{j}$. These similarities are normalized by the softmax function so that they are positive and sum to one and after normalization are referred to as attention. The term $\bf Softmax\left[(\mathbf{X}\boldsymbol\Phi_{q})(\mathbf{X}\boldsymbol\Phi_{k})^{T}/\sqrt{d_{q}}\right]$ is of size $I\times I$ and is known as the attention matrix.
The self-attention mechanism is equivariant to permutations of the input. In other words, if we apply a permutation matrix $\mathbf{P}$ to the rows of the matrix $\mathbf{X}$, the output will also be permuted, but will otherwise stay the same:
\begin{eqnarray}
{\bf Sa}[\mathbf{P}\mathbf{X}] &=&\bf Softmax\left[\frac{(\mathbf{P}\mathbf{X}\boldsymbol\Phi_{q})(\mathbf{P}\mathbf{X}\boldsymbol\Phi_{k})^{T}}{\sqrt{d_{q}}}\right]\mathbf{P}\mathbf{X}\boldsymbol\Phi_{v}\nonumber\\
&=&\mathbf{P}\cdot \bf Softmax\left[\frac{(\mathbf{X}\boldsymbol\Phi_{q})(\mathbf{X}\boldsymbol\Phi_{k})^{T}}{\sqrt{d_{q}}}\right]\mathbf{P}^{T}\mathbf{P}\mathbf{X}\boldsymbol\Phi_{v}\nonumber \\
&=&\mathbf{P}\cdot\bf Softmax\left[\frac{(\mathbf{X}\boldsymbol\Phi_{q})(\mathbf{X}\boldsymbol\Phi_{k})^{T}}{\sqrt{d_{q}}}\right]\mathbf{X}\boldsymbol\Phi_{v}\nonumber \\
&=&\mathbf{P}\cdot {\bf Sa}[\mathbf{X}] . \tag{2}
\end{eqnarray}
This is not desirable when the vectors $\mathbf{x}_{i}$ represents words in a sentence as the order of the inputs is important; the sentences The man ate the fish and The fish ate the man have different meanings and we hope that any neural processing will take this into account.
Before discussing how to encode positional information, it is worth thinking about what properties we would like this encoding to have. First, we need to know the relative position of two words rather than their absolute position. Transformers are trained with spans of text that may contain multiple sentences, and the start of the span may be mid-way through the sentence. Consequently, the absolute position does not contain much useful information.
Second, word embeddings that are far from one another in the sequence might be expected to interact with one another less than those that are closer. For example, when we disambiguate a pronoun (e.g., understanding who he is in a sentence like He ate the sandwich), it's likely that the answer is close at hand, not several thousand words away. Finally, we might expect that we need the relative position with less and less accuracy as the distance between tokens increases. For small distances, the relative word position directly affects the meaning of the sentence, but for larger distances the words are probably in different sentences and the exact distance between them matters much less.
In the original transformer paper, position was encoded by adding a pre-determined matrix $\boldsymbol\Pi$ to the input embedding matrix $\mathbf{X}$ where the position embeddings are pre-defined as:
\begin{eqnarray}
\Pi_{i, 2f} &=& \sin[\omega_f i] \nonumber\\
\Pi_{i, 2f+1} &=& \cos[\omega_f i] \tag{3}
\end{eqnarray}
where $i$ indexes the position in the sequence and $f$ indexes pairs of adjacent embedding dimensions. The angular frequencies $\omega_f$ of adjacent dimensions $d = 2f$ and $d+1 = 2f+1$ are the same and take the value $\omega_f = 10000^{-2f/D}$ (figure 1).
One way to think about adding the matrix $\boldsymbol\Pi$ is that we are adding a different vector to the embedding $\mathbf{x}_{i}$ where this vector encodes the absolute position $i$. So if the same word occurs at different positions in the sequence, it would have two different embeddings. For this reason, this sinusoidal encoding is considered an absolute position embedding.
This scheme is worth examining closely. In the self-attention mechanism we apply linear transformations $\boldsymbol\Phi_{q}$ and $\boldsymbol\Phi_{k}$ to $\mathbf{X}+\boldsymbol\Pi$ and then compute dot products between every pair of columns in the resulting matrices. We'll now consider several interesting properties that emerge we apply linear transformations to this sinusoidal embedding and take dot products.
Separating position and word embeddings: At first sight, adding the position embeddings to the data seems a bad idea; we probably need both the word embedding and the position embedding without having them hopelessly entangled. However, this is not necessarily a problem. Since the embedding dimension $D$ is usually greater than the maximum sequence length $I$ (e.g., BERT used D=1024, I=512), it is possible for the system to learn word embeddings that lie outside the subspace of the position embeddings. If this were the case, the system could recover the word embeddings by learning linear transformations $\boldsymbol\Phi_{q}$ and $\boldsymbol\Phi_{k}$ where the null-space spans the position embeddings. Similarly, the system could recover the position embeddings.
Down-weighting distant elements: The dot product between the position encodings $\boldsymbol\pi_{i}$ and $\boldsymbol\pi_{j}$ at different positions $i$ and $j$ (i.e. rows of $\boldsymbol\Pi)$ gets smaller as the relative position $|i-j|$ increase (figure 2). So if the system were to retrieve the position embeddings using a linear transform as described above, it could create an attention matrix that increasingly down-weights attention between elements as they become more distant when it computes the dot products.
Relative vs. absolute positions: We have added a unique embedding $\boldsymbol\pi_{i}$ at each absolute position $i$. However, it's possible to transform the embedding $\boldsymbol\pi_{i}$ at position $i$ to that at relative position $i+j$ using a linear operation. To see this, consider the embeddings $\left(\sin[\omega_{f}i]\;\;\cos[\omega_{f} i]\right)^{T}$ at word position $i$ and two adjacent dimensions $d$ and $d+1$ of the embedding. Applying the following linear transform we get:
\begin{eqnarray}
\begin{pmatrix}\cos[\omega_{f} j]&\sin[\omega_{f} j]\\-\sin[\omega_{f} j]&\cos[\omega_{f} j]\end{pmatrix}
\begin{pmatrix}\sin[\omega_{f} i]\\\cos[\omega_{f} i]\end{pmatrix} &=&\begin{pmatrix}
\cos[\omega_{f} j]\sin[\omega_{f} i]+ \sin[\omega_{f} j]\cos[\omega_{f} i]\\
-\sin[\omega_{f} j]\sin[\omega_{f} i]+\cos[\omega_{f} j]\cos[\omega_{f} i]\end{pmatrix}\nonumber \\ &=&
\begin{pmatrix}\sin[\omega_{f} (i+j)]\\\cos[\omega_{f} (i+j)]\end{pmatrix} \tag{4}
\end{eqnarray}
where we have used the trigonometric addition identities. So by applying the appropriate linear transformation, the system can transform the position encoding at position $i$ to that at position $i+j$. If it did this for just the queries, then the dot products between position vectors would take a maximum value at a relative offset of $j$ rather than 0.
Note that all of the above is supposition; the trained network does not necessarily do any of these things. The point is that these capabilities are available to it if it chooses to use them.
We've seen that it's possible to use sinusoidal embeddings for which the linear projections and dot-products have useful properties. An obvious next step is to learn the position embedding matrix $\boldsymbol\Pi$ during training. This approach was also tried in the original transformer paper and adopted by subsequent encoder models like BERT and GPT-2.
The advantage of learning the position embeddings is that we can potentially capture more complex properties. The disadvantage is that it adds a lot of extra parameters to the model, and once learned, the model cannot be extended to longer sequence lengths.
It's interesting however, to test if the learned position embeddings capture the desirable properties of the sinusoidal embeddings. Wang and Chen (2020) compared the cosine similarities (closely related to dot products) between embeddings at different relative distances (figure 3). For GPT-2 the similarity of the embeddings decreases as a function of distance for small distances with a periodic component at larger distances. For BERT, the results are more noisy and complicated.
They also examined if it is possible to predict the absolute positions by applying linear regression to the learned embedding. For the BERT embeddings, the error in these predictions is large, for the GPT-2 embeddings very small, and for the sinusoidal embeddings zero. The same experiment can be done by regressing pairs of position embeddings to predict relative position. Here, the error is again greatest for the BERT embeddings, but this time, the GPT-2 embeddings outperform the pre-defined sinusoidal embeddings.
Adding position embeddings modifies the self-attention calculation to:
\begin{equation}
\bf Sa [\mathbf{X}] = \bf Softmax\left[\frac{((\mathbf{X}+\boldsymbol\Pi)\boldsymbol\Phi_{q})((\mathbf{X}+\boldsymbol\Pi)\boldsymbol\Phi_{k})^{T}}{\sqrt{d_{q}}}\right](\mathbf{X}+\boldsymbol\Pi)\boldsymbol\Phi_{v}. \tag{5}
\end{equation}
The position matrix modifies both the attention matrix (the softmax term) and the computation of the values. There have been a number of studies in which the latter modification is dropped so that just the attention matrix is changed:
\begin{equation}
\bf Sa [\mathbf{X}] = \bf Softmax\left[\frac{((\mathbf{X}+\boldsymbol\Pi)\boldsymbol\Phi_{q})((\mathbf{X}+\boldsymbol\Pi)\boldsymbol\Phi_{k})^{T}}{\sqrt{d_{q}}}\right]\mathbf{X}\boldsymbol\Phi_{v}. \tag{6}
\end{equation}
In these circumstances, the position information is usually added at every layer as it is only represented very implicitly in the output of the computation.
Let's consider the un-normalized and pre-softmax attention matrix:
\begin{equation}
\tilde{\mathbf{A}} = ((\mathbf{X}+\boldsymbol\Pi)\boldsymbol\Phi_{q})((\mathbf{X}+\boldsymbol\Pi)\boldsymbol\Phi_{k})^{T}, \tag{7}
\end{equation}
which has elements:
\begin{eqnarray}
\tilde{a}_{i,j} &=& ((\mathbf{x}_{i}+\boldsymbol\pi_{i})\boldsymbol\Phi_{q})((\mathbf{x}_{j}+\boldsymbol\pi_{j})\boldsymbol\Phi_{k})^{T}\nonumber \\
&=& \underbrace{\mathbf{x}_{i}\boldsymbol\Phi_q\boldsymbol\Phi_{k}^{T}\mathbf{x}_{j}^{T}}_\text{content-content}+\underbrace{\mathbf{x}_{i}\boldsymbol\Phi_q\boldsymbol\Phi_{k}^{T}\boldsymbol\pi_{j}^{T}}_{\text{content-position}}+\underbrace{\boldsymbol\pi_{i}\boldsymbol\Phi_q\boldsymbol\Phi_{k}^{T}\mathbf{x}_{j}^{T}}_{\text{position-content}}+\underbrace{\boldsymbol\pi_{i}\boldsymbol\Phi_q\boldsymbol\Phi_{k}^{T}\boldsymbol\pi_{j^{T}}}_{\text{position-position}},\label{eq:attention_breakdown} \tag{8}
\end{eqnarray}
where we can see that each element has four contributions in which the position embedding $\boldsymbol\pi$ and the content vector $\mathbf{x}$ interact differently. This expression has been modified in various ways
Untied embeddings: One simple modification is to decouple or untie the content and position components rather than add them together before projection. A simple way to do this is to remove the terms where they interact and to use a separate linear transform for each to give:
\begin{equation}
\tilde{a}_{i,j} = \underbrace{\mathbf{x}_{i}\boldsymbol\Phi_q\boldsymbol\Phi_{k}^{T}\mathbf{x}_{j}^{T}}_\text{content-content}+\underbrace{\boldsymbol\pi_{i}\boldsymbol\Psi_q\boldsymbol\Psi_{k}^{T}\boldsymbol\pi_{j}^{T}}_{\text{position-position}}. \tag{9}
\end{equation}
Relative embeddings: Another modification is to directly inject information about the relative position. For example, Shaw et al. (2018) add a term $\boldsymbol\pi_{|i-j|}$ which depends on the position difference.
\begin{equation}\label{eq:rel_pos_shaw}
\tilde{a}_{i,j} = \underbrace{\mathbf{x}_{i}\boldsymbol\Phi_q\boldsymbol\Phi_{k}^{T}\mathbf{x}_{j}^{T}}_\text{content-content}+\underbrace{\mathbf{x}_{i}\boldsymbol\Phi_q\boldsymbol\pi_{i-j}^{T}}_{\text{content-position}}. \tag{10}
\end{equation}
where a different position vector $\boldsymbol\pi_{i-j}$ is learned for each signed position offset $i-j$ where this offset is usually clipped so after a certain distance, all terms are the same. Note that this position vector is defined directly in the space of the keys rather than projected into it^{1}.
Raffel et al. (2019) simplified this further by simply adding a learnable scalar $\pi_{|i-j|}$ to the attention matrix
\begin{equation}
\tilde{a}_{i,j} = \underbrace{\left(\mathbf{x}_{i}\boldsymbol\Phi_q\boldsymbol\Phi_{k}^{T}\mathbf{x}_{j}^{T}\right)}_\text{content-content} + \pi_{i-j}. \tag{11}
\end{equation}
where $\pi_{i-j}$ is a different scalar for each signed offset $i-j$. Relative position information has also been combined directly in other ways various other ways such as simply multiplying the attentions by a modifying factor $\pi_{|i-j|}$:
\begin{equation}
\tilde{a}_{i,j} = \underbrace{\left(\mathbf{x}_{i}\boldsymbol\Phi_q\boldsymbol\Phi_{k}^{T}\mathbf{x}_{j}^{T}\right)}_\text{content-content}\cdot \pi_{|i-j|}. \tag{12}
\end{equation}
where $\pi_{i-j}$ is a different scalar for each absolute offset $|i-j|$.
Finally, we note that pre-defined sinusoidal embeddings have also been used in a system based on equation 10 (where $\boldsymbol\pi_{ij}$ now contains sinusoidal terms in relative position $i-j$) and also in more complex ways.
Combining ideas: Many schemes combine have proposed position embeddings that combine the ideas of (i) only retaining certain terms from equation 8, (ii) using different projection matrices for the content and position embeddings, and (iii) using relative embeddings. For example, in DeBERTa they use:
\begin{equation}
\tilde{a}_{i,j} =
\underbrace{\mathbf{x}_{i}\boldsymbol\Phi_q\boldsymbol\Phi_{k}^{T}\mathbf{x}_{j}^{T}}_\text{content-content}+
\underbrace{\mathbf{x}_{i}\boldsymbol\Phi_q\boldsymbol\Psi_{k}^{T}\boldsymbol\pi_{i-j}^{T}}_{\text{content-position}}+
\underbrace{\boldsymbol\pi_{j-i}\boldsymbol\Psi_q\boldsymbol\Phi_{k}^{T}\mathbf{x}_{j}^{T}}_{\text{position-content}}. \tag{13}
\end{equation}
where they drop the position-position term and have a different relative embedding $\boldsymbol\pi_{i-j}$ for each signed offset $i-j$ between the positions.
In this section we have provided a brief overview of how position information is added into transformers. At the time of writing, it is not clear which of these position embedding schemes is empirically superior. For downstream tasks on BERT, relative position embeddings generally perform better than absolute position embeddings, but there does not seem to be much difference between sinusoidal embeddings and learned embeddings. To learn more about position embeddings, consult this survey paper.
In the second part of this blog, we consider modifications to the self-attention mechanism that make it more efficient as the sequence length increases. The self-attention mechanism takes $I$ inputs $\mathbf{x}_{i}$ and returns $I$ modified outputs. In this process, each input $\mathbf{x}_{i}$ interacts with one another; each output is a weighted sum of the values corresponding to every input, where the weights depend on how much the input attends to every other input. As such, the transformer naturally has quadratic complexity in the size $I$ of the input sequence.
However, there are some situations in which we might expect this input set to be extremely large. In NLP, we may wish to summarize long documents or answer questions about a body of documents. In other modalities like vision or audio processing, the data can also be of extremely high dimension. In these circumstances, the quadratic complexity of the attention mechanism can become the limiting factor and a sub-field has emerged that tries to address this bottleneck.
In this section, we review three lines of work. First, we discuss methods that aim to reduce the size of the attention matrix. Second, we review approaches that introduce sparsity into the attention matrix. Finally, we present methods that treat the self-attention computation as a kernel function and try to approximate this to create algorithms with linear complexity in the sequence length.
One simple idea to make self-attention more efficient is to reduce the size of the attention matrix. In memory compressed attention, a strided convolution is applied to the keys and values so the self-attention operation becomes:
\begin{equation}
\bf Sa[\mathbf{X}] = \bf Softmax\left[\mathbf{X}\boldsymbol\Phi_{q}(\boldsymbol\theta_{k}\circledast\mathbf{X}\boldsymbol\Phi_{k})^{T} \right](\boldsymbol\theta_{v}\circledast\mathbf{X}\boldsymbol\Phi_{v}), \tag{14}
\end{equation}
where $\boldsymbol\theta_{k}$ and $\boldsymbol\theta_{v}$ are the convolution kernels. If the stride $s$ is the same as the kernel size, then the effect is to take a learned weighted average of nearby key/value vectors and the resulting attention matrix reduces to size $I\times I/s$ (figure 5).
The Linformer applies a very similar trick that is motivated by the observation that the self-attention mechanism is often low-rank in practice. Consequently, we can reduce the complexity of the calculation by projecting the keys and value into a learned subspace:
\begin{equation}
\bf Sa[\mathbf{X}] = \bf Softmax\left[\mathbf{X}\boldsymbol\Phi_{q}(\boldsymbol\Psi_{k}\mathbf{X}\boldsymbol\Phi_{k})^{T} \right](\boldsymbol\Psi_{v}\mathbf{X}\boldsymbol\Phi_{v}), \tag{15}
\end{equation}
where $\boldsymbol\Psi_{k}$ and $\boldsymbol\Psi_{v}$ are the $I/s\times I$ projection matrices for the keys and values respectively.
Another approach to making attention more computationally efficient is to constrain the attention computation so that every input does not attend to every other input. In local attention the inputs are divided into disjoint groups of neighbours and each block is passed through a separate self-attention mechanism before recombining (figure 6) In this way, inputs within the same block only attend to one another. Of course, this has the disadvantage that elements that are far from each other in the sequence never interact with one another, but alternating transformer layers that use local and full attention solves this problem.
Local attention can be visualized by plotting a matrix showing interaction of the queries and keys (figure 6). Note that for the decoder version, we also employ masked self-attention so each query can only attend to keys that have the same index or less and there are no interactions in the upper triangular portion.
Visualizing attention in this way leads naturally to the idea of using a convolutional structure (figure 7), in which each input only interacts with the nearest few inputs (or nearest preceding inputs for decoders). When used alone, this will mean that it may take many layers for information to propagate along the sequence. Again, this drawback can be remedied by alternating layers with the convolutional attention patterns and layers with full attention. Indeed, this is what is done in GPT-3. A different approach that maintains the overall sparsity is to use dilated convolutions with different dilation rates in different layers (figure 7b-c), or by introducing layers where some a few of the queries interact with every key (figure 7d). Collectively, these methods are referred to as sparse transformers.
The Longformer also used a convolutional structure which is sometimes dilated, but simultaneously allowed some keys to and queries to interact with all of the others (figure 9a). This was referred to as global attention and the positions correspond to special tokens such as the $<$cls$>$ token in BERT or special tokens in question answering tasks that delimit the question and answer. Note that global attention can only be used in encoder models since elements attend to every other element and hence see ahead in the sequence.
A natural extension of this method is to define some new content embeddings which attend to all of the keys and queries, but do not themselves correspond to any individual tokens in the input (figure 9). This is known as the extended transformer construction (ETC). These additional global content embeddings act as a kind of memory, which can both receive and broadcast information from all of the elements and are combined with a sparse convolutional pattern which ensures strong interactions between nearby inputs. The BigBird model took this idea by one step further by also adding sparse random connections between elements to help ensure the rapid mixing of information from different parts of the sequence.
One notable complication of using global content embeddings occurs if it is combined with relative attention; there is no relative offset between the global and regular elements, and so special relative position embeddings are learned for mapping to, from, and between, the global content embeddings.
In this section we have reviewed approaches that make self-attention more efficient, by limiting the interaction between different inputs. Note that all of these methods use pre-defined sparsity patterns. There is also another line of research that attempts to learn the sparsity pattern. This includes the routing transformer, reformer and Sinkhorn transformer.
A third approach to making self-attention more efficient it to approximate the attention computation using Kernel methods. The premise is that the dot product attention for the $i^{th}$ query can thought of as a special case of the following computation:
\begin{equation}
\mathbf{x}_{i}^{\prime} = \frac{\sum_{j=1}^{I}\mbox{sim}[\mathbf{x}_{i}\boldsymbol\Phi_{q}, \mathbf{x}_{j}\boldsymbol\Phi_{k}]\mathbf{x}_{j}\boldsymbol\Phi_{v}}{\sum_{j=1}^{I}\mbox{sim}[\mathbf{x}_{i}\boldsymbol\Phi_{q}, \mathbf{x}_{j}\boldsymbol\Phi_{k}]} \tag{16}
\end{equation}
where $\mbox{sim}[\bullet,\bullet]$ returns a measure of similarity between the two arguments. For dot-product self-attention, this is defined as $\mbox{sim}[\mathbf{x}_{i}\boldsymbol\Phi_{q}, \mathbf{x}_{j}\boldsymbol\Phi_{k}] = \exp[\mathbf{x}_{i}\boldsymbol\Phi_{q}(\mathbf{x}_{j}\boldsymbol\Phi_{k})^{T}]$.
We now treat this similarity as a kernel function, and as such it can be expressed as the dot product of non-linear transformations $\bf z[\bullet]$ of the inputs
\begin{equation}
\mbox{sim}[\mathbf{x}_{i}\boldsymbol\Phi_{q}, \mathbf{x}_{j}\boldsymbol\Phi_{k}] = \bf z[\mathbf{x}_{i}\boldsymbol\Phi_{q}]\bf z[\mathbf{x}_{j}\boldsymbol\Phi_{k}]^{T}, \tag{17}
\end{equation}
which means that the output becomes:
\begin{eqnarray}
\mathbf{x}_{i}^{\prime} &=& \frac{\sum_{j=1}^{I}\bf z [\mathbf{x}_{i}\boldsymbol\Phi_{q}]\bf z [\mathbf{x}_{j}\boldsymbol\Phi_{k}]^{T}\mathbf{x}_{j}\boldsymbol\Phi_{v}}{\sum_{j=1}^{I}\bf z[\mathbf{x}_{i}\boldsymbol\Phi_{q}]\bf z[\mathbf{x}_{j}\boldsymbol\Phi_{k}]^{T}}\nonumber \\
&=&\frac{\bf z[\mathbf{x}_{i}\boldsymbol\Phi_{q}]\sum_{j=1}^{I}\bf z[\mathbf{x}_{j}\boldsymbol\Phi_{k}]^{T}\mathbf{x}_{j}\boldsymbol\Phi_{v}}{\bf z[\mathbf{x}\boldsymbol\Phi_{q}]\sum_{j=1}^{I}\bf z[\mathbf{x}_{j}\boldsymbol\Phi_{k}]^{T}}, \tag{18}
\end{eqnarray}
where we have used the associativity property of matrix multiplication between the first and second lines.
If we could find $\bf z[\bullet]$ such that $\bf z[\mathbf{x}_{i}\boldsymbol\Phi_{q}]\bf z[\mathbf{x}_{j}\boldsymbol\Phi_{k}]^{T} = \exp[\mathbf{x}_{i}\boldsymbol\Phi_{q}(\mathbf{x}_{j}\boldsymbol\Phi_{k})^{T}]$, then this is much more efficient. We compute the terms in the sums once and then compute each $\mathbf{x}_{i}$ term separately with a matrix multiplication. It turns out that such a non-linear transform $\bf z[\bullet]$ does indeed exist, but unfortunately, it maps the argument to an infinite dimensional space. From a computational viewpoint, this is not very helpful!
We'll describe two approaches that sidestep this problem. First, the linear transformer implicitly uses a different measure of similarity $\bf sim[\mathbf{a},\mathbf{b}] = \bf z[\mathbf{a}]\bf z[\mathbf{b}]^{T}$ by defining a function $\bf z[\bullet]$ which is more tractable. In particular, they use $\bf z[\mathbf{a}] = \bf elu[\mathbf{a}]+1$ where $\bf elu[\bullet]$ is the exponential linear unit which is a pointwise non-linearity. Second, the performer attempts to approximate the standard dot-product similarity using a finite dimensional mapping $\bf z[\bullet]$. The latter approach is empirically more successful, but this may be because the tricks for training transformers (see part III of this blog) do not transfer effectively to using a different similarity measure.
These approaches can be also adapted to decoders. Here, when we calculate the output corresponding to input $\mathbf{x}_{i}$ we only use the partial sums up to index $i$:
\begin{eqnarray}
\mathbf{x}_{i}^{\prime} &=& \frac{\bf z[\mathbf{x}_{i}\boldsymbol\Phi_{q}]\sum_{j=1}^{i}\bf z[\mathbf{x}_{j}\boldsymbol\Phi_{k}]^{T}\mathbf{x}_{j}\boldsymbol\Phi_{v}}{\bf z[\mathbf{x}_{i}\boldsymbol\Phi_{q}]\sum_{j=1}^{i}\bf z[\mathbf{x}_{j}\boldsymbol\Phi_{k}]^{T}} \nonumber \\
&=& \frac{\bf z[\mathbf{x}_{i}\boldsymbol\Phi_{q}]\mathbf{A}_{i}}{\bf z[\mathbf{x}_{i}\boldsymbol\Phi_{q}]\mathbf{b}_i}, \tag{19}
\end{eqnarray}
where $\mathbf{A}_{i}$ and $\mathbf{b}_{i}$ represent the partial sums in the numerator and denominator respectively. If we initialize $\mathbf{A}_{0}$ and $\mathbf{b}_{0}$ to zero, then the we can compute all the terms efficiently by iterating:
\begin{eqnarray}\label{eq:transformer_rnn}
\mathbf{A}_{i}&\leftarrow&\mathbf{A}_{i-1}+ \bf z[\mathbf{x}_{j}\boldsymbol\Phi_{k}]^{T}\mathbf{x}_{i}\boldsymbol\Phi_{v}\nonumber \\
\mathbf{b}_{i}&\leftarrow&\mathbf{b}_{i-1}+ \bf z[\mathbf{x}_{j}\boldsymbol\Phi_{k}]^{T}\nonumber \\
\mathbf{x}_{i}^{\prime}&\leftarrow& \frac{\bf z[\mathbf{x}_{i}\boldsymbol\Phi_{q}]\mathbf{A}_{i}}{\bf z[\mathbf{x}_{i}\boldsymbol\Phi_{q}]\mathbf{b}_i}. \tag{20}
\end{eqnarray}
In conclusion, if we consider the interaction between the queries and keys to be a kernel function, we can replace this by the dot product of non-linear functions of the key and query. This leads naturally to a very efficient implementation for both encoder and decoder architectures.
In this section, we have reviewed three families of modifications that allow the self-attention mechanism to be extended to longer sequences without a quadratic increases in computation. To learn more about this area, consult this review paper.
In the previous sections, we have addressed the questions of how to encode position, and how to extend the transformer to longer sequence lengths. In this section, we shift gears and consider the relationship between the self-attention mechanism and other models. We'll also consider alternatives to the self-attention mechanism.
The first connection that we will draw is between the self-attention decoder and recurrent neural networks (RNNs). In the final part of the previous section, we re-interpreted the dot-product self-attention mechanism as a kernel function $\mbox{k}[\bullet, \bullet]$:
\begin{equation}
\mathbf{x}_{i}^{\prime} = \frac{\sum_{j=1}^{i}\mbox{k}[\mathbf{x}_{i}\boldsymbol\Phi_{q}, \mathbf{x}_{j}\boldsymbol\Phi_{k}]\mathbf{x}_{j}\boldsymbol\Phi_{v}}{\sum_{j=1}^{i}\mbox{k}[\mathbf{x}_{i}\boldsymbol\Phi_{q}, \mathbf{x}_{j}\boldsymbol\Phi_{k}]} = \frac{\sum_{j=1}^{i} \bf z[\boldsymbol\Phi_{q}\mathbf{x}_{i}]\bf z[\mathbf{x}_{j}\boldsymbol\Phi_{k}]^{T}\mathbf{x}_{j}\boldsymbol\Phi_{v}}{\sum_{j=1}^{i} \bf z[\mathbf{x}_{i}\boldsymbol\Phi_{q}]\bf z[\mathbf{x}_{j}\boldsymbol\Phi_{k}]^{T}}. \tag{21}
\end{equation}
This means that the kernel function can be replaced by the dot product of non-linear functions $\bf z[\bullet]$ of the queries and keys and this led to the iterative computation in equation 20.
Viewed in this light, the decoder has an obvious mapping to an RNN. Each state is processed sequentially and the quantities $\mathbf{A}_{i}$ and $\mathbf{b}_{i}$ from equation 20 form the hidden state (figure 10). However, it turns out that to exactly replicate dot-product self-attention requires the function $\bf z[\bullet]$ to map its arguments to an infinite dimensional space. Hence, it is perhaps unsurprising that the transformer architecture out-performs the RNN in practice.
A hypernetwork is a network that is used to predict the parameters of a second network that then performs the main task in hand. In part I of this tutorial, we already saw that the attention matrix can be interpreted as forming the weights of a network that maps the values to the outputs (figure 11). These weights are (i) non-negative, (ii) sparse (there is no interaction between the different dimensions of the values) and (iii) shared (the same weight is used for every dimension of the interaction between the $i^{th}$ value and the $j^{th}$ output). As such they form a hypernetwork with a particular structure.
Viewed from this perspective, we might consider other mechanisms than dot-product self attention to create these weights (figure 12). The synthesizer uses a multi-layer perceptron $\bf MLP[\bullet]$ to create each row of the $I\times I$ matrix from input $\mathbf{x}_{i}$. This row is then passed through the softmax function to create the attention weights:
\begin{eqnarray}
\mbox{Synthesizer}\left[\mathbf{X} \right] &=&\bf Softmax\left[\bf MLP[\mathbf{X}]\right] \mathbf{X}\boldsymbol\Phi_{v}\nonumber \\
&=&\bf Softmax\left[\bf Relu[\mathbf{X}\boldsymbol\Phi_{1}]\boldsymbol\Phi_{2}]\right] \mathbf{X}\boldsymbol\Phi_{v}\nonumber
\end{eqnarray}
This is interesting since the rows of the attention matrix are no longer computed based on similarities between pairs of tokens, but just from each individual token alone. Surprisingly, it seems to work comparably well to the original dot-product self-attention mechanism.
A similar idea can be used to generate an attention matrix with convolutional structure. This belongs to the family of dynamic convolutions in which the convolution weights are themselves determined by the data. Part of the network block in the paper Pay less attention uses this approach. One advantage of this scheme is that there is no need for a position encoding; the convolution weights are determined by all of the inputs, and if we permute them, the result will be different.
Finally, it should be noted that linear transformers are also closely related to fast weight memory systems which are intellectual forerunners of hypernetworks.
A different way to think about self-attention is as a routing network. The attention matrix distributes (routes) each of the $I$ computed value vectors to the $I$ outputs. From this viewpoint, there is a connection between self-attention and capsule networks. Roughly speaking, a capsule network is intended to capture hierarchical relations in images, so lower network levels might detect facial parts (noses, mouths), which are then combined (routed) in higher level capsules that represent a face. One major difference is that capsule networks use routing by agreement. In self-attention, the elements $\mathbf{x}_{i}$ compete with each other for how much they contribute to output $j$ (via the softmax operation). In capsule networks, the higher levels of the network compete with each other for inputs from the lower levels.
Once we consider self-attention as a routing network, we can ask the question of whether it is necessary to make this routing dynamic (i.e, dependent on the data). Another variant of the synthesizer removed the dependence of the attention matrix on the inputs entirely and either used pre-determined random values or learned values (figure 13a). This performed surprisingly well across a variety of tasks.
Graph convolutional networks consider each input vector $\mathbf{x}_{i}$ to be associated with a node on a known graph, and process these nodes through a series of layers in which each node interacts with its neighbours. As such they have a close relationship to self-attention; they can be viewed as routing networks, but here the routing is determined by the adjacency matrix of the graph (figure 13b) and not the data.
Graph attention networks (figure 13c) combine both mechanisms; the routing depends both on the data (although using additive attention, not dot-product attention) and the graph structure (which is used to mask the attention matrix in a similar way to in masked self-attention in decoders).
Returning to the original self-attention mechanism, it is now clear that it can be viewed as a graph neural network on the complete graph, where the query tokens are the destination nodes and the key and value tokens are the source nodes.
Linear convolutions of the neighboring inputs in the sequence can be considered a special case of multi-head dot-product self attention with relative position embeddings. For example, consider using additive position embeddings so that the overall self-attention mechanism is given by:
\begin{equation}
{\bf Sa}[\mathbf{X}] =\bf Softmax\left[(\mathbf{X}\boldsymbol\Phi_{q})(\mathbf{X}\boldsymbol\Phi_{k})^{T}+\boldsymbol\Pi\right]\mathbf{X}\boldsymbol\Phi_{v}, \tag{22}
\end{equation}
where the matrix $\boldsymbol\Pi$ has a different learned value $\pi_{i-j}$ for each offset $|i-j|$. Now consider setting $\boldsymbol\Phi_{q}=\boldsymbol\Phi_k = \mathbf{0}$ and $\boldsymbol\Phi_{v}=\mathbf{I}$ to yield:
\begin{equation}
{\bf Sa}[\mathbf{X}] =\bf Softmax\left[\boldsymbol\Pi\right]\mathbf{X}\nonumber.
\end{equation}
If we now choose the relative position contributions $\pi_{i-j}$ to be very large for one offset $i-j$ and small for all of the others, the overall effect will be to create an attention matrix with zeros everywhere except within a single diagonal offset by $i-j$ from the center, where the values will be one. When applied to the data $\mathbf{X}$, this has the effect of shifting the rows of the value matrix by $j$. In a multi-head attention context, each head could learn a different offset. When the outputs of these heads are recombined using:
\begin{equation}
{\bf MhSa}[\mathbf{X}] = \left[{\bf Sa}_{1}[\mathbf{X}]\;{\bf Sa}_{2}[\mathbf{X}]\;\ldots\;{\bf Sa}_{H}[\mathbf{X}] \right]\boldsymbol\Phi_{c}, \tag{23}
\end{equation}
it is possible to choose $\boldsymbol\Phi_{c}$ so that all of the outputs from the $h^{th}$ self attention mechanism have the same weight and so we have effectively performed a convolution on the rows of $\mathbf{X}$.
To summarize, it is possible for a multi-head self attention with relative position embeddings to simulate convolution. This is particularly interesting when the transformer is applied to vision problems where convolutional networks are the standard. Indeed, there is some evidence that this is exactly what transformers are doing in vision tasks.
A notable characteristic of the self attention mechanism and related models is that the processing divides into two paths, one of which is later used to modify the other. In attention, this modification takes the form of pre-multiplication by the attention matrix. However, there is another family of models which use one path to just modulate the magnitude of the other.
The gated linear unit (figure 14a) is an example of such a gating mechanism. The input $\mathbf{X}$ has a linear transformation $\boldsymbol\Phi_{u1}$ applied to it and the result is passed through a pointwise sigmoid function $\bf Sig[\bullet]$ . This maps the results to between zero and one so that they can be used to modulate the magnitude of the data $\mathbf{X}\boldsymbol\Phi_{u2}$ flowing down the other path, which have been subject to a a different linear transformation. The whole function is hence:
\begin{equation}
\bf GLU[\mathbf{X}] = \bf Sig[\mathbf{X}\boldsymbol\Phi_{u1}]\odot \mathbf{X}\boldsymbol\Phi_{u2}. \tag{24}
\end{equation}
Although the architecture is superficially similar, this is not really equivalent to a transformer, as each input $\mathbf{x}_{i}$ (row of $\mathbf{X}$) is treated independently. The gated MLP addresses this by modifying the architecture to incorporate a learned linear transformation $\boldsymbol\Psi$ that combines together the different inputs:
\begin{equation}
\bf GMLP[\mathbf{X}] = (\bf Sig[\mathbf{X}\boldsymbol\Phi_{u1}]\odot \boldsymbol\Psi\mathbf{X}\boldsymbol\Phi_{u2})\boldsymbol\Phi_{v}. \tag{25}
\end{equation}
as well as a final linear transform $\boldsymbol\Phi_{v}$ that remaps to the original dimensionality. This model again has the advantage that it does not need a position encoding; the inputs are mixed using $\boldsymbol\Psi$ and if we permute their order, the output will not just be a permutation of the input.
Finally, we'll consider the relationship between Hopfield networks and the attention mechanism. A Hopfield network can retrieve a stored memory based on a query via an iteratve procedure in which the query is updated after interaction with the system. They were originally defined for binary vectors, but the modern Hopfield network extends the idea to continuous values.
Ramsauer et al. (2020) show that for a carefully defined Hopfield energy function, the update rule is equivalent to self-attention mechanism. The most natural way to think of this is in terms of encoder-decoder attention. The decoder queries memories from the encoder network. If viewed as a Hopfield network, the query-key attention computes a simple iteration of the memory retrieval. To complete the process, the output of the attention network should be feed back in as a new query until a stable state is reached (figure 15).
In this blog, we have discussed extensions to the basic self-attention mechanism. First, we discussed how to incorporate positional information, and then how to extend the self-attention mechanism to longer sequences. Finally, we have discussed the relationship between self-attention and a number of other models, including RNNs, CNNs, graph convolutional networks and Hopfield networks. We note that some caution is required here. Recent work has suggested that many of the variations of the original model do not necessarily yield consistent performance benefits.
In part III of this blog, we discuss how to train transformers in practice. To make training stable, a number of tricks are required including unusual learning rate scheduled, various forms of normalization, and careful initialization.
1 In fact they also modified the value terms in a similar way although their ablation study suggested that this did not contribute much
]]>In this blog, we will introduce weighted context-free grammars or WCFGs. These assign a non-negative weight to each rule in the grammar. From here, we can assign a weight to any parse tree by multiplying the weights of its component rules together. We present two variations of the CYK algorithm that apply to WCFGs. (i) The inside algorithm computes the sum of the weights of all possible analyses (parse trees) for a sentence. (ii) The weighted parsing algorithm find the parse tree with the highest weight.
In Part III of this tutorial, we introduce probabilistic context-free grammars. These are a special case of WCFGs where the weights of all rules with the same left-hand side sum to one. We then discuss how to learn these weights from a corpus of text. We will see that the inside algorithm is a critical part of this process.
Before we start our discussion, let's briefly review what we learned about context-free grammars and the CYK recognition algorithm in part I of this tutorial. Recall that we defined a context-free grammar as the tuple $\langle S, \mathcal{V}, \Sigma, \mathcal{R}\rangle$ with a start symbol $S$, non-terminals $\mathcal{V}$, terminals $\Sigma$ and finally the rules $\mathcal{R}$.
In our examples, the non-terminals are a set $\mathcal{V}=\{\mbox{VP, PP, NP, DT, NN, }\ldots\}$ containing sub-clauses (e.g., verb-phrase $\mbox{VP}$ ) and parts of speech (e.g., noun $\mbox{NN}$). The terminals contain the words. We will consider grammars in Chomsky Normal Form, where the rules either map one non-terminal to two other non terminals (e.g., $\text{VP} \rightarrow \text{V} \; \text{NP})$ or a single terminal symbol (e.g., $\text{V}$-> eats).
The CYK recognition algorithm takes a sentence and a grammar in Chomsky Normal Form and determines if the sentence is valid under the grammar. With minor changes, it can also return the set of valid parse trees. It constructs a chart where each position in the chart corresponds to a sub-sequence of words (figure 2). At each position, there is a binary array with one entry per rule, where this entry is set to true if this rule can be applied validly to the associated sub-sequence.
The CYK algorithm works by first finding valid unary rules that map pre-terminals representing parts of speech to terminals representing words (e.g., DT$\rightarrow$ the). Then it considers sub-sequences of increasing length and identifies applicable binary non-terminal rules (e.g., $\mbox{NP}\rightarrow \mbox{DT NN})$. The rule is applicable if there are two sub-trees lower down in the chart whose roots match its right hand side. If the algorithm can place the start symbol in the top-left of the chart, then the overall sentence is valid. The pseudo-code is given by:
0 # Initialize data structure
1 chart[1...n, 1...n, 1...V] := FALSE
2
3 # Use unary rules to find possible parts of speech at pre-terminals
4 for p := 1 to n # start position
5 for each unary rule A -> w_p
6 chart[1, p, A] := TRUE
7
8 # Main parsing loop
9 for l := 2 to n # sub-sequence length
10 for p := 1 to n-l+1 # start position
11 for s := 1 to l-1 # split width
12 for each binary rule A -> B C
13 chart[l, p, A] = chart[l, p, A] OR
(chart[s, p, B] AND chart[l-s,p+s, C])
14
15 return chart[n, 1, S]
For a much more detailed discussion of this algorithm, consult Part I of this blog.
Weighted context-free grammars (WCFGs) are context-free grammars which have a non-negative weight associated with each rule. More precisely, we add the function $g: \mathcal{R} \mapsto \mathbb{R}_{\geq 0}$ that maps each rule to a non-negative number. The weight of a full derivation tree $T$ is then the product of the weights of each rule $T_t$:
\begin{equation}\label{eq:weighted_tree_from_rules}
\mbox{G}[T] = \prod_{t \in T} g[T_t]. \tag{1}
\end{equation}
Context-free grammars generate strings, whereas weighted context free grammars generate strings with an associated weight.
We will interpret the weight $g[T_t]$ as the degree to which we favor a rule, and so, we "prefer" parse trees $T$ with higher overall weights $\mbox{G}[T]$. Ultimately, we will learn these weights in such a way that real observed sentences have high weights and ungrammatical sentences have lower weights. From this viewpoint, the weights can be viewed as parameters of the model.
Since the tree weights $G[T]$ are non-negative, they can be interpreted as un-normalized probabilities. To create a valid probability distribution over possible parse trees, we must normalize by the total weight $Z$ of all tree derivations:
\begin{eqnarray}
Z &=& \sum_{T \in \mathcal{T}[\mathbf{w}]} \mbox{G}[T] \nonumber \\
&=& \sum_{T \in \mathcal{T}[\mathbf{w}]} \prod_{t \in T} \mbox{g}[T_t], \tag{2}
\end{eqnarray}
where $\mathcal{T}[\mathbf{w}]$ represents the set of all possible parse trees from which the observed words $\mathbf{w}=[x_{1},x_{2},\ldots x_{L}]$ can be derived. We'll refer to the normalizing constant $Z$ as the partition function. The conditional distribution of a possible derivation $T$ given the observed words $\mathbf{w}$ is then:
\begin{equation}
Pr(T|\mathbf{w}) = \frac{\mbox{G}[T]}{Z}. \tag{3}
\end{equation}
We defined the partition function $Z$ as the sum of the weights of all the trees $\mathcal{T}[\mathbf{w}]$ from which the observed words $\mathbf{w}$ can be derived. However, in Part I of this tutorial we saw that the number of possible binary parse trees increases very rapidly with the sentence length.
The CYK recognition algorithm used dynamic programming to search this huge space of possible trees in polynomial time and determine whether there is at least one valid tree. To compute the partition function, we will use a similar trick to search through all possible trees and sum their weights simultaneously. This is known as the inside algorithm.
Before we present the inside algorithm, we need to introduce the semiring. This abstract algebraic structure will help us adapt the CYK algorithm to compute different quantities. A semiring is a set $\mathbb{A}$ on which we have defined two binary operators:
1. $\oplus$ is a commutative operation with identity element 0, which behaves like the addition $+$:
2. $\otimes$ is an associative operation that (right) distributes over $\oplus$ just like multiplication $\times$. It has the identity element 1 and absorbing element 0:
Similarly to grammars we will just denote semirings as tuples: $\langle\mathbb{A}, \oplus, \otimes, 0, 1\rangle$. You can think of the semiring as generalizing the notions of addition and multiplication.^{1}
Computing the partition function $Z$ for the conditional distribution $Pr(T|\mathbf{w})$ might appear difficult, because it sums over the large space of possible derivations for the sentence $\mathbf{w}$. However, we've already seen how the CYK recognition algorithm accepts or rejects a sentence in polynomial time, while sweeping though all possible derivations. The inside algorithm uses a variation of the same trick to compute the partition function.
When used for recognition, the $\texttt{chart}$ holds values of $\texttt{TRUE}$ and $\texttt{FALSE}$ and the computation was based on two logical operators OR and AND, and we can think of these as being part of the semiring $\langle\{\texttt{TRUE}, \texttt{FALSE}\}, OR, AND, \texttt{FALSE}, \texttt{TRUE}\rangle$.
The inside algorithm replaces this semiring with the sum-product semiring $\langle\mathbb{R}_{\geq 0} \cup \{+\infty\} , +, \times, 0, 1\rangle$ to get the following procedure:
0 # Initialize data structure
1 chart[1...n, 1...n, 1...|V|] := 0
2
3 # Use unary rules to find possible parts of speech at pre-terminals
4 for p := 1 to n # start position
5 for each unary rule A -> w_p
6 chart[1, p, A] := g[A-> w_p]
7
8 # Main parsing loop
9 for l := 2 to n # sub-sequence length
10 for p := 1 to n-l+1 # start position
11 for s := 1 to l-1 # split width
12 for each binary rule A -> B C
13 chart[l, p, A] = chart[l, p, A] +
(g[A -> B C] x chart[s, p, B] x chart[l-s,p+s, C] )
14
15 return chart[n, 1, S]
where we have highlighted the differences from the recognition algorithm in green.
As in the CYK recognition algorithm, each position $(p,l)$ in the $\texttt{chart}$ represents the sub-sequence that starts at position $p$ and is of length $l$ (figure 2). In the inside algorithm, every position in the chart holds a length $|V|$ vector where the $v^{th}$ entry corresponds to the $v^{th}$ non-terminal. The value held in this vector is the sum of the weights of all sub-trees for which the $v^{th}$ non-terminal is the root.
The intuition for the update rule in line 13 is simple. The additional weight for adding rule $A\rightarrow BC$ into the chart is the weight $g[A\rightarrow BC]$ for this rule times the sum of weights of all possible left sub-trees rooted in B times the sum of weights of all possible right sub-trees rooted in C. As before, there may be multiple possible rules that place non-terminal $A$ in a position corresponding to different splits of the sub-sequence and here we perform this computation for each rule and sum the results together.
In figures 3 and 4 we show a worked example of the inside algorithm for the same sentence as we used for the CYK recognition algorithm. Figure 3a corresponds to lines 4-6 of the algorithm where we are initializing the first row of the chart based on the unary rule weights. Figure 3b corresponds to the main loop in lines 9-13 for sub-sequence length $l=2$. Here we assign binary non-terminal rules and compute their weights as (cost of rule $\times$ weight of left branch $\times$ weight of right branch).
Figure 4a corresponds to the main loop in lines 9-13 for sub-sequence length $l=5$. At position (5,2), there are two possible rules that apply, both of which result in the same non-terminal. We calculate the weights for each rule as before, and add the results so that the final weight at this position sums over all sub-trees. Figure 4b shows the final result of the algorithm. The weight associated with the start symbol $S$ at position (6,1) is the partition function.
Our discussion so far does not make it clear why the method for computing the partition function is known as the inside algorithm. This is because the $\texttt{chart}$ holds the inside-weights for each anchored non-terminal. By "anchored" we mean a non-terminal $A_i^k$ pronounced "Aye from eye to Kay" is anchored to a span in the sentence (i.e, a sub-string). It yields the string $A_i^k \Rightarrow w_i, \ldots, w_k$.
An anchored rule then has the form $A_i^k \rightarrow B_i^j C_j^k$. With this notation in our hand we can provide the recursive definition to the inside weight of anchored non-terminals:
\begin{equation}\label{eq:inside_update}
\alpha[A_i^k] = \sum_{B, C}\sum_{j=i+1}^k \mbox{g}[A \rightarrow B C] \times \alpha[B_i^j] \times \alpha[C_j^k]. \tag{4}
\end{equation}
The inside-weight $\alpha[A_i^k]$ corresponds to the sum of all the left and right sub-trees considering all possible split points $j$ and all possible non-terminals B and C (figure 5).
In the previous section, we saw that we could transform the CYK recognition algorithm into the inside algorithm, by just changing the underlying semiring. With his small adjustment, we showed that we can compute the partition function (sum of weights of all tree derivations) in polynomial time. In this section, we apply a similar trick to weighted parsing.
Recall that the partition function $Z$ was defined as the sum of all possible derivations:
\begin{eqnarray}
Z &=& \sum_{T \in \mathcal{T}[\mathbf{w}]} \mbox{G}[T] \nonumber \\
&=& \sum_{T \in \mathcal{T}[\mathbf{w}]} \prod_{t \in T} \mbox{g}[T_t], \tag{5}
\end{eqnarray}
In contrast, weighted parsing aims to find the derivation $T^{*}$ with the highest weight among all possible derivations:
\begin{eqnarray}
T^{*} &=& \underset{T \in \mathcal{T}[\mathbf{w}]}{\text{arg} \, \text{max}} \; \left[\mbox{G}[T]\right] \nonumber \\
&=& \underset{T \in \mathcal{T}[\mathbf{w}]}{\text{arg} \, \text{max}} \left[\prod_{t \in T} \mbox{g}[T_t]\right], \tag{6}
\end{eqnarray}
where $\mbox{G}[T]$ is the weight of a derivation tree which is computed by taking the product of the weights $\mbox{g}[T_t]$ of the rules.
Once again we will modify the semiring in the CYK algorithm to perform the task. Let us replace the sum-product semiring $\langle\mathbb{R}_{\geq 0} \cup \{+\infty\} , +, \times, 0, 1\rangle$ with the max-product semiring $<\mathbb{R}_{\geq 0} \cup \{+\infty\} , \max[\bullet], \times, 0, 1>$ to find the score of the "best" derivation. This gives us the following algorithm:
0 # Initialize data structure
1 chart[1...n, 1...n, 1...|V|] := 0
2
3 # Use unary rules to find possible parts of speech at pre-terminals
4 for p := 1 to n # start position
5 for each unary rule A -> w_p
6 chart[1, p, A] := g[A -> w_p]
7
8 # Main parsing loop
9 for l := 2 to n # sub-sequence length
10 for p := 1 to n-l+1 # start position
11 for s := 1 to l-1 # split width
12 for each binary rule A -> B C
13 chart[l, p, A] = max[chart[l, p, A],
(g[A -> B C] x chart[s, p, B] x chart[l-s,p+s, C]]
14
15 return chart[n, 1, S]
The differences from the CYK recognition algorithm are colored in green, and the single difference from both the inside algorithm and the CYK recognition algorithm is colored in orange.
Once more, each position $(p,l)$ in the $\texttt{chart}$ represents the sub-sequence that starts at position $p$ and is of length $l$. In the inside algorithm, each position contained a vector with one entry for each of the $|V|$ rules. Each element of this vector contained the sum of the weights of all of the sub-trees which feed into this anchored non-terminal. In this variation, each element contains the maximum weight among all the sub-trees that feed into this anchored non-terminal. Position (n,1) represents the whole string, and so the value $\texttt{chart[n, 1, S]}$ is the maximum weight among all valid parse trees. If this is zero, then there is no valid derivation.
The update rule at line 13 for the weight at $\texttt{chart[l, p, A]}$ now has the following interpretation. For each rule $\texttt{A -> B C}$ and for each possible split $\texttt{s}$ of the data, we multiply the the rule weight $\texttt{g[A -> B C]}$ by the two weights $\texttt{chart[s, p, B]}$ and $\texttt{chart[l-s, p+s, B]}$ associated with the two child sub-sequences. If the result is larger than the current highest value, then we update it. If we are interested in the parse tree itself, then we can store back-pointers indicating which split yielded the maximum value at each position, and traverse backwards to retrieve the best tree.
In figure 6 we illustrate worked example of weighted parsing. The algorithm starts by assigning weights to pre-terminals exactly as in figure 3a. The computation of the weights for sub-sequences of length $l=2$ is also exactly as in figure 3b, and the algorithm also proceeds identically for $l=3$ and $l=4$.
The sole difference occurs for the sub-sequence of length $l=5$ at position $p=2$ (figure 6). There are two possible rules that both assign the non-terminal VP to the chart at this position. In the inside algorithm, we calculated the weights of these rules and summed them. In weighted parsing, we store the largest of these weights, and this operation corresponds to the $\mbox{max}[\bullet,\bullet]$ function on line 13 of the algorithm.
At the end of the procedure, the weight associated with the start symbol at position (6,1) corresponds to the tree with the maximum weight and so is considered the "best". By keeping track of which sub-tree yielded the maximum weight at each split, we can retrieve this tree which corresponds to our best guess at parsing the sentence.
We've seen that we can add weights to CFGs and replace the $AND, OR$ semiring with $+, \times$ to find the total weight of all possible derivations (i.e. compute the partition function with the inside algorithm). Further more, but we can use $\max, \times$ instead to find the parse tree with the highest weight.
The semirings allow us to unify the CYK recognition, inside, and weighted parsing algorithms by recursively defining the chart entries as:
\begin{equation}
\texttt{chart}[A_i^k] = \bigoplus_{B, C, j} \mbox{g}[A \rightarrow B C] \otimes \texttt{chart}[B_i^j] \otimes \texttt{chart}[C_i^k], \tag{7}
\end{equation}
where for recognition $\mbox{g}[A \rightarrow B C]$ just returns $\texttt{TRUE}$ for all existing rules.
Readers familiar with graphical models, will no doubt have noticed the similarity between these methods and sum-product and max-product belief propagation. Indeed, we could alternatively have presented this entire argument in terms of graphical models, but the semiring formulation is more concise.
In the final part of this blog, we will consider probabilistic context-free grammars, which are a special case of weighted-context free grammars. We'll develop algorithms to learn the weights from (i) a corpus of sentences with known parse trees and (ii) just the sentences. The latter case will lead to a discussion of the famous inside-outside algorithm.
^{1. }If you are wondering why is it "semi", its because the magnificent rings also have an additive inverse for each element: $x \oplus (-x) = 0$.
]]>In the technical demo, users get to see this interactive system at work.
The value proposition of a project like this is about democratizing data-driven insights by enabling non-technical users to interact with structured data, using natural language.
"Today, a lot of potentially useful knowledge and insights is trapped in databases, and only technical users can access that information, typically by using SQL. Turing by Borealis AI’s database interface unlocks these insights for non-technical users, who can query the multitude of databases using natural language and get the results and insights they need."
- Yanshuai Cao, Senior Research Team Lead at Borealis AI
Turing by Borealis AI comes closer than most of technology available today, achieving and holding state-of-the-art performance levels, while reducing accuracy issues. Such cross-domain text-to-SQL semantic parsers generally have serious accuracy and usability problems, making practical applications a challenge. Unlike in online search, where approximate answers can be good enough, when users query relational databases to glean specific insights, high degree of accuracy is needed to provide value. With Turing by Borealis AI’s technology, a user can look at multiple hypotheses and with the help of explanation Turing by Borealis AI provides, can figure out which of the SQL queries comes closest to the search intent.
Here’s a sample use case: Let's say a non-technical user is in the business of delivering supplies to gas stations. The user wants to query available databases and find out which stations to contact next, in order to grow the business. How would the user get these business insights, without relying on SQL to do the search across available databases? With Turing by Borealis AI, users can start the search by picking the ‘gas station domain’ and ask: "What are the locations with gas stations owned by companies making over 100 billion in sales?" Under the hood, there is a deep learning model that treats the text-to-SQL problem as graph-to-tree mapping and produces a SQL query, executing it against the database to return the results.
Turing by Borealis AI generates SQL and uses a synchronous context-free grammar system to provide a high-precision explanation, so that users can make sure the results are trustworthy and match the intent.
Learn more about cross-database text-to-SQL in this blog, with further details on Turing by Borealis AI in this paper and here.
The team is presenting Turing by Borealis AI and related works: two main conference papers, one demo paper and one workshop paper at the joint conference of the Association for Computational Linguistics and the International Joint Conference on Natural Language Processing (ACL-IJCNLP 2021) on August 1-6, 2021. The team is also aiming to release the core of its semantic parsing at that time.
]]>People communicate in natural language, which is flexible but often vague, whereas computer languages have no room for ambiguity. For a computer to respond to users' questions or commands in natural language, it needs to extract meaning, resolve ambiguity, and translate to executable programs. This is the task of semantic parsing (SP), whose applications include voice assistants, code generation, natural language interfaces to databases (NLDB), and many more. Our Turing by Borealis AI system is an NLDB, a software system enabling users to interact with databases in natural language, as illustrated in Figure 1.
The semantic parsing model powering an NLDB needs to be trained with questions and their corresponding SQL queries. If the model only generalizes to new questions on the training domain, the NLDB cannot be quickly adapted to new databases, so it would not be very useful in practice. Hence, the model somehow needs to generalize to new databases with unseen schema and unseen questions. This is cross-domain or cross-database text-to-SQL semantic parsing.
The goal of this blog post is to glimpse into how models (like our Turing by Borealis AI system) for this task work without popping the hood. It is suitable for any reader with basic knowledge of machine learning and natural language processing.
We will first give a brief review of SQL that readers can skip if already familiar, then introduce two running examples of text-to-SQL prediction. The examples will illustrate some challenges involved in cross-domain semantic parsing and illustrate why simple methods would not succeed. Afterwards, we will describe a high-level framework that treats cross-database text-to-SQL as a graph-to-tree mapping. We will use the two running examples, to show how the framework tackles the challenges that we identified. Finally, we will provide some pointers for interested readers to learn more, including our recent ACL papers (Xu et al., 2021a,b; Norouzi et al., 2021) that respectively set the new state-of-the-art accuracy on the Spider text-to-SQL benchmark and some code generation problems.
Before showing the examples, let us review some SQL basics. SQL stands for Structured Query Language and is used for storing, manipulating and retrieving data in relational databases. We will just focus on the retrieval here.
Relational databases store information records in tables. The schema of a database describes the structure of the domain: what are the tables, what columns does each table contain, the data type of each column, as well as special roles that some columns play. The first type of special role is a primary key. This is a column or a combination of columns that has to be unique for each data record. The second type of special role is a foreign key, which is a column or combination of columns whose values match the primary key records of another table. Foreign key relations link tables together.
SELECT
QueryA basic SQL query looks like the following SELECT * FROM my_table
, where *
is a reserved token meaning "all columns". This query will return all rows of the table my_table
. The star can be replaced by one or more column names, in which case, the query would only return the mentioned attributes in each row. Slightly more advanced queries will involve filtering condition, expressed using a WHERE
clause: SELECT * FROM my_table WHERE condition
. This query will only return records for which the condition
holds true. The SQL syntax for the actual condition is generally self-explanatory.
GROUP BY
and HAVING
Sometimes columns could correspond to categorical attributes like "sector". Here, an interesting class of questions involves aggregating some properties associated with each categorical value of the column. For this purpose, we would need the GROUP BY
clause: SELECT MAX(salary), sector FROM my_table GROUP BY sector
, which would find the highest salary per each sector. If we want to filter the categories, we can use the HAVING
clause. For example, we might want to filter out sectors based on their associated statistics, HAVING
is similar to WHERE
but operates on grouped categories instead. For example, SELECT MAX(income), sector FROM my_table GROUP BY sector HAVING AVG(salary) < 50000.
JOIN
Last but not least, the concept of JOIN
needs some explanation. As SQL databases store records in tables, sometimes we need to "merge" corresponding rows of two or more tables. We might need the merged records as the final result or as an intermediate step to compute something else. This requires joining one or more tables with the syntax: SELECT * FROM table1 JOIN table2 ON table1.col_fkey = table2.col_pkey
. The ON
part introduces a condition that is usually an equality relation between the foreign key and primary key columns like in this example but can also be on other columns. This query returns the combination of rows in table1
and rows in table2
whose value in the column col_fkey
of table1
equals to the value of col_pkey
of table2
.
To predict the correct SQL from a natural language question, the model needs to correctly interpret each input word in the context of both the sentence and the schema. Furthermore, it needs to generate a syntactically correct SQL query as the output otherwise the database cannot execute it. To illustrate the challenges more concretely, let's consider two examples for the "Employee_hire_evaluation" database of the Spider benchmark. This database is a development set domain that models would not have seen during training.
The database has the following tables: employee
, shop
, hiring
, evaluation
. Each table has a number of columns:
employee
: employee_id
, name
, age
, city
, with employee_id
being the primary key.shop: shop_id, name, location, district, number_products, manager_name
, with shop_ID
being the primary key.hiring: shop_id, employee_ID, start_from, is_full_time
, with employee_id
being the primary key, and also a foreign key to Employee table's employee_id
, and shop_id
being a foreign key to shop
table's shop_id
.evaluation: employee_id, year_awarded, bonus
, with employee_id
and year_awarded
together as the primary key, and employee_id
as a foreign key referencing Employee table's employee_id
.Question: Which cities have more than one employee under 30?
Correct SQL:
SELECT employee.city
FROM employee
WHERE employee.age < 30
GROUP BY employee.city
HAVING COUNT (*) > 1
Analysis: Besides the general logic of the SQL query, a model needs to infer two conditions from the question, employee.age < 30
and COUNT (*) > 1
. The entities involved in the conditions (tables, columns or the star) are not explicitly mentioned in the text and have to be inferred. The model needs to deduce that "employee under 30" requires column age
, by leveraging two pieces of information. First, it can have some prior common sense knowledge that the expression "employee under [NUMBER]" refers to employee age rather than some other attribute. Second, it could exclude other columns because the value "$30$" is too different from other columns' values based on type or range. For the second condition, the model needs to infer from the entire phrase "Which cities have more than one employee ..." that the condition is on the number of employees in each city, hence requiring GROUP BY
[$\ldots$] HAVING
[$\ldots$]. Finally, it needs to piece the two conditions together as well as the rest of the query using the correct syntax.
Question: What's the average age in each shop?
Correct SQL:
SELECT AVG (employee.age) , shop.shop_id
FROM employee
JOIN hiring
JOIN shop ON employee.employee_id = hiring.employee_id
AND hiring.shop_id = shop.shop_id
GROUP BY shop.shop_id
Analysis: To correctly predict this SQL, not only does the SP model needs to infer correctly from "in each shop" that the output contains GROUP BY shop.shop_id
, it also needs to infer the involvement of tables employee
, hiring
which are not explicitly mentioned like shop
. The table employee
can be inferred based on the need for its age
column. On the other hand, the hiring
table can only be inferred from the need to link between employee.age
and shop.shop_id
.
You might wonder whether some generic or simple approach can already solve this cross-database text-to-SQL problem. For example, let's consider the sequence-to-sequence model often used in machine translation. Text-to-SQL semantic parsing bears some similarity to machine translation if we view SQL as a foreign language to translate into. However, some crucial differences exist. First, typical training datasets for machine translation larger than those for SQL semantic parsing by two orders of magnitude or even more. Second, in machine translation, partially correct results can still provide partial utility, but for an NLDB, any small mistake in the predicted SQL query could invalidate the result. Third, as we have seen from the examples, the database schema is crucial for correct translation to SQL, which sequence-to-sequence machine translation models do not consider. For these reasons, typical neural sequence-to-sequence models do not work well.
Another baseline is shallow semantic parsing in which we simplify the problem and assume that there are a fixed number of user intentions. An intent classifier could then select the template that best corresponds to the user question from a pre-defined list. Then a model extracts the relevant information from the user question to fill in the template slots. For instance, we can turn the first example into a template whose SQL would have some slots to be filled:
SELECT employee.city
FROM employee
WHERE employee.age [COMP_A] [A]
GROUP BY employee.city
HAVING COUNT (*) [COMP_C] [C]
Given enough training examples of question tagged with its corresponding template ID and slot values, then a model could potentially answer questions like "show me the cities with less than 5 employees over twenty five.", by identifying this template out of many, then predicting that COMP_A
$:=$ <
, A
$:=$ 5
, COMP_C
$:=$ >
, C
$:=$ 25
. This approach is commonly used in voice-assistant and task-oriented dialogue systems. The main drawback is that the templates need to be pre-defined, so the system cannot generalize to new queries on the fly. Hence this approach is also unsuitable for cross-database NLDB in general.
As shown by the two running examples, successful cross-database SQL semantic parsing really requires the model to reason using at least three sets of knowledge:
We now describe a general framework for cross-database text-to-SQL that leverages all of this knowledge. The backbone of the overall system is a neural network with encoder-decoder architecture, which is adapted in various ways to leverage explicit symbolic knowledge.
Motivated by the examples, we see that the model needs to jointly encode the question and schema, considering how words relate to each other within and across the question and the schema. So the input for cross-database semantic parsing has an inherent graph structure; the nodes are the tokens in the questions and schema and are linked by different edges. On the output side, to produce grammatically correct SQLs and leverage programming-language-specific inductive prior, we treat the prediction problem as generation of the abstract syntax tree (AST) of the program. Hence, we can characterize this task a graph-to-tree mapping.
Figure 2 illustrates the overall framework for Example One: an encoder consumes the input graph, and a decoder produces the output AST. Joint modelling of question and schema as a graph was popularized by the relation-aware transformer (RAT) work (Wang et al., 2019) while TranX (Yin and Neubig, 2018) provides a unified framework for modelling output programs as ASTs. Our Turing by Borealis AI system also follows this overall approach, with many additional innovations that we will not cover here.
As mentioned above, we view each token in the question and schema as a node in a graph. The most basic edge type among the nodes is a generic link between any pair of tokens, reflecting the assumption that a-priori any token could provide relevant context to any other token, so a link cannot be ruled out. This essentially yields a fully connected graph. For visual simplicity, we omit these edges from Figure 2.
However, other types of relations carry special meanings and are sparse. These include (i) foreign key relations that link a column in one table to the primary key of another table, (ii) exact string match and partial string match between words in the questions and words in column or table names and (iii) implicit links between a table and its columns. Some of these edges are illustrated in different colours on the input side in Figure 2. Because there can be more than one type of edge between two tokens to be modelled, this input is technically a multi-graph.
How do these edges help predict the correct SQL? Let's return to the examples.
In Example One (Figure 2), the word "employee" in the question exactly matches the table name employee
, so a special edge for an exact match is created in this input graph during preprocessing. For a graph neural network or relation-aware transformer that can encode a graph by propagating information along edges, this link creates a potential pathway for information from the representation of the columns (employee_ID, name, age, city
) of table employee
to contextualize the representation of the question token "employee", and vice versa. This makes it more likely for employee_ID, name, age, city
to be selected compared to columns in the other tables when predicting a column corresponding to the condition "employee under 30".
The second example is more interesting. The question mentions the table name shop
explicitly, while the table employee
can be easily inferred based on the column mention age
. However, for hiring
there is no textual evidence from the question, direct or indirect, that the SQL query should involve hiring
. The only way to infer is through the foreign key links and the fact that otherwise shop
and employee
are disconnected and cannot be joined. This potential reasoning process is illustrated in Figure 3.
Now that we understand how the (multi-)graph structure would help the semantic parsing model, let's formalize what the encoder does at a high level. Let $\mathcal{S}=\{s_1,\dots, s_{\lvert \mathcal{S} \rvert}\}$ denote the schema elements, consisting of tables and their columns, and use $Q=q_1\dots q_{\lvert Q \rvert}$ to denote the sequence of words in the question. Let $\mathcal{G}=\langle\mathcal{V}, \mathcal{E}\rangle$ denote the multi-graph with edge sets $\mathcal{E}$. The encoder, $f_{\text{enc}}$, maps $\mathcal{G}$ to a joint representation $ \mathcal{H} =$ $\{\phi^q_1, \ldots,\phi^q_{\lvert Q \rvert} \} \cup \{\phi^s_1, \ldots,\phi^s_{\lvert \mathcal{S} \rvert} \}$. The fully connected portion of the multi-graph can be well modelled by a transformer (see [link] for our blog series on Transformers). Indeed, one can flatten the schema into a linear string, with the tokens belonging to different column or table names separated by a special token like "[SEP]" and concatenate this string with the question string before feeding into a pre-trained model such as BERT. The use of pre-trained BERT (or other variants) here is how implicit common sense knowledge is embodied in the semantic parser. To model information propagation along the special sparse edges of the multi-graph, we can then feed the BERT output embeddings into a relation-aware transformer (Wang et al., 2019). There are a few subtle details omitted here, which we will give some pointers for at the end of this article.
If we model SQL queries as linear sequences of text tokens on the output side, it is not easy to leverage the SQL grammar knowledge. During inference, one could use a grammar validator program to check if a generated sequence is legal; however, the neural network is still not using this information during training for better generalization. Furthermore, the grammar not only captures what is illegal but also how SQL expressions can be composed. Leveraging this prior knowledge will significantly improve the learning efficiency from a small number of examples. Therefore, we want to cast the problem as generating the abstract syntax tree of SQL queries.
A common approach to predict an abstract syntax tree (AST) is to use a grammar-based transition system like TranX (Yin and Neubig, 2018), which decomposes the generation process of an abstract syntax tree (AST) into a sequence of actions. The neural model learns to predict the action sequence, and the transition system then constructs the AST using the predicted action sequence. Finally, another deterministic routine maps the AST into a linear string format of SQL, aka the surface code (Figure 4).
Figure 5 shows a snippet of the SQL grammar for TranX used by our Turing by Borealis AI system. It is specified in an abstract syntax description language (ASDL). It is similar to a context-free grammar, but more powerful, with each production rule's right-hand side being a function call signature with strongly-typed arguments. The type names are non-terminal symbols in the grammar, for which there are further production rules. This grammar is specific to the programming language of interest, or a subset of features in a programming language, and needs to be developed by a human expert.
stmt = Intersect(query_expr lbody, query_expr rbody) |
The transition system converts between an AST and its AST-constructing action sequence, leveraging a grammar like the one in Figure 5. The transition system starts at the root of the AST and derives the action sequence by a top-down, left-to-right depth-first traversal of the tree. At each step, it generates one of the possible parametrized action types.
For cross-domain text-to-SQL parsing, the action types can include: (1) ApplyRule[$r$] which applies a production rule $r$ of the grammar to the latest generated node in the AST; (2) Reduce which marks the complete generation of a subtree corresponding to a function call (in the ASDL grammar); (3-4) SelectTable[$t$] and SelectColumn[$c$] which, respectively, choose a table $t$ and a column $c$ from the database schema $\mathcal{S}$; (5) CopyToken[$k$] which copies a token $q_k$ from the user question $Q$; (6) GenToken[$l$] which generates a token $w_l$ from a vocabulary. In practice, with careful design, it is possible to simplify and avoid SelectTable and GenToken, which is part of the technical novelties in our Turing by Borealis AI system.
Before training, the TranX system first converts the surface SQL code to the AST representation using a deterministic domain-specific routine. Then, leveraging the grammar, it converts the AST into the action sequence (Figure 4). The actual training is then standard maximum likelihood with teacher-forcing, which you can read about in this tutorial. At each step, the model predicts the correct action conditioned on the ground-truth partial action sequence up to that point, as well as the encoder representation $\mathcal{H}$ of the question and schema. Most of the action types are parameterized by some argument, for example, production rule $r$ for ApplyRule, column $c$ for SelectColumn. The model first predicts the action type, then conditioned on the ground-truth action type (regardless of the predicted one), predicts the argument.
The inference process builds upon beam-search, which you can learn more about in this tutorial. The difference here is that the beam-search is guided by the grammar and the transition system. This grammar-guided beam-search decoding sounds complex and indeed has many tedious implementation details, but it is conceptually simple: at each step of decoding, for each partial sequence in the beam, the transition system tracks all action types and arguments that are legal according to the grammar; the neural net can only select from those options. Once beam-search produces multiple action sequences, the transition system converts them to ASTs, then converts them to surface SQL code strings using the domain-specific post-processing routine as illustrated in Figure 5.
Besides the neural attention over the encoder representation, some other weak reasoning using the grammar happens here during beam-search inference. By tracking multiple partial trees (implicitly, via partial action sequences), a hypothesis scored high at the beginning could drop sharply because its high-probability continuation could violate the grammar. As a result, another partial tree that is less likely at first, could become more plausible and eventually be the top prediction.
Inferring and encoding special edges in the multi-graph: we saw some examples of special edges between a question token and schema word, but there could be other types of links. For example, suppose a question word happens to match a database value in some column. In that case, this is evidence that this question word has an implicit relationship to the corresponding column. More generally, these edges are inferred using heuristic pre-processing rules, in a process known as schema linking. The relation-aware transformer layers can learn to deal with some degree of noise in the links. For more details, please see the original RAT paper (Wang et al,. 2019).
We also discussed using a pre-trained transformer to encode the implicit fully-connected part of the multi-graph, in conjunction with RAT-based modelling of the sparse special edges. But the pre-trained transformer builds contextualized representation for subword tokens, whereas table and column names are usually phrases. The Appendix section of Xu et al. (2021a) contains more information about how these models can be pieced together.
Modelling tables implicitly through columns: as mentioned previously, it is possible to drop the SelectTable action altogether. The idea is to globally uniquely identify the columns rather than using the column names only. We can add the table representation to all of its column representations on the input encoding side before feeding into the RAT layers. On the output side, we can give each column a globally unique ID for SelectColumn. Then the table can be inferred deterministically from the predicted columns during post-processing. This design choice simplifies the relation learning for encoding and makes the output action sequences shorter. On some rare occasions, this becomes an over-simplification causing failures for some complex queries, for instance, when there are multiple self-joins. Please see XU et al., (2021b) for more details.
TranX transition system and leveraging tree structures in the neural decoder: so far, we only showed how TranX works on the high level, but readers interested in using the framework for semantic parsing should consult (Yin and Neubig, 2018) for more details. In particular, the TranX transition system exposes the topology of the AST to the linear action sequence decoding process via something called parent frontier field. The parent does not always correspond to the immediate preceding step in the action sequence. Yet, it is important to directly condition on its representation during decoding, which is known as parent feeding.
Handling values in the question: in Example One, the value $30$ from the question is exactly the token needed in the condition part of the SQL statement, so it can be just copied over. However, in general, this might not always be the case. Most models use a combination of generation and copy attention. But as mentioned earlier, Turing (Xu et al., 2021b) simplifies away the generation and only performs the copy action. The idea is that during training, the model learns to identify the question text span providing evidence for the value, which significantly simplifies the learning problem and reduces overfitting. A heuristic search-based post-processor is responsible for producing the actual value to be used in the SQL at inference time.
Training and generalization when the model is deep and the dataset small: using relation-aware transformer layers on top of pre-trained transformers like BERT or RoBERTa can quickly make the overall model very deep and hard to train. The usual rules-of-thumb for optimizing transformers are to use a large batch size, make the model shallower, or both. However, our recent work finds a way to train ultra-deep transformers ($48$ layers) using a small batch size and this improves the model generalization, especially for hard cases. This technique allowed us to place No. $1$ on the Spider Leaderboard (Exact Set Match without Values) ^{1}.
Beyond teacher-forcing maximum likelihood: other sequence learning methods could also be used in theory, such as scheduled sampling or beam-search optimization (BSO). See our work on training a globally normalized semantic parsing model using a method similar to BSO (Huang et al., 2021), which works on some simple dataset, but not yet on complex ones like Spider.
Other approaches for semantic parsing: there are other promising approaches that do not follow the framework presented in this blog. For cross-database semantic parsing, Rubin and Berant (2021) abandons autoregressive decoding, but instead performs semi-autoregressive bottom-up semantic parsing. The advantage is that at each step of decoding, the model both conditions on and predicts semantically meaningful sub-programs, instead of semantically-vacuous partial trees. The method performs competitively on Spider, which is impressive; moreover, it potentially has better compositional or out-of-distribution generalization. On the other end of the spectrum, if our goal is not cross-domain text-to-SQL, but generic code generation, then our recent ACL work (Norouzi et al., 2021) shows that leveraging a large monolingual corpus of programming language source code enables simple transformer-based seq-to-seq baseline to perform competitively. Note that this does not contradict our discussion about simple seq-to-seq baseline unable to perform well in cross-database semantic parsing.
Explaining the queries: an essential feature of Turing by Borealis AI is the ability to explain the predicted queries to non-technical users. This allows people to use their own judgment to pick out which of the top hypotheses is more likely to be correct. Please check out our paper (Xu et al., 2021b) for more information about the explanation system.
1 As of June-02-2021, the time of publication of this blog. Our entry is "DT-Fixup SQL-SP + RoBERTa (DB content used) Borealis AI".
The current dominant paradigm in natural language processing is to build enormous language models based on the transformer architecture. Models such as GPT3 contain billions of parameters, which collectively describe joint statistics of spans of text and have been extremely successful over a wide range of tasks.
However, these models do not explicitly take advantage of the structure of language; native speakers understand that a sentence is syntactically valid, even if it is meaningless. Consider how Colorless green ideas sleep furiously feels like valid English, whereas Furiously sleep ideas green colorless does not ^{1}. This structure is formally described by a grammar, which is a set of rules that can generate an infinite number of sentences, all of which sound right, even if they mean nothing.
In this blog, we review earlier work that models grammatical structure. We introduce the CYK algorithm which finds the underlying syntactic structure of sentences and forms the basis of many algorithms for linguistic analysis. The algorithms are elegant and interesting for their own sake. However, we also believe that this topic remains important in the age of large transformers. We hypothesize that the future of NLP will consist of merging flexible transformers with linguistically informed algorithms to achieve systematic and compositional generalization in language processing.
Our discussion will focus on context-free grammars or CFGs. These provide a mathematically precise framework in which sentences are constructed by recursively combining smaller phrases usually referred to as constituents.^{2} Sentences under a CFG are analyzed through a tree-structured derivation in which the sentence is recursively generated phrase by phrase (figure 1).
The problem of recovering the underlying structure of a sentence is known as parsing. Unfortunately, natural language is ambiguous and so there may not be a single possible meaning; consider the sentence I saw him with the binoculars. Here, it is unclear whether the subject or the object of the sentence holds the binoculars (figure 2). To cope with this ambiguity, we will need weighted and probabilistic extensions to the context free grammar (referred to as WCFGs and PCFGs respectively). These allow us to compute a number that indicates how "good" each possible interpretation of a sentence is.
In Part I of this series of two blogs, we introduce the notion of a context-free grammar and consider how to parse sentences using this grammar. We then describe the CYK recognition algorithm which identifies whether the sentence can be parsed under a given grammar. In Part II, we introduce the aforementioned weighted context-free grammars and show how the CYK algorithm can be adapted to compute different quanties including the most likely sentence structure. In Part III we instroduce probabilistic context-free grammars, and we present the inside-outside algorithm which efficiently computes the expected counts of the rules in the grammar for all possible analyses of a sentence. These expected counts are used in the E-Step of an expectation-maximization procedure for learning the rule weights.
Before tackling these problems, we'll first discuss the properties of a parse tree (figure 3). The root of the tree is labelled as "sentence" or "start". The leaves or terminals of the tree contain the words of the sentence. The parents of these leaves are called pre-terminals and contain the part-of-speech (POS) categories of the words (e.g., verb, noun, adjective, preposition). Words are considered to be from the same category if a sentence is still syntactically valid when they are substituted. For example: The {sad, happy, excited, bored} person in the coffee shop. This is known as the substitution test. Above the pre-terminals, the word categories are collected together into phrases.
There are three more important things to notice. First, the verb phrase highlighted in magenta has three children. However, there is no theoretical limit to this number. We could easily add the prepositional phrases in the garden and under a tree and so on. The complexity of the sentence is limited in practice by human memory and not by the grammar itself.
Second, the grammatical structure allows for recursion. In this example, a verb phrase is embedded within a second verb phrase, which itself is embedded in a third verb phrase. Finally, we note that the parse tree disambiguates the meaning of the sentence. From a grammatical point of view, it could be that it was the bone that was enjoying every moment. However, it is clear that this is not the case, since the verb phrase corresponding to enjoying is attached to the verb phrase corresponding to eating and not the bone (see also figure 2).
In this section, we present a more formal treatment of context-free grammars. In the following section, we'll elucidate the main ideas with an example.
A language is a set of strings. Each string is a sequence of terminal symbols. In figure 3 these correspond to individual words, but more generally they may be abstract tokens. The set of terminals $\Sigma=\{\mbox{a,b,c},\ldots\}$ is called an alphabet or lexicon. There is also a set $\mathcal{V}=\{\mbox{A,B,C}\ldots...\}$ of non-terminals, one of which is the special start symbol $S$.
Finally, there are a set $\mathcal{R}$ of production or re-write rules. These relate the non-terminal symbols to each other and to the terminals. Formally, these grammar rules are a subset of the finite relation $\mathcal{R}\in \mathcal{V} \times (\Sigma \cup \mathcal{V})^*$ where $*$ denotes the Kleene star. Informally, this means that each grammar rule is an ordered pair where the first element is a non-terminal from $\mathcal{V}$ and the second is any possible string containing terminals from $\Sigma$ and non-terminal from $\mathcal{V}$. For example, B$\rightarrow$ab, C$\rightarrow$Baa and A$\rightarrow$AbCa are all production rules.
A context free grammar is the tuple $G=\{\mathcal{V}, \Sigma, \mathcal{R}, S\}$ consisting of the non-terminals $\mathcal{V}$, terminals $\Sigma$, production rules $\mathcal{R}$, and start symbol $S$. The associated context-free language consists of all possible strings of terminals that are derivable from the grammar.
Informally, the term context-free means that each production rule starts with a single non-terminal symbol. Context-free grammars are part of the Chomsky hierarchy of languages which contains (in order of increasing expressiveness) regular, context-free, context-sensitive, and recursively enumerable grammars. Each differs in the family of production rules that are permitted and the complexity of the associated parsing algorithms (table 1). As we shall see, context-free languages can be parsed in $O(n^{3})$ time where $n$ is the number of observed terminals. Parsing more expressive grammars in the Chomsky hierarchy has exponential complexity. In fact, context-free grammars are not considered to be expressive enough to model real languages. Many other types of grammar have been invented that are both more expressive and parseable in polynomial time, but these are beyond the scope of this post.
Language | Recognizer | Parsing Complexity |
Recursively enumerable Context-sensitive Context-free Regular |
Turing machine Linear-bounded automata Pushdown automata Finite-state automata |
decideable PSPACE $O(n^3)$ $O(n)$ |
Table 1. The Chomsky hierarchy of languages. As the grammar-type becomes simpler, the required computation model (recognizer) becomes less general and the parsing complexity decreases.
Consider the context free grammar that generated the example in figure 4. Here, the set of non-terminals $\mathcal{V}=\{\mbox{VP, PP, NP, DT, NN, VBZ, IN,}\ldots\}$ contains the start symbol, phrases, and pre-terminals. The set of terminals $\Sigma=\{$The, dog, is, in, the, garden, $\ldots \}$ contains the words. The production rules in the grammar associated with this example include:
Of course, a full model of English grammar contains many more non-terminals, terminals, and rules than we observed in this single example. The main point is that the tree structure in figure 4 can be created by the repeated application of a finite set of rules.
Later on, we will describe the CYK recognition algorithm. This takes a sentence and a context-free grammar and determines whether there is a valid parse tree that can explain the sentence in terms of the production rules of the CFG. However, the CYK algorithm assumes that the context free grammar is in Chomsky Normal Form (CNF). A grammar is in CNF if it only contains the following types of rules:
\begin{align}
\tag{binary non-terminal}
\text{A} &\rightarrow \text{B} \; \text{C} \\
\tag{unary terminal}
\text{A} &\rightarrow \text{a} \\
\tag{delete sentence}
\text{S} &\rightarrow \epsilon
\end{align}
where A,B, and C are non-terminals, a is a token, S is the start symbol and $\epsilon$ represents the empty string.
The binary non-terminal rule means that a non-terminal can create exactly two other non-terminals. An example is the rule $S \rightarrow \text{NP} \; \text{VP}$ in figure 4. The unary terminal rule means that a non-terminal can create a single terminal. The rule $\text{NN} \rightarrow$ $\text{dog}$ in figure 4 is an example. The delete sentence rule allows the grammar to create empty strings, but in practice we avoid $\epsilon$-productions.
Notice that the parse tree in figure 3 is not in Chomsky Normal Form because it contains the rule $\text{VP} \rightarrow \text{VBG} \; \text{NP} \; \text{VP}$. For the case of natural language processing, there are two main tasks to convert a grammar to CNF:
Both of these operations introduce new non-terminals into the grammar. Indeed, in the former case, we may introduce different numbers of new non-terminals depending on which children we choose combine. It can be shown that in the worst-case scenario, converting CFGs into an equivalent grammar in Chomsky Normal Form results in a quadratic increase in the number of rules. Note also that although the CNF transformation is the most popular, it is not the only, or even the most efficient option.
Given a grammar in Chomsky Normal Form, we can turn our attention to parsing a sentence. The parsing algorithm will return a valid parse tree like the one in figure 6 if the sentence has a valid analysis, or indicate that there is no such valid parse tree.
It follows that one way to characterize a parsing algorithm is that it searches over the set of all possible parse trees. A naive approach might be to exhaustively search through these trees until we find one that obeys all of the rules in the grammar and yields the sentence. In the next section, we'll consider the size of this search space, find that it is very large, and draw the conclusion that this brute-force approach is intractable.
The parse tree of a sentence of length $n$ consists of a binary tree with $n-1$ internal nodes, plus another $n$ nodes connecting the pre-terminals to the terminals. The number of binary trees with $n$ internal nodes can be calculated via the recursion:
\begin{equation}
C_{n} = \sum_{i=0}^{n-1}C_{n-i}C_{i}. \tag{1}
\end{equation}
The intuition for this recursion is illustrated in figure 7. This series of intergers are known as the Catalan number and can be written out explicitly as:
\begin{equation}
C_n = \frac{(2n)!}{(n+1)!n!}. \tag{2}
\end{equation}
Needless to say the series grows extremely fast:
\begin{equation}
1, 1, 2, 5, 14, 42, 132, 429, 1430, 4862, 16796, 58786, \ldots \tag{3}
\end{equation}
Consider the example sentence I saw him with the binoculars. Here there are only C_5=42 possible trees, but these must be combined with the non-terminals in the grammar (figure 8). In this example, for each of the 42 trees, each of the six leaves must contain one of four possible parts of speech (DT, NN, P, VBD) and each of the five non-leaves must contain one of four possible clause types (S, NP, VP, PP) and so there are 42 * 4^6 * 4^5 = 176160768 possible parse trees.
Even this minimal example had a very large number of possible explanations. Now consider that (i) the average sentence length written by Charles Dickens was 20 words, with an associated $C_{20}=6,564,120,420$ possible binary trees and (ii) that there are many more parts of speech and clause types in a realistic model of the English language. It's clear that there are an enormous number of possible parses and it is not practical to employ exhaustive search to find the valid ones.
The CYK algorithm (named after inventors John Cocke, Daniel Younger, and Tadao Kasami) was the first polynomial time parsing algorithm that could be applied to ambiguous CFGs (i.e., CFGs that allow multiple derivations for the same string). In its simplest form, the CYK algorithm solves the recognition problem; it determines whether a string $\mathbf{w}$ can be derived from a grammar $G$. In other words, the algorithm takes a sentence and a context-free grammar and returns TRUE if there is a valid parse tree or FALSE otherwise.
This algorithm sidesteps the need to try every possible tree by exploiting the fact that a complete sentence is made by combining sub-clauses, or equivalently, a parse tree is made by combining sub-trees. A tree is only valid if its sub-trees are also valid. The algorithm works from the bottom of the tree upwards, storing possible valid sub-trees as it goes and building larger sub-trees from these components without the need to re-calculate them. As such, CYK is a dynamic programming algorithm.
The CYK algorithm is just a few lines of pseudo-code:
0 # Initialize data structure
1 chart[1...n, 1...n, 1...V] := FALSE
2
3 # Use unary rules to find possible parts of speech at pre-terminals
4 for p := 1 to n # start position
5 for each unary rule A -> w_p
6 chart[1, p, A] := TRUE
7
8 # Main parsing loop
9 for l := 2 to n # sub-string length
10 for p := 1 to n-l+1 #start position
11 for s := 1 to l-1 # split width
12 for each binary rule A -> B C
13 chart[l, p, A] = chart[l, p, A] OR
(chart[s, p, B] AND chart[l-s,p+s C])
14
15 return chart[n, 1, S]
The algorithm is simple, but is hard to understand from the code alone. In the next section, we will present a worked example which makes this much easier to comprehend. Before we do that though, let's make some high level observations. The algorithm consists of four sections:
The complexity of the algorithm is easy to discern. Lines 9-13 contain three for loops depending on the sentence length $n$ (lines 9-11) and one more depending on the number of grammar rules $|R|$ (line 12). This gives us a complexity of $\mathcal{O}(n^3 \cdot |R|)$.
To make the CYK algorithm easier to understand, we'll use the worked example of parsing the sentence I saw him with the binoculars. We already saw in figure 2 that this sentence has two possible meanings. We'll assume the minimal grammar from figure 8 that is sufficient to parse the sentence. In the next four subsections we'll consider the four parts of the algorithm in turn.
Figure 9 shows the chart for our example sentence, which is itself shown in an extra row under the chart. Each element in the chart corresponds to a sub-string of the sentence. The first index of the chart $l$ represents the length of that sub-string and the second index $p$ is the starting position. So, the element of the chart at position (4,2) represents the sub-string that is length four and starts at word two which is saw him with the. We do not use the upper triangular portion of the chart.
The CYK algorithm runs through each of the elements of the chart, starting with strings of length 1 and working through each position and then moving to strings of length 2, and so on, until we finally consider the whole sentence. This explains the loops in lines 9 and 10. The third loop considers possible binary splits of the strings and is indexed by $s$. For position (4,2), the string can be split into saw $|$ him with the ($s=1$, blue boxes), saw him $|$ with the ($s=2$, green boxes), or saw him with $|$ the ($s=3$, red boxes).
Now that we understand the meaning of the chart and how it is indexed, let's run through the algorithm step by step. First we deal with strings of length $l=1$ (i.e., the individual words). We run through each unary rule $A \rightarrow w_p$ in the grammar and set these elements to TRUE in the chart (figure 10). There is only one ambiguity here, which is the word saw which could be a past tense verb or a noun. This process corresponds to lines 5-6 of the algorithm.
In the main loop, we consider sub-strings of increasing length starting with pairs of words and working up to the full length of the sentence. For each sub-string, we determine if there is a rule of the form $\text{A}\rightarrow \text{B}\;\text{C}$ that can derive it.
We start with strings of length $l=2$. These can obviously only be split in one way. For each position, we note in the chart all the non-terminals A that can be expanded to generate the parts of speech B and C in the boxes corresponding to the individual words (figure 11).
In the next outer loop, we consider sub-strings of length $l=3$ (figure 12). For each position, we search for a rule that can derive the three words. However, now we must also consider two possible ways to split the length 3 sub-string. For example, for position $(3,2)$ we attempt to derive the sub-string saw him with. This can be split as saw him $|$ with corresponding to positions (2,2)$|$(1,4) which contain VP and P respectively. However, there is no rule of the form $\text{A}\rightarrow\text{VP}\;\text{P}$. Likewise, there is no rule that can derive the split saw $|$ him with since there was no rule that could derive him with. Consequently, we leave position $(3,2)$ empty. However, at position $(3,4)$, the rule $\text{PP}\rightarrow \text{P}\;\text{NP}$ can be applied as discussed in the legend of figure 12.
We continue this process, working upwards through the chart for longer and longer sub-strings (figure 13). For each sub-string length, we consider each position and each possible split and add non-terminals to the chart where we find an applicable rule. We note that position $(5,2)$ in figure 13b corresponding to the sub-string saw him with the binoculars is particularly interesting. Here there are two possible rules $\text{VP}\rightarrow\text{VP}\;\text{PP}$ and $\text{VP}\rightarrow\text{VBD}\;\text{NP}$ that both come to the conclusion that the sub-string can be derived by the non-terminal VP. This corresponds to the original ambiguity in the sentence.
When we reach the top-most row of the chart ($l=6$), we are considering the whole sentence. At this point, we discover if the start symbol $S$ can be used to derive the entire string. If there is such a rule, the sentence is valid under the grammar and if there isn't then it is not. This corresponds to the final line of the CYK algorithm pseudocode. For this example, we use the rule $S\rightarrow \text{NP}\;\text{VP}$ explain the entire sting with the noun phrase I and the verb phrase saw him with the binoculars and conclude that the sentence is valid under this context free grammar.
The basic CYK algorithm just returns a binary variable indicating whether the sentence can be parsed or not under a grammar $G$. Often we are interested in retrieving the parse tree(s). Figure 14 superimposes the paths that led to the start symbol in the top left from figures 11-13. These paths form a shared parse forest; two trees share the black paths, but the red paths are only in the first tree and the blue paths are only in the second tree. These two trees correspond to the two possible meanings of the sentence (figure 15).
These two figures show that it is trivial to reconstruct the parse tree once we have run the CYK algorithm as long as we cache the inputs to each position in the chart. We simply start from the start symbol at position (6,1) and work back down through the tree. At any point where there are two inputs into a cell, there is an ambiguity and we must enumerate all combinations of these ambiguities to find all the valid parses. This technique is similar to other dynamic programming problems (e.g.: the canonical implementation of the longest common subsequence algorithm computes only the size of the subsequence, but backpointers allow for retrieving the subsequence itself).
The previous example was relatively unambiguous. For a bit of fun, we'll also show the results on the famously difficult-to-understand sentence Buffalo buffalo Buffalo buffalo buffalo buffalo Buffalo buffalo. Surprisingly, this is a valid English sentence. To comprehend it, you need to know that (i) buffalo is a plural noun describing animals that are also known as bison, (ii) Buffalo is a city, and (iii) buffalo is a verb that means "to indimidate". The meaning of the sentence is thus:
Bison from the city Buffalo that are intimidated by other bison from the city Buffalo, themselves intimidate yet other bison from the city Buffalo.
To make things even harder, we'll assume that the text is written in all lower case, and so each instance of buffalo could correspond to any of the three meanings. Could you come up with a grammar that assigns an intuitive analysis to this sentence? In Figure 16 we provide a minimal, but sufficient grammar that allows the CYK algorithm to find a single and reasonable parse tree for this strange sentence.
In this part of the blog, we have described the CYK algorithm for the recognition problem; the algorithm determines whether a string can be generated by a given grammar. It is a classic example of a dynamic programming algorithm that explores an exponential search space in polynomial time by storing intermediate results. Another way of thinking about the CYK algorithm from a less procedural and more declarative perspective is that it is performing logical deduction. The axioms are the grammar rules and we are presented with facts which are the words. For a given sub-string length, we deduce new facts applying the rules of the grammar $G$ and facts (or axioms) that we had previous deduced about shorter sub-strings. We keep applying the rules to reach new facts about which sub-string is derivable by $G$ with the goal of proving that $S$ derives the sentence.
Note that we have used an unconventional indexing for the chart in our description. For a more typical presentation, consult these slides.
In part II, we will consider assigning probabilities to the production rules, so when the parse is ambiguous, we can assign probabilities to the different meanings. We will also consider the inside-outside algorithm which helps learn these probabilities.
^{1} This famous example was used in Syntactic Structures by Noam Chomsky in 1957 to motivate the independence of syntax and semantics.
^{2} The idea that sentences are recursively built up from smaller coherent parts dates back at least to a Sanskrit sutra of around 4000 verses known as Aṣṭādhyāyī written by Pāṇini probably around the 6th-4th century BC.
]]>Each of this series of three blogs focuses on different aspects of the transformer. In Part I, we introduce self-attention, which is the core mechanism that underpins the transformer architecture. We then describe transformers themselves and how they can be used as encoders, decoders, or encoder-decoders using well-known examples such as BERT and GPT3. This discussion will be suitable for someone who knows machine learning, but who is not familiar with the transformer.
Part II considers how to adapt the transformer to cope with longer sequences, different methods for encoding the positions of elements in the sequence, and other modifications to the basic architecture. We also discuss the relationship between the transformer and other models. This will be suitable for a reader who knows the basics about transformers and wants to learn more.
Transformer models are difficult to train from scratch in practice. Part III details the tricks that are required to ensure that training does not fail. We conclude with a discussion of our recent work on how to modify the training procedure to fine-tune deep transformers when only sparse training data is available. This discussion will be suitable for practitioners who want to learn more about how to work effectively with transformers.
To motivate the transformer, consider the following passage:
The restaurant refused to serve me a ham sandwich, because it only cooks vegetarian food. In the end, they just gave me two slices of bread. Their ambience was just as good as the food and service.
We would like to build a network that can process this passage into a representation that is suitable for downstream tasks. For example, we might want to classify the review as positive or negative, or answer questions such as "Does the restaurant serve steak?". Two problems immediately present themselves:
First, the input representation will be large. Typically, we might describe each of the 37 words with an embedding vector of length 1024 and so the network input will be of length $37 *1024 = 37888$ even for this small passage. A more realistically sized input might have hundreds or even thousands of words. It's not clear that a standard fully-connected network would be practical here; it would need a very large number of parameters, and it's not obvious how to adapt such a network to inputs containing different numbers of words. This suggests the need for some kind of parameter sharing that is analogous to the use of convolutions in image processing.
Second, language is fundamentally ambiguous; it is not clear from the syntax alone that the pronoun it refers to the restaurant and not the ham sandwich. To fully understand the text, the word it should somehow be connected to the word restaurant. In the parlance of transformers, the former word should pay attention to the latter. This implies that there must be connections between the words, and that the strength of these connections will depend on the words themselves. Moreover, these connections need to extend across large spans of the text; the word their in the last sentence also refers to the restaurant.
In conclusion, we have argued that a model that can process real world text (i) will use parameter sharing so that it can cope with long input passages of differing lengths, and (ii) will contain connections between word representations that depend on the words themselves. The transformer acquires both of these properties by using dot-product self-attention.
A standard neural network layer $\bf nn[\bullet]$, takes a $D\times 1$ input $\mathbf{x}$, applies a linear transformation followed by a static non-linearity like a rectified linear unit (ReLU)
\begin{equation}
\bf nn[\mathbf{x}] = \bf ReLU[\boldsymbol\Phi\tilde{\mathbf{x}}], \tag{1}
\end{equation}
to return a modified output vector. Here, the notation $\tilde{\mathbf{x}}$ indicates that we have appended the constant value 1 to the end of $\mathbf{x}$ so that the parameter matrix $\boldsymbol\Phi$ can also represent the offsets in the linear transformation. For simplicity, we'll assume that we use this trick every time we apply a linear transformation and just write $\boldsymbol\Phi\mathbf{x}$ from now on.
In contrast, a self-attention block $\bf sa[\bullet]$ takes $I$ inputs $\mathbf{x}_{i}$, each of dimension $D\times 1$ and returns $I$ output vectors. In the context of NLP, each of the inputs $\mathbf{x}_{i}$ will represent a word or part of a word. For input $\mathbf{x}_{i}$, the self-attention block returns the weighted sum:
\begin{equation}
\mbox{sa}[\mathbf{x}_{i}] = \sum_{j=1}^{I}a[\mathbf{x}_{i}, \mathbf{x}_{j}]\boldsymbol\Phi_v \mathbf{x}_{j}. \tag{2}
\end{equation}
The sum is over all of the inputs $\{\mathbf{x}_{i}\}_{i=1}^{I}$ after applying the same linear transformation $\boldsymbol\Phi_{v}$ to each. We will refer to the parameters $\boldsymbol\Phi_{v}$ as value weights and the product $\boldsymbol\Phi_v \mathbf{x}_{i}$ as computing the values for the $i^{th}$ input. These values are weighted by the terms $a[\mathbf{x}_{i}, \mathbf{x}_{j}]$ which are scalars that represent the attention of input $\mathbf{x}_{i}$ to input $\mathbf{x}_{j}$.
In the following sections, we will look at this in more detail by breaking this computation down into two parts. First we'll consider the computation of the values and their subsequent weighting as described in equation 2. Then we'll describe how compute the attention weights $a[\mathbf{x}_{i}, \mathbf{x}_{j}]$.
The same value weights $\boldsymbol\Phi_{v}$ are applied to each input $\mathbf{x}_{i}$ and because of this parameter sharing, far fewer parameters are required than if we had used a fully-connected network (figure 1). Moreover, this part of the computation is easy to extend to different sequence lengths.
The attention weights $a[\mathbf{x}_{i}, \mathbf{x}_{j}]$ combine the values from different inputs. They are also sparse in a sense, since there is only one weight for each ordered pair of inputs $(\mathbf{x}_{i},\mathbf{x}_{j})$, regardless of the size of these inputs. It follows that the number of attention weights increases with the square of the sequence length $I$, but is independent of the length $D$ of each input $\mathbf{x}_{i}$.
In the previous section, we saw that the outputs are the result of two chained linear transformations; the values $\boldsymbol\Phi_{v}\mathbf{x}_{i}$ are computed independently for each input $\mathbf{x}_{i}$ and these vectors are combined linearly by the attention weights $a[\mathbf{x}_{i},\mathbf{x}_{j}]$. However, the overall self-attention computation is non-linear because the attention weights are themselves non-linear functions of the input.
More specifically, the attention weight $a[\mathbf{x}_{i},\mathbf{x}_{j}]$ depends on the dot-product $(\boldsymbol\Phi_{q}\mathbf{x}_{i})^{T}\boldsymbol\Phi_{k}\mathbf{x}_{j}$ between $\mathbf{x}_{i}$ and $\mathbf{x}_{j}$ after each as been transformed by a different linear transformations $\boldsymbol\Phi_{q}$ and $\boldsymbol\Phi_{k}$ respectively. To complete the computation of the attention weight, these dot-product similarities are passed through a softmax function:
\begin{eqnarray}\label{eq:sattention2}
a[\mathbf{x}_{i},\mathbf{x}_{j}] &=& \mbox{softmax}_{j}\left[(\boldsymbol\Phi_{q}\mathbf{x}_{i})^{T}\boldsymbol\Phi_{k}\mathbf{x}_{j} \right]\nonumber\\
&=& \frac{\exp\left[(\boldsymbol\Phi_{q}\mathbf{x}_{i})^{T}\boldsymbol\Phi_{k}\mathbf{x}_{j} \right]}{\sum_{j=1}^{I}\exp\left[(\boldsymbol\Phi_{q}\mathbf{x}_{i})^{T}\boldsymbol\Phi_{k}\mathbf{x}_{j} \right]} \tag{3}
\end{eqnarray}
and so for each $\mathbf{x}_{i}$ they are positive and sum to one (figure 2). For obvious reasons, this is known as dot-product self-attention.
The vectors $\boldsymbol\Phi_{q}\mathbf{x}_{i}$ and $\boldsymbol\Phi_{k}\mathbf{x}_{i}$ are known as the queries and keys respectively. These names were inherited from the field of information retrieval and have the following interpretation: the output for input $\mathbf{x}_{i}$ receives a weighted sum of values $\boldsymbol\Phi_v \mathbf{x}_{j}$, where the weights $a[\mathbf{x}_{i}, \mathbf{x}_{j}]$ depend on the similarity between the query vector $\boldsymbol\Phi_q \mathbf{x}_{j}$ and the key vector $\boldsymbol\Phi_k \mathbf{x}_{j}$.
To summarize, we see that for input $\mathbf{x}_{i}$, the output is a weighted sum of the same linear transformation $\boldsymbol\Phi_{v}$ of all of the inputs, where these weights are positive and sum to one. The weights depend on a measure of similarity between input $\mathbf{x}_{i}$ and the other inputs. The computation as a whole is non-linear due to the dot-product and softmax operation used to compute these weights. Consequently, there is no need for a pointwise non-linearity like a ReLU.
Note that this mechanism fulfils the requirements that we laid out earlier. First, there is a single shared set of of parameters $\boldsymbol\Phi_{v},\boldsymbol\Phi_{q},\boldsymbol\Phi_{k}$. This is independent of the number of inputs $I$ and so the network can be applied to different sequence lengths. Second, the connections between the inputs (words) depend on the input representations themselves via the computed attention values.
The above computation can be written in a more compact form if we assume that the $I$ inputs $\mathbf{x}_{i}$ are form the rows of the $I\times D$ matrix $\mathbf{x}$:
\begin{equation}
\mbox{Sa}[\mathbf{x}] = \mbox{Softmax}[\mathbf{X}\boldsymbol\Phi_{q}(\mathbf{X}\boldsymbol\Phi_{k})^{T}]\mathbf{X}\boldsymbol\Phi_{v}. \tag{4}
\end{equation}
where the function $\mbox{Softmax}[\bullet]$ takes a matrix and performs the softmax operation independently on each of its rows (figure 3). Note that here the matrices $\boldsymbol\Phi_{v}, \boldsymbol\Phi_{q}$ and $\boldsymbol\Phi_{k}$ are the transposes of those in the original formulation.
In the previous section, we described the dot-product self-attention mechanism. Here, we introduce three extensions that are all almost always used in practice.
Observant readers will have noticed that the above mechanism loses some important information; the computation will be the same, regardless of the order of the inputs $\mathbf{x}_{i}$. However, if the inputs correspond to the words in a sentence, it's clear that the order matters. To incorporate information about position, we add a matrix $\boldsymbol\Pi$ which is the same size as the input matrix that encodes this information.
The position matrix $\boldsymbol\Pi$ may either be chosen manually or learned. It may be added to the initial word embeddings only or it may be added at every layer of the network. Sometimes it is only added to $\mathbf{x}$ in the computation of the queries and keys. The contents of this vector and other variations will be discussed in detail in part II of this blog; however, the main idea is that there is unique vector added to each input $\mathbf{x}_{i}$ that lets the system know its position in the sequence.
The dot products in the attention computation may have very large magnitudes. This can move the arguments to the softmax function into a region where the largest value dominates to a large degree and consequently, the associated gradients are very small and the model becomes hard to train. To resolve this issue, it is typical to scale the computed attention values by the square root of dimension $d_{q}$ of the queries and keys (i.e., the number of columns in $\boldsymbol\Phi_{q}$ and $\boldsymbol\Phi_{k}$ which must be the same). This gives:
\begin{equation}
\mbox{Sa}[\mathbf{x}] =\mbox{Softmax}\left[\frac{(\mathbf{X}\boldsymbol\Phi_{q})(\mathbf{X}\boldsymbol\Phi_{k})^{T}}{\sqrt{d_{q}}}\right]\mathbf{X}\boldsymbol\Phi_{v}. \tag{5}
\end{equation}
This is known as scaled dot product self-attention.
Practitioners usually apply multiple self-attention mechanisms in parallel, and this is known as multi-head self attention. The $h^{th}$ self-attention mechanism or head can be written as:
\begin{equation}
\mbox{Sa}_{h}[\mathbf{x}] =\mbox{Softmax}\left[\frac{(\mathbf{X}\boldsymbol\Phi_{qh})(\mathbf{X}\boldsymbol\Phi_{kh})^{T}}{\sqrt{d_{q}}}\right]\mathbf{X}\boldsymbol\Phi_{vh}. \tag{6}
\end{equation}
where we have different parameters $\boldsymbol\Phi_{qh}$, $\boldsymbol\Phi_{kh}$ and $\boldsymbol\Phi_{vh}$ for each head. The outputs of these self-attention mechanisms are concatenated and another linear transform $\boldsymbol\Phi_{c}$ is applied to combine them (figure 4):
\begin{equation}
\mbox{MhSa}[\mathbf{X}] = \left[\mbox{Sa}_{1}[\mathbf{X}]\;\mbox{Sa}_{2}[\mathbf{X}]\;\ldots\;\mbox{Sa}_{H}[\mathbf{X}] \right]\boldsymbol\Phi_{c}. \tag{7}
\end{equation}
This appears to be necessary to make the transformer work well in practice. It has been speculated that multiple heads make the self-attention network more robust to bad initializations. The fact that trained models only seem to depend on a subset of the heads lends credence to this speculation.
Self-attention is just one part of a larger transformer layer. This layer consists of a multi-head self-attention unit (which allows the word representations to interact with each other) followed by a fully connected network $\mbox{mlp}[\mathbf{x}_{i}]$ (that operates separately on each word representation). Both of these units are residual networks (i.e., their output is added back to the original input). In addition, it is typical to add a LayerNorm operation after both the self-attention and fully connected networks. The complete layer can be described by the following series of operations:
\begin{eqnarray}
\mathbf{x} &\leftarrow& \mathbf{x} + \mbox{MhSa}[\mathbf{x}] \nonumber \\
\mathbf{x} &\leftarrow& \mbox{Layernorm}[\mathbf{x}] \hspace{3cm}\nonumber\\
\mathbf{x}_{i} &\leftarrow& \mathbf{x}_{i}+\mbox{mlp}[\mathbf{x}_{i}] \hspace{3.6cm}\forall\; i\in\{1\ldots I\}\nonumber\\
\mathbf{x} &\leftarrow& \mbox{Layernorm}[\mathbf{x}], \tag{8}
\end{eqnarray}
where the column vectors $\mathbf{x}_{i}$ are transposed and form the rows of the full data matrix $\mathbf{x}$ in the first stage. In a real system, the data would pass through a series of these layers.
Now that we have a good understanding of self-attention and the transformer layer, let's walk through a typical modern NLP processing pipeline.
A text processing pipeline begins with a tokenizer. This splits the text into a vocabulary of smaller constituent units (tokens) that can be processed by the subsequent network. In the discussion above, we have implied that these are words, but there are a several difficulties with this.
One approach would be just to use letters and punctuation marks as the vocabulary, but this would mean splitting text into a large number of very small parts and requiring the subsequent network to re-learn the relations between them.
In practice, a compromise between using letters and full words is used, and the final vocabulary will include both common words and short parts of words from which larger and less frequent words can be composed. The vocabulary is computed using a method such as byte pair encoding that uses ideas from text compression methods; essentially it greedily merges commonly-occurring sub-strings based on their frequency. This type of approach is known as a sub-word tokenizer.
Each different token within the vocabulary is mapped to a word embedding. Importantly, the same token always maps to the same embedding. These embeddings are learned along with the rest of unknown parameters in the network. A typical embedding size is 1024 and a typical total vocabulary size is 30,000, and so even before the main network, there are a lot of parameters to learn.
These embeddings are then collected to form the rows of the input matrix $\mathbf{x}$ and the positional encoding $\boldsymbol\Pi$ may be added at this stage.
Finally, the input embedding matrix $\mathbf{X}$ is passed to a series of transformer layers, which we'll refer to as a transformer network from now on. There are three main types of transformer network. First, a transformer network can be used as an encoder. Here, the goal is to transform the text into a representation that can support a variety of language tasks, such as sentiment analysis or question answering. An example of an encoder model is the BERT model.
Second, a transformer network can be used as a decoder. Here, the goal of the network is to generate a new token that continues the input text. An example of a decoder model is GPT3.
Finally, transformer networks can be used to build encoder-decoder models. These are used in sequence to sequence models, which take one text string and convert them to another text string. For example, in machine translation, an input sentence in English might be processed by the encoder. The decoder then generates the translated sentence in French. An example of an encoder-decoder model is the paper where transformers were first introduced.
We'll now consider each of these three variations in turn.
BERT is an encoder model that uses s a vocabulary of 30,000 tokens. The tokens are converted to 1024 dimensional word embeddings and passed through 24 transformer layers. In each of these is a self-attention layer with 16 heads, and for each head the queries, keys, and values are of dimension 64 (i.e., the matrices $\boldsymbol\Phi_{vh},\boldsymbol\Phi_{qh},\boldsymbol\Phi_{kh}$ are of size $1024\times 64$). The dimension of the hidden layer in the neural network layer of the transformer is 4096. The total number of parameters is $\sim 340$ million. This sounds like a lot, but is tiny by modern standards.
Encoder models are trained in two stages. During pre-training, the parameters of the transformer architecture are learned using self-supervision from a large corpus of text. The goal here is for the model to learn general information about the statistics of language. In the fine-tuning stage, the resulting network is adapted to solve a particular task, using a smaller body of supervised training data. We'll now discuss each of these stages in turn for the BERT model.
In the pre-training stage, the network is trained using self-supervision. This allows the use of enormous amounts of data, without the need for manual labels. For BERT, the self-supervision task consists of predicting missing words from sentences from a large internet corpus (figure 7)^{1}. During training, the maximum input length was 512 tokens and the batch size is 256. The system is trained for 1,000,000 steps which is roughly 50 epochs of the 3.3 billion word corpus.
Trying to predict missing words forces the transformer network to understand something of the syntax of the language. For example, that it might learn that the adjective red is often found before nouns like house or car but never before a verb like shout. It also allows the model to learn some superficial common sense about the world. For example, after training, the model will assign a higher probability to the missing word train in the sentence The <mask> pulled into the station, than it would to the word peanut. However, there are persuasive arguments that the degree of "understanding" that this type of model can ever have is limited.
In the fine-tuning stage, the parameters of the model are adjusted to specialize it to a particular task. This usually involves adding an extra layer on top of the transformer network, to convert the collection of vectors $\mathbf{x}_{1},\ldots \mathbf{x}_{I}$ associated with the input tokens to the desired format of the output. Examples include:
Text classification: In BERT, there is a special token known as the $<$cls$>$ token (short for classification token) that is placed at the start of each string during pre-training. For text classification tasks like sentiment analysis, the vector associated with this string is mapped to a single number and passed through a logistic sigmoid. This creates a number between 0 and 1 that can be interpreted as the probability that the sentiment is positive and the system is fine-tuned to maximize this correct probability (figure 8a).
Word classification: In named entity recognition, the goal is to classify each individual word as an entity type (e.g., person, place, organization, or no-entity). To this end, the vector $\mathbf{x}_{i}$ associated with each token in the input sequence is mapped to a $K\times 1$ vector where $K$ is the entity type (figure 8a) and the system is fine tuned to maximize these probabilities (figure 8b).
Text span prediction: In the SQuAD 1.1 question answering task, both the question and a passage from Wikipedia containing the answer are input into the system. BERT is then used to predict the text span in the passage that contains the answer. Each token associated with the Wikipedia passage maps to two numbers, that indicate how likely it is that the text span begins and ends at this location. The resulting two sets of numbers are put through two softmax functions and the probability of any text span being the answer can then be derived by combining the probability of starting and ending at the appropriate places.
In this section, we present a high-level description of GPT3 which is an example of a transformer decoder model. The basic architecture is extremely similar to the encoder model in that it consists of a series of transformer layers that operate on learned word embeddings. However, the goal is different. The encoder aimed to build a representation of the text that could be fine-tuned to solve a more specific NLP task. However, the decoder has one purpose which is to generate the next token in a provided sequence. By iterating this procedure, the model can produce a body of coherent text.
More specifically, GPT3 constructs a language model. For any sentence it aims to model the joint probability $Pr(t_1,t_2,\ldots t_{N})$ of the $N$ observed tokens and it does this by factorizing this joint probability into an auto-regressive sequence:
\begin{equation}
Pr(t_{1},t_{2},\ldots t_{N}) = \prod_{n=1}^{N}Pr(t_{n}|t_{1}\ldots t_{n-1}). \tag{9}
\end{equation}
This is easiest to understand with a concrete example. Consider the sentence It takes great personal courage to let yourself appear weak. For simplicity, let's assume that the tokens are the full words. The probability of the full sentence is:
$Pr$(It takes great personal courage to let yourself appear weak) $=$
$Pr$(It) $\cdot$ $Pr$(takes$|$It) $\cdot$ $Pr$(great$|$It takes) $\cdot$ $Pr$(courage$|$It takes great) $\cdot$
$Pr$(to$|$It takes great courage) $\cdot$ $Pr$(let$|$It takes great courage to) $\cdot$
$Pr$(yourself$|$It takes great courage to let) $\cdot$
$Pr$(appear$|$It takes great courage to let yourself) $\cdot$
$Pr$(weak$|$It takes great courage to let yourself appear). (10)
This demonstrates the connection between the probabilistic formulation of the cost function and the next token prediction task.
When we train a decoder model, we aim to maximize the log-probability of the input text under the auto-regressive language model. Ideally, we would like to pass in the whole sentence and compute all of the log probabilities and their gradients simultaneously. However, this poses a problem; if we pass in the full sentence, then the term computing $\log$ $[$ $Pr$(great$|$It takes) $]$ will have access to both the answer great and also the right context courage to let yourself appear weak.
To see how to avoid this problem, recall that in a transformer network, the tokens only interact in the self-attention layers. This implies that the problem can be resolved by ensuring that the attention to the answer and the right context are zero. This can be achieved by setting the appropriate dot products to negative infinity before they are passed through the $\mbox{softmax}[\bullet]$ function. This idea is known as masked self-attention.
The overall decoder transformer network operates as follows. The input text is tokenized and the tokens are converted to embeddings. The embeddings are passed into the transformer network, but now the transformer layers use masked self-attention so that they can only attend to the current and previous tokens. You can think of each of the output embeddings as representing a partial sentence, and for each the goal is is to predict the next token in the sequence. Consequently, after the transformer layers, a linear layer maps each word embedding to the size of the vocabulary, followed by a $\mbox{softmax}[\bullet]$ function that converts these values to probabilities. We aim to maximize sum of the log probabilities of the next token in the ground truth sequence at every position (figure 9).
To generate from the model, we start with an input sequence of text (which might be just the special $<$start$>$ token) and feed this into the network which then outputs the probability of the next token. We can then either pick the most likely token or sample from this probability distribution. The new extended sequence can be fed back into the decoder network which outputs the probability distribution over the next token and in this way, we can generate large bodies of text. The computation can be made quite efficient as prior embeddings do not interact with subsequent ones due to the masked self-attention and so a lot of the earlier computation can be recycled as we generate subsequent tokens.
In practice, there are many strategies such as beam-search and top-K sampling that can be added to help make the output text more coherent. These are discussed in detail in our previous blog on natural language generation. Here's an example of completing text from the GPT2 model:
Borealis AI is a great place to work because there are a lot of people there that are passionate about this kind of technology, like me. There are some challenges for developers but it also gives us a great opportunity to work on different problems.
where the text provided to the model is in green and the generated text in in blue.
GPT3 applies these ideas on a massive scale. The sequence lengths are 2048 tokens long and since multiple spans of 2048 tokens are processed at once, the total batch size is 3.2 million tokens. There are 96 transformer layers (some of which implement a sparse version of attention), each of which processes a word embedding of size 12288. There are 96 heads in the self-attention layers and the value, query and key dimension is 128. It is trained with 300 billion tokens and learns a total of 175 billion parameters. It is interesting to note that since text is extremely compressible, it actually contains more than enough capacity to memorize the entire data set.
One surprising property of learning a model on this scale is that it can perform many tasks without the need for fine-tuning. If we provide several examples of correct question answer pairs, and then another question, it can often answer the final question correctly just by completing the sequence. One example of this is correcting English grammar:
Poor English input: I eated the purple berries.
Good English output: I ate the purple berries.
Poor English input: Thank you for picking me as your designer. I’d appreciate it.
Good English output: Thank you for choosing me as your designer. I appreciate it.
Poor English input: The mentioned changes have done. or I did the alteration that you requested. or I changed things you wanted and did the modifications.
Good English output: The requested changes have been made. or I made the alteration that you requested. or I changed things you wanted and made the modifications.
Poor English input: I’d be more than happy to work with you in another project.
Good English output: I'd be more than happy to work with you on another project.
The text containing the paired examples in green was provided as context for GPT3 and the system then generated the correct answer in blue. This phenomenon extends to many situations including generating code snippets based on natural language descriptions, simple arithmetic, translating between languages, answering questions about text passages and many more. Consequently, it is argued that enormous language models are few-shot learners; they can learn to do novel tasks based on just a few examples. However, in practice the performance is erratic and it is not clear the extent to which it is extrapolating from learned examples rather than merely interpolating, or even copying verbatim.
The original transformer paper focused on translation between languages, which is an example of a sequence-to-sequence task. Their original architecture was an encoder-decoder model that (as the name suggests) combines both encoder and decoder models.
Consider the example of translating from English to French. The encoder receives the sentence in English and processes it through a series of transformer layers to create an output representation for each token. The decoder receives the sentence in French and processes through a series of transformer layers that use masked self-attention. However, these transformer layers also attend to the output of the encoder. Consequently, each French output word conditioned not only on the previous output words, but also on the entire English sentence that it is translating (figure 10).
In practice this is achieved by modifying the transformer layer. The original transformer layer in the decoder (figure 5) consisted of a masked self-attention layer followed by a multi-layer perceptron applied individually to each embedding. In between these we now introduce a second attention layer, in which the embeddings attend to the output embeddings from the encoder. This uses a version of self-attention where the queries $\mathbf{X}_{d}\boldsymbol\Phi_{q}$ are computed from the decoder embeddings $\mathbf{X}_{d}$, and the keys $\mathbf{X}_{e}\boldsymbol\Phi_{k}$ and values $\mathbf{X}_{e}\boldsymbol\Phi_{v}$ are generated from the encoder embeddings $\mathbf{X}_{e}$:
\begin{equation}
\mbox{ Sa}[\mathbf{X}_{d},\mathbf{x}_{e}] = \mbox{Softmax}[\mathbf{X}_{d}\boldsymbol\Phi_{q}(\mathbf{X}_{e}\boldsymbol\Phi_{k})^{T}]\mathbf{X}_{e}\boldsymbol\Phi_{v}. \tag{11}
\end{equation}
This is known as encoder-decoder attention (figure 11).
In this blog, we introduced the idea of self-attention and then described how this fits into the transformer architecture. We then presented the encoder, decoder, and encoder-decoder versions of this architecture. We've seen that the transformer operates on sets of high-dimensional embeddings. It has a low computational complexity per layer and much of the computation can performed in parallel, using the matrix form. Since every input embedding interacts with every other, it can describe long-range dependencies in text. It is these characteristics that have allowed transformers to be applied in massive systems like GPT3.
In the second part of the blog we will discuss extensions of the basic transformer model. In particular, we will expand on methods to encode the position of tokens and methods to extend transformers to process very long sequences. We'll also discuss how the transformer architecture relates to other models. Finally, in the third part of this series, we will discuss the details of how to train transformer models successfully.
^{1} BERT also used a secondary task which involved predicting whether two sentences were originally adjacent in the text or not, but this only marginally improved performance.
]]>For many of the government and industry partners in attendance at the AI4Good Lab’s Industry Night organised by CIFAR this year, the plan was to provide support, mentorship and guidance to the participants of the AI4Good Lab – a Canadian AI training initiative for women-identified STEM students. Industry experts would spend time with participants, exploring their fields of study, their goals and ambitions, and future career opportunities in the field.
While COVID-19 may have prevented attendees from being together in person, the organizers ensured everyone felt relaxed and comfortable in the virtual booths. The Borealis AI/RBC Amplify booth, for example, featured a ‘virtually comfortable’ L-shape couch, two square stools and a rectangle coffee table. Non-virtual drinks were, sadly, ‘BYOB.’
The two hostesses of the Borealis AI/RBC Amplify booth, Dr. Eirene Seiradaki, Director of Research Partnerships at Borealis AI and Rachael Rishworth from RBC Amplify, talked with the AI4Good Lab participants about the wide range of learning and career opportunities available as students continue their journey of lifelong learning – from the Borealis AI Fellowships (which support AI researchers at Canadian Universities) and Borealis AI Internships through to the RBC Amplify program, which provides interns with hands-on prototype development opportunities at the bank.
The AI4Good Lab participants certainly seemed enthusiastic to learn and share. The booth was full for the entire 2 hours – a testament to the quality of the discussions (and, perhaps, the comfort of the virtual furniture?).
Alongside other big names in AI such as CIFAR, AMII, Vector Institute, Google Canada, DeepMind, Accenture and Manulife, the Borealis AI/RBC Amplify team offered participants a view into the wide array of initiatives and opportunities available in the AI space. They also spent time answering the participants’ questions about careers in the field of AI.
Yet it wasn’t just the students of the AI4Good Lab that were learning that night; so, too, were the industry partners and booth hosts and hostesses. Virtual networking lounges were placed in between booths, creating unique spaces that encouraged fruitful discussions among all participants – students, partners and organizers. Hosts and hostesses also visited each other’s booths to talk with ecosystem partners; in fact, rumor has it that Eirene was spotted on one of the ‘virtually hipster’ stools at the Vector Institute booth, taking a few minutes to chat with good friends and their guests at the end of the event.
More importantly, perhaps, the event highlighted the future impact the participants can make in the field and in the world around them. Since the start of the AI4Good Lab program in early May, the female-identified students participating in the Lab’s two cohorts in Montreal and Edmonton have been building their AI skills and capabilities, in order to conceptualize, design and develop a prototype of an AI application for social good. It is their ideas, research and development that will shape the debate around the value and ethics of AI in the future.
Ultimately, the AI4Good Industry Night demonstrated that learning is a life-long and collaborative journey. Industry participants shared their experience and insights; the students and the organizers of the AI4Good Lab shared theirs. Everybody left the event with a renewed sense of optimism, new ideas and new network connections.
On behalf of the attendees of the AI4Good Lab Industry Night, we would like to thank Maya Marcus-Sells and Yosra Kazemi for organizing a fantastic event in the face of the continued disruption of COVID-19.
Below are just a few photos of the event. We are confident the ideas generated there will emerge into view over the coming months and years.
]]>Canada’s AI research ecosystem has a long history of producing cutting-edge work, leading to the highest concentration of deep learning researchers and students in the world (Invest in Canada, 2017).
“At Borealis AI, we believe Canada’s continued leadership as a global destination for the study of AI requires ongoing support and investment from the business community. As one of the leading voices on AI in Canada, we are committed to helping grow the ecosystem – supporting those researchers, universities, startups and companies that are driving the next wave of exploration and innovation,” noted Dr. Kathryn Hume, Interim Head of Borealis AI.
This year’s Fellowships were awarded to students at nine Canadian universities, from Dalhousie University on the Atlantic to UBC on the Pacific. The ten Fellows – five women and five men – reflect diverse backgrounds and research areas, focusing their skills on problems that range from measuring the level of privacy in anonymous databases through to uncovering new ways to screen for prostate cancer.
“We admire the great Machine Learning research being conducted within Canada’s academic programs and research institutes like AMII, MILA and Vector Institute. And we are keen to support the young research talent flowing out of our universities. By investing in cutting-edge deep learning researchers, their universities and advisors, our goal is to build and strengthen the broader Machine Learning research ecosystem in Canada,” added Dr. Eirene Seiradaki, Director of Research Partnerships at Borealis AI.
Faculty: Dr. Martha White
Borealis AI 2021 Fellow: Vincent Liu
Research topic: Developing batch reinforcement learning algorithms with theoretical guarantees
Faculty: Dr. Purang Abolmaesumi
Borealis AI 2021 Fellow: Golara Javadi
Research topic:Applying Machine Learning to create novel techniques for prostate cancer detection
Faculty: Dr. Arash Mohammadi
Borealis AI 2021 Fellow: Parnian Afshar
Research topic: Deep learning-based radiomics for disease diagnosis
Faculty: Dr. Sageev Oore
Borealis AI 2021 Fellow: Chandramouli Shama Sastry
Research topic: Applying generative models to the identification of distribution shifts and the learning of robust representations
Faculty: Dr. Joelle Pineau
Borealis AI 2021 Fellow: Lucas Page-Caccia
Research topic: The development of neural representations that adapt to new data
Faculty: Dr. Doina Precup
Borealis AI 2021 Fellow: Veronica Chelu
Research topic: Temporal credit assignment problems in reinforcement learning
Faculty: Dr. Hans U. Boden
Borealis AI 2021 Fellow: Lindsay White
Research topic: Applying algebraic topology to measure privacy in anonymous databases
Faculty: Dr. Xiaodan Zhu
Borealis AI 2021 Fellow: Xiaoyu Yang
Research topic: Natural language reasoning and incorporating external knowledge into neural networks
Faculty: Dr. Yasutaka Furukawa
Borealis AI 2021 Fellow: Nelson Nauata
Research topic: Structured reasoning, structured generative models, geometry generation, and geometry reconstruction
Faculty: Dr. Kimon Fountoulakis
Borealis AI 2021 Fellow: Shenghao Yang
Research topic: Combining discrete and continuous optimization methods for graph-based Machine Learning
“These Borealis AI Fellowships are a strong endorsement of the hard work being done at Canada’s Universities and Machine Learning Research Institutes. More importantly, they directly support Canadian research and research teams – like those at Dalhousie University – as they strive to advance the field of Machine Learning,” added Dr. Sageev Oore, Faculty at Dalhousie University and Vector Institute and Advisor to the Dalhousie University Borealis AI 2021 Fellow.
The new cycle of Fellowship applications for the next academic year will open this fall. Please refer to our site for details and information about applying to our Graduate Fellowship program.
These fellowships are part of Borealis AI’s commitment to support Canadian academic excellence in AI and Machine Learning. They provide financial assistance for exceptional domestic and international graduate students to carry out fundamental research, as they pursue their Masters and PhDs in various fields of AI. The program is one of a number of Borealis AI initiatives designed to strengthen the partnership between academia and industry and advance the momentum of Canada’s leadership in the AI space.
To learn more visit: https://www.borealisai.com/en/about/fellowships/
]]>“Artificial Intelligence and Machine Learning hold massive potential for people and businesses across the country and around the world. But it must also be reflective of the society we live in today and the society we hope to create in the future,” noted Dr. Kathryn Hume, Interim Head of Borealis AI. “Programs like the AI4Good Lab are instrumental in helping encourage diversity in the field, developing young talent and uncovering new ideas for applying AI to social good.”
The AI4Good Lab is a AI training initiative for women-identified STEM students that is unique in its focus on not only developing AI skills, but also on tackling diversity and inclusion in AI and inspiring the next generation of AI leaders to develop AI as a force for social good. The AI4Good Lab is partnered with Canadian non-profit research institute, CIFAR, and Montreal-based start-up foundation, OSMO.
“Demand for this type of initiative has been amazing. Women are deeply interested in exploring careers in AI. Companies are just as interested in diversifying their AI workforce. And society is keen to ensure AI is being used to benefit everyone,” noted Dr. Doina Precup, co-founder of AI4Good Lab and a researcher at Mila (McGill University) and Deepmind. “In fact, demand is so high that we have brought together a second cohort in Edmonton in order to help more women access the skills and capabilities of the AI4Good Lab.”
Borealis AI is pleased to be expanding its participation in the AI4Good Lab program. Building on a strong foundation of collaboration throughout last year’s program, Borealis AI will be providing valuable mentorship, advice and skills transfer to participants, as well as operational, strategic and marketing support to program organizers. The Early Talent Team from RBC Technology and Operations will be providing financial support.
“We're excited to be part of a leading-edge program that helps young women develop and grow their knowledge and skills in Machine Learning and AI," added Bruce Ross, Group Head, Technology & Operations, RBC. "These capabilities hold massive potential for every industry but we need diverse talent in order to unlock their full possibilities. Supporting and mentoring women technologists through the AI4Good Lab is one way we're doing that.”
The AI4Good Lab program consists of two parts. In the first, participants attend virtual workshops and lectures focused on providing intensive machine learning training. In the second, participants use their new skills to prototype AI products to tackle a social good problem of their choosing. Last year’s projects included prototypes for applications focused on helping victims of domestic abuse, identifying bias in literature, mental health tools, a cervical cancer risk predictor, and a waste sorting app.
“Our support for this important program is not just about skills transfer. We want participants to be able to interact with female researchers who can serve as role models and mentors, providing advice on career opportunities and development in the field,” noted Dr. Eirene Seiradaki, Director of Research Partnerships at Borealis AI. “We also want to encourage and cultivate the development of the valuable ideas and applications that these women are creating through the program. Our goal is to help create the next generation of female AI developers and their ideas.”
It is exactly these types of partnerships with industry and National Research Institutes that help enrich and diversify the AI ecosystem across Canada. Indeed, besides Borealis AI, RBC, CIFAR and OSMO, the program also enjoys collaboration from AMII, Vector Institute, DeepMind, IBM and others.
"It has been encouraging to see many of Canada’s leading national and regional AI businesses embrace the program by providing not just financial support, but also mentorship and career opportunities,” added Maya Marcus-Sells, Executive Director at the AI4Good Lab. “It takes an ecosystem approach to achieve the objectives we have set. And, with leaders such as Borealis AI and RBC, Canada’s AI ecosystem is clearly stepping up to the challenge.”
Borealis AI and RBC are proud to support this year’s AI4Good Lab expanded program. Initiatives like these form an important part of the life-long learning and development we encourage in our people and our communities. To continue helping the participants of the program upgrade their skills after they graduate from the Lab, Borealis AI and RBC are excited to welcome applications from the AI4Good lab’s present and past students to the Borealis AI internships and the RBC Amplify programs for opportunities to access subject-matter experts, while working on meaningful problems and building their future careers.
]]>“The partnership between Borealis AI, RBC and the AI4Good Lab is a win-win for us, for the AI ecosystem and for society,” added Dr. Eirene Seiradaki. “We are delighted to be playing an important role in supporting women in AI through this program and throughout their careers. We look forward to working with the great participants in this year’s cohorts, exploring the next big ideas in social good, and growing the AI4Good Lab’s alumni network.”
In the world of stable software releases, the launch day was often an event not to be missed. Development teams would gather (pre-COVID) for a ‘shipping party’ to celebrate the end of a long and arduous, yet ultimately successful journey (rumour has it that, at the Microsoft Windows 95 event, one developer rode his motorcycle through the office halls in celebration… much to the chagrin of the facilities manager).
The advent of web-based systems put an end to the ‘shipping parties’. The world has discovered agile. However, for the most part, the PM’s work did not naturally extend to include support and maintenance of a software product, especially in more standard SaaS products. Traditionally, this responsibility sat with DevOps/SRE/Engineering teams. In the context of machine learning, it is much more important for PMs to get more involved in support & maintenance, and put the plans in place early in the product lifecycle.
In ML, product managers tend to consider support and maintenance much more deeply, and here's why. For the vast majority of ML models, there will never be a state in which the model isn’t changing, even if no new lines of code are added. ML models are best considered as living, evolving organisms that require continuous support and maintenance (often right until the end of their lives). To complicate matters further, they are highly susceptible to forces outside of their control – changes in user patterns, malicious actors, or business needs for example – which can affect input data patterns. That, in turn, can impact the model’s accuracy in production.
It all comes down to developing a continuous support and maintenance mindset. This requires PMs and their research teams to start thinking about the challenges that may come up throughout the active lifespan of the product as early as possible. For example, if you can calculate how quickly the model accuracy will likely degrade in production, you can start building the right pipelines and infrastructure to trigger retraining at the appropriate intervals. Similarly, if you can test for bias, adversarial robustness, and fairness earlier in the lifecycle, you can reduce the risk and complexity of monitoring and alerting on the appropriate metrics.
In this – our final post in our Field Guide to ML Product Development – we’ll explore some of the unusual complexities of planning for and managing support and maintenance in an ML project environment. We’ll pick up on some consistent themes that have been raised throughout this series: the uncertainty of dealing with evolving models; the complexity of aligning technology and business demands; and the critical importance of understanding the data.
We’ll also offer a few tips and insights from our work at Borealis AI where, through our relationship with RBC, we have learned a lot about deploying, supporting and maintaining ML and AI models within large enterprise governance, control and technology environments.
ML systems tend to require more attention than your typical SaaS, especially when in production. Think all the usual maintenance problems that come with traditional software systems, plus new ML-specific issues that are continuously evolving. Here are five key areas that should be top of mind for PMs as they plan for the support and maintenance stage of the ML product lifecycle:
These are just some of the considerations that PMs should be thinking about early as they plan for their ML product lifecycle. But, depending on your situation, model and context, the range of issues could be much broader.
While there is much a PM can be doing to prepare for this support & maintenance phase, it helps to have these two key elements figured out very early in the process – the monitoring process and the stakeholder expectations.
Given the recent pace of the background rate of change and the ongoing maintenance costs associated with ML products, PMs will need a comprehensive live monitoring framework that enhances responsiveness and ensures long-term system reliability. And it’s best to be thinking about this early on when a PM can still influence the data streams, production impact, and stakeholder expectations.
Start thinking about stakeholder expectations early in the process. It may be helpful to build your plan around SLAs, SLOs and SLIs. SLAs (service level agreements) represent your agreement with your users. SLOs (service level objectives) are the goals you aim to achieve in pursuit of the SLA. Service level indicators (SLIs), on the other hand, allow you to measure how well your team is performing against those SLOs.
In conclusion, as this Field Guide to Machine Learning Product Development has made clear, ML products are complex, unpredictable and extraordinarily difficult to manage. This is as true in the Support and Maintenance stage as it is throughout the lifecycle. It is an exciting challenge, however, and one of the reasons behind us sharing more of what we know about doing this right.
Looking back across the articles in this Field Guide, the key takeaways for PMs would be to start by thinking about the end-to-end lifecycle, identify where early planning and research can help reduce some of that uncertainty, and start planning to get your arms around the complexity.
Those who miss the old camaraderie of the ‘shipping party’ may also want to schedule some time to celebrate milestones with their teams (motorcycles to remain in the parking lots, please).
]]>We're Hiring!
We are scientists, engineers and product experts dedicated to forging new frontiers of AI in finance, backed by RBC. The team is growing! Please view the open roles, including Product Manager, Business Development Lead, Research Engineer - Data Systems, and more.
The analysts at Gartner think there is an 85% chance that an average AI project will fail [1]. A few clicks over, Venturebeat puts the chance of failure even higher [2]. The risk of failure can be substantial as you get to the production stage of the AI product development lifecycle. Yet it doesn’t have to be that way.
One of the key factors behind that fail rate is that product development teams tend to put off thinking about production until way too late in the process. Product managers, especially those used to more traditional methodologies (such as agile or waterfall) tend to take their time to properly scope the problem and make sure the model works. They don't tend to think through production in the same way, since in agile or waterfall context, production is relatively straightforward.
That’s a dangerous assumption to make in a machine learning context. Indeed, PMs should allocate more time and effort to think through how their models will actually be put into production. Doing this right can make or break projects.
In this article, we’ll explore how to set your product up for success for the production stage of the ML product development lifecycle. We’ll look at what ‘good’ looks like, what challenges you might face, and how to overcome those challenges and build the right mindset to avoid becoming one of those 85% of AI projects that fail.
Let’s say your team has come up with a great model to predict which customers are going to make a purchase in the next 3 months. You proudly hand all that data over to the sales team, only to find that it will take them at least 6 months to act on that data. Simply put, the solution doesn’t match the need. The project is a failure.
Production considerations should be addressed early on in the project lifecycle. We've seen this happen in the industry time and again: You have teams working diligently with their heads down building something, only to find out in production that it doesn’t ‘work’ the way the business or user needs it to work. Perhaps they need to get the output in a different format. Or maybe they need it to integrate into a particular system. Or have it work on a different timeline (maybe they need something in real-time, but you planned for batch). Finding these things out at the very end of the lifecycle can be fatal to your product, and must be addressed right from the start.
At its most basic, the primary objective of the production stage of the ML product lifecycle is to get the output of your model where it needs to be, in front of who needs it, at the right time. In a traditional DevOps scenario, that is a fairly linear process. Not so for ML products. As I explain a little later in this article, MLOps (the steroid-induced cousin to DevOps) is much more complicated, inherently more uncertain and extraordinarily iterative.
That is why it is so important that production activities start right at the very beginning of the product lifecycle – from the moment you start thinking about what your model is going to do. If you have read the previous articles in this Field Guide, this won’t come as a surprise (all PMs at Borealis AI recognize and respect the importance of thinking about production at the start). But if you only realize this fact during the production stage, you will likely find yourself in trouble.
What are some of the common considerations that PMs can start thinking about early in the process? The most obvious is the alignment to the business needs and processes. And that means understanding how the outputs are going to be used within a certain strategy – understanding how the scope of your project progresses the overall strategy for the user or business.
It also means understanding the environment in which the model will be used – how will you build your data pipeline? What systems will you pull from? What systems will you write to? Do you need to provide an API? These questions need to be answered at the start to ensure that development aligns to the needs of production.
You may want to start thinking early about the privacy, compliance and legal reviews that may need to be conducted on the model. When building models within strict compliance frameworks, for example, it is critically important to conduct the proper privacy impact assessments and legal reviews. Potential challenges or additional requirements that may arise at this stage can often be dealt with early (with much less complexity and stress).
Similarly, all of Borealis AI’s models pass through a formal model validation process where we demonstrate that the model is fair, robust, and meets all of the qualifications necessary to allow it to interact with users (whether customers or employees). Again, if we only learn in the production stage that the business requires a high level of explainability for the model, we may find ourselves losing time to additional experiments, or in the worst case, more model development.
Just as important is understanding how the product will need to integrate into the existing technology and business environment. Were we to build a new model to support customer-facing bank employees, for example, it would need to integrate into existing systems and processes across call centres and the more than 1,000 physical branches across Canada. And you'd need to train every single person on what that model does and how to use its output. Early planning can make this kind of integration much less uncertain.
It’s also worth remembering that with ML, unlike with traditional software development, production models can change very quickly. Inputs can change unexpectedly (for example, the business might decide to change vendors for a critical external data flow that feeds the model). Performance can degrade very quickly. PMs will need to think about how it will be monitored while it is in place.
ML models are constantly changing. They need performance monitoring. They need updates. They will need retraining. Taken as a whole, these barriers and considerations add a lot of uncertainty and complexity to the production stage of the ML product lifecycle. This way of working requires evolutionary changes for PMs used to the traditional DevOps principles – you’ll need to develop an MLOps mindset to specifically fit AI systems and ML models.
An MLOps mindset covers the entire pipeline from how the data flows from its origination point through to the model outputs consumption by the end-user. It means having the foresight to account for all of the considerations outlined above and iterate as the lifecycle progresses. It means incorporating CI/CD (continuous integration, continuous delivery) to ensure changes can be made and pushed into production quickly and easily, as well as establishing a strong process for bug checking and sign-off once the project goes into production. It should also include automatic triggers that respond to changes in performance while in production.
As this illustration shows (taken from this manifesto of MLOps principles), the MLOps environment is dynamic, iterative and complex.
One important indicator of success for an ML project is how early the PM starts planning for production.
My advice would be to ensure you have close collaboration right from the discovery stage to ensure you have all of your teams aligned – there is more you can do at this stage to prepare for production than you might think. And make sure that collaboration isn’t just with your business partner; include privacy, legal, compliance, validation and technology owners.
Once you are in production, it is also critically important to have the right engineering team with the right attitude, culture and capabilities to work through these challenges (like the technology integration) and to ensure that model performance is sufficient to meet the demands of the business.
Ultimately, it all comes back to the purpose of your product. It’s often easy to get caught up in the excitement of building an ML product. But if you keep your eye on the prize and always remember the specific need that you are trying to solve, you should be able to keep out of the weeds that often slow progress through production. At a macro level, every win, small and big, will contribute to the overall success rate of AI projects. As PMs, we can be the change we want to see.
[1] https://www.gartner.com/en/newsroom/press-releases/2018-02-13-gartner-says-nearly-half-of-cios-are-planning-to-deploy-artificial-intelligence
[2] https://venturebeat.com/2019/07/19/why-do-87-of-data-science-projects-never-make-it-into-production/
]]>We're Hiring!
We are scientists, engineers and product experts dedicated to forging new frontiers of AI in finance, backed by RBC. The team is growing! Please view the open roles, including Product Manager, Business Development Lead, Research Engineer - Data Systems, and more.
Imagine dropping in on two different team standup meetings. At the first standup, the team is building an email feature served with an API. The developers are familiar with scrums or Kanban, and are comfortable breaking their work down into one or two-day chunks. You hear things like “I implemented an API that will talk to the email system. We’re all set, and it’s in testing.”
At the second standup, the team is building a machine learning model to predict future stock market returns. The machine learning scientists are chunking their experiment cadence into short sprints. But these team members are more likely to say things like “I’m still training the model, and something is not working. I’m going to look into the initialization parameters and maybe try to optimize that. I’ll report back next time and then repeat.”
In both examples, standups give the team a space to share progress against the tasks in the workplan. But in the machine learning standup, uncertainty often creeps into the discussion, as teams share ideas on different experiments that could lead to a new outcome. Yet – outside of the standup routine – stakeholders expect progress updates.
The big challenge facing product managers (PMs) in the model development phase of the machine learning product lifecycle is simply to keep things moving in the face of uncertainty.
Progress might require careful management of three main stakeholder groups: the development teams; the executives you are problem-solving with; and the designers and engineers with interdependencies on final model performance.
In this post, we’ll explore the model development phase from the viewpoints of these three stakeholder groups, identifying lessons we have learned from our work at Borealis AI.
Armed with a clear set of baseline performance metrics for the task (that you hopefully developed during the feasibility study phase as suggested in this article), machine learning teams can come together around a unified objective and sense of progress. The benchmark essentially provides direction to the team as they explore candidate model architecture that can improve performance above that baseline.
For many, the challenge then becomes breaking down the tasks in order to structure progress against your goal. During the build phase, machine learning teams tend to follow an iterative process that moves through training, validating and testing their models in order to confirm that the statistical properties they learn during the training phase will generalize well to new data sets during the testing phase.
Yet that does not mean that progress cannot be measured. Many teams share their research results and improvements on model evaluation using scoring systems like AUC or ROC, or through precision and recall. However, finding the right balance between ambiguity and specificity is often challenging: goals like “improve 5% this week” can be difficult to achieve with machine learning sprints. It’s not uncommon to go weeks without making progress and then suddenly have a breakthrough that leaps past the benchmark. Yet setting a goal of simply “reporting back progress” can be too loose to drive the team towards a productive outcome.
It is also worth considering the relationship between technical value and business value. An improvement in model performance does not always linearly translate into improvements in business outcomes. For example, if you are building something where additional precision leads to monetary gains (like a model that predicts stock market performance), extra model performance is valuable to the business. For a product recommendation engine, on the other hand, you can quickly reach a point of diminishing returns. In these cases, it may be more worthwhile getting your model out and learning in production than it would be to keep working on finding an architecture that would only deliver incremental performance improvements. As a PM, it is your role to spot when there is a mismatch in value.
This is also the right time to be thinking about the ethical and governance considerations of your model in production. Remember that any requirements for explainability around fair treatment of individuals or groups should be established up front in the design stage (or earlier) as they may act as constraints on the algorithmic design.
Models can change even after they go into production. Sometimes the model can pick up inductive biases from its training set that result in different (and sometimes undesired) behaviors when the data distribution changes slightly in production, -- a phenomenon a recent paper calls underspecification. So while the team may be excited when a model performs well on a hold-out test set (as they should be), it might be a good idea to remain skeptical.
To manage this risk, PMs may want to work though various hypothetical scenarios with distribution shifts in order to get ahead of future challenges, while the design team keeps experimenting. Ultimately, the best approach might be to work through these tests in simulated scenarios and be prepared to monitor performance, retrain and adapt the model as needed, once it goes into production.
Executives and business leaders want to see high-level overviews of what will be delivered, when it will be delivered, and what benefits or cost savings the model will create. And as PM, it's on you to communicate progress to these stakeholders, typically more used to clear, linear development timelines -- and educate them on the uncertain and non-linear nature of machine learning projects.
As PM, you may find yourself saying things like “We hope to deliver this in six months, but it’s possible our timelines will push out as far as 12 months. If we don’t see evidence of progress by then, we’ll cut the project. As far as value created, that will depend on where we land with model performance and we won’t know that until it’s in production.” Likely not the definite, linear answers that most executives are looking for.
There are a few ways to help manage this uncertainty and ‘bring executives along’ on the development journey. One is through transparent and consistent communication.
Don’t be afraid of frequent touchpoints with executives; it’s better to have them travel with you than it is to surprise them with bad news after going silent for six months.
When communicating with the business side, get ready to translate early and often; phrase progress in terms of business outcomes rather than technical benchmarks or milestones. For instance, instead of discussing a precision measure for a classification task, frame the discussion as an effort to articulate the impact of a low false positive rate on outcomes. Help your business partners envision what the future business value might look like.
Another way to manage uncertainty is to hedge risk by working on a couple of candidate projects at once – treating development more like a venture capitalist investing into multiple startups, knowing that not every startup or project will be a success. It’s an approach that is common within machine learning organisations and it’s often embedded in the DNA of their project managers. To ensure projects don’t turn into sunk costs, be sure to incorporate touchpoints with go/no-go decision points across the product lifecycle.
Get a sense of what error rates the business is comfortable accepting for a task. Then reverse engineer that understanding into a minimum viable product to kickstart the fly wheel learning effect and improve the algorithm’s performance over time.
As in any project, there are clear interdependencies between model performance and the product features that can be designed and developed to work with the model. In a traditional software development process, design will spec out the requirements for a product or feature, creating the blueprint for development teams. In a machine learning project, development can often be more nuanced. As the team works its way towards a candidate solution using the spec, they may find out they need to change the criteria in some way, which could have implications on downstream design.
When building a cash flow forecasting model, for example, you may have design spec a feature that encompasses all types of business payments only to decide that, in the interest of time and pace to market, the first version of the model will only work on one type of payment. This shift in model scope requires design to redo the specs, creating downstream implications for supporting engineering development teams.
So why not just have design and engineering wait until after the machine learning team has found a candidate model for production? The answer is ‘time lost’.
The reality is that eventual product deployment can be greatly accelerated by having engineering work on front-end features and back-end data pipelines in parallel with machine learning experimentation. But that will require engineering teams with the right attitude and culture, keen skills for identifying ‘no regret’ development work, a level of comfort discarding code that isn’t valuable, and a sense of excitement around the challenge of iterating back and forth with peers in design and machine learning. PMs can help and take ownership of managing all the moving parts, iterating with teams throughout the process can often foster this type of culture – showing empathy when needed, and constantly communicating with stakeholders to ensure everyone is on the same page as the specs for the production solution evolve.
Navigating the uncertainties of the machine learning model development phase takes a PM skill set that goes beyond fluency with scrum practices, clarity of vision, and clear communication. It requires techniques for creating a sense of team unity and momentum (even when it's unclear when and how things will work).
While it can feel daunting in the moment, the payoffs of doing this well are huge. As long as communication is tight and the team is grounded in trust, you'll start to see different functions rely upon one another more deeply as they coalesce around a candidate model and product design that can go to production.
We'll dive into what that looks like in our next post!
]]>We're Hiring!
We are scientists, engineers and product experts dedicated to forging new frontiers of AI in finance, backed by RBC. The team is growing! Please view the open roles, including Product Manager, Business Development Lead, Research Engineer - Data Systems, and more.
A good user experience (UX) has always been at the heart of successful products. Ensuring you understand how users will engage with your product – their needs, problems and experiences – has always been key to good product design.
Traditionally, that meant doing lots of primary and secondary research. But with machine learning (ML), that approach is no longer fit-for-purpose. Instead, we need to acknowledge that our users are extremely dynamic – research can only tell us so much.
I believe that – in a machine learning context – great UX is created by watching users in action. That means getting something in front of your user as quickly as possible. And that, in turn, is about speed of development.
That is why, in this blog, I will share a framework and scoring system that I use to think about that intersection between UX and ML solution fit, with a clear prioritization on rapid prototype development, getting in front of your user and creating a great experience.
In previous blog posts, my colleagues have explored the nuances of machine learning in the discovery and the feasibility stages of the product development lifecycle. The design requirements stage is an extension of these initial phases. The goal is to cultivate a deep understanding of our end-users’ problems and determine if a machine learning solution is the right fit.
It’s not always easy – particularly when you rely on primary and secondary research – to ascertain what your users’ needs and problems will be. So, instead, the mindset should be around understanding how best to learn about your users in the fastest way possible. For some, that may mean developing quick and dirty prototypes of your application. Or it may mean thinking about possible outcomes that users may experience when facing a given context.
Ultimately, the design stage is about setting your team up to iterate, learn and incorporate. And that raises some key questions for the product manager: What are our users’ problems? Is machine learning the appropriate solution? What are the prediction trade-offs?
As with any technology we build, our goal is to solve a user’s problem. To do that, we need to develop a fundamental understanding of the user, their needs and motivations, and the context surrounding the problem. The insights we uncover will start to form the foundation of any solution we create. Depending on the intersection between the human problem and the advantages ML can provide, we can then form an opinion on whether an ML solution is the right fit.
There are numerous resources on the topic – I am partial to methods by IDEO and UX Collective, for example. But based on my experience, here are four questions you can use to help kick-start your user research. It is not an exhaustive list. And it should be continuously iterated; throughout the product development process, you will learn more about your users which, in turn, will change the way you think and build for them.
Now that we understand our users’ problems, we can begin to explore whether machine learning is the appropriate solution for solving them.
That may sound like a daunting task; machine learning can be used to do many things from personalization and recommendations through to natural language understanding and outlier detection. But, at the root of all of those solutions is a prediction. Whether it is predicting what a user will like based on who they are, or predicting what text to display when prompted by a user, it all comes down to predictions (for more on the economics of ML, check out Prediction Machines by Agrawal, Avi, and Gans).
Therefore, our primary question at this stage should be:
That’s a great start – at this point we understand the user’s problems and we have an idea of what we might want to predict in order to help solve those problems. Unfortunately, that does not necessarily mean we have a clear-cut ML problem. Machine learning is more nuanced. There are a range of other considerations that could shift the decision towards a more traditional solution set. More questions will need to be asked. And in the following sections, I’ll explore these nuances and help develop some more questions to help you decide if ML is the right fit.
In 2016, Andrew Ng argued that “if a typical person can do a mental task with less than one second of thought, we can probably automate it using AI either now or in the near future.” Ng’s adage reminds us that it’s important to develop a human baseline for performance on the task you are looking to solve.
A human baseline serves as both a feasibility test and as a method for understanding requirements moving forward. Just as importantly, it sets an expectation for what metrics your ML model would require in order to be considered ‘good enough’.
Consider, for example, an application that aims to label dog breeds from photos. And let’s assume humans have a moderately high baseline accuracy of 60%. That baseline can then be written into the project as a requirement which, in turn, will help solidify the approximate timelines for the project. The baseline can be used to indicate if the model is over-fitting or under-fitting – something that can often be difficult to assess.
Contrast that example against an ML application that attempts to properly value a piece of art. Since the price would likely vary widely depending on who you ask, establishing a human baseline would be extraordinarily difficult. Knowing this, a PM may decide to re-think the solution or re-phase the user problem.
Ultimately, we should assume that – from a user’s perspective – the human baseline represents the absolute minimum performance the application must achieve before release. If people can identify a dog accurately 60% of the time, they won’t want to use a machine that only gets it right 50% of the time.
The new question we need to add, therefore, is:
Throughout this Field Guide article series, a common theme has been the uncertainty that machine learning applications create. ML is all about predictions. And anyone who makes predictions knows that sometimes they are going to be wrong. What is important is to understand how these mistakes will impact your end-user’s experience. In ML terms, that means deciding with your team the right balance between the recall and precision of the application.
To start, you may want to host a brainstorming session to create a confusion matrix with your cross-disciplinary team. This will help them understand the realm of possible outcomes a user may experience. It will also allow the team to reflect on not only what the prediction will be, but also how that prediction might affect the end-user’s experience. Your team can then use this information to optimize for the appropriate user outcomes which, in turn, will help solidify your design requirements.
Let’s say that we are designing a news recommender system where a false negative leads to recommendations users will not engage with. The impact on the user will be a negative experience which, over time, would be detrimental to the application as users start to disengage entirely. Knowing this, a decision can be made to optimize for false negatives.
This may lead to entirely new conversations about how the team can achieve the desired metric, perhaps through data labelling, application design, system design or other improvements.
With this in mind, the next question we need to add is:
One of the great things about machine learning systems is that they evolve as our users’ mental models change over time. Knowing and planning how a system may change over time can be incredibly important as applications that are created to generalize their predictions will require more data than those created for a specific use case.
In our earlier example where users were identifying breeds of dogs, the data sample will be fairly contained. If the task were to be generalized to identify all animals of all types in photos, an entirely new set of data requirements will be needed (one that contains photos of all target animals we want to classify).
While these more expansive data sets can often be tricky to obtain, it is possible to design interaction patterns that make it easy (even rewarding) for users to give feedback to the system. Even so, PMs will find that the more generalized the requirements the more complex it is to build an application that not only outperforms a human baseline, but also controls for the experience of uncertainty.
The final question we need to add, therefore, is:
So, let’s quantify all of this for you. Here are the four questions I believe PMs should be asking themselves as they consider the ML design stage, aligned to a scoring system.
Now for each of these options, answer the following questions and assign the appropriate score:
The prediction with the lowest score will likely be the easiest to tackle and will probably make the best candidate to quickly prototype and learn from.
History teaches us that the ability to iterate and learn is the primary driver of successful products. And in the ML world, that means iterating rapidly, assessing user needs and problems, understanding the value that ML could deliver and properly identifying the downstream risks and challenges. It also means assessing whether ML would be a good fit.
Ultimately, the goal is to start with your users’ needs and experience in mind and then quickly move into identifying what you need in order to rapidly test and learn. To be sure, adding machine learning to your projects can increase the layers of complexity in your application. I hope this article helps you cut through some of that complexity and hasten your product development cycle.
]]>We're Hiring!
We are scientists, engineers and product experts dedicated to forging new frontiers of AI in finance, backed by RBC. The team is growing! Please view the open roles, including Product Manager, Business Development Lead, Research Engineer - Data Systems, and more.
The world is full of 2x2 matrices designed to help product development teams prioritize work and decide what to build next. IBM Design’s matrix exercises, for example, plot user value against technical feasibility. Others take different approaches.
But are these simple frameworks also useful for prioritizing machine learning features? The short answer is “kind of”. But certainly not within a traditional 2-hour design session.
The challenge with machine learning systems is that it’s very hard for teams – including machine learning scientists – to estimate how difficult it will be to build a model or feature without first experimenting with data and building first-pass, baseline models for a task. Put simply, feasibility can’t be estimated through a one-time judgement call from a seasoned technical team: it’s an explicit, time-boxed phase in the product development lifecycle.
It's also iterative, as the team may conclude that the original problem articulated during Discovery isn’t solvable, but a variation (perhaps a down-scoped version of the problem) is. Then the focus shifts to whether this newly-formulated task still delivers impact for the business and end-users.
Clearly, impact-feasibility analyses are still very useful for machine learning projects. But they must be viewed as living documents that evolve and solidify as the team learns more about the data, algorithms, and technical constraints of their task and problem domain.
At Borealis AI, we call this phase a feasibility study and we use it at the beginning of our projects to help us decide whether or not to pursue a particular project and to manage the expectations of our business partners.
In this post, we’ll explore some of the more common questions faced during the feasibility study phase. And we’ll place particular focus on how you can communicate progress throughout the feasibility phase to ensure business stakeholders’ expectations are properly managed.
For the machine learning research team, the big questions during the feasibility study phase are how quickly they can get to a baseline model performance for the task and how quickly they can show measurable progress on that baseline.
The key to doing this well is to develop a practice of rapid experimentation, where the machine learning team can iteratively cycle through experiments using different model architectures and parameters for the task. As we discussed in a prior post about using config files to run machine learning experiments, it helps to have the right platform and tools already set up to enable reproducibility, parallelization, and collaboration between different members of the team.
At this phase in the process, the team isn’t trying to build the best architecture for the task; rather, they are trying to decide if there is a solution to the problem and if they can deliver measurable results or measurable lift above the baseline. A review of existing literature can help drive thoughtful choices about experimental design before the iteration process begins.
Naturally, no two feasibility phases will be identical. The steps and activities – even how long the phase will last – depends on the nature of the task. And there are many nuances and decisions to be made.
For example, is the task to build a brand-new ML feature? Or is it to improve on an existing algorithm? The differences related to this particular nuance (one we encounter frequently at Borealis AI) can be significant.
When creating a brand-new ML feature, delivering value with an algorithm can often be easier … as long as there's actually a machine learning problem to solve (as Wendy noted in her blog on discovery, sometimes the team will find that a rules-based solution – or a mix of rules-based and ML – is actually better suited to the problem). Confirming it is a problem that actually requires ML is ‘step one’ on a new build.
The second challenge is that the business won't yet have a clear answer to many of the quantitative performance requirements the ML team needs to create a viable solution. They'll likely have an intuitive sense for what error rate they can accept, but probably won't be able to immediately stipulate what accuracy, precision, or recall they need from a model in order to pass it into production.
This is one of those points where great communications and collaboration comes into play. Determining the baseline becomes a dialogue between the machine learning team and the business team. The business team may become frustrated by questions they can’t answer; the lack of precision up front can frustrate the machine learning team. Product management needs to be constantly sharing results back and forth between business and technical teams to serve as a communication bridge during this process.
When improving upon an existing problem, by contrast, teams often face less ambiguity around task definition: an algorithm with baseline performance already exists. The task is to make enough measurable improvement upon this baseline to merit production. That means the real question is how to timebox initial explorations in a way that provides reasonable confidence that we can deliver a better solution.
Quick tip: If the potential delta in business impact is very high, that timebox may be quite long, sometimes up to 6-12 months. If the potential delta in business impact is lower, it's best to keep the timebox shorter and experiment quickly with newer algorithms.
Another nuance worth noting arises when the machine learning team uses off-the-shelf trained models they can tune to the specifics of the domain and environment. With accessible (and high quality) pre-trained models, teams already know the solution will work; what they are trying to find out now is how well the models adapt to a specific domain.
As you progress through the feasibility study phase and explore baseline model performance and architectures, you may find that you need to reframe the task. Consider, for example, an initial task of creating a cashflow prediction solution for a bank. But based on variations in data distributions, it becomes clear that it is not possible to accurately predict cash flow for all industries nine months into the future. What is possible, however, is to predict cash flow for retailers and restaurants four months into the future.
The business team then needs to decide if that down-scoped version will add sufficient value and, if so, that version may end up becoming the first minimum viable product (MVP) that is developed further in the future. If the business team decides the scaled down version does not deliver sufficient value, however, it should probably be halted so that efforts can be prioritized towards another task.
As these discussions are happening, product managers will want to take the time to work with design to understand the cost of errors in the production solution. Essentially, you want to know how frequently the system needs to be right in order to be useful. The acceptable error rate can change: a product recommendation engine, for example, will certainly have a different acceptable error rate to a banking or healthcare solution.
This is also a perfect time to be engaging with ethics and compliance to understand any governance restrictions on the model. For example, does the use case require an explainable model? If so, different architectures permit different degrees of explainability, which would impact machine learning architecture down the line. Does the model have constraints on fairness, bias, or privacy? This would shift the objective function of the task and, if constraints are high, could even impact feasibility.
The key thing to remember is that ethical questions should be asked up front, at the very beginning of the development process. It should never be left until the governance and compliance review at the end of the cycle.
While the research team explores baselines, the production engineering team should be assessing the viability of integrating the model into the downstream systems and sketching the system architecture.
Borealis AI is part of the complex enterprise environment of RBC. And that means our engineering team must always be thinking about how they can adapt our production approach to integrate into the various data pipelines, API architectures, and down-stream applications and business process across business lines in the bank.
While not all machine learning teams will face the same challenges, here are some common questions all ML engineering teams will want to consider during the feasibility and planning phases:
Product management will want to incorporate answers to these questions into the final feasibility analysis, and use them to inform a decision on whether to proceed to the next step in the product lifecycle: designing production system requirements.
Unlike traditional development workflows, the machine learning feasibility study phase is used to dig into the data and quickly conduct experiments to establish baseline performance on a task.
This is not just about populating a 2x2 matrix: it requires constant communication between the technical teams and the business teams. And it requires project managers to iteratively reassess impact and feasibility as the teams learn more about the problem space and possibilities.
Feasibility studies also provide an excellent opportunity to manage stakeholder expectations and risk: from the very start, executives should understand that the project may not actually be feasible, or the details may morph based on data and context. But with continuous communication and transparency through the feasibility study phase, they can cut their investments short if needed, and ensure development teams are prioritizing projects that will lead to high impact and success.
]]>
We're Hiring!We are scientists, engineers and product experts dedicated to forging new frontiers of AI in finance, backed by RBC. The team is growing! Please view the open roles, including Product Manager, Business Development Lead, Research Engineer - Data Systems, and more.
The development of a machine learning product should always start with discovery. As a product manager, I am constantly asking myself “are we building the right thing?". It’s a question that overrides all other concerns I might have.
I’m sure I’m not alone: we product managers (PMs) are responsible for ensuring the products our teams build deliver value to our users and our business. We are responsible for ensuring that our clients not only use but choose our product. We are responsible for delivering a quality solution in a timely manner.
Ultimately, answering that question often boils down to ensuring we’ve identified a problem that is not only worth the effort of solving, but also one that our team can solve best. Identifying such a problem is at the core of product discovery. In this post, I share my perspective on product discovery, drawing from experiences as a PM at an AI-first organization that emphasizes product innovation.
Just as you wouldn’t want to set off down a long highway in the wrong direction, product discovery is critical to setting the right direction through the product development lifecycle. Head off in the wrong direction and you could find yourself far from your destination as weeks or even months are lost while the team solves for the wrong problem. Nobody wants that. So, it is critical for product discovery to focus on clearly
identifying the right problem to solve to create real value for the users.
But what is the ‘right problem’ and how does a product team find the right problem?
There are already a lot of good resources out there that explore product discovery in a more traditional setting (i.e., without the complexity of incorporating ML, let alone bespoke, novel ML solutions, into products).
For example, I like how Silicon Valley Product Group’s Marty Cagan lays out 5 areas of risk to explore during product discovery: user value risk, usability risk, feasibility risk, business viability risk, and ethics risk. Product discovery coach Teresa Torres also has a great framework to evaluate opportunities and other useful tips around how to organize findings from the discovery process.
There is also a lot of good content available on how to talk with users and customers to uncover pain points, how to test assumptions, and how to build POCs and prototypes quickly. Indeed, much of the traditional product discovery process, which focuses on the problem space (versus solutioning), remains applicable to a machine learning PM in the innovation space.
For those PMs who are focused on applying ML techniques in novel and innovative ways, there are two additional considerations from the solution space that need to come into the mix when identifying the ‘right problem’ during discovery.
The first is to consider whether ML is actually the best solution to solve the problem. It is far too easy to get caught up in the excitement of ML and want to apply it to every problem we come across (I’ve been in those conversations before). But an ML solution is relatively complex and comes with its own set of risks and challenges. An ML solution is something you apply only when non-ML techniques are insufficient for solving the problem. Simply put,
just because ML can solve the problem does not mean it’s the best solution for the problem.
When evaluating the applicability of ML to a given problem, I often look at three main axes:
The second consideration in identifying the ‘right problem’ is the opportunity for unique ML contributions. You may find it is possible to use ML in your product without any novel ML research. But for PMs at organizations with teams of researchers aiming to develop novel ML techniques (like those at Borealis AI), we must also evaluate whether there is an opportunity for unique ML contributions in the products we develop.
It is a tricky consideration for PMs to balance. On the one hand, we know that ‘cool research’ should never be the rationale for creating a product; discovery is about understanding the user and business problems. But on the other hand, PMs should also avoid taking too narrow a lens by focusing solely on the immediate problem, and think more broadly about how our users, businesses and problem space will evolve. This can often uncover even bigger problems that can lead to greater innovation and space for novel research to be applied, to deilver better returns for the users and for the business down the road.
At Borealis AI, we believe that one of the keys to product discovery comes down to the ‘people’ side of the equation. While Borealis AI has expertise in product, engineering, and ML research, we work closely with partners in the specific lines of businesses at RBC to leverage their business domain expertise. Successful product discovery (recognizing the considerations above) often relies on establishing trust with our business partners and a focus on long-term relationships.
One of ways Borealis has successfully built trust with our business partners is by starting with a focus on delivering shorter-term value for our business partners and their users at the expense of novel technology and high risks, but with the clear intention to collaborate with partners on more innovative high-risk, high-reward solutions over time. We also often create POCs and prototypes that can be helpful in building trust with our partners.
Over time, and as trust with our business partners builds, the ability to discover more innovative product opportunities grows. We understand more about the pain points and trends of the business domain and the users. Our partners understand more about ML, its limitations and its risks. They gain confidence in ML solutions. And that leads to greater mutual trust as we work together to incorporate novel ML research into products that require longer timelines and carry greater risk (though those are also great opportunities for greater reward!).
At Borealis AI, our culture is centered around novel research and innovation, deeply intertwined with an ambitious product spirit. As product managers, our challenge is to discover the right problem: one that addresses a meaningful pain point for users and the businesses we collaborate with; that requires a complex technology like ML; and where there is opportunity to incorporate novel ML techniques and research over time.
Discovery of the ‘right problem’ to solve relies on the strength of our business partnerships, where trust needs to be established. As the business partnerships develop and evolve, continuous product discovery leads to more opportunities for innovation and novel ML contributions.
So, product managers out there: are we building the right thing? It’s hard to know for sure. But a good product discovery process is an important piece to answering that question.
]]>We're Hiring!
We are scientists, engineers and product experts dedicated to forging new frontiers of AI in finance, backed by RBC. The team is growing! Please view the open roles, including Product Manager, Business Development Lead, Research Engineer - Data Systems, and more.
A century and a quarter later, data remains at the very heart of innovation. Today, data systems are much more sophisticated than Hollerith’s basic counting machine. Modern day systems are driven by sophisticated technologies – like machine learning (ML) – designed to uncover new insights and value from data.
Most businesses around the world now recognize that data is the foundation for future competitive advantage. But there's so much hype around data being the ‘new oil’ that business leaders often miss the nuances of managing data in a way that can actually create value from algorithms.
The adage ‘garbage in, garbage out’ is as true for basic analytics as it is for ML. As teams build ML products, it's important they focus not only on the aptness of the algorithm, but also on the quality of the data throughout the development lifecycle. At Borealis AI, we often say that data must be a first-class citizen.
Understanding the limitations, requirements and complexities of dealing with data is critical for ML product managers, given the cost, time and resource requirements involved.
In this article, we walk through each step of the typical ML project, (as outlined in this article by Sam Chow), to reflect on the nuances of managing data in the ML product lifecycle. We start by exploring the properties of real data and reviewing the more common data design patterns of data-driven systems.
Real data is messy. It’s inconsistent. And it’s poorly organized. In large part, that is because most data is collected to serve a distinct application and not necessarily intended for future data-driven analytics uses.
Consider, for example, how a simple concept such as counting ‘customer orders’ might result in different data across various company functions. For a sales department, the data would represent the number of contracts signed, regardless of whether or not the invoices have been paid. For an accounting function, the data would reflect the number of paid sales, whether or not the product has been shipped. Supply chain, on the other hand, is concerned with the number of customer orders delivered. A simple query therefore – something like number_of_orders_per_day – could result in a number of different data points, adding to the uncertain nature of ML systems.
Part of the challenge is that ML algorithms are generally downstream consumers of real data. Data is typically stored for one purpose, and then reformatted for an ML algorithm. Even in applications designed to feed data to ML models, the data often has a transactional role beyond its potential use in ML models.
For example, Twitter stores tweets and social interaction in a format that makes it responsive and interesting to its users. But behind the scenes, the company then transforms this data to a format appropriate for their Who-To-Follow recommendation system.
The process of converting data is known as the Extract-Transfer-Load (ETL) process. At its most basic, ETL is used to extract the right data from one (or more) system(s), transfer its format into something usable by another system, and then load it into that other system.
While it may seem like a fairly simple process, it is widely recognized to be the most expensive and time-consuming step in the development of an ML system. It is also critically important, as it ends up influencing the models’ success, fairness, scalability, and reliability (2).
Given the costs and importance of gathering the right data, it is critical that there are enough controls to ensure it is done correctly and that the limitations of the data are well understood. Figure (1) shows a simple heuristic to break down data preparation into three stages: raw data, clean data, and aggregated data.
Raw data is the dirtiest (or at least has the greatest potential to be dirty). It's the original data collected through a data-generating business process – most often from business transactions – with all its potential quality issues (duplication, transcription errors, etc.).
Clean data is a verified and unified version of the real data. The data types and business definition of these data are known and documented, making them legible to ML scientists and engineers. However, it does not include new information from data aggregation or embedded representations (such as vectorized, mathematical representations of latent properties in the data that can be used as inputs to a machine learning algorithm).
Aggregate data includes any derived representation of the clean data. For example, it may include total_number_orders for each day or month. It may also include text or customer transaction embeddings, which are often used for deep learning applications.
A data scientist might say that conjoint clean and aggregate data are the main components of the feature store used as inputs to an ML algorithm. For those without a data science background, a ‘feature’ is a piece of information that the system uses to identify a data entry (so, for example, a customer can be viewed as a group of features including name, data of birth, residency, and so on).
As teams advance through the ML product development lifecycle (as aptly described by Sam Chow in this introductory post), it is critical that data be a first-class citizen that is carefully considered at each stage.
There's no machine learning without data. When exploring business problems and impact with executive stakeholders, it's important that teams also connect with data stewards to understand what data is available to use as inputs to the algorithmic system and solve the problem. The following aspects related to the data should be identified:
During the feasibility stage, it is important to zoom out to reassess project assumptions, review data discovery activities and think through any new data sources that may need to be added. This is the time to mitigate any potential risk from a data perspective. Here is what you can do at this stage:
Keep in mind that a data sample will probably come with just a sample of the data problems that will need to be resolved. Yet it could be expensive to clean the entire data set at this stage. There is an obvious trade-off decision that needs to be made. Our upcoming blog about the Feasibility stage will talk in further detail about data preparation techniques, such as normalization, clipping outliers, data encoding, and so on.
Here, we need to start looking at entire data sets, in addition to investigating any new machine learning models and application integrations. This is your chance to:
The importance of the data does not diminish at this stage: the data as a first-class citizen rule applies here too, as you prepare the data for feedback loops and any future retraining. ML systems in production need two things: efficient access to data features to be used in the down-stream task; and continuous data preparation for new real data that should be fed to the system's feature store. Business applications which use ML systems continuously receive new data from different sources, and therefore the data preparation pipelines used during training must be deployed with the ML system. The following considerations are generally important at this stage:
Experienced product managers help their organizations view their data-generating processes with a view towards future AI and machine learning innovation. Many of the common data challenges product managers face could be solved by ensuring data is captured and treated appropriately from the very beginning.
Yet this does not absolve project and product managers from understanding the nuances of their data. In fact, as machine learning systems become increasingly important parts of the business problem-solving equation, a sound understanding of data and data quality will become an increasingly valuable capability.
Expect to see more organizations start looking at building systems that are designed to democratize data, including unified data governance, data catalogues to simplify the process of discovering new data sources, data preparation pipelines, continuous data integration, and building knowledge graphs.
This post is part of an ongoing series developed by the product managers at Borealis AI. In future posts, my colleagues will delve deeper into other aspects of the PM process. But don’t be surprised to see a continuous return to the theme of data as a first class citizen. That is a fundamental of the ML PM process.
This is part 2 in our Field Guide to Machine Learning Product Development: What Product Managers Need to Know series. Read the Introduction here, and stay tuned for deep dives into the six main stages of the machine learning PM process.
(1) Tabulation and Processing history at United States Census: https://www.census.gov/history/www/innovations/technology/tabulation_and_processing.html - accessed on March 12th 2021
(2) Sambasivan, Nithya, et al. " Everyone wants to do the model work, not the data work": Data Cascades in High-Stakes AI. (CHI 2021).
]]>
We're Hiring!Borealis AI offers a stimulating work environment and the platform to do world-class machine learning in a startup culture. We're growing the team, hiring for roles such as Product Manager, Business Development Lead, Research Engineer - Data Systems and more.
The Agile Manifesto published in 2001 has changed the way most software products are developed today: in iterative, focused short loops and sprints. However, applying agile to machine learning systems is far from straightforward. Traditional software implements a known set of computations to come up with an output — developers have a specific output in mind, making product development relatively straightforward. Experienced developers get good at estimating how long it will take to build a component or feature.
As Andrei Karpathy explained in his Software 2.0 post in 2017, machine learning is a different kind of programming that provides much more flexibility around the kinds of problems we can automate using computer systems, as developers provide “rough skeletons” of code and use data to fill in the details. The challenge is, this same flexibility is accompanied by uncertainty in the development process. Machine learning systems are inductive; they learn a suitable program (called a model) based on prior examples of inputs and outputs (called training data), and then apply what they’ve learned to make inferences on new data in production. Scientists choose the model architectures carefully, but getting it right can take time—even an experienced machine learning scientist doesn’t know if they’ll find a solution, let alone how long it will take to find one.
The key difference between building traditional software and machine learning systems is accounting for that uncertainty.
With machine learning, uncertainty is rife across the product development lifecycle – right down to the level of individual algorithms. For example, you can reach a certain level of performance and find that you need to make changes, which would require weeks of code refactoring, and then more time just to return to the same level of baseline performance. And that’s assuming that an entirely different approach isn’t required. Particularly in the early stages of a project, it’s simply unclear what might work.
Uncertainty can also surface in the problem statement and data set. For example, building a recommendation engine for TikTok (which carries live, user-generated content) would be substantially different from building a recommendation engine for Netflix (where content set is tightly controlled). They are both recommendation engines, but their problem statements and data sets are very different.
Uncertainty could also mean that your mathematical model may not generalize well to an unseen data set. Predicting the accuracy of a machine learning model is difficult. Even seemingly good machine learning models can sometimes make errors, necessitating further tweaks to balance the levels of bias and variance, or precision and recall.
In order to manage uncertainty during the product development lifecycle, machine learning projects need to be highly iterative from the beginning.
At Borealis AI, we simultaneously work on each phase of our lifecycle in parallel, carefully iterating on each section until we reach a satisfactory level before circling back, then forward to ensure we find a fit between problem and solution.
Given the complexity, it is imperative for product managers new to machine learning context to develop a nuanced, operational understanding of machine learning systems, capabilities and limitations — and invest the time upfront to tie the effort more tightly to real business problems.
As the product management discipline evolves, we aim to provide a framework for approaching machine learning projects that can be referenced by practitioners or teams that are building machine learning systems, covering the six main stages: Discovery, Feasibility, Design and Requirements, Model and Product Development, Production, Support and Maintenance. We have also dedicated a separate section to the topic of managing data throughout the ML product lifecycle.
The development of a machine learning product should always start with discovery. It’s tempting to skip this step — agile has made it too easy to build first and think (and test) later—but in an ML context this step is critical, and will help product teams tackle and tame uncertainty early in the product lifecycle. Discovery begins by defining the problem space, understanding and framing the business problem, the desired outcome and mapping them to a solution space. This is not always an easy task: technical teams are looking for a quantitative baseline to measure feasibility against while the business is struggling to achieve clarity; the business is looking for clear cost-benefit analysis at a time when the outcomes are unclear. Balancing stakeholder needs, business problem framing, defining goals and outcomes, and building relationships with the business will be key. Here's our take on a new age of discovery: Balancing short-term value with long-term innovation, or how to successfully navigate the discovery stage.
The feasibility stage is used to determine if the machine learning system can achieve the outcomes initially proposed. An initial Minimum Viable Product (MVP) solution may be created at this stage to identify immediate limitations which, in turn, helps refine and reframe the problem statement. Given adaptive nature of machine learning systems, a feasibility study will help product teams understand how the system might have to change, adapt and interact with users when in production, develop a baseline for the problem and a more informed perspective on business ROI. Read more about conducting your Feasibility study for ML projects.
The design of machine learning requirements takes a careful combination of design and business goals. PMs need to understand the end user to drive value. By focusing on the end-user and developing a deep understanding of a specific problem your product can solve. Otherwise, you run the risk of developing an extremely powerful system only to address a small (or possibly non-existent) problem. In the design & requirements section, we explore the fundamental differences of user experience with machine learning systems, and what goes into developing a dynamic interface that adapts to a user’s needs given the uncertainty when it comes to predictions.
It’s time to take the blueprints and a clear set of goals and targets developed through the discovery, feasibility and design stages and start the build. This will be a collaborative endeavour with design, engineering, ML researchers, business and PMs working together, and PMs can drive a lot of value by continuing deliberate stakeholder management. The build will go from simple to complex – PMs should prepare to drive the product through train, test, and validate cycles; manage standups with scientists who may or may not be making progress; and navigate the iterative dependence between design, science and engineering teams. For a deeper dive into this stage, please read How to navigate the model and product development phase of your machine learning project.
The production stage is where the ML system is introduced to the infrastructure where it will function (and hopefully thrive). In the upcoming blog post on this topic, we will discuss the challenges around production environments in large enterprise environments – from business and legal, to technological issues like MLOps. Read our overview of what PMs can do here to set projects up for success: The (never-ending) production stage.
ML projects are not complete upon shipping your first draft. In fact, shipping your MVP is an important first step towards a larger iteration. Developing and deploying ML systems can be relatively fast and affordable, but maintaining them over time can be more difficult and expensive than commonly assumed. Machine learning systems are tightly coupled — a change in one component, changes to the feature space, hyper parameters, or learning rate, for example, can affect model performance in unexpected ways. In the upcoming blog post, we will walk PM teams through what to expect at this stage and how to plan and prepare for it.
]]>
We're Hiring!Borealis AI offers a stimulating work environment and the platform to do world-class machine learning in a startup culture. We're growing the team, hiring for roles such as Product Manager, Business Development Lead, Research Engineer - Data Systems and more.
The views expressed in this article are those of the interviewee and do not necessarily reflect the position of RBC or Borealis AI.
Prasad Chalasani (PC):
Our focus is on enabling business analysts to use the automated ML and explainability to make better decisions. What we found is that there is a world of non-coder business analysts out there who want to use business data and machine learning to deliver actionable predictions. But they also want to understand the drivers behind those predictions, so they can trust and act on them. Our solutions are aimed at helping them achieve those goals quickly, getting insights they need in minutes, rather than waiting for days or even weeks for data scientists to build ML models.
(PC):
Certainly. And we also offer a model training and explanation API for data scientists and ML engineers. But we have found that data scientists are generally quite comfortable using open source explainability tools that enable model developers to show factors influencing predictions. Business analysts, on the other hand, can’t adapt tools to their business contexts as easily, since these adjustments would require ML modeling skills, and tend to be more interested in the drivers behind the predictions that influence business outcomes. That means their capabilities and needs are somewhat different.
(PC):
We have learned a lot over the past year and half. Initially, our plan was to bring our solutions to technical teams as a pure explainability tool. But we quickly realized we needed to pivot. Now we are getting lots of traction in the market by packaging AutoML with explainability tools focused on non-technical teams. The conversation isn’t about explaining their models; it’s about understanding their data and improving the business.
(PC):
Real-world applications have been very exciting. We recently used our event-sequence models approach to identify which government actions were having the greatest impact on COVID-19 death rates. We can potentially also use these models to figure out if a certain type of drug or therapy leads to an increase in adverse events.
At the other end of the spectrum, we work with clients to drive marketing analytics. For example, we show which accounts are most likely to convert, and help identify campaigns and channels with the most impact on predicted conversion rates. This helps businesses see the real drivers of their business, faster. These drivers can then be adjusted to transform the business and improve the outcomes.
(PC):
For that research, we focused on the following two desirable properties of feature attributions: sparsity and stability. Sparsity means making the attributions human-friendly by focusing on only the truly relevant features. Stability ensures that you don’t get wildly different explanations if the input changes slightly.
It turns out that, when you train a model to be robust against adversarial perturbations, you also tend to get sparsity and stability as side effects. If you picture the adversary strength as a knob you can adjust, you want to find that point where attributions become sparse without impacting the natural accuracy of the model. When it comes to stability, we found that training an ML model with an extra regularizer penalizing instability is actually equivalent to adversarial training.
(PC):
In some domains, such as adverse drug event prediction, the precise probability is extremely important. But in the marketing use case, for example, it’s the relative predictions that are the key. When the marketing team provides the sales team with a list of accounts that, based on their models, are most likely to convert, the sales team isn’t worried about the exact probability of conversion but rather the relative probability versus other customers.
(PC):
I believe the opportunity for these types of augmented analytics solutions is massive. I think there is significant demand amongst the business analyst community for solutions that can augment their current business intelligence tools with AutoML that come with the rationale built-in. We have a general-purpose business intelligence engine called XBI and a specialized engine focused on the marketing use case. But we see growing demand for similar tools across the business analyst space.
We’ve also been putting a lot of focus into asynchronous event sequence data which we believe is being largely ignored in the open source solutions for AutoML and explainability. It’s one thing to take nice looking tabular data and generate explanations; it’s another thing to work with messier data sets like event sequence data. I suspect that’s where the innovation is going to come from in the near future. One of our ICML 2020 papers proposes a flexible model for such data, and shows a way to infer causality among different types of events.
(PC):
I think it’s a golden time for startups in this space. Many great tools are being developed in the data-engineering and ML ecosystem. Machine learning is still a hot topic and there is lots of investor interest in new companies in this area.
If I had a word of caution, however, it would be to think carefully about the plan for commercialization. When you have an idea, it’s very tempting to get to work building things. But, as we learned, you really need to think carefully about the commercialization. It took us some time to understand how our idea created value for businesses. Being thoughtful about the go-to-market strategy can save valuable time wasted talking to the wrong audience.
(PC):
We’ve been exploring some exciting partnerships with different marketing analytics and Customer Data Platforms. And we are currently running a number of private beta tests with customers using our XBI engine. In the longer term, we are focused on pushing these products out to the market, expanding our team, creating new partnerships and building new solutions. It’s going to be an exciting time.
Prasad Chalasani is the Co-Founder and CEO of XaiPient, a New-York based AI explainability start up. With 20 years of experience leading quant and ML teams at some of the world’s leading organizations – including Goldman Sachs, Yahoo and MediaMath – Prasad is active in the field of trustworthy AI research and has published several papers in the field of machine learning.
]]>In part II of the tutorial, we present recent methods for making machine learning differentially private. We will also discuss differentially private methods for generative modelling which provide an enticing solution to a seemingly intractable problem: how can we release data for general use while still protecting privacy?
Most machine learning algorithms optimize an objective function representing how well the training data has been learned. This optimization is typically performed using a variant of stochastic gradient descent (SGD). The near universality of SGD makes it a natural starting point for differentially private machine learning. Moreover, the inherent randomness of SGD suggests a natural affinity with differential privacy.
DP-SGD (Algorithm 1) is similar to traditional stochastic gradient descent. We first randomly initialize the parameters, then proceed by randomly selecting a batch of data points, computing their gradients, and using those gradients to update the parameters. The key differences are the lines highlighted in blue in which the gradient norms are clipped, and those highlighted in red where Gaussian noise is added.
Algorithm 1: Differentially private stochastic gradient descent
Input: Training data $\mathcal{D} = \{\mathbf{x}_1,\dots, \mathbf{x}_I\}$; Loss function $\mathcal{L}[\boldsymbol\theta] = \frac{1}{I}\sum_i\mathcal{L}[\boldsymbol\theta,\mathbf{x}_i]$; Parameters: learning rate $\eta_t$, noise scale $\sigma$, batch size $B$, gradient norm bound $C$
Initialize $\boldsymbol\theta_0$ randomly
for $t=1,\dots,T$ do
Randomly sample a batch $\mathcal{D}_t \subset \mathcal{D}$
Compute the gradient $\mathbf{g}_t[\mathbf{x}_i] = \nabla_{\theta} \mathcal{L}[\boldsymbol\theta_{t-1},\mathbf{x}_i]$ for each $\mathbf{x}_i \in \mathcal{D}_t$
Clip the magnitude of each gradient
\begin{equation}
\mathbf{g}_t[\mathbf{x}_i] \leftarrow \frac{\mathbf{g}_t[\mathbf{x}_i]}{\max\left[1,\Vert \mathbf{g}_t(\mathbf{x}_i) \Vert/C\right]}
\end{equation}
Add noise $\boldsymbol\xi_{i,t}\sim \mbox{Norm}_{\boldsymbol\xi_{i,t}}[\mathbf{0},\sigma^{2}C^{2}\mathbf{I}]$
\begin{equation}
\mathbf{g}_t[\mathbf{x}_i] \leftarrow \mathbf{g}_t[\mathbf{x}_{i}] + \boldsymbol\xi_{i,t}
\end{equation}
Average the (clipped, noisy) gradients
\begin{equation}
\bar{\mathbf{g}}_t = \frac{1}{B} \sum_{x_i\in{D}_{t}} \mathbf{g}_t[\mathbf{x}_i]
\end{equation}
Take a step in the descent direction
\begin{equation}
\boldsymbol\theta_t \leftarrow \boldsymbol\theta_{t-1} - \eta_t\bar{\mathbf{g}}_t
\end{equation}
end
Output: Final parameters $\boldsymbol\theta_T$
In part I, we saw that DP mechanisms are sensitive to worst case scenarios and a single data point that produces an arbitrarily large gradient could violate privacy by having an out-sized influence on the final model. Hence, the magnitude of the gradient contribution from each data point is limited to be no more than $C$. Gaussian noise is then added to each gradient. In part I of this tutorial, we saw that for a fixed level of differential privacy, there is a connection between the amount of noise we must add and the amount of clipping. A larger value of $C$ does less clipping, but requires more noise to retain the same degree of privacy and this is made explicit in the algorithm.
The process of clipping, adding Gaussian noise, and averaging together the gradients in the mini-batch is differentially private. As we train, we combine the gradients from different mini-batches, and this too is differentially private due to the composition properties of differential privacy. However, this can significantly over-estimate how much privacy is lost.
Consequently, Abadi et al. (2016) developed a method to keep better track of privacy loss, which was termed the moments accountant. While the details are beyond the scope of this tutorial, the empirical results are shown in figure 1. For a given noise level and number of training steps, the moments accountant results in a much smaller value of $\epsilon$ and hence a much better guarantee on privacy. The learned model is identical whichever way we track the privacy loss, but for a given privacy budget, the moments accountant allows DP-SGD to be run for many more iterations.
We have established how to learn a machine learning model using a differentially private mechanism. However, we have not yet addressed the question of how this impacts model performance. Abadi et al. (2016) found that smaller privacy budgets tended to produce lower accuracy on the test set (figure 2). However, an interesting upside was that differentially private SGD improved generalization performance; the gap between training and testing accuracy tended to be smaller, particularly for smaller privacy budgets. Although this may initially appear counter-intuitive, it is perhaps not entirely surprising given that there are strong theoretical parallels between generalization and differential privacy and it is known that adding noise during training can improve generalization.
In the previous section, we described a differentially private version of stochastic gradient descent. While stochastic gradient descent is the most popular ML training algorithm, there are many other useful techniques, including boosting, decision trees, nearest neighbours, and others. Papernot et al. (2017) introduced a framework for differentially private learning known as Private Aggregation of Teacher Ensembles or PATE that allows any model to be used during training.
PATE works by making the predictions of the machine learning model differentially private instead of making the model itself differentially private. Consider a machine learning model $\mbox{f}[\cdot;\boldsymbol\theta]$ with parameters $\boldsymbol\theta$. The DP-SGD method described above trains the model $\mbox{f}[\cdot;\boldsymbol\theta]$ such that the parameters $\boldsymbol\theta$ are private and by extension any predictions that are derived from them. Alternatively, we could train the model using a standard algorithm. We cannot release the parameters $\boldsymbol\theta$ but we can release the predictions of the model if we add noise using a randomized response mechanism. However, every time we make a prediction, we spend privacy and so we could only ever make a finite number of predictions for a fixed privacy budget.
PATE combines the advantages of both of these approaches. It allows any model to be used to make the predictions, and allows that model to be queried an arbitrary number of times without leaking privacy. However, to accomplish this, it assumes that we have a additional trove of unlabelled public (non-sensitive) data. The overall idea is to use the private data to label this public data such that the labels are differentially private and then use this new labelled data to train any new model we want (figure 3).
In more detail, PATE divides the sensitive data into $N$ disjoint subsets $\mathcal{D}_{1}\ldots \mathcal{D}_{N}$, each of which is used to train an independent machine learning model $\mbox{f}_{n}[\mathbf{x}]$ known as a teacher. Then, these teachers are combined to create an aggregate teacher using a noisy voting model:
\begin{equation}
\mbox{f}[\mathbf{x}] = \underset{j}{\operatorname{argmax}} \left[ \sum_{n}[f_{n}[\mathbf{x}]=j] + \xi \right]. \tag{1}
\end{equation}
Here, we count how many of the $n$ classifiers give answer $j$ and add noise $\xi\sim\mbox{Lap}_{\xi}[\epsilon^{-1}]$ from a Laplacian distribution with parameter $b=\epsilon^{-1}$ to get a vote for class $j$. Then we choose the class with the largest vote.
This differentially private model can only make a finite number of predictions before expending its privacy budget. PATE uses this finite number of predictions to label the public (non-sensitive data). This creates a differentially private dataset from which any type of model can be trained. By the post-processing property of differential privacy, this student model will also be differentially private.
Papernot et al. (2017) applied PATE on MNIST and SVHN and in a followup paper introduced new aggregation techniques, tighter privacy bounds, and applied the technique on more datasets. They showed that the models produced from the PATE framework either significantly outperformed other differentially private learning approaches in terms of accuracy or privacy budget required. In many cases, PATE was better on both fronts, producing higher test set accuracy with lower privacy budgets.
The PATE framework is straightforward but it is not obvious why it needs multiple teachers. Couldn't we set $N=1$, train a single teacher on the entire dataset, and use a different privacy mechanism to label the data? The answer is yes, but in fact a significant amount of privacy is gained through splitting the dataset and aggregating the outputs of the teachers.
To understand why, recall that differential privacy is based on what happens when we change a single element in our dataset. With data splitting and teacher aggregation, that single element change can only affect a single teacher and because of the voting mechanism, this is less likely to change the result; a change in a single element can only change the result when the voting is close.
This insight is exploited in the followup paper to derive a data dependent bound on the privacy budget spent when labelling the public data. Labelling inputs which produce near ties expends significantly more privacy than labelling inputs which have clear consensus among the teachers. This argues for more teachers. Near ties in the voting are less likely when there are more teachers and so less noise needs to be added during aggregation. This idea can be used to reduce the privacy budget used by allowing the aggregate teacher to refuse to label inputs where there isn't sufficient consensus.
Of course, there is a limit on the benefit of adding more teachers. As the number of teachers increase, the number of data points used to train each teacher decreases and so does its accuracy; while having more teachers may reduce the amount of noise added to the labels for the student it will, eventually, decrease the accuracy of those labels.
In the previous section we introduced methods for building machine learning models that make predictions that are differentially private. We now turn to methods for generating data that are differentially private.
The Netflix Prize competition demonstrated that there is significant power in releasing data for public machine learning challenges. Netflix's own algorithm was beaten within days in 2006 and the top prize was claimed in 2009. However, Naryan and Shmatikov (2006) showed that it was possible to identify individual Netflix users using a linkage attack in which the dataset was matched with film ratings from the Internet Movie Database
In fact, there are many situations where companies may wish to share data with each other without compromising privacy. For example, pooling medical records between hospitals might be particularly useful when the data pertains to rare diseases. In such situations, organizations wish to protect the privacy of the data without preventing their researchers from doing their jobs. The answer to this conundrum is not to release the data at all. Instead we can release synthetic data which looks and acts like the real data, but which has been generated in a differentially private manner.
A generative model takes a dataset and learns how to synthesize realistic new examples. This area of machine learning has seen enormous success in recent years. It is now possible to generate compelling but purely synthetic images of people, cats, horses, artwork and even chemicals.
A successful type of generative model is the Generative Adversarial Network (GAN). This consists of two networks: a generator, $G[\mathbf{z};\boldsymbol\theta)]$ and a discriminator $D[\mathbf{x};\boldsymbol\phi]$. The generator is a network with parameters $\boldsymbol\theta$ that takes a sample of random noise $\mathbf{z}$, and transforms this into a data point $\mathbf{x}$. The discriminator is a second network with parameters $\boldsymbol\phi$ that takes a data point $\mathbf{x}$ and seeks to classify that data point as being either real data or fake (i.e., created by the generator).
Training a GAN can be thought of as a two-player game in which the discriminator tries to identify which data points are fake, and the generator tries to generate data points which fool the discriminator. The players take turns updating their parameters by taking steps of stochastic gradient descent based on the current state of the other player. The generator takes a few steps of SGD trying to minimize
\begin{equation}
\min_{\boldsymbol\theta} \left[-\mathbb{E}_{\mathbf{z}}[\log D[G[\mathbf{z};\boldsymbol\theta];\boldsymbol\phi]]\right],\tag{2}
\end{equation}
followed by the discriminator taking steps of SGD trying to maximize
\begin{equation}
\max_{\boldsymbol\phi} \left[\mathbb{E}_{x}[\log D[\mathbf{x};\boldsymbol\phi]] - \mathbb{E}_{\mathbf{z}}[\log D[G[\mathbf{z};\boldsymbol\theta];\boldsymbol\phi]]\right]. \label{eq:GAN_D}\tag{3}
\end{equation}
Here, the expectations $\mathbb{E}_\mathbf{x}$ and $\mathbb{E}_\mathbf{z}$ are approximated by taking samples from the training data and noise distributions respectively. Once this game has converged, samples can be generated by sampling a noise vector $\mathbf{z}$ and feeding it through the generator $G[\mathbf{z};\boldsymbol\theta]$.
In the previous section we saw that GAN training is based on stochastic gradient descent. The DP-GAN applies the differentially private SGD method of Abadi et al. to learning a GAN; it applies gradient clipping and adds Gaussian noise as described previously and uses the moments accountant to track the privacy budget.
Note that the real data (whose privacy we want to protect) is only present in the discriminator where real samples are compared to the generated ones. Consequently, we update the discriminator with DP-SGD. The generator is updated only based on the discriminator and hence can be updated using standard SGD. This will be differentially private based on the post-processing theorem and any generated samples, or even the entire generator itself, can be released without further compromising the privacy of the data.
Unfortunately, the performance of DP-GAN is relatively poor in practice. Large privacy budgets were required to produce plausible results even on relatively simple datasets like MNIST and generated images using plausible privacy budgets were of low quality. Subsequent work has improved matters. Frigerio et al. (2019) developed better gradient clipping strategies while applying DP-GAN to a broader range of data domains. Torkzadehmahani et al. (2019) introduced an improved privacy accounting method based on Renyi differential privacy and used their framework for training conditional GANs.
In the previous section, we argued that using DP-SGD to train GANs is possible but does not produce good results for reasonable privacy budgets. PATE-GAN is an alternative apporach that leverages the PATE framework to train differentially private GANs.
As we saw above, it is only the discriminator which needs to be differentially private; the generator is updated based on the discriminator and so inherits the differential privacy due to the post-processing property. Since the discriminator is just a classifier which labels its input as either "real" or "fake", we could replace this with the student model from a PATE-trained classifier.
Recall that PATE requires a non-sensitive, unlabelled dataset which is used to query the differentially private aggregated teacher model to produce labels. The dataset and its labels are then used to train the student classifier. In PATE-GAN, we use the generator in place of this unlabelled dataset. The samples from the generator are used to query the teachers and the resulting labels are used to train the student.
In a GAN, the generator and discriminator must be trained in concert. However, training the discriminator is no longer as simple as a few SGD steps. Instead, PATE-GAN alternates between (i) using the generator to produce samples which are then combined with real data to update the teachers, (ii) using the generator to produce samples which are then used to query the (updated) teachers to produce labelled samples which are used to update the student model, and then finally (iii) using the student model to update the generator (figure 4).
The performance of PATE-GAN was not primarily evaluated in terms of the sample quality. Instead, the authors measure performance for the scenario where a real dataset is replaced with a synthetic one. To this end, the generator was used to produce a training set, which was used to train a new model, and finally tested against real data.
Using this criterion, performance with PATE-GAN was significantly better than with DP-GAN given the same privacy budget, although it was worse than training directly on the real data for some datasets. However, performance was only slightly worse than using the same method with a non-differentially private GAN. This implies that the performance gap is not specific to differential privacy, but rather a limitation of current generative models.
Private data generation is known to be theoretically hard in the worst case. However, the properties of non-convex optimization suggest that theoretically deep learning should be nearly impossible and yet we do it in practice every day. Likewise, it seems that private data generation may also be practically feasible. There is much research to be done in this area. There are a plethora of other generative models beyond GANs (VAEs, normalizing flows, implicit likelihood models, etc.) which are ripe for DP analysis and many other open questions. Nonetheless, algorithms like PATE-GAN already provide practically useful differentially private data generation which may be sufficient in some cases. One useful resource in this area is our open-source toolbox that implements PATE-GAN and several other methods for private data generation.
Part I of this tutorial introduced differential privacy generally and showed basic examples of differentially private data analysis. In part II we have discussed state-of-the-art techniques for applying modern machine learning in a differentially private manner. Whether the goal is to learn a single, private machine learning model from sensitive data, or to enable the release of new datasets, there are differentially private options. Machine learning with sensitive data does not have to come at the expense of privacy.
]]>