MATLAB: Classification Learner APP. Cross-validation, scatter plot and confusion matrix.

classification learnerconfusion matrix

I have a question regarding this app, hopefully some app-experts can help me πŸ™‚
I read from the website: "If you use k-fold cross-validation, then the app computes the accuracy scores using the observations in the k validation folds and reports the average cross-validation error. It also makes predictions on the observations in these validation folds and computes the confusion matrix and ROC curve based on these predictions".
Ok for the accuracy but.. if you look at the confusion matrix generated after selecting "k-fold validation", you have integer values. How are they determined? It is not an average of the confusion matrices obtained by eack of the k validation folds… they are neither summed up, since the sum of all the elements corresponds with the number of the learning set trials provided… so?
The same for the scatter plot after training: you can notice correct and incorrect trials in the figure.. But are they considered correct/incorrect on the basis of the average results obtained in all the k validation folds? Or this depicts the classification obtained through only one representative fold?
Thanks in advance.

Best Answer

Hi Giansu,
Let's understand the scatter plot and confusion matrix generated by Classification Learner App for k-fold cross-validation with an example of iris dataset having 150 samples and 5-fold cross-validation.
As we choose 5-folds, the app will partition the data into 5 disjoint sets or folds cross-validation. For each fold, the app trains a model using 4 folds as training data and remaining 1-fold (i.e. held-out fold) as validation data.
It means whenever we use k-fold cross-validation, all the 150 samples will be considered as validation data or held-out fold for once. For e.g., for first iteration 1st fold will be validation and remaining 4 folds will be training data and similarly for second iteration 2nd fold will be validation and remaining 4 folds will be training data.
Scatter plot: The each prediction shown in the scatter plot is obtained when that particular observation was a part of held-out fold or validation data while model was training.
Confusion Matrix: The confusion matrix depicts how correctly the model predicted the class of the observation when that particular observation was a part of held-out fold or validation data while model was training. Hence the values are integer in confusion matrix.
Accuracy: The accuracy is calculated for each k-fold and to calculate the accuracy for the model we do average.
Following are the scatter plot and confusion matrix which I got on iris data for 5-fold cross validation:
Hope it helps!