test.hh 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. // -*- mode: c++; coding: utf-8 -*-
  2. // ra-ra - Test library.
  3. // (c) Daniel Llorens - 2012-2022
  4. // This library is free software; you can redistribute it and/or modify it under
  5. // the terms of the GNU Lesser General Public License as published by the Free
  6. // Software Foundation; either version 3 of the License, or (at your option) any
  7. // later version.
  8. #pragma once
  9. #include <string>
  10. #include <iomanip>
  11. #include <iostream>
  12. #include <ctime>
  13. #include "ra.hh"
  14. namespace ra {
  15. struct TestRecorder
  16. {
  17. constexpr static double QNAN = std::numeric_limits<double>::quiet_NaN();
  18. constexpr static double PINF = std::numeric_limits<double>::infinity();
  19. // ra::amax ignore nans in the way fmax etc. do, and we don't want that here.
  20. template <class A>
  21. inline static auto
  22. amax_strict(A && a)
  23. {
  24. using std::max;
  25. using T = value_t<A>;
  26. T c = std::numeric_limits<T>::has_infinity ? -std::numeric_limits<T>::infinity() : std::numeric_limits<T>::lowest();
  27. return early(map([&c](auto && a) { if (c<a) { c = a; }; return std::make_tuple(isnan(a), QNAN*a); },
  28. std::forward<A>(a)), c);
  29. return c;
  30. }
  31. enum verbose_t { QUIET, // as NOISY if failed, else no output
  32. ERRORS, // as NOISY if failed, else info and fp errors (default)
  33. NOISY }; // full output of info, test arguments, fp errors
  34. std::ostream & o;
  35. int total=0, skipped=0, passed_good=0, passed_bad=0, failed_good=0, failed_bad=0;
  36. std::vector<int> bad;
  37. std::string info_str;
  38. verbose_t verbose_default, verbose;
  39. bool willskip = false;
  40. bool willexpectfail = false;
  41. bool willstrictshape = false;
  42. TestRecorder(std::ostream & o_=std::cout, verbose_t verbose_default_=ERRORS)
  43. : o(o_), verbose_default(verbose_default_), verbose(verbose_default_) {}
  44. template <class ... A> void
  45. section(A const & ... a)
  46. {
  47. o << "\n" << esc_bold << format(a ...) << esc_unbold << std::endl;
  48. }
  49. static std::string
  50. format_error(double e)
  51. {
  52. return format(esc_yellow, std::setprecision(2), e, esc_reset);
  53. }
  54. template <class ... A> TestRecorder &
  55. info(A && ... a)
  56. {
  57. bool empty = (info_str=="");
  58. info_str += esc_pink;
  59. info_str += (empty ? "" : "; ") + format(a ...) + esc_reset;
  60. return *this;
  61. }
  62. TestRecorder & quiet(verbose_t v=QUIET) { verbose = v; return *this; }
  63. TestRecorder & noisy(verbose_t v=NOISY) { verbose = v; return *this; }
  64. TestRecorder & skip(bool s=true) { willskip = s; return *this; }
  65. TestRecorder & strictshape(bool s=true) { willstrictshape = s; return *this; }
  66. TestRecorder & expectfail(bool s=true) { willexpectfail = s; return *this; }
  67. template <class A, class B>
  68. void
  69. test(bool c, A && info_full, B && info_min,
  70. std::source_location const loc = std::source_location::current())
  71. {
  72. switch (verbose) {
  73. case QUIET: {
  74. if (!c) {
  75. o << format(esc_cyan, "[", total, ":", loc, "]", esc_reset, " ...",
  76. esc_bold, esc_red, " FAILED", esc_reset,
  77. esc_yellow, (willskip ? " skipped" : ""), (willexpectfail ? " expected" : ""), esc_reset,
  78. " ", info_full())
  79. << std::endl;
  80. }
  81. }; break;
  82. case NOISY: case ERRORS: {
  83. o << format(esc_cyan, "[", total, ":", loc, "]", esc_reset, " ...")
  84. << (c ? std::string(esc_green) + " ok" + esc_reset
  85. : std::string(esc_bold) + esc_red + " FAILED" + esc_reset)
  86. << esc_yellow << (willskip ? " skipped" : "")
  87. << (willexpectfail ? (c ? " not expected" : " expected") : "") << esc_reset
  88. << " " << ((verbose==NOISY || c==willexpectfail) ? info_full() : info_min())
  89. << std::endl;
  90. }; break;
  91. default: std::abort();
  92. }
  93. info_str = "";
  94. verbose = verbose_default;
  95. if (!willskip) {
  96. ++(willexpectfail? (c ? passed_bad : failed_good) : (c ? passed_good : failed_bad));
  97. if (c==willexpectfail) {
  98. bad.push_back(total);
  99. }
  100. } else {
  101. ++skipped;
  102. }
  103. ++total;
  104. willstrictshape = willskip = willexpectfail = false;
  105. }
  106. #define LAZYINFO(...) [&] { return format(info_str, (info_str=="" ? "" : "; "), __VA_ARGS__); }
  107. template <class A>
  108. void
  109. test(bool c, A && info_full,
  110. std::source_location const loc = std::source_location::current())
  111. {
  112. test(c, info_full, info_full, loc);
  113. }
  114. void
  115. test(bool c,
  116. std::source_location const loc = std::source_location::current())
  117. {
  118. test(c, LAZYINFO(""), loc);
  119. }
  120. // Comp = ... is non-deduced context, so can't replace test_eq() with a default argument here.
  121. // where() is used to match shapes if either REF or A don't't have one.
  122. template <class A, class B, class Comp>
  123. bool
  124. test_comp(A && a, B && b, Comp && comp, char const * msg,
  125. std::source_location const loc = std::source_location::current())
  126. {
  127. if (willstrictshape
  128. ? [&] {
  129. if constexpr (ra::rank_s<decltype(a)>()==ra::rank_s<decltype(b)>()
  130. || ra::rank_s<decltype(a)>()==RANK_ANY || ra::rank_s<decltype(b)>()==RANK_ANY) {
  131. return ra::rank(a)==ra::rank(b) && every(ra::shape(a)==ra::shape(b));
  132. } else {
  133. return false;
  134. } }()
  135. : agree_op(comp, a, b)) {
  136. bool c = every(ra::map(comp, a, b));
  137. test(c, LAZYINFO(where(false, a, b), " (", msg, " ", where(true, a, b), ")"),
  138. LAZYINFO(""), loc);
  139. return c;
  140. } else {
  141. test(false,
  142. LAZYINFO("Mismatched args [", ra::noshape, ra::shape(a), "] [", ra::noshape, ra::shape(b), "]",
  143. willstrictshape ? " (strict shape)" : ""),
  144. LAZYINFO("Shape mismatch", willstrictshape ? " (strict shape)" : ""),
  145. loc);
  146. return false;
  147. }
  148. }
  149. template <class R, class A>
  150. bool
  151. test_eq(R && ref, A && a,
  152. std::source_location const loc = std::source_location::current())
  153. {
  154. return test_comp(ra::start(ref), ra::start(a), [](auto && a, auto && b) { return every(a==b); },
  155. "should be ==", loc);
  156. }
  157. template <class A, class B>
  158. bool
  159. test_lt(A && a, B && b,
  160. std::source_location const loc = std::source_location::current())
  161. {
  162. return test_comp(ra::start(a), ra::start(b), [](auto && a, auto && b) { return every(a<b); },
  163. "should be <", loc);
  164. }
  165. template <class A, class B>
  166. bool
  167. test_le(A && a, B && b,
  168. std::source_location const loc = std::source_location::current())
  169. {
  170. return test_comp(ra::start(a), ra::start(b), [](auto && a, auto && b) { return every(a<=b); },
  171. "should be <=", loc);
  172. }
  173. // These two are included so that the first argument can remain the reference.
  174. template <class A, class B>
  175. bool
  176. test_gt(A && a, B && b,
  177. std::source_location const loc = std::source_location::current())
  178. {
  179. return test_comp(ra::start(a), ra::start(b), [](auto && a, auto && b) { return every(a>b); },
  180. "should be >", loc);
  181. }
  182. template <class A, class B>
  183. bool
  184. test_ge(A && a, B && b,
  185. std::source_location const loc = std::source_location::current())
  186. {
  187. return test_comp(ra::start(a), ra::start(b), [](auto && a, auto && b) { return every(a>=b); },
  188. "should be >=", loc);
  189. }
  190. template <class R, class A>
  191. double
  192. test_rel_error(R && ref, A && a, double req, double level=0,
  193. std::source_location const loc = std::source_location::current())
  194. {
  195. double e = (level<=0)
  196. ? amax_strict(where(isfinite(ref),
  197. rel_error(ref, a),
  198. where(isinf(ref),
  199. where(ref==a, 0., PINF),
  200. where(isnan(a), 0., PINF))))
  201. : amax_strict(where(isfinite(ref),
  202. abs(ref-a)/level,
  203. where(isinf(ref),
  204. where(ref==a, 0., PINF),
  205. where(isnan(a), 0., PINF))));
  206. test(e<=req,
  207. LAZYINFO("rerr (", esc_yellow, "ref", esc_reset, ": ", ref, esc_yellow, ", got", esc_reset, ": ", a,
  208. ") = ", format_error(e), (level<=0 ? "" : format(" (level ", level, ")")), ", req. ", req),
  209. LAZYINFO("rerr: ", format_error(e), (level<=0 ? "" : format(" (level ", level, ")")),
  210. ", req. ", req),
  211. loc);
  212. return e;
  213. }
  214. template <class R, class A>
  215. double
  216. test_abs_error(R && ref, A && a, double req=0,
  217. std::source_location const loc = std::source_location::current())
  218. {
  219. double e = amax_strict(where(isfinite(ref),
  220. abs(ref-a),
  221. where(isinf(ref),
  222. where(ref==a, 0., PINF),
  223. where(isnan(a), 0., PINF))));
  224. test(e<=req,
  225. LAZYINFO("aerr (ref: ", ref, ", got: ", a, ") = ", format_error(e), ", req. ", req),
  226. LAZYINFO("aerr: ", format_error(e), ", req. ", req),
  227. loc);
  228. return e;
  229. }
  230. #undef LAZYINFO
  231. int
  232. summary() const
  233. {
  234. std::time_t t = std::time(nullptr);
  235. tm * tmp = std::localtime(&t);
  236. char buf[64];
  237. std::strftime(buf, sizeof(buf), "%Y-%m-%d %H:%M:%S", tmp);
  238. o << "--------------\nTests end " << buf << ". ";
  239. o << format("Of ", total, " tests passed ", (passed_good+passed_bad),
  240. " (", passed_bad, " unexpected), failed ", (failed_good+failed_bad),
  241. " (", failed_bad, " unexpected), skipped ", skipped, ".\n");
  242. if (bad.size()>0) {
  243. o << format(bad.size(), " bad tests: [", esc_bold, esc_red, ra::noshape, format_array(bad),
  244. esc_reset, "].\n");
  245. }
  246. return bad.size();
  247. }
  248. };
  249. } // namespace ra