Solved – Loss function for Logistic Regression

logisticloss-functionsmachine learning

If we are doing a binary classification using logistic regression, we often use the cross entropy function as our loss function. More specifically, suppose we have $T$ training examples of the form $(x^{(t)},y^{(t)})$, where $x^{(t)}\in\mathbb{R}^{n+1},y^{(t)}\in\{0,1\}$, we use the following loss function
$$\mathcal{LF}(\theta)=-\dfrac{1}{T}\sum_{t}y^{t}\log(\text{sigm}(\theta^T x))+(1-y^{(t)})\log(1-\text{sigm}(\theta^T x)\,,$$
where $\text{sigm}$ denotes the sigmoid function.

Question: However, if we are doing linear regression, we often use squared-error as our loss function. Are there any specific reasons for using the cross entropy function instead of using squared-error or the classification error in logistic regression?

I read somewhere that, if we use squared-error for binary classification, the resulting loss function would be non-convex. Is this the only reason reason, or is there any other deeper reason which I am missing?

Attempt: To get a sense of what different loss functions would look like, I have generated $50$ random datapoints on both sides of the line $y=x$. I have assigned the class $c=1$ to the datapoints which are present on one side of the line $y=x$, and $c=0$ to the other datapoints. After generating this data, I have computed the costs for different lines $\theta_1 x-\theta_2y=0$ which pass through the origin using the following loss functions:

  1. squared-error function using the predicted labels and the actual labels.
  2. squared-error function using the continuous scores $\theta^Tx$ instead of thresholding by $0$.
  3. squared-error function using the continuous scores $\text{sigm}(\theta^T x)$.
  4. classification error, i.e., number of misclassified points.
  5. cross entropy loss function.

I have considered only the lines which pass through the origin instead of general lines, such as $\theta_1x-\theta_2y+\theta_0=0$, so that I can plot the loss function. I have obtained the following plots.
enter image description here
enter image description here
From the above plots, we can infer the following:

  1. The plot corresponding to $1$ is neither smooth, it is not even continuous, nor convex. This makes sense since the cost can take only finite number of values for any $\theta_1,\theta_2$.
  2. The plot corresponding to $2$ is smooth as well as convex.
  3. The plot corresponding to $3$ is smooth but is not convex.
  4. The plot corresponding to $4$ is neither smooth nor convex, similar to $1$.
  5. The plot corresponding to $5$ is smooth as well as convex, similar to $2$.

If I am not mistaken, for the purpose of minimizing the loss function, the loss functions corresponding to $(2)$ and $(5)$ are equally good since they both are smooth and convex functions.

Is there any reason to use $(5)$ rather than $(2)$? Also, apart from the smoothness or convexity, are there any reasons for preferring cross entropy loss function instead of squared-error?

Best Answer

You got off on the wrong track as detailed here. Just because you have a binary $Y$ it doesn't mean that you should be interested in classification. You are really interested in a probability model, so logistic regression is a good choice. Get the nomenclature right or you will confuse everyone.

To the main point, the theory of statistical estimation shows that in the absence of outside information (which would make you use Bayesian logistic regression), maximum likelihood estimation is the gold standard for efficiency and bias. The log likelihood function provides the objective function.

You may have confused a loss/cost/utility function with estimation optimization. Get the optimum estimates using maximum likelihood estimation or penalized maximum likelihood (or better Bayesian modeling if you have constraints or other information). The a utility function comes in when needing to make an optimum decision to minimize expected loss (maximize expected utility). But I don't think you are asking about decision analysis. So stick with the gold standard objective function - the log likelihood.