In [48]:
using PyPlot

# 14.1 - A
function plot_rosen(f)
    X = [-1.5:0.01:1.5]
    Y = [-1.0:0.01:3]

    F = [f([x, y]) for x in X, y in Y]

    Xs = [x for x in X, y in Y]
    Ys = [y for x in X, y in Y]
    plot_surface(Xs, Ys, F, alpha=0.2)
end

const C = 100
f(x) = (1-x[1])^2 + C*(x[2]-x[1]^2)^2
# get the algebraic gradient
dfdx(x) = [-2*(1-x[1]) - 4C*(x[2]-x[1]^2)*x[1],
           2C*(x[2]-x[1]^2)]
plot_rosen(f)
Out[48]:
PyObject <mpl_toolkits.mplot3d.art3d.Poly3DCollection object at 0x10bbd7050>
Warning: redefining constant C

In [3]:
# 14.1 - B, The Downhill Simplex Method

function simplex_mean(s, chosen)
    # takes the mean of the simplex, excluding the chosen point
    m = [0, 0]
    for c in 1:3
        if c != chosen
            m += s[c]
        end
    end
    return m ./ 2
end

reflect(x, mean) = 2mean - x
reflect_stretch(x, mean) = 3mean - 2x
reflect_shrink(x, mean) = 3mean/2 - x/2
shrink(x, mean) = (mean + x) / 2

function shrink_all!(s)
    # assume the simplex has already been sorted in best->worst order
    s[:,2] = (s[:,2] + s[:,1])/2
    s[:,3] = (s[:,3] + s[:,1])/2
    return s
end

feval_simplex(s) = reshape(mapslices(f, s, 1), 3)
@assert feval_simplex([0 0; 1 1; 2 2]') == [f([0, 0]), f([1, 1]), f([2, 2])]

function next_simplex!(s)
    # modifies the given simplex and returns the number of function evals in this iteration
    res = feval_simplex(simplex)
    # 3 evals
    sortidxs = sortperm(res)
    s[:] = s[:, sortidxs]
    res = res[sortidxs]
    
    mn = (s[:,1] + s[:,2]) / 2
    x_ref = reflect(s[:,3], mn)
    v_ref = f(x_ref)
    # 4 evals
    if v_ref < res[1]
        # reflecting got us the best point, how about stretching?
#        println("reflecting = great")
        x_ref_str = reflect_stretch(s[:,3], mn)
        v_ref_str = f(x_ref_str)
        # 5 evals
        if v_ref_str < v_ref
#            println("reflecting + stretching = good")
            # stretching got us an even better point, we've got our next simplex
            s[:,3] = x_ref_str
            return 5
        end
        # OK, reflecting helped but stretching didn't
        s[:,3] = x_ref
        return 5
    elseif v_ref < res[2]
#        println("reflecting = good")
        # after reflecting it's no longer the worst but not the best. good enough
        s[:,3] = x_ref
        return 4
    end
#    println("reflecting = bad")
    # after reflecting this is still the worst
    x_ref_shr = reflect_shrink(s[:,3], mn)
    v_ref_shr = f(x_ref_shr)
    # 5 evals
    if v_ref_shr < res[2]
#        println("reflecting + shrinking = good")
        s[:,3] = x_ref_shr
        return 5
    end
    # reflecting and shrinking didn't work, how about just shrinking
    x_shr = shrink(s[:,3], mn)
    v_shr = f(x_shr)
    # 6 evals
    if v_shr < res[2]
#        println("shrinking = good")
        s[:,3] = x_shr
        return 6
    end
#    println("shrinking everything")
    # wow, still the worst after all that. Let's just shrink everybody!
    shrink_all!(s)
    return 6
end

function plot_simplex(s)
    data = vcat(s, feval_simplex(s)')
    data = hcat(data, data[:,1])
    
    plot3D(data[1,:]', data[2,:]', zs=data[3,:]', color="black")
end

simplex = [-1 -1; -0.9 -1; -1 -0.9]'
last_val = maximum(feval_simplex(simplex))
static_count = 0
evals = 0
#simplex = [0 3; 0.1 3; 0 3.1]'
plot_rosen(f)
plot_simplex(simplex)
while true
    evals += next_simplex!(simplex)
    plot_simplex(simplex)
    val = maximum(feval_simplex(simplex))
    if abs(val - last_val) < 0.01
        static_count += 1
        if static_count > 10
            break
        end
    else
        static_count = 0
    end
    last_val = val
end
gca()[:view_init](elev=60, azim=20)
println("Final simplex after $(evals) evals:")
println(simplex)
Final simplex after 220 evals:
1.0161401221528772	.9873635411262629	.9877134772017717
1.0328986672684621	.976464021205909	.9744354104623267


In [5]:
# Lets try implementing a line search

# ye goldene ratio
const gold = (1+sqrt(5))/2

function find_enclosing(x_init, d)
    # finds an enclosing set of 2 points with one in the middle
    # takes a starting point and direction vector
    # returns a 6-tuple with (x1, y1, x2, y2, x3, y3)
    
    x1 = x_init
    y1 = f(x1)
    x2 = x1 + d
    y2 = f(x2)
    
    scatter3D(x1[1], x1[2], zs=f(x1), color="red", s=5)

    if y2 > y1
        # uh oh, we either overshot already or we're going in the wrong direction
        xtemp = x1
        ytemp = y1
        x1 = x2
        y1 = y2
        x2 = xtemp
        y2 = ytemp
    end
    x3 = x2 + (x2 - x1) * gold
    y3 = f(x3)
    plot3D([x1[1], x2[1]], [x1[2], x2[2]], zs=[y1, y2], color="red")
    scatter3D(x2[1], x2[2], zs=y2, color="red", s=5)
    plot3D([x2[1], x3[1]], [x2[2], x3[2]], zs=[y2, y3], color="red")
    scatter3D(x3[1], x3[2], zs=y3, color="red", s=5)
    while y3 < y2
        x1 = x2
        y1 = y2
        x2 = x3
        y2 = y3
        x3 = x2 + (x2 - x1) * gold
        y3 = f(x3)
        plot3D([x2[1], x3[1]], [x2[2], x3[2]], zs=[y2, y3], color="red")
        scatter3D(x3[1], x3[2], zs=y3, color="red", s=5)
    end
    (x1, y1, x2, y2, x3, y3)
end

function find_minimum(x1, y1, x2, y2, x3, y3, thresh)
    # assumes that x1 and x3 form a bracket that encloses the minimum, and x2 is an intermediate point
    # that forms a golden section between x1 and x3. It is also assumed that the smaller interval is
    # between x1 and x2. It is NOT assumed that x3 > x1.
    iterations = 0

    while norm(x3 - x1) > thresh && iterations < 10000
        #println("norm(x3-x1) = $(norm(x3 - x1))")
        #println("x1 = ($(x1[1]), $(x1[2])), y1 = $(y1)")
        #println("x2 = ($(x2[1]), $(x2[2])), y2 = $(y2)")
        #println("x3 = ($(x3[1]), $(x3[2])), y3 = $(y3)")
        for i in 1:2
            @assert (x1[i] <= x2[i] <= x3[i]) || (x1[i] >= x2[i] >= x3[i])
        end
        x4 = x1 + x3 - x2
        y4 = f(x4)
        if y4 > y2
            # here we also flip the order so that the small interval is first
            x3 = x1
            y3 = y1
            x1 = x4
            y1 = y4        
        else
            x1 = x2
            y1 = y2
            x2 = x4
            y2 = y4
        end
        iterations += 1
    end
#    println("Found minimum in $(iterations) iterations")
    (x2, y2)
end

line_minimize(x, d, thresh=1e-5) = find_minimum(find_enclosing(x, d)..., thresh)
hello
In [15]:
# use line minimization to find a minimum using Powell's Method

plot_rosen(f)
# set an initial value for x
x_0 = [-1.0, -0.5]
y_0 = f(x_0)
# set an initial set of directions to the axes
ds = eye(2, 2) .* 0.1
#ds = [0.0 1; 1 0]
ys = zeros(size(ds, 2))
iterations = 0
while iterations < 100
    x = x_0
    # search in each direction
    for i in 1:size(ds, 2)
        d = ds[:, i]
        #println("minimizing from ($(x[1]), $(x[2]))")
        # track the new y for each direction minimization
        x, ys[i] = line_minimize(x, d)
        #println("minimum in d=($(d[1]), $(d[2])) was ($(x[1]), $(x[2]))")
        scatter3D(x[1], x[2], zs=ys[i], color="purple", s=20)
    end
    d_new = x - x_0
    #d_new = d_new / norm(d_new)
    i = indmax(ds' * d_new)
    ds[:, i] = d_new
    #println("replacing direction $i with ($(d_new[1]), $(d_new[2]))")
    #println("---------")
    iterations += 1
    if norm(x - x_0) < 1e-7
        break
    end
    x_0 = x
end
println("Finished in $(iterations) iterations")

x = x_0
y = ys[end]
println("x = ($(x[1]), $(x[2]))")
println("y = $y")

gca()[:view_init](elev=60, azim=40)
Finished in 14 iterations
x = (0.9999948501361238, 0.9999902850064573)
y = 6.070940605450799e-11

In [49]:
# now we'll try to use Conjugate Gradient Descent

plot_rosen(f)
# set an initial value for x
x = [-1.0, -0.5]
y = f(x)
# set an initial direction
grad = dfdx(x)
d = -grad
iterations = 0
while iterations < 100 # emergency stop
    x_new, y_new = line_minimize(x, d / norm(d))
    iterations += 1
    if norm(x_new - x) < 1e-7
        break
    end
    grad_new = dfdx(x_new)
    gamma = dot(grad_new, (grad_new - grad)) / dot(grad, grad)
    if gamma < 0
        gamma = 0
    end
    d = grad_new + gamma * d
    x = x_new
    y = y_new
    scatter3D(x[1], x[2], zs=y, color="purple", s=20)

end
println("Finished in $(iterations) iterations")
println("x = ($(x[1]), $(x[2]))")
println("y = $y")

gca()[:view_init](elev=60, azim=40)
Finished in 100 iterations
x = (0.7055051820165372, 0.4968664005300224)
y = 0.08680309002403648