00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028 #include "numeric/polyeqnsolver.hxx"
00029 #include "numeric/matrix.hxx"
00030 #include <list>
00031 #include <cmath>
00032 #include <stdio.h>
00033
00034 using namespace ::scsolver::numeric;
00035 using namespace ::std;
00036
00037 class TestFailed {};
00038
00039 class PolyEqnSolverTest : public PolyEqnSolver
00040 {
00041 public:
00042 void addDataPoint(double x, double y)
00043 {
00044 fprintf(stdout, "PolyEqnSolverTest::addDataPoint: adding (%g, %g)\n", x, y);
00045 DataPoint pt(x, y);
00046 m_DataPoints.push_back(pt);
00047 PolyEqnSolver::addDataPoint(x, y);
00048 }
00049
00050 const Matrix solve()
00051 {
00052 fprintf(stdout, "PolyEqnSolverTest::solve: ------------------------------\n");
00053 Matrix sol = PolyEqnSolver::solve();
00054
00055 if (sol.cols() != 1)
00056 {
00057 printf("solution must be a single-column matrix.\n");
00058 throw TestFailed();
00059 }
00060
00061 printf("solution = ");
00062 sol.trans().print(5);
00063 verifySolution(sol);
00064 return sol;
00065 }
00066
00067 void clear()
00068 {
00069 PolyEqnSolver::clear();
00070 size_t n = PolyEqnSolver::size();
00071 if (n != 0)
00072 {
00073 fprintf(stdout, "PolyEqnSolverTest::clear: data point size is not zero.\n");
00074 throw TestFailed();
00075 }
00076
00077
00078 m_DataPoints.clear();
00079 }
00080
00081 private:
00082
00083 void verifySolution(const Matrix& solution)
00084 {
00085 size_t deltaCount = 0;
00086 size_t n = solution.rows();
00087 list<DataPoint>::const_iterator itr = m_DataPoints.begin(), itrEnd = m_DataPoints.end();
00088 for (; itr != itrEnd; ++itr)
00089 {
00090 double verTerm = 1.0, leftSum = 0.0;
00091 for (size_t i = 0; i < n; ++i)
00092 {
00093 leftSum += solution(i, 0) * verTerm;
00094 verTerm *= itr->X;
00095 }
00096
00097 if (itr->Y != leftSum)
00098 {
00099 printf(" delta = %.20f\n", itr->Y - leftSum);
00100 ++deltaCount;
00101 }
00102 }
00103
00104 if (!deltaCount)
00105 printf("solution verified\n");
00106 else
00107 {
00108 printf("delta count = %d\n", deltaCount);
00109 throw TestFailed();
00110 }
00111 }
00112
00113 list<DataPoint> m_DataPoints;
00114 };
00115
00116
00117
00118 class QuadPeakTest
00119 {
00120 public:
00121 QuadPeakTest()
00122 {
00123 }
00124
00125 void set(double a, double b, double c)
00126 {
00127 printf("--------------------------------------------------------------------\n");
00128 m_A = a;
00129 m_B = b;
00130 m_C = c;
00131 Matrix coef(3, 1);
00132 coef(0, 0) = m_C;
00133 coef(1, 0) = m_B;
00134 coef(2, 0) = m_A;
00135 getQuadraticPeak(m_X, m_Y, coef);
00136 }
00137
00138 void print() const
00139 {
00140 printf("f(x) = %g x^2 ", m_A);
00141 if (m_B >= 0.0)
00142 printf("+ ");
00143 else
00144 printf("- ");
00145 printf("%g x ", fabs(m_B));
00146 if (m_C >= 0.0)
00147 printf("+ ");
00148 else
00149 printf("- ");
00150 printf("%g : f(x) peaks at (%g, %g)\n", m_C, m_X, m_Y);
00151 }
00152
00153 void verify() const
00154 {
00155 static const double step = 1.0;
00156 static const int count = 10;
00157
00158 double lambda = step;
00159 for (int i = 0; i < count; ++i)
00160 {
00161 double left = eval(m_X - lambda);
00162 double right = eval(m_X + lambda);
00163 double delta = fabs(left - right);
00164
00165 if (delta/left > 0.00000000000005)
00166 {
00167 printf(" delta is not zero (%.20f)\n", delta/left);
00168 throw TestFailed();
00169 }
00170
00171 lambda += step;
00172 }
00173 printf("quadratic peak verified\n");
00174 }
00175
00176 private:
00177
00178 double eval(double x) const
00179 {
00180 return x*x*m_A + x*m_B + m_C;
00181 }
00182
00183 double m_A;
00184 double m_B;
00185 double m_C;
00186
00187 double m_X;
00188 double m_Y;
00189 };
00190
00191 void runTest()
00192 {
00193 PolyEqnSolverTest polySolver;
00194 polySolver.addDataPoint(1.0, 32.0);
00195 polySolver.addDataPoint(5.0, 2.0);
00196 polySolver.addDataPoint(9.0, 10.0);
00197 polySolver.solve();
00198
00199 polySolver.clear();
00200 polySolver.addDataPoint(0.0, 2.0);
00201 polySolver.addDataPoint(2.0, 6.0);
00202 polySolver.solve();
00203
00204 polySolver.clear();
00205 polySolver.addDataPoint(1, 12);
00206 polySolver.addDataPoint(2, 8);
00207 polySolver.addDataPoint(3, 7);
00208 polySolver.solve();
00209
00210 try
00211 {
00212 polySolver.clear();
00213 polySolver.solve();
00214 throw TestFailed();
00215 }
00216 catch( const NotEnoughDataPoints& )
00217 {
00218 printf("NotEnoughDataPoints exception caught on zero data point (expected).\n");
00219 }
00220
00221 try
00222 {
00223 polySolver.clear();
00224 polySolver.addDataPoint(1.0, 1.0);
00225 polySolver.solve();
00226 throw TestFailed();
00227 }
00228 catch( const NotEnoughDataPoints& )
00229 {
00230 printf("NotEnoughDataPoints exception caught on 1 data point (expected).\n");
00231 }
00232 }
00233
00234 void runQuadPeakTest()
00235 {
00236 QuadPeakTest qpt;
00237 for (int a = -8; a < 9; ++a)
00238 {
00239 for (int b = -4; b < 5; ++b)
00240 {
00241 for (int c = -10; c < 11; ++c)
00242 {
00243 qpt.set(a, b, c);
00244 qpt.print();
00245 qpt.verify();
00246 }
00247 }
00248 }
00249 }
00250
00251 int main()
00252 {
00253 try
00254 {
00255 runTest();
00256 runQuadPeakTest();
00257 }
00258 catch ( const TestFailed& )
00259 {
00260 printf("***************************\n");
00261 printf("* UNIT TEST FAILED... *\n");
00262 printf("***************************\n");
00263 return 1;
00264 }
00265
00266 printf("***************************\n");
00267 printf("* UNIT TEST PASSED!!! *\n");
00268 printf("***************************\n");
00269 }