Solved – Variable importance in RNN or LSTM

importancelstmneural networksrecurrent neural network

Several method have been devised for accessing or quantifying variable importance (even if only relative to each other) in MLP neural network models:

  • Connection weights
  • Garson’s algorithm
  • Partial derivatives
  • Input perturbation
  • Sensitivity analysis
  • Forward stepwise addition
  • Backward stepwise elimination
  • Improved stepwise selection 1
  • Improved stepwise selection 2

(these were described in http://dx.doi.org/10.1016/j.ecolmodel.2004.03.013)

Is there any method that can be applied to RNNs or LSTMs neural networks?

Best Answer

In short, yes, you can get some measure of variable importances for RNN based models. I won't iterate through all of the listed suggestions in the question, but I will walk through an example of sensitivity analysis in depth.

The data

The input data for my RNN will be composed of a time-series with three features, $x_1$, $x_2$, $x_3$. Each feature will be all be drawn from the random uniform distribution. The target variable for my RNN will be a time-series (one prediction for each time-step in my input):

$$ y = \left\{\begin{array}{lr} 0, & \text{if } x_1 x_2 \geq 0.25\\ 1, & \text{if } x_1 x_2 < 0.25 \end{array}\right. $$

As we can see, the target is dependent on only the first two features. Thus, a good variable importance metric should show the first two variables being important, and the third variable being unimportant.

The model

The model is a simple three layer LSTM with a sigmoid activation in the final layer. The model will be trained in 5 epochs with 1000 batches per epoch.

Variable importance

To measure the variable importance, we'll take a large sample (250 time-series) of our data $\hat{x}$ and compute the model's predictions $\hat{y}$. Then, for each variable $x_i$ we'll perturb that variable (and only that variable) by a random normal distribution centered at 0 with scale 0.2 and compute a prediction $\hat{y_i}$. We'll measure the effect this perturbation has by computing the Root Mean Square difference between the original $\hat{y}$ and the perturbed $\hat{y_i}$. A larger Root Mean Square difference means that variable is "more important".

Obviously, the exact mechanism you use to perturb your data, and how you measure the difference between perturbed and unperturbed outputs, will be highly dependent on your particular dataset.

Results

After doing all of the above, we see the following importances:

Variable 1, perturbation effect: 0.1162
Variable 2, perturbation effect: 0.1185
Variable 3, perturbation effect: 0.0077

As we expected, variables 1 and 2 are found to be much more important (about 15x more) than variable 3!

Python code to reproduce

from tensorflow import keras  # tensorflow v1.14.0 was used
import numpy as np            # numpy v1.17.1 was used

np.random.seed(2019)

def make_model():
    inp = keras.layers.Input(shape=(10, 3))
    x = keras.layers.LSTM(10, activation='relu', return_sequences=True)(inp)
    x = keras.layers.LSTM(5, activation='relu', return_sequences=True)(x)
    x = keras.layers.LSTM(1, activation='sigmoid', return_sequences=True)(x)
    out = keras.layers.Flatten()(x)
    return keras.models.Model(inp, out)

def data_gen():
    while True:
        x = np.random.rand(5, 10, 3)  # batch x time x features
        yield x, x[:, :, 0] * x[:, :, 1] < 0.25

def var_importance(model):
    g = data_gen()
    x = np.concatenate([next(g)[0] for _ in range(50)]) # Get a sample of data
    orig_out = model.predict(x)
    for i in range(3):  # iterate over the three features
        new_x = x.copy()
        perturbation = np.random.normal(0.0, 0.2, size=new_x.shape[:2])
        new_x[:, :, i] = new_x[:, :, i] + perturbation
        perturbed_out = model.predict(new_x)
        effect = ((orig_out - perturbed_out) ** 2).mean() ** 0.5
        print(f'Variable {i+1}, perturbation effect: {effect:.4f}')

def main():
    model = make_model()
    model.compile('adam', 'binary_crossentropy')
    print(model.summary())
    model.fit_generator(data_gen(), steps_per_epoch=500, epochs=10)
    var_importance(model)

if __name__ == "__main__":
    main()

The output of the code:

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         [(None, 10, 3)]           0
_________________________________________________________________
lstm (LSTM)                  (None, 10, 10)            560
_________________________________________________________________
lstm_1 (LSTM)                (None, 10, 5)             320
_________________________________________________________________
lstm_2 (LSTM)                (None, 10, 1)             28
_________________________________________________________________
flatten (Flatten)            (None, 10)                0
=================================================================
Total params: 908
Trainable params: 908
Non-trainable params: 0
_________________________________________________________________
Epoch 1/5
1000/1000 [==============================] - 14s 14ms/step - loss: 0.6261
Epoch 2/5
1000/1000 [==============================] - 12s 12ms/step - loss: 0.4901
Epoch 3/5
1000/1000 [==============================] - 13s 13ms/step - loss: 0.4631
Epoch 4/5
1000/1000 [==============================] - 14s 14ms/step - loss: 0.4480
Epoch 5/5
1000/1000 [==============================] - 14s 14ms/step - loss: 0.4440
Variable 1, perturbation effect: 0.1162
Variable 2, perturbation effect: 0.1185
Variable 3, perturbation effect: 0.0077