Teaching attention, part 4 / N: Soft Attention
There is more to memory than compression. A reduction of information can have structure – for example, in a gaussian mixture model, the sufficient statistics have are the means within each cluster. Depending on which cluster you might be visiting, you might focus on a different one of the means. A good compression should induce an effortless way of searching through different parts of a summary. In essence,
Good models let us navigate information.
In deep learning, this process of navigating information is called “attending to” different inputs. We’ll illustrate a canonical attention mechanism in the context of sentence translation, but it’s worth noting that the principle applies to many settings with structured outputs (like graphs), where the influences from input to output are somehow localized.
In the abstract, sentence translation is a prediction problem where the inputs are sequences and the outputs are also sequences. The input and output are the same sentence in a source and a target language, respectively. Each element of the sequence is a word in the sentence. Words can be represented either by one-hot encodings from some vocabulary (usually, the words in a corpus above some minimum frequency), or by embeddings derived from some previously trained model on another corpus. This is illustrated in the figure below.
Once you make the leap into representing words as numerical vectors, it’s natural to think of sentences as curves in some high-dimensional space. The two sentences are just two curves in two spaces, and you need to learn to predict the second from the first. The reason this problem is doable is because, after seeing enough training examples, you might associated certain subsequences (mini curve shapes) in the source space with subsequences in the target space.
Sequence-to-Sequence Architecture
I’m tempted to call the first approach to this problem the “naive approach,” except that it doesn’t feel right to call something naive that was state-of-the-art only a couple of years ago [1].
The idea is to build a link between the two sequence shapes, using an appropriate learned representation. You can think of it as a two step process,
(1) Build a memory of the first sequence, which you’ll be able to refer to later on. (2) Learn to associate the memory of the first sequence with a reasonable shape for the second.
Step (1) is done with an RNN encoder, just like what we’ve been exploring in the previous posts. The final state \(h_{T}\) (the final green bar) gives a succint representation of the entire sequence (the red dots). More than being an unsupervised representation, it’s one that’s been tailored exactly for the problem of decoding into the target language.
The second step is also accomplished by an RNN, but with a a couple twists. As in usual RNNs, you build up a summary, let’s call it \(\tilde{h}_t\), based on the previous timepoints summary, and also a new input, the previous target word \(y_t\). The idea is that the summary is giving a sense of what the sequence means so far, which is useful for figuring out how it will evolve (certain shapes are much more common than others). The first twist is that you also are allowed to look at the summary the source sequence, the original \(h_T\) (practically, this is just another linear input to the RNN unit).
The second twist is that the inputs are not properly specified, because they aren’t actually available at test time. If you knew \(y_t\), you wouldn’t need to bother with prediction in the first place! Instead, you should build a separate prediction model from \(\tilde{h}_t\) to \(\hat{y}\) and use the predictions for the actual inputs (an alternative, called “teacher-forcing,” uses the real \(y_t\) during training time, can have advantages, but comes with it’s own issues). This process is illustrated below. The dashed line are the predictions that you have access to, and the length of the yellow curves gives a sense of the loss.
Attention
The decoding process above puts a lot of burden on the final encoded \(h_T\) from the first sequence. This seems easily avoidable – it’s not like we need to throw away all the previous summaries because of computational constraints. Moreover, you would expect that the summaries \(h_t\) and \(\tilde{h}_t\) might be closely related in timepoints that are close together. The beginning of the source and target sentences should mean essentially the same thing, even though they are in different languages.
Soft attention is the name of a type of learning module designed directly around this observation [2]. The idea is to use a different summary at each step of decoding. Instead of only ever using \(h_T\), we’re allowed to use some combination \(\sum_{i} \alpha_i h_i\) of the encoder summaries, where the weights \(\alpha_i\) depend on the timepoint of the decoder. These weights are illustrated by the pink distributions in the figure below.
At the next timepoint, the attetion distribution is free to change, according to the new position of the decoder. This extra degree of flexibility – the fact that the RNN decoder has access to different encoder summaries at different timepoints, not just the unchanging \(h_T\) – is the property that distinguishes a decoder with an attetion mechanism from a ordinary decoder. A updated attetion distribution, associated with the next decoder timepoint, is show in the figure below.
Exercise: Write out pseudocode for the forwards pass of a sequece-to-sequece model with an attetion mechanism in the decoder.
How does the model know what \(\alpha_i\) distributions are most appropriate for the different sentences? After all, it must be learned from the data, somehow. There is no single consensus on how to define these weights, and many papers have been written specifically around novel proposals for guiding attention effectively. That said, two approaches are representative of a wide class of mechanisms [3],
- Local attention: At step \(t\) in decoding, define \(\alpha_i \propto
\exp\left(w_1^{T} \tilde{h}_t, \dots, w_T^{T}\tilde{h}_t\right)\). Or, in pseudocode,
the \(\alpha\) vector can be computed as
a = softmax(l(h_t))
wherel = nn.linear(dim(h_t), T)
. - Content-based attention: At step \(t\) in decoding, define \(\alpha_i \propto
\exp\left(h_1^{T}\tilde{h}_t, \dots, h_T^{T}\tilde{h}_t\right)\). In
pseudocode, this is
a = softmax(H_enc @ h_t)
whereH_enc
is formed by stacking the summaries \(h_1, \dots, h_T\) as rows.
The main difference is that local attention does not refer to the encoder summaries \(h_t\). It puts high mass on some parts of the source sentence depending solely on the state of the decoder \(\tilde{h}_t\). The decoder summaries tend to reflect, in some of their coordinates, how far along into the target sentence they are. So, these attention distributions tend to just place high mass on timepoints in the encoder close to the current timepoint in the decoder.
The content-based attention mechanism, on the other hand, allots weight \(\alpha_i\) according to how similar the source and target encodings are (a large inner product corresponds to a small angle). This means that if there is are words in the source sentence that have similar meaning \(h_t\) to the current decoder summary, \(\tilde{h}_t\), then those words will be upweighted, no matter their position in the sentence. This is qualitatively illustrated in the figure below. The content-based approach places high mass near the end of the encoder sentence, because it is close (in the embedding space), to the meaning at the decoder timepoint. This is in spite of the fact that those words occur at a timepoint very different than the current decoder timepoint.
Discussion
No matter how we define the \(\alpha_i\)’s, the fact that we learn them has a nice consequence: for free, we’ve learned a soft-alignment between the source and target sentence. At any timepoint in decoding, the encodings with high \(\alpha_i\) are said to “softly-align” to that timepoint. In this way, you might notice “inversions,” where the decoder doesn’t just proceed linearly down the encoder, placing all mass on \(h_t\) when decoding \(\tilde{h}_t\). This is common in translation between languages that put adjectives before / after nouns, for example.
A final observation – including this mechanism is especially beneficial when working with long sentences. This is phenomenon documented in the literature [2], and it makes sense in retrospect: long sentences are more suceptible to the vanishing gradient problem, and even with gating, the final \(h_T\) tends to mostly reflect cotributions from the end of the source sentence. A approach with an attetion mechaism, in contrast, can always draw on summaries taken near the start of the sentence, by putting most of the mass of the \(\alpha_i\) there.
Finally, while we’ve focused on attention in a sequence to sequence context, it’s worth noting that these principles extend much more generally. Whenever you have a structured output or collection of tasks, it’s possible to refer to summaries selectively, attending to only those that are most relevant for the current element in the more complex structure / collection. For example, in a graph problem, only the nearby nodes might be worth attending to. The general takeaway is that, whenever you might be able to exploit specific structure in the relationship between summaries $h_t$, you can design an attention mechanism that makes learning from them simpler.
Exercise Solution:
// Encode
all_h = []
for t in range(T):
h = rnn_encoder(x[t], h[t - 1])
all_h.append(h)
// Decode with attention
h_tilde, y_hat = all_h[-1], []
for t in range(T_prime):
y_hat.append(pred(h_tilde))
alpha = attn_weights(y_hat[-1], h_tilde, all_h)
h_tilde = rnn_decoder(y_hat[-1], h_tilde, alpha * all_h)
[1] Sutskever, I., Vinyals, O., & Le, Q. V. (2014). Sequence to sequence learning with neural networks. In Advances in neural information processing systems (pp. 3104-3112).
[2] Bahdanau, D., Cho, K., & Bengio, Y. (2014). Neural machine translation by jointly learning to align and translate. arXiv preprint arXiv:1409.0473.
[3] Luong, M. T., Pham, H., & Manning, C. D. (2015). Effective approaches to attention-based neural machine translation. arXiv preprint arXiv:1508.04025.
Kris at 09:04