Solved – Most common method for deciding when to stop training a neural net on a batch

gradient descentmachine learningneural networks

I have created my own neural net which is using batch gradient descent. In other words, it trains on batches of examples all at once.

My issue is trying to figure out when to stop the training of the batch. I'll try to make things as understandable as possible since there are so many options, but please comment if you need clarification.

Each batch allows me to compute a list of output errors that correspond to the error the computer has on each example. Each error is a vector with elements equal to the number of output neurons.

Let each error vector be notated by e1, e2, e3, ... so the list of output errors is [e1, e2, ...]. Let an error vector in general be notated by e.

Now there are 2 things we can do to a vector or a list:

  • Average: average all of the elements in a vector or list. This will be notated by the function a(). For example to find the average error of an example we would do a(e)
  • Check: make sure that each element in the vector or list is below a certain value. If any of the elements are too high, we continue to train on the entire batch. (This function can also be applied to a single value.) This will be notated by the function c(). For example, if we wanted to make sure the average error in e was low enough, we would do c(a(e))
  • Sum: add up all of the elements in the vector or list. This will be notated by the function s(). For example, s(e) would be totaling the error found in error vector e.

Here are the methods I have determined for deciding when to stop gradient descent:

  • Iterative: just run the training a certain number of times
  • Average of averages: Find the average error of each example, then find the average of those and make sure it's below a certain amount. This is given by the formula c(a([a(e1), a(e2), ...]))
  • Specific error: Check to make sure the error in each output neuron is below a certain amount. This is given by the formula c(e1), c(e2), ...
  • Specific averages: Check to make sure the average error in each example is below a certain amount. This is given by the formula c([a(e1), a(e2), ...])
  • Sum of sums: Total up all the error in the entire batch and then make sure it's low enough. This is given by the formula c(s([s(e1), s(e2), ...])
  • Sum of averages: Find the average error of each example, then add up all those and make sure it's below a certain amount. This is given by the formula c(s([a(e1), a(e2), ...]))
  • Average of sums: Compute the average total error in each example, and then make sure it's below a certain amount. This is given by the formula c(a([s(e1), s(e2), ...]))
  • Specific sums: Check to make sure the total of the error in each example is below a certain amount. This is given by the formula c([s(e1), s(e2), ...])

So my question is this:

Which one of these options is most commonly used when training neural nets? Why? (Is there another option that I've missed?)

(Edit: I have found a very general version of this question here, where the answers seem to say there is no specifically good way. If that is still true, please say so in your answer. Note that this is not a duplicate, as I explain the specific methods I would like to use.)

Edit: as A.D pointed out, summing is pretty useless because it depends on how many values are in each list/array/batch. And if you know how many examples there are, you can just use averaging. So, the methods involving "summing" have been cut.

Best Answer

Seems like there are 2 questions here.

  1. What is the "loss" for a batch?

  2. How do I know when to stop training.

Regarding question 1, the most common loss/error is mean of the individual samples in a batch. While you can use the sum of the errors as well, it is confusing (IMO) as it depends on the number of instances per batch.

For question 2, a common method is early-stopping. Note that early-stopping requires that you have a held-out validation set. Also, note that this method does not really care about the loss/error on training data. It simply stops training when there is no more improvement (measured by f1 for classification tasks) on the held-out set.

As you mentioned there are many ways to go about this task, but the steps I outlined are pretty standard in my opinion.

Related Question