123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397 |
- // -*- mode: c++; coding: utf-8 -*-
- // ra-ra - Traverse (ply) array or array expression or array statement.
- // (c) Daniel Llorens - 2013-2019, 2021
- // This library is free software; you can redistribute it and/or modify it under
- // the terms of the GNU Lesser General Public License as published by the Free
- // Software Foundation; either version 3 of the License, or (at your option) any
- // later version.
- // TODO Lots of room for improvement: small (fixed sizes) and large (tiling, etc. see eval.cc in Blitz++).
- // TODO Traversal order should be a parameter, since some operations (e.g. output, ravel) require a specific order.
- // TODO Better heuristic for traversal order.
- // TODO std::execution::xxx-policy, validate output argument strides.
- #pragma once
- #include "pick.hh"
- #include "expr.hh"
- namespace ra {
- // ---------------------
- // does expr tree contain Len?
- // ---------------------
- template <>
- constexpr bool has_len_def<Len> = true;
- template <IteratorConcept ... P>
- constexpr bool has_len_def<Pick<std::tuple<P ...>>> = (has_len<P> || ...);
- template <class Op, IteratorConcept ... P>
- constexpr bool has_len_def<Expr<Op, std::tuple<P ...>>> = (has_len<P> || ...);
- template <int w, class O, class N, class S>
- constexpr bool has_len_def<Iota<w, O, N, S>> = (has_len<O> || has_len<N> || has_len<S>);
- // ---------------------
- // replace Len in expr tree.
- // ---------------------
- template <class E_>
- struct WithLen
- {
- // constant & scalar appear in Iota args. dots_t and insert_t appear in subscripts.
- // FIXME what else? restrict to IteratorConcept<E_> || is_constant<E_> || is_scalar<E_> ...
- template <class E> constexpr static decltype(auto)
- f(dim_t len, E && e)
- {
- return std::forward<E>(e);
- }
- };
- template <>
- struct WithLen<Len>
- {
- template <class E> constexpr static decltype(auto)
- f(dim_t len, E && e)
- {
- return Scalar<dim_t>(len);
- }
- };
- template <class Op, IteratorConcept ... P, int ... I>
- requires (has_len<P> || ...)
- struct WithLen<Expr<Op, std::tuple<P ...>, mp::int_list<I ...>>>
- {
- template <class E> constexpr static decltype(auto)
- f(dim_t len, E && e)
- {
- return expr(std::forward<E>(e).op, WithLen<std::decay_t<P>>::f(len, std::get<I>(std::forward<E>(e).t)) ...);
- }
- };
- template <IteratorConcept ... P, int ... I>
- requires (has_len<P> || ...)
- struct WithLen<Pick<std::tuple<P ...>, mp::int_list<I ...>>>
- {
- template <class E> constexpr static decltype(auto)
- f(dim_t len, E && e)
- {
- return pick(WithLen<std::decay_t<P>>::f(len, std::get<I>(std::forward<E>(e).t)) ...);
- }
- };
- template <int w, class O, class N, class S>
- requires (has_len<O> || has_len<N> || has_len<S>)
- struct WithLen<Iota<w, O, N, S>>
- {
- // usable iota types must be either is_constant or is_scalar.
- template <class T> constexpr static decltype(auto)
- coerce(T && t)
- {
- if constexpr (IteratorConcept<T>) {
- return FLAT(t);
- } else {
- return std::forward<T>(t);
- }
- }
- template <class E> constexpr static decltype(auto)
- f(dim_t len, E && e)
- {
- return iota<w>(coerce(WithLen<std::decay_t<N>>::f(len, std::forward<E>(e).n)),
- coerce(WithLen<std::decay_t<O>>::f(len, std::forward<E>(e).i)),
- coerce(WithLen<std::decay_t<S>>::f(len, std::forward<E>(e).s)));
- }
- };
- template <class E>
- constexpr decltype(auto)
- with_len(dim_t len, E && e)
- {
- return WithLen<std::decay_t<E>>::f(len, std::forward<E>(e));
- }
- // --------------
- // ply, run time order
- // --------------
- // Traverse array expression looking to ravel the inner loop.
- // step() must give 0 for k>=their own rank, to allow frame matching.
- template <IteratorConcept A>
- inline void
- ply_ravel(A && a)
- {
- rank_t rank = a.rank();
- // FIXME without assert compiler thinks var rank may be negative. See test in [ra40].
- #ifdef NDEBUG
- #pragma GCC diagnostic push
- #pragma GCC diagnostic ignored "-Wvla-larger-than="
- rank_t order[rank];
- dim_t sha[rank], ind[rank];
- #pragma GCC diagnostic pop
- #else
- assert(rank>=0);
- rank_t order[rank];
- dim_t sha[rank], ind[rank];
- #endif
- for (rank_t i=0; i<rank; ++i) {
- order[i] = rank-1-i;
- }
- switch (rank) {
- case 0: *(a.flat()); return;
- case 1: break;
- default: // TODO better heuristic
- // if (rank>1) {
- // std::sort(order, order+rank, [&a, &order](auto && i, auto && j)
- // { return a.len(order[i])<a.len(order[j]); });
- // }
- ;
- }
- // outermost compact dim.
- rank_t * ocd = order;
- // FIXME see same thing below.
- #pragma GCC diagnostic push
- #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
- auto ss = a.len(*ocd);
- #pragma GCC diagnostic pop
- for (--rank, ++ocd; rank>0 && a.keep_step(ss, order[0], *ocd); --rank, ++ocd) {
- ss *= a.len(*ocd);
- }
- for (int k=0; k<rank; ++k) {
- ind[k] = 0;
- sha[k] = a.len(ocd[k]);
- if (sha[k]==0) { // for the raveled dimensions ss takes care.
- return;
- }
- RA_CHECK(sha[k]!=DIM_BAD, "undefined dim ", ocd[k]);
- }
- // all sub xpr steps advance in compact dims, as they might be different.
- auto const ss0 = a.step(order[0]);
- for (;;) {
- dim_t s = ss;
- for (auto p=a.flat(); s>0; --s, p+=ss0) {
- *p;
- }
- for (int k=0; ; ++k) {
- if (k>=rank) {
- return;
- } else if (ind[k]<sha[k]-1) {
- ++ind[k];
- a.adv(ocd[k], 1);
- break;
- } else {
- ind[k] = 0;
- a.adv(ocd[k], 1-sha[k]);
- }
- }
- }
- }
- // -------------------------
- // ply, compile time order
- // -------------------------
- template <class order, int ravel_rank, class A, class S>
- constexpr void
- subindex(A & a, dim_t s, S const & ss0)
- {
- if constexpr (mp::len<order> == ravel_rank) {
- #pragma GCC diagnostic push
- #pragma GCC diagnostic warning "-Wstringop-overflow"
- #pragma GCC diagnostic warning "-Wstringop-overread"
- for (auto p=a.flat(); s>0; --s, p+=ss0) {
- *p;
- }
- #pragma GCC diagnostic pop
- } else {
- dim_t size = a.len(mp::first<order>::value); // TODO Precompute these at the top
- for (dim_t i=0, iend=size; i<iend; ++i) {
- subindex<mp::drop1<order>, ravel_rank>(a, s, ss0);
- a.adv(mp::first<order>::value, 1);
- }
- a.adv(mp::first<order>::value, -size);
- }
- }
- // convert runtime jj into compile time j. TODO a.adv<k>().
- template <class order, int j, class A, class S>
- constexpr void
- until(int const jj, A & a, dim_t const s, S const & ss0)
- {
- if constexpr (mp::len<order> >= j) {
- if (jj==j) {
- subindex<order, j>(a, s, ss0);
- } else {
- until<order, j+1>(jj, a, s, ss0);
- }
- } else {
- std::abort();
- }
- }
- // find outermost compact dim.
- template <class A>
- constexpr auto
- ocd()
- {
- rank_t const rank = A::rank_s();
- auto s = A::len_s(rank-1);
- int j = 1;
- while (j<rank && A::keep_step(s, rank-1, rank-1-j)) {
- s *= A::len_s(rank-1-j);
- ++j;
- }
- return std::make_tuple(s, j);
- };
- template <IteratorConcept A>
- constexpr void
- plyf(A && a)
- {
- constexpr rank_t rank = rank_s<A>();
- static_assert(rank>=0, "plyf needs static rank");
- if constexpr (rank_s<A>()==0) {
- *(a.flat());
- } else if constexpr (rank_s<A>()==1) {
- subindex<mp::iota<1>, 1>(a, a.len(0), a.step(0));
- // this can only be enabled when f() will be constexpr; static keep_step implies all else is also static.
- // important rank>1 for with static size operands [ra43].
- } else if constexpr (rank_s<A>()>1 && requires (dim_t d, rank_t i, rank_t j) { A::keep_step(d, i, j); }) {
- constexpr auto sj = ocd<std::decay_t<A>>();
- constexpr auto s = std::get<0>(sj);
- constexpr auto j = std::get<1>(sj);
- // all sub xpr steps advance in compact dims, as they might be different.
- // send with static j. Note that order here is inverse of order.
- until<mp::iota<rank_s<A>()>, 0>(j, a, s, a.step(rank-1));
- } else {
- // the unrolling above isn't worth it when s, j cannot be constexpr.
- auto s = a.len(rank-1);
- subindex<mp::iota<rank_s<A>()>, 1>(a, s, a.step(rank-1));
- }
- }
- // ---------------------------
- // select best performance (or requirements) for each type
- // ---------------------------
- template <IteratorConcept A>
- constexpr void
- ply(A && a)
- {
- static_assert(!has_len<A>, "len used outside subscript context.");
- if constexpr (size_s<A>()==DIM_ANY) {
- ply_ravel(std::forward<A>(a));
- } else {
- plyf(std::forward<A>(a));
- }
- }
- // ---------------------------
- // ply, short-circuiting
- // ---------------------------
- // TODO Refactor with ply_ravel. Make exit available to plyf.
- // TODO These are reductions. How to do higher rank?
- template <IteratorConcept A, class DEF>
- inline auto
- ply_ravel_exit(A && a, DEF && def)
- {
- static_assert(!has_len<A>, "len used outside subscript context.");
- rank_t rank = a.rank();
- // FIXME without assert compiler thinks var rank may be negative. See test in [ra40].
- #ifdef NDEBUG
- #pragma GCC diagnostic push
- #pragma GCC diagnostic ignored "-Wvla-larger-than="
- rank_t order[rank];
- dim_t sha[rank], ind[rank];
- #pragma GCC diagnostic pop
- #else
- assert(rank>=0);
- rank_t order[rank];
- dim_t sha[rank], ind[rank];
- #endif
- for (rank_t i=0; i<rank; ++i) {
- order[i] = rank-1-i;
- }
- switch (rank) {
- case 0: {
- if (auto what = *(a.flat()); std::get<0>(what)) {
- return std::get<1>(what);
- }
- return def;
- }
- case 1: break;
- default: // TODO better heuristic
- // if (rank>1) {
- // std::sort(order, order+rank, [&a, &order](auto && i, auto && j)
- // { return a.len(order[i])<a.len(order[j]); });
- // }
- ;
- }
- // outermost compact dim.
- rank_t * ocd = order;
- // FIXME on github actions ubuntu-latest g++-11 -O3 :-|
- #pragma GCC diagnostic push
- #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
- auto ss = a.len(*ocd);
- #pragma GCC diagnostic pop
- for (--rank, ++ocd; rank>0 && a.keep_step(ss, order[0], *ocd); --rank, ++ocd) {
- ss *= a.len(*ocd);
- }
- for (int k=0; k<rank; ++k) {
- ind[k] = 0;
- sha[k] = a.len(ocd[k]);
- if (sha[k]==0) { // for the raveled dimensions ss takes care.
- return def;
- }
- RA_CHECK(sha[k]!=DIM_BAD, "undefined dim ", ocd[k]);
- }
- // all sub xpr steps advance in compact dims, as they might be different.
- auto const ss0 = a.step(order[0]);
- for (;;) {
- dim_t s = ss;
- for (auto p=a.flat(); s>0; --s, p+=ss0) {
- if (auto what = *p; std::get<0>(what)) {
- return std::get<1>(what);
- }
- }
- for (int k=0; ; ++k) {
- if (k>=rank) {
- return def;
- } else if (ind[k]<sha[k]-1) {
- ++ind[k];
- a.adv(ocd[k], 1);
- break;
- } else {
- ind[k] = 0;
- a.adv(ocd[k], 1-sha[k]);
- }
- }
- }
- }
- template <IteratorConcept A, class DEF>
- constexpr decltype(auto)
- early(A && a, DEF && def)
- {
- return ply_ravel_exit(std::forward<A>(a), std::forward<DEF>(def));
- }
- template <class Op, class ... A>
- constexpr void
- for_each(Op && op, A && ... a)
- {
- ply(map(std::forward<Op>(op), std::forward<A>(a) ...));
- }
- } // namespace ra
|