[Math] How to create a custom probability density function from a discrete distribution

probability distributionspythonstatistics

I have some data that follows some unknown probability function. I would like to roughly extract that probability function.

My approach is to plot the data in a histogram and smooth it out using LOWESS. I implemented this following this post. I then use interpolation to create my cumulative distribution function (or at least I think so). But here I am stuck. I know I have to normalize the distribution somehow, but simply dividing by the number of data points does not seem to do the trick (read: something is completely wrong). How do I do this? Is there a better approach?

I would like to implement all of this in Python. Here is my code so far:

import numpy as np
from scipy import interpolate, integrate
from scipy.stats import rv_continuous
import statsmodels.api as sm
import matplotlib.pyplot as plt


class CustomDistribution(rv_continuous):
    def __init__(self, custom_cdf):
        super().__init__()
        self.custom_cdf = custom_cdf

    def _pdf(self, x, *args):
        return self.custom_cdf(x)

    # def _ppf(self, q, **kwargs):
    #     return


# create some normally distributed values and make a histogram
a = np.random.normal(size=10000)
counts, bins = np.histogram(a, bins=100, density=True)
bin_widths = np.diff(bins)
bin_centers = bins[:-1] + bin_widths

# roughly try to fit some distribution
lowess = sm.nonparametric.lowess(counts, bin_centers, frac=0.1)
fit = interpolate.interp1d(
    lowess[:, 0], lowess[:, 1], kind='cubic', fill_value=0,
    bounds_error=False
)

# integral = integrate.quad(fit, -np.inf, np.inf)[0]
# print(integral)
distr = CustomDistribution(fit)
# print(distr.rvs(size=1))

plt.hist(a, 100, density=True, label='original data')
plt.plot(bin_centers, fit(bin_centers), label='fit')
plt.hist(distr.rvs(size=100), 100, density=True, label='data from fit')
plt.legend()
plt.show()

In order to be able to draw random variates from this distribution later on, I also need to implement the inverse cdf / _ppf(), according to the rv_continuous documentation.

EDIT

So I have been able to draw some random variates from the fit, but it takes incredibly long. The main reason seems to be that in order to draw the variates, I need to have a properly implemented inverse cdf / _ppf(). From the documentation of rv_continuous:

The default method _rvs relies on the inverse of the cdf, _ppf, applied to a uniform random variate. In order to generate random variates efficiently, either the default _ppf needs to be overwritten (e.g. if the inverse cdf can expressed in an explicit form) or a sampling method needs to be implemented in a custom _rvs method.

Because I don't have that, the cdf is created by integrating over the pdf, but for some reason, the integral does not seem to converge:

IntegrationWarning: The maximum number of subdivisions (50) has been achieved.
  If increasing the limit yields no improvement it is advised to analyze 
  the integrand in order to determine the difficulties.  If the position of a 
  local difficulty can be determined (singularity, discontinuity) one will 
  probably gain from splitting up the interval and calling the integrator 
  on the subranges.  Perhaps a special-purpose integrator should be used.

So I guess I need a good way to create my own implementation of at least the cdf, but better yet of the inverse cdf.
enter image description here

Best Answer

So after lots of research and many iterations of trial & error, I found a solution that is satisfying to me. I am going to list some of the posts and articles that I came across during my research in the hopes that it might be helpful to somebody else.

First of all, I would like to mention this answer, which lead me on my original path in the first place. I still don't fully understand the presented solution though, which is mainly why I tried my own approach shown in the question. It also mentions the possibility to add Gaussian tails, which I would have liked, but fails to explain how to do so.

Then there is this post, which asks about drawing random variates from a known discrete distribution. This is very easy to implement in Python using numpy.random.choice(). However, this is not exactly what I am looking for, although there is an answer that talks about drawing samples from a known continuous distribution. Since I don't know my distribution though, this does not work for me.

Eventually, I found this answer, which explains the approach that works best for me. The technique is called 'kernel density estimation' (KDE), which I guess is probably a concept well known in Mathematics, but was new to me. The author describes the general concept as follows:

Some modern computer programs have the ability to piece together curves of various shapes in such a way as to approximate the density function of the population from which a sample was chosen. (The result is sometimes called a 'spline'.)

Here is an article talking about how to work with splines in Python. This was still too complicated for me though, but then I found this answer, explaining how to use a Gaussian KDE in Python. As far as I understand it, a Gaussian KDE just means to fit a bunch of Gaussian curves to your dataset, which is easy enough for me to undestand and in hindsight a pretty obvious solution.

My Solution

I know of two libraries that offer Gaussian KDE fitting in Python. The one that I came across first, is the scipy.stats.gaussian_kde function. Its advantages are that it offers several ways to determine the binwidth of the Gaussian kernels and comes with a method that directly generates the pdf. It didn't 100% work for me at fist though, which is why I looked for standard ways to fit other types of kernels.

That lead me to this article, which mentions the sklearn.neighbors.KernelDensity class. Its advantage is that next to the Gaussian kernel, it also offers a whole bunch of other kernels (’tophat’, ’epanechnikov’, ’exponential’, ’linear’, ’cosine’). However, generating the pdf is more cumbersome for some reason.

Both versions come with their own set of different parameters that can be optimized, although I haven't fully explored them yet.

Below you find code which compares the two methods using some sample data:

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde
from sklearn.neighbors import KernelDensity


# define classes for the gaussian kde's that are easier to work with
class GaussianKDE:
    def __init__(self, bins, tc=True, lb=None, ub=None):
        self.truncated = tc
        if self.truncated:
            if lb is None:
                bin_width = np.diff(bins[:2])
                lb = bins[0] - bin_width/2
            if ub is None:
                bin_width = np.diff(bins[-3:-1])
                ub = bins[-1] + bin_width/2
            self.bounds = (lb, ub)

    def _pdf(self, pdf, x):
        if self.truncated:
            pdf[(x < self.bounds[0]) | (x > self.bounds[1])] = 0

        return pdf

    def rvs(self, size=0):
        return np.array([])

    def _rvs(self, rvs):
        if self.truncated:
            mask = (rvs < self.bounds[0]) | (rvs > self.bounds[1])
            n_outside = np.sum(mask)
            if n_outside:
                rvs = rvs[mask.__invert__()]
                rvs = np.append(rvs, self.rvs(size=n_outside))

        return rvs


class SkleanGaussian(GaussianKDE):
    def __init__(
        self, bins, weights, bandwidth, truncated=True,
        lower_bound=None, upper_bound=None
    ):
        super().__init__(bins, tc=truncated, lb=lower_bound, ub=upper_bound)
        self.kde = KernelDensity(bandwidth=bandwidth, kernel='gaussian')
        self.kde.fit(bins[:, None], sample_weight=weights)

    def pdf(self, x):
        pdf = np.exp(self.kde.score_samples(x[:, None]))

        return self._pdf(pdf, x)

    def rvs(self, size=1):
        rvs = self.kde.sample(n_samples=size)

        return self._rvs(rvs)


class ScipyGaussian(GaussianKDE):
    def __init__(
        self, bins, weights, bandwidth, truncated=True,
        lower_bound=None, upper_bound=None
    ):
        super().__init__(bins, tc=truncated, lb=lower_bound, ub=upper_bound)
        self.kde = gaussian_kde(bins, weights=weights, bw_method=bandwidth)

    def pdf(self, x):
        pdf = self.kde.pdf(x)

        return self._pdf(pdf, x)

    def rvs(self, size=1):
        rvs = self.kde.resample(size=size)[0]

        return self._rvs(rvs)


a = [some dataset]

n_bins = 36
weights, bins = np.histogram(a, bins=n_bins, density=True)

# the scale is necessary so that the pdf units are in counts
counts, _ = np.histogram(a, bins=n_bins, density=False)
scale = max(counts)/max(weights)

# somehow the kde can only handle values > 0
weights[weights == 0] = 10e-10

# using the bin centers instead of the bin edges makes the fit more accurate
bin_widths = np.diff(bins)
bin_centers = bins[:-1] + bin_widths/2

# I found that using the bin width as the band width for the gaussian
# kernels to usually be a good choice, however, for the scipy kde, this has
# to be multiplied by 4 for some reason.
bandwidth = bin_widths[0]

# plot the dataset
plt.hist(a, n_bins, density=False, label='original data')
fit_xdata = np.linspace(bins[0], bins[-1], 1000)

# fit the dataset using the sklean gaussian kde
sk = SkleanGaussian(bin_centers, weights, bandwidth, lower_bound=0)
plt.plot(fit_xdata, sk.pdf(fit_xdata)*scale, label='sklean', linewidth=4)
plt.hist(
    sk.rvs(size=len(a)), bins=bins, density=False, label='sklearn data',
    histtype='step', linewidth=4
)

# fit the dataset using the scipy gaussian kde
sp = ScipyGaussian(bin_centers, weights, bandwidth*4, lower_bound=0)
plt.plot(fit_xdata, sp.pdf(fit_xdata)*scale, label='scipy', linewidth=4)
plt.hist(
    sp.rvs(size=len(a)), bins=bins, density=False, label='scipy data',
    histtype='step', linewidth=4
)

plt.legend()
plt.show()

enter image description here

Here is the dataset that I used for the plot above:

a = [
    0.06264390643490558, 0.09509194878588015, 0.03026355120009857,
    0.08127250971151266, 0.716335393111658, 0.1655346400736833,
    0.22021132083772382, 0.10405264304209878, 0.15158061573261902,
    1.1407146037314464, 1.3778213635979537, 0.13329059704017512,
    0.05554544035545377, 0.20906769301159273, 0.5566954753368197,
    0.3828496546267931, 0.06568476117108585, 0.5945406473694029,
    0.15191868810884088, 0.1897610554446672, 0.48630942301532426,
    0.4831209249147272, 1.439234543369555, 0.37311272029547893,
    0.5642500398077652, 0.1721244529796186, 0.5273312631232334,
    0.09616976653890508, 1.7021485946160702, 0.61447034734571,
    0.5523115399000325, 0.26668901873622003, 0.11724199817053041,
    0.07679793531486108, 0.5467101224578725, 0.19096230149841698,
    1.5183103660533463, 0.23257955113571152, 0.23779389633692136,
    0.4335181261541503, 0.23596773604981203, 0.02908997544346879,
    1.0674407653138862, 1.3494585876980802, 0.03885367859253224,
    0.1647270631314967, 1.3452485276638249, 0.055296850543858474,
    0.5326261321249462, 0.1095002052797961, 0.7355242953977977,
    0.32705388174865846, 0.16730022582292728, 0.7073702951150274,
    0.36460732110959304, 0.19467640490769073, 0.4495038666887294,
    0.09757906785593913, 0.3010801376352415, 0.12794699352187025,
    0.39989828215397, 0.20330797121132918, 0.29425252757674447,
    1.1953516399331021, 0.30819540047353705, 0.09655204788531789,
    0.6742888167179486, 0.34898376802601266, 0.525988432728582,
    0.26374070871269023, 0.2098040005553093, 0.7362005434660942,
    0.13011378085077427, 0.9702001966483852, 0.11879257050699737,
    0.11139207608379881, 0.26853989726608724, 0.1176863257430886,
    0.06942182181115912, 0.09337597357884074, 1.065787386865521,
    0.8600251431808253, 0.7414582724453467, 0.41368121072845515,
    0.10912611550111842, 0.4589679132595474, 0.6487664702146925,
    0.05081836603847919, 0.5982838285074653, 0.4951404470155634,
    0.874755188992434, 0.373927512890217, 0.23901475699659636,
    0.40534674921226244, 0.14572932928447854, 0.258095727129861,
    0.03701320217555316, 0.8608119476645506, 0.8035215302573417,
    0.25001363287622624, 0.3824321776006704, 0.11141062865370506,
    0.28757152982424217, 0.5418034135841103, 0.9979464062423947,
    0.33585844217537625, 0.10999515005480506, 0.9353993512283134,
    0.21633356328558692, 0.4135286376910485, 0.12517565255156923,
    0.5493703312508813, 0.18041891534047605, 0.6134781434007768,
    1.236849076997046, 0.11648039850271226, 0.48023683744489787,
    0.446085262613892, 0.15218060410132964, 0.5981226511467477,
    0.41655457237126264, 0.0916076697756312, 0.8921542343086389,
    0.37765806142120517, 0.5730556848410264, 0.4087943957867417,
    0.26539251087751103, 0.48334188724907456, 1.0940009339751164,
    0.4289527463607125, 0.09892174608694435, 0.09972326341531497,
    0.09302787943360856, 0.09254221754833591, 0.22443552408403697,
    0.15139820646247348, 0.4873604503030918, 0.23661873092128244,
    0.02237213029199188, 0.2745054397907624, 0.27676369147378477,
    0.4534408586858371, 0.21347828414822836, 0.06905488986155826,
    0.17340334860547618, 0.27604472469051894, 0.1531858326601331,
    0.18770701088419353, 0.6156182770718479, 0.5140259022206131,
    0.9790865042331606, 0.26292971777388546, 0.6524233338463798,
    0.5921673964274822, 0.339395316066237, 0.18831984075216548,
    0.17372650600658673, 1.300121408216778, 0.15889728482083987,
    0.12368630227927495, 0.08954277706375355, 0.3463700941028804,
    0.24220714837607554, 0.5109582845097397, 0.2728876481573656,
    0.17475295034837, 0.6306085097942234, 0.42162744054491025,
    0.4056695542708484, 0.18212274672855236, 0.8371282778503085,
    0.059853775614641065, 0.17923614412529443, 0.16822848734176118,
    0.10754075126512831, 0.42487936727784553, 0.14438712461237121,
    0.6615368265902856, 0.6232562393680249, 0.4318836965404065,
    0.7633134420222835, 0.4936370484262888, 0.33051384534054395,
    0.2587997878220695, 0.6585808316791898, 0.22124695640438757,
    0.20520843119564264, 1.068658697795139, 0.3585823254925308,
    0.15665653047993242, 0.43502435775611115, 0.17863764522983772,
    0.4428210519107696, 0.005188613331506297, 0.5316337457600517,
    0.1081311633487118, 0.47318606594252416, 0.3361081405496264,
    0.07761349433896658, 0.13053000096857073, 0.34115215289651174,
    0.1349288635587115, 0.23844269369432247, 0.2530265940902519,
    0.17630871792248376, 0.4146536126202932, 0.5687370367773086,
    0.312999754852458, 0.45252967096695673, 0.3586198213191352,
    0.26106509731982536, 0.6215295686341544, 0.20613684944624042,
    0.05627296046197301, 0.2952320632817755, 0.179266048785794,
    0.2341543444799701, 0.39873210928403, 0.3805417547615136,
    0.5093783829009244,
    0.2641373922159488, 0.17298488767008083, 0.2949659131539717,
    0.19821947853803826, 0.3222776906143646, 0.9128533648955071,
    0.07301370147748584, 0.3850687302360944, 0.3174591050711264,
    0.6087818303756414, 0.1029330789739462, 0.37809449054116956,
    1.2098043167134858, 0.14496481592915952, 0.39599261315360207,
    0.35485684520105976, 0.13860115769821593, 0.18103841926101988,
    1.0086819742783641, 0.5254484619509217, 0.5434383907132665,
    0.43297302616418437, 0.2959149312785611, 0.46141576362699227,
    0.5910517081695509, 0.5544360396270079, 0.333789874987028,
    0.35068713980019617, 0.4378869103419748, 0.2477388161628418,
    0.2776207053683075, 0.28138785354283696, 0.3121773128851953,
    0.8551094082649814, 0.36260811683188693, 0.13645457737286926,
    0.5032863207452049, 0.30009366753124794, 0.22732875715491696,
    0.46752362206958603, 0.13220007003480325, 0.09357959641336099,
    0.18308001164332718, 0.1344379481689656, 0.1974856028210634,
    0.42891788781552403, 0.2527505263451604, 0.16706368404648916,
    0.29815230814398436, 0.08718238488596576, 0.2012522435157994,
    0.0768729923513817, 0.26397928923437924, 0.2973978523275003,
    0.3583746943684728, 0.15870851376966866, 0.1304107410968139,
    0.11349589344395113, 0.3034133239576807, 0.3314294885480667,
    0.1835909929132002, 0.2610453989195211, 0.4578755183476205,
    0.19321346885847945, 0.10711391247695866, 0.4571782483977145,
    0.2991181260894714, 0.4417777063457217, 0.287704925522506,
    0.15802591941973748, 0.7832117789763846, 0.4028493656033323,
    0.11438325831296418, 0.33349574972657137, 0.2269718178112178,
    0.6094357005347584, 0.22206581172489798, 0.33134040379492596,
    0.7936463734651206, 0.2084780229120582, 0.24925017622807397,
    0.1312798413357054, 0.27232547836607107, 0.43756796411514304,
    0.8940605230980392, 0.7976405217299876, 1.540752406450383,
    0.12774291495349768, 0.14797402154192743, 0.11369093588919098,
    0.54802680049886, 0.7824536096080216, 0.21254517667187017,
    0.8873428013548432, 0.0515702638323821, 0.811207216889128,
    0.13866880127081788, 0.3095791172114035, 0.17209658043390716,
    0.04465397452524362, 0.056610798963341535, 0.16127494330380887,
    0.07352456663431992, 0.6196367956553391, 0.19341309771944393,
    0.01067667657836674, 0.1655774740325723, 0.11186982915353766,
    0.3695520033470526, 0.231503473221922, 0.14117834173014893,
    0.2998040376287597, 0.5554151652398902, 0.3086164916129593,
    0.16658637331182907, 0.08306798335693508, 0.15445153506855844,
    0.08333947143305998, 0.1806410523324605, 0.44749808461412754,
    0.47452679760299316, 0.08326594074423026, 0.30618902094836115,
    0.33147485563685103, 0.43707207237183293, 0.24505758162328187,
    0.5404070832286212, 0.09020491065214786, 0.21198592065200442,
    0.5707961579776272, 0.8746763226299498, 0.08260141353745308,
    0.35624785729306213, 0.34885579197930267, 0.4085317593525889,
    0.12362169974237566, 1.1109002064449287, 0.1462888288488881,
    0.7897067469658909, 0.2354360951628401, 0.5606853414763533,
    0.03048416989297525, 0.16228482915689738, 0.20759622014042464,
    0.22291858115560048, 0.07384607974087773, 0.7167466792800311,
    0.05612846888585919, 0.4116828260865948, 0.0718610699836005,
    0.28044619746943916, 0.31012076674169237, 0.2019293266605177,
    0.2773113174652737, 0.08911122228940253, 0.3725180990092935,
    0.13868693075185026, 0.616241851912476, 0.27467850649338554,
    0.29838997749883656, 0.14404731709232282, 0.8531089188901888,
    0.5216636427252435, 0.3914386252967373, 0.20230271478367096,
    0.3513050951675317, 0.5265638281073232, 0.2951340725921615,
    0.40219491596019924, 0.7222218692992858, 0.22006986102088374,
    0.4734266888255716, 0.23065250010233185, 0.3492795333266558,
    0.20954441795751014, 0.23215823642270114, 0.7466520337449913,
    0.2186694344946019, 0.3091832674278487, 0.12278575440349773,
    0.31083585426476856, 0.4186175488792576, 0.26152815441905136,
    0.2111673001412763, 0.5508069668518308, 0.41041320966392897,
    0.6696549009373107, 0.3951373064261152, 0.23462821245908802,
    0.07421264849147509, 0.15393755104174198, 0.17675695406916733,
    0.7086391854752552, 0.13678287126042765, 0.24567963252047093,
    1.7507239720494245, 0.2420117210905297, 0.9732947504562621,
    0.14372713368920265, 0.15527527790017748, 0.5594361595875883,
    0.32702156115831793, 0.33562865748193976, 0.7114573401140054,
    0.10110742894535309, 0.4154120724371111, 0.7103633082088364,
    0.004431506604978203, 0.17945593546206282, 0.47203048886021004,
    0.0931773052856488, 0.06748321024976894, 0.057769774664969145,
    0.06018386801407815, 0.04922693734941837, 0.15488630724090174,
    0.5828608048019022, 0.20072252614953653, 0.12744166702736842,
    0.003621774515854778, 0.575051076576879, 0.13327720382781044,
    0.18457209936381336, 0.4510854821150856, 0.2022858470184543,
    0.3166344207767408, 0.6210870797725996, 0.5059332396867006,
    0.13309628959165176, 0.6208306935070056, 1.091297773415639,
    0.40417245645876226, 0.3163877186319963, 0.2537384898988808,
    0.5773193815915504, 0.03409058825475501, 0.17013135868903237,
    0.16979181484424266, 0.44537353155821646, 0.3763459466475411,
    0.23965384075944998, 0.44429180138227026, 0.22481618350628446,
    0.29146020896413305, 0.7939904049175383, 0.04737656838952173,
    0.38965456751738603, 0.4183938596647599, 0.4237353177599,
    0.33986754377938416, 0.3786950459810704, 0.25595381961527014,
    0.5031533930439777, 0.21501620761777607, 0.18623177943921937,
    0.23893173190135122, 0.33537707415960405, 0.6150422678262184,
    0.1429232114467445, 0.5815825868278779, 0.623984618664029,
    0.44821273248526144, 0.2523244467058149, 0.38928537311463274,
    0.3131312407015937, 0.40233222010584074, 0.3512188676313344,
    0.1457322755366677, 0.3539015903980124, 0.2606778470219352,
    0.3984669438966321, 0.5128962614318571, 0.2521244605861394,
    0.07650193008100992, 0.1989495371540755, 0.7300753581115305,
    0.06784356156650419, 0.3481075782435549, 0.1409835093710039,
    0.15042622063061714, 0.3125205602544919, 0.33616909503086123,
    0.031236949122354703, 0.15570933495906555, 0.21906711000916115,
    0.021623992255017856, 0.4642906572190585, 0.4237911729703078,
    0.0019192385723672898, 0.27732114376106304, 0.1623237658761772,
    0.2580854517095076, 1.5447786799130103, 0.28921270696094425,
    0.16649067388964092, 0.3083875362880904, 0.3030460052851359,
    0.1699823130576581, 0.40892613802274674, 0.7422814174713747,
    0.6948183832978986, 0.5123473987809603, 0.45398273069305567,
    0.11354264325223219, 0.6534759753020928, 0.45842125126227207,
    0.31910372904138357, 0.674451857008122, 0.21352067300283756,
    0.139365287999925, 0.2625197786309779, 0.1726339379398195,
    0.030572480530779062, 0.19726829645886312, 0.5060947166303587,
    0.4532890275257452, 0.3065434284032874, 0.12682799636376038,
    0.090528811507983, 0.37764297456197476, 0.0733335457127021,
    0.4869517615744886, 0.8559878580825219, 0.09998230400655135,
    0.49532574591698303, 0.167188638066996, 0.3319661271360515,
    0.10239646461805112, 0.11857264423132291, 0.1492564990706423,
    0.2702397477098583, 0.13588441340039523, 0.19019806035250486,
    0.21384018516716766, 0.3211531080408406, 0.2499647212420634,
    0.12240413238556865, 0.12780065227225182, 0.2446387607364641,
    0.2243913924798307, 0.21715100406936516, 0.19134086206895268,
    0.23105901862935485, 0.4496479121201125, 0.3839717141679469,
    0.06582995176464726, 0.2693403773279218, 0.94563258264632,
    0.03317538295217102, 0.2467924570279441, 0.07147268251888911,
    0.6664566332618854, 0.3754209309946151, 0.11675547014158336,
    0.2808545079239046, 0.5447460300204945, 0.30905266116582814,
    0.20852310445684294, 0.24102989872215386, 0.39596709672795843,
    0.21958916357866948, 0.6340243268261101, 0.15598542254498654,
    0.16453314075910846, 0.24699051171072725, 0.3729704073809626,
    1.528338561270426, 1.3754498635580892, 0.38543485664198485,
    0.07543958222670002, 0.12367271594875584, 0.3027614689159825,
    1.0869247101332529, 0.3193620758305926, 0.783094928187596,
    0.625922084459429,
    0.1439254057729466, 0.05253020735175755, 0.13441219034877228,
    0.2839895155804683, 0.32019427898948266, 0.199992074441306,
    0.09573602363642589, 0.03851905382173678, 0.45817816823314905,
    0.39599748429650905, 0.021032608778116798, 0.6394982209846444,
    0.06160563003443469, 0.4389129185674811, 0.1600420871350196,
    0.05471160943534488, 0.3678550807669665, 0.4140780826571728,
    0.24824025383305012,
]