00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028 #include "numeric/funcobj.hxx"
00029 #include <string>
00030 #include <memory>
00031 #include <vector>
00032 #include <exception>
00033 #include <cstdlib>
00034
00035 using namespace ::scsolver::numeric;
00036 using namespace ::std;
00037
00038 class TestFailed : public ::std::exception
00039 {
00040 public:
00041 explicit TestFailed(const char* reason) :
00042 m_reason(reason)
00043 {
00044 }
00045
00046 virtual ~TestFailed() throw()
00047 {
00048 }
00049
00050 const char* what() const throw()
00051 {
00052 return m_reason.c_str();
00053 }
00054
00055 private:
00056 string m_reason;
00057 };
00058
00059 class TestBaseFunc1 : public BaseFuncObj
00060 {
00061 public:
00062 TestBaseFunc1() :
00063 BaseFuncObj(3)
00064 {
00065 }
00066
00067 virtual ~TestBaseFunc1()
00068 {
00069 }
00070
00071 virtual double eval() const
00072 {
00073 double x1 = getVar(0), x2 = getVar(1), x3 = getVar(2);
00074 return x1 + x2 + x3;
00075 }
00076
00080 virtual const::std::string getFuncString() const
00081 {
00082 return string("x1 + x2 + x3");
00083 }
00084 };
00085
00086 class TestBaseFunc2 : public BaseFuncObj
00087 {
00088 public:
00089 TestBaseFunc2() :
00090 BaseFuncObj(6)
00091 {
00092 }
00093
00094 virtual ~TestBaseFunc2()
00095 {
00096 }
00097
00098 virtual double eval() const
00099 {
00100 double x1 = getVar(0), x2 = getVar(1), x3 = getVar(2);
00101 double x4 = getVar(3), x5 = getVar(4), x6 = getVar(5);
00102 return x1 + x2 + x3 + x4 + x5 + x6;
00103 }
00104
00108 virtual const::std::string getFuncString() const
00109 {
00110 return string("x1 + x2 + x3 + x4 + x5 + x6");
00111 }
00112 };
00113
00119 double getRandomNumber()
00120 {
00121 double val = static_cast<double>(rand());
00122 return val/static_cast<double>(RAND_MAX);
00123 }
00124
00125 void checkVarValues(const BaseFuncObj& rBaseFunc, size_t varIndex, double varIndexValue)
00126 {
00127 fprintf(stdout, " checking variable values...\n");
00128
00129 size_t varCount = rBaseFunc.getVarCount();
00130 fprintf(stdout, " (");
00131 for (size_t i = 0; i < varCount; ++i)
00132 {
00133 double varValue = rBaseFunc.getVar(i);
00134 if (i > 0)
00135 fprintf(stdout, ", ");
00136 fprintf(stdout, "%g", varValue);
00137 if (i == varIndex)
00138 {
00139 if (varValue != varIndexValue)
00140 throw TestFailed("variable value is incorrect");
00141 }
00142 else
00143 {
00144 if (varValue != 0.0)
00145 throw TestFailed("locked variable value is not zero");
00146 }
00147 }
00148 fprintf(stdout, ")\n");
00149 }
00150
00151 void resetVarValues(BaseFuncObj& rFuncObj)
00152 {
00153 size_t varCount = rFuncObj.getVarCount();
00154 vector<double> initVars(varCount);
00155 rFuncObj.setVars(initVars);
00156
00157
00158 for (size_t i = 0; i < varCount; ++i)
00159 {
00160 if (rFuncObj.getVar(i) != 0.0)
00161 throw TestFailed("initial variable must be zero");
00162 }
00163 }
00164
00165 void checkVarRatio(BaseFuncObj& rFuncObj, const vector<double>& ratios)
00166 {
00167 fprintf(stdout, " checking ratio...\n");
00168 size_t varCount = rFuncObj.getVarCount();
00169 if (varCount < 2)
00170 return;
00171
00172 double var1 = rFuncObj.getVar(0);
00173 double ratio1 = ratios.at(0);
00174 for (size_t i = 1; i < varCount; ++i)
00175 {
00176 double _var1 = rFuncObj.getVar(i);
00177 double _var2 = var1 * ratios.at(i)/ratio1;
00178 double delta = _var1/_var2 - 1.0;
00179 fprintf(stdout, " var = %g vs %g \t (delta = %g)\n",
00180 _var1, _var2, delta);
00181 if ((delta > 0 ? delta : -delta) > 5.0e-16)
00182 throw TestFailed("ratio is incorrect");
00183 }
00184 }
00185
00186 void runTest(BaseFuncObj* p)
00187 {
00188 auto_ptr<BaseFuncObj> pFuncObj(p);
00189 size_t varCount = pFuncObj->getVarCount();
00190
00191 fprintf(stdout, "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n");
00192 fprintf(stdout, "running test on function: %s (variable count: %d)\n",
00193 pFuncObj->getFuncString().c_str(), pFuncObj->getVarCount());
00194 fprintf(stdout, "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n");
00195
00196
00197 for (size_t i = 0; i < varCount; ++i)
00198 {
00199 resetVarValues(*pFuncObj);
00200 SingleVarFuncObj& rSingleVarFunc = pFuncObj->getSingleVarFuncObj(i);
00201 for (size_t incStep = 0; incStep < 20; ++incStep)
00202 {
00203 double newVar = rSingleVarFunc.getVar() + 3.0;
00204 rSingleVarFunc.setVar(newVar);
00205 checkVarValues(*pFuncObj, i, newVar);
00206 }
00207 }
00208
00209
00210 resetVarValues(*pFuncObj);
00211 vector<double> ratios;
00212 for (size_t i = 0; i < varCount; ++i)
00213 ratios.push_back(getRandomNumber());
00214 SingleVarFuncObj& rSingleVarFunc = pFuncObj->getSingleVarFuncObjByRatio(ratios);
00215 for (size_t incStep = 0; incStep < 20; ++incStep)
00216 {
00217 double newVar = rSingleVarFunc.getVar() + 3.0;
00218 rSingleVarFunc.setVar(newVar);
00219 checkVarRatio(*pFuncObj, ratios);
00220 }
00221 }
00222
00223 int main()
00224 {
00225 try
00226 {
00227 runTest(new TestBaseFunc1);
00228 runTest(new TestBaseFunc2);
00229 fprintf(stdout, "test successful\n");
00230 }
00231 catch (const ::std::exception& e)
00232 {
00233 fprintf(stdout, "%s\n", e.what());
00234 }
00235 }