ra-ply.H 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. // (c) Daniel Llorens - 2013-2014
  2. // This library is free software; you can redistribute it and/or modify it under
  3. // the terms of the GNU Lesser General Public License as published by the Free
  4. // Software Foundation; either version 3 of the License, or (at your option) any
  5. // later version.
  6. #ifndef RA_PLY_H
  7. #define RA_PLY_H
  8. /// @file ra-ply.H
  9. /// @brief Traverse (ply) array or array expression or array statement.
  10. // @TODO Lots of room for improvement: small (fixed sizes) and large (tiling, etc. see eval.cc in Blitz++).
  11. #include "ra/ra-type.H"
  12. #include <functional>
  13. namespace ra {
  14. // @TODO this to protect against older convention in vtraits; eventually remove.
  15. static_assert(mp::And<std::is_signed<rank_t>, std::is_signed<dim_t>>::value, "bad rank_t");
  16. // --------------
  17. // Run time order, two versions.
  18. // --------------
  19. // @TODO See ply_ravel() for traversal order.
  20. // @TODO A(i0, i1 ...) can be partial-applied as A(i0)(i1 ...) for faster indexing
  21. // @TODO Traversal order should be a parameter, since some operations (e.g. output, ravel) require a specific order.
  22. template <class A>
  23. void ply_index(A && a)
  24. {
  25. // @TODO try and merge the singular cases.
  26. if (a.done()) {
  27. return;
  28. }
  29. rank_t const rank = a.rank();
  30. auto sha(a.shape());
  31. using Shape = decltype(sha);
  32. Shape ind(ra_traits<Shape>::make(rank, 0));
  33. rank_t order[rank];
  34. for (rank_t i=0; i<rank; ++i) {
  35. order[i] = rank-1-i;
  36. }
  37. for (;;) {
  38. a.at(ind);
  39. for (int k=0; ; ++k) {
  40. if (k==rank) {
  41. return;
  42. } else if (++ind[order[k]]<sha[order[k]]) {
  43. break;
  44. } else {
  45. ind[order[k]] = 0;
  46. }
  47. }
  48. }
  49. }
  50. // Traverse array expression looking to ravel the inner loop.
  51. // size() and preferred_stride() are only used on the driving argument (largest rank).
  52. // adv(), stride(), compact_stride() and flat() are used on all the leaf arguments. The strides must give 0 for k>=their own rank, to allow frame matching.
  53. // @TODO Traversal order should be a parameter, since some operations (e.g. output, ravel) require a specific order.
  54. template <class A>
  55. void ply_ravel(A && a)
  56. {
  57. // flat()+step would not see the zero sizes above the axes used for step, so we must check for empty expr anyway.
  58. if (a.done()) {
  59. return;
  60. }
  61. rank_t const rank = a.rank();
  62. if (rank==0) {
  63. *(a.flat());
  64. return;
  65. }
  66. rank_t order[rank];
  67. for (rank_t i=0; i<rank; ++i) {
  68. order[i] = rank-1-i;
  69. }
  70. // find outermost compact dim.
  71. auto step = a.size(order[0]);
  72. int ocd = 1;
  73. for (; ocd!=rank && a.compact_stride(step, order[0], order[ocd]); ++ocd) {
  74. step *= a.size(order[ocd]);
  75. }
  76. step *= a.preferred_stride(order[0]);
  77. // all sub xpr strides advance in compact dims, as they might be different.
  78. auto const ss0(a.stride(order[0]));
  79. // @TODO don't the full sha or ind. But try to use them to do ply_index more efficiently.
  80. auto sha(a.shape());
  81. using Shape = decltype(sha);
  82. Shape ind(ra_traits<Shape>::make(rank, 0));
  83. // @TODO Blitz++ uses explicit stack of end-of-dim p positions, has special cases for common/unit stride.
  84. for (;;) {
  85. auto p = a.flat();
  86. for (auto end=p+step; p!=end; p+=ss0) {
  87. *p;
  88. }
  89. for (int k=ocd; ; ++k) {
  90. if (k>=rank) {
  91. return;
  92. } else if (ind[order[k]]<sha[order[k]]-1) {
  93. ++ind[order[k]];
  94. a.adv(order[k], 1);
  95. break;
  96. } else {
  97. ind[order[k]] = 0;
  98. a.adv(order[k], 1-sha[order[k]]);
  99. }
  100. }
  101. }
  102. }
  103. // -------------------------
  104. // Compile time order. See bench-ra-dot.C for use. Index version.
  105. // -------------------------
  106. template <class order, class A, class S>
  107. std::enable_if_t<mp::Len<order>::value==0>
  108. subindexf(A & a, S & s, S & i)
  109. {
  110. a.at(i);
  111. }
  112. template <class order, class A, class S>
  113. std::enable_if_t<(mp::Len<order>::value>0)>
  114. subindexf(A & a, S & s_, S & i_)
  115. {
  116. dim_t & i = i_[mp::First_<order>::value];
  117. dim_t const s = s_[mp::First_<order>::value];
  118. for (i=0; i!=s; ++i) {
  119. subindexf<mp::Drop1_<order>>(a, s_, i_);
  120. }
  121. }
  122. template <class A>
  123. void plyf_index(A && a)
  124. {
  125. auto s(a.shape());
  126. using Shape = decltype(s);
  127. Shape i(ra_traits<Shape>::make(s.size(), 0));
  128. subindexf<mp::Iota_<A::rank_s()>>(a, s, i); // cf with ply_index() for C order.
  129. }
  130. // -------------------------
  131. // Compile time order. See bench-array-dot-ra.C for use. No index version.
  132. // With compile-time recursion by rank, one can use adv<k>, but order must also be compile-time.
  133. // -------------------------
  134. template <class order, int ravel_rank, class A, class S>
  135. std::enable_if_t<mp::Len<order>::value==ravel_rank>
  136. subindex(A & a, dim_t const s, S const & ss0)
  137. {
  138. auto p = a.flat();
  139. for (auto end=p+s; p!=end; p+=ss0) {
  140. *p;
  141. }
  142. }
  143. template <class order, int ravel_rank, class A, class S>
  144. std::enable_if_t<(mp::Len<order>::value>ravel_rank)>
  145. subindex(A & a, dim_t const s, S const & ss0)
  146. {
  147. dim_t size = a.size(mp::First_<order>::value); // @TODO Precompute these at the top
  148. for (dim_t i=0, iend=size; i<iend; ++i) {
  149. subindex<mp::Drop1_<order>, ravel_rank>(a, s, ss0);
  150. a.adv(mp::First_<order>::value, 1);
  151. }
  152. a.adv(mp::First_<order>::value, -size);
  153. }
  154. // until() converts runtime jj into compile time j. @TODO a.adv<k>().
  155. template <class order, int j, class A, class S>
  156. std::enable_if_t<(mp::Len<order>::value<j)>
  157. until(int const jj, A & a, dim_t const s, S const & ss0)
  158. {
  159. assert(0 && "rank too high");
  160. }
  161. template <class order, int j, class A, class S>
  162. std::enable_if_t<(mp::Len<order>::value>=j)>
  163. until(int const jj, A & a, dim_t const s, S const & ss0)
  164. {
  165. if (jj==j) {
  166. subindex<order, j>(a, s, ss0);
  167. } else {
  168. until<order, j+1>(jj, a, s, ss0);
  169. }
  170. }
  171. template <class A>
  172. auto plyf(A && a) -> std::enable_if_t<(A::rank_s()<=0)>
  173. {
  174. static_assert(A::rank_s()==0, "plyf needs static rank");
  175. *(a.flat());
  176. }
  177. template <class A>
  178. auto plyf(A && a) -> std::enable_if_t<(A::rank_s()==1)>
  179. {
  180. subindex<mp::Iota_<1>, 1>(a, a.size(0)*a.preferred_stride(0), a.stride(0));
  181. }
  182. template <class A>
  183. auto plyf(A && a) -> std::enable_if_t<(A::rank_s()>1)>
  184. {
  185. rank_t const rank = a.rank();
  186. // find the outermost compact dim.
  187. auto step = a.size(rank-1);
  188. int j = 1;
  189. while (j!=rank && a.compact_stride(step, rank-1, rank-1-j)) {
  190. step *= a.size(rank-1-j);
  191. ++j;
  192. }
  193. step *= a.preferred_stride(rank-1);
  194. // all sub xpr strides advance in compact dims, as they might be different.
  195. // send with static j. Note that order here is inverse of order.
  196. until<mp::Iota_<A::rank_s()>, 0>(j, a, step, a.stride(rank-1));
  197. }
  198. // ---------------------------
  199. // Selectors, best performance for each type.
  200. // ---------------------------
  201. template <class A>
  202. enableif_<has_tensorindex<std::decay_t<A>>>
  203. ply_either(A && a)
  204. {
  205. ply_index(std::forward<A>(a));
  206. }
  207. template <class A>
  208. std::enable_if_t<!has_tensorindex<std::decay_t<A>>::value && (A::size_s()==DIM_ANY || (A::rank_s()!=0 && A::rank_s()!=1))>
  209. ply_either(A && a)
  210. {
  211. ply_ravel(std::forward<A>(a));
  212. }
  213. template <class A>
  214. std::enable_if_t<!has_tensorindex<std::decay_t<A>>::value && (A::size_s()!=DIM_ANY && (A::rank_s()==0 || A::rank_s()==1))>
  215. ply_either(A && a)
  216. {
  217. plyf(std::forward<A>(a));
  218. }
  219. // ---------------------------
  220. // Short-circuiting pliers. @TODO These are reductions. How to do higher rank?
  221. // ---------------------------
  222. // @BUG Slow. Options for ply should be the same as for non-short circuit.
  223. template <class Op, class A, std::enable_if_t<is_array_iterator<A>::value && (A::rank_s()!=1 || has_tensorindex<A>::value), int> = 0>
  224. bool ply_index_short_circuit(A && a)
  225. {
  226. /* @TODO try and merge the singular cases. */
  227. if (a.done()) {
  228. return Op()(false);
  229. }
  230. rank_t const rank = a.rank();
  231. auto s(a.shape());
  232. using Shape = decltype(s);
  233. Shape i(ra_traits<Shape>::make(rank, 0));
  234. rank_t order[rank];
  235. for (rank_t i=0; i<rank; ++i) {
  236. order[i] = rank-1-i;
  237. }
  238. for (;;) {
  239. if (Op()(a.at(i))) {
  240. return Op()(true);
  241. }
  242. for (int k=0; ; ++k) {
  243. if (k==rank) {
  244. return Op()(false);
  245. } else if (++i[order[k]]<s[order[k]]) {
  246. break;
  247. } else {
  248. i[order[k]] = 0;
  249. }
  250. }
  251. }
  252. }
  253. template <class Op, class A, std::enable_if_t<is_array_iterator<A>::value && (A::rank_s()==1 && !has_tensorindex<A>::value), int> = 0>
  254. bool ply_index_short_circuit(A && a)
  255. {
  256. auto s = a.size(0)*a.preferred_stride(0);
  257. auto ss0 = a.stride(0);
  258. auto p = a.flat();
  259. for (auto end=p+s; p!=end; p+=ss0) {
  260. if (Op()(*p)) {
  261. return Op()(true);
  262. }
  263. }
  264. return Op()(false);
  265. }
  266. template <class Op, class A, enableif_<mp::Not<is_array_iterator<A>>, int> = 0>
  267. bool ply_index_short_circuit(A && a)
  268. {
  269. return ply_index_short_circuit<Op>(start(std::forward<A>(a)));
  270. }
  271. template <class A> bool any(A && a) { return ply_index_short_circuit<mp::identity>(std::forward<A>(a)); }
  272. template <class A> bool every(A && a) { return ply_index_short_circuit<std::logical_not<bool> >(std::forward<A>(a)); }
  273. } // namespace ra
  274. #endif // RA_PLY_H