Attention

Deep Learning

Introduction

Until 2014, recurrent neural networks (RNNs) were the default choice for modeling sequential tasks using deep learning. Proposed in 2014, attention models have quickly become an important alternative. Originally developed for machine translation, they have been widely applied in diverse application domains such as speech, computer vision, and natural language processing.

In this article, we provide an intuition and mathematical formulation for the implementation of attention mechanism. Over the years, several variants of this basic attention idea have been proposed. We cover the various types of attention in a separate article.

Prerequisites

To understand the attention model, we recommend familiarity with the concepts in

Follow the above links to first get acquainted with the corresponding concepts.

Sequence-to-sequence models: A recap

Many machine learning tasks involve the transformation of one sequence into another. For example, machine translation involves the transformation of text in one language into a text in another language. A popular deep learning strategy for modeling sequence to sequence tasks is the sequence-to-sequence model, also known as seq2seq models. Seq2seq models consist of two subcomponents — an encoder network and a decoder network.

The encoder encodes a \(\tau\)-length sequence of input tokens \( \seq{\vx^{ (1)}, \ldots, \vx^{(\tau)}} \) into a sequence of codes (vectors) \( \seq{\ve^{(1)},\ldots,\ve^{(\tau)}} \) such that each subsequent encoding is a function of the previous encoding and the input at that step.

$$ \ve^{(t)} = f_e(\ve^{(t-1)}, \vx^{(t)}), ~~\forall t=1,\ldots,\tau $$

where, \( f_e \) is the encoding function learned by the encoder. For example, the encoder can be an LSTM, as used by CITE[sutskever-2014].

The set of encodings are passed on to the decoder for output sequence generation as a context vector. The context vector is some function \( f_c \) of the sequence of encodings,

$$ \vc = f_c(\set{\ve^{(1)},\ldots,\ve^{(\tau)}}) $$

For example, a simple strategy is to use the final encoding \( \vh^{(\tau)} \) as the context vector, as used in CITE[sutskever-2014].

The decoder utilizes the context vector \( \vc \) to generate the output sequence \( \seq{\vy^{(1)}, \ldots, \vy^{\dash{\tau}}} \). The decoder achieves this transformation by learning the following function.

\begin{equation} \vy^{(t)} = g_d(\vc, \vy^{(t-1)}, \vd^{(t)}), ~~\forall t=1,\ldots,\dash{\tau} \label{eqn:decoder-output} \end{equation}

where, \( g_d \) is the decoding function learned by the model and \( \vd^{(t)} \) is the internal state of the decoder. The decoder state itself follows a recurrence relationship with previous state as

\begin{equation} \vd^{(t)} = f_d(\vc, \vy^{(t-1)}, \vd^{(t)}), ~~\forall t=1,\ldots,\dash{\tau} \label{eqn:decoder-state} \end{equation}

For example, the decoder could be modeled as an LSTM-based RNN and \( \vd^{(t)} \) can be the hidden state of such RNN.

The length of the output sequence, \( \dash{\tau} \) could be different from the length of the input sequence, \( \tau \).

The encoder and decoder networks are commonly implemented as recurrent neural networks (RNNs), with specialized cells such as LSTM or GRU. The seq2seq model is trained using supervised examples — tuples of input and target sequences to jointly fit the encoder and decoder models.

Challenges with the encoder-decoder architecture

The seq2seq encoder-decoder architecture may seem intuitive and straightforward but there are two main challenges to this approach.

  • All the information from the input sequence is compressed into a single context vector \( \vc \). The vector \( \vc \) is a fixed length vector and may not be expressive enough to capture all the context from the input sequence, leading to information loss and consequently poor performance of the decoder.
  • There is no inherent modeling of the alignment between the input and the output sequence since the encoder and decoder only communicate through the context vector \( \vc \). In most sequential mapping tasks such as automatic translation, the output sequence elements depend on specific subsequences of the input sequence. This correspondence may not be position aligned but might be semantically and contextually dependent. The encoder-decoder architecture is unable to model this since the communication uses the intermediate language of the encoding \( \vc \), losing this direct alignment and correspondence between input and output elements.

The design of attention models is motivated by the need to overcome these challenges with traditional encoder-decoder networks.

Attention models: Intuition

The challenges mentioned earlier can be overcome by providing the decoder a custom context vector each output of the sequence. A good custom context vector should provide information from relevant steps of the encoding sequence. The relevance of an encoding to the decoder should depend on the state of the decoder

To achieve such selective behavior for each output token, the model can instead focus its attention on specific elements of the sequence of encodings. This is achieved by inducing attention weights for each step of the sequence of encodings.

The attention weights are used to build a linear combination of the encoding sequence — a vector known as the context vector. The context vector is the mechanism by which the decoder can have a complete, but focused, overview of the entire sequence of encodings.

The attention weights are additional parameters of the model, incorporated as an additional feedforward neural network. They are learned jointly with the rest of the encoder-decoder architecture.

This is the central idea behind the attention mechanism. Now let's study the specifics.

Attentive decoder

As motivated in the previous section, instead of a single context vector, we need a custom context vector for each state of the decoder. In other words, we change the decoder function from Equation \eqref{eqn:decoder-output} to include a decoder-state-specific context vector \( \vc^{(t)} \), that we shall call the attentive context vector, as follows

\begin{equation} \vy^{(t)} = f_d(\vc^{(t)}, \vy^{(t-1)}, \vd^{(t)}), ~~\forall t=1,\ldots,\dash{\tau} \label{eqn:decoder-output-attention} \end{equation}

Similarly, the decoder state recurrence relation from Equation \eqref{eqn:decoder-state} should also be updated to incorporate the attentive context vector as follows

\begin{equation} \vd^{(t)} = f_d(\vc^{(t)}, \vy^{(t-1)}, \vd^{(t)}), ~~\forall t=1,\ldots,\dash{\tau} \label{eqn:decoder-state-attention} \end{equation}

How do we compute this attentive context vector? Let's see next.

Attentive context

The attentive context vector is a linear combination of the sequence of encodings, weighted by state-specific weights known as attention weights. Mathematically, we express this as

$$ \vc^{(t)} = \sum_{\dash{t}} \alpha_{t\dash{t}} \vh^{(\dash{t})} $$

Note that the weights \( \alpha_{t\dash{t}} \) are specific to the decoder time-step \( t \) and the encoder time-step \( \dash{t} \).

The next question: how do we compute the attention weights? In deep learning, there is always another network for that!

The alignment model

Attention weights are computed as a normalized score of a deep feedforward network, the alignment model. The alignment model trains a network to predict the weights as a function of decoder state and encoder state. Let's denote the function learned by the alignment model as \( f_a \).

$$ a_{t\dash{t}} = f_{\text{AM}}(\vd^{(t-1)}, \ve^{(\dash{t})}) $$

The alignment model is also known as a compatibility function because it infers the compatibility of the decoder state to the positions in the input sequence. The alignments or compatibilities are then normalized as a softmax.

$$ \alpha_{t\dash{t}} = \frac{\textexp{a_{t\dash{t}}}}{\sum_{k=1}^{\tau} \textexp{a_{tk}}} $$

Thus, we can also succinctly write the collection of weights \( \alpha_{t\dash{t}} \) as a vector \( \vec{\alpha} = [\alpha_{t1}, \ldots, \alpha_{t\tau}] \), and represent these as a softmax over the function \( f_{\text{AM}} \).

$$ \vec{\alpha}_{t} = \text{softmax} \left( f_{\text{AM}}\left(\vd^{(t-1)}, \seq{\ve^{(1)}, \ldots, \ve^{(\dash{t})}} \right) \right) $$

With this representation, the context vector \( \vc^{(t)} \) is a dot product of the weight vector \( \vec{\alpha}_t \) and the matrix representing the sequence of encodings \( \mE = [\ve^{(1)}, \ldots, \ve^{(\tau)}] \).

$$ \vc^{(t)} = \vec{\alpha}_t \mE $$

The attention function

As another perspective on the attention mechanism, we can consider the attention being implemented with a specialized block, the attention function, which works on an input query and input key-value pairs to arrive at an output.

In the case of the seq2seq model,

  • the query is the decoder state vector \( \vd^{(t)} \)
  • for the key-value pairs, consider the keys and values as the sequence of encodings. For example, the keys could be position \( \dash{t} \) of the input sequence (represented as a one-hot vector) and the values could be the encoding \( \ve^{(\dash{t})} \). In the original paper that proposed attention, both the key and value are \( \ve^{(\dash{t})} \) CITE[bahdanau-2014].
  • and finally, the output of the attention function is the context vector \( \vc^{(t)} \).

Put succinctly, the output of the attention function \( f_a \) is a weighted sum of the values, where the weight assigned to each value is computed by the compatibility function of the query with the corresponding key.

\begin{aligned} \vc^{(t)} &= \text{Attention}\left(\vq^{(t)}, \vk^{(t)}, \vv^{(t)}\right) \label{eqn:attention-function} \end{aligned}

For a succinct representation, the attention function is often implemented as a function over matrices,

\begin{aligned} \mC &= \text{Attention}\left(\mQ, \mK, \mV\right) \label{eqn:attention-function-matrix} \end{aligned}

where, \( \mQ \) is a matrix containing the sequence of decoder states \( \seq{\vd^{(1)}, \ldots, \vd^{(\dash{\tau})}} \). In our current model, the matrices \( \mK \) and \( \mV \) both represent the sequence of encodings \( \seq{\ve^{(1)}, \ldots, \ve^{(\tau)}} \). That being said, it is possible to have different \( \mK \) and \( \mV \) in some variants of attention.

Types of compatibility functions

The above query and key-value based formulation allows for a lot of flexibility in defining the compatibility functions. We list some of the popular implementations of the compatibility functions here.

Additive attention

The seminal attention paper CITE[bahdanau-2014] used a feedforward neural network with one hidden layer for computing the compatibility function. The input to this network consisted of the decoder state (the query), and the encoder state (the key) and the output were the weights (after softmax).

Let us write this function in detail. Let \( \mW_q \) and \( \mW_k \) represent the weights in the hidden layer of this network for the query \( \mQ \) and key \( \mK \) input terms, respectively. Let \( \mW_o \) denote the weights of the output layer of this network. The network used a hyperbolic tangent activation function for the hidden layer, and a softmax as the output activation function.

For this network, the overall function being computed is

$$ \text{Attention}(\mQ, \mK, \mV) = \text{softmax}\left( \mW_o^T \tanh\left(\mQ^T\mW_q + \mK^T\mW_k\right) \right) \mV $$

Since the query and the key appear in an addition, this form of attention is now known as additive attention.

Dot product attention

Additive attention requires several parameters \( (\mW_k, \mW_q, \mW_o) \) to be trained. We can do something simpler by just computing a dot-product of the query and key terms.

$$ \text{Attention}(\mQ, \mK, \mV) = \text{softmax}\left( \mQ\mK^T \right) \mV $$

There are no additional parameters to learn. Plus, being a mere dot-product, it is super efficient to calculate. This form of attention is known as dot-product attention. A scaled variant of dot-product attention, aptly called scaled dot-product attention, is used in the popular Transformer model.

Multiplicative attention

Additive attention has too many parameters. Dot-product attention has none! A middle ground, is the multiplicative attention, with just a single weight matrix \( \mW \). It is computed as

$$ \text{Attention}(\mQ, \mK, \mV) = \text{softmax}\left( \mQ\mW\mK^T \right) \mV $$

Training

Thus, a basic attention model has three components: the encoder, the decoder, and the alignment model. Training involves learning the parameters of all these components. In the case of dot-product attention, the alignment model does not have any parameters to learn. The recipe for training the encoder, decoder, and alignment model jointly is straightforward and follows the usual process of training a deep neural network.

First we define a task-dependent loss for the predictions of the model. Subject to this loss, we utilize a gradient-based optimization strategy such as stochastic gradient descent (SGD) or its variant to fit the model parameters to the available training data. The gradients are computed using backpropagation.

Attention beyond RNNs

Attention is not limited to RNN architectures. In fact, one of the coolest ideas of this decade in the deep learning domain is the Transformer, a neural network devoid of the RNN architecture that utilizes attention to implement sequence-based tasks. The transformer network is a precursor to many successful models that are currently popular for sequential tasks such as BERT and GPT. We have covered the transformer model in detail as a separate article.

Please share

Let your friends, followers, and colleagues know about this resource you discovered.

Let's connect

Please share your comments, questions, encouragement, and feedback.