small.hh 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742
  1. // -*- mode: c++; coding: utf-8 -*-
  2. // ra-ra - Arrays with static lengths/strides, cf big.hh.
  3. // (c) Daniel Llorens - 2013-2023
  4. // This library is free software; you can redistribute it and/or modify it under
  5. // the terms of the GNU Lesser General Public License as published by the Free
  6. // Software Foundation; either version 3 of the License, or (at your option) any
  7. // later version.
  8. #pragma once
  9. #include "ply.hh"
  10. #include "expr.hh"
  11. namespace ra {
  12. // --------------------
  13. // Helpers for slicing
  14. // --------------------
  15. // FIXME condition should be zero rank, maybe convertibility, not is_integral
  16. template <class T> constexpr bool is_scalar_index = std::is_integral_v<std::decay_t<T>>;
  17. template <class I>
  18. struct is_beatable_def
  19. {
  20. constexpr static bool value = is_scalar_index<I>;
  21. constexpr static int skip_src = 1;
  22. constexpr static int skip = 0;
  23. constexpr static bool static_p = value; // can the beating be resolved statically?
  24. };
  25. template <class I> requires (is_iota<I>)
  26. struct is_beatable_def<I>
  27. {
  28. using T = decltype(I::i);
  29. constexpr static bool value = is_scalar_index<T> && (DIM_BAD != I::len_s(0));
  30. constexpr static int skip_src = 1;
  31. constexpr static int skip = 1;
  32. constexpr static bool static_p = false; // FIXME see Iota with ct N, S
  33. };
  34. // FIXME have a 'filler' version (e.g. with default n = -1) or maybe a distinct type.
  35. template <int n>
  36. struct is_beatable_def<dots_t<n>>
  37. {
  38. static_assert(n>=0, "bad count for dots_n");
  39. constexpr static bool value = (n>=0);
  40. constexpr static int skip_src = n;
  41. constexpr static int skip = n;
  42. constexpr static bool static_p = true;
  43. };
  44. template <int n>
  45. struct is_beatable_def<insert_t<n>>
  46. {
  47. static_assert(n>=0, "bad count for dots_n");
  48. constexpr static bool value = (n>=0);
  49. constexpr static int skip_src = 0;
  50. constexpr static int skip = n;
  51. constexpr static bool static_p = true;
  52. };
  53. template <class I> using is_beatable = is_beatable_def<std::decay_t<I>>;
  54. // --------------------
  55. // Develop indices for Small
  56. // --------------------
  57. namespace indexer0 {
  58. template <class lens, class steps, class P, rank_t end, rank_t k=0>
  59. constexpr dim_t index(P const & p)
  60. {
  61. if constexpr (k==end) {
  62. return 0;
  63. } else {
  64. RA_CHECK(inside(p[k], mp::ref<lens, k>::value));
  65. return (p[k] * mp::ref<steps, k>::value) + index<lens, steps, P, end, k+1>(p);
  66. }
  67. }
  68. template <class lens, class steps, class P>
  69. constexpr dim_t shorter(P const & p) // for Container::at().
  70. {
  71. static_assert(mp::len<lens> >= size_s<P>(), "Too many indices.");
  72. return index<lens, steps, P, size_s<P>()>(p);
  73. }
  74. template <class lens, class steps, class P>
  75. constexpr dim_t longer(P const & p) // for IteratorConcept::at().
  76. {
  77. if constexpr (RANK_ANY==size_s<P>()) {
  78. RA_CHECK(mp::len<lens> <= p.size(), "Too few indices.");
  79. } else {
  80. static_assert(mp::len<lens> <= size_s<P>(), "Too few indices.");
  81. }
  82. return index<lens, steps, P, mp::len<lens>>(p);
  83. }
  84. } // namespace indexer0
  85. // --------------------
  86. // Small iterator
  87. // --------------------
  88. // TODO Refactor with CellBig / STLIterator
  89. // Used by CellBig / CellSmall.
  90. template <class C>
  91. struct CellFlat
  92. {
  93. C c;
  94. constexpr void operator+=(dim_t const s) { c.p += s; }
  95. constexpr C & operator*() { return c; }
  96. };
  97. // V is always SmallBase<SmallView, ...>
  98. template <class V, rank_t cellr_spec=0>
  99. struct CellSmall
  100. {
  101. static_assert(cellr_spec!=RANK_ANY && cellr_spec!=RANK_BAD, "bad cell rank");
  102. constexpr static rank_t fullr = ra::rank_s<V>();
  103. constexpr static rank_t cellr = dependent_cell_rank(fullr, cellr_spec);
  104. constexpr static rank_t framer = dependent_frame_rank(fullr, cellr_spec);
  105. static_assert(cellr>=0 || cellr==RANK_ANY, "bad cell rank");
  106. static_assert(framer>=0 || framer==RANK_ANY, "bad frame rank");
  107. static_assert(fullr==cellr || gt_rank(fullr, cellr), "bad cell rank");
  108. using cell_lens = mp::drop<typename V::lens, framer>;
  109. using cell_steps = mp::drop<typename V::steps, framer>;
  110. using lens = mp::take<typename V::lens, framer>; // these are steps on atom_type * p !!
  111. using steps = mp::take<typename V::steps, framer>;
  112. using atom_type = std::remove_reference_t<decltype(*(std::declval<V>().data()))>;
  113. using cell_type = SmallView<atom_type, cell_lens, cell_steps>;
  114. using value_type = std::conditional_t<0==cellr, atom_type, cell_type>;
  115. cell_type c;
  116. constexpr CellSmall(CellSmall const & ci): c { ci.c.p } {}
  117. // see STLIterator for the case of s_[0]=0, etc. [ra12].
  118. constexpr CellSmall(atom_type * p_): c { p_ } {}
  119. RA_DEF_ASSIGNOPS_DEFAULT_SET
  120. constexpr static rank_t rank_s() { return framer; }
  121. constexpr static rank_t rank() { return framer; }
  122. constexpr static dim_t len_s(int k) { RA_CHECK(inside(k, rank_s())); return V::len(k); }
  123. constexpr static dim_t len(int k) { RA_CHECK(inside(k, rank())); return V::len(k); }
  124. constexpr static dim_t step(int k) { return k<rank() ? V::step(k) : 0; }
  125. constexpr static bool keep_step(dim_t st, int z, int j) { return st*step(z)==step(j); }
  126. constexpr void adv(rank_t k, dim_t d) { c.p += step(k)*d; }
  127. constexpr auto
  128. flat() const
  129. {
  130. if constexpr (0==cellr) {
  131. return c.p;
  132. } else {
  133. return CellFlat<cell_type> { c };
  134. }
  135. }
  136. constexpr decltype(auto)
  137. at(auto const & i) const
  138. {
  139. if constexpr (0==cellr) {
  140. return c.p[indexer0::longer<lens, steps>(i)];
  141. } else {
  142. return cell_type(c.p + indexer0::longer<lens, steps>(i));
  143. }
  144. }
  145. };
  146. // --------------------
  147. // STLIterator for both CellSmall & CellBig
  148. // FIXME make it work for any array iterator, as in ply_ravel, ply_index.
  149. // --------------------
  150. template <class S, class I, class P>
  151. constexpr void
  152. next_in_cube(rank_t const framer, S const & dimv, I & i, P & p)
  153. {
  154. for (int k=framer-1; k>=0; --k) {
  155. ++i[k];
  156. if (i[k]<dimv[k].len) {
  157. p += dimv[k].step;
  158. return;
  159. } else {
  160. i[k] = 0;
  161. p -= dimv[k].step*(dimv[k].len-1);
  162. }
  163. }
  164. p = nullptr;
  165. }
  166. template <int k, class lens, class steps, class I, class P>
  167. constexpr void
  168. next_in_cube(I & i, P & p)
  169. {
  170. if constexpr (k>=0) {
  171. ++i[k];
  172. if (i[k]<mp::ref<lens, k>::value) {
  173. p += mp::ref<steps, k>::value;
  174. } else {
  175. i[k] = 0;
  176. p -= mp::ref<steps, k>::value*(mp::ref<lens, k>::value-1);
  177. next_in_cube<k-1, lens, steps>(i, p);
  178. }
  179. } else {
  180. p = nullptr;
  181. }
  182. }
  183. template <class Iterator>
  184. struct STLIterator
  185. {
  186. using value_type = typename Iterator::value_type;
  187. using difference_type = dim_t;
  188. using pointer = value_type *;
  189. using reference = value_type &;
  190. using iterator_category = std::forward_iterator_tag;
  191. using shape_type = decltype(ra::shape(std::declval<Iterator>()));
  192. Iterator ii;
  193. shape_type i;
  194. STLIterator(STLIterator const & it) = default;
  195. constexpr STLIterator & operator=(STLIterator const & it)
  196. {
  197. i = it.i;
  198. ii.Iterator::~Iterator(); // no-op except for View<RANK_ANY>. Still...
  199. new (&ii) Iterator(it.ii); // avoid ii = it.ii [ra11]
  200. return *this;
  201. }
  202. STLIterator(Iterator const & ii_)
  203. : ii(ii_),
  204. // shape_type may be std::array or std::vector.
  205. i([&] {
  206. if constexpr (DIM_ANY==size_s<shape_type>()) {
  207. return shape_type(ii.rank(), 0);
  208. } else {
  209. return shape_type {0};
  210. }
  211. }())
  212. {
  213. // [ra12] Null p_ so begin()==end() for empty range. ply() uses lens so this doesn't matter.
  214. if (0==ra::size(ii)) {
  215. ii.c.p = nullptr;
  216. }
  217. };
  218. template <class PP> bool operator==(PP const & j) const { return ii.c.p==j.ii.c.p; }
  219. template <class PP> bool operator!=(PP const & j) const { return ii.c.p!=j.ii.c.p; }
  220. decltype(auto) operator*() const { if constexpr (0==Iterator::cellr) return *ii.c.p; else return ii.c; }
  221. decltype(auto) operator*() { if constexpr (0==Iterator::cellr) return *ii.c.p; else return ii.c; }
  222. STLIterator & operator++()
  223. {
  224. if constexpr (0==Iterator::rank_s()) { // when rank==0, DIM_ANY check isn't enough
  225. ii.c.p = nullptr;
  226. } else if constexpr (DIM_ANY != ra::size_s<Iterator>()) {
  227. next_in_cube<Iterator::rank()-1, typename Iterator::lens, typename Iterator::steps>(i, ii.c.p);
  228. } else {
  229. next_in_cube(ii.rank(), ii.dimv, i, ii.c.p);
  230. }
  231. return *this;
  232. }
  233. };
  234. template <class T> STLIterator<T> stl_iterator(T && t) { return STLIterator<T>(std::forward<T>(t)); }
  235. // --------------------
  236. // Base for both small view & container
  237. // --------------------
  238. template <class lens_, class steps_, class ... I>
  239. struct FilterDims
  240. {
  241. using lens = lens_;
  242. using steps = steps_;
  243. };
  244. template <class lens_, class steps_, class I0, class ... I>
  245. struct FilterDims<lens_, steps_, I0, I ...>
  246. {
  247. constexpr static int s = is_beatable<I0>::skip;
  248. constexpr static int s_src = is_beatable<I0>::skip_src;
  249. using next = FilterDims<mp::drop<lens_, s_src>, mp::drop<steps_, s_src>, I ...>;
  250. using lens = mp::append<mp::take<lens_, s>, typename next::lens>;
  251. using steps = mp::append<mp::take<steps_, s>, typename next::steps>;
  252. };
  253. template <dim_t len0, dim_t step0>
  254. constexpr dim_t
  255. select(dim_t i0)
  256. {
  257. RA_CHECK(inside(i0, len0));
  258. return i0*step0;
  259. };
  260. template <dim_t len0, dim_t step0, int n>
  261. constexpr dim_t
  262. select(dots_t<n> i0)
  263. {
  264. return 0;
  265. }
  266. template <class lens, class steps>
  267. constexpr dim_t
  268. select_loop()
  269. {
  270. return 0;
  271. }
  272. template <class lens, class steps, class I0, class ... I>
  273. constexpr dim_t
  274. select_loop(I0 i0, I ... i)
  275. {
  276. constexpr int s_src = is_beatable<I0>::skip_src;
  277. return select<mp::first<lens>::value, mp::first<steps>::value>(i0)
  278. + select_loop<mp::drop<lens, s_src>, mp::drop<steps, s_src>>(i ...);
  279. }
  280. template <template <class ...> class Child_, class T_, class lens_, class steps_>
  281. struct SmallBase
  282. {
  283. using lens = lens_;
  284. using steps = steps_;
  285. using T = T_;
  286. template <class TT> using BadDimension = mp::int_c<(TT::value<0 || TT::value==DIM_ANY || TT::value==DIM_BAD)>;
  287. static_assert(!mp::apply<mp::orb, mp::map<BadDimension, lens>>::value, "Negative dimensions.");
  288. static_assert(mp::len<lens> == mp::len<steps>, "Mismatched lengths & steps."); // TODO static check on steps.
  289. using Child = Child_<T, lens, steps>;
  290. constexpr static rank_t rank() { return mp::len<lens>; }
  291. constexpr static rank_t rank_s() { return mp::len<lens>; }
  292. constexpr static dim_t size() { return mp::apply<mp::prod, lens>::value; }
  293. constexpr static dim_t size_s() { return size(); }
  294. constexpr static auto slens = mp::tuple_values<std::array<dim_t, rank()>, lens>();
  295. constexpr static auto ssteps = mp::tuple_values<std::array<dim_t, rank()>, steps>();
  296. constexpr static dim_t len(int k) { return slens[k]; }
  297. constexpr static dim_t len_s(int k) { return slens[k]; }
  298. constexpr static dim_t step(int k) { return ssteps[k]; }
  299. constexpr static auto shape() { return SmallView<ra::dim_t const, mp::int_list<rank_s()>, mp::int_list<1>>(slens.data()); }
  300. // allowing rank 1 for coord types
  301. constexpr static bool convertible_to_scalar = size()==1; // rank()==0 || (rank()==1 && size()==1);
  302. #define RA_CONST_OR_NOT(CONST) \
  303. constexpr T CONST * data() CONST { return static_cast<Child CONST &>(*this).p; } \
  304. template <class ... I> \
  305. constexpr decltype(auto) \
  306. operator()(I && ... i) CONST \
  307. { \
  308. constexpr int scalars = (0 + ... + is_scalar_index<I>); \
  309. if constexpr (scalars<rank() && (is_beatable<I>::static_p && ...)) { \
  310. using FD = FilterDims<lens, steps, I ...>; \
  311. return SmallView<T CONST, typename FD::lens, typename FD::steps> \
  312. (data()+select_loop<lens, steps>(i ...)); \
  313. } else if constexpr (scalars==rank()) { \
  314. return data()[select_loop<lens, steps>(i ...)]; \
  315. } else if constexpr ((!is_beatable<I>::static_p || ...)) { /* TODO More than one selector... */ \
  316. return from(*this, std::forward<I>(i) ...); \
  317. } else { \
  318. static_assert(mp::always_false<I ...>); /* p2593r0 */ \
  319. } \
  320. } \
  321. /* BUG I must be fixed size, otherwise we can't make out the output type. */ \
  322. template <class I> \
  323. constexpr decltype(auto) \
  324. at(I const & i) CONST \
  325. { \
  326. return SmallView<T CONST, mp::drop<lens, ra::size_s<I>()>, mp::drop<steps, ra::size_s<I>()>> \
  327. (data()+indexer0::shorter<lens, steps>(i)); \
  328. } \
  329. template <class ... I> \
  330. constexpr decltype(auto) \
  331. operator[](I && ... i) CONST \
  332. { \
  333. return (*this)(std::forward<I>(i) ...); \
  334. } \
  335. /* TODO support s(static ra::iota) */ \
  336. template <int ss, int oo=0> \
  337. constexpr auto \
  338. as() CONST \
  339. { \
  340. static_assert(rank()>=1, "bad rank for as<>"); \
  341. static_assert(ss>=0 && oo>=0 && ss+oo<=size(), "bad size for as<>"); \
  342. return SmallView<T CONST, mp::cons<mp::int_c<ss>, mp::drop1<lens>>, steps>(this->data()+oo*this->step(0)); \
  343. } \
  344. T CONST & \
  345. back() CONST \
  346. { \
  347. static_assert(rank()==1 && size()>0, "back() is not available"); \
  348. return (*this)[size()-1]; \
  349. } \
  350. constexpr operator T CONST & () CONST requires (convertible_to_scalar) { return data()[0]; }
  351. FOR_EACH(RA_CONST_OR_NOT, /*const*/, const)
  352. #undef RA_CONST_OR_NOT
  353. // see same thing for View.
  354. #define DEF_ASSIGNOPS(OP) \
  355. template <class X> \
  356. requires (!mp::is_tuple_v<std::decay_t<X>>) \
  357. constexpr Child & \
  358. operator OP(X && x) \
  359. { \
  360. ra::start(static_cast<Child &>(*this)) OP x; \
  361. return static_cast<Child &>(*this); \
  362. }
  363. FOR_EACH(DEF_ASSIGNOPS, =, *=, +=, -=, /=)
  364. #undef DEF_ASSIGNOPS
  365. // braces don't match X &&
  366. constexpr Child &
  367. operator=(nested_arg<T, lens> const & x)
  368. {
  369. ra::iter<-1>(static_cast<Child &>(*this)) = mp::from_tuple<std::array<typename nested_tuple<T, lens>::sub, len(0)>>(x);
  370. return static_cast<Child &>(*this);
  371. }
  372. // braces row-major ravel for rank!=1
  373. constexpr Child &
  374. operator=(ravel_arg<T, lens> const & x_)
  375. {
  376. auto x = mp::from_tuple<std::array<T, size()>>(x_);
  377. std::copy(x.begin(), x.end(), this->begin());
  378. return static_cast<Child &>(*this);
  379. }
  380. template <rank_t c=0> using iterator = ra::CellSmall<SmallBase<SmallView, T, lens, steps>, c>;
  381. template <rank_t c=0> using const_iterator = ra::CellSmall<SmallBase<SmallView, T const, lens, steps>, c>;
  382. template <rank_t c=0> constexpr iterator<c> iter() { return data(); }
  383. template <rank_t c=0> constexpr const_iterator<c> iter() const { return data(); }
  384. // FIXME see if we need to extend this for cellr!=0.
  385. // template <class P> using STLIterator = std::conditional_t<have_default_steps, P, STLIterator<Iterator<P>>>;
  386. constexpr static bool have_default_steps = std::same_as<steps, default_steps<lens>>;
  387. template <class I, class P> using pick_STLIterator = std::conditional_t<have_default_steps, P, ra::STLIterator<I>>;
  388. using STLIterator = pick_STLIterator<iterator<0>, T *>;
  389. using STLConstIterator = pick_STLIterator<const_iterator<0>, T const *>;
  390. // TODO begin() end() may be different types for ranged for (https://en.cppreference.com/w/cpp/language/range-for), but not for stl algos like std::copy. That's unfortunate as it would allow simplifying end().
  391. // TODO With default steps I can just return p. Make sure to test before changing this.
  392. constexpr STLIterator begin() { if constexpr (have_default_steps) return data(); else return iter(); }
  393. constexpr STLConstIterator begin() const { if constexpr (have_default_steps) return data(); else return iter(); }
  394. constexpr STLIterator end() { if constexpr (have_default_steps) return data()+size(); else return iterator<0>(nullptr); }
  395. constexpr STLConstIterator end() const { if constexpr (have_default_steps) return data()+size(); else return const_iterator<0>(nullptr); }
  396. };
  397. // ---------------------
  398. // Small view & container
  399. // ---------------------
  400. // Strides are compile time, so we can put most members in the view type.
  401. template <class T, class lens, class steps>
  402. struct SmallView: public SmallBase<SmallView, T, lens, steps>
  403. {
  404. using Base = SmallBase<SmallView, T, lens, steps>;
  405. using Base::operator=;
  406. T * p;
  407. constexpr SmallView(T * p_): p(p_) {}
  408. constexpr SmallView(SmallView const & s): p(s.p) {}
  409. constexpr operator T & () { static_assert(Base::convertible_to_scalar); return p[0]; }
  410. constexpr operator T const & () const { static_assert(Base::convertible_to_scalar); return p[0]; };
  411. };
  412. #if defined (__clang__)
  413. template <class T, int N> using extvector __attribute__((ext_vector_type(N))) = T;
  414. #else
  415. template <class T, int N> using extvector __attribute__((vector_size(N*sizeof(T)))) = T;
  416. #endif
  417. template <class Z>
  418. struct equal_to_t
  419. {
  420. template <class ... T> constexpr static bool value = (std::is_same_v<Z, T> || ...);
  421. };
  422. template <class T, size_t N>
  423. consteval size_t
  424. align_req()
  425. {
  426. if constexpr (equal_to_t<T>::template value<char, unsigned char,
  427. short, unsigned short,
  428. int, unsigned int,
  429. long, unsigned long,
  430. long long, unsigned long long,
  431. float, double>
  432. && 0<N && 0==(N & (N-1))) {
  433. return alignof(extvector<T, N>);
  434. } else {
  435. return alignof(T[N]);
  436. }
  437. }
  438. template <class T, class lens, class steps, class ... nested_args, class ... ravel_args>
  439. struct
  440. #if RA_DO_OPT_SMALLVECTOR==1
  441. alignas(align_req<T, mp::apply<mp::prod, lens>::value>())
  442. #else
  443. #endif
  444. SmallArray<T, lens, steps, std::tuple<nested_args ...>, std::tuple<ravel_args ...>>
  445. : public SmallBase<SmallArray, T, lens, steps>
  446. {
  447. using Base = SmallBase<SmallArray, T, lens, steps>;
  448. using Base::rank, Base::size;
  449. T p[Base::size()]; // cf what std::array does for zero size; wish zero size just worked :-/
  450. constexpr SmallArray() {}
  451. // braces don't match (X &&)
  452. constexpr SmallArray(nested_args const & ... x)
  453. {
  454. static_cast<Base &>(*this) = nested_arg<T, lens> { x ... };
  455. }
  456. // braces row-major ravel for rank!=1
  457. constexpr SmallArray(ravel_args const & ... x)
  458. {
  459. static_cast<Base &>(*this) = ravel_arg<T, lens> { x ... };
  460. }
  461. // needed if T isn't registered as scalar [ra44]
  462. constexpr SmallArray(T const & t)
  463. {
  464. for (auto & x: p) { x = t; }
  465. }
  466. // X && x makes this a better match than nested_args ... for 1 argument.
  467. template <class X>
  468. requires (!std::is_same_v<T, std::decay_t<X>> && !mp::is_tuple_v<std::decay_t<X>>)
  469. constexpr SmallArray(X && x)
  470. {
  471. static_cast<Base &>(*this) = x;
  472. }
  473. using View = SmallView<T, lens, steps>;
  474. using ViewConst = SmallView<T const, lens, steps>;
  475. // conversion to const
  476. constexpr operator View () { return View(p); }
  477. constexpr operator ViewConst () const { return ViewConst(p); }
  478. };
  479. template <class A0, class ... A>
  480. SmallArray(A0, A ...) -> SmallArray<A0, mp::int_list<1+sizeof...(A)>>;
  481. // FIXME unfortunately necessary. Try to remove the need, also of (S, begin, end) in Container, once the nested_tuple constructors work.
  482. template <class A, class I, class J>
  483. A ravel_from_iterators(I && begin, J && end)
  484. {
  485. A a;
  486. std::copy(std::forward<I>(begin), std::forward<J>(end), a.begin());
  487. return a;
  488. }
  489. // ---------------------
  490. // Builtin arrays
  491. // ---------------------
  492. template <class T, class I=mp::iota<std::rank_v<T>>>
  493. struct builtin_array_lens;
  494. template <class T, int ... I>
  495. struct builtin_array_lens<T, mp::int_list<I ...>>
  496. {
  497. using type = mp::int_list<std::extent_v<T, I> ...>;
  498. };
  499. template <class T> using builtin_array_lens_t = typename builtin_array_lens<T>::type;
  500. template <class T>
  501. struct builtin_array_types
  502. {
  503. using A = std::remove_volatile_t<std::remove_reference_t<T>>; // preserve const
  504. using E = std::remove_all_extents_t<A>;
  505. using lens = builtin_array_lens_t<A>;
  506. using view = SmallView<E, lens>;
  507. };
  508. // forward declared in bootstrap.hh.
  509. template <class T> requires (is_builtin_array<T>)
  510. constexpr auto
  511. start(T && t)
  512. {
  513. using Z = builtin_array_types<T>;
  514. return typename Z::view((typename Z::E *)(t)).iter();
  515. }
  516. template <class T> requires (is_builtin_array<T>)
  517. struct ra_traits_def<T>
  518. {
  519. using S = typename builtin_array_types<T>::view;
  520. constexpr static decltype(auto) shape(T const & t) { return S::shape(); }
  521. constexpr static dim_t size(T const & t) { return S::size_s(); }
  522. constexpr static dim_t size_s() { return S::size_s(); }
  523. constexpr static rank_t rank(T const & t) { return S::rank(); }
  524. constexpr static rank_t rank_s() { return S::rank_s(); }
  525. };
  526. // --------------------
  527. // Small ops; cf view-ops.hh.
  528. // FIXME maybe there, or separate file.
  529. // TODO See if this can be merged with Reframe (e.g. beat(reframe(a)) -> transpose(a) ?)
  530. // --------------------
  531. template <class A, class i>
  532. struct axis_indices
  533. {
  534. template <class T> using match_index = mp::int_c<(T::value==i::value)>;
  535. using I = mp::iota<mp::len<A>>;
  536. using type = mp::Filter_<mp::map<match_index, A>, I>;
  537. // don't enforce, so allow dead axes (e.g. in transpose<1>(rank 1 array)).
  538. // static_assert((mp::len<type>)>0, "dst axis doesn't appear in transposed axes list");
  539. };
  540. template <class axes_list, class src_lens, class src_steps>
  541. struct axes_list_indices
  542. {
  543. static_assert(mp::len<axes_list> == mp::len<src_lens>, "Bad size for transposed axes list.");
  544. constexpr static int talmax = mp::fold<mp::max, void, axes_list>::value;
  545. constexpr static int talmin = mp::fold<mp::min, void, axes_list>::value;
  546. static_assert(talmin >= 0, "Bad index in transposed axes list.");
  547. // don't enforce, so allow dead axes (e.g. in transpose<1>(rank 1 array)).
  548. // static_assert(talmax < mp::len<src_lens>, "bad index in transposed axes list");
  549. template <class dst_i> struct dst_indices_
  550. {
  551. using type = typename axis_indices<axes_list, dst_i>::type;
  552. template <class i> using lensi = mp::ref<src_lens, i::value>;
  553. template <class i> using stepsi = mp::ref<src_steps, i::value>;
  554. using step = mp::fold<mp::sum, void, mp::map<stepsi, type>>;
  555. using len = mp::fold<mp::min, void, mp::map<lensi, type>>;
  556. };
  557. template <class dst_i> using dst_indices = typename dst_indices_<dst_i>::type;
  558. template <class dst_i> using dst_len = typename dst_indices_<dst_i>::len;
  559. template <class dst_i> using dst_step = typename dst_indices_<dst_i>::step;
  560. using dst = mp::iota<(talmax>=0 ? (1+talmax) : 0)>;
  561. using type = mp::map<dst_indices, dst>;
  562. using lens = mp::map<dst_len, dst>;
  563. using steps = mp::map<dst_step, dst>;
  564. };
  565. #define DEF_TRANSPOSE(CONST) \
  566. template <int ... Iarg, template <class ...> class Child, class T, class lens, class steps> \
  567. constexpr auto \
  568. transpose(SmallBase<Child, T, lens, steps> CONST & a) \
  569. { \
  570. using ti = axes_list_indices<mp::int_list<Iarg ...>, lens, steps>; \
  571. return SmallView<T CONST, typename ti::lens, typename ti::steps>(a.data()); \
  572. }; \
  573. \
  574. template <template <class ...> class Child, class T, class lens, class steps> \
  575. constexpr auto \
  576. diag(SmallBase<Child, T, lens, steps> CONST & a) \
  577. { \
  578. return transpose<0, 0>(a); \
  579. }
  580. FOR_EACH(DEF_TRANSPOSE, /* const */, const)
  581. #undef DEF_TRANSPOSE
  582. // TODO Used by ProductRule; waiting for proper generalization.
  583. template <template <class ...> class Child1, class T1, class lens1, class steps1,
  584. template <class ...> class Child2, class T2, class lens2, class steps2>
  585. constexpr auto
  586. cat(SmallBase<Child1, T1, lens1, steps1> const & a1, SmallBase<Child2, T2, lens2, steps2> const & a2)
  587. {
  588. using A1 = SmallBase<Child1, T1, lens1, steps1>;
  589. using A2 = SmallBase<Child2, T2, lens2, steps2>;
  590. static_assert(A1::rank()==1 && A2::rank()==1, "Bad ranks for cat."); // gcc accepts a1.rank(), etc.
  591. using T = std::decay_t<decltype(a1[0])>;
  592. Small<T, A1::size()+A2::size()> val;
  593. std::copy(a1.begin(), a1.end(), val.begin());
  594. std::copy(a2.begin(), a2.end(), val.begin()+a1.size());
  595. return val;
  596. }
  597. template <template <class ...> class Child1, class T1, class lens1, class steps1, class A2>
  598. requires (is_scalar<A2>)
  599. constexpr auto
  600. cat(SmallBase<Child1, T1, lens1, steps1> const & a1, A2 const & a2)
  601. {
  602. using A1 = SmallBase<Child1, T1, lens1, steps1>;
  603. static_assert(A1::rank()==1, "bad ranks for cat");
  604. using T = std::decay_t<decltype(a1[0])>;
  605. Small<T, A1::size()+1> val;
  606. std::copy(a1.begin(), a1.end(), val.begin());
  607. val[a1.size()] = a2;
  608. return val;
  609. }
  610. template <class A1, template <class ...> class Child2, class T2, class lens2, class steps2>
  611. requires (is_scalar<A1>)
  612. constexpr auto
  613. cat(A1 const & a1, SmallBase<Child2, T2, lens2, steps2> const & a2)
  614. {
  615. using A2 = SmallBase<Child2, T2, lens2, steps2>;
  616. static_assert(A2::rank()==1, "bad ranks for cat");
  617. using T = std::decay_t<decltype(a2[0])>;
  618. Small<T, 1+A2::size()> val;
  619. val[0] = a1;
  620. std::copy(a2.begin(), a2.end(), val.begin()+1);
  621. return val;
  622. }
  623. // FIXME should be local (constexpr lambda + mp::apply?)
  624. template <int s> struct explode_divop
  625. {
  626. template <class T> struct op_
  627. {
  628. static_assert((T::value/s)*s==T::value);
  629. using type = mp::int_c<T::value / s>;
  630. };
  631. template <class T> using op = typename op_<T>::type;
  632. };
  633. // See view-ops.hh:explode, collapse. FIXME support real->complex, etc.
  634. template <class super_t,
  635. template <class ...> class Child, class T, class lens, class steps>
  636. constexpr auto
  637. explode(SmallBase<Child, T, lens, steps> & a)
  638. {
  639. using ta = SmallBase<Child, T, lens, steps>;
  640. // the returned type has steps in super_t, but to support general steps we'd need steps in T. Maybe FIXME?
  641. static_assert(super_t::have_default_steps);
  642. constexpr rank_t ra = ta::rank_s();
  643. constexpr rank_t rb = super_t::rank_s();
  644. static_assert(std::is_same_v<mp::drop<lens, ra-rb>, typename super_t::lens>);
  645. static_assert(std::is_same_v<mp::drop<steps, ra-rb>, typename super_t::steps>);
  646. using csteps = mp::map<explode_divop<ra::size_s<super_t>()>::template op, mp::take<steps, ra-rb>>;
  647. return SmallView<super_t, mp::take<lens, ra-rb>, csteps>((super_t *) a.data());
  648. }
  649. } // namespace ra