tensor.cpp 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584
  1. #include "stdafx.h"
  2. #include "defs.h"
  3. static void promote_tensor(void);
  4. static int compatible(U *, U *);
  5. //-----------------------------------------------------------------------------
  6. //
  7. // Called from the "eval" module to evaluate tensor elements.
  8. //
  9. // p1 points to the tensor operand.
  10. //
  11. //-----------------------------------------------------------------------------
  12. void
  13. eval_tensor(void)
  14. {
  15. int i, ndim, nelem;
  16. U **a, **b;
  17. //---------------------------------------------------------------------
  18. //
  19. // create a new tensor for the result
  20. //
  21. //---------------------------------------------------------------------
  22. nelem = p1->u.tensor->nelem;
  23. ndim = p1->u.tensor->ndim;
  24. p2 = alloc_tensor(nelem);
  25. p2->u.tensor->ndim = ndim;
  26. for (i = 0; i < ndim; i++)
  27. p2->u.tensor->dim[i] = p1->u.tensor->dim[i];
  28. //---------------------------------------------------------------------
  29. //
  30. // b = eval(a)
  31. //
  32. //---------------------------------------------------------------------
  33. a = p1->u.tensor->elem;
  34. b = p2->u.tensor->elem;
  35. for (i = 0; i < nelem; i++) {
  36. push(a[i]);
  37. eval();
  38. b[i] = pop();
  39. }
  40. //---------------------------------------------------------------------
  41. //
  42. // push the result
  43. //
  44. //---------------------------------------------------------------------
  45. push(p2);
  46. promote_tensor();
  47. }
  48. //-----------------------------------------------------------------------------
  49. //
  50. // Add tensors
  51. //
  52. // Input: Operands on stack
  53. //
  54. // Output: Result on stack
  55. //
  56. //-----------------------------------------------------------------------------
  57. void
  58. tensor_plus_tensor(void)
  59. {
  60. int i, ndim, nelem;
  61. U **a, **b, **c;
  62. save();
  63. p2 = pop();
  64. p1 = pop();
  65. // are the dimension lists equal?
  66. ndim = p1->u.tensor->ndim;
  67. if (ndim != p2->u.tensor->ndim) {
  68. push(symbol(NIL));
  69. restore();
  70. return;
  71. }
  72. for (i = 0; i < ndim; i++)
  73. if (p1->u.tensor->dim[i] != p2->u.tensor->dim[i]) {
  74. push(symbol(NIL));
  75. restore();
  76. return;
  77. }
  78. // create a new tensor for the result
  79. nelem = p1->u.tensor->nelem;
  80. p3 = alloc_tensor(nelem);
  81. p3->u.tensor->ndim = ndim;
  82. for (i = 0; i < ndim; i++)
  83. p3->u.tensor->dim[i] = p1->u.tensor->dim[i];
  84. // c = a + b
  85. a = p1->u.tensor->elem;
  86. b = p2->u.tensor->elem;
  87. c = p3->u.tensor->elem;
  88. for (i = 0; i < nelem; i++) {
  89. push(a[i]);
  90. push(b[i]);
  91. add();
  92. c[i] = pop();
  93. }
  94. // push the result
  95. push(p3);
  96. restore();
  97. }
  98. //-----------------------------------------------------------------------------
  99. //
  100. // careful not to reorder factors
  101. //
  102. //-----------------------------------------------------------------------------
  103. void
  104. tensor_times_scalar(void)
  105. {
  106. int i, ndim, nelem;
  107. U **a, **b;
  108. save();
  109. p2 = pop();
  110. p1 = pop();
  111. ndim = p1->u.tensor->ndim;
  112. nelem = p1->u.tensor->nelem;
  113. p3 = alloc_tensor(nelem);
  114. p3->u.tensor->ndim = ndim;
  115. for (i = 0; i < ndim; i++)
  116. p3->u.tensor->dim[i] = p1->u.tensor->dim[i];
  117. a = p1->u.tensor->elem;
  118. b = p3->u.tensor->elem;
  119. for (i = 0; i < nelem; i++) {
  120. push(a[i]);
  121. push(p2);
  122. multiply();
  123. b[i] = pop();
  124. }
  125. push(p3);
  126. restore();
  127. }
  128. void
  129. scalar_times_tensor(void)
  130. {
  131. int i, ndim, nelem;
  132. U **a, **b;
  133. save();
  134. p2 = pop();
  135. p1 = pop();
  136. ndim = p2->u.tensor->ndim;
  137. nelem = p2->u.tensor->nelem;
  138. p3 = alloc_tensor(nelem);
  139. p3->u.tensor->ndim = ndim;
  140. for (i = 0; i < ndim; i++)
  141. p3->u.tensor->dim[i] = p2->u.tensor->dim[i];
  142. a = p2->u.tensor->elem;
  143. b = p3->u.tensor->elem;
  144. for (i = 0; i < nelem; i++) {
  145. push(p1);
  146. push(a[i]);
  147. multiply();
  148. b[i] = pop();
  149. }
  150. push(p3);
  151. restore();
  152. }
  153. int
  154. is_square_matrix(U *p)
  155. {
  156. if (istensor(p) && p->u.tensor->ndim == 2 && p->u.tensor->dim[0] == p->u.tensor->dim[1])
  157. return 1;
  158. else
  159. return 0;
  160. }
  161. //-----------------------------------------------------------------------------
  162. //
  163. // gradient of tensor
  164. //
  165. //-----------------------------------------------------------------------------
  166. void
  167. d_tensor_tensor(void)
  168. {
  169. int i, j, ndim, nelem;
  170. U **a, **b, **c;
  171. ndim = p1->u.tensor->ndim;
  172. nelem = p1->u.tensor->nelem;
  173. if (ndim + 1 >= MAXDIM)
  174. goto dont_evaluate;
  175. p3 = alloc_tensor(nelem * p2->u.tensor->nelem);
  176. p3->u.tensor->ndim = ndim + 1;
  177. for (i = 0; i < ndim; i++)
  178. p3->u.tensor->dim[i] = p1->u.tensor->dim[i];
  179. p3->u.tensor->dim[ndim] = p2->u.tensor->dim[0];
  180. a = p1->u.tensor->elem;
  181. b = p2->u.tensor->elem;
  182. c = p3->u.tensor->elem;
  183. for (i = 0; i < nelem; i++) {
  184. for (j = 0; j < p2->u.tensor->nelem; j++) {
  185. push(a[i]);
  186. push(b[j]);
  187. derivative();
  188. c[i * p2->u.tensor->nelem + j] = pop();
  189. }
  190. }
  191. push(p3);
  192. return;
  193. dont_evaluate:
  194. push_symbol(DERIVATIVE);
  195. push(p1);
  196. push(p2);
  197. list(3);
  198. }
  199. //-----------------------------------------------------------------------------
  200. //
  201. // gradient of scalar
  202. //
  203. //-----------------------------------------------------------------------------
  204. void
  205. d_scalar_tensor(void)
  206. {
  207. int i;
  208. U **a, **b;
  209. p3 = alloc_tensor(p2->u.tensor->nelem);
  210. p3->u.tensor->ndim = 1;
  211. p3->u.tensor->dim[0] = p2->u.tensor->dim[0];
  212. a = p2->u.tensor->elem;
  213. b = p3->u.tensor->elem;
  214. for (i = 0; i < p2->u.tensor->nelem; i++) {
  215. push(p1);
  216. push(a[i]);
  217. derivative();
  218. b[i] = pop();
  219. }
  220. push(p3);
  221. }
  222. //-----------------------------------------------------------------------------
  223. //
  224. // Derivative of tensor
  225. //
  226. //-----------------------------------------------------------------------------
  227. void
  228. d_tensor_scalar(void)
  229. {
  230. int i;
  231. U **a, **b;
  232. p3 = alloc_tensor(p1->u.tensor->nelem);
  233. p3->u.tensor->ndim = p1->u.tensor->ndim;
  234. for (i = 0; i < p1->u.tensor->ndim; i++)
  235. p3->u.tensor->dim[i] = p1->u.tensor->dim[i];
  236. a = p1->u.tensor->elem;
  237. b = p3->u.tensor->elem;
  238. for (i = 0; i < p1->u.tensor->nelem; i++) {
  239. push(a[i]);
  240. push(p2);
  241. derivative();
  242. b[i] = pop();
  243. }
  244. push(p3);
  245. }
  246. int
  247. compare_tensors(U *p1, U *p2)
  248. {
  249. int i;
  250. if (p1->u.tensor->ndim < p2->u.tensor->ndim)
  251. return -1;
  252. if (p1->u.tensor->ndim > p2->u.tensor->ndim)
  253. return 1;
  254. for (i = 0; i < p1->u.tensor->ndim; i++) {
  255. if (p1->u.tensor->dim[i] < p2->u.tensor->dim[i])
  256. return -1;
  257. if (p1->u.tensor->dim[i] > p2->u.tensor->dim[i])
  258. return 1;
  259. }
  260. for (i = 0; i < p1->u.tensor->nelem; i++) {
  261. if (equal(p1->u.tensor->elem[i], p2->u.tensor->elem[i]))
  262. continue;
  263. if (lessp(p1->u.tensor->elem[i], p2->u.tensor->elem[i]))
  264. return -1;
  265. else
  266. return 1;
  267. }
  268. return 0;
  269. }
  270. //-----------------------------------------------------------------------------
  271. //
  272. // Raise a tensor to a power
  273. //
  274. // Input: p1 tensor
  275. //
  276. // p2 exponent
  277. //
  278. // Output: Result on stack
  279. //
  280. //-----------------------------------------------------------------------------
  281. void
  282. power_tensor(void)
  283. {
  284. int i, k, n;
  285. // first and last dims must be equal
  286. k = p1->u.tensor->ndim - 1;
  287. if (p1->u.tensor->dim[0] != p1->u.tensor->dim[k]) {
  288. push_symbol(POWER);
  289. push(p1);
  290. push(p2);
  291. list(3);
  292. return;
  293. }
  294. push(p2);
  295. n = pop_integer();
  296. if (n == (int) 0x80000000) {
  297. push_symbol(POWER);
  298. push(p1);
  299. push(p2);
  300. list(3);
  301. return;
  302. }
  303. if (n == 0) {
  304. if (p1->u.tensor->ndim != 2)
  305. stop("power(tensor,0) with tensor rank not equal to 2");
  306. n = p1->u.tensor->dim[0];
  307. p1 = alloc_tensor(n * n);
  308. p1->u.tensor->ndim = 2;
  309. p1->u.tensor->dim[0] = n;
  310. p1->u.tensor->dim[1] = n;
  311. for (i = 0; i < n; i++)
  312. p1->u.tensor->elem[n * i + i] = one;
  313. push(p1);
  314. return;
  315. }
  316. if (n < 0) {
  317. n = -n;
  318. push(p1);
  319. inv();
  320. p1 = pop();
  321. }
  322. push(p1);
  323. for (i = 1; i < n; i++) {
  324. push(p1);
  325. inner();
  326. if (iszero(stack[tos - 1]))
  327. break;
  328. }
  329. }
  330. void
  331. copy_tensor(void)
  332. {
  333. int i;
  334. save();
  335. p1 = pop();
  336. p2 = alloc_tensor(p1->u.tensor->nelem);
  337. p2->u.tensor->ndim = p1->u.tensor->ndim;
  338. for (i = 0; i < p1->u.tensor->ndim; i++)
  339. p2->u.tensor->dim[i] = p1->u.tensor->dim[i];
  340. for (i = 0; i < p1->u.tensor->nelem; i++)
  341. p2->u.tensor->elem[i] = p1->u.tensor->elem[i];
  342. push(p2);
  343. restore();
  344. }
  345. // Tensors with elements that are also tensors get promoted to a higher rank.
  346. static void
  347. promote_tensor(void)
  348. {
  349. int i, j, k, nelem, ndim;
  350. save();
  351. p1 = pop();
  352. if (!istensor(p1)) {
  353. push(p1);
  354. restore();
  355. return;
  356. }
  357. p2 = p1->u.tensor->elem[0];
  358. for (i = 1; i < p1->u.tensor->nelem; i++)
  359. if (!compatible(p2, p1->u.tensor->elem[i]))
  360. stop("Cannot promote tensor due to inconsistent tensor components.");
  361. if (!istensor(p2)) {
  362. push(p1);
  363. restore();
  364. return;
  365. }
  366. ndim = p1->u.tensor->ndim + p2->u.tensor->ndim;
  367. if (ndim > MAXDIM)
  368. stop("tensor rank > 24");
  369. nelem = p1->u.tensor->nelem * p2->u.tensor->nelem;
  370. p3 = alloc_tensor(nelem);
  371. p3->u.tensor->ndim = ndim;
  372. for (i = 0; i < p1->u.tensor->ndim; i++)
  373. p3->u.tensor->dim[i] = p1->u.tensor->dim[i];
  374. for (j = 0; j < p2->u.tensor->ndim; j++)
  375. p3->u.tensor->dim[i + j] = p2->u.tensor->dim[j];
  376. k = 0;
  377. for (i = 0; i < p1->u.tensor->nelem; i++) {
  378. p2 = p1->u.tensor->elem[i];
  379. for (j = 0; j < p2->u.tensor->nelem; j++)
  380. p3->u.tensor->elem[k++] = p2->u.tensor->elem[j];
  381. }
  382. push(p3);
  383. restore();
  384. }
  385. static int
  386. compatible(U *p, U *q)
  387. {
  388. int i;
  389. if (!istensor(p) && !istensor(q))
  390. return 1;
  391. if (!istensor(p) || !istensor(q))
  392. return 0;
  393. if (p->u.tensor->ndim != q->u.tensor->ndim)
  394. return 0;
  395. for (i = 0; i < p->u.tensor->ndim; i++)
  396. if (p->u.tensor->dim[i] != q->u.tensor->dim[i])
  397. return 0;
  398. return 1;
  399. }
  400. #if SELFTEST
  401. static char *s[] = {
  402. "#test_tensor",
  403. "a=(1,2,3)",
  404. "",
  405. "b=(4,5,6)",
  406. "",
  407. "c=(7,8,9)",
  408. "",
  409. "rank((a,b,c))",
  410. "2",
  411. "(a,b,c)",
  412. "((1,2,3),(4,5,6),(7,8,9))",
  413. // check tensor promotion
  414. "((1,0),(0,0))",
  415. "((1,0),(0,0))",
  416. "a=quote(a)",
  417. "",
  418. "b=quote(b)",
  419. "",
  420. "c=quote(c)",
  421. "",
  422. };
  423. void
  424. test_tensor(void)
  425. {
  426. test(__FILE__, s, sizeof s / sizeof (char *));
  427. }
  428. #endif