/* * NelderMead.java * Created on Mar 26, 2005 */ package search; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.ListIterator; import java.util.Map; /** * A Nelder-Mead simplex search. */ public final class NelderMead { private final Function f; private final int dim; private final List/**/ simplex; private final int minIdx = 0, maxIdx, nextIdx; private final Comparator pointComparator = new PointComparator(); private final Map/**/ point2Value = new HashMap(); private void putValue(Point p, double d) { point2Value.put(p, new Double(d)); } private double getValue(Point p) { return ((Double)point2Value.get(p)).doubleValue(); } public NelderMead(Function f, Point[] initialSimplex) { this.f = f; this.dim = initialSimplex.length - 1; this.simplex = new ArrayList(Arrays.asList(initialSimplex)); this.maxIdx = simplex.size() - 1; this.nextIdx = maxIdx - 1; } public Point search() { // evaluate the function at each point in simplex, for(Iterator/**/ i = simplex.iterator(); i.hasNext();) { Point p = (Point)i.next(); putValue(p, f.at(p)); } //printData(); // sort the points and get the best, worst, and next to worst Collections.sort(simplex, pointComparator); Point minP = (Point)simplex.get(minIdx); Point maxP = (Point)simplex.get(maxIdx); Point nextP = (Point)simplex.get(nextIdx); double min = getValue(minP); double max = getValue(maxP); double next = getValue(nextP); //printPointValue("min", minP, min); //printPointValue("max", maxP, max); //printPointValue("next", nextP, next); // sum all but the worst point Point total = new Point(dim); for(Iterator/**/ i = simplex.iterator(); i.hasNext();) { Point p = (Point)i.next(); if (!p.equals(maxP)) total = total.add(p); } Point mean = total.div(dim); //System.out.println("mean: " + mean); // reflect Point reflect = mean.mul(2).sub(maxP); double reflectVal = f.at(reflect); //printPointValue("reflect", reflect, reflectVal); // check if reflect or reflectGrow is better if (reflectVal < min) { Point reflectGrow = mean.mul(3).sub(maxP.mul(2)); double reflectGrowVal = f.at(reflectGrow); //printPointValue("reflectGrow", reflectGrow, reflectGrowVal); if (reflectGrowVal < reflectVal) { //System.out.println("reflect and grow!"); simplex.set(maxIdx, reflectGrow); return checkStop(max, reflectGrowVal, reflectGrow); } else { //System.out.println("reflect!"); simplex.set(maxIdx, reflect); return checkStop(max, reflectVal, reflect); } } else if (reflectVal < next) { //System.out.println("reflect!"); simplex.set(maxIdx, reflect); return checkStop(max, reflectVal, reflect); } // reflect and shrink Point reflectShrink = mean.mul(3).sub(maxP).div(2); double reflectShrinkVal = f.at(reflectShrink); //printPointValue("reflectShrink", reflectShrink, reflectShrinkVal); if (reflectShrinkVal < next) { //System.out.println("reflect and shrink!"); simplex.set(maxIdx, reflectShrink); return checkStop(max, reflectShrinkVal, reflectShrink); } // shrink Point shrink = mean.add(maxP).div(2); double shrinkVal = f.at(shrink); //printPointValue("shrink", shrink, shrinkVal); if (shrinkVal < next) { //System.out.println("shrink!"); simplex.set(maxIdx, shrink); return checkStop(max, shrinkVal, shrink); } // shrink all //System.out.println("shrink all!"); for (ListIterator/**/ i = simplex.listIterator(); i.hasNext();) { Point p = (Point)i.next(); if (!p.equals(minP)) i.set(p.add(minP).div(2)); } return search(); } private Point checkStop(double oldMin, double newMin, Point p) { //printPointValue("new", p, newMin); return (oldMin - newMin < .001) ? p : search(); } private void printData() { System.out.println(); System.out.print("["); for(Iterator/**/ i = simplex.iterator(); i.hasNext();) { Point p = (Point)i.next(); double d = getValue(p); System.out.print(p + " = " + d); if (i.hasNext()) System.out.print("; "); } System.out.println("]"); } private static void printPointValue(String s, Point p, double v) { System.out.println(s + " : " + p + " = " + v); } private class PointComparator implements Comparator { public int compare(Object o1, Object o2) { Point p1 = (Point)o1; Point p2 = (Point)o2; double v1 = getValue(p1); double v2 = getValue(p2); return Double.compare(v1, v2); } } }