from mpl_toolkits.mplot3d import Axes3D from mpl_toolkits.mplot3d.art3d import Poly3DCollection, Line3DCollection from matplotlib import cm from matplotlib.ticker import LinearLocator, FormatStrFormatter import matplotlib.pyplot as plt from numpy import * # Initialize function evaluations f_evals = 0 # Define surface function def Rosenbrock(v): x = v[0] y = v[1] return (1-x)**2 + 100*(y - x**2)**2 # Prepare figure fig = plt.figure() ax = fig.gca(projection='3d') # Define function to plot simplex def plot_simplex(A): A = vstack([A, A[0,:]]) X = A[:, 0] Y = A[:, 1] Z_surf = A[:, 2] Z_proj = array([0, 0, 0, 0]) verts_surf = [zip(X, Y, Z_surf)] ax.add_collection3d(Poly3DCollection(verts_surf, facecolors='w')) verts_proj = [zip(X, Y, Z_proj)] ax.add_collection3d(Line3DCollection(verts_proj, colors='k', linewidths=0.5)) # Choose stopping criterion: simplex area min_area = 0.00001 A2 = array([1, 1, 1]) A2.resize(3, 1) def stop(A1): # Evaluate area as determinant of matrix # See http://mathworld.wolfram.com/TriangleArea.html A = hstack([A1, A2]) # Take absolute value, otherwise area signed area = absolute(linalg.det(A)/2) return area < min_area # Initialize vertices. Start near (x0, y0) = (-1, -1) v1 = array([-1, -1]).astype(float) f1 = Rosenbrock(v1) v2 = array([-0.9, -1]).astype(float) f2 = Rosenbrock(v2) v3 = array([-1, -0.9]).astype(float) f3 = Rosenbrock(v3) f_evals += 3 M1 = vstack([v1, v2, v3]).astype(float) M2 = vstack([f1, f2, f3]).astype(float) # First column: X, Second column: Y, Third column: f(X,Y) M = hstack([M1, M2]).astype(float) # Move simplex n = 0 while True: # Check stopping criterion if stop(M[:,0:2]) == True: # Return centroid of final simplex and function minimum mean = M[:,0:2].sum(0)/M[:,0:2].shape[0] print 'Minimum location', mean, '\nMinimum value', Rosenbrock(mean), '\nNumber of function evaluations', f_evals break # Identify worst vertex rank = argsort(M[:, -1]) best = rank[0] # first element has smallest value worst = rank[-1] # last element has largest value next_worst = rank[-2] # second to last element has next to largest val # Calculate mean of all but worst point mean = (M[:,0:2].sum(0) - M[worst,0:2])/(M.shape[0]-1) # Reflect point across (Eqn 15.2) new_XY = 2*mean - M[worst,0:2] new_f = Rosenbrock(new_XY) f_evals += 1 # Case 1. If new point better than best point, double step size (Eqn 15.3) if new_f < M[best, 2]: new_XY_2 = 3*mean - 2*M[worst,0:2] new_f_2 = Rosenbrock(new_XY_2) f_evals += 1 # Overwrite worst point with new point if new_f_2 < new_f: M[worst,0:2] = new_XY_2 M[worst,2] = new_f_2 print 'Step', n, ': doubled step size' else: M[worst,0:2] = new_XY M[worst,2] = new_f print 'Step', n, ': reflected' # Case 2. If new point still worst, reflect and shrink (Eqn 15.4) elif new_f > M[next_worst, 2]: new_XY_2 = 1.5*mean - 0.5*M[worst,0:2] new_f_2 = Rosenbrock(new_XY_2) f_evals += 1 # If no longer worst, update and continue if new_f_2 < M[next_worst, 2]: M[worst,0:2] = new_XY_2 M[worst,2] = new_f_2 print 'Step', n, ': reflected and shrinked' # If still worst, shrink step (Eqn 15.5) else: new_XY_3 = 0.5*(mean + M[worst,0:2]) new_f_3 = Rosenbrock(new_XY_3) f_evals += 1 # If no longer worst, update and continue if new_f_3 < M[next_worst, 2]: M[worst,0:2] = new_XY_3 M[worst,2] = new_f_3 print 'Step', n, ': shrinked' # If STILL worst, shrink entire simplex (Eqn 15.6) else: for i in range(M.shape[0]): M[i,0:2] = 0.5*(M[i,0:2] + M[best, 0:2]) M[i,2] = Rosenbrock(M[i,0:2]) f_evals += 1 print 'Step', n, ': shrinked entire simplex' # Case 3. If new point neither best nor worst, replace and continue else: M[worst,0:2] = new_XY M[worst,2] = new_f print 'Step', n, ': replaced' # Append updated simplex to plot plot_simplex(M) # Update counter n += 1 # Plot Rosenbrock surface X_Rosen = arange(-1.3, 1.3, 0.05) Y_Rosen = arange(-1, 1.2, 0.05) X_Rosen, Y_Rosen = meshgrid(X_Rosen, Y_Rosen) Z_Rosen = (1 - X_Rosen)**2 + 100*(Y_Rosen - X_Rosen**2)**2 surf_Rosen = ax.plot_surface(X_Rosen, Y_Rosen, Z_Rosen, rstride=1, cstride=1, cmap=cm.coolwarm, linewidth=0, antialiased=False, alpha = 0.3) # Adjust axes ax.set_zlim(0, 600) ax.zaxis.set_major_locator(LinearLocator(5)) ax.zaxis.set_major_formatter(FormatStrFormatter('%.0f')) # Render plot plt.show()