When trying to plot the mini batch loss vs iteration while training a CNN, an error occurs. I modified the sample graph given in the training documentation to plot the loss instead of the accuracy. My modification is given below:
function plotTrainingLoss(info) persistent plotObj info.State == "start" plotObj = animatedline('Color','r'); xlabel("Iteration") ylabel("Loss") title("Training loss evolution") elseif info.State == "iteration" addpoints(plotObj,info.Iteration,info.TrainingLoss) drawnow limitrate nocallbacks fprintf('%d \n',info.TrainingLoss) end end
Given that info.TrainingLoss a possible output argument, this should simply work. I'm also printing the value to the screen which works well. But the graph fails with the following error:
Initializing image normalization.|=========================================================================================|| Epoch | Iteration | Time Elapsed | Mini-batch | Mini-batch | Base Learning|| | | (seconds) | Loss | Accuracy | Rate ||=========================================================================================|| 1 | 1 | 0.33 | 2.0795 | 14.84% | 0.0010 |
Error using trainNetwork (line 133)Invalid type for argument Y. Type should be double. Error in TrainOwnCNN_Advanced (line 108) CNN = trainNetwork(trainingData,layers,options);
I tried debugging, and the TrainingLoss is a single gpuArray object. But trying info.BaseLearnRate prints to screen without issues, but also gives the same error when trying to plot. Switching it to info.TrainingAccuracy magically works like in the help documentation. I'm doing the function call correctly otherwise the rest would fail which is not the case and TrainingAccuracy works. The plot expects a double which info.TrainingAccuracy seems to be, but the rest isn't?
Can anyone shed some light on this?
Printing to screen while plotting info.TrainingAccuracy looks like below:
Initializing image normalization.|=========================================================================================|| Epoch | Iteration | Time Elapsed | Mini-batch | Mini-batch | Base Learning|| | | (seconds) | Loss | Accuracy | Rate ||=========================================================================================|| 1 | 1 | 0.23 | 2.0832 | 7.81% | 0.0010 |2.083199e+00 2.083350e+00 2.082194e+00 2.081159e+00 2.080728e+00 2.080156e+00 2.079782e+00 2.080018e+00 2.077905e+00 2.078161e+00 2.078137e+00
Best Answer