Machine Learning Classification – Why Classifiers Need Same Prevalence in Training and Testing Sets

classificationmachine learningprevalence

Here are some commonly seen statements about the importance of prevalence in the train and test sets when developing a classifier:

"Another reason not to rebalance datasets is that models should be
trained on datasets whose distributions will reflect the future,
real-world test cases for which they will ultimately be applied."
https://towardsdatascience.com/why-balancing-classes-is-over-hyped-e382a8a410f7

"It is simply the case that a classifier trained to a 1/2 prevalence
situation will not be applicable to a population with a 1/1000
prevalence." https://www.fharrell.com/post/classification/

"fans of “classifiers” sometimes subsample from observations in the
most frequent outcome category (here Y=1) to get an artificial 50/50
balance of Y=0 and Y=1 when developing their classifier. Fans of such
deficient notions of accuracy fail to realize that their classifier
will not apply to a population when a much different prevalence of Y=1
than 0.5." https://www.fharrell.com/post/classification/*

These statements are so obvious that there is no explanation for it. But I struggle to understand, why is that the case?

In the typical situation that we are using a deterministic model + a decision boundary to perform classification (e.g. logistic regression, or a neural network), how does class prevalence in the test set really affecting the model results?

Prevalence might affect the training of a model if there are too few samples to be able to learn relevant features to distinguish a given class. Here we have the common example of a naive classifier trained on a heavily imbalanced training set measured on accuracy, where it will learn to only predict the majority class.

But once the classifier has learnt relevant features, predictions will be deterministic. It does not matter if we feed a single example in testing (prevalence = 100%), two examples of each class (prevalence = 50%) or 99:1 examples (prevalence 1%). The prediction for that one exemplar will always be the same.

If we have a very imbalanced data set (say 99:1 for two classes), I dont see why balancing the training set would introduce any problems. The training set would have an artificial prevalence of 50/50, to make sure the model learns relevant features for both classes (given there are still enough training examples for each one), and then this model can be deployed in a natural test set. If the test set has a prevalence of 99:1, the model has no awareness of it. Is this thinking wrong?

As pointed out here Different number of samples (observations) per class (one vs. all classification), a classifier is the likelihood function in a Bayesian model, and it is independent of prior probabilities (prevalence). Specifically in machine learning we don't use probabilistic models, but deterministic, where the Likelihood just gives out a point estimate for the model parameters. If we are interested in modeling a posterior probability, then we can combine the likelihood with priors, but that is a different statement and apporach that the usual approach in classification algorithms.

Are the initial statements about prevalence really applicable to deterministic models, or is there something else I am missing? Does a deterministic model really learn to mimic the prevalence of a training population?

Best Answer

"Why exactly does a classifier need the same prevalence in the train and test sets?"

Perhaps my answer to a related question on the DS SE might help

Doesn't over(/under)sampling an imbalanced dataset cause issues?

Yes, the classifier will expect the relative class frequencies in operation to be the same as those in the training set. This means that if you over-sample the minority class in the training set, the classifier is likely to over-predict that class in operational use.

To see why it is best to consider probabilistic classifiers, where the decision is based on the posterior probability of class membership p(C_i|x), but this can be written using Bayes' rule as

$p(C_i|x) = \frac{p(x|C_i)p(C_i)}{p(x)}\qquad$ where $\qquad p(x) = > \sum_j p(x|C_j)p(c_j)$,

so we can see that the decision depends on the prior probabilities of the classes, $p(C_i)$, so if the prior probabilities in the training set are different than those in operation, the operational performance of our classifier will be suboptimal, even if it is optimal for the training set conditions.

Some classifiers have a problem learning from imbalanced datasets, so one solution is to oversample the classes to ameliorate this bias in the classifier. There are to approaches. The first is to oversample by just the right amount to overcome this (usually unknown) bias and no more, but that is really difficult. The other approach is to balance the training set and then post-process the output to compensate for the difference in training set and operational priors. We take the output of the classifier trained on an oversampled dataset and multiply by the ratio of operational and training set prior probabilities,

$q_o(C_i|x) \propto p_t(x|C_i)p_t(C_i) \times \frac{p_o(C_i)}{p_t(C_i} > = p_t(x|C_i)p_o(C_i)$

Quantities with the o subscript relate to operational conditions and those wit the t subscript relate to training set conditions. I have written this as $q_o(C_i|x)$ as it is an un-normalised probability, but it is straight forward to renormalise them by dividing by the sum of $q_o(C_i|x)$ over all classes. For some problems it may be better to use cross-validation to chose the correction factor, rather than the theoretical value used here, as it depends on the bias in the classifier due to the imbalance.

So in short, for imbalanced datasets, use a probabilistic classifier and oversample (or reweight) to get a balanced dataset, in order to overcome the bias a classifier may have for imbalanced datasets. Then post-process the output of the classifier so that it doesn't over-predict the minority class in operation.

Specific issues:

If we have a very imbalanced data set (say 99:1 for two classes), I dont see why balancing the training set would introduce any problems.

It doesn't present a problem, provided you post-process the output of the model to compensate for the difference in training set and operational class frequencies. If you don't perform that adjustment (or you use a discrete yes-no classifier) you will over-predict the minority class for the reason given above.

"fans of “classifiers” sometimes subsample from observations in the most frequent outcome category (here Y=1) to get an artificial 50/50 balance of Y=0 and Y=1 when developing their classifier. Fans of such deficient notions of accuracy fail to realize that their classifier will not apply to a population when a much different prevalence of Y=1 than 0.5."

I don't think this accurately represents the situation. The reason for balancing is actually because the majority class is "more important" in some sense than the minority class, and the rebalancing is an attempt to include misclassification costs so that it does work better in operational conditions. However a lot of blogs don't explain that properly, so a lot of practitioners are rather misinformed about it.

Related Question