Scikit-learn QuantileRegressor memory allocation error. No issue with statsmodel QuantReg with the same data

pythonquantile regressionscikit learnstatsmodels

I'm trying to fit a quantile regression model to my input data. I would like to use sklearn, but I am getting a memory allocation error when I try to fit the model. The same data with the statsmodels equivalent function is working fine.

The error I get is the following:

numpy.core._exceptions._ArrayMemoryError: Unable to allocate 55.9 GiB for an array with shape (86636, 86636) and data type float64

It doesn't make any sense, my X and y are shapes (86636, 4) and (86636, 1) respectively.

Here's my script:

import pandas as pd
import statsmodels.api as sm
from sklearn.linear_model import QuantileRegressor

training_df = pd.read_csv("/path/to/training_df.csv") # 86,000 rows

FEATURES = [
    "feature_1",
    "feature_2",
    "feature_3",
    "feature_4",
]

TARGET = "target"


# STATSMODELS WORKS FINE WITH 86,000, RUNS IN 2-3 SECONDS.
model_statsmodels = sm.QuantReg(training_df[TARGET], training_df[FEATURES]).fit(q=0.5)

# SKLEARN GIVES A MEMORY ALLOCATION ERROR, OR TAKES MINUTES TO RUN IF I SIGNIFICANTLY TRIM THE DATA TO < 1000 ROWS.
model_sklearn = QuantileRegressor(quantile=0.5, alpha=0)
model_sklearn.fit(training_df[FEATURES], training_df[TARGET])

There error I get is the following:

I've checked the sklearn documentation and pretty sure my inputs are fine as dataframes, I get the same issues with NDarrays. So not sure what the issue is. Is it possible there's an issue with something under-the-hood?

[Here][1] is the scikit-learn documentation for QunatileRegressor.

Many thanks for any help / ideas.
[1]: https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.QuantileRegressor.html

Best Answer

The sklearn QuantileRegressor class uses linear programming to solve the quantile regression problem which is much more computationally expensive than iterative reweighted least squares as used by statsmodel QuantReg class.

Here is a github issue for the same problem: https://github.com/scikit-learn/scikit-learn/issues/22922

Thanks dipetkov for the link.

Related Question