Logistic – How Does a Simple Logistic Regression Model Achieve 92% Classification Accuracy on MNIST?

image processinglogistic

Even though all the images in the MNIST dataset are centered, with a similar scale, and face up with no rotations, they have a significant handwriting variation that puzzles me how a linear model achieves such a high classification accuracy.

As far as I am able to visualize, given the significant handwriting variation, the digits should be linearly inseparable in a 784 dimensional space, i.e., there should be a little complex (though not very complex) non-linear boundary that separates the different digits, similar to the well-cited $XOR$ example where positive and negative classes can not be separated by any linear classifier. It seems baffling to me how multi-class logistic regression produces such a high accuracy with entirely linear features (no polynomial features).

As an example, given any pixel in the image, different handwritten variations of the digits $2$ and $3$ can make that pixel illuminated or not. Therefore, with a set of learned weights, each pixel can make a digit look as a $2$ as well as a $3$. Only with a combination of pixel values should it be possible to say whether a digit is a $2$ or a $3$. This is true for most of the digit pairs. So, how is logistic regression, which blindly bases its decision independently on all pixel values (without considering any inter-pixel dependencies at all), able to achieve such high accuracies.

I know that I am wrong somewhere or am just over-estimating the variation in the images. However, it would be great if someone could help me with an intuition on how the digits are 'almost' linearly separable.

Best Answer

tl;dr Even though this is an image classification dataset, it remains a very easy task, for which one can easily find a direct mapping from inputs to predictions.


Answer:

This is a very interesting question and thanks to the simplicity of logistic regression you can actually find out the answer.

What logistic regression does is for each image accept $784$ inputs and multiply them with weights to generate its prediction. The interesting thing is that due to the direct mapping between input and output (i.e. no hidden layer), the value of each weight corresponds to how much each one of the $784$ inputs are taken into account when computing the probability of each class. Now, by taking the weights for each class and reshaping them into $28 \times 28$ (i.e. the image resolution), we can tell what pixels are most important for the computation of each class.

Note, again, that these are the weights.

Now take a look at the above image and focus on the first two digits (i.e. zero and one). Blue weights mean that this pixel's intensity contributes a lot for that class and red values mean that it contributes negatively.

Now imagine, how does a person draw a $0$? He draws a circular shape that's empty in between. That's exactly what the weights picked up on. In fact if someone draws the middle of the image, it counts negatively as a zero. So to recognize zeros you don't need some sophisticated filters and high-level features. You can just look at the drawn pixel locations and judge according to this.

Same thing for the $1$. It always has a straight vertical line in the middle of the image. All else counts negatively.

The rest of the digits are a bit more complicated, but with little imaginations you can see the $2$, the $3$, the $7$ and the $8$. The rest of the numbers are a bit more difficult, which is what actually limits the logistic regression from reaching the high-90s.

Through this you can see that logistic regression has a very good chance of getting a lot of images right and that's why it scores so high.


The code to reproduce the above figure is a bit dated, but here you go:

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

# Load MNIST:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

# Create model
x = tf.placeholder(tf.float32, shape=(None, 784))
y = tf.placeholder(tf.float32, shape=(None, 10))

W = tf.Variable(tf.zeros((784,10)))
b = tf.Variable(tf.zeros((10)))
z = tf.matmul(x, W) + b

y_hat = tf.nn.softmax(z)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_hat), reduction_indices=[1]))
optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) # 

correct_pred = tf.equal(tf.argmax(y_hat, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# Train model
batch_size = 64
with tf.Session() as sess:

    loss_tr, acc_tr, loss_ts, acc_ts = [], [], [], []

    sess.run(tf.global_variables_initializer()) 

    for step in range(1, 1001):

        x_batch, y_batch = mnist.train.next_batch(batch_size) 
        sess.run(optimizer, feed_dict={x: x_batch, y: y_batch})

        l_tr, a_tr = sess.run([cross_entropy, accuracy], feed_dict={x: x_batch, y: y_batch})
        l_ts, a_ts = sess.run([cross_entropy, accuracy], feed_dict={x: mnist.test.images, y: mnist.test.labels})
        loss_tr.append(l_tr)
        acc_tr.append(a_tr)
        loss_ts.append(l_ts)
        acc_ts.append(a_ts)

    weights = sess.run(W)      
    print('Test Accuracy =', sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})) 

# Plotting:
for i in range(10):
    plt.subplot(2, 5, i+1)
    weight = weights[:,i].reshape([28,28])
    plt.title(i)
    plt.imshow(weight, cmap='RdBu')  # as noted by @Eric Duminil, cmap='gray' makes the numbers stand out more
    frame1 = plt.gca()
    frame1.axes.get_xaxis().set_visible(False)
    frame1.axes.get_yaxis().set_visible(False)