Solved – Why back propagate through time in a RNN

backpropagationneural networksrecurrent neural networktime series

In a recurrent neural network, you would usually forward propagate through several time steps, "unroll" the network, and then back propagate across the sequence of inputs.

Why would you not just update the weights after each individual step in the sequence? (the equivalent of using a truncation length of 1, so there is nothing to unroll) This completely eliminates the vanishing gradient problem, greatly simplifies the algorithm, would probably reduce the chances of getting stuck in local minima, and most importantly seems to work fine. I trained a model this way to generate text and the results seemed comparable to results I have seen from BPTT trained models. I am only confused on this because every tutorial on RNNs I have seen says to use BPTT, almost as if it is required for proper learning, which is not the case.

Update: I added an answer

Best Answer

Edit: I made a big mistake when comparing the two methods and have to change my answer. It turns out the way I was doing it, just back propagating on the current time step, actually starts out learning faster. The quick updates learn the most basic patterns very quickly. But on a larger data set and with longer training time, BPTT does in fact come out on top. I was testing a small sample for just a few epochs and assumed whoever starts out winning the race will be the winner. But this did lead me to an interesting find. If you start out your training back propagating just a single time step, then change to BPTT and slowly increase how far back you propagate, you get faster convergence.

Related Question