Solved – How to threshold multiclass probability prediction to get confusion matrix

accuracyconfusion matrixlogisticmulti-classmultinomial-distribution

Lets say my multinomial logistic regression predict that a chance of a sample belonging to a each class is A=0.6, B=0.3, C=0.1 How do I threshold this values to get just binary prediction of a sample belonging to a class, taking in to an account imbalances of classes. I know what I would do if it's just a binary decision (threshold based on classes prevalence), or if the classes are balanced (classify to a class with highest probability). My end goal is to get 3×3 confusion matrix

Best Answer

According to @cangrejo's answer: https://stats.stackexchange.com/a/310956/194535, suppose the original output probability of your model is the vector $v$, and then you can define the prior distribution:

$\pi=(\frac{1}{\theta_1}, \frac{1}{\theta_2},..., \frac{1}{\theta_N})$, for $\theta_i \in (0,1)$ and $\sum_i\theta_i = 1$, where $N$ is the total number of labeled classes, $i$ is the class index.

Take $v' = v \odot \pi$ as the new output probability of your model, where $\odot$ denotes an element-wise product.

Now, your question can be reformulate to this: Finding the $\pi$ which optimize the metrics you have specified (eg. roc_auc_score) from the new output probability model. Once you find it, the $\theta s (\theta_1, \theta_2, ..., \theta_N)$ is your optimal threshold for each classes.

The Code part:


  1. Create a proxyModel class which takes your original model object as an argument and return a proxyModel object. When you called predict_proba() through the proxyModel object, it will calculate new probability automatically based on the threshold you specified:

    class proxyModel():
        def __init__(self, origin_model):
            self.origin_model = origin_model
    
        def predict_proba(self, x, threshold_list=None):
            # get origin probability
            ori_proba = self.origin_model.predict_proba(x)
    
            # set default threshold
            if threshold_list is None:
                threshold_list = np.full(ori_proba[0].shape, 1)
    
            # get the output shape of threshold_list
            output_shape = np.array(threshold_list).shape
    
            # element-wise divide by the threshold of each classes
            new_proba = np.divide(ori_proba, threshold_list)
    
            # calculate the norm (sum of new probability of each classes)
            norm = np.linalg.norm(new_proba, ord=1, axis=1)
    
            # reshape the norm
            norm = np.broadcast_to(np.array([norm]).T, (norm.shape[0],output_shape[0]))
    
            # renormalize the new probability
            new_proba = np.divide(new_proba, norm)
    
            return new_proba
    
        def predict(self, x, threshold_list=None):
            return np.argmax(self.predict_proba(x, threshold_list), axis=1)
    
  2. Implement a score function:

    def scoreFunc(model, X, y_true, threshold_list):
        y_pred = model.predict(X, threshold_list=threshold_list)
        y_pred_proba = model.predict_proba(X, threshold_list=threshold_list)
    
        ###### metrics ######
        from sklearn.metrics import accuracy_score
        from sklearn.metrics import roc_auc_score
        from sklearn.metrics import average_precision_score
        from sklearn.metrics import f1_score
    
        accuracy = accuracy_score(y_true, y_pred)
        roc_auc = roc_auc_score(y_true, y_pred_proba, average='macro')
        pr_auc = average_precision_score(y_true, y_pred_proba, average='macro')
        f1_value = f1_score(y_true, y_pred, average='macro')
    
        return accuracy, roc_auc, pr_auc, f1_value
    
    
  3. Define weighted_score_with_threshold() function, which takes the threshold as input and return weighted score:

    def weighted_score_with_threshold(threshold, model, X_test, Y_test, metrics='accuracy', delta=5e-5):
        # if the sum of thresholds were not between 1+delta and 1-delta, 
        # return infinity (just for reduce the search space of the minimizaiton algorithm, 
        # because the sum of thresholds should be as close to 1 as possible).
        threshold_sum = np.sum(threshold)
    
        if threshold_sum > 1+delta:
            return np.inf
    
        if threshold_sum < 1-delta:
            return np.inf
    
        # to avoid objective function jump into nan solution
        if np.isnan(threshold_sum):
            print("threshold_sum is nan")
            return np.inf
    
        # renormalize: the sum of threshold should be 1
        normalized_threshold = threshold/threshold_sum
    
        # calculate scores based on thresholds
        # suppose it'll return 4 scores in a tuple: (accuracy, roc_auc, pr_auc, f1)
        scores = scoreFunc(model, X_test, Y_test, threshold_list=normalized_threshold)    
    
        scores = np.array(scores)
        weight = np.array([1,1,1,1])
    
        # Give the metric you want to maximize a bigger weight:
        if metrics == 'accuracy':
            weight = np.array([10,1,1,1])
        elif metrics == 'roc_auc':
            weight = np.array([1,10,1,1])
        elif metrics == 'pr_auc':
            weight = np.array([1,1,10,1])
        elif metrics == 'f1':
            weight = np.array([1,1,1,10])
        elif 'all':
            weight = np.array([1,1,1,1])
    
        # return negatitive weighted sum (because you want to maximize the sum, 
        # it's equivalent to minimize the negative sum)
        return -np.dot(weight, scores)
    
  4. Use optimize algorithm differential_evolution() (better then fmin) to find the optimal threshold:

    from scipy import optimize
    
    output_class_num = Y_test.shape[1]
    bounds = optimize.Bounds([1e-5]*output_class_num,[1]*output_class_num)
    
    pmodel = proxyModel(model)
    
    result = optimize.differential_evolution(weighted_score_with_threshold, bounds, args=(pmodel, X_test, Y_test, 'accuracy'))
    
    # calculate threshold
    threshold = result.x/np.sum(result.x)
    
    # print the optimized score
    print(scoreFunc(model, X_test, Y_test, threshold_list=threshold))