import numpy as np from numpy import * import matplotlib matplotlib.use('TkAgg') from matplotlib import pyplot as plt from mpl_toolkits.mplot3d import Axes3D from matplotlib import animation # get_ipython().magic(u'matplotlib inline') func_eval = 0 def f(x,y): global func_eval func_eval += 1 f = (1-x)**2 + 100*(y-x**2)**2 return f # def NM_simplex(f,[xs,ys]): # pick 2 more points xs = -1; ys = -1; p = zeros([3,2]) p_ord = zeros([3,2]) ff = zeros([3,1]) p[0,:] = [xs, ys] p[1,:] = [xs+0.05, ys+0.1] p[2,:] = [xs+0.2, ys+0.1] ff = f(p[:,0],p[:,1]) order = argsort(ff,0) p_ord[0,:] = p[order[0]] p_ord[1,:] = p[order[1]] p_ord[2,:] = p[order[2]] p = p_ord it = 0 pp = []; zz = []; while 1: ff = f(p[:,0],p[:,1]) midpoint = (p[0,:] + p[1,:])/2 p_new = 2*midpoint-p_ord[2,:] f_new = f(p_new[0],p_new[1]) if f_new < ff[0]: # Good move, new best p_new2 = 3*midpoint-2*p_ord[2,:] f_new2 = f(p_new2[0],p_new2[1]) if f_new2 < f_new: # better move, grow p[2,:] = p[1,:] p[1,:] = p[0,:] p[0,:] = p_new2 else: # not that good p[2,:] = p[1,:] p[1,:] = p[0,:] p[0,:] = p_new elif f_new < ff[1]: # better than 2nd best p[2,:] = p[1,:] p[1,:] = p_new else: # still worst p_new = 3.*midpoint/2.-p[2,:]/2. f_new = f(p_new[0],p_new[1]) if f_new < ff[1]: # shrink, not worst p[2,:] = p[1,:] p[1,:] = p_new else: p_new = (midpoint+p[2,:])/2. f_new = f(p_new[0],p_new[1]) if f_new < ff[1]: # shrink more, not worst p[2,:] = p[1,:] p[1,:] = p_new else: # still worst, shrink all vertices p[0,:] = (p_ord[0,:]+p_ord[0,:])/2. p[1,:] = (p_ord[1,:]+p_ord[0,:])/2. p[2,:] = (p_ord[2,:]+p_ord[0,:])/2. pp = np.append(pp,p[0,:],axis=-1) zz = np.append(zz,f(p[0,0],p[0,1])) it += 1 if (abs(f(p[0,0],p[0,1])) < 1e-6): print "Done!" print "Minimum @ " + str(p[0,:]) print "Iterations: " + str(it) print "Function Evaluations: " + str(func_eval) break x = arange(-3,3,0.1) y = arange(-3,3,0.1) xx, yy = meshgrid(x, y) z = f(xx, yy) fig = plt.figure() ax = Axes3D(fig) ax.set_zlim(0,amax(z)) # wireframe = ax.plot_wireframe(xx, yy, z,color="black",rstride=1, cstride=1) pp = reshape(pp,[it,2]) ax.plot(pp[:,0],pp[:,1],zz,'red',lw=1.5) contour = ax.contourf(xx,yy,z,200) plt.show()