inv.cpp 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. //-----------------------------------------------------------------------------
  2. //
  3. // Input: Matrix on stack
  4. //
  5. // Output: Inverse on stack
  6. //
  7. // Example:
  8. //
  9. // > inv(((1,2),(3,4))
  10. // ((-2,1),(3/2,-1/2))
  11. //
  12. // Note:
  13. //
  14. // Uses Gaussian elimination for numerical matrices.
  15. //
  16. //-----------------------------------------------------------------------------
  17. #include "stdafx.h"
  18. #include "defs.h"
  19. static int
  20. check_arg(void)
  21. {
  22. if (!istensor(p1))
  23. return 0;
  24. else if (p1->u.tensor->ndim != 2)
  25. return 0;
  26. else if (p1->u.tensor->dim[0] != p1->u.tensor->dim[1])
  27. return 0;
  28. else
  29. return 1;
  30. }
  31. void
  32. inv(void)
  33. {
  34. int i, n;
  35. U **a;
  36. save();
  37. p1 = pop();
  38. if (check_arg() == 0) {
  39. push_symbol(INV);
  40. push(p1);
  41. list(2);
  42. restore();
  43. return;
  44. }
  45. n = p1->u.tensor->nelem;
  46. a = p1->u.tensor->elem;
  47. for (i = 0; i < n; i++)
  48. if (!isnum(a[i]))
  49. break;
  50. if (i == n)
  51. yyinvg();
  52. else {
  53. push(p1);
  54. adj();
  55. push(p1);
  56. det();
  57. p2 = pop();
  58. if (iszero(p2))
  59. stop("inverse of singular matrix");
  60. push(p2);
  61. divide();
  62. }
  63. restore();
  64. }
  65. void
  66. invg(void)
  67. {
  68. save();
  69. p1 = pop();
  70. if (check_arg() == 0) {
  71. push_symbol(INVG);
  72. push(p1);
  73. list(2);
  74. restore();
  75. return;
  76. }
  77. yyinvg();
  78. restore();
  79. }
  80. // inverse using gaussian elimination
  81. void
  82. yyinvg(void)
  83. {
  84. int h, i, j, n;
  85. n = p1->u.tensor->dim[0];
  86. h = tos;
  87. for (i = 0; i < n; i++)
  88. for (j = 0; j < n; j++)
  89. if (i == j)
  90. push(one);
  91. else
  92. push(zero);
  93. for (i = 0; i < n * n; i++)
  94. push(p1->u.tensor->elem[i]);
  95. decomp(n);
  96. p1 = alloc_tensor(n * n);
  97. p1->u.tensor->ndim = 2;
  98. p1->u.tensor->dim[0] = n;
  99. p1->u.tensor->dim[1] = n;
  100. for (i = 0; i < n * n; i++)
  101. p1->u.tensor->elem[i] = stack[h + i];
  102. tos -= 2 * n * n;
  103. push(p1);
  104. }
  105. //-----------------------------------------------------------------------------
  106. //
  107. // Input: n * n unit matrix on stack
  108. //
  109. // n * n operand on stack
  110. //
  111. // Output: n * n inverse matrix on stack
  112. //
  113. // n * n garbage on stack
  114. //
  115. // p2 mangled
  116. //
  117. //-----------------------------------------------------------------------------
  118. #define A(i, j) stack[a + n * (i) + (j)]
  119. #define U(i, j) stack[u + n * (i) + (j)]
  120. void
  121. decomp(int n)
  122. {
  123. int a, d, i, j, u;
  124. a = tos - n * n;
  125. u = a - n * n;
  126. for (d = 0; d < n; d++) {
  127. // diagonal element zero?
  128. if (equal(A(d, d), zero)) {
  129. // find a new row
  130. for (i = d + 1; i < n; i++)
  131. if (!equal(A(i, d), zero))
  132. break;
  133. if (i == n)
  134. stop("inverse of singular matrix");
  135. // exchange rows
  136. for (j = 0; j < n; j++) {
  137. p2 = A(d, j);
  138. A(d, j) = A(i, j);
  139. A(i, j) = p2;
  140. p2 = U(d, j);
  141. U(d, j) = U(i, j);
  142. U(i, j) = p2;
  143. }
  144. }
  145. // multiply the pivot row by 1 / pivot
  146. p2 = A(d, d);
  147. for (j = 0; j < n; j++) {
  148. if (j > d) {
  149. push(A(d, j));
  150. push(p2);
  151. divide();
  152. A(d, j) = pop();
  153. }
  154. push(U(d, j));
  155. push(p2);
  156. divide();
  157. U(d, j) = pop();
  158. }
  159. // clear out the column above and below the pivot
  160. for (i = 0; i < n; i++) {
  161. if (i == d)
  162. continue;
  163. // multiplier
  164. p2 = A(i, d);
  165. // add pivot row to i-th row
  166. for (j = 0; j < n; j++) {
  167. if (j > d) {
  168. push(A(i, j));
  169. push(A(d, j));
  170. push(p2);
  171. multiply();
  172. subtract();
  173. A(i, j) = pop();
  174. }
  175. push(U(i, j));
  176. push(U(d, j));
  177. push(p2);
  178. multiply();
  179. subtract();
  180. U(i, j) = pop();
  181. }
  182. }
  183. }
  184. }