In [8]:
from matplotlib import cm
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
plt.rcParams['figure.dpi'] = 150

(a)

In [9]:
fig = plt.figure()
ax = fig.gca(projection='3d')
surf, cont = plt.subplots()

x = np.arange(-1.1,1.1,0.05)
y = np.arange(-1.1,1.1,0.05)
x, y = np.meshgrid(x,y)
f = (1-x)**2+100*(y-x*x)**2

surf = ax.plot_surface(x,y,f,cmap=cm.coolwarm,alpha=0.8)
cont.contour(x,y,f,100,cmap=cm.gist_gray,linewidths=1)

ax.set_zlim(0, 350)
cont.set(xlim=(-1, 1), ylim=(-1, 1))
cont.set_aspect('equal')

plt.show()

(b)

In [10]:
def func(x):
    f = (1-x[0])**2+100*(x[1]-x[0]**2)**2
    return f

def reflect(x,mean):
    x = np.subtract(np.multiply(2,mean),x)
    return x

def refgrow(x,mean):
    x = np.subtract(np.multiply(3,mean),np.multiply(2,x))
    return x

def refshrink(x,mean):
    x = np.subtract(np.multiply(1.5,mean),np.multiply(0.5,x))
    return x

def shrink(x,mean):
    x = np.multiply(0.5,np.add(mean,x))
    return x
    
def shrinkall(x):
    for i in range(3):
        x[i] = np.multiply(0.5,np.add(x[i],x[2]))
    return x
In [11]:
def move():
    global simplex 
    simplex = np.array(sorted(simplex,key=lambda x:-func(x)))
    mean = np.multiply(0.5,np.add(simplex[1],simplex[2]))
    x_now = simplex[0]
    
    if func(reflect(x_now,mean)) < func(simplex[2]):
        simplex[0] = reflect(x_now,mean)
        if func(refgrow(x_now,mean)) < func(simplex[2]):
            simplex[0] = refgrow(x_now,mean)
            
    elif (func(refshrink(x_now,mean)) > func(simplex[2])):
        simplex[0] = refshrink(x_now,mean)
        
    elif func(shrink(x_now,mean)) < func(simplex[2]):
        simplex[0] = shrink(x_now,mean)
        
    else: simplex = shrinkall(simplex)   
In [12]:
from matplotlib.path import Path
from matplotlib.patches import PathPatch
import time

fig, cont = plt.subplots()

x = np.arange(-1.1,1.1,0.05)
y = np.arange(-1.1,1.1,0.05)
x, y = np.meshgrid(x,y)
f = (1-x)**2+100*(y-x*x)**2


simplex = np.array([[-1,-1],[-1,-0.5],[-0.5,-1]])

cont.contour(x,y,f,100,cmap=cm.gist_gray,linewidths=1,alpha=0.3)
cont.set(xlim=(-1, 1), ylim=(-1, 1))
cont.set_aspect('equal')

for i in range(100):
    path = Path(simplex)
    pathpatch = PathPatch(path,facecolor='gray',edgecolor='black')
    cont.add_patch(pathpatch)
    move()
    if func(simplex[1])-func(simplex[2])<0.01: break
    
plt.show()
print(i)
21