Several method have been devised for accessing or quantifying variable importance (even if only relative to each other) in MLP neural network models:
- Connection weights
- Garson’s algorithm
- Partial derivatives
- Input perturbation
- Sensitivity analysis
- Forward stepwise addition
- Backward stepwise elimination
- Improved stepwise selection 1
- Improved stepwise selection 2
(these were described in http://dx.doi.org/10.1016/j.ecolmodel.2004.03.013)
Is there any method that can be applied to RNNs or LSTMs neural networks?
Best Answer
In short, yes, you can get some measure of variable importances for RNN based models. I won't iterate through all of the listed suggestions in the question, but I will walk through an example of sensitivity analysis in depth.
The data
The input data for my RNN will be composed of a time-series with three features, $x_1$, $x_2$, $x_3$. Each feature will be all be drawn from the random uniform distribution. The target variable for my RNN will be a time-series (one prediction for each time-step in my input):
$$ y = \left\{\begin{array}{lr} 0, & \text{if } x_1 x_2 \geq 0.25\\ 1, & \text{if } x_1 x_2 < 0.25 \end{array}\right. $$
As we can see, the target is dependent on only the first two features. Thus, a good variable importance metric should show the first two variables being important, and the third variable being unimportant.
The model
The model is a simple three layer LSTM with a sigmoid activation in the final layer. The model will be trained in 5 epochs with 1000 batches per epoch.
Variable importance
To measure the variable importance, we'll take a large sample (250 time-series) of our data $\hat{x}$ and compute the model's predictions $\hat{y}$. Then, for each variable $x_i$ we'll perturb that variable (and only that variable) by a random normal distribution centered at 0 with scale 0.2 and compute a prediction $\hat{y_i}$. We'll measure the effect this perturbation has by computing the Root Mean Square difference between the original $\hat{y}$ and the perturbed $\hat{y_i}$. A larger Root Mean Square difference means that variable is "more important".
Obviously, the exact mechanism you use to perturb your data, and how you measure the difference between perturbed and unperturbed outputs, will be highly dependent on your particular dataset.
Results
After doing all of the above, we see the following importances:
As we expected, variables 1 and 2 are found to be much more important (about 15x more) than variable 3!
Python code to reproduce
The output of the code: