Solved – Learning initial state in RNNs

lstmrecurrent neural network

I was reading Hinton's slides about recurrent networks (link), and he says that the initial state of the network should be learned just like the weights (slide 14). If that's the case, how would we handle the unknown initial state during test time?

Additionally, in the "Learning Precise Timing with LSTM" paper by Gers, Schraudolph, and Schmidhuber (link), they say that "the initial state of the network should be learned as well." But I've never read such approach being used in most papers that work with RNNs.
Has anybody learned initial state as parameters and got better results from it? If so, how did you deal with unknown states during test time?

Best Answer

I am assuming you understood how to learn all the other weights in the RNN.

All states need to be computed, i.e. we have to take some input and multiply them with some weight matrix and then pass the resulting product through some activation function and finally obtain a state.

Initial states are odd because we do not compute them by taking the product of some initial input and weight matrix and use an activation function - we can simply "give" them some random value. They are essentially weights.

Instead of "manually" assigning them some value, why not just learn them like we learn any other weights? Which means we will have to take the gradient with respect to these initial weights and update them just like the other weights.

An alternate way to think about it is by considering the following scenario: Let's use an initial input. We can fix the initial input to be $1$ and learn an initial weight matrix. The product of the initial input ($1$) and the initial weight matrix will give you an initial state. At test time we will continue to use initial input $1$ and multiply it with the initial weight matrix which will result in the initial state. The weights won't change so the initial state won't change either, instead of always multiplying $1$ with the initial weight matrix you can just save the initial state. Now you have essentially learned an initial state!

Related Question