Solved – the intuition behind a Long Short Term Memory (LSTM) recurrent neural network

intuitionneural networkspredictive-modelsrecurrent neural networktime series

The idea behind Recurrent Neural Network (RNN) is clear to me. I understand it in the following way:
We have a sequence of observations ($\vec o_1, \vec o_2, \dots, \vec o_n$) (or, in other words, multivariate time series). Each single observation $\vec o_i$ is an $N$-dimensional numeric vector. Within RNN-model we assume that the next observation $\vec o_{i+1}$ is a function of previous observation $\vec o_{i}$ as well as previous "hidden state" $\vec h_i$, where hidden states are also represented by numerical vectors (dimensions of observed and hidden states can be different). The hidden states themselves are also assumed to depend on the previous observation and hidden state:

$\vec o_i, \vec h_i = F (\vec o_{i-1}, \vec h_{i-1})$

Finally, in the RNN model, the function $F$ is assumed to be a neural network. We train (fit) the neural network using the available data (a sequence of observations). Our goal in the training is to be able to predict the next observation as accurately as possible using the previous observations.

Now, LSTM network is a modification of RNN network. As far as I understood, the motivation behind LSTM is to resolve the problem of short memory that is peculiar to RNN (conventional RNN have troubles with relating events that are too far separated in time).

I understand how LSTM networks work. Here is the best explanation of LSTM that I have found. The basic idea is as follows:

In addition to the hidden state vector we introduce a so called "cell state" vector that has the same size (dimensionality) as the hidden state vector ($\vec c_i$). I think that the "cell state" vector is introduced to model long term memory. As in the case of conventional RNN, the LSTM network gets the observed and hidden state as the input. Using this input, we calculate a new "cell state" in the following way:

$\vec c_{i+1} = \vec \omega_1 (\vec o_i, \vec h_i) \cdot \vec c_i + \vec \omega_2 (\vec o_i, \vec h_i) \cdot \vec c_{int} (\vec o_i, \vec h_i),$

where the functions of $\vec \omega_1$, $\vec \omega_2$ and $\vec c_{int}$ are modeled by neural networks. To make the expression simpler I just remove the arguments:

$\vec c_{i+1} = \vec \omega_1 \cdot \vec c_i + \vec \omega_2 \cdot \vec c_{int}$

So, we can see that the new "cell state vector" ($\vec c_i$) is a weighted sum of the old state vector ($\vec c_{i-1}$) and an "intermediate" cell state vector ($\vec c_{int}$). The multiplication between the vectors is component-wise (we multiply two N dimensional vectors and get, as a result, another N dimensional vector). In other words, we mix two cell states vectors (the old one and intermediate one) using component specific weights.

Here is the intuition between the described operations. The cell state vector can be interpreted as a memory vector. The second weights vector $\vec \omega_2$ (calculated by a neural network) is a "keep" (or forget) gate. Its values decide if we keep or forget (erase) a corresponding value from the cell state vector (or long term memory vector). The first weights vector ($\omega_1$), which is calculated by another neural network, is called "write" or "memorize" gate. It decides if a new memory (the "intermediate" cell state vector) has to be saved (or more precisely, if a particular component of it has to be saved / written). The "intermediate" cell state is the new memory that is either ignored or memorized in the cell state (depending on values in the $\vec \omega_1$ vector). Actually, it would be more accurate to say, that with the two weights vectors ($\vec \omega_1$ and $\vec \omega_2$) we "mix" the old and new memory.

So, after the above described mixing (or forgetting and memorization) we have a new cell state vector. Then we calculate an "intermediate" hidden state by using another neural network (as before, we use observed state $\vec o_i$ and hidden state $\vec h_i$ as inputs). Finally, we combine the new cell state (memory) with the "intermediate" hidden state ($\vec h_{int}$) to get the new (or "final") hidden state that we actually output:

$\vec h_{i+1} = \vec h_{int} \cdot S(\vec c_{i+1}),$

where $S$ is a sigmoid function applied to each component of the cell state vector.

So, my question is: Why (or how exactly) does this architecture solve the problem?

In particular I do not understand the following:

  1. We use a neural network to generate "intermediate" memory (cell state vector) that is mixed with the "old" memory (or cell state) to get a "new" memory (cell state). The weighting factors for the mixing are also calculated by neural networks. But why can't we use just one neural network to calculate the "new" cell state (or memory). Or, in other words, why can't we use the observed state, hidden state and the old memory as inputs to a neural network that calculates the "new" memory?
  2. In the end we use the observed and hidden states to calculate a new hidden state and then we use the "new" cell state (or (long-term) memory) to correct the component of the newly calculated hidden state. In other words, the components of the cell state are used just as weights that just reduce the corresponding components of the calculated hidden state. But why is the cell state vector used in this particular way? Why can't we calculate the new hidden state by putting the cell state vector (long term memory) to the input of a neural network (that also take the observed and hidden states as the input)?

Added:

Here is a video that might help to clarify how different gates ("keep", "write" and "read") are organized.

Best Answer

As I understand your questions, what you picture is basically concatenating the input, previous hidden state, and previous cell state, and passing them through one or several fully connected layer to compute the output hidden state and cell state, instead of independently computing "gated" updates that interact arithmetically with the cell state. This would basically create a regular RNN that only outputted part of the hidden state.

The main reason not to do this is that the structure of LSTM's cell state computations ensures constant flow of error through long sequences. If you used weights for computing the cell state directly, you'd need to backpropagate through them at each time step! Avoiding such operations largely solves vanishing/exploding gradients that otherwise plague RNNs.

Plus, the ability to retain information easily over longer time spans is a nice bonus. Intuitively, it would be much more difficult for the network to learn from scratch to preserve cell state over longer time spans.

It's worth noting that the most common alternative to LSTM, the GRU, similarly computes hidden state updates without learning weights that operate directly on the hidden state itself.