Solved – Constructing a model with SMOTE and sklearn pipeline

classificationpythonscikit learnunbalanced-classes

I have a very imbalanced dataset on which I'm trying to construct a LinearSVC model with SMOTE and standardization, using a Pipeline. I had already applied SMOTE and sklearn's StandardScaler with LinearSVC, and then had constructed the same model with imblearn's make_pipeline. After having trained them both, I thought I would get the same accuracy scores in the tests, but that didn't happen.

SMOTE + StandardScaler + LinearSVC : 0.7647058823529411
SMOTE + StandardScaler + LinearSVC + make_pipeline : 0.7058823529411765

This is my code (I'll leave the imports and values for X and y in the end of the question:

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=27)

pipe = make_pipeline(SMOTE(random_state=42), StandardScaler(), LinearSVC(dual=False, random_state=13))
pipe = pipe.fit(X_train, np.array(y_train))
y_pred = pipe.predict(X_test)
accuracy_1 = accuracy_score(y_test, y_pred)

# Apply SMOTE to training data and keep original test data
sm = SMOTE(random_state=27)
X_train, y_train = sm.fit_sample(X_train, np.array(y_train))

# Apply standardization after SMOTE
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

model_fitted = LinearSVC(dual=False, random_state=13).fit(X_train, np.array(y_train))
svc_pred = model_fitted.predict(X_test)
accuracy_2 = accuracy_score(y_test, svc_pred)

print(f'SMOTE + StandardScaler + LinearSVC : {accuracy_1}')
print(f'SMOTE + StandardScaler + LinearSVC + make_pipeline : {accuracy_2}')

I've read here that applying a transformer and an estimator separatly influences the model validation, because the test fold already contains information (i.e. mean and std) about the training set, since X_train was used for standardization. Is that the reason why the test metrics are different with and without make_pipeline? Am I doing something wrong in the data preprocessing? Furthermore, which would be the "correct" implementation?
Note: this also happens with other classification models I have tested, such as sklearn's RandomForestClassifier, sklearn's LogisticRegression, xgboost's XGBClassifier, etc. Bellow is a sample of my dataset. It is worth saying the class distribution in y is roughly the same as in the original data, which is approximately normally distributed.

import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from imblearn.over_sampling import SMOTE
from imblearn.pipeline import make_pipeline
from sklearn.svm import LinearSVC
from sklearn.metrics import accuracy_score

X = pd.DataFrame([[282.54489323989515, 4, 68.16, 282.1480611960005, -6.8, 7.0403401020090985], [276.54339631417474, 4, 45.98, 342.12807470000047, -8.0, 7.152493798701296], [270.24417189647465, 4, 58.87000000000001, 343.12332366800047, -8.6, 7.137774871767122], [243.40109776254047, 4, 58.43, 309.0735966040004, -7.0, 7.064051762848258], [257.17042467369515, 4, 41.99, 276.1262631320004, -6.9, 6.311938468753469], [275.3731012246145, 7, 41.99, 330.0979976960004, -7.7, 6.7877044796869805], [265.9606513742565, 5, 51.22, 292.12117775200045, -6.9, 6.259273323565326], [326.58334774194606, 6, 103.96, 373.1096270880005, -7.5, 6.5930025859140855], [324.3768142343405, 7, 101.16, 374.0936426760005, -7.3, 6.719960972749474], [383.1075778787854, 8, 61.88000000000001, 427.14294708400075, -7.5, 7.269062418081917], [209.64125624379497, 2, 90.7, 239.080709908, -5.7, 5.592893958208465], [226.93724086966162, 3, 90.7, 253.09635997200002, -5.8, 5.802737996059506], [324.27768288769073, 6, 58.44000000000001, 372.13530359200064, -7.8, 7.227844120934621], [287.6044968433291, 5, 64.15, 374.07422332800047, -7.9, 7.139390063492067], [301.5965126238026, 5, 90.38, 370.12636044800047, -8.5, 7.351183041236545], [253.37623458851155, 7, 92.47, 325.98611209600017, -7.2, 6.186930161560662], [244.5860078879504, 6, 72.24, 309.99119747600025, -7.2, 6.725594845987346], [238.16516681807818, 6, 92.47, 292.0250844480002, -6.9, 6.042819706071707], [227.4354796800226, 5, 49.33, 281.00103388800017, -6.6, 5.568803884060376], [212.22441190958924, 4, 49.33, 247.04000624, -6.7, 5.853461533522028], [221.13625526232255, 5, 62.22, 281.99628285600016, -6.0, 5.552044152236652], [205.92518749188923, 4, 62.22, 248.035255208, -5.8, 5.520270148240644], [197.13496079132813, 3, 41.99, 232.040340588, -6.3, 5.712922288211787], [269.7450337433286, 4, 41.99, 308.0716407160004, -6.8, 5.673227140581641], [212.34602856176144, 4, 41.99, 266.0013682360002, -6.4, 5.793003591963598], [242.72369025122825, 6, 68.00999999999999, 275.0825397480004, -6.2, 5.860529551226553], [216.5158622726282, 4, 41.99, 226.110613068, -6.1, 6.271971700077697], [231.72693004306151, 5, 41.99, 260.0716407160001, -6.4, 6.3772530531690546], [225.4277056253615, 4, 54.02, 261.06688968400016, -5.3, 5.689922825396825], [252.62122300089482, 5, 41.99, 246.17321332400002, -6.1, 6.360644730824734], [249.02291466892822, 6, 41.99, 274.08729078000044, -5.8, 6.064690346597848], [250.53244991322836, 5, 91.79999999999998, 266.1167610680003, -7.1, 7.095551672050173], [239.53568970506169, 4, 65.78, 251.105862036, -6.8, 6.781502936951936], [256.83167433092837, 5, 65.78, 265.12151210000025, -6.4, 6.822023072039075], [233.23646528736165, 3, 77.81, 252.10111100400002, -5.6, 5.84349677811078], [379.1868982309576, 15, 113.58, 420.15326572800063, -5.8, 5.914399715839718], [251.78262455188363, 3, 58.64, 254.16304256400002, -6.2, 6.047892115162612], [232.97324844948372, 3, 58.64, 248.116092372, -6.6, 6.3525069249639285], [250.26923307535037, 4, 58.64, 262.1317424360002, -6.6, 5.438963752636251], [233.66639048055603, 4, 71.09, 255.100776656, -6.7, 6.572156237651242], [322.5948372520963, 6, 23.550000000000004, 320.18886338800064, -6.4, 6.457320252747259], [342.38182416082435, 7, 52.65000000000001, 351.19467704000067, -6.4, 6.904352297091798], [348.53559216058534, 6, 83.71, 365.17394159600076, -7.3, 6.91475121103896], [263.9706724272623, 5, 74.71000000000002, 316.04485458800036, -7.1, 6.2596209758574775], [273.34334207529537, 3, 23.550000000000004, 280.15756326000053, -6.7, 6.275560233877237], [279.9624427709684, 5, 23.550000000000004, 320.1091690960005, -6.9, 6.543782477633478], [242.60087095609484, 2, 61.44, 272.09833413200045, -6.8, 6.604806006160513], [249.21997165176785, 4, 61.44, 312.0499399680004, -6.9, 5.996064206404703], [375.6840430670745, 8, 61.83000000000001, 376.22497412400094, -6.6, 6.136790912698414], [356.48434271240757, 8, 57.90000000000001, 388.13107374000055, -7.8, 6.79856678421578], [296.7516255586903, 6, 68.13, 322.1555518800005, -7.4, 6.5142416255966245], [294.11516685789024, 5, 68.13, 320.13990181600053, -7.9, 7.204179028027528], [330.3981313859795, 6, 48.67, 358.1205090560005, -7.9, 6.80603791341991], [266.1537188075787, 2, 74.60000000000002, 314.0579088000003, -8.0, 6.292021111388609], [330.3981313859795, 6, 48.67, 358.1205090560005, -8.1, 6.791316292929299], [373.9773576657134, 8, 63.6, 358.21440944000085, -6.2, 6.684677804417808], [304.3119200595515, 4, 39.44, 328.10994437200037, -8.0, 6.449420691419692], [402.7000276929414, 11, 61.83000000000001, 390.2406241880009, -6.0, 6.694095905150409], [333.23162041421836, 7, 54.37, 316.2038447560008, -6.1, 6.405040869685871], [263.85065158331184, 3, 74.6, 304.07355886400035, -8.4, 6.290557231601733], [424.84863706529114, 6, 111.52, 510.13146766400047, -9.0, 6.586497894494393], [436.6704002902747, 7, 61.83000000000001, 478.1780239320007, -9.1, 7.549792955322462], [432.910874365208, 7, 72.83, 466.1780239320007, -8.1, 7.263302052170057], [314.5666155353682, 4, 111.9, 362.0790381680004, -8.2, 5.895165988566992], [338.9461414604349, 5, 108.73999999999998, 374.07903816800035, -8.2, 6.136152998556996], [411.7984310385411, 8, 68.9, 450.1467238040006, -9.1, 7.72519632156733]])
y = np.array([0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
       0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0])

Best Answer

You switched accuracy 1 and accuracy 2 in your print statements. Random states should all match. This answer on Stackoverflow might be helpful.

Related Question