Solved – Speed of prediction: neural network vs. random forest

machine learningMATLABneural networksrandom forest

I'm currently trying to improve on a classifier. The current method used is a neural network, and the method I've found to be better is a random forest (or even just a single tree). With 40 trees, the classification is much better than the neural network. However, it takes 40 minutes(using 4 parallel workers due to running out of memory) to classify a large block of data; whereas, the neural network takes ~5 minutes(using 8 parallel workers). Is there a way to improve the speed of prediction? And does anyone know the reason for this huge slow down? I'm guessing it is due to the number of trees, and also the number of workers I can use.

MATLAB was used to create and run both the network and the forest.

40 features, 13 outputs, training set size: ~800,000, individual block size: ~500×500, whole file to be classfied: 1+GB along with other files containing more information

The data is not sparse.

Best Answer

The comments are quite accurate, to summarize (and calling $p$ the number of simulateneous workers you have) the complexities should be (depending on the implementations) :

  • Random Forest : $O(n_{trees}*n* \log (n) / p)$
  • Neural Network : $O(n_{neurons}*size_{neurons}*n/p)$

The speed will also depend on the implementation, the $O$ just gives information about the scalability of the prediction part. The constant term omitted with the $O$ notations can be critical.

Indeed, you should expect random forests to be slower than neural networks.

To speed things up, you can try :

  • using other libraries (I have never used Matlab's random forest though)

  • reducing the depth of the trees (which will replace the $\log(n)$ by a constant term and allow you to use more workers - but this may harm the accuracy of the classifier)

  • check for duplicate features / constant columns in your data set and remove them (they do not improve accuracy and are responsible for a greater memory usage)

  • [Edit] is your data sparse ? I observed huge speed-ups using the "sparse" representations of the data (as long as the learning algorithms support it)