Solved – What do eps and tol do in LassoCV (scikit-learn)

lassopythonregressionregularizationscikit learn

Using scikit-learn:

http://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LassoCV.html

Specifically, I am interested in:

1) If eps grows, does the accuracy(precision) increase or decrease?

2) If tol grows, does the accuracy(precision) increase or decrease?

Best Answer

Here is an example of LassoCV's affect on MSE with varying eps and tol (using the diabetes dataset), for various $\alpha$'s. Note that this is the average MSE (each CV run will have a different MSE):

enter image description here

enter image description here

It appears that eps has a significant impact for some penalty parameters, but with a large enough penalty it doesn't matter. tol doesn't seem to play a large role (at least as far as scikit has implement LassoCV).

See below for code.

import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
%matplotlib inline
import numpy as np
from sklearn import datasets
from sklearn.linear_model import LassoCV

# load data
diabetes = datasets.load_diabetes()
X = diabetes.data
y = diabetes.target

# Plot of epsilons
epss = [0.0001, 0.001, 0.01, 0.1]

plt.figure(figsize=(10,6))
color = cm.rainbow(np.linspace(0,1,len(epss)))

for i,c in zip(epss,color):
    model = LassoCV(eps=i).fit(X, y)

    ymin, ymax = 2300, 3800
    plt.plot(m_log_alphas, model.mse_path_.mean(axis=-1), color=c,
             label='eps = {}'.format(i), linewidth=2)
    plt.legend()

    plt.xlabel('-log(alpha)')
    plt.ylabel('Mean square error')
    plt.axis('tight')
    plt.ylim(ymin, ymax)


# Plot of tols
plt.figure(figsize=(10,6))
tols = [0.0001, 0.001, 0.01, 0.1, 1]

color = cm.rainbow(np.linspace(0,1,len(tols)))

for i,c in zip(tols,color):
    model = LassoCV(tol=i).fit(X, y)

    ymin, ymax = 2300, 3800
    plt.plot(m_log_alphas, model.mse_path_.mean(axis=-1), color=c,
             label='tol = {}'.format(i), linewidth=2)
    plt.legend()

    plt.xlabel('-log(alpha)')
    plt.ylabel('Mean square error')
    plt.axis('tight')
    plt.ylim(ymin, ymax)