Deep Learning – Understanding Attention Mechanisms in Recurrent Neural Networks

attentiondeep learninglstmrecurrent neural networktime series

Attention mechanisms have been used in various Deep Learning papers in the last few years. Ilya Sutskever, head of research at Open AI, has enthusiastically praised them:
https://towardsdatascience.com/the-fall-of-rnn-lstm-2d1594c74ce0

Eugenio Culurciello at Purdue University has claimed that RNNs and LSTMs should be abandoned in favor of purely attention-based neural networks:

https://towardsdatascience.com/the-fall-of-rnn-lstm-2d1594c74ce0

This seems an exaggeration, but it's undeniable that purely attention-based models have done quite well in sequence modeling tasks: we all know about the aptly named paper from Google, Attention is all you need

However, what exactly are attention-based models? I've yet to find a clear explanation of such models. Suppose I want to forecast the new values of a multivariate time series, given its historical values. It's quite clear how to do that with an RNN having LSTM cells. How would I do the same with an attention-based model?

Best Answer

Attention is a method for aggregating a set of vectors $v_i$ into just one vector, often via a lookup vector $u$. Usually, $v_i$ is either the inputs to the model or the hidden states of previous time-steps, or the hidden states one level down (in the case of stacked LSTMs).

The result is often called the context vector $c$, since it contains the context relevant to the current time-step.

This additional context vector $c$ is then fed into the RNN/LSTM as well (it can be simply concatenated with the original input). Therefore, the context can be used to help with prediction.

The simplest way to do this is to compute probability vector $p = \text{softmax}(V^Tu)$ and $c = \sum_i p_i v_i$ where $V$ is the concatenation of all previous $v_i$. A common lookup vector $u$ is the current hidden state $h_t$.

There are many variations on this, and you can make things as complicated as you want. For example, instead using $v_i^T u$ as the logits, one may choose $f(v_i, u)$ instead, where $f$ is an arbitrary neural network.

A common attention mechanism for sequence-to-sequence models uses $p = \text{softmax}(q^T \tanh(W_1 v_i + W_2 h_t))$, where $v$ are the hidden states of the encoder, and $h_t$ is the current hidden state of the decoder. $q$ and both $W$s are parameters.

Some papers which show off different variations on the attention idea:

Pointer Networks use attention to reference inputs in order to solve combinatorial optimization problems.

Recurrent Entity Networks maintain separate memory states for different entities (people/objects) while reading text, and update the correct memory state using attention.

Transformer models also make extensive use of attention. Their formulation of attention is slightly more general and also involves key vectors $k_i$: the attention weights $p$ are actually computed between the keys and the lookup, and the context is then constructed with the $v_i$.


Here is a quick implementation of one form of attention, although I can't guarantee correctness beyond the fact that it passed some simple tests.

Basic RNN:

def rnn(inputs_split):
    bias = tf.get_variable('bias', shape = [hidden_dim, 1])
    weight_hidden = tf.tile(tf.get_variable('hidden', shape = [1, hidden_dim, hidden_dim]), [batch, 1, 1])
    weight_input = tf.tile(tf.get_variable('input', shape = [1, hidden_dim, in_dim]), [batch, 1, 1])

    hidden_states = [tf.zeros((batch, hidden_dim, 1), tf.float32)]
    for i, input in enumerate(inputs_split):
        input = tf.reshape(input, (batch, in_dim, 1))
        last_state = hidden_states[-1]
        hidden = tf.nn.tanh( tf.matmul(weight_input, input) + tf.matmul(weight_hidden, last_state) + bias )
        hidden_states.append(hidden)
    return hidden_states[-1]

With attention, we add only a few lines before the new hidden state is computed:

        if len(hidden_states) > 1:
            logits = tf.transpose(tf.reduce_mean(last_state * hidden_states[:-1], axis = [2, 3]))
            probs = tf.nn.softmax(logits)
            probs = tf.reshape(probs, (batch, -1, 1, 1))
            context = tf.add_n([v * prob for (v, prob) in zip(hidden_states[:-1], tf.unstack(probs, axis = 1))])
        else:
            context = tf.zeros_like(last_state)

        last_state = tf.concat([last_state, context], axis = 1)

        hidden = tf.nn.tanh( tf.matmul(weight_input, input) + tf.matmul(weight_hidden, last_state) + bias )

the full code

Related Question