expr.hh 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588
  1. // -*- mode: c++; coding: utf-8 -*-
  2. // ra-ra - Expression templates with prefix matching.
  3. // (c) Daniel Llorens - 2011-2023
  4. // This library is free software; you can redistribute it and/or modify it under
  5. // the terms of the GNU Lesser General Public License as published by the Free
  6. // Software Foundation; either version 3 of the License, or (at your option) any
  7. // later version.
  8. #pragma once
  9. #include "atom.hh"
  10. #include <functional>
  11. namespace ra {
  12. // --------------------
  13. // prefix match
  14. // --------------------
  15. constexpr rank_t
  16. choose_rank(rank_t ra, rank_t rb)
  17. {
  18. return BAD==rb ? ra : BAD==ra ? rb : ANY==ra ? ra : ANY==rb ? rb : (ra>=rb ? ra : rb);
  19. }
  20. // if non-negative args don't match, pick first (see below). FIXME maybe return invalid.
  21. constexpr dim_t
  22. choose_len(dim_t sa, dim_t sb)
  23. {
  24. return BAD==sa ? sb : BAD==sb ? sa : ANY==sa ? sb : sa;
  25. }
  26. template <bool checkp, class T, class K=mp::iota<mp::len<T>>> struct Match;
  27. template <bool checkp, IteratorConcept ... P, int ... I>
  28. struct Match<checkp, std::tuple<P ...>, mp::int_list<I ...>>
  29. {
  30. using T = std::tuple<P ...>;
  31. T t;
  32. // 0: fail, 1: rt, 2: pass
  33. consteval static int
  34. check_s()
  35. {
  36. if constexpr (sizeof...(P)<2) {
  37. return 2;
  38. } else if constexpr (ANY==rank_s()) {
  39. return 1; // FIXME could be tightened to 2 in some cases
  40. } else {
  41. bool tbc = false;
  42. for (int k=0; k<rank_s(); ++k) {
  43. dim_t ls = len_s(k);
  44. if (((k<std::decay_t<P>::rank_s() && ls!=choose_len(std::decay_t<P>::len_s(k), ls)) || ...)) {
  45. return 0;
  46. } else {
  47. int anyk = ((k<std::decay_t<P>::rank_s() && (ANY==std::decay_t<P>::len_s(k))) + ...);
  48. int fixk = ((k<std::decay_t<P>::rank_s() && (0<=std::decay_t<P>::len_s(k))) + ...);
  49. tbc = tbc || (anyk>0 && anyk+fixk>1);
  50. }
  51. }
  52. return tbc ? 1 : 2;
  53. }
  54. }
  55. constexpr bool
  56. check() const
  57. {
  58. if constexpr (sizeof...(P)<2) {
  59. return true;
  60. } else if constexpr (constexpr int c = check_s(); 0==c) {
  61. return false;
  62. } else if constexpr (1==c) {
  63. for (int k=0; k<rank(); ++k) {
  64. dim_t ls = len(k);
  65. if (((k<std::get<I>(t).rank() && ls!=choose_len(std::get<I>(t).len(k), ls)) || ...)) {
  66. RA_CHECK(!checkp, "Shape mismatch [", (std::array { std::get<I>(t).len(k) ... }), "] on axis ", k, ".");
  67. return false;
  68. }
  69. }
  70. }
  71. return true;
  72. }
  73. constexpr
  74. Match(P ... p_): t(std::forward<P>(p_) ...)
  75. {
  76. // TODO Maybe on ply, would avoid the checkp, make agree_xxx() unnecessary.
  77. if constexpr (checkp && !(has_len<P> || ...)) {
  78. static_assert(check_s(), "Shape mismatch.");
  79. RA_CHECK(check());
  80. }
  81. }
  82. // rank of largest subexpr, so we look at all of them.
  83. consteval static rank_t
  84. rank_s()
  85. {
  86. rank_t r = BAD;
  87. return ((r=choose_rank(r, ra::rank_s<P>())), ...);
  88. }
  89. consteval static rank_t
  90. rank()
  91. requires (ANY != Match::rank_s())
  92. {
  93. return rank_s();
  94. }
  95. constexpr rank_t
  96. rank() const
  97. requires (ANY == Match::rank_s())
  98. {
  99. rank_t r = BAD;
  100. ((r = choose_rank(r, std::get<I>(t).rank())), ...);
  101. assert(ANY!=r); // not at runtime
  102. return r;
  103. }
  104. // first nonnegative size, if none first ANY, if none then BAD
  105. constexpr static dim_t
  106. len_s(int k)
  107. {
  108. auto f = [&k]<class A>(dim_t s) {
  109. constexpr rank_t ar = A::rank_s();
  110. return (ar<0 || k<ar) ? choose_len(s, A::len_s(k)) : s;
  111. };
  112. dim_t s = BAD; ((s>=0 ? s : s = f.template operator()<std::decay_t<P>>(s)), ...);
  113. return s;
  114. }
  115. constexpr static dim_t
  116. len(int k)
  117. requires (requires (int kk) { P::len(kk); } && ...)
  118. {
  119. return len_s(k);
  120. }
  121. constexpr dim_t
  122. len(int k) const
  123. requires (!(requires (int kk) { P::len(kk); } && ...))
  124. {
  125. auto f = [&k](dim_t s, auto const & a) {
  126. return k<a.rank() ? choose_len(s, a.len(k)) : s;
  127. };
  128. dim_t s = BAD; ((s>=0 ? s : s = f(s, std::get<I>(t))), ...);
  129. assert(ANY!=s); // not at runtime
  130. return s;
  131. }
  132. constexpr void
  133. adv(rank_t k, dim_t d)
  134. {
  135. (std::get<I>(t).adv(k, d), ...);
  136. }
  137. constexpr auto
  138. step(int i) const
  139. {
  140. return std::make_tuple(std::get<I>(t).step(i) ...);
  141. }
  142. constexpr bool
  143. keep_step(dim_t st, int z, int j) const
  144. requires (!(requires (dim_t st, rank_t z, rank_t j) { P::keep_step(st, z, j); } && ...))
  145. {
  146. return (std::get<I>(t).keep_step(st, z, j) && ...);
  147. }
  148. constexpr static bool
  149. keep_step(dim_t st, int z, int j)
  150. requires (requires (dim_t st, rank_t z, rank_t j) { P::keep_step(st, z, j); } && ...)
  151. {
  152. return (std::decay_t<P>::keep_step(st, z, j) && ...);
  153. }
  154. };
  155. // ---------------------------
  156. // reframe
  157. // ---------------------------
  158. // Reframe is a variant of transpose that works on any IteratorConcept. As in transpose(), one names
  159. // the destination axis for each original axis. However, unlike general transpose, axes may not be
  160. // repeated. The main application is the rank conjunction below.
  161. template <class T> constexpr T zerostep = 0;
  162. template <class ... T> constexpr std::tuple<T ...> zerostep<std::tuple<T ...>> = { zerostep<T> ... };
  163. // Dest is a list of destination axes [l0 l1 ... li ... l(rank(A)-1)].
  164. // The dimensions of the reframed A are numbered as [0 ... k ... max(l)-1].
  165. // If li = k for some i, then axis k of the reframed A moves on axis i of the original iterator A.
  166. // If not, then axis k of the reframed A is 'dead' and doesn't move the iterator.
  167. // TODO invalid for ANY (since Dest is compile time). [ra7]
  168. template <class Dest, IteratorConcept A>
  169. struct Reframe
  170. {
  171. A a;
  172. constexpr static int orig(int k) { return mp::int_list_index<Dest>(k); }
  173. consteval static rank_t rank_s() { return 1+mp::fold<mp::max, ic_t<-1>, Dest>::value; }
  174. consteval static rank_t rank() { return rank_s(); }
  175. constexpr static dim_t len_s(int k)
  176. {
  177. int l = orig(k);
  178. return l>=0 ? std::decay_t<A>::len_s(l) : BAD;
  179. }
  180. constexpr dim_t
  181. len(int k) const
  182. {
  183. int l = orig(k);
  184. return l>=0 ? a.len(l) : BAD;
  185. }
  186. constexpr void
  187. adv(rank_t k, dim_t d)
  188. {
  189. if (int l = orig(k); l>=0) {
  190. a.adv(l, d);
  191. }
  192. }
  193. constexpr auto
  194. step(int k) const
  195. {
  196. int l = orig(k);
  197. return l>=0 ? a.step(l) : zerostep<decltype(a.step(l))>;
  198. }
  199. constexpr bool
  200. keep_step(dim_t st, int z, int j) const
  201. {
  202. int wz = orig(z);
  203. int wj = orig(j);
  204. return wz>=0 && wj>=0 && a.keep_step(st, wz, wj);
  205. }
  206. constexpr decltype(auto)
  207. flat()
  208. {
  209. return a.flat();
  210. }
  211. constexpr decltype(auto)
  212. at(auto const & i) const
  213. {
  214. return a.at(mp::map_indices<dim_t, Dest>(i));
  215. }
  216. };
  217. // Optimize no-op case. TODO If A is CellBig, etc. beat Dest on it, same for eventual transpose_expr<>.
  218. template <class Dest, class A>
  219. constexpr decltype(auto)
  220. reframe(A && a)
  221. {
  222. if constexpr (std::is_same_v<Dest, mp::iota<1+mp::fold<mp::max, ic_t<-1>, Dest>::value>>) {
  223. return std::forward<A>(a);
  224. } else {
  225. return Reframe<Dest, A> { std::forward<A>(a) };
  226. }
  227. }
  228. // ---------------------------
  229. // verbs and rank conjunction
  230. // ---------------------------
  231. template <class cranks_, class Op_>
  232. struct Verb
  233. {
  234. using cranks = cranks_;
  235. using Op = Op_;
  236. Op op;
  237. };
  238. RA_IS_DEF(is_verb, (std::is_same_v<A, Verb<typename A::cranks, typename A::Op>>))
  239. template <class cranks, class Op>
  240. constexpr auto
  241. wrank(cranks cranks_, Op && op)
  242. {
  243. return Verb<cranks, Op> { std::forward<Op>(op) };
  244. }
  245. template <rank_t ... crank, class Op>
  246. constexpr auto
  247. wrank(Op && op)
  248. {
  249. return Verb<mp::int_list<crank ...>, Op> { std::forward<Op>(op) };
  250. }
  251. template <class V, class T, class R=mp::makelist<mp::len<T>, mp::nil>, rank_t skip=0>
  252. struct Framematch_def;
  253. template <class V, class T, class R=mp::makelist<mp::len<T>, mp::nil>, rank_t skip=0>
  254. using Framematch = Framematch_def<std::decay_t<V>, T, R, skip>;
  255. template <class A, class B>
  256. struct max_i
  257. {
  258. constexpr static int value = (A::value == choose_rank(A::value, B::value)) ? 0 : 1;
  259. };
  260. // Get a list (per argument) of lists of live axes. The last frame match is handled by standard prefix matching.
  261. template <class ... crank, class W, class ... Ti, class ... Ri, rank_t skip>
  262. struct Framematch_def<Verb<std::tuple<crank ...>, W>, std::tuple<Ti ...>, std::tuple<Ri ...>, skip>
  263. {
  264. static_assert(sizeof...(Ti)==sizeof...(crank) && sizeof...(Ti)==sizeof...(Ri), "Bad arguments.");
  265. // live = number of live axes on this frame, for each argument. // TODO crank negative, inf.
  266. using live = mp::int_list<(rank_s<Ti>() - mp::len<Ri> - crank::value) ...>;
  267. using frameaxes = std::tuple<mp::append<Ri, mp::iota<(rank_s<Ti>() - mp::len<Ri> - crank::value), skip>> ...>;
  268. using FM = Framematch<W, std::tuple<Ti ...>, frameaxes, skip + mp::ref<live, mp::indexof<max_i, live>>::value>;
  269. using R = typename FM::R;
  270. template <class VV> static decltype(auto) op(VV && v) { return FM::op(std::forward<VV>(v).op); } // cf [ra31]
  271. };
  272. // Terminal case where V doesn't have rank (is a raw op()).
  273. template <class V, class ... Ti, class ... Ri, rank_t skip>
  274. struct Framematch_def<V, std::tuple<Ti ...>, std::tuple<Ri ...>, skip>
  275. {
  276. static_assert(sizeof...(Ti)==sizeof...(Ri), "Bad arguments.");
  277. // TODO -crank::value when the actual verb rank is used (eg to use CellBig<... that_rank> instead of just begin()).
  278. using R = std::tuple<mp::append<Ri, mp::iota<(rank_s<Ti>() - mp::len<Ri>), skip>> ...>;
  279. template <class VV> static decltype(auto) op(VV && v) { return std::forward<VV>(v); }
  280. };
  281. // ---------------------------
  282. // general expression
  283. // ---------------------------
  284. template <class Op, class T, class K=mp::iota<mp::len<T>>> struct Expr;
  285. template <class Op, IteratorConcept ... P, int ... I>
  286. struct Expr<Op, std::tuple<P ...>, mp::int_list<I ...>>: public Match<true, std::tuple<P ...>>
  287. {
  288. template <class T>
  289. struct Flat
  290. {
  291. Op & op;
  292. T t;
  293. template <class S> constexpr void operator+=(S const & s) { ((std::get<I>(t) += std::get<I>(s)), ...); }
  294. // FIXME flagged by gcc 12.1 -O3
  295. #pragma GCC diagnostic push
  296. #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
  297. constexpr decltype(auto) operator*() { return std::invoke(op, *std::get<I>(t) ...); }
  298. #pragma GCC diagnostic pop
  299. };
  300. template <class ... F>
  301. constexpr static auto
  302. flat(Op & op, F && ... f)
  303. {
  304. return Flat<std::tuple<F ...>> { op, { std::forward<F>(f) ... } };
  305. }
  306. using Match_ = Match<true, std::tuple<P ...>>;
  307. using Match_::t, Match_::rank_s, Match_::rank;
  308. Op op;
  309. // test/ra-9.cc [ra1]
  310. constexpr Expr(Op op_, P ... p_): Match_(std::forward<P>(p_) ...), op(std::forward<Op>(op_)) {}
  311. RA_DEF_ASSIGNOPS_SELF(Expr)
  312. RA_DEF_ASSIGNOPS_DEFAULT_SET
  313. constexpr decltype(auto)
  314. at(auto const & j) const
  315. {
  316. return std::invoke(op, std::get<I>(t).at(j) ...);
  317. }
  318. constexpr decltype(auto)
  319. flat() // FIXME can't be const bc of Flat::op. Carries over to Pick / Reframe .flat() ...
  320. {
  321. return flat(op, std::get<I>(t).flat() ...);
  322. }
  323. // needed for rank_s()==ANY, which don't decay to scalar when used as operator arguments.
  324. constexpr
  325. operator decltype(*(flat(op, std::get<I>(t).flat() ...))) ()
  326. {
  327. if constexpr (0!=rank_s() && (1!=rank_s() || 1!=size_s<Expr>())) { // for coord types; so ct only
  328. static_assert(rank_s()==ANY);
  329. assert(0==rank());
  330. }
  331. return *flat();
  332. }
  333. };
  334. template <class Op, IteratorConcept ... P>
  335. constexpr bool is_special_def<Expr<Op, std::tuple<P ...>>> = (is_special<P> || ...);
  336. template <class V, class ... T, int ... i>
  337. constexpr auto
  338. expr_verb(mp::int_list<i ...>, V && v, T && ... t)
  339. {
  340. using FM = Framematch<V, std::tuple<T ...>>;
  341. return expr(FM::op(std::forward<V>(v)), reframe<mp::ref<typename FM::R, i>>(std::forward<T>(t)) ...);
  342. }
  343. template <class Op, class ... P>
  344. constexpr auto
  345. expr(Op && op, P && ... p)
  346. {
  347. if constexpr (is_verb<Op>) {
  348. return expr_verb(mp::iota<sizeof...(P)> {}, std::forward<Op>(op), std::forward<P>(p) ...);
  349. } else {
  350. return Expr<Op, std::tuple<P ...>> { std::forward<Op>(op), std::forward<P>(p) ... };
  351. }
  352. }
  353. template <class Op, class ... A>
  354. constexpr auto
  355. map(Op && op, A && ... a)
  356. {
  357. return expr(std::forward<Op>(op), start(std::forward<A>(a)) ...);
  358. }
  359. // ---------------
  360. // explicit agreement checks
  361. // ---------------
  362. template <class ... P>
  363. constexpr bool
  364. agree(P && ... p)
  365. {
  366. return agree_(ra::start(std::forward<P>(p)) ...);
  367. }
  368. // 0: fail, 1: rt, 2: pass
  369. template <class ... P>
  370. constexpr int
  371. agree_s(P && ... p)
  372. {
  373. return agree_s_(ra::start(std::forward<P>(p)) ...);
  374. }
  375. template <class Op, class ... P>
  376. constexpr bool
  377. agree_op(Op && op, P && ... p)
  378. {
  379. return agree_op_(std::forward<Op>(op), ra::start(std::forward<P>(p)) ...);
  380. }
  381. template <class ... P>
  382. constexpr bool
  383. agree_(P && ... p)
  384. {
  385. return (Match<false, std::tuple<P ...>> { std::forward<P>(p) ... }).check();
  386. }
  387. template <class ... P>
  388. constexpr int
  389. agree_s_(P && ... p)
  390. {
  391. return Match<false, std::tuple<P ...>>::check_s();
  392. }
  393. template <class Op, class ... P>
  394. constexpr bool
  395. agree_op_(Op && op, P && ... p)
  396. {
  397. if constexpr (is_verb<Op>) {
  398. return agree_verb(mp::iota<sizeof...(P)> {}, std::forward<Op>(op), std::forward<P>(p) ...);
  399. } else {
  400. return agree_(std::forward<P>(p) ...);
  401. }
  402. }
  403. template <class V, class ... T, int ... i>
  404. constexpr bool
  405. agree_verb(mp::int_list<i ...>, V && v, T && ... t)
  406. {
  407. using FM = Framematch<V, std::tuple<T ...>>;
  408. return agree_op_(FM::op(std::forward<V>(v)), reframe<mp::ref<typename FM::R, i>>(std::forward<T>(t)) ...);
  409. }
  410. // ---------------------------
  411. // pick
  412. // ---------------------------
  413. template <class T, class J> struct pick_at_type;
  414. template <class ... P, class J> struct pick_at_type<std::tuple<P ...>, J>
  415. {
  416. using type = mp::apply<std::common_reference_t, std::tuple<decltype(std::declval<P>().at(std::declval<J>())) ...>>;
  417. };
  418. template <std::size_t I, class T, class J>
  419. constexpr pick_at_type<mp::drop1<std::decay_t<T>>, J>::type
  420. pick_at(std::size_t p0, T && t, J const & j)
  421. {
  422. if constexpr (I+2<std::tuple_size_v<std::decay_t<T>>) {
  423. if (p0==I) {
  424. return std::get<I+1>(t).at(j);
  425. } else {
  426. return pick_at<I+1>(p0, t, j);
  427. }
  428. } else {
  429. RA_CHECK(p0==I, " p0 ", p0, " I ", I);
  430. return std::get<I+1>(t).at(j);
  431. }
  432. }
  433. template <class T> struct pick_star_type;
  434. template <class ... P> struct pick_star_type<std::tuple<P ...>>
  435. {
  436. using type = mp::apply<std::common_reference_t, std::tuple<decltype(*std::declval<P>()) ...>>;
  437. };
  438. template <std::size_t I, class T>
  439. constexpr pick_star_type<mp::drop1<std::decay_t<T>>>::type
  440. pick_star(std::size_t p0, T && t)
  441. {
  442. if constexpr (I+2<std::tuple_size_v<std::decay_t<T>>) {
  443. if (p0==I) {
  444. return *(std::get<I+1>(t));
  445. } else {
  446. return pick_star<I+1>(p0, t);
  447. }
  448. } else {
  449. RA_CHECK(p0==I, " p0 ", p0, " I ", I);
  450. return *(std::get<I+1>(t));
  451. }
  452. }
  453. template <class T, class K=mp::iota<mp::len<T>>> struct Pick;
  454. template <IteratorConcept ... P, int ... I>
  455. struct Pick<std::tuple<P ...>, mp::int_list<I ...>>: public Match<true, std::tuple<P ...>>
  456. {
  457. static_assert(sizeof...(P)>1);
  458. template <class T_>
  459. struct Flat
  460. {
  461. T_ t;
  462. template <class S> constexpr void operator+=(S const & s) { ((std::get<I>(t) += std::get<I>(s)), ...); }
  463. constexpr decltype(auto) operator*() { return pick_star<0>(*std::get<0>(t), t); }
  464. };
  465. template <class ... P_>
  466. constexpr static auto
  467. flat(P_ && ... p)
  468. {
  469. return Flat<std::tuple<P_ ...>> { std::tuple<P_ ...> { std::forward<P_>(p) ... } };
  470. }
  471. using Match_ = Match<true, std::tuple<P ...>>;
  472. using Match_::t, Match_::rank_s, Match_::rank;
  473. // test/ra-9.cc [ra1]
  474. constexpr Pick(P ... p_): Match_(std::forward<P>(p_) ...) {}
  475. RA_DEF_ASSIGNOPS_SELF(Pick)
  476. RA_DEF_ASSIGNOPS_DEFAULT_SET
  477. constexpr decltype(auto)
  478. flat()
  479. {
  480. return flat(std::get<I>(t).flat() ...);
  481. }
  482. constexpr decltype(auto)
  483. at(auto const & j) const
  484. {
  485. return pick_at<0>(std::get<0>(t).at(j), t, j);
  486. }
  487. // needed for xpr with rank_s()==ANY, which don't decay to scalar when used as operator arguments.
  488. constexpr
  489. operator decltype(*(flat(std::get<I>(t).flat() ...))) ()
  490. {
  491. if constexpr (0!=rank_s() && (1!=rank_s() || 1!=size_s<Pick>())) { // for coord types; so ct only
  492. static_assert(rank_s()==ANY);
  493. assert(0==rank());
  494. }
  495. return *flat();
  496. }
  497. };
  498. template <IteratorConcept ... P>
  499. constexpr bool is_special_def<Pick<std::tuple<P ...>>> = (is_special<P> || ...);
  500. template <class ... P> Pick(P && ... p) -> Pick<std::tuple<P ...>>;
  501. template <class ... P>
  502. constexpr auto
  503. pick(P && ... p)
  504. {
  505. return Pick { start(std::forward<P>(p)) ... };
  506. }
  507. } // namespace ra