(conceptual question) What kind of minima do we expect stochastic gradient descent to get stuck on and why

calculusgradient descentmachine learningmultivariable-calculusstochastic-processes

Suppose you want to find $k$ that minimises your cost function $J_D(k)$ for the whole dataset $D$. We may want to apply batch gradient descent or stochastic gradient descent. Let's deliberately initialise $k$ with the same number $k_0 = 1$ for both BGD and SGD to see the difference in their behavior.

If you apply BGD, the whole process may look like this:

enter image description here

On the other hand if you apply SGD, this optimisation may look like this:

enter image description here

In both pictures the blue solid curve represents the cost function $J_D(k)$. But in the second picture there are also dotted curves. In my experiment I used batch size which was $10\text%$ of the whole dataset $D$. So each dotted curve represents the cost function $J_B(k)$ for the current batch $B$. From these dotted curves you can see that the gradient $\nabla J_B\left(k_0\right)$ happened to be large multiple times in a row for different batches $B$. That's why the point was "pushed" to the deeper minimum.

As I understand we use SGD hoping that there is such a big $\nabla J_B\left(k_0\right)$ for some batch $B$ so that the red point is "pushed" to jump out of this local minimum for another chance at arriving at a better minimum.

But I'm stuck here.

  1. Why does stochastic gradient descent lead us to a minimum at all? Why can't it escape all the minima?

  2. Why do we think that another local minimum is going to be deeper than the initial one? I don't believe that we just hope that our new minimum is going to be good enough. With the same success we could randomly choose a value for $k$.

  3. If we don't think that another local minimum is going to be deeper, then how is SGD supposed to avoid local minima problem? Our red point can end up in a minimum that is shallower (higher) than the initial one, e.g.:

enter image description here

  1. If BGD looks for the nearest minimum, then what kind of minimum does SGD look for?

  2. How do we know for certain that it's not going to escape a deeper minimum? How deep should it be?

  3. What kind of minima do we expect stochastic gradient descent to get stuck on and why do we think it's going to be deeper than the minimum we can obtain with a normal gradient descent?

  4. How SGD is supposed to avoid local minima problem if all it can is just push us to jump out of a local minimum? I mean it doesn't look for a better one but only wandering along the curvature.


As a side note, $J_D(k) = \frac 1n\sum_{i=1}^n\left(\sin\left(kx_i\right) – y_i\right)^2$.


EDIT 1:

Need some clarification of @WhoDatBoy's answer.

Since we randomly selecting each batch, the single batch's distribution is going to be similar to that of the whole dataset. And this distribution uniquely determines the distribution of residuals of each batch. That's why each batch's gradient is going to be similar to that of the whole dataset. Is that right?

Now, I perfectly understand why the red point can't usually escape from wide minima: it's very unlikely to select a batch with gradient that differs from the whole dataset multiple times in a row.

However there is still a thing that confuses me. You said that SGD was not invented to be robust against local minima. But it's told to be likely to reach a better minimum than the initial one. And I can see why in the case when the red point was initialised in some local minimum near a wide minimum: there's a chance that some batch's gradient will push it from the shallow minimum towards the wider one. But what if our cost function looks like this:

enter image description here

Some batch can push the point to the left in the shallower minimum. The point can stay there for a while and after that it can be pushed again towards even shallower minimum.

Question 1: Is it highly unlikely case, since there are always batches seeking to push the point to the right?

Or consider the following situation:

enter image description here

Despite the fact that the initial minimum is deeper, but it's very narrow. So, the point can be easily pushed out of it towards the shallower minimum. In both these cases SGD can fail.

Question 2: I'm not sure about the first one but in the second plot we definitely can't say that the red point is likely to find a better minimum, right? I mean, all the minima are too narrow for the point to stay there.

Question 3: Is it true that only wide minima can hold the point (no matter deep or not)? And how wide should it be depends on the batch size we choose?

Question 4: And since we don't know it in advance, we just try to guess its size, right?

Question 5: It turns out then, the depth of the minimum doesn't play much role in holding the red point. It's the width of the minimum that matters?

Question 6: Do we assume something before applying SGD? If yes, then what exactly? I mean is there some kind of assumption of the form: "SGD is likely to find a better minimum if the curvature does not have only narrow minima and is not too hilly".


EDIT 2:

All over this edit I assume that we have the same batch size and the same learning rate and, for the sake of simplicity, assume that all those minima $A$, $B$ and $C$ (denoted below) have the same width (but different depth).

CONFUSION 1: In your Question 5 answer you said that the depth is important. Doesn't it mean that the deeper the minimum is, the harder it is for the red point to escape that minimum? Thus, we can conclude that the red point is likely to stuck in relatively deep minima when using SGD. The word "relatively" is used because the depth that is able to hold the red point depends on the batch size and the learning rate: the smaller the batch size and the bigger the learning rate, the deeper minima the red point is looking for. By "looking for" I mean that the red point is going to get stuck in such minima. However, we don't know how deep the minimum has to be in absolute value.

CONFUSION 2: It's still unclear why the red point is likely to get stuck in deeper minima. Suppose, for the sake of example, that we have a dataset of $100$ observation and 3 minima in our cost function curvature: $A$, $B$ and $C$. The respective errors (values of our cost function) at those minima are: $\mathrm{error}(A) = 100$, $\mathrm{error}(B) = 10$ and $\mathrm{error}(C) = 0$.

  1. Now, when the red point gets into the minimum $C$, then each of $100$ observations has $0$ error and therefore $0$ gradient. So, whatever batch you choose it's going to have $0$ gradient, since its gradient is the sum of gradients of the observations the batch consists of. Consequently, it's impossible for the red point to escape from the minimum $C$.

  2. And here is my main confusion. Why is the red point less likely to escape the minimum $B$ (the deeper one) than the minimum $A$? It would be nice to explain it in the following way: "since the $\mathrm{error}(A) > \mathrm{error}(B)$, then the gradient of each observation is smaller in the minimum $B$ and therefore every batch now has smaller gradient which causes smaller ability to push the red point out of the minimum $B$". But the problem is that we can NOT claim such a thing, since the error reduction in the minimum $B$ compared to the minimum $A$ could be caused by a single observation. I mean, if the error of a single observation, say the first one out of our 100 observations, reduced significantly, then it would cause a reduction in the error of our cost function. But the rest of the observations has the same error as before and therefore the same gradient. And since we're randomly picking the batch on each iteration, our red point can still be pushed out of the minimum $B$ with the same probability (am I wrong in here regarding the same probability?).

CONFUSION 3: It becomes even more confusing when the reduction in the error of a single observation is not the case, and the error of our cost function is reduced due to the fact that overall error reduced in some observations is greater than the overall error raised in other observations. The minimum would be deeper in this case, but how to show that the red point is now less likely to escape from this new deeper minimum?

I want to note here that in my understanding, the reduction in the cost function does not mean either the reduction in the gradient of the whole dataset or the reduction in the gradient of some individual batch. Then how on earth can the reduction in the cost function (and this is exactly what the deeper minimum means) mean smaller ability to escape the minimum?

Best Answer

EDITS BELOW OP

OP:

  1. You're assuming all data are sampled from the same distribution, so on average the gradient computed at any step will be roughly equivalent to that of the global gradient. SGD is just gradient descent except, rather than computing the cost on the whole set, you just do it on a random subset. This is for speed. But the crucial point lies in that you're randomly selecting from a set whose distribution is equivalent to that of the entire set.

It can't escape the minima for the same reason that regular GD can't - the gradient always points "downhill", and if you're in a local minimum then, by definition, every direction is "uphill".

  1. In general we don't know that there isn't a better (or worse) minimum. Indeed, many methods utilize lots of random initializations to try to mitigate problems from bad initial guesses (look up simulated annealing)

  2. There is no definitive solution that works in every situation on avoiding local minima. Look up "gradient descent with momentum" - this is one method that adds a "momentum" term (inspired from physics) to help be robust against local minima, but if they are deep enough or "wide" enough, it will not matter.

This is why people try so hard to cast problems in a way where convex optimization can be used

  1. Both BGD and SGD look for the nearest minimum. SGD is just a "low resolution" version of BGD.

  2. We are not certain. I'm sure someone could do an analysis on the liklihood of leaving a local minimum of a certain description given the distribution the sample is pulled from, but I've not done this nor seen it done. It's not very important because in general we don't have a full description of the cost function in general. Remember, we are sampling the cost function at a point - in solving a real problem you would never have a graph like you do (otherwise, what's the point of SGD - just look at the graph)

  3. This is a little weirdly phrased, but I think you're asking why gradient descent is more robust to small local minima. It stems from the stochasticity introduced by only sampling a small portion of the data to get your cost function. If you just happen to subsample points which have a gradient that pushes you out of a "true" local minimum, then great. You would never know that you have been saved from a "true" local minimum though.

  4. Again, SGD was not invented to be robust against local minima. It was a simple extension of regular GD for the sake of speed and incremental updates in e.g. neural networks. If you want robustness against local minima, any form of regular GD is not what you're looking for. Again, think about adding a momentum term or checking out some other optimizers


Edit 1:

First, thanks to @Hyperplane for pointing out that SGD is not just used for speed, but also regularization. Check out the paper he linked below in the comments if interested.

Response to edits in the original post:

Since we randomly select each batch ... the whole dataset

This is essentially correct. We make the probabilistic assumption that all data are sampled from the same distribution. The random sample will exhibit different statistical attributes than the whole set (e.g. sample mean), but on average we expect that the gradient of the cost function computed from the subsample will be similar to that of the entire sample.

Note that in some cases in your GIF, it is actually quite far off - the expected difference will shrink as your subsample grows, and determining how big to make your subsample is a consideration you have to have.

However there is still a thing that confuses me ... pushed again towards a shallower minimum

There are a few things to say here.

  1. I'm not sure if I can say with mathematical certainty that SGD is more likely to reach a better minimum - I have no proof of this. We can see intuitively that there could be reason for this, so it is "worth" trying in practice.

  2. The step size relative to the "size" of the minimum matters. Let's assume for the sake of simplicity that our minimum sits at the bottom of a surface which describes the bottom half of a sphere. If the step size is 10 times the radius of sphere, we simply cannot mathematically get caught in this minimum (unless of course we happen to land exactly on the minimum)

    On the other hand, if the radius of the sphere is 10e6 times greater than the step size, it is virtually impossible to leave this "bowl".

  3. To directly answer your question about progressively "jumping up" to higher and higher minima, the main answer is the importance of tuning your learning rate (step size). Having a sufficiently small learning rate will make this virtually impossible.

Question 1 answer: Again, you're highlighting the importance of learning rate. With a sufficiently small learning rate, it is very unlikely that you would end up in a much higher minimum. If the learning rate is very high, it could happen or you could essentially just end up bouncing around randomly. If the learning rate is "medium" then you could end up in the global minimum or perhaps either of the first two to three "up the bowl" on either side of the global minimum. Depends on learning rate, small batch size, and the stochasticity introduced from random selection.

Despite the fact that ... In both these cases SGD can fail

I was actually going to answer something like this above, but I thought it was confusing. Either way here it is: The steepness of the minimum makes this very unlikely given a reasonable step size. Let's say for the sake of example that the gradient is -20 on the left and +20 on the right on the minimum. Remember that we expect our small batches to have similar gradients (derivatives) to the whole dataset. So we might expect a gradient range of say 15 to 25 in magnitude (negative on the left, positive on the right). We are still pointing in the right direction. Could you "hop out"? Sure. But it's very unlikely and would be a "fluke" of a random sample.

Question 2 answer: Depends on learning rate and initialization. If the learning rate is pretty low, then you could very well get stuck in any of the minima near x = -0.5, 0, or 2. This isn't really a bad thing in this particular case for either the 0 or 2 minima.

Question 3 answer: Again, "wide" is hard to define when you can have a function that is at any scale. Are we concerned with width-to-depth ratio (rhetorical)? In short the answer is no - if the learning rate is small enough then either will work. I can expand if you need.

Question 4 answer: I'm not sure what you're saying here. Guess the size of what - the "size" of the "bowls" around the minima? Recall that you can get much more complex behavior in higher dimensions (e.g. saddle points in 2D) which is where most applications lie (in higher dimensions I mean). In general, you don't know the "scale" of the relevant features of the function.

Take the 2nd function you added. It exists in the range -5 to 5 with important features on that order. You could've just as easily made this $-5^{-10}$ to $5^{-10}$, or $-5^{10}$ to $5^{10}$. A learning rate of say 0.01 could be useful for your function, but would be hopelessly useless for the other 2 scales (even impossible for the first)

Question 5 answer: As above, depth is important because even if the small batch gradients are not equal to the whole batch gradient, they will still generally be pointing in the right direction. The smaller the magnitude of the gradient, the easier it is for random sampling to go in a "bad" direction.

On the other hand, even though it is easier to go in a "bad" direction in a wider bowl, you have to take many "bad" steps in succession to leave the bowl. So there is a tradeoff, and we have seen again the importance of tuning your learning rate (and small batch size).

Look again at your GIF - around the big hills, the small batch derivates pretty closely math - at least they have the right direction. However, on the far left of your function, the derivatives don't as closely match.

Question 6 answer: You assume nothing beyond the assumptions of regular GD (namely continuity and differentiability). In general, SGD is used because it is faster and, as mentioned by @Hyperplane, is a form of regularization. It may leave local minima (but so too might regular GD with a high enough learning rate). I cannot give you a definitive rule for when it might be better to use one or the other. I'm not saying such a rule doesn't exist, but I do not know it if it does.


Edits 2:

Confusion 1 response: It depends what the source of the escape is. If you are asking whether or not the stochasticity introduced from using a minibatch is less likely to be the cause of leaving a deep minimum, then yes. The deeper the minimum is (given a fixed width), the higher the magnitude of the derivative and the less likely a minibatch will have a derivative that points in the "wrong" direction (i.e. has a different sign than the true whole batch derivative, in the 1D case).

If the source of the escape is that your learning rate is not tuned properly, then having a deeper minimum could actually make you more likely to leave. Remember, you're not really stepping "downhill" - you're taking a step left or right (in the 1D case) whose size depends on your learning rate and the derivative at that point. If the derivative has higher magnitude (i.e. the minimum is deeper), then it means you are taking a bigger step to the right or left. If the learning rate is high enough, you could very well take a step big enough to leave the well.

the smaller the batch size and the bigger the learning rate, the deeper minima the red point is looking for

This is not true for the reasons given in my last paragraph. Higher learning rate does not mean you are walking "downhill" with bigger steps, it means you are walking to the left or right with bigger steps and then "ending up on the function" at whatever altitude happens to correspond to that x value.

Imagine this: I have a global minimum at (0,0), and a global maximum at (1, 1000). Let's say I'm at x=-0.5, and the cost function happens to have a derivative of -1 there (i.e. "pointing" towards the minimum). Let's also assume that my learning rate is 1.5. **Then my update will take my current position and add $-derivative*learning\_rate=-(-1)*1.5=+1.5$. Therefore, my new x becomes $x_{new} = x_{old} + 1.5 = -0.5 + 1.5 = 1$. Look at that! I've ended up increasing my cost from 0 to 1000; and, worse yet, the derivative of the cost function at my new x value is 0, so I am stuck here.

This is obviously a very convenient and unlikely scenario, but it highlights the very important point that steps are taken on the x-axis, not on the "function surface". It is just to point out that you can leave minima if they are very skinny and deep if your learning rate is on the order of their width.

Confusion 2 response: We have been using the term "minimum" relatively loosely, and while it is convenient, there is a problem here. You have said

Now, when the red point gets into the minimum C, then each of 100 observations has 0 error and therefore 0 gradient. So, whatever batch you choose it's going to have 0 gradient, since its gradient is the sum of gradients of the observations the batch consists of. Consequently, it's impossible for the red point to escape from the minimum C.

We have been talking about being "in" a minimum as being in the "bowl" that surrounds it. The minimum itself is only the single point at which the derivative is 0. If you happen to land on a minimum exactly, then it is true you will never be able to leave. All updates will move you exactly 0. This is very unlikely. With the right learning rate, you will asymptotically converge to a minimum, but not reach it. Even if you happen to land on one (overwhelmingly unlikely for functions like those you show - the set of minima is measure zero), then you could still be pushed off of it just because of the errors introduced from floating-point computations.

I think the above might've answered your number 2 bullet under this confusion. In general, if you end up in a steep minimum with a low learning rate, you will never leave (steep and low are both relative here).

Confusion 3 answer: I'm a bit confused by your wording here. Importantly, "reduction in cost function" is not the same thing as a "deeper minimum" for the reason I mentioned above. Deeper minimum just means bigger magnitude derivative means bigger steps left or right. This could be highly detrimental.

Ultimately I think you shouldn't worry too much about this. Just think of SGD as a fast, low-resolution version of GD which happens to be slightly robust against extremely weak minima - this is more of an oddity than anything. If you are concerned with robustness then again look at (S)GD + Momentum.

Related Question