Solved – Why is gradient descent inefficient for large data set

gradient descentlarge datamachine learning

Let's say our data set contains 1 million examples, i.e., $x_1, \ldots, x_{10^6}$, and we wish to use gradient descent to perform a logistic or linear regression on these data set.

What is it with the gradient descent method that makes it inefficient?

Recall that the gradient descent step at time $t$ is given by:

$$w_{t+1} = w_{t} + \eta_t \nabla f(x)$$

where $f$ is the loss function.

I am not seeing anything out of the ordinary with the above step that causes the algorithm to be inefficient. Is it the computation of $\nabla f(x)$? Couldn't this operation be pre-computed, i.e., each $\frac{\partial f}{\partial x}$ already computed, and simply evaluate them at each data point $x_i?$

Best Answer

It would help if you provided a context to the claim that the gradient descent is inefficient. Inefficient relative to what?

I guess that the missing context here is the comparison to stochastic or batch gradient descent in machine learning. Here's how to answer the question in this context. You are optimizing the parameters of the model, even hyperparameters. So, you have the cost function $\sum_{i=1}^n L(x_i|\Theta)$, where $x_i$ - your data, and $\Theta$ - vector of parameters, and $L()$ - loss function. To minimize this cost you use gradient descent over the parameters $\theta_j$: $$ \frac{\partial}{\partial \theta_j}\sum_{i=1}^nL(\Theta|x_i)$$

So, you see that you need to get the sum over all data $x_{i=1,\dots,n}$. This is unfortunate, because it means that you keep looping through the data for each step of your gradient descent. That's how the batch and stochastic gradient descent comes up: what if we sampled from the data set, and calculated the gradient on a sample, not the full set? $$ \frac{\partial}{\partial \theta_j}\sum_{k=1}^{n_s}L(\Theta|x_k)$$ Here, $n_s$ is the number of observations in the sample $s$. So, if your sample is 1/100th of the total set, you speed up your calculations by 100 times! Obviously, this introduces the noise, which lengthens the learning, but noise is decreases at rate of $\sqrt n$ while calculation amount increases at $n$, so this trick may work.

Alternatively, insteado waiting until full sum $\sum_{i=1}^n$ is calculated, you could split this into batches, and do a step for each batch $\sum_{s=1}^M\sum_{i_s=1}^{n_s}$. This way you would have done M steps by the time the sum over entire data set is calculated. These would be noisier steps, but noise cancels out over time.