LevenbergMarquardt.php 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. <?php
  2. // Levenberg-Marquardt in PHP
  3. // http://www.idiom.com/~zilla/Computer/Javanumeric/LM.java
  4. class LevenbergMarquardt {
  5. /**
  6. * Calculate the current sum-squared-error
  7. *
  8. * Chi-squared is the distribution of squared Gaussian errors,
  9. * thus the name.
  10. *
  11. * @param double[][] $x
  12. * @param double[] $a
  13. * @param double[] $y,
  14. * @param double[] $s,
  15. * @param object $f
  16. */
  17. function chiSquared($x, $a, $y, $s, $f) {
  18. $npts = count($y);
  19. $sum = 0.0;
  20. for ($i = 0; $i < $npts; ++$i) {
  21. $d = $y[$i] - $f->val($x[$i], $a);
  22. $d = $d / $s[$i];
  23. $sum = $sum + ($d*$d);
  24. }
  25. return $sum;
  26. } // function chiSquared()
  27. /**
  28. * Minimize E = sum {(y[k] - f(x[k],a)) / s[k]}^2
  29. * The individual errors are optionally scaled by s[k].
  30. * Note that LMfunc implements the value and gradient of f(x,a),
  31. * NOT the value and gradient of E with respect to a!
  32. *
  33. * @param x array of domain points, each may be multidimensional
  34. * @param y corresponding array of values
  35. * @param a the parameters/state of the model
  36. * @param vary false to indicate the corresponding a[k] is to be held fixed
  37. * @param s2 sigma^2 for point i
  38. * @param lambda blend between steepest descent (lambda high) and
  39. * jump to bottom of quadratic (lambda zero).
  40. * Start with 0.001.
  41. * @param termepsilon termination accuracy (0.01)
  42. * @param maxiter stop and return after this many iterations if not done
  43. * @param verbose set to zero (no prints), 1, 2
  44. *
  45. * @return the new lambda for future iterations.
  46. * Can use this and maxiter to interleave the LM descent with some other
  47. * task, setting maxiter to something small.
  48. */
  49. function solve($x, $a, $y, $s, $vary, $f, $lambda, $termepsilon, $maxiter, $verbose) {
  50. $npts = count($y);
  51. $nparm = count($a);
  52. if ($verbose > 0) {
  53. print("solve x[".count($x)."][".count($x[0])."]");
  54. print(" a[".count($a)."]");
  55. println(" y[".count(length)."]");
  56. }
  57. $e0 = $this->chiSquared($x, $a, $y, $s, $f);
  58. //double lambda = 0.001;
  59. $done = false;
  60. // g = gradient, H = hessian, d = step to minimum
  61. // H d = -g, solve for d
  62. $H = array();
  63. $g = array();
  64. //double[] d = new double[nparm];
  65. $oos2 = array();
  66. for($i = 0; $i < $npts; ++$i) {
  67. $oos2[$i] = 1./($s[$i]*$s[$i]);
  68. }
  69. $iter = 0;
  70. $term = 0; // termination count test
  71. do {
  72. ++$iter;
  73. // hessian approximation
  74. for( $r = 0; $r < $nparm; ++$r) {
  75. for( $c = 0; $c < $nparm; ++$c) {
  76. for( $i = 0; $i < $npts; ++$i) {
  77. if ($i == 0) $H[$r][$c] = 0.;
  78. $xi = $x[$i];
  79. $H[$r][$c] += ($oos2[$i] * $f->grad($xi, $a, $r) * $f->grad($xi, $a, $c));
  80. } //npts
  81. } //c
  82. } //r
  83. // boost diagonal towards gradient descent
  84. for( $r = 0; $r < $nparm; ++$r)
  85. $H[$r][$r] *= (1. + $lambda);
  86. // gradient
  87. for( $r = 0; $r < $nparm; ++$r) {
  88. for( $i = 0; $i < $npts; ++$i) {
  89. if ($i == 0) $g[$r] = 0.;
  90. $xi = $x[$i];
  91. $g[$r] += ($oos2[$i] * ($y[$i]-$f->val($xi,$a)) * $f->grad($xi, $a, $r));
  92. }
  93. } //npts
  94. // scale (for consistency with NR, not necessary)
  95. if ($false) {
  96. for( $r = 0; $r < $nparm; ++$r) {
  97. $g[$r] = -0.5 * $g[$r];
  98. for( $c = 0; $c < $nparm; ++$c) {
  99. $H[$r][$c] *= 0.5;
  100. }
  101. }
  102. }
  103. // solve H d = -g, evaluate error at new location
  104. //double[] d = DoubleMatrix.solve(H, g);
  105. // double[] d = (new Matrix(H)).lu().solve(new Matrix(g, nparm)).getRowPackedCopy();
  106. //double[] na = DoubleVector.add(a, d);
  107. // double[] na = (new Matrix(a, nparm)).plus(new Matrix(d, nparm)).getRowPackedCopy();
  108. // double e1 = chiSquared(x, na, y, s, f);
  109. // if (verbose > 0) {
  110. // System.out.println("\n\niteration "+iter+" lambda = "+lambda);
  111. // System.out.print("a = ");
  112. // (new Matrix(a, nparm)).print(10, 2);
  113. // if (verbose > 1) {
  114. // System.out.print("H = ");
  115. // (new Matrix(H)).print(10, 2);
  116. // System.out.print("g = ");
  117. // (new Matrix(g, nparm)).print(10, 2);
  118. // System.out.print("d = ");
  119. // (new Matrix(d, nparm)).print(10, 2);
  120. // }
  121. // System.out.print("e0 = " + e0 + ": ");
  122. // System.out.print("moved from ");
  123. // (new Matrix(a, nparm)).print(10, 2);
  124. // System.out.print("e1 = " + e1 + ": ");
  125. // if (e1 < e0) {
  126. // System.out.print("to ");
  127. // (new Matrix(na, nparm)).print(10, 2);
  128. // } else {
  129. // System.out.println("move rejected");
  130. // }
  131. // }
  132. // termination test (slightly different than NR)
  133. // if (Math.abs(e1-e0) > termepsilon) {
  134. // term = 0;
  135. // } else {
  136. // term++;
  137. // if (term == 4) {
  138. // System.out.println("terminating after " + iter + " iterations");
  139. // done = true;
  140. // }
  141. // }
  142. // if (iter >= maxiter) done = true;
  143. // in the C++ version, found that changing this to e1 >= e0
  144. // was not a good idea. See comment there.
  145. //
  146. // if (e1 > e0 || Double.isNaN(e1)) { // new location worse than before
  147. // lambda *= 10.;
  148. // } else { // new location better, accept new parameters
  149. // lambda *= 0.1;
  150. // e0 = e1;
  151. // // simply assigning a = na will not get results copied back to caller
  152. // for( int i = 0; i < nparm; i++ ) {
  153. // if (vary[i]) a[i] = na[i];
  154. // }
  155. // }
  156. } while(!$done);
  157. return $lambda;
  158. } // function solve()
  159. } // class LevenbergMarquardt