from matplotlib import cm
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
plt.rcParams['figure.dpi'] = 150
(a)
fig = plt.figure()
ax = fig.gca(projection='3d')
surf, cont = plt.subplots()
x = np.arange(-1.1,1.1,0.05)
y = np.arange(-1.1,1.1,0.05)
x, y = np.meshgrid(x,y)
f = (1-x)**2+100*(y-x*x)**2
surf = ax.plot_surface(x,y,f,cmap=cm.coolwarm,alpha=0.8)
cont.contour(x,y,f,100,cmap=cm.gist_gray,linewidths=1)
ax.set_zlim(0, 350)
cont.set(xlim=(-1, 1), ylim=(-1, 1))
cont.set_aspect('equal')
plt.show()
(b)
def func(x):
f = (1-x[0])**2+100*(x[1]-x[0]**2)**2
return f
def reflect(x,mean):
x = np.subtract(np.multiply(2,mean),x)
return x
def refgrow(x,mean):
x = np.subtract(np.multiply(3,mean),np.multiply(2,x))
return x
def refshrink(x,mean):
x = np.subtract(np.multiply(1.5,mean),np.multiply(0.5,x))
return x
def shrink(x,mean):
x = np.multiply(0.5,np.add(mean,x))
return x
def shrinkall(x):
for i in range(3):
x[i] = np.multiply(0.5,np.add(x[i],x[2]))
return x
def move():
global simplex
simplex = np.array(sorted(simplex,key=lambda x:-func(x)))
mean = np.multiply(0.5,np.add(simplex[1],simplex[2]))
x_now = simplex[0]
if func(reflect(x_now,mean)) < func(simplex[2]):
simplex[0] = reflect(x_now,mean)
if func(refgrow(x_now,mean)) < func(simplex[2]):
simplex[0] = refgrow(x_now,mean)
elif (func(refshrink(x_now,mean)) > func(simplex[2])):
simplex[0] = refshrink(x_now,mean)
elif func(shrink(x_now,mean)) < func(simplex[2]):
simplex[0] = shrink(x_now,mean)
else: simplex = shrinkall(simplex)
from matplotlib.path import Path
from matplotlib.patches import PathPatch
import time
fig, cont = plt.subplots()
x = np.arange(-1.1,1.1,0.05)
y = np.arange(-1.1,1.1,0.05)
x, y = np.meshgrid(x,y)
f = (1-x)**2+100*(y-x*x)**2
simplex = np.array([[-1,-1],[-1,-0.5],[-0.5,-1]])
cont.contour(x,y,f,100,cmap=cm.gist_gray,linewidths=1,alpha=0.3)
cont.set(xlim=(-1, 1), ylim=(-1, 1))
cont.set_aspect('equal')
for i in range(100):
path = Path(simplex)
pathpatch = PathPatch(path,facecolor='gray',edgecolor='black')
cont.add_patch(pathpatch)
move()
if func(simplex[1])-func(simplex[2])<0.01: break
plt.show()
print(i)