Solved – Sudden accuracy drop when training LSTM or GRU in Keras

grulstmneural networks

My recurrent neural network (LSTM, resp. GRU) behaves in a way I cannot explain. The training starts and it trains well (the results look quite good) when suddenly accuracy drops (and loss rapidly increases) – both training and testing metrics. Sometimes the net just goes crazy and returns random outputs and sometimes (as in the last of three given examples) it starts to return same output to all the inputs.

image

Do you have any explanation for this behavior? Any opinion is welcome. Please, see the task description and the figures below.

The task: From a word predict its word2vec vector
The input: We have an own word2vec model (normalized) and we feed the network with a word (letter by letter). We pad the words (see the example below).
Example: We have a word football and we want to predict its word2vec vector which is 100 dimensions wide. Then the input is $football$$$$$$$$$$.

Three examples of the behavior:

Single layer LSTM

model = Sequential([
    LSTM(1024, input_shape=encoder.shape, return_sequences=False),
    Dense(w2v_size, activation="linear")
])

model.compile(optimizer='adam', loss="mse", metrics=["accuracy"])

image

Single layer GRU

model = Sequential([
    GRU(1024, input_shape=encoder.shape, return_sequences=False),
    Dense(w2v_size, activation="linear")
])

model.compile(optimizer='adam', loss="mse", metrics=["accuracy"])

image

Double layer LSTM

model = Sequential([
    LSTM(512, input_shape=encoder.shape, return_sequences=True),
    TimeDistributed(Dense(512, activation="sigmoid")),
    LSTM(512, return_sequences=False),
    Dense(256, activation="tanh"),
    Dense(w2v_size, activation="linear")
])

model.compile(optimizer='adam', loss="mse", metrics=["accuracy"])

image

We have also experienced this kind of behavior in another project before which used similar architecture but its objective and data were different. Thus the reason should not be hidden in the data or in the particular objective but rather in the architecture.

Best Answer

Here are my suggestion to pinpoint the issue:

1) Look at training learning curve: How is the learning curve on train set? Does it learn the training set? If not, first work on that to make sure you can over fit on the training set.

2) Check your data to make sure there is no NaN in it (training, validation, test)

3) Check the gradients and the weights to make sure there is no NaN.

4) Decrease the learning rate as you train to make sure it's not because of a sudden big update that stuck in a sharp minima.

5) To make sure everything's right, check the predictions of your network so that your network is not making some constant, or repetitive predictions.

6) Check if your data in your batch is balanced with respect to all classes.

7) normalize your data to be zero mean unit variance. Initialize the weights likewise. It will assist the training.