Solved – Why is logistic regression particularly prone to overfitting in high dimensions

logisticoverfittingregularization

Why does "the asymptotic nature of logistic regression" make it particularly prone to overfitting in high dimensions? (source):

enter image description here

I understand the LogLoss (cross entropy) grows quickly as $y$ (true probability) approaches $1-y'$ (predicted probability):

but why does that imply that "the asymptotic nature of logistic regression would keep driving the loss towards 0 in high dimensions without regularization"?

In my mind, just because the loss can grow quickly (if we get very close to the wrong and full opposite answer), it doesn't mean that it would thus try to fully interpolate the data. If anything the optimizer would avoid entering the asymptotic part (fast growing part) of the loss as aggressively as it can.

Best Answer

The existing answers aren't wrong, but I think the explanation could be a little more intuitive. There are three key ideas here.

1. Asymptotic Predictions

In logistic regression we use a linear model to predict $\mu$, the log-odds that $y=1$

$$ \mu = \beta X $$

We then use the logistic/inverse logit function to convert this into a probability

$$ P(y=1) = \frac{1}{1 + e^{-\mu}} $$

enter image description here

Importantly, this function never actually reaches values of $0$ or $1$. Instead, $y$ gets closer and closer to $0$ as $\mu$ becomes more negative, and closer to $1$ as it becomes more positive.

enter image description here

2. Perfect Separation

Sometimes, you end up with situations where the model wants to predict $y=1$ or $y=0$. This happens when it's possible to draw a straight line through your data so that every $y=1$ on one side of the line, and $0$ on the other. This is called perfect separation.

Perfect separation in 1D

In 2D

enter image description here

When this happens, the model tries to predict as close to $0$ and $1$ as possible, by predicting values of $\mu$ that are as low and high as possible. To do this, it must set the regression weights, $\beta$ as large as possible.

Regularisation is a way of counteracting this: the model isn't allowed to set $\beta$ infinitely large, so $\mu$ can't be infinitely high or low, and the predicted $y$ can't get so close to $0$ or $1$.

3. Perfect Separation is more likely with more dimensions

As a result, regularisation becomes more important when you have many predictors.

To illustrate, here's the previously plotted data again, but without the second predictors. We see that it's no longer possible to draw a straight line that perfectly separates $y=0$ from $y=1$.

enter image description here


Code

# https://stats.stackexchange.com/questions/469799/why-is-logistic-regression-particularly-prone-to-overfitting

library(tidyverse)
theme_set(theme_classic(base_size = 20))

# Asymptotes
mu = seq(-10, 10, .1)
p = 1 / (1 + exp(-mu))
g = ggplot(data.frame(mu, p), aes(mu, p)) + 
  geom_path() +
  geom_hline(yintercept=c(0, 1), linetype='dotted') +
  labs(x=expression(mu), y='P(y=1)')
g

g + coord_cartesian(xlim=c(-10, -9), ylim=c(0, .001))

# Perfect separation
x = c(1, 2, 3, 4, 5, 6)
y = c(0, 0, 0, 1, 1, 1)
df = data.frame(x, y)
ggplot(df, aes(x, y)) +
  geom_hline(yintercept=c(0, 1), linetype='dotted') +
  geom_smooth(method='glm', 
              method.args=list(family=binomial), se=F) +
  geom_point(size=5) +
  geom_vline(xintercept=3.5, color='red', size=2, linetype='dashed')

## In 2D
x1 = c(rnorm(100, -2, 1), rnorm(100, 2, 1))
x2 = c(rnorm(100, -2, 1), rnorm(100, 2, 1))
y = ifelse( x1 + x2 > 0, 1, 0)
df = data.frame(x1, x2, y)
ggplot(df, aes(x1, x2, color=factor(y))) +
  geom_point() +
  geom_abline(intercept=1, slope=-1,
              color='red', linetype='dashed') +
  scale_color_manual(values=c('blue', 'black')) +
  coord_equal(xlim=c(-5, 5), ylim=c(-5, 5)) +
  labs(color='y')

## Same data, but ignoring x2
ggplot(df, aes(x1, y)) +
  geom_hline(yintercept=c(0, 1), linetype='dotted') +
  geom_smooth(method='glm', 
              method.args=list(family=binomial), se=T) +
  geom_point()
Related Question