operators.hh 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. // -*- mode: c++; coding: utf-8 -*-
  2. // ra-ra - Operator overloads for expression templates.
  3. // (c) Daniel Llorens - 2014-2023
  4. // This library is free software; you can redistribute it and/or modify it under
  5. // the terms of the GNU General Public License as published by the Free
  6. // Software Foundation; either version 3 of the License, or (at your option) any
  7. // later version.
  8. #pragma once
  9. #include "big.hh"
  10. #include "optimize.hh"
  11. #include "complex.hh"
  12. #ifndef RA_DO_OPT
  13. #define RA_DO_OPT 1 // enabled by default
  14. #endif
  15. #if RA_DO_OPT==1
  16. #define RA_OPT optimize
  17. #else
  18. #define RA_OPT
  19. #endif
  20. // ---------------------------
  21. // globals FIXME do we really need these?
  22. // ---------------------------
  23. // These global versions must be available so that e.g. ra::transpose<> may be searched by ADL even when giving explicit template args. See http://stackoverflow.com/questions/9838862 .
  24. template <class A> constexpr void transpose(ra::no_arg);
  25. template <int A> constexpr void iter(ra::no_arg);
  26. namespace ra {
  27. template <class T> constexpr bool is_scalar_def<std::complex<T>> = true;
  28. template <int ... Iarg, class A>
  29. constexpr decltype(auto)
  30. transpose(mp::int_list<Iarg ...>, A && a)
  31. {
  32. return transpose<Iarg ...>(std::forward<A>(a));
  33. }
  34. // ---------------------------
  35. // TODO integrate with is_beatable shortcuts, operator() in the various array types.
  36. // ---------------------------
  37. template <class II, int drop, class Op>
  38. constexpr decltype(auto)
  39. from_partial(Op && op)
  40. {
  41. if constexpr (drop==mp::len<II>) {
  42. return std::forward<Op>(op);
  43. } else {
  44. return wrank(mp::append<mp::makelist<drop, int_c<0>>, mp::drop<II, drop>> {},
  45. from_partial<II, drop+1>(std::forward<Op>(op)));
  46. }
  47. }
  48. template <class I> using index_rank = int_c<rank_s<I>()>;
  49. // TODO we should be able to do better by slicing at each dimension, etc. But verb<> only supports rank-0 for the innermost op.
  50. template <class A, class ... I>
  51. constexpr auto
  52. from(A && a, I && ... i)
  53. {
  54. if constexpr (0==sizeof...(i)) {
  55. return a();
  56. } else if constexpr (1==sizeof...(i)) {
  57. // support dynamic rank for 1 arg only (see test in test/from.cc).
  58. return map(std::forward<A>(a), std::forward<I>(i) ...);
  59. } else {
  60. using II = mp::map<index_rank, mp::tuple<decltype(std::forward<I>(i)) ...>>;
  61. return map(from_partial<II, 1>(std::forward<A>(a)), std::forward<I>(i) ...);
  62. }
  63. }
  64. // --------------------------------
  65. // Array versions of operators and functions
  66. // --------------------------------
  67. // We need zero/scalar specializations because the scalar/scalar operators maybe be templated (e.g. complex<>), so they won't be found when an implicit conversion from zero->scalar is also needed. That is, without those specializations, ra::View<complex, 0> * complex will fail.
  68. // These depend on OPNAME defined in optimize.hh and used there to match ET patterns.
  69. #define DEF_NAMED_BINARY_OP(OP, OPNAME) \
  70. template <class A, class B> requires (ra_irreducible<A, B>) \
  71. constexpr auto \
  72. operator OP(A && a, B && b) \
  73. { \
  74. return RA_OPT(map(OPNAME(), std::forward<A>(a), std::forward<B>(b))); \
  75. } \
  76. template <class A, class B> requires (ra_reducible<A, B>) \
  77. constexpr auto \
  78. operator OP(A && a, B && b) \
  79. { \
  80. return FLAT(std::forward<A>(a)) OP FLAT(std::forward<B>(b)); \
  81. }
  82. DEF_NAMED_BINARY_OP(+, std::plus<>)
  83. DEF_NAMED_BINARY_OP(-, std::minus<>)
  84. DEF_NAMED_BINARY_OP(*, std::multiplies<>)
  85. DEF_NAMED_BINARY_OP(/, std::divides<>)
  86. DEF_NAMED_BINARY_OP(==, std::equal_to<>)
  87. DEF_NAMED_BINARY_OP(>, std::greater<>)
  88. DEF_NAMED_BINARY_OP(<, std::less<>)
  89. DEF_NAMED_BINARY_OP(>=, std::greater_equal<>)
  90. DEF_NAMED_BINARY_OP(<=, std::less_equal<>)
  91. DEF_NAMED_BINARY_OP(!=, std::not_equal_to<>)
  92. DEF_NAMED_BINARY_OP(|, std::bit_or<>)
  93. DEF_NAMED_BINARY_OP(&, std::bit_and<>)
  94. DEF_NAMED_BINARY_OP(^, std::bit_xor<>)
  95. DEF_NAMED_BINARY_OP(<=>, std::compare_three_way)
  96. #undef DEF_NAMED_BINARY_OP
  97. // FIXME address sanitizer complains in bench-optimize.cc if we use std::identity. Maybe false positive
  98. struct unaryplus
  99. {
  100. template <class T> constexpr /* static P1169 in gcc13 */ auto
  101. operator()(T && t) const noexcept
  102. { return std::forward<T>(t); }
  103. };
  104. #define DEF_NAMED_UNARY_OP(OP, OPNAME) \
  105. template <class A> requires (ra_irreducible<A>) \
  106. constexpr auto \
  107. operator OP(A && a) \
  108. { \
  109. return map(OPNAME(), std::forward<A>(a)); \
  110. } \
  111. template <class A> requires (ra_reducible<A>) \
  112. constexpr auto \
  113. operator OP(A && a) \
  114. { \
  115. return OP FLAT(std::forward<A>(a)); \
  116. }
  117. DEF_NAMED_UNARY_OP(+, unaryplus)
  118. DEF_NAMED_UNARY_OP(-, std::negate<>)
  119. DEF_NAMED_UNARY_OP(!, std::logical_not<>)
  120. #undef DEF_NAMED_UNARY_OP
  121. // When OP(a) isn't found from ra::, the deduction from rank(0) -> scalar doesn't work.
  122. // TODO Cf examples/useret.cc, test/reexported.cc
  123. #define DEF_NAME_OP(OP) \
  124. using ::OP; \
  125. template <class ... A> requires (ra_irreducible<A ...>) \
  126. constexpr auto \
  127. OP(A && ... a) \
  128. { \
  129. return map([](auto && ... a) -> decltype(auto) { return OP(a ...); }, std::forward<A>(a) ...); \
  130. } \
  131. template <class ... A> requires (ra_reducible<A ...>) \
  132. constexpr decltype(auto) \
  133. OP(A && ... a) \
  134. { \
  135. return OP(FLAT(std::forward<A>(a)) ...); \
  136. }
  137. FOR_EACH(DEF_NAME_OP, rel_error, pow, xI, conj, sqr, sqrm, sqrt, cos, sin)
  138. FOR_EACH(DEF_NAME_OP, exp, expm1, log, log1p, log10, isfinite, isnan, isinf, clamp)
  139. FOR_EACH(DEF_NAME_OP, max, min, abs, ra::odd, asin, acos, atan, atan2, lerp, arg)
  140. FOR_EACH(DEF_NAME_OP, cosh, sinh, tanh)
  141. FOR_EACH(DEF_NAME_OP, real_part, imag_part) // return ref
  142. #undef DEF_NAME_OP
  143. template <class T, class A>
  144. constexpr auto cast(A && a)
  145. {
  146. return map([](auto && b) { return T(b); }, std::forward<A>(a));
  147. }
  148. // TODO could be useful to deduce T as tuple of value_types (&).
  149. template <class T, class ... A>
  150. constexpr auto pack(A && ... a)
  151. {
  152. return map([](auto && ... a) { return T { a ... }; }, std::forward<A>(a) ...);
  153. }
  154. // FIXME needs a nested array for I, which is ugly.
  155. template <class A, class I>
  156. constexpr auto at(A && a, I && i)
  157. {
  158. return map([a = std::tuple<A>(std::forward<A>(a))]
  159. (auto && i) -> decltype(auto) { return std::get<0>(a).at(i); }, i);
  160. }
  161. // --------------------------------
  162. // selection or shorcutting
  163. // --------------------------------
  164. // These ra::start are needed bc rank 0 converts to and from scalar, so ? can't pick the right (-> scalar) conversion.
  165. template <class T, class F>
  166. requires (ra_reducible<T, F>)
  167. constexpr decltype(auto)
  168. where(bool const w, T && t, F && f)
  169. {
  170. return w ? FLAT(t) : FLAT(f);
  171. }
  172. template <class W, class T, class F>
  173. requires (ra_irreducible<W, T, F>)
  174. constexpr auto
  175. where(W && w, T && t, F && f)
  176. {
  177. return pick(cast<bool>(std::forward<W>(w)), std::forward<F>(f), std::forward<T>(t));
  178. }
  179. // catch all for non-ra types.
  180. template <class T, class F>
  181. requires (!(ra_irreducible<T, F>) && !(ra_reducible<T, F>))
  182. constexpr decltype(auto)
  183. where(bool const w, T && t, F && f)
  184. {
  185. return w ? t : f;
  186. }
  187. template <class A, class B>
  188. requires (ra_irreducible<A, B>)
  189. constexpr auto operator &&(A && a, B && b)
  190. {
  191. return where(std::forward<A>(a), cast<bool>(std::forward<B>(b)), false);
  192. }
  193. template <class A, class B>
  194. requires (ra_irreducible<A, B>)
  195. constexpr auto operator ||(A && a, B && b)
  196. {
  197. return where(std::forward<A>(a), true, cast<bool>(std::forward<B>(b)));
  198. }
  199. #define DEF_SHORTCIRCUIT_BINARY_OP(OP) \
  200. template <class A, class B> \
  201. requires (ra_reducible<A, B>) \
  202. constexpr auto operator OP(A && a, B && b) \
  203. { \
  204. return FLAT(a) OP FLAT(b); \
  205. }
  206. FOR_EACH(DEF_SHORTCIRCUIT_BINARY_OP, &&, ||);
  207. #undef DEF_SHORTCIRCUIT_BINARY_OP
  208. // --------------------------------
  209. // Some whole-array reductions.
  210. // TODO First rank reductions? Variable rank reductions?
  211. // --------------------------------
  212. template <class A>
  213. constexpr bool
  214. any(A && a)
  215. {
  216. return early(map([](bool x) { return std::make_tuple(x, x); }, std::forward<A>(a)), false);
  217. }
  218. template <class A>
  219. constexpr bool
  220. every(A && a)
  221. {
  222. return early(map([](bool x) { return std::make_tuple(!x, x); }, std::forward<A>(a)), true);
  223. }
  224. // FIXME variable rank? see J 'index of' (x i. y), etc.
  225. template <class A>
  226. constexpr auto
  227. index(A && a)
  228. {
  229. return early(map([](auto && a, auto && i) { return std::make_tuple(bool(a), i); },
  230. std::forward<A>(a), ra::iota(ra::start(a).len(0))),
  231. ra::dim_t(-1));
  232. }
  233. // [ma108]
  234. template <class A, class B>
  235. constexpr bool
  236. lexicographical_compare(A && a, B && b)
  237. {
  238. return early(map([](auto && a, auto && b)
  239. { return a==b ? std::make_tuple(false, true) : std::make_tuple(true, a<b); },
  240. a, b),
  241. false);
  242. }
  243. // FIXME only works with numeric types.
  244. template <class A>
  245. constexpr auto
  246. amin(A && a)
  247. {
  248. using std::min;
  249. using T = value_t<A>;
  250. T c = std::numeric_limits<T>::has_infinity ? std::numeric_limits<T>::infinity() : std::numeric_limits<T>::max();
  251. for_each([&c](auto && a) { if (a<c) { c = a; } }, a);
  252. return c;
  253. }
  254. template <class A>
  255. constexpr auto
  256. amax(A && a)
  257. {
  258. using std::max;
  259. using T = value_t<A>;
  260. T c = std::numeric_limits<T>::has_infinity ? -std::numeric_limits<T>::infinity() : std::numeric_limits<T>::lowest();
  261. for_each([&c](auto && a) { if (c<a) { c = a; } }, a);
  262. return c;
  263. }
  264. // FIXME encapsulate this kind of reference-reduction.
  265. // FIXME expr/ply mechanism doesn't allow partial iteration (adv then continue).
  266. template <class A, class Less = std::less<value_t<A>>>
  267. constexpr decltype(auto)
  268. refmin(A && a, Less && less = std::less<value_t<A>>())
  269. {
  270. RA_CHECK(a.size()>0);
  271. decltype(auto) s = ra::start(a);
  272. auto p = &(*s.flat());
  273. for_each([&less, &p](auto & a) { if (less(a, *p)) { p = &a; } }, s);
  274. return *p;
  275. }
  276. template <class A, class Less = std::less<value_t<A>>>
  277. constexpr decltype(auto)
  278. refmax(A && a, Less && less = std::less<value_t<A>>())
  279. {
  280. RA_CHECK(a.size()>0);
  281. decltype(auto) s = ra::start(a);
  282. auto p = &(*s.flat());
  283. for_each([&less, &p](auto & a) { if (less(*p, a)) { p = &a; } }, s);
  284. return *p;
  285. }
  286. template <class A>
  287. constexpr auto
  288. sum(A && a)
  289. {
  290. auto c = concrete_type<value_t<A>>(0);
  291. for_each([&c](auto && a) { c += a; }, a);
  292. return c;
  293. }
  294. template <class A>
  295. constexpr auto
  296. prod(A && a)
  297. {
  298. auto c = concrete_type<value_t<A>>(1);
  299. for_each([&c](auto && a) { c *= a; }, a);
  300. return c;
  301. }
  302. template <class A> constexpr auto reduce_sqrm(A && a) { return sum(sqrm(a)); }
  303. template <class A> constexpr auto norm2(A && a) { return std::sqrt(reduce_sqrm(a)); }
  304. template <class A, class B>
  305. constexpr auto
  306. dot(A && a, B && b)
  307. {
  308. std::decay_t<decltype(FLAT(a) * FLAT(b))> c(0.);
  309. for_each([&c](auto && a, auto && b)
  310. {
  311. #ifdef FP_FAST_FMA
  312. c = fma(a, b, c);
  313. #else
  314. c += a*b;
  315. #endif
  316. }, a, b);
  317. return c;
  318. }
  319. template <class A, class B>
  320. constexpr auto
  321. cdot(A && a, B && b)
  322. {
  323. std::decay_t<decltype(conj(FLAT(a)) * FLAT(b))> c(0.);
  324. for_each([&c](auto && a, auto && b)
  325. {
  326. #ifdef FP_FAST_FMA
  327. c = fma_conj(a, b, c);
  328. #else
  329. c += conj(a)*b;
  330. #endif
  331. }, a, b);
  332. return c;
  333. }
  334. // --------------------
  335. // Other whole-array ops.
  336. // --------------------
  337. template <class A>
  338. constexpr auto
  339. normv(A const & a)
  340. {
  341. auto b = concrete(a);
  342. b /= norm2(b);
  343. return b;
  344. }
  345. // FIXME benchmark w/o allocation and do Small/Big versions if it's worth it.
  346. template <class A, class B, class C>
  347. constexpr void
  348. gemm(A const & a, B const & b, C & c)
  349. {
  350. for_each(ra::wrank<1, 1, 2>(ra::wrank<1, 0, 1>([](auto && c, auto && a, auto && b) { c += a*b; })), c, a, b);
  351. }
  352. #define MMTYPE decltype(from(std::multiplies<>(), a(ra::all, 0), b(0)))
  353. // default for row-major x row-major. See bench-gemm.cc for variants.
  354. template <class S, class T>
  355. constexpr auto
  356. gemm(ra::View<S, 2> const & a, ra::View<T, 2> const & b)
  357. {
  358. int M = a.len(0);
  359. int N = b.len(1);
  360. int K = a.len(1);
  361. // no with_same_shape bc cannot index 0 for type if A/B are empty
  362. auto c = with_shape<MMTYPE>({M, N}, decltype(std::declval<S>()*std::declval<T>())());
  363. for (int k=0; k<K; ++k) {
  364. c += from(std::multiplies<>(), a(ra::all, k), b(k));
  365. }
  366. return c;
  367. }
  368. // we still want the Small version to be different.
  369. template <class A, class B>
  370. constexpr ra::Small<std::decay_t<decltype(FLAT(std::declval<A>()) * FLAT(std::declval<B>()))>, A::len(0), B::len(1)>
  371. gemm(A const & a, B const & b)
  372. {
  373. constexpr int M = a.len(0);
  374. constexpr int N = b.len(1);
  375. // no with_same_shape bc cannot index 0 for type if A/B are empty
  376. auto c = with_shape<MMTYPE>({M, N}, ra::none);
  377. for (int i=0; i<M; ++i) {
  378. for (int j=0; j<N; ++j) {
  379. c(i, j) = dot(a(i), b(ra::all, j));
  380. }
  381. }
  382. return c;
  383. }
  384. #undef MMTYPE
  385. template <class A, class B>
  386. constexpr auto
  387. gevm(A const & a, B const & b)
  388. {
  389. int const M = b.len(0);
  390. int const N = b.len(1);
  391. // no with_same_shape bc cannot index 0 for type if A/B are empty
  392. auto c = with_shape<decltype(a[0]*b(0))>({N}, 0);
  393. for (int i=0; i<M; ++i) {
  394. c += a[i]*b(i);
  395. }
  396. return c;
  397. }
  398. // FIXME a must be a view, so it doesn't work with e.g. gemv(conj(a), b).
  399. template <class A, class B>
  400. constexpr auto
  401. gemv(A const & a, B const & b)
  402. {
  403. int const M = a.len(0);
  404. int const N = a.len(1);
  405. // no with_same_shape bc cannot index 0 for type if A/B are empty
  406. auto c = with_shape<decltype(a(ra::all, 0)*b[0])>({M}, 0);
  407. for (int j=0; j<N; ++j) {
  408. c += a(ra::all, j) * b[j];
  409. }
  410. return c;
  411. }
  412. } // namespace ra
  413. #undef RA_OPT