ply.hh 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. // -*- mode: c++; coding: utf-8 -*-
  2. // ra-ra - Traverse (ply) array or array expression or array statement.
  3. // (c) Daniel Llorens - 2013-2019, 2021
  4. // This library is free software; you can redistribute it and/or modify it under
  5. // the terms of the GNU Lesser General Public License as published by the Free
  6. // Software Foundation; either version 3 of the License, or (at your option) any
  7. // later version.
  8. // TODO Lots of room for improvement: small (fixed sizes) and large (tiling, etc. see eval.cc in Blitz++).
  9. // TODO Traversal order should be a parameter, since some operations (e.g. output, ravel) require a specific order.
  10. // TODO Better heuristic for traversal order.
  11. // TODO std::execution::xxx-policy, validate output argument strides.
  12. #pragma once
  13. #include "pick.hh"
  14. #include "expr.hh"
  15. namespace ra {
  16. // ---------------------
  17. // does expr tree contain Len?
  18. // ---------------------
  19. template <>
  20. constexpr bool has_len_def<Len> = true;
  21. template <IteratorConcept ... P>
  22. constexpr bool has_len_def<Pick<std::tuple<P ...>>> = (has_len<P> || ...);
  23. template <class Op, IteratorConcept ... P>
  24. constexpr bool has_len_def<Expr<Op, std::tuple<P ...>>> = (has_len<P> || ...);
  25. template <int w, class O, class N, class S>
  26. constexpr bool has_len_def<Iota<w, O, N, S>> = (has_len<O> || has_len<N> || has_len<S>);
  27. // ---------------------
  28. // replace Len in expr tree.
  29. // ---------------------
  30. template <class E_>
  31. struct WithLen
  32. {
  33. // constant & scalar appear in Iota args. dots_t and insert_t appear in subscripts.
  34. // FIXME what else? restrict to IteratorConcept<E_> || is_constant<E_> || is_scalar<E_> ...
  35. template <class E> constexpr static decltype(auto)
  36. f(dim_t len, E && e)
  37. {
  38. return std::forward<E>(e);
  39. }
  40. };
  41. template <>
  42. struct WithLen<Len>
  43. {
  44. template <class E> constexpr static decltype(auto)
  45. f(dim_t len, E && e)
  46. {
  47. return Scalar<dim_t>(len);
  48. }
  49. };
  50. template <class Op, IteratorConcept ... P, int ... I>
  51. requires (has_len<P> || ...)
  52. struct WithLen<Expr<Op, std::tuple<P ...>, mp::int_list<I ...>>>
  53. {
  54. template <class E> constexpr static decltype(auto)
  55. f(dim_t len, E && e)
  56. {
  57. return expr(std::forward<E>(e).op, WithLen<std::decay_t<P>>::f(len, std::get<I>(std::forward<E>(e).t)) ...);
  58. }
  59. };
  60. template <IteratorConcept ... P, int ... I>
  61. requires (has_len<P> || ...)
  62. struct WithLen<Pick<std::tuple<P ...>, mp::int_list<I ...>>>
  63. {
  64. template <class E> constexpr static decltype(auto)
  65. f(dim_t len, E && e)
  66. {
  67. return pick(WithLen<std::decay_t<P>>::f(len, std::get<I>(std::forward<E>(e).t)) ...);
  68. }
  69. };
  70. template <int w, class O, class N, class S>
  71. requires (has_len<O> || has_len<N> || has_len<S>)
  72. struct WithLen<Iota<w, O, N, S>>
  73. {
  74. // usable iota types must be either is_constant or is_scalar.
  75. template <class T> constexpr static decltype(auto)
  76. coerce(T && t)
  77. {
  78. if constexpr (IteratorConcept<T>) {
  79. return FLAT(t);
  80. } else {
  81. return std::forward<T>(t);
  82. }
  83. }
  84. template <class E> constexpr static decltype(auto)
  85. f(dim_t len, E && e)
  86. {
  87. return iota<w>(coerce(WithLen<std::decay_t<N>>::f(len, std::forward<E>(e).n)),
  88. coerce(WithLen<std::decay_t<O>>::f(len, std::forward<E>(e).i)),
  89. coerce(WithLen<std::decay_t<S>>::f(len, std::forward<E>(e).s)));
  90. }
  91. };
  92. template <class E>
  93. constexpr decltype(auto)
  94. with_len(dim_t len, E && e)
  95. {
  96. return WithLen<std::decay_t<E>>::f(len, std::forward<E>(e));
  97. }
  98. // --------------
  99. // ply, run time order
  100. // --------------
  101. // Traverse array expression looking to ravel the inner loop.
  102. // step() must give 0 for k>=their own rank, to allow frame matching.
  103. template <IteratorConcept A>
  104. inline void
  105. ply_ravel(A && a)
  106. {
  107. rank_t rank = a.rank();
  108. // FIXME without assert compiler thinks var rank may be negative. See test in [ra40].
  109. #ifdef NDEBUG
  110. #pragma GCC diagnostic push
  111. #pragma GCC diagnostic ignored "-Wvla-larger-than="
  112. rank_t order[rank];
  113. dim_t sha[rank], ind[rank];
  114. #pragma GCC diagnostic pop
  115. #else
  116. assert(rank>=0);
  117. rank_t order[rank];
  118. dim_t sha[rank], ind[rank];
  119. #endif
  120. for (rank_t i=0; i<rank; ++i) {
  121. order[i] = rank-1-i;
  122. }
  123. switch (rank) {
  124. case 0: *(a.flat()); return;
  125. case 1: break;
  126. default: // TODO better heuristic
  127. // if (rank>1) {
  128. // std::sort(order, order+rank, [&a, &order](auto && i, auto && j)
  129. // { return a.len(order[i])<a.len(order[j]); });
  130. // }
  131. ;
  132. }
  133. // outermost compact dim.
  134. rank_t * ocd = order;
  135. // FIXME see same thing below.
  136. #pragma GCC diagnostic push
  137. #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
  138. auto ss = a.len(*ocd);
  139. #pragma GCC diagnostic pop
  140. for (--rank, ++ocd; rank>0 && a.keep_step(ss, order[0], *ocd); --rank, ++ocd) {
  141. ss *= a.len(*ocd);
  142. }
  143. for (int k=0; k<rank; ++k) {
  144. ind[k] = 0;
  145. sha[k] = a.len(ocd[k]);
  146. if (sha[k]==0) { // for the raveled dimensions ss takes care.
  147. return;
  148. }
  149. RA_CHECK(sha[k]!=DIM_BAD, "undefined dim ", ocd[k]);
  150. }
  151. // all sub xpr steps advance in compact dims, as they might be different.
  152. auto const ss0 = a.step(order[0]);
  153. for (;;) {
  154. dim_t s = ss;
  155. for (auto p=a.flat(); s>0; --s, p+=ss0) {
  156. *p;
  157. }
  158. for (int k=0; ; ++k) {
  159. if (k>=rank) {
  160. return;
  161. } else if (ind[k]<sha[k]-1) {
  162. ++ind[k];
  163. a.adv(ocd[k], 1);
  164. break;
  165. } else {
  166. ind[k] = 0;
  167. a.adv(ocd[k], 1-sha[k]);
  168. }
  169. }
  170. }
  171. }
  172. // -------------------------
  173. // ply, compile time order
  174. // -------------------------
  175. template <class order, int ravel_rank, class A, class S>
  176. constexpr void
  177. subindex(A & a, dim_t s, S const & ss0)
  178. {
  179. if constexpr (mp::len<order> == ravel_rank) {
  180. #pragma GCC diagnostic push
  181. #pragma GCC diagnostic warning "-Wstringop-overflow"
  182. #pragma GCC diagnostic warning "-Wstringop-overread"
  183. for (auto p=a.flat(); s>0; --s, p+=ss0) {
  184. *p;
  185. }
  186. #pragma GCC diagnostic pop
  187. } else {
  188. dim_t size = a.len(mp::first<order>::value); // TODO Precompute these at the top
  189. for (dim_t i=0, iend=size; i<iend; ++i) {
  190. subindex<mp::drop1<order>, ravel_rank>(a, s, ss0);
  191. a.adv(mp::first<order>::value, 1);
  192. }
  193. a.adv(mp::first<order>::value, -size);
  194. }
  195. }
  196. // convert runtime jj into compile time j. TODO a.adv<k>().
  197. template <class order, int j, class A, class S>
  198. constexpr void
  199. until(int const jj, A & a, dim_t const s, S const & ss0)
  200. {
  201. if constexpr (mp::len<order> >= j) {
  202. if (jj==j) {
  203. subindex<order, j>(a, s, ss0);
  204. } else {
  205. until<order, j+1>(jj, a, s, ss0);
  206. }
  207. } else {
  208. std::abort();
  209. }
  210. }
  211. // find outermost compact dim.
  212. template <class A>
  213. constexpr auto
  214. ocd()
  215. {
  216. rank_t const rank = A::rank_s();
  217. auto s = A::len_s(rank-1);
  218. int j = 1;
  219. while (j<rank && A::keep_step(s, rank-1, rank-1-j)) {
  220. s *= A::len_s(rank-1-j);
  221. ++j;
  222. }
  223. return std::make_tuple(s, j);
  224. };
  225. template <IteratorConcept A>
  226. constexpr void
  227. plyf(A && a)
  228. {
  229. constexpr rank_t rank = rank_s<A>();
  230. static_assert(rank>=0, "plyf needs static rank");
  231. if constexpr (rank_s<A>()==0) {
  232. *(a.flat());
  233. } else if constexpr (rank_s<A>()==1) {
  234. subindex<mp::iota<1>, 1>(a, a.len(0), a.step(0));
  235. // this can only be enabled when f() will be constexpr; static keep_step implies all else is also static.
  236. // important rank>1 for with static size operands [ra43].
  237. } else if constexpr (rank_s<A>()>1 && requires (dim_t d, rank_t i, rank_t j) { A::keep_step(d, i, j); }) {
  238. constexpr auto sj = ocd<std::decay_t<A>>();
  239. constexpr auto s = std::get<0>(sj);
  240. constexpr auto j = std::get<1>(sj);
  241. // all sub xpr steps advance in compact dims, as they might be different.
  242. // send with static j. Note that order here is inverse of order.
  243. until<mp::iota<rank_s<A>()>, 0>(j, a, s, a.step(rank-1));
  244. } else {
  245. // the unrolling above isn't worth it when s, j cannot be constexpr.
  246. auto s = a.len(rank-1);
  247. subindex<mp::iota<rank_s<A>()>, 1>(a, s, a.step(rank-1));
  248. }
  249. }
  250. // ---------------------------
  251. // select best performance (or requirements) for each type
  252. // ---------------------------
  253. template <IteratorConcept A>
  254. constexpr void
  255. ply(A && a)
  256. {
  257. static_assert(!has_len<A>, "len used outside subscript context.");
  258. if constexpr (size_s<A>()==DIM_ANY) {
  259. ply_ravel(std::forward<A>(a));
  260. } else {
  261. plyf(std::forward<A>(a));
  262. }
  263. }
  264. // ---------------------------
  265. // ply, short-circuiting
  266. // ---------------------------
  267. // TODO Refactor with ply_ravel. Make exit available to plyf.
  268. // TODO These are reductions. How to do higher rank?
  269. template <IteratorConcept A, class DEF>
  270. inline auto
  271. ply_ravel_exit(A && a, DEF && def)
  272. {
  273. static_assert(!has_len<A>, "len used outside subscript context.");
  274. rank_t rank = a.rank();
  275. // FIXME without assert compiler thinks var rank may be negative. See test in [ra40].
  276. #ifdef NDEBUG
  277. #pragma GCC diagnostic push
  278. #pragma GCC diagnostic ignored "-Wvla-larger-than="
  279. rank_t order[rank];
  280. dim_t sha[rank], ind[rank];
  281. #pragma GCC diagnostic pop
  282. #else
  283. assert(rank>=0);
  284. rank_t order[rank];
  285. dim_t sha[rank], ind[rank];
  286. #endif
  287. for (rank_t i=0; i<rank; ++i) {
  288. order[i] = rank-1-i;
  289. }
  290. switch (rank) {
  291. case 0: {
  292. if (auto what = *(a.flat()); std::get<0>(what)) {
  293. return std::get<1>(what);
  294. }
  295. return def;
  296. }
  297. case 1: break;
  298. default: // TODO better heuristic
  299. // if (rank>1) {
  300. // std::sort(order, order+rank, [&a, &order](auto && i, auto && j)
  301. // { return a.len(order[i])<a.len(order[j]); });
  302. // }
  303. ;
  304. }
  305. // outermost compact dim.
  306. rank_t * ocd = order;
  307. // FIXME on github actions ubuntu-latest g++-11 -O3 :-|
  308. #pragma GCC diagnostic push
  309. #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
  310. auto ss = a.len(*ocd);
  311. #pragma GCC diagnostic pop
  312. for (--rank, ++ocd; rank>0 && a.keep_step(ss, order[0], *ocd); --rank, ++ocd) {
  313. ss *= a.len(*ocd);
  314. }
  315. for (int k=0; k<rank; ++k) {
  316. ind[k] = 0;
  317. sha[k] = a.len(ocd[k]);
  318. if (sha[k]==0) { // for the raveled dimensions ss takes care.
  319. return def;
  320. }
  321. RA_CHECK(sha[k]!=DIM_BAD, "undefined dim ", ocd[k]);
  322. }
  323. // all sub xpr steps advance in compact dims, as they might be different.
  324. auto const ss0 = a.step(order[0]);
  325. for (;;) {
  326. dim_t s = ss;
  327. for (auto p=a.flat(); s>0; --s, p+=ss0) {
  328. if (auto what = *p; std::get<0>(what)) {
  329. return std::get<1>(what);
  330. }
  331. }
  332. for (int k=0; ; ++k) {
  333. if (k>=rank) {
  334. return def;
  335. } else if (ind[k]<sha[k]-1) {
  336. ++ind[k];
  337. a.adv(ocd[k], 1);
  338. break;
  339. } else {
  340. ind[k] = 0;
  341. a.adv(ocd[k], 1-sha[k]);
  342. }
  343. }
  344. }
  345. }
  346. template <IteratorConcept A, class DEF>
  347. constexpr decltype(auto)
  348. early(A && a, DEF && def)
  349. {
  350. return ply_ravel_exit(std::forward<A>(a), std::forward<DEF>(def));
  351. }
  352. template <class Op, class ... A>
  353. constexpr void
  354. for_each(Op && op, A && ... a)
  355. {
  356. ply(map(std::forward<Op>(op), std::forward<A>(a) ...));
  357. }
  358. } // namespace ra