det.cpp 4.0 KB


  1. //-----------------------------------------------------------------------------
  2. //
  3. // Input: Matrix on stack
  4. //
  5. // Output: Determinant on stack
  6. //
  7. // Example:
  8. //
  9. // > det(((1,2),(3,4)))
  10. // -2
  11. //
  12. // Note:
  13. //
  14. // Uses Gaussian elimination for numerical matrices.
  15. //
  16. //-----------------------------------------------------------------------------
  17. #include "stdafx.h"
  18. #include "defs.h"
  19. static int
  20. check_arg(void)
  21. {
  22. if (!istensor(p1))
  23. return 0;
  24. else if (p1->u.tensor->ndim != 2)
  25. return 0;
  26. else if (p1->u.tensor->dim[0] != p1->u.tensor->dim[1])
  27. return 0;
  28. else
  29. return 1;
  30. }
  31. void
  32. det(void)
  33. {
  34. int i, n;
  35. U **a;
  36. save();
  37. p1 = pop();
  38. if (check_arg() == 0) {
  39. push_symbol(DET);
  40. push(p1);
  41. list(2);
  42. restore();
  43. return;
  44. }
  45. n = p1->u.tensor->nelem;
  46. a = p1->u.tensor->elem;
  47. for (i = 0; i < n; i++)
  48. if (!isnum(a[i]))
  49. break;
  50. if (i == n)
  51. yydetg();
  52. else {
  53. for (i = 0; i < p1->u.tensor->nelem; i++)
  54. push(p1->u.tensor->elem[i]);
  55. determinant(p1->u.tensor->dim[0]);
  56. }
  57. restore();
  58. }
  59. // determinant of n * n matrix elements on the stack
  60. void
  61. determinant(int n)
  62. {
  63. int h, i, j, k, q, s, sign, t;
  64. int *a, *c, *d;
  65. h = tos - n * n;
  66. a = (int *) malloc(3 * n * sizeof (int));
  67. if (a == NULL)
  68. out_of_memory();
  69. c = a + n;
  70. d = c + n;
  71. for (i = 0; i < n; i++) {
  72. a[i] = i;
  73. c[i] = 0;
  74. d[i] = 1;
  75. }
  76. sign = 1;
  77. push(zero);
  78. for (;;) {
  79. if (sign == 1)
  80. push_integer(1);
  81. else
  82. push_integer(-1);
  83. for (i = 0; i < n; i++) {
  84. k = n * a[i] + i;
  85. push(stack[h + k]);
  86. multiply(); // FIXME -- problem here
  87. }
  88. add();
  89. /* next permutation (Knuth's algorithm P) */
  90. j = n - 1;
  91. s = 0;
  92. P4: q = c[j] + d[j];
  93. if (q < 0) {
  94. d[j] = -d[j];
  95. j--;
  96. goto P4;
  97. }
  98. if (q == j + 1) {
  99. if (j == 0)
  100. break;
  101. s++;
  102. d[j] = -d[j];
  103. j--;
  104. goto P4;
  105. }
  106. t = a[j - c[j] + s];
  107. a[j - c[j] + s] = a[j - q + s];
  108. a[j - q + s] = t;
  109. c[j] = q;
  110. sign = -sign;
  111. }
  112. free(a);
  113. stack[h] = stack[tos - 1];
  114. tos = h + 1;
  115. }
  116. //-----------------------------------------------------------------------------
  117. //
  118. // Input: Matrix on stack
  119. //
  120. // Output: Determinant on stack
  121. //
  122. // Note:
  123. //
  124. // Uses Gaussian elimination which is faster for numerical matrices.
  125. //
  126. // Gaussian Elimination works by walking down the diagonal and clearing
  127. // out the columns below it.
  128. //
  129. //-----------------------------------------------------------------------------
  130. void
  131. detg(void)
  132. {
  133. save();
  134. p1 = pop();
  135. if (check_arg() == 0) {
  136. push_symbol(DET);
  137. push(p1);
  138. list(2);
  139. restore();
  140. return;
  141. }
  142. yydetg();
  143. restore();
  144. }
  145. void
  146. yydetg(void)
  147. {
  148. int i, n;
  149. n = p1->u.tensor->dim[0];
  150. for (i = 0; i < n * n; i++)
  151. push(p1->u.tensor->elem[i]);
  152. lu_decomp(n);
  153. tos -= n * n;
  154. push(p1);
  155. }
  156. //-----------------------------------------------------------------------------
  157. //
  158. // Input: n * n matrix elements on stack
  159. //
  160. // Output: p1 determinant
  161. //
  162. // p2 mangled
  163. //
  164. // upper diagonal matrix on stack
  165. //
  166. //-----------------------------------------------------------------------------
  167. #define M(i, j) stack[h + n * (i) + (j)]
  168. void
  169. lu_decomp(int n)
  170. {
  171. int d, h, i, j;
  172. h = tos - n * n;
  173. p1 = one;
  174. for (d = 0; d < n - 1; d++) {
  175. // diagonal element zero?
  176. if (equal(M(d, d), zero)) {
  177. // find a new row
  178. for (i = d + 1; i < n; i++)
  179. if (!equal(M(i, d), zero))
  180. break;
  181. if (i == n) {
  182. p1 = zero;
  183. break;
  184. }
  185. // exchange rows
  186. for (j = d; j < n; j++) {
  187. p2 = M(d, j);
  188. M(d, j) = M(i, j);
  189. M(i, j) = p2;
  190. }
  191. // negate det
  192. push(p1);
  193. negate();
  194. p1 = pop();
  195. }
  196. // update det
  197. push(p1);
  198. push(M(d, d));
  199. multiply();
  200. p1 = pop();
  201. // update lower diagonal matrix
  202. for (i = d + 1; i < n; i++) {
  203. // multiplier
  204. push(M(i, d));
  205. push(M(d, d));
  206. divide();
  207. negate();
  208. p2 = pop();
  209. // update one row
  210. M(i, d) = zero; // clear column below pivot d
  211. for (j = d + 1; j < n; j++) {
  212. push(M(d, j));
  213. push(p2);
  214. multiply();
  215. push(M(i, j));
  216. add();
  217. M(i, j) = pop();
  218. }
  219. }
  220. }
  221. // last diagonal element
  222. push(p1);
  223. push(M(n - 1, n - 1));
  224. multiply();
  225. p1 = pop();
  226. }