Machine Learning – Why Linear Regression Can Perform Better Than Linear Neural Networks

convergencemachine learningneural networkspythonregression

I looked through other questions on stats.exchange, but could not quite find a similar question, this post probably comes the closest to it.

Since (fully-connected) neural networks with only one layer and no activation function is linear regression, even though it is trained with some form of gradient descent, I expect the performance to be relatively similar to a linear regression performed by ordinary least squares. But in my case they differ by a factor of up to 2 or 3! Is this to be expected?

A bit of background:
I'm using neural networks to predict ocean wave heights from some points along the coast to other points. The further you go offshore (and if you don't include certain physical interactions in your calculations) the more linear the problem gets. That is why I started to compare my results with a (OLS) linear regression and found that it outperforms my best neural networks (by around a factor of 2). Which was shocking at first, but somehow explicable given the strong linear relationship for most of the domain. However, when analyzing the errors I can see that linear regression performs worse close to the shore, where non-linearity is the strongest, but if I train a (non-linear) neural network only for this region it still performs worse than a simple linear regression!

But I also want to note that the neural networks they don't perform bad at all. Quite in contrast, the results are very promising, it's just surprising that they get outperformed so easily even in places where the setting is more non-linear than linear.

Data set:

Input is of shape (2918, 69) of which 80% is used for training, 20% is used for validation. Each row consists of the wave height at a certain location in the domain.

The output is of shape (2918, 3773) so the neural network and linear regression are performing super-resolution. Equally, each row consists of the wave height at (different) locations in the domain.

Neural Network Architecture
As mentioned above, I try normal linear regression with OLS. A simple neural network with one layer and without activation function (so again, a linear regression) with either SGD or ADAM, where I trained different learning rates and many, many epochs (going up to 5000). The batch sizes that I tried are 16, 32, and 512 (just to go wildly different). Weights are updated after each batch.

The two other networks are not too important for the question, since I am wondering more about the different between the linear approaches, but I use either: 1. A fully-connected neural network with multiple hidden layers (ranging from 2 to 5) with varying unit sizes (128 to 2056), and ReLU activation.
2. A graph neural network that performs multiple Chebyshev Convolutions on the input, than upsamples it to the output size and performs yet 2-3 additional convolutions.

Both non-linear networks perform well and are more or less equal, but the linear regression is still best.

Questions

  1. Is it to be expected that a linear neural network performs that much worse than a linear regression, even though it's theoretically the same?

  2. Is it to be expected that the linear regression outperforms the neural network even in cases where the setting is mostly non-linear (even though some linear regions exist)

What I tried
Since it looked like a convergence problem or a bug, I tried different learning rates, different optimizers (SGD with and without momentum, ADAM), different loss functions (l1 and l2 loss) and implemented everything in 2 different libraries (pytorch and tensorflow), with different neural network architectures (fully-connected and graph neural networks). I tried to normalize and not normalize my data, there's still a large discrepancy. I let everything run also for multiple epochs

Is there something obvious maybe that I missed or that I could try? Are the results surprising? Especially the discrepancy between OLS linear regression and gradient descent linear regression?

I know there's no code or images associated to this question, but I hope the overall problem setting is still clear. Thank you already a lot for any answers!

Best Answer

It's not too surprising.

Essentially sgd is a very bad optimiser (at least for small data problems), whereas OLS (using eg QR decomposition) is finding the minimum in a single step.

If you set up a linear neural network, then you should get the same results as OLS just that your learning rate might need to be much smaller (than what you have) which will then require a lot more iterations.

The difference beween batch gradient descent and OLS is due to the curvature of your error surface (which in OLS corresponds to the covariance matrix). if your covariance matrix is constant diagonal (ie each dimension is variance $k$ (eg 1) and no correlation), then Batch SGD will perform as well as OLS with a suitably chosen learning rate. On the other hand unequal variances or correlations will mean that your single learning rate will cause 'overshooting' in some directions and slow progress in other directions (the error surface will be elliptical rather than spherical).

Another thing is that you should be using batch or minibatch (to remove noise in the optimisation procedure and match up to OLS).

EDIT: I recommend to first normalise the data. If you calculate the covariance of the inputs, and calculate the maximum eigenvalue of the covariance matrix, the learning rate should be smaller than 1/max_eigenvalue. Can you report what this value is in your case, together with your current learning rate (and obviously the smaller the learning rate the more epochs you need to iterate)?

Wikipedia on Newton's method

Standard OLS can be viewed as 1 step of newton's method. Newton's method is essentially doing gradient descent with a matrix of learning rates (based on the second derivative of the error, which is the covariance matrix for OLS)

The fact that you have a 'wave' input makes your data highly correlated (neighbouring points are similar), this is precisely the type of situation that gradient descent is very bad at, because it creates this narrow valley.

Here is code I have created to illustrate the problem. the design matrix (1000 x 70) is ill conditioned (matrix rank 8 in my simulation). gradient descent reaches about 1.7 whereas OLS is 5e-14. using the OLS coefficients in the linear net achieves the same zero loss (sanity check) [wandb is a logging tool and can be omitted]

#%%
import numpy as np
#import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import wandb

# %%

N=1000
p = 70
rng = np.random.default_rng()
X=np.cos(rng.random((N,1))*np.linspace(0,3,p)[np.newaxis,:])
# first  column
# %%
cov = np.cov(X[:,1:],rowvar=False)
var = np.diag(cov)
cond = np.linalg.cond(cov)
max_eig = np.linalg.eigvals(cov).max()

print(f"condition number is {cond}, max eigenvalue {max_eig:.2f}, max_eigenvalue {max_eig1:.2f}")
# %%
beta = rng.standard_normal((p,))
# %%
y = X @ beta
# %%
lsq = np.linalg.lstsq(X,y, rcond=None)
# %%
beta_hat = lsq[0]
# %%
plt.plot(X@beta,X@beta_hat,'.')
# %%
# https://www.deeplearningwizard.com/deep_learning/practical_pytorch/pytorch_linear_regression/
# Create class
class LinearRegressionModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearRegressionModel, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim, bias=False)  

    def forward(self, x):
        out = self.linear(x)
        return out
#%%
model = LinearRegressionModel(p, 1).float()
criterion = nn.MSELoss()
learning_rate = 0.01
#momentum=0.9
#optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

inputs = torch.from_numpy(X.astype(np.float32)).requires_grad_()
labels = torch.from_numpy(y.astype(np.float32))

wandb.init(
    # set the wandb project where this run will be logged
    project="OLS",
    config = {
        "optimisation" : "batch_adam",
        "learning_rate": learning_rate,
        #"momentum": momentum,
    }
)
# %%
epochs = 100000
losses = []
#%%
for epoch in range(epochs):
    epoch += 1
    # Convert numpy array to torch Variable
    
    # Clear gradients w.r.t. parameters
    optimizer.zero_grad() 

    # Forward to get output
    outputs = model(inputs)

    # Calculate Loss
    loss = criterion(outputs, labels)

    # Getting gradients w.r.t. parameters
    loss.backward()

    # Updating parameters
    optimizer.step()
    loss_value = loss.item()
    print(f"epoch {epoch}, loss {loss_value}")
    losses.append(loss_value)
    wandb.log({"loss": loss_value})
# %%
# check ols weights give zero loss in linear net
model1 = LinearRegressionModel(p, 1).float()
outputs1 = model1(inputs)
loss = criterion(outputs1, labels)
print(f"loss at initialisation {loss.item()}")
model1.linear.weight.data = torch.from_numpy(beta_hat.astype(np.float32))
outputs1 = model1(inputs)
loss = criterion(outputs1, labels)
print(f"loss using OLS solution {loss.item()}")
# %%

Related Question