ply.hh 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. // -*- mode: c++; coding: utf-8 -*-
  2. // ra-ra - Traverse (ply) array or array expression or array statement.
  3. // (c) Daniel Llorens - 2013-2019, 2021
  4. // This library is free software; you can redistribute it and/or modify it under
  5. // the terms of the GNU Lesser General Public License as published by the Free
  6. // Software Foundation; either version 3 of the License, or (at your option) any
  7. // later version.
  8. // TODO Lots of room for improvement: small (fixed sizes) and large (tiling, etc. see eval.cc in Blitz++).
  9. // TODO Traversal order should be a parameter, since some operations (e.g. output, ravel) require a specific order.
  10. // TODO Better heuristic for traversal order.
  11. // TODO std::execution::xxx-policy, validate output argument strides.
  12. #pragma once
  13. #include "atom.hh"
  14. #include <functional>
  15. namespace ra {
  16. // --------------
  17. // Run time order
  18. // --------------
  19. // Traverse array expression looking to ravel the inner loop.
  20. // step() must give 0 for k>=their own rank, to allow frame matching.
  21. template <IteratorConcept A>
  22. inline void
  23. ply_ravel(A && a)
  24. {
  25. rank_t rank = a.rank();
  26. // FIXME without assert compiler thinks var rank may be negative. See test in [ra40].
  27. #ifdef NDEBUG
  28. #pragma GCC diagnostic push
  29. #pragma GCC diagnostic ignored "-Wvla-larger-than="
  30. rank_t order[rank];
  31. dim_t sha[rank], ind[rank];
  32. #pragma GCC diagnostic pop
  33. #else
  34. assert(rank>=0);
  35. rank_t order[rank];
  36. dim_t sha[rank], ind[rank];
  37. #endif
  38. for (rank_t i=0; i<rank; ++i) {
  39. order[i] = rank-1-i;
  40. }
  41. switch (rank) {
  42. case 0: *(a.flat()); return;
  43. case 1: break;
  44. default: // TODO better heuristic
  45. // if (rank>1) {
  46. // std::sort(order, order+rank, [&a, &order](auto && i, auto && j)
  47. // { return a.len(order[i])<a.len(order[j]); });
  48. // }
  49. ;
  50. }
  51. // outermost compact dim.
  52. rank_t * ocd = order;
  53. // FIXME see same thing below.
  54. #pragma GCC diagnostic push
  55. #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
  56. auto ss = a.len(*ocd);
  57. #pragma GCC diagnostic pop
  58. for (--rank, ++ocd; rank>0 && a.keep_step(ss, order[0], *ocd); --rank, ++ocd) {
  59. ss *= a.len(*ocd);
  60. }
  61. for (int k=0; k<rank; ++k) {
  62. ind[k] = 0;
  63. sha[k] = a.len(ocd[k]);
  64. if (sha[k]==0) { // for the raveled dimensions ss takes care.
  65. return;
  66. }
  67. RA_CHECK(sha[k]!=DIM_BAD, "undefined dim ", ocd[k]);
  68. }
  69. // all sub xpr steps advance in compact dims, as they might be different.
  70. auto const ss0 = a.step(order[0]);
  71. for (;;) {
  72. dim_t s = ss;
  73. for (auto p=a.flat(); s>0; --s, p+=ss0) {
  74. *p;
  75. }
  76. for (int k=0; ; ++k) {
  77. if (k>=rank) {
  78. return;
  79. } else if (ind[k]<sha[k]-1) {
  80. ++ind[k];
  81. a.adv(ocd[k], 1);
  82. break;
  83. } else {
  84. ind[k] = 0;
  85. a.adv(ocd[k], 1-sha[k]);
  86. }
  87. }
  88. }
  89. }
  90. // -------------------------
  91. // Compile time order.
  92. // -------------------------
  93. template <class order, int ravel_rank, class A, class S>
  94. constexpr void
  95. subindex(A & a, dim_t s, S const & ss0)
  96. {
  97. if constexpr (mp::len<order> == ravel_rank) {
  98. #pragma GCC diagnostic push
  99. #pragma GCC diagnostic warning "-Wstringop-overflow"
  100. #pragma GCC diagnostic warning "-Wstringop-overread"
  101. for (auto p=a.flat(); s>0; --s, p+=ss0) {
  102. *p;
  103. }
  104. #pragma GCC diagnostic pop
  105. } else {
  106. dim_t size = a.len(mp::first<order>::value); // TODO Precompute these at the top
  107. for (dim_t i=0, iend=size; i<iend; ++i) {
  108. subindex<mp::drop1<order>, ravel_rank>(a, s, ss0);
  109. a.adv(mp::first<order>::value, 1);
  110. }
  111. a.adv(mp::first<order>::value, -size);
  112. }
  113. }
  114. // convert runtime jj into compile time j. TODO a.adv<k>().
  115. template <class order, int j, class A, class S>
  116. constexpr void
  117. until(int const jj, A & a, dim_t const s, S const & ss0)
  118. {
  119. if constexpr (mp::len<order> >= j) {
  120. if (jj==j) {
  121. subindex<order, j>(a, s, ss0);
  122. } else {
  123. until<order, j+1>(jj, a, s, ss0);
  124. }
  125. } else {
  126. std::abort();
  127. }
  128. }
  129. // find outermost compact dim.
  130. template <class A>
  131. constexpr auto
  132. ocd()
  133. {
  134. rank_t const rank = A::rank_s();
  135. auto s = A::len_s(rank-1);
  136. int j = 1;
  137. while (j<rank && A::keep_step(s, rank-1, rank-1-j)) {
  138. s *= A::len_s(rank-1-j);
  139. ++j;
  140. }
  141. return std::make_tuple(s, j);
  142. };
  143. template <IteratorConcept A>
  144. constexpr void
  145. plyf(A && a)
  146. {
  147. constexpr rank_t rank = rank_s<A>();
  148. static_assert(rank>=0, "plyf needs static rank");
  149. if constexpr (rank_s<A>()==0) {
  150. *(a.flat());
  151. } else if constexpr (rank_s<A>()==1) {
  152. subindex<mp::iota<1>, 1>(a, a.len(0), a.step(0));
  153. // this can only be enabled when f() will be constexpr; static keep_step implies all else is also static.
  154. // important rank>1 for with static size operands [ra43].
  155. } else if constexpr (rank_s<A>()>1 && requires (dim_t d, rank_t i, rank_t j) { A::keep_step(d, i, j); }) {
  156. constexpr auto sj = ocd<std::decay_t<A>>();
  157. constexpr auto s = std::get<0>(sj);
  158. constexpr auto j = std::get<1>(sj);
  159. // all sub xpr steps advance in compact dims, as they might be different.
  160. // send with static j. Note that order here is inverse of order.
  161. until<mp::iota<rank_s<A>()>, 0>(j, a, s, a.step(rank-1));
  162. } else {
  163. // the unrolling above isn't worth it when s, j cannot be constexpr.
  164. auto s = a.len(rank-1);
  165. subindex<mp::iota<rank_s<A>()>, 1>(a, s, a.step(rank-1));
  166. }
  167. }
  168. // ---------------------------
  169. // Select best performance (or requirements) for each type.
  170. // ---------------------------
  171. template <IteratorConcept A>
  172. constexpr void
  173. ply(A && a)
  174. {
  175. if constexpr (size_s<A>()==DIM_ANY) {
  176. ply_ravel(std::forward<A>(a));
  177. } else {
  178. plyf(std::forward<A>(a));
  179. }
  180. }
  181. // ---------------------------
  182. // Short-circuiting pliers.
  183. // ---------------------------
  184. // TODO Refactor with ply_ravel. Make exit available to plyf.
  185. // TODO These are reductions. How to do higher rank?
  186. template <IteratorConcept A, class DEF>
  187. inline auto
  188. ply_ravel_exit(A && a, DEF && def)
  189. {
  190. rank_t rank = a.rank();
  191. // FIXME without assert compiler thinks var rank may be negative. See test in [ra40].
  192. #ifdef NDEBUG
  193. #pragma GCC diagnostic push
  194. #pragma GCC diagnostic ignored "-Wvla-larger-than="
  195. rank_t order[rank];
  196. dim_t sha[rank], ind[rank];
  197. #pragma GCC diagnostic pop
  198. #else
  199. assert(rank>=0);
  200. rank_t order[rank];
  201. dim_t sha[rank], ind[rank];
  202. #endif
  203. for (rank_t i=0; i<rank; ++i) {
  204. order[i] = rank-1-i;
  205. }
  206. switch (rank) {
  207. case 0: {
  208. if (auto what = *(a.flat()); std::get<0>(what)) {
  209. return std::get<1>(what);
  210. }
  211. return def;
  212. }
  213. case 1: break;
  214. default: // TODO better heuristic
  215. // if (rank>1) {
  216. // std::sort(order, order+rank, [&a, &order](auto && i, auto && j)
  217. // { return a.len(order[i])<a.len(order[j]); });
  218. // }
  219. ;
  220. }
  221. // outermost compact dim.
  222. rank_t * ocd = order;
  223. // FIXME on github actions ubuntu-latest g++-11 -O3 :-|
  224. #pragma GCC diagnostic push
  225. #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
  226. auto ss = a.len(*ocd);
  227. #pragma GCC diagnostic pop
  228. for (--rank, ++ocd; rank>0 && a.keep_step(ss, order[0], *ocd); --rank, ++ocd) {
  229. ss *= a.len(*ocd);
  230. }
  231. for (int k=0; k<rank; ++k) {
  232. ind[k] = 0;
  233. sha[k] = a.len(ocd[k]);
  234. if (sha[k]==0) { // for the raveled dimensions ss takes care.
  235. return def;
  236. }
  237. RA_CHECK(sha[k]!=DIM_BAD, "undefined dim ", ocd[k]);
  238. }
  239. // all sub xpr steps advance in compact dims, as they might be different.
  240. auto const ss0 = a.step(order[0]);
  241. for (;;) {
  242. dim_t s = ss;
  243. for (auto p=a.flat(); s>0; --s, p+=ss0) {
  244. if (auto what = *p; std::get<0>(what)) {
  245. return std::get<1>(what);
  246. }
  247. }
  248. for (int k=0; ; ++k) {
  249. if (k>=rank) {
  250. return def;
  251. } else if (ind[k]<sha[k]-1) {
  252. ++ind[k];
  253. a.adv(ocd[k], 1);
  254. break;
  255. } else {
  256. ind[k] = 0;
  257. a.adv(ocd[k], 1-sha[k]);
  258. }
  259. }
  260. }
  261. }
  262. template <IteratorConcept A, class DEF>
  263. constexpr decltype(auto)
  264. early(A && a, DEF && def)
  265. {
  266. return ply_ravel_exit(std::forward<A>(a), std::forward<DEF>(def));
  267. }
  268. template <class Op, class ... A>
  269. constexpr void
  270. for_each(Op && op, A && ... a)
  271. {
  272. ply(map(std::forward<Op>(op), std::forward<A>(a) ...));
  273. }
  274. } // namespace ra