In [51]:
import numpy as np
from matplotlib import rc
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import seaborn as sns
from mpl_toolkits import mplot3d
import matplotlib.cm as cm
rc('animation', html='html5')
sns.set(style='whitegrid', rc={'figure.figsize':(10,8)})
In [52]:
N = 100
sigma = 0.25
X = np.random.uniform(-5, 5, N)
noise = np.random.normal(0, sigma, N)
Y = np.tanh(X) + noise

plt.plot(X, Y, 'o')
plt.show()
In [110]:
# M - number of clusters
# f_coeff - number of coefficents
def cwm(X, Y, M, f_coeff, iters):
    N = len(X)
    
    # for later animations, we are going to save all of the following quantities per iteration:
    # conditonal forecast (prediciton of y given x)
    forecast = np.zeros((iters, N))
    # forecast error
    forecast_error = np.zeros((iters, N))
    # cluster probabilites
    cluster_prob = np.zeros((iters, M, N))
    
    # initialize cluster mus uniformly
    mu = np.random.uniform(min(X), max(Y), M)
    
    # give each cluster an equal probability
    p_c = np.ones(M) / M
    
    # each input term is a 1d gaussian, so it has a single variance value
    var_x = np.ones(M)
    
    # each cluster also as beta(s)
    beta = np.ones((M, f_coeff))
    
    # each output term also has a variance
    var_y = np.ones(M)
    
    for it in range(iters):
        # P(x | c_m), equation (16.24)
        p_x_c = np.zeros((N, M))
        for i in range(N):
            for m in range(M):
                p_x_c[i, m] = np.exp((- (X[i] - mu[m]) ** 2) / (2 * var_x[m]) ) / np.sqrt(2 * np.pi * var_x[m])
                
        # Save probabilites per cluster
        for m in range(M):
            for i in range(N):
                cluster_prob[it, m, i] = p_x_c[i, m] * p_c[m]

        # P(y | x, c_m) - equation (16.26)
        def f(x, coeff):
            s = 0
            for i in range(len(coeff)):
                s += coeff[i] * (x ** i)
            return s
        p_y_x_c = np.zeros((N, M))
        for i in range(N):
            for m in range(M):
                p_y_x_c[i, m] = np.exp((- (Y[i] - f(X[i], beta[m])) ** 2) / (2 * var_y[m]) ) / np.sqrt(2 * np.pi * var_y[m])

        # Calculate the quantities per data point
        for i in range(N):
            # conditional_forcast  (16.27)
            nom = 0
            denom = 0
            for m in range(M):
                nom += f(X[i], beta[m]) * p_x_c[i, m] * p_c[m]
                denom += p_x_c[i, m] * p_c[m]
            forecast[it, i] = nom / denom
            
            # forecast error (16.28)
            error_nom = 0
            for m in range(M):
                error_nom += (var_y[m] + f(X[i], beta[m]) ** 2) * p_x_c[i, m] * p_c[m]
            forecast_error[it, i] = error_nom / denom - forecast[it, i] ** 2
        

        # P(c_m | y, x) - equation (16.29)
        p_c_y_x = np.zeros((N, M))
        for i in range(N):
            for m in range(M):
                p_c_y_x[i, m] = p_y_x_c[i, m] * p_x_c[i, m] * p_c[m]
            p_c_y_x[i] /= np.sum(p_c_y_x[i])

        # Update P(c_m) - equation (16.30)
        for m in range(M):
            p_c[m] = np.mean(p_c_y_x[:, m])

        # Update variance for the input terms - equation (16.32)
        # notice that we use the old mus
        for m in range(M):
            var_x[m] = np.sum(np.multiply((X - mu[m])**2, p_c_y_x[:, m])) / np.sum(p_c_y_x[:, m])
            # add a small constant
            var_x[m] += 0.1

        # Update mus - equation (16.31)
        for m in range(M):
            mu[m] = np.sum(np.multiply(X, p_c_y_x[:, m])) / np.sum(p_c_y_x[:, m])

        # Now let's estimate the local linear model parameters
        for m in range(M):
            a = np.zeros(f_coeff)
            for j in range(f_coeff):
                a[j] = np.sum(np.multiply(np.multiply(Y, X**j), p_c_y_x[:, m])) / np.sum(p_c_y_x[:, m])
            B = np.zeros((f_coeff, f_coeff))
            for i in range(f_coeff):
                for j in range(f_coeff):
                    B[i, j] = np.sum(np.multiply(np.multiply(X**i, X**j), p_c_y_x[:, m])) / np.sum(p_c_y_x[:, m])

            Binv = np.linalg.inv(B)

            beta[m] = Binv.dot(a)

        # Update variance for the output terms - equation (16.37)
        for m in range(M):
            var_y[m] = np.sum(np.multiply((Y - f(X, beta[m]))**2, p_c_y_x[:, m])) / np.sum(p_c_y_x[:, m])
            # add a small constant
            var_y[m] += 0.1
        
    return forecast, forecast_error, cluster_prob 

Constant local model with varying number of clusters

M = 1

In [148]:
current_palette = sns.color_palette()
In [154]:
forecast, forecast_error, cluster_prob = cwm(X, Y, 1, 1, iters=30)

fig, (ax, ax2) = plt.subplots(2, 1, gridspec_kw={'height_ratios': [2, 1]}, figsize=(9, 9))

# original datapoints
ax.plot(X, Y, 'o')
ax.set_xticklabels([])
ax.set_xlim(min(X), max(X))
ax.set_ylim(-1.75, 1.75)

# forecast with errorbars
plotline, caplines, barlinecols = ax.errorbar(X, forecast[0], yerr=forecast_error[0], ls='none')

# cluster probs
ax2.set_xlim(min(X), max(X))
ax2.set_ylim(0, 0.2)

cluster_lines = []
for m in range(cluster_prob.shape[1]):
    line, = ax2.plot(X, cluster_prob[0, m], 'o', color=current_palette[m+2])
    cluster_lines.append(line)

plt.tight_layout()
plt.close()

def update_anim(i):
    # Replot the forecast data first
    plotline.set_data(X, forecast[i])

    # Update the error bars
    # Credit to http://matplotlib.1069221.n5.nabble.com/Update-values-of-errorbars-td18337.html
    barlinecols[0].set_segments(zip(zip(X, forecast[i] - forecast_error[i]), zip(X, forecast[i] + forecast_error[i])))
    
    # Update the cluster probs
    for m in range(cluster_prob.shape[1]):
        cluster_lines[m].set_data(X, cluster_prob[i, m])
    

animation.FuncAnimation(fig, update_anim, range(len(forecast)), blit=False, interval=500, repeat=True)
Out[154]:

M = 2

In [155]:
# Two clusters, constant linear model
forecast, forecast_error, cluster_prob = cwm(X, Y, 2, 1, iters=30)

fig, (ax, ax2) = plt.subplots(2, 1, gridspec_kw={'height_ratios': [2, 1]}, figsize=(9, 9))

# original datapoints
ax.plot(X, Y, 'o')
ax.set_xticklabels([])
ax.set_xlim(min(X), max(X))
ax.set_ylim(-1.75, 1.75)

# forecast with errorbars
plotline, caplines, barlinecols = ax.errorbar(X, forecast[0], yerr=forecast_error[0], ls='none')

# cluster probs
ax2.set_xlim(min(X), max(X))
ax2.set_ylim(0, 0.2)

cluster_lines = []
for m in range(cluster_prob.shape[1]):
    line, = ax2.plot(X, cluster_prob[0, m], 'o', color=current_palette[m+2])
    cluster_lines.append(line)

plt.tight_layout()
plt.close()

def update_anim(i):
    # Replot the forecast data first
    plotline.set_data(X, forecast[i])

    # Update the error bars
    # Credit to http://matplotlib.1069221.n5.nabble.com/Update-values-of-errorbars-td18337.html
    barlinecols[0].set_segments(zip(zip(X, forecast[i] - forecast_error[i]), zip(X, forecast[i] + forecast_error[i])))
    
    # Update the cluster probs
    for m in range(cluster_prob.shape[1]):
        cluster_lines[m].set_data(X, cluster_prob[i, m])
    

animation.FuncAnimation(fig, update_anim, range(len(forecast)), blit=False, interval=500, repeat=True)
Out[155]:

M = 3

In [156]:
# Three clusters, constant linear model
forecast, forecast_error, cluster_prob = cwm(X, Y, 3, 1, iters=50)

fig, (ax, ax2) = plt.subplots(2, 1, gridspec_kw={'height_ratios': [2, 1]}, figsize=(9, 9))

# original datapoints
ax.plot(X, Y, 'o')
ax.set_xticklabels([])
ax.set_xlim(min(X), max(X))
ax.set_ylim(-1.75, 1.75)

# forecast with errorbars
plotline, caplines, barlinecols = ax.errorbar(X, forecast[0], yerr=forecast_error[0], ls='none')

# cluster probs
ax2.set_xlim(min(X), max(X))
ax2.set_ylim(0, 0.2)

cluster_lines = []
for m in range(cluster_prob.shape[1]):
    line, = ax2.plot(X, cluster_prob[0, m], 'o', color=current_palette[m+2])
    cluster_lines.append(line)

plt.tight_layout()
plt.close()

def update_anim(i):
    # Replot the forecast data first
    plotline.set_data(X, forecast[i])

    # Update the error bars
    # Credit to http://matplotlib.1069221.n5.nabble.com/Update-values-of-errorbars-td18337.html
    barlinecols[0].set_segments(zip(zip(X, forecast[i] - forecast_error[i]), zip(X, forecast[i] + forecast_error[i])))
    
    # Update the cluster probs
    for m in range(cluster_prob.shape[1]):
        cluster_lines[m].set_data(X, cluster_prob[i, m])
    

animation.FuncAnimation(fig, update_anim, range(len(forecast)), blit=False, interval=500, repeat=True)
Out[156]:

M = 8

In [164]:
# Six clusters, constant linear model
forecast, forecast_error, cluster_prob = cwm(X, Y, 8, 1, iters=80)

fig, (ax, ax2) = plt.subplots(2, 1, gridspec_kw={'height_ratios': [2, 1]}, figsize=(9, 9))

# original datapoints
ax.plot(X, Y, 'o')
ax.set_xticklabels([])
ax.set_xlim(min(X), max(X))
ax.set_ylim(-1.75, 1.75)

# forecast with errorbars
plotline, caplines, barlinecols = ax.errorbar(X, forecast[0], yerr=forecast_error[0], ls='none')

# cluster probs
ax2.set_xlim(min(X), max(X))
ax2.set_ylim(0, 0.2)

cluster_lines = []
for m in range(cluster_prob.shape[1]):
    line, = ax2.plot(X, cluster_prob[0, m], 'o', color=current_palette[m+2])
    cluster_lines.append(line)

plt.tight_layout()
plt.close()

def update_anim(i):
    # Replot the forecast data first
    plotline.set_data(X, forecast[i])

    # Update the error bars
    # Credit to http://matplotlib.1069221.n5.nabble.com/Update-values-of-errorbars-td18337.html
    barlinecols[0].set_segments(zip(zip(X, forecast[i] - forecast_error[i]), zip(X, forecast[i] + forecast_error[i])))
    
    # Update the cluster probs
    for m in range(cluster_prob.shape[1]):
        cluster_lines[m].set_data(X, cluster_prob[i, m])
    

animation.FuncAnimation(fig, update_anim, range(len(forecast)), blit=False, interval=500, repeat=True)
Out[164]:

Local linear model with varying number of clusters

M = 1

In [158]:
forecast, forecast_error, cluster_prob = cwm(X, Y, 1, 2, iters=30)

fig, (ax, ax2) = plt.subplots(2, 1, gridspec_kw={'height_ratios': [2, 1]}, figsize=(9, 9))

# original datapoints
ax.plot(X, Y, 'o')
ax.set_xticklabels([])
ax.set_xlim(min(X), max(X))
ax.set_ylim(-1.75, 1.75)

# forecast with errorbars
plotline, caplines, barlinecols = ax.errorbar(X, forecast[0], yerr=forecast_error[0], ls='none')

# cluster probs
ax2.set_xlim(min(X), max(X))
ax2.set_ylim(0, 0.2)

cluster_lines = []
for m in range(cluster_prob.shape[1]):
    line, = ax2.plot(X, cluster_prob[0, m], 'o', color=current_palette[m+2])
    cluster_lines.append(line)

plt.tight_layout()
plt.close()

def update_anim(i):
    # Replot the forecast data first
    plotline.set_data(X, forecast[i])

    # Update the error bars
    # Credit to http://matplotlib.1069221.n5.nabble.com/Update-values-of-errorbars-td18337.html
    barlinecols[0].set_segments(zip(zip(X, forecast[i] - forecast_error[i]), zip(X, forecast[i] + forecast_error[i])))
    
    # Update the cluster probs
    for m in range(cluster_prob.shape[1]):
        cluster_lines[m].set_data(X, cluster_prob[i, m])
    

animation.FuncAnimation(fig, update_anim, range(len(forecast)), blit=False, interval=500, repeat=True)
Out[158]:

M = 2

In [160]:
forecast, forecast_error, cluster_prob = cwm(X, Y, 2, 2, iters=50)

fig, (ax, ax2) = plt.subplots(2, 1, gridspec_kw={'height_ratios': [2, 1]}, figsize=(9, 9))

# original datapoints
ax.plot(X, Y, 'o')
ax.set_xticklabels([])
ax.set_xlim(min(X), max(X))
ax.set_ylim(-1.75, 1.75)

# forecast with errorbars
plotline, caplines, barlinecols = ax.errorbar(X, forecast[0], yerr=forecast_error[0], ls='none')

# cluster probs
ax2.set_xlim(min(X), max(X))
ax2.set_ylim(0, 0.2)

cluster_lines = []
for m in range(cluster_prob.shape[1]):
    line, = ax2.plot(X, cluster_prob[0, m], 'o', color=current_palette[m+2])
    cluster_lines.append(line)

plt.tight_layout()
plt.close()

def update_anim(i):
    # Replot the forecast data first
    plotline.set_data(X, forecast[i])

    # Update the error bars
    # Credit to http://matplotlib.1069221.n5.nabble.com/Update-values-of-errorbars-td18337.html
    barlinecols[0].set_segments(zip(zip(X, forecast[i] - forecast_error[i]), zip(X, forecast[i] + forecast_error[i])))
    
    # Update the cluster probs
    for m in range(cluster_prob.shape[1]):
        cluster_lines[m].set_data(X, cluster_prob[i, m])
    

animation.FuncAnimation(fig, update_anim, range(len(forecast)), blit=False, interval=500, repeat=True)
Out[160]:

M = 3

In [161]:
forecast, forecast_error, cluster_prob = cwm(X, Y, 3, 2, iters=50)

fig, (ax, ax2) = plt.subplots(2, 1, gridspec_kw={'height_ratios': [2, 1]}, figsize=(9, 9))

# original datapoints
ax.plot(X, Y, 'o')
ax.set_xticklabels([])
ax.set_xlim(min(X), max(X))
ax.set_ylim(-1.75, 1.75)

# forecast with errorbars
plotline, caplines, barlinecols = ax.errorbar(X, forecast[0], yerr=forecast_error[0], ls='none')

# cluster probs
ax2.set_xlim(min(X), max(X))
ax2.set_ylim(0, 0.2)

cluster_lines = []
for m in range(cluster_prob.shape[1]):
    line, = ax2.plot(X, cluster_prob[0, m], 'o', color=current_palette[m+2])
    cluster_lines.append(line)

plt.tight_layout()
plt.close()

def update_anim(i):
    # Replot the forecast data first
    plotline.set_data(X, forecast[i])

    # Update the error bars
    # Credit to http://matplotlib.1069221.n5.nabble.com/Update-values-of-errorbars-td18337.html
    barlinecols[0].set_segments(zip(zip(X, forecast[i] - forecast_error[i]), zip(X, forecast[i] + forecast_error[i])))
    
    # Update the cluster probs
    for m in range(cluster_prob.shape[1]):
        cluster_lines[m].set_data(X, cluster_prob[i, m])
    

animation.FuncAnimation(fig, update_anim, range(len(forecast)), blit=False, interval=500, repeat=True)
Out[161]:

M = 8

In [163]:
forecast, forecast_error, cluster_prob = cwm(X, Y, 8, 2, iters=80)

fig, (ax, ax2) = plt.subplots(2, 1, gridspec_kw={'height_ratios': [2, 1]}, figsize=(9, 9))

# original datapoints
ax.plot(X, Y, 'o')
ax.set_xticklabels([])
ax.set_xlim(min(X), max(X))
ax.set_ylim(-1.75, 1.75)

# forecast with errorbars
plotline, caplines, barlinecols = ax.errorbar(X, forecast[0], yerr=forecast_error[0], ls='none')

# cluster probs
ax2.set_xlim(min(X), max(X))
ax2.set_ylim(0, 0.2)

cluster_lines = []
for m in range(cluster_prob.shape[1]):
    line, = ax2.plot(X, cluster_prob[0, m], 'o', color=current_palette[m+2])
    cluster_lines.append(line)

plt.tight_layout()
plt.close()

def update_anim(i):
    # Replot the forecast data first
    plotline.set_data(X, forecast[i])

    # Update the error bars
    # Credit to http://matplotlib.1069221.n5.nabble.com/Update-values-of-errorbars-td18337.html
    barlinecols[0].set_segments(zip(zip(X, forecast[i] - forecast_error[i]), zip(X, forecast[i] + forecast_error[i])))
    
    # Update the cluster probs
    for m in range(cluster_prob.shape[1]):
        cluster_lines[m].set_data(X, cluster_prob[i, m])
    

animation.FuncAnimation(fig, update_anim, range(len(forecast)), blit=False, interval=500, repeat=True)
Out[163]:
In [ ]: