import numpy as np
from matplotlib import rc
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import seaborn as sns
from mpl_toolkits import mplot3d
import matplotlib.cm as cm
rc('animation', html='html5')
sns.set(style='whitegrid', rc={'figure.figsize':(10,8)})
N = 100
sigma = 0.25
X = np.random.uniform(-5, 5, N)
noise = np.random.normal(0, sigma, N)
Y = np.tanh(X) + noise
plt.plot(X, Y, 'o')
plt.show()
# M - number of clusters
# f_coeff - number of coefficents
def cwm(X, Y, M, f_coeff, iters):
N = len(X)
# for later animations, we are going to save all of the following quantities per iteration:
# conditonal forecast (prediciton of y given x)
forecast = np.zeros((iters, N))
# forecast error
forecast_error = np.zeros((iters, N))
# cluster probabilites
cluster_prob = np.zeros((iters, M, N))
# initialize cluster mus uniformly
mu = np.random.uniform(min(X), max(Y), M)
# give each cluster an equal probability
p_c = np.ones(M) / M
# each input term is a 1d gaussian, so it has a single variance value
var_x = np.ones(M)
# each cluster also as beta(s)
beta = np.ones((M, f_coeff))
# each output term also has a variance
var_y = np.ones(M)
for it in range(iters):
# P(x | c_m), equation (16.24)
p_x_c = np.zeros((N, M))
for i in range(N):
for m in range(M):
p_x_c[i, m] = np.exp((- (X[i] - mu[m]) ** 2) / (2 * var_x[m]) ) / np.sqrt(2 * np.pi * var_x[m])
# Save probabilites per cluster
for m in range(M):
for i in range(N):
cluster_prob[it, m, i] = p_x_c[i, m] * p_c[m]
# P(y | x, c_m) - equation (16.26)
def f(x, coeff):
s = 0
for i in range(len(coeff)):
s += coeff[i] * (x ** i)
return s
p_y_x_c = np.zeros((N, M))
for i in range(N):
for m in range(M):
p_y_x_c[i, m] = np.exp((- (Y[i] - f(X[i], beta[m])) ** 2) / (2 * var_y[m]) ) / np.sqrt(2 * np.pi * var_y[m])
# Calculate the quantities per data point
for i in range(N):
# conditional_forcast (16.27)
nom = 0
denom = 0
for m in range(M):
nom += f(X[i], beta[m]) * p_x_c[i, m] * p_c[m]
denom += p_x_c[i, m] * p_c[m]
forecast[it, i] = nom / denom
# forecast error (16.28)
error_nom = 0
for m in range(M):
error_nom += (var_y[m] + f(X[i], beta[m]) ** 2) * p_x_c[i, m] * p_c[m]
forecast_error[it, i] = error_nom / denom - forecast[it, i] ** 2
# P(c_m | y, x) - equation (16.29)
p_c_y_x = np.zeros((N, M))
for i in range(N):
for m in range(M):
p_c_y_x[i, m] = p_y_x_c[i, m] * p_x_c[i, m] * p_c[m]
p_c_y_x[i] /= np.sum(p_c_y_x[i])
# Update P(c_m) - equation (16.30)
for m in range(M):
p_c[m] = np.mean(p_c_y_x[:, m])
# Update variance for the input terms - equation (16.32)
# notice that we use the old mus
for m in range(M):
var_x[m] = np.sum(np.multiply((X - mu[m])**2, p_c_y_x[:, m])) / np.sum(p_c_y_x[:, m])
# add a small constant
var_x[m] += 0.1
# Update mus - equation (16.31)
for m in range(M):
mu[m] = np.sum(np.multiply(X, p_c_y_x[:, m])) / np.sum(p_c_y_x[:, m])
# Now let's estimate the local linear model parameters
for m in range(M):
a = np.zeros(f_coeff)
for j in range(f_coeff):
a[j] = np.sum(np.multiply(np.multiply(Y, X**j), p_c_y_x[:, m])) / np.sum(p_c_y_x[:, m])
B = np.zeros((f_coeff, f_coeff))
for i in range(f_coeff):
for j in range(f_coeff):
B[i, j] = np.sum(np.multiply(np.multiply(X**i, X**j), p_c_y_x[:, m])) / np.sum(p_c_y_x[:, m])
Binv = np.linalg.inv(B)
beta[m] = Binv.dot(a)
# Update variance for the output terms - equation (16.37)
for m in range(M):
var_y[m] = np.sum(np.multiply((Y - f(X, beta[m]))**2, p_c_y_x[:, m])) / np.sum(p_c_y_x[:, m])
# add a small constant
var_y[m] += 0.1
return forecast, forecast_error, cluster_prob
current_palette = sns.color_palette()
forecast, forecast_error, cluster_prob = cwm(X, Y, 1, 1, iters=30)
fig, (ax, ax2) = plt.subplots(2, 1, gridspec_kw={'height_ratios': [2, 1]}, figsize=(9, 9))
# original datapoints
ax.plot(X, Y, 'o')
ax.set_xticklabels([])
ax.set_xlim(min(X), max(X))
ax.set_ylim(-1.75, 1.75)
# forecast with errorbars
plotline, caplines, barlinecols = ax.errorbar(X, forecast[0], yerr=forecast_error[0], ls='none')
# cluster probs
ax2.set_xlim(min(X), max(X))
ax2.set_ylim(0, 0.2)
cluster_lines = []
for m in range(cluster_prob.shape[1]):
line, = ax2.plot(X, cluster_prob[0, m], 'o', color=current_palette[m+2])
cluster_lines.append(line)
plt.tight_layout()
plt.close()
def update_anim(i):
# Replot the forecast data first
plotline.set_data(X, forecast[i])
# Update the error bars
# Credit to http://matplotlib.1069221.n5.nabble.com/Update-values-of-errorbars-td18337.html
barlinecols[0].set_segments(zip(zip(X, forecast[i] - forecast_error[i]), zip(X, forecast[i] + forecast_error[i])))
# Update the cluster probs
for m in range(cluster_prob.shape[1]):
cluster_lines[m].set_data(X, cluster_prob[i, m])
animation.FuncAnimation(fig, update_anim, range(len(forecast)), blit=False, interval=500, repeat=True)
# Two clusters, constant linear model
forecast, forecast_error, cluster_prob = cwm(X, Y, 2, 1, iters=30)
fig, (ax, ax2) = plt.subplots(2, 1, gridspec_kw={'height_ratios': [2, 1]}, figsize=(9, 9))
# original datapoints
ax.plot(X, Y, 'o')
ax.set_xticklabels([])
ax.set_xlim(min(X), max(X))
ax.set_ylim(-1.75, 1.75)
# forecast with errorbars
plotline, caplines, barlinecols = ax.errorbar(X, forecast[0], yerr=forecast_error[0], ls='none')
# cluster probs
ax2.set_xlim(min(X), max(X))
ax2.set_ylim(0, 0.2)
cluster_lines = []
for m in range(cluster_prob.shape[1]):
line, = ax2.plot(X, cluster_prob[0, m], 'o', color=current_palette[m+2])
cluster_lines.append(line)
plt.tight_layout()
plt.close()
def update_anim(i):
# Replot the forecast data first
plotline.set_data(X, forecast[i])
# Update the error bars
# Credit to http://matplotlib.1069221.n5.nabble.com/Update-values-of-errorbars-td18337.html
barlinecols[0].set_segments(zip(zip(X, forecast[i] - forecast_error[i]), zip(X, forecast[i] + forecast_error[i])))
# Update the cluster probs
for m in range(cluster_prob.shape[1]):
cluster_lines[m].set_data(X, cluster_prob[i, m])
animation.FuncAnimation(fig, update_anim, range(len(forecast)), blit=False, interval=500, repeat=True)
# Three clusters, constant linear model
forecast, forecast_error, cluster_prob = cwm(X, Y, 3, 1, iters=50)
fig, (ax, ax2) = plt.subplots(2, 1, gridspec_kw={'height_ratios': [2, 1]}, figsize=(9, 9))
# original datapoints
ax.plot(X, Y, 'o')
ax.set_xticklabels([])
ax.set_xlim(min(X), max(X))
ax.set_ylim(-1.75, 1.75)
# forecast with errorbars
plotline, caplines, barlinecols = ax.errorbar(X, forecast[0], yerr=forecast_error[0], ls='none')
# cluster probs
ax2.set_xlim(min(X), max(X))
ax2.set_ylim(0, 0.2)
cluster_lines = []
for m in range(cluster_prob.shape[1]):
line, = ax2.plot(X, cluster_prob[0, m], 'o', color=current_palette[m+2])
cluster_lines.append(line)
plt.tight_layout()
plt.close()
def update_anim(i):
# Replot the forecast data first
plotline.set_data(X, forecast[i])
# Update the error bars
# Credit to http://matplotlib.1069221.n5.nabble.com/Update-values-of-errorbars-td18337.html
barlinecols[0].set_segments(zip(zip(X, forecast[i] - forecast_error[i]), zip(X, forecast[i] + forecast_error[i])))
# Update the cluster probs
for m in range(cluster_prob.shape[1]):
cluster_lines[m].set_data(X, cluster_prob[i, m])
animation.FuncAnimation(fig, update_anim, range(len(forecast)), blit=False, interval=500, repeat=True)
# Six clusters, constant linear model
forecast, forecast_error, cluster_prob = cwm(X, Y, 8, 1, iters=80)
fig, (ax, ax2) = plt.subplots(2, 1, gridspec_kw={'height_ratios': [2, 1]}, figsize=(9, 9))
# original datapoints
ax.plot(X, Y, 'o')
ax.set_xticklabels([])
ax.set_xlim(min(X), max(X))
ax.set_ylim(-1.75, 1.75)
# forecast with errorbars
plotline, caplines, barlinecols = ax.errorbar(X, forecast[0], yerr=forecast_error[0], ls='none')
# cluster probs
ax2.set_xlim(min(X), max(X))
ax2.set_ylim(0, 0.2)
cluster_lines = []
for m in range(cluster_prob.shape[1]):
line, = ax2.plot(X, cluster_prob[0, m], 'o', color=current_palette[m+2])
cluster_lines.append(line)
plt.tight_layout()
plt.close()
def update_anim(i):
# Replot the forecast data first
plotline.set_data(X, forecast[i])
# Update the error bars
# Credit to http://matplotlib.1069221.n5.nabble.com/Update-values-of-errorbars-td18337.html
barlinecols[0].set_segments(zip(zip(X, forecast[i] - forecast_error[i]), zip(X, forecast[i] + forecast_error[i])))
# Update the cluster probs
for m in range(cluster_prob.shape[1]):
cluster_lines[m].set_data(X, cluster_prob[i, m])
animation.FuncAnimation(fig, update_anim, range(len(forecast)), blit=False, interval=500, repeat=True)
forecast, forecast_error, cluster_prob = cwm(X, Y, 1, 2, iters=30)
fig, (ax, ax2) = plt.subplots(2, 1, gridspec_kw={'height_ratios': [2, 1]}, figsize=(9, 9))
# original datapoints
ax.plot(X, Y, 'o')
ax.set_xticklabels([])
ax.set_xlim(min(X), max(X))
ax.set_ylim(-1.75, 1.75)
# forecast with errorbars
plotline, caplines, barlinecols = ax.errorbar(X, forecast[0], yerr=forecast_error[0], ls='none')
# cluster probs
ax2.set_xlim(min(X), max(X))
ax2.set_ylim(0, 0.2)
cluster_lines = []
for m in range(cluster_prob.shape[1]):
line, = ax2.plot(X, cluster_prob[0, m], 'o', color=current_palette[m+2])
cluster_lines.append(line)
plt.tight_layout()
plt.close()
def update_anim(i):
# Replot the forecast data first
plotline.set_data(X, forecast[i])
# Update the error bars
# Credit to http://matplotlib.1069221.n5.nabble.com/Update-values-of-errorbars-td18337.html
barlinecols[0].set_segments(zip(zip(X, forecast[i] - forecast_error[i]), zip(X, forecast[i] + forecast_error[i])))
# Update the cluster probs
for m in range(cluster_prob.shape[1]):
cluster_lines[m].set_data(X, cluster_prob[i, m])
animation.FuncAnimation(fig, update_anim, range(len(forecast)), blit=False, interval=500, repeat=True)
forecast, forecast_error, cluster_prob = cwm(X, Y, 2, 2, iters=50)
fig, (ax, ax2) = plt.subplots(2, 1, gridspec_kw={'height_ratios': [2, 1]}, figsize=(9, 9))
# original datapoints
ax.plot(X, Y, 'o')
ax.set_xticklabels([])
ax.set_xlim(min(X), max(X))
ax.set_ylim(-1.75, 1.75)
# forecast with errorbars
plotline, caplines, barlinecols = ax.errorbar(X, forecast[0], yerr=forecast_error[0], ls='none')
# cluster probs
ax2.set_xlim(min(X), max(X))
ax2.set_ylim(0, 0.2)
cluster_lines = []
for m in range(cluster_prob.shape[1]):
line, = ax2.plot(X, cluster_prob[0, m], 'o', color=current_palette[m+2])
cluster_lines.append(line)
plt.tight_layout()
plt.close()
def update_anim(i):
# Replot the forecast data first
plotline.set_data(X, forecast[i])
# Update the error bars
# Credit to http://matplotlib.1069221.n5.nabble.com/Update-values-of-errorbars-td18337.html
barlinecols[0].set_segments(zip(zip(X, forecast[i] - forecast_error[i]), zip(X, forecast[i] + forecast_error[i])))
# Update the cluster probs
for m in range(cluster_prob.shape[1]):
cluster_lines[m].set_data(X, cluster_prob[i, m])
animation.FuncAnimation(fig, update_anim, range(len(forecast)), blit=False, interval=500, repeat=True)
forecast, forecast_error, cluster_prob = cwm(X, Y, 3, 2, iters=50)
fig, (ax, ax2) = plt.subplots(2, 1, gridspec_kw={'height_ratios': [2, 1]}, figsize=(9, 9))
# original datapoints
ax.plot(X, Y, 'o')
ax.set_xticklabels([])
ax.set_xlim(min(X), max(X))
ax.set_ylim(-1.75, 1.75)
# forecast with errorbars
plotline, caplines, barlinecols = ax.errorbar(X, forecast[0], yerr=forecast_error[0], ls='none')
# cluster probs
ax2.set_xlim(min(X), max(X))
ax2.set_ylim(0, 0.2)
cluster_lines = []
for m in range(cluster_prob.shape[1]):
line, = ax2.plot(X, cluster_prob[0, m], 'o', color=current_palette[m+2])
cluster_lines.append(line)
plt.tight_layout()
plt.close()
def update_anim(i):
# Replot the forecast data first
plotline.set_data(X, forecast[i])
# Update the error bars
# Credit to http://matplotlib.1069221.n5.nabble.com/Update-values-of-errorbars-td18337.html
barlinecols[0].set_segments(zip(zip(X, forecast[i] - forecast_error[i]), zip(X, forecast[i] + forecast_error[i])))
# Update the cluster probs
for m in range(cluster_prob.shape[1]):
cluster_lines[m].set_data(X, cluster_prob[i, m])
animation.FuncAnimation(fig, update_anim, range(len(forecast)), blit=False, interval=500, repeat=True)
forecast, forecast_error, cluster_prob = cwm(X, Y, 8, 2, iters=80)
fig, (ax, ax2) = plt.subplots(2, 1, gridspec_kw={'height_ratios': [2, 1]}, figsize=(9, 9))
# original datapoints
ax.plot(X, Y, 'o')
ax.set_xticklabels([])
ax.set_xlim(min(X), max(X))
ax.set_ylim(-1.75, 1.75)
# forecast with errorbars
plotline, caplines, barlinecols = ax.errorbar(X, forecast[0], yerr=forecast_error[0], ls='none')
# cluster probs
ax2.set_xlim(min(X), max(X))
ax2.set_ylim(0, 0.2)
cluster_lines = []
for m in range(cluster_prob.shape[1]):
line, = ax2.plot(X, cluster_prob[0, m], 'o', color=current_palette[m+2])
cluster_lines.append(line)
plt.tight_layout()
plt.close()
def update_anim(i):
# Replot the forecast data first
plotline.set_data(X, forecast[i])
# Update the error bars
# Credit to http://matplotlib.1069221.n5.nabble.com/Update-values-of-errorbars-td18337.html
barlinecols[0].set_segments(zip(zip(X, forecast[i] - forecast_error[i]), zip(X, forecast[i] + forecast_error[i])))
# Update the cluster probs
for m in range(cluster_prob.shape[1]):
cluster_lines[m].set_data(X, cluster_prob[i, m])
animation.FuncAnimation(fig, update_anim, range(len(forecast)), blit=False, interval=500, repeat=True)