ReLU in Neural Networks – Why a Single ReLU Can’t Learn Another ReLU

kerasmachine learningneural networksoptimization

As a follow-up to My neural network can't even learn Euclidean distance I simplified even more and tried to train a single ReLU (with random weight) to a single ReLU. This is the simplest network there is, and yet half the time it fails to converge.

If the initial guess is in the same orientation as the target, it learns quickly and converges to the correct weight of 1:

animation of ReLU learning ReLU

loss curve showing convergence points

If the initial guess is "backwards", it gets stuck at a weight of zero and never goes through it to the region of lower loss:

animation of ReLU failing to learn ReLU

loss curve of ReLU failing to learn ReLU

closeup of loss curve at 0

I don't understand why. Shouldn't gradient descent easily follow the loss curve to the global minima?

Example code:

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, ReLU
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

batch = 1000


def tests():
    while True:
        test = np.random.randn(batch)

        # Generate ReLU test case
        X = test
        Y = test.copy()
        Y[Y < 0] = 0

        yield X, Y


model = Sequential([Dense(1, input_dim=1, activation=None, use_bias=False)])
model.add(ReLU())
model.set_weights([[[-10]]])

model.compile(loss='mean_squared_error', optimizer='sgd')


class LossHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.losses = []
        self.weights = []
        self.n = 0
        self.n += 1

    def on_epoch_end(self, batch, logs={}):
        self.losses.append(logs.get('loss'))
        w = model.get_weights()
        self.weights.append([x.flatten()[0] for x in w])
        self.n += 1


history = LossHistory()

model.fit_generator(tests(), steps_per_epoch=100, epochs=20,
                    callbacks=[history])

fig, (ax1, ax2) = plt.subplots(2, 1, True, num='Learning')

ax1.set_title('ReLU learning ReLU')
ax1.semilogy(history.losses)
ax1.set_ylabel('Loss')
ax1.grid(True, which="both")
ax1.margins(0, 0.05)

ax2.plot(history.weights)
ax2.set_ylabel('Weight')
ax2.set_xlabel('Epoch')
ax2.grid(True, which="both")
ax2.margins(0, 0.05)

plt.tight_layout()
plt.show()

enter image description here

Similar things happen if I add bias: 2D loss function is smooth and simple, but if the relu starts upside down, it circles around and gets stuck (red starting points), and doesn't follow the gradient down to the minimum (like it does for blue starting points):

enter image description here

Similar things happen if I add output weight and bias, too. (It will flip left-to-right, or down-to-up, but not both.)

Best Answer

There's a hint in your plots of the loss as a function of $w$. These plots have a "kink" near $w=0$: that's because on the left of 0, the gradient of the loss is vanishing to 0 (however, $w=0$ is a suboptimal solution because the loss is higher there than it is for $w=1$). Moreover, this plot shows that the loss function is non-convex (you can draw a line that crosses the loss curve in 3 or more locations), so that signals that we should be cautious when using local optimizers such as SGD. Indeed, the following analysis shows that when $w$ is initialized to be negative, it is possible to converge to a suboptimal solution.

The optimization problem is $$ \begin{align} \min_{w,b} &\|f(x)-y\|_2^2 \\ f(x) &= \max(0, wx+b) \end{align} $$

and you're using first-order optimization to do so. A problem with this approach is that $f$ has gradient

$$ f^\prime(x)= \begin{cases} w, & \text{if $x>0$} \\ 0, & \text{if $x<0$} \end{cases} $$

When you start with $w<0$, you'll have to move to the other side of $0$ to come closer to the correct answer, which is $w=1$. This is hard to do, because when you have $|w|$ very, very small, the gradient will likewise become vanishingly small. Moreover, the closer you get to 0 from the left, the slower your progress will be!

This is why in your plots for initializations that are negative $w^{(0)} <0 $, your trajectories all stall out near $w^{(i)}=0$. This is also what your second animation is showing.

This is related to the dying relu phenomenon; for some discussion, see My ReLU network fails to launch

An approach which might be more successful would be to use a different nonlinearity such as the leaky relu, which does not have the so-called "vanishing gradient" issue. The leaky relu function is

$$ g(x)= \begin{cases} x, & \text{if $x>0$} \\ cx, & \text{otherwise} \end{cases} $$ where $c$ is a constant so that $|c|$ is small and positive. The reason that this works is the derivative isn't 0 "on the left."

$$ g^\prime(x)= \begin{cases} 1, & \text{if $x>0$} \\ c, & \text{if $x < 0$} \end{cases} $$

Setting $c=0$ is the ordinary relu. Most people choose $c$ to be something like $0.1$ or $0.3$. I haven't seen $c<0$ used, though I'd be interested to see a study of what effect, if any, it has on such networks. (Note that for $c=1,$ this reduces to the identity function; for $|c|>1$, compositions of many such layers may cause exploding gradients because the gradients become larger in successive layers.)

Slightly modifying OP's code provides a demonstration that the issue lies with the choice of activation function. This code initializes $w$ to be negative and uses the LeakyReLU in place of the ordinary ReLU. The loss quickly decreases to a small value, and the weight correctly moves to $w=1$, which is optimal.

LeakyReLU fixes the problem

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, ReLU
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

batch = 1000


def tests():
    while True:
        test = np.random.randn(batch)

        # Generate ReLU test case
        X = test
        Y = test.copy()
        Y[Y < 0] = 0

        yield X, Y


model = Sequential(
    [Dense(1, 
           input_dim=1, 
           activation=None, 
           use_bias=False)
    ])
model.add(keras.layers.LeakyReLU(alpha=0.3))
model.set_weights([[[-10]]])

model.compile(loss='mean_squared_error', optimizer='sgd')


class LossHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.losses = []
        self.weights = []
        self.n = 0
        self.n += 1

    def on_epoch_end(self, batch, logs={}):
        self.losses.append(logs.get('loss'))
        w = model.get_weights()
        self.weights.append([x.flatten()[0] for x in w])
        self.n += 1


history = LossHistory()

model.fit_generator(tests(), steps_per_epoch=100, epochs=20,
                    callbacks=[history])

fig, (ax1, ax2) = plt.subplots(2, 1, True, num='Learning')

ax1.set_title('LeakyReLU learning ReLU')
ax1.semilogy(history.losses)
ax1.set_ylabel('Loss')
ax1.grid(True, which="both")
ax1.margins(0, 0.05)

ax2.plot(history.weights)
ax2.set_ylabel('Weight')
ax2.set_xlabel('Epoch')
ax2.grid(True, which="both")
ax2.margins(0, 0.05)

plt.tight_layout()
plt.show()

Another layer of complexity arises from the fact that we're not moving infinitesimally, but instead in finitely many "jumps," and these jumps take us from one iteration to the next. This means that there are some circumstances where negative initial vales of $w$ won't get stuck; these cases arise for particular combinations of $w^{(0)}$ and gradient descent step sizes large enough to "jump" over the vanishing gradient.

I've played around with this code some and I've found that leaving the initialization at $w^{(0)}=-10$ and changing the optimizer from SGD to Adam, Adam + AMSGrad or SGD + momentum does nothing to help. Moreover, changing from SGD to Adam actually slows the progress in addition to not helping to overcome the vanishing gradient on this problem.

On the other hand, if you change the initialization to $w^{(0)}=-1$ and change the optimizer to Adam (step size 0.01), then you can actually overcome the vanishing gradient. It also works if you use $w^{(0)}=-1$ and SGD with momentum (step size 0.01). It even works if you use vanilla SGD (step size 0.01) and $w^{(0)}=-1$.

The relevant code is below; use opt_sgd or opt_adam.

opt_sgd = keras.optimizers.SGD(lr=1e-2, momentum=0.9)
opt_adam = keras.optimizers.Adam(lr=1e-2, amsgrad=True)
model.compile(loss='mean_squared_error', optimizer=opt_sgd)