Gradient Descent for Linear Regression – Why Use When Closed-Form Solution Exists?

gradient descentmachine learningregression

I am taking the Machine Learning courses online and learnt about Gradient Descent for calculating the optimal values in the hypothesis.

h(x) = B0 + B1X

why we need to use Gradient Descent if we can easily find the values with the below formula? This looks straight forward and easy too. but GD needs multiple iterations to get the value.

B1 = Correlation * (Std. Dev. of y/ Std. Dev. of x)

B0 = Mean(Y) – B1 * Mean(X)

NOTE: Taken as in https://www.dezyre.com/data-science-in-r-programming-tutorial/linear-regression-tutorial

I did checked on the below questions and for me it was not clear to understand.

Why is gradient descent required?

Why is optimisation solved with gradient descent rather than with an analytical solution?

The above answers compares GD vs. using derivatives.

Best Answer

The main reason why gradient descent is used for linear regression is the computational complexity: it's computationally cheaper (faster) to find the solution using the gradient descent in some cases.

The formula which you wrote looks very simple, even computationally, because it only works for univariate case, i.e. when you have only one variable. In the multivariate case, when you have many variables, the formulae is slightly more complicated on paper and requires much more calculations when you implement it in software: $$\beta=(X'X)^{-1}X'Y$$ Here, you need to calculate the matrix $X'X$ then invert it (see note below). It's an expensive calculation. For your reference, the (design) matrix X has K+1 columns where K is the number of predictors and N rows of observations. In a machine learning algorithm you can end up with K>1000 and N>1,000,000. The $X'X$ matrix itself takes a little while to calculate, then you have to invert $K\times K$ matrix - this is expensive.

So, the gradient descent allows to save a lot of time on calculations. Moreover, the way it's done allows for a trivial parallelization, i.e. distributing the calculations across multiple processors or machines. The linear algebra solution can also be parallelized but it's more complicated and still expensive.

Additionally, there are versions of gradient descent when you keep only a piece of your data in memory, lowering the requirements for computer memory. Overall, for extra large problems it's more efficient than linear algebra solution.

This becomes even more important as the dimensionality increases, when you have thousands of variables like in machine learning.

Remark. I was surprised by how much attention is given to the gradient descent in Ng's lectures. He spends nontrivial amount of time talking about it, maybe 20% of entire course. To me it's just an implementation detail, it's how exactly you find the optimum. The key is in formulating the optimization problem, and how exactly you find it is nonessential. I wouldn't worry about it too much. Leave it to computer science people, and focus on what's important to you as a statistician.

Having said this I must qualify by saying that it is indeed important to understand the computational complexity and numerical stability of the solution algorithms. I still don't think you must know the details of implementation and code of the algorithms. It's not the best use of your time as a statistician usually.

Note 1. I wrote that you have to invert the matrix for didactic purposes and it's not how usually you solve the equation. In practice, the linear algebra problems are solved by using some kind of factorization such as QR, where you don't directly invert the matrix but do some other mathematically equivalent manipulations to get an answer. You do this because matrix inversion is an expensive and numerically unstable operation in many cases.

This brings up another little advantageof the gradient descent algorithm as a side effect: it works even when the design matrix has collinearity issues. The usual linear algebra path would blow up and gradient descent will keep going even for collinear predictors.