atom.hh 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. // -*- mode: c++; coding: utf-8 -*-
  2. // ra-ra - Terminal nodes for expression templates.
  3. // (c) Daniel Llorens - 2011-2023
  4. // This library is free software; you can redistribute it and/or modify it under
  5. // the terms of the GNU General Public License as published by the Free
  6. // Software Foundation; either version 3 of the License, or (at your option) any
  7. // later version.
  8. #pragma once
  9. #include <vector>
  10. #include <utility>
  11. // --------------------
  12. // error function
  13. // --------------------
  14. #include <cassert>
  15. #include "bootstrap.hh"
  16. // If you define your own RA_ASSERT, you might remove this from here.
  17. #include <iostream>
  18. // https://en.cppreference.com/w/cpp/preprocessor/replace
  19. // See examples/throw.cc for how to override this RA_ASSERT.
  20. #ifndef RA_ASSERT
  21. #define RA_ASSERT(cond, ...) \
  22. { \
  23. if (std::is_constant_evaluated()) { \
  24. assert(cond /* FIXME maybe one day */); \
  25. } else { \
  26. if (bool c = cond; !c) [[unlikely]] { \
  27. std::cerr << ra::format("**** ra: ", ##__VA_ARGS__, " ****") << std::endl; \
  28. assert(c); \
  29. } \
  30. } \
  31. }
  32. #endif
  33. #if defined(RA_DO_CHECK) && RA_DO_CHECK==0
  34. #define RA_CHECK( ... )
  35. #else
  36. #define RA_CHECK( ... ) RA_ASSERT( __VA_ARGS__ )
  37. #endif
  38. #define RA_AFTER_CHECK Yes
  39. namespace ra {
  40. // --------------------
  41. // global introspection I
  42. // --------------------
  43. template <class V>
  44. requires (!std::is_void_v<V>)
  45. constexpr dim_t
  46. rank_s()
  47. {
  48. using dV = std::decay_t<V>;
  49. if constexpr (requires { dV::rank_s(); }) {
  50. return dV::rank_s();
  51. } else if constexpr (requires { ra_traits<V>::rank_s(); }) {
  52. return ra_traits<V>::rank_s();
  53. } else {
  54. return 0;
  55. }
  56. }
  57. template <class V>
  58. constexpr rank_t
  59. rank_s(V const &)
  60. {
  61. return rank_s<V>();
  62. }
  63. template <class V>
  64. requires (!std::is_void_v<V>)
  65. constexpr dim_t
  66. size_s()
  67. {
  68. using dV = std::decay_t<V>;
  69. if constexpr (requires { dV::size_s(); }) {
  70. return dV::size_s();
  71. } else if constexpr (requires { ra_traits<V>::size_s(); }) {
  72. return ra_traits<V>::size_s();
  73. } else {
  74. if constexpr (RANK_ANY==rank_s<V>()) {
  75. return DIM_ANY;
  76. // make it work for non-registered types.
  77. } else if constexpr (0==rank_s<V>()) {
  78. return 1;
  79. } else {
  80. dim_t s = 1;
  81. for (int i=0; i!=dV::rank_s(); ++i) {
  82. if (dim_t ss=dV::len_s(i); ss>=0) {
  83. s *= ss;
  84. } else {
  85. return ss; // either DIM_ANY or DIM_BAD
  86. }
  87. }
  88. return s;
  89. }
  90. }
  91. }
  92. template <class V>
  93. constexpr dim_t
  94. size_s(V const &)
  95. {
  96. return size_s<V>();
  97. }
  98. template <class V>
  99. constexpr rank_t
  100. rank(V const & v)
  101. {
  102. if constexpr (requires { v.rank(); }) {
  103. return v.rank();
  104. } else if constexpr (requires { ra_traits<V>::rank(v); }) {
  105. return ra_traits<V>::rank(v);
  106. } else {
  107. static_assert(!std::is_same_v<V, V>, "No rank() for this type.");
  108. }
  109. }
  110. template <class V>
  111. constexpr dim_t
  112. size(V const & v)
  113. {
  114. if constexpr (requires { v.size(); }) {
  115. return v.size();
  116. } else if constexpr (requires { ra_traits<V>::size(v); }) {
  117. return ra_traits<V>::size(v);
  118. } else {
  119. dim_t s = 1;
  120. for (rank_t k=0; k<rank(v); ++k) { s *= v.len(k); }
  121. return s;
  122. }
  123. }
  124. // To avoid, prefer implicit matching.
  125. template <class V>
  126. constexpr decltype(auto)
  127. shape(V const & v)
  128. {
  129. if constexpr (requires { v.shape(); }) {
  130. return v.shape();
  131. } else if constexpr (requires { ra_traits<V>::shape(v); }) {
  132. return ra_traits<V>::shape(v);
  133. } else if constexpr (constexpr rank_t rs=rank_s<V>(); rs>=0) {
  134. Small<dim_t, rs> s;
  135. for (rank_t k=0; k<rs; ++k) { s[k] = v.len(k); }
  136. return s;
  137. } else {
  138. static_assert(RANK_ANY==rs);
  139. rank_t r = v.rank();
  140. std::vector<dim_t> s(r);
  141. for (rank_t k=0; k<r; ++k) { s[k] = v.len(k); }
  142. return s;
  143. }
  144. }
  145. // To handle arrays of static/dynamic size.
  146. template <class A>
  147. inline void
  148. resize(A & a, dim_t s)
  149. {
  150. if constexpr (DIM_ANY==size_s<A>()) {
  151. a.resize(s);
  152. } else {
  153. RA_CHECK(s==dim_t(a.len_s(0)), "Bad resize ", s, " vs ", a.len_s(0));
  154. }
  155. }
  156. // --------------------
  157. // atom types
  158. // --------------------
  159. // IteratorConcept for rank 0 object. This can be used on foreign objects, or as an alternative to the rank conjunction.
  160. // We still want f(C) to be a specialization in most cases (ie avoid ply(f, C) when C is rank 0).
  161. template <class C>
  162. struct Scalar
  163. {
  164. C c;
  165. RA_DEF_ASSIGNOPS_DEFAULT_SET
  166. constexpr static rank_t rank_s() { return 0; }
  167. constexpr static rank_t rank() { return 0; }
  168. constexpr static dim_t len_s(int k) { RA_CHECK(k<0, "Bad axis ", k); std::abort(); }
  169. constexpr static dim_t len(int k) { RA_CHECK(k<0, "Bad axis ", k); std::abort(); }
  170. constexpr static void adv(rank_t k, dim_t d) {}
  171. constexpr static dim_t step(int k) { return 0; }
  172. constexpr static bool keep_step(dim_t st, int z, int j) { return true; }
  173. constexpr decltype(auto) flat() const { return *this; } // [ra39]
  174. constexpr decltype(auto) at(auto && j) const { return c; }
  175. // use self as Flat
  176. constexpr void operator+=(dim_t d) const {}
  177. constexpr C & operator*() { return this->c; }
  178. constexpr C const & operator*() const { return this->c; } // [ra39]
  179. };
  180. template <class C> constexpr auto scalar(C && c) { return Scalar<C> { std::forward<C>(c) }; }
  181. // IteratorConcept for foreign rank 1 objects.
  182. template <std::random_access_iterator I, dim_t N>
  183. struct Ptr
  184. {
  185. static_assert(N>=0 || N==DIM_BAD || N==DIM_ANY);
  186. I i;
  187. [[no_unique_address]] std::conditional_t<N==DIM_ANY, dim_t, int_c<N>> n;
  188. constexpr Ptr(I i) requires (N!=DIM_ANY): i(i) {}
  189. constexpr Ptr(I i, dim_t n) requires (N==DIM_ANY): i(i), n(n) {}
  190. RA_DEF_ASSIGNOPS_SELF(Ptr)
  191. RA_DEF_ASSIGNOPS_DEFAULT_SET
  192. constexpr static rank_t rank_s() { return 1; };
  193. constexpr static rank_t rank() { return 1; }
  194. constexpr static dim_t len_s(int k) { RA_CHECK(k==0, "Bad axis ", k); return N; }
  195. constexpr static dim_t len(int k) requires (N!=DIM_ANY) { RA_CHECK(k==0, "Bad axis ", k); return N; }
  196. constexpr dim_t len(int k) const requires (N==DIM_ANY) { RA_CHECK(k==0, "Bad axis ", k); return n; }
  197. constexpr static dim_t step(int k) { return k==0 ? 1 : 0; }
  198. constexpr static bool keep_step(dim_t st, int z, int j) { return st*step(z)==step(j); }
  199. constexpr void adv(rank_t k, dim_t d) { i += step(k) * d; }
  200. constexpr auto flat() const { return i; }
  201. constexpr decltype(auto) at(auto && j) const
  202. {
  203. RA_CHECK(DIM_BAD==N || inside(j[0], len(0)), " j ", j[0], " size ", len(0));
  204. return i[j[0]];
  205. }
  206. };
  207. template <class I> constexpr auto ptr(I i) { return Ptr<I, DIM_BAD> { i }; }
  208. template <class I, int N> constexpr auto ptr(I i, int_c<N>) { return Ptr<I, N> { i }; }
  209. template <class I> constexpr auto ptr(I i, dim_t n) { return Ptr<I, DIM_ANY> { i, n }; }
  210. template <std::ranges::random_access_range V> constexpr auto
  211. vector(V && v)
  212. {
  213. if constexpr (constexpr dim_t s = size_s<V>(); DIM_ANY==s) {
  214. return ptr(std::begin(std::forward<V>(v)), std::ssize(v));
  215. } else {
  216. return ptr(std::begin(std::forward<V>(v)), int_c<s> {});
  217. }
  218. }
  219. // Sequence and IteratorConcept for same. Iota isn't really an atom, but its exprs must all have rank 0 so it kind of is.
  220. // FIXME Sequence should be its own type, we can't represent a ct origin bc IteratorConcept interface takes up i.
  221. template <int w, class O, class N, class S>
  222. struct Iota
  223. {
  224. constexpr static dim_t nn = [] { if constexpr (is_constant<N>) { return N::value; } else { return DIM_ANY; } }();
  225. constexpr static dim_t ss = [] { if constexpr (is_constant<S>) { return S::value; } else { return DIM_ANY; } }();
  226. static_assert(w>=0);
  227. static_assert(is_constant<N> || 0==rank_s<N>());
  228. static_assert(is_constant<S> || 0==rank_s<S>());
  229. static_assert(nn>=0 || nn==DIM_BAD || (!is_constant<N> && nn==DIM_ANY)); // forbid N dim_c<DIM_ANY>
  230. O i = {};
  231. [[no_unique_address]] N const n = {};
  232. [[no_unique_address]] S const s = {};
  233. constexpr static O gets() requires (is_constant<S>) { return ss; }
  234. constexpr O gets() const requires (!is_constant<S>) { return s; }
  235. struct Flat
  236. {
  237. O i;
  238. S s;
  239. constexpr void operator+=(dim_t d) { i += O(d)*O(s); }
  240. constexpr auto operator*() const { return i; }
  241. };
  242. constexpr static rank_t rank_s() { return w+1; };
  243. constexpr static rank_t rank() { return w+1; }
  244. constexpr static dim_t len_s(int k) { RA_CHECK(k<=w, "Bad axis", k); return k==w ? nn : DIM_BAD; }
  245. constexpr static dim_t len(int k) requires (nn!=DIM_ANY) { RA_CHECK(k<=w, "Bad axis ", k); return k==w ? nn : DIM_BAD; }
  246. constexpr dim_t len(int k) const requires (nn==DIM_ANY) { RA_CHECK(k<=w, "Bad axis ", k); return k==w ? n : DIM_BAD; }
  247. constexpr static dim_t step(rank_t k) { return k==w ? 1 : 0; }
  248. constexpr static bool keep_step(dim_t st, int z, int j) { return st*step(z)==step(j); }
  249. constexpr void adv(rank_t k, dim_t d) { i += O(step(k) * d) * O(s); }
  250. constexpr auto flat() const { return Flat { i, s }; }
  251. constexpr auto at(auto && j) const { return i + O(j[w])*O(s); }
  252. };
  253. template <int w> using TensorIndex = Iota<w, dim_t, dim_c<DIM_BAD>, dim_c<1>>;
  254. template <class T>
  255. constexpr auto
  256. default_one()
  257. {
  258. if constexpr (std::is_integral_v<T>) {
  259. return T(1);
  260. } else if constexpr (is_constant<T>) {
  261. static_assert(1==T::value);
  262. return T {};
  263. }
  264. }
  265. template <int w=0, class O=dim_t, class N=dim_c<DIM_BAD>, class S=dim_c<1>>
  266. constexpr auto
  267. iota(N && n = N {}, O && org = 0, S && s = default_one<S>())
  268. {
  269. if constexpr (std::is_integral_v<N>) {
  270. RA_CHECK(n>=0, "Bad iota length ", n);
  271. }
  272. using OO = std::conditional_t<is_constant<std::decay_t<O>> || is_scalar<std::decay_t<O>>, std::decay_t<O>, O>;
  273. using NN = std::conditional_t<is_constant<std::decay_t<N>> || is_scalar<std::decay_t<N>>, std::decay_t<N>, N>;
  274. using SS = std::conditional_t<is_constant<std::decay_t<S>> || is_scalar<std::decay_t<S>>, std::decay_t<S>, S>;
  275. return Iota<w, OO, NN, SS> { std::forward<O>(org), std::forward<N>(n), std::forward<S>(s) };
  276. }
  277. #define DEF_TENSORINDEX(w) constexpr TensorIndex<w> JOIN(_, w);
  278. FOR_EACH(DEF_TENSORINDEX, 0, 1, 2, 3, 4);
  279. #undef DEF_TENSORINDEX
  280. // Never ply(), solely to be rewritten.
  281. struct Len
  282. {
  283. constexpr static rank_t rank_s() { return 0; }
  284. constexpr static rank_t rank() { return 0; }
  285. constexpr static dim_t len_s(int k) { std::abort(); }
  286. constexpr static dim_t len(int k) { std::abort(); }
  287. constexpr static void adv(rank_t k, dim_t d) { std::abort(); }
  288. constexpr static dim_t step(int k) { std::abort(); }
  289. constexpr static bool keep_step(dim_t st, int z, int j) { std::abort(); }
  290. constexpr static Len const & flat() { std::abort(); }
  291. constexpr void operator+=(dim_t d) const { std::abort(); }
  292. constexpr dim_t operator*() const { std::abort(); }
  293. };
  294. constexpr Len len {};
  295. // let operators build expr trees.
  296. static_assert(IteratorConcept<Len>);
  297. template <> constexpr bool is_special_def<Len> = true;
  298. RA_IS_DEF(has_len, false);
  299. // --------------
  300. // coerce potential Iterators
  301. // --------------
  302. template <class T>
  303. constexpr void
  304. start(T && t) { static_assert(mp::always_false<T>, "Type cannot be start()ed."); }
  305. // undefined len iota (ti) is excluded from optimization and beating to allow e.g. B = A(... ti ...).
  306. // FIXME find a way?
  307. RA_IS_DEF(is_iota, false)
  308. template <class O, class N, class S>
  309. constexpr bool is_iota_def<Iota<0, O, N, S>> = (DIM_BAD != Iota<0, O, N, S>::nn);
  310. template <class T> requires (is_foreign_vector<T>)
  311. constexpr auto
  312. start(T && t) { return ra::vector(std::forward<T>(t)); }
  313. template <class T> requires (is_scalar<T>)
  314. constexpr auto
  315. start(T && t) { return ra::scalar(std::forward<T>(t)); }
  316. template <class T>
  317. constexpr auto
  318. start(std::initializer_list<T> v) { return ptr(v.begin(), v.size()); }
  319. // forward declare for Match; implemented in small.hh.
  320. template <class T> requires (is_builtin_array<T>)
  321. constexpr auto
  322. start(T && t);
  323. // neither CellBig nor CellSmall will retain rvalues [ra4].
  324. template <class T> requires (is_slice<T>)
  325. constexpr auto
  326. start(T && t) { return iter<0>(std::forward<T>(t)); }
  327. RA_IS_DEF(is_ra_scalar, (std::same_as<A, Scalar<decltype(std::declval<A>().c)>>))
  328. template <class T> requires (is_ra_scalar<T>)
  329. constexpr decltype(auto)
  330. start(T && t) { return std::forward<T>(t); }
  331. // iterators need to be restarted on each use (eg ra::cross()) [ra35].
  332. template <class T> requires (is_iterator<T> && !is_ra_scalar<T>)
  333. constexpr auto
  334. start(T && t) { return std::forward<T>(t); }
  335. // --------------------
  336. // global introspection II
  337. // --------------------
  338. // also used to paper over Scalar<X> vs X
  339. template <class A>
  340. constexpr decltype(auto)
  341. FLAT(A && a)
  342. {
  343. if constexpr (is_scalar<A>) {
  344. return std::forward<A>(a); // avoid dangling temp in this case [ra8]
  345. } else {
  346. return *(ra::start(std::forward<A>(a)).flat());
  347. }
  348. }
  349. // FIXME do we really want to drop const? See use in concrete_type.
  350. template <class A> using value_t = std::decay_t<decltype(FLAT(std::declval<A>()))>;
  351. } // namespace ra