import random import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D import math import numpy as np fig = plt.figure() ax = fig.gca(projection='3d') def rosenbrock(x1,x2): return (1-x1)**2 + 100*(x2-x1**2)**2 x1 = np.arange(-3,3,.3) x2 = np.arange(-4,8,.3) x1,x2 = np.meshgrid(x1,x2) x3 = rosenbrock(x1,x2) ax.plot_wireframe(x1,x2,x3, rstride=1, cstride=1, linewidth=1, antialiased=True) def get_mean(verts): return np.array(sum([vert.xy for vert in verts])/len(verts)) #dh_simplex class vertex: def __init__(self): self.xy = np.array([random.random()+2,random.random()-4]) def getval(self) : return rosenbrock(*self.xy) def __repr__(self): return " vertex(%f,%f,%f) " %(self.xy[0],self.xy[1],self.getval()) def simplex_step(): mean = get_mean(verts[0:]) new_vert = vertex() new_vert.xy = 2*mean - verts[0].xy #good reflection? if new_vert.getval() < verts[-1].getval(): new_vertstrech = vertex() new_vertstrech.xy = 3*mean - 2*verts[0].xy #strech even better? if new_vertstrech.getval() < new_vert.getval(): verts[0]= new_vertstrech print "return from streched reflection" return verts else: verts[0]= new_vert print "return from reflection" return verts else: #we have to shrink and reflect thats better new_vert.xy = 2/3*mean - .5*verts[0].xy # nope we should just shrink towards the mean if new_vert.getval() > verts[1].getval(): new_vertshrink = vertex() new_vertshrink.xy =.5*(mean+verts[0].xy) #was this any good? if new_vertshrink.getval() < verts[1].getval(): verts[0]= new_vertshrink print "return from shrinking" return verts else: # no, appearently we are very close to a minimum lets shrink towards the best for vert in verts: vert.xy = .5*(vert.xy+verts[-1].xy) print "return from shrinking towards local minima" return verts else: verts[0]= new_vert print "return from shrinking + reflection" return verts verts = [vertex() for i in xrange(3)] print "intital search position" ,verts tol = .001 oldval =float('inf') trace = [] for vert in verts: trace.append([vert.xy[0],vert.xy[1],vert.getval()]) while oldval>verts[-1].getval()+tol: #sort worst to best verts = sorted(verts, key=lambda vertex:-vertex.getval()) oldval= verts[0].getval() simplex_step() for vert in verts: trace.append([vert.xy[0],vert.xy[1],vert.getval()]) #ax.plot(trace[0],trace[1],trace[2],c="r") print "Done!" trace = zip(*trace) #print trace ax.plot(trace[0],trace[1],trace[2],c="black") ax.scatter(trace[0],trace[1],c="black") ax.set_zlim3d(0, 15000) #plt.axis([-5,5, -5,5]) plt.show()