00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028 #include "numeric/quasinewton.hxx"
00029 #include "numeric/funcobj.hxx"
00030 #include "numeric/type.hxx"
00031 #include "numeric/nlpmodel.hxx"
00032 #include "numeric/exception.hxx"
00033
00034 using namespace ::scsolver::numeric;
00035 using namespace ::std;
00036 using ::scsolver::numeric::nlp::QuasiNewton;
00037
00038 void debugPrint(const vector<double>& vars, const char* msg)
00039 {
00040 FILE* fs = stdout;
00041 fprintf(fs, "%s: (", msg);
00042 size_t n = vars.size();
00043 for (size_t i = 0; i < n; ++i)
00044 {
00045 if (i > 0)
00046 fprintf(fs, ", ");
00047 fprintf(fs, "%g", vars[i]);
00048 }
00049 fprintf(fs, ")\n");
00050 }
00051
00052 class TestFunc1 : public BaseFuncObj
00053 {
00054 public:
00055 TestFunc1() :
00056 BaseFuncObj(2)
00057 {
00058
00059 setVar(0, 0);
00060 setVar(1, 3);
00061 }
00062
00063 virtual ~TestFunc1()
00064 {
00065 }
00066
00067 virtual double eval() const
00068 {
00069 const vector<double>& vars = getVars();
00070 double term1 = vars[0] - 2;
00071 term1 *= term1*term1*term1;
00072
00073 double term2 = vars[0] - 2.0*vars[1];
00074 term2 *= term2;
00075
00076 return term1 + term2;
00077 }
00078
00082 virtual const string getFuncString() const
00083 {
00084 return string("(x1 - 2)^4 + (x1 - 2*x2)^2");
00085 }
00086 };
00087
00088 void runTest(BaseFuncObj* pFuncObj)
00089 {
00090 auto_ptr<BaseFuncObj> func(pFuncObj);
00091 nlp::Model model;
00092 model.setGoal(GOAL_MINIMIZE);
00093 model.setFuncObject(func.get());
00094
00095
00096 const vector<double>& vars = func->getVars();
00097 vector<double>::const_iterator itr = vars.begin(), itrEnd = vars.end();
00098 for (; itr != itrEnd; ++itr)
00099 model.pushVar(*itr);
00100
00101 auto_ptr<QuasiNewton> nlpSolver(new QuasiNewton);
00102 nlpSolver->setModel(&model);
00103 try
00104 {
00105 nlpSolver->solve();
00106 const vector<double>& sol = nlpSolver->getSolution();
00107 debugPrint(sol, "solution");
00108 }
00109 catch (const ::std::exception& e)
00110 {
00111 fprintf(stdout, " standard exception: %s\n", e.what());
00112 }
00113 }
00114
00115 int main()
00116 {
00117 runTest(new TestFunc1);
00118 }
00119