Solved – Nested cross-validation for classification in MATLAB

classificationcross-validationMATLABsvm

I am trying to tackle a classification problem with Support Vector Machine in Matlab using SVM. Using sample codes in
Bioinformatics Toolbox documentation (SVM Classification with Cross Validation)
, I am able to train a SVM and find its optimal parameters.

But from
Inner loop overfitting in nested cross-validation and How does one appropriately apply cross-validation in the context of selecting learning parameters for support vector machines?,

I will require another outer (leave-one-out) cross-validation to ensure that the final model isn't biased.

My question is, so how do I implement this nested cross-validation? I understand that the inner cross-validation is to choose the optimal parameters, but what is the purpose of the outer cross-validation and how should I do it in code?

Other similar information:

Nested CV with shrunken centroids and SVM

On Over-fitting in Model Selection and Subsequent Selection Bias in
Performance Evaluation

How Wrong Can We Get? A Review of Machine Learning Approaches and Error Bars

Best Answer

The purpose of the outer cross-validation (CV) is to get an estimate of the classifier's performance on genuinely unseen data. If the hyperparameters are tuned based on a cross-validation statistic this can lead to a biased performance estimate and so an outer loop, which was not involved in any aspect of feature or model selection is needed to determine the performance estimate. Conversely if you do not tune the hyperparameters (and use default hyperparameters in SVM_train and SVM-classify) you do not need an outer cross-validation loop.

Here is an example of some code that will implement nested CV, this implementation uses Nelder-Mead optimization (NMO) and sequential forward feature selection in the inner loop to find the optimum feature set and hyperparameters (box-constraint (C) and RBF sigma).

Data are the data to be classified (Dimension: Cases x Features)

Labels are the class labels for each case

%************** Nested cross-validation ******************
Results = classperf(Labels, 'Positive', 1, 'Negative', 0);      % Initialize the classifier performance object
for i = 1:length(Labels)
    test = zeros(size(Labels));
    test(i) = 1; test = logical(test); train = ~test;
    disp(sprintf('Fold: %d of %d.\n',i,length(Labels)))

    %************** Perform feature selection ************
    z0 = [0,0];    % z=[rbf_sigma,boxconstraint] - set to default exp(z) = [0,0]
    [rbf_sigma_Acc(i) boxconstraint_Acc(i) maxAcc Features{i}] = SVM_NMO(z0,Data(train,:),Labels(train),num_folds);

    %***************** Outer loop CV *********************
    svmStruct = svmtrain(Data(train,Features{i}),Labels(train),'Kernel_Function','rbf','rbf_sigma',rbf_sigma_Acc(i),'boxconstraint',boxconstraint_Acc(i));
    class = svmclassify(svmStruct,Data(test,Features{i}));    % updates the CP object with the current classification results
    classperf(Results,class,test);
    Acc_fold(i) = Results.LastCorrectRate;    
    disp(sprintf('Test set Accuracy (Fold %d): %2.2f',i,Acc_fold(i)))
    disp(sprintf('Test set Accuracy (running mean): %2.2f\n',100*Results.CorrectRate))
end

function [rbf_sigma boxconstraint Acc Features_opt] = SVM_NMO(z0,Data,Labels,num_folds)
opts = optimset('TolX',1e-1,'TolFun',1e-1);
fun = @(z)SVM_min_fn(Data,Labels,exp(z(1)),exp(z(2)),num_folds);
[z_opt,Crit] = fminsearch(fun,z0,opts);
[~, Features_opt] = fun(z_opt);

%************ Get optimal results **************
Acc = 1 - Crit;                       % Accuracy for model  
rbf_sigma = exp(z_opt(1));
boxconstraint = exp(z_opt(2));
disp(sprintf('Max Acc: %2.2f, RBF sigma: %1.2f. Boxconstraint: %1.2f',Acc,rbf_sigma,boxconstraint))


function [Crit Features] = SVM_min_fn(Data,Labels,rbf_sigma,boxconstraint,num_folds)
direction = 'forward';
opts = statset('display','iter');
kernel = 'rbf';

disp(sprintf('RBF sigma: %1.4f. Boxconstraint: %1.4f',rbf_sigma,boxconstraint))
c = cvpartition(Labels,'k',num_folds);
opts = statset('display','iter','TolFun',1e-3);
fun = @(x_train,y_train,x_test,y_test)SVM_class_fun(x_train,y_train,x_test,y_test,kernel,rbf_sigma,boxconstraint);
[fs,history] = sequentialfs(fun,Data,Labels,'cv',c,'direction',direction,'options',opts);

Features = find(fs==1);        % Features selected for given sigma and C
[Crit,h] = min(history.Crit);  % Mean classification error

Hope this helps

Related Question