expr.hh 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. // -*- mode: c++; coding: utf-8 -*-
  2. // ra-ra - Operation nodes for expression templates.
  3. // (c) Daniel Llorens - 2011-2022
  4. // This library is free software; you can redistribute it and/or modify it under
  5. // the terms of the GNU Lesser General Public License as published by the Free
  6. // Software Foundation; either version 3 of the License, or (at your option) any
  7. // later version.
  8. #pragma once
  9. #include "match.hh"
  10. #include <functional>
  11. namespace ra {
  12. // ---------------------------
  13. // reframe
  14. // ---------------------------
  15. // Reframe is a variant of transpose that works on any IteratorConcept. As in transpose(), one names
  16. // the destination axis for each original axis. However, unlike general transpose, axes may not be
  17. // repeated. The main application is the rank conjunction below.
  18. template <class T> constexpr T zerostep = 0;
  19. template <class ... T> constexpr std::tuple<T ...> zerostep<std::tuple<T ...>> = { zerostep<T> ... };
  20. // Dest is a list of destination axes [l0 l1 ... li ... l(rank(A)-1)].
  21. // The dimensions of the reframed A are numbered as [0 ... k ... max(l)-1].
  22. // If li = k for some i, then axis k of the reframed A moves on axis i of the original iterator A.
  23. // If not, then axis k of the reframed A is 'dead' and doesn't move the iterator.
  24. // TODO invalid for RANK_ANY (since Dest is compile time). [ra7]
  25. template <class Dest, IteratorConcept A>
  26. struct Reframe
  27. {
  28. A a;
  29. constexpr static int orig(int k) { return mp::int_list_index<Dest>(k); }
  30. constexpr static rank_t rank_s() { return 1+mp::fold<mp::max, int_c<-1>, Dest>::value; }
  31. constexpr static rank_t rank() { return rank_s(); }
  32. constexpr static dim_t len_s(int k)
  33. {
  34. int l = orig(k);
  35. return l>=0 ? std::decay_t<A>::len_s(l) : DIM_BAD;
  36. }
  37. constexpr dim_t
  38. len(int k) const
  39. {
  40. int l = orig(k);
  41. return l>=0 ? a.len(l) : DIM_BAD;
  42. }
  43. constexpr void
  44. adv(rank_t k, dim_t d)
  45. {
  46. if (int l = orig(k); l>=0) {
  47. a.adv(l, d);
  48. }
  49. }
  50. constexpr auto
  51. step(int k) const
  52. {
  53. int l = orig(k);
  54. return l>=0 ? a.step(l) : zerostep<decltype(a.step(l))>;
  55. }
  56. constexpr bool
  57. keep_step(dim_t st, int z, int j) const
  58. {
  59. int wz = orig(z);
  60. int wj = orig(j);
  61. return wz>=0 && wj>=0 && a.keep_step(st, wz, wj);
  62. }
  63. constexpr decltype(auto)
  64. flat()
  65. {
  66. return a.flat();
  67. }
  68. constexpr decltype(auto)
  69. at(auto const & i) const
  70. {
  71. return a.at(mp::map_indices<std::array<dim_t, mp::len<Dest>>, Dest>(i));
  72. }
  73. };
  74. // Optimize no-op case.
  75. // TODO If A is CellBig, etc. beat Dest directly on it, same for eventual transpose_expr<>.
  76. template <class Dest, class A>
  77. constexpr decltype(auto)
  78. reframe(A && a)
  79. {
  80. if constexpr (std::is_same_v<Dest, mp::iota<1+mp::fold<mp::max, int_c<-1>, Dest>::value>>) {
  81. return std::forward<A>(a);
  82. } else {
  83. return Reframe<Dest, A> { std::forward<A>(a) };
  84. }
  85. }
  86. // ---------------------------
  87. // verbs and rank conjunction
  88. // ---------------------------
  89. template <class cranks_, class Op_>
  90. struct Verb
  91. {
  92. using cranks = cranks_;
  93. using Op = Op_;
  94. Op op;
  95. };
  96. RA_IS_DEF(is_verb, (std::is_same_v<A, Verb<typename A::cranks, typename A::Op>>))
  97. template <class cranks, class Op>
  98. constexpr auto
  99. wrank(cranks cranks_, Op && op)
  100. {
  101. return Verb<cranks, Op> { std::forward<Op>(op) };
  102. }
  103. template <rank_t ... crank, class Op>
  104. constexpr auto
  105. wrank(Op && op)
  106. {
  107. return Verb<mp::int_list<crank ...>, Op> { std::forward<Op>(op) };
  108. }
  109. template <class V, class T, class R=mp::makelist<mp::len<T>, mp::nil>, rank_t skip=0>
  110. struct Framematch_def;
  111. template <class V, class T, class R=mp::makelist<mp::len<T>, mp::nil>, rank_t skip=0>
  112. using Framematch = Framematch_def<std::decay_t<V>, T, R, skip>;
  113. template <class A, class B>
  114. struct max_i
  115. {
  116. constexpr static int value = (A::value == choose_rank(A::value, B::value)) ? 0 : 1;
  117. };
  118. // Get a list (per argument) of lists of live axes. The last frame match is handled by standard prefix matching.
  119. template <class ... crank, class W, class ... Ti, class ... Ri, rank_t skip>
  120. struct Framematch_def<Verb<std::tuple<crank ...>, W>, std::tuple<Ti ...>, std::tuple<Ri ...>, skip>
  121. {
  122. static_assert(sizeof...(Ti)==sizeof...(crank) && sizeof...(Ti)==sizeof...(Ri), "Bad arguments.");
  123. // live = number of live axes on this frame, for each argument. // TODO crank negative, inf.
  124. using live = mp::int_list<(rank_s<Ti>() - mp::len<Ri> - crank::value) ...>;
  125. using frameaxes = std::tuple<mp::append<Ri, mp::iota<(rank_s<Ti>() - mp::len<Ri> - crank::value), skip>> ...>;
  126. using FM = Framematch<W, std::tuple<Ti ...>, frameaxes, skip + mp::ref<live, mp::indexof<max_i, live>>::value>;
  127. using R = typename FM::R;
  128. template <class VV> static decltype(auto) op(VV && v) { return FM::op(std::forward<VV>(v).op); } // cf [ra31]
  129. };
  130. // Terminal case where V doesn't have rank (is a raw op()).
  131. template <class V, class ... Ti, class ... Ri, rank_t skip>
  132. struct Framematch_def<V, std::tuple<Ti ...>, std::tuple<Ri ...>, skip>
  133. {
  134. static_assert(sizeof...(Ti)==sizeof...(Ri), "Bad arguments.");
  135. // TODO -crank::value when the actual verb rank is used (eg to use CellBig<A, that_rank> instead of just begin()).
  136. using R = std::tuple<mp::append<Ri, mp::iota<(rank_s<Ti>() - mp::len<Ri>), skip>> ...>;
  137. template <class VV> static decltype(auto) op(VV && v) { return std::forward<VV>(v); }
  138. };
  139. // ---------------------------
  140. // general expression
  141. // ---------------------------
  142. template <class Op, class T, class K=mp::iota<mp::len<T>>> struct Expr;
  143. template <class Op, IteratorConcept ... P, int ... I>
  144. struct Expr<Op, std::tuple<P ...>, mp::int_list<I ...>>: public Match<true, std::tuple<P ...>>
  145. {
  146. template <class T>
  147. struct Flat
  148. {
  149. Op & op;
  150. T t;
  151. template <class S> constexpr void operator+=(S const & s) { ((std::get<I>(t) += std::get<I>(s)), ...); }
  152. // FIXME gcc 12.1 flags this (-O3 only).
  153. #pragma GCC diagnostic push
  154. #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
  155. constexpr decltype(auto) operator*() { return std::invoke(op, *std::get<I>(t) ...); }
  156. #pragma GCC diagnostic pop
  157. };
  158. template <class ... F>
  159. constexpr static auto
  160. flat(Op & op, F && ... f)
  161. {
  162. return Flat<std::tuple<F ...>> { op, { std::forward<F>(f) ... } };
  163. }
  164. using Match_ = Match<true, std::tuple<P ...>>;
  165. using Match_::t, Match_::rank_s, Match_::rank;
  166. Op op;
  167. // test/ra-9.cc [ra1]
  168. constexpr Expr(Op op_, P ... p_): Match_(std::forward<P>(p_) ...), op(std::forward<Op>(op_)) {}
  169. RA_DEF_ASSIGNOPS_SELF(Expr)
  170. RA_DEF_ASSIGNOPS_DEFAULT_SET
  171. constexpr decltype(auto)
  172. at(auto const & j) const
  173. {
  174. return std::invoke(op, std::get<I>(t).at(j) ...);
  175. }
  176. constexpr decltype(auto)
  177. flat() // FIXME can't be const bc of Flat::op. Carries over to Pick / Reframe .flat() ...
  178. {
  179. return flat(op, std::get<I>(t).flat() ...);
  180. }
  181. // needed for rank_s()==RANK_ANY, which don't decay to scalar when used as operator arguments.
  182. operator decltype(*(flat(op, std::get<I>(t).flat() ...))) ()
  183. {
  184. // for coord types; so ct only
  185. if constexpr ((rank_s()!=1 || size_s<Expr>()!=1) && rank_s()!=0) {
  186. static_assert(rank_s()==RANK_ANY);
  187. assert(rank()==0);
  188. }
  189. return *flat();
  190. }
  191. };
  192. template <class Op, IteratorConcept ... P>
  193. constexpr bool is_special_def<Expr<Op, std::tuple<P ...>>> = (is_special<P> || ...);
  194. template <class V, class ... T, int ... i>
  195. constexpr auto
  196. expr_verb(mp::int_list<i ...>, V && v, T && ... t)
  197. {
  198. using FM = Framematch<V, std::tuple<T ...>>;
  199. return expr(FM::op(std::forward<V>(v)), reframe<mp::ref<typename FM::R, i>>(std::forward<T>(t)) ...);
  200. }
  201. template <class Op, class ... P>
  202. constexpr auto
  203. expr(Op && op, P && ... p)
  204. {
  205. if constexpr (is_verb<Op>) {
  206. return expr_verb(mp::iota<sizeof...(P)> {}, std::forward<Op>(op), std::forward<P>(p) ...);
  207. } else {
  208. return Expr<Op, std::tuple<P ...>> { std::forward<Op>(op), std::forward<P>(p) ... };
  209. }
  210. }
  211. template <class Op, class ... A>
  212. constexpr auto
  213. map(Op && op, A && ... a)
  214. {
  215. return expr(std::forward<Op>(op), start(std::forward<A>(a)) ...);
  216. }
  217. // ---------------
  218. // explicit agreement checks. FIXME provide separate agree_s().
  219. // ---------------
  220. template <class ... P>
  221. constexpr bool
  222. agree(P && ... p)
  223. {
  224. return agree_(ra::start(std::forward<P>(p)) ...);
  225. }
  226. template <class Op, class ... P>
  227. constexpr bool
  228. agree_op(Op && op, P && ... p)
  229. {
  230. return agree_op_(std::forward<Op>(op), ra::start(std::forward<P>(p)) ...);
  231. }
  232. template <class ... P>
  233. constexpr bool
  234. agree_(P && ... p)
  235. {
  236. return check_expr<false>(Match<false, std::tuple<P ...>> { std::forward<P>(p) ... });
  237. }
  238. template <class Op, class ... P>
  239. constexpr bool
  240. agree_op_(Op && op, P && ... p)
  241. {
  242. if constexpr (is_verb<Op>) {
  243. return agree_verb(mp::iota<sizeof...(P)> {}, std::forward<Op>(op), std::forward<P>(p) ...);
  244. } else {
  245. return agree_(std::forward<P>(p) ...);
  246. }
  247. }
  248. template <class V, class ... T, int ... i>
  249. constexpr bool
  250. agree_verb(mp::int_list<i ...>, V && v, T && ... t)
  251. {
  252. using FM = Framematch<V, std::tuple<T ...>>;
  253. return agree_op_(FM::op(std::forward<V>(v)), reframe<mp::ref<typename FM::R, i>>(std::forward<T>(t)) ...);
  254. }
  255. } // namespace ra