dual.hh 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. // -*- mode: c++; coding: utf-8 -*-
  2. // ra-ra - Dual numbers for automatic differentiation.
  3. // (c) Daniel Llorens - 2013-2023
  4. // This library is free software; you can redistribute it and/or modify it under
  5. // the terms of the GNU Lesser General Public License as published by the Free
  6. // Software Foundation; either version 3 of the License, or (at your option) any
  7. // later version.
  8. // See VanderBergen2012, Berland2006. Generally about automatic differentiation:
  9. // http://en.wikipedia.org/wiki/Automatic_differentiation
  10. // From the Taylor expansion of f(a) or f(a, b)...
  11. // f(a+εa') = f(a)+εa'f_a(a)
  12. // f(a+εa', b+εb') = f(a, b)+ε[a'f_a(a, b) b'f_b(a, b)]
  13. #pragma once
  14. #include <cmath>
  15. #include <iosfwd>
  16. #include "macros.hh"
  17. using std::abs, std::sqrt, std::fma;
  18. template <class T>
  19. struct Dual
  20. {
  21. T re, du;
  22. constexpr static bool is_complex = requires { requires !(std::is_same_v<T, std::decay_t<decltype(std::declval<T>().real())>>); };
  23. template <class S> struct real_part { struct type {}; };
  24. template <class S> requires (is_complex) struct real_part<S> { using type = typename S::value_type; };
  25. using real_type = typename real_part<T>::type;
  26. constexpr Dual(T const & r, T const & d): re(r), du(d) {}
  27. constexpr Dual(T const & r): re(r), du(0.) {} // conversions are by default constants.
  28. constexpr Dual(real_type const & r) requires (is_complex): re(r), du(0.) {}
  29. constexpr Dual() {}
  30. #define DEF_ASSIGNOPS(OP) \
  31. constexpr Dual & operator JOIN(OP, =)(T const & r) { *this = *this OP r; return *this; } \
  32. constexpr Dual & operator JOIN(OP, =)(Dual const & r) { *this = *this OP r; return *this; } \
  33. constexpr Dual & operator JOIN(OP, =)(real_type const & r) requires (is_complex) { *this = *this OP r; return *this; }
  34. FOR_EACH(DEF_ASSIGNOPS, +, -, /, *)
  35. #undef DEF_ASSIGNOPS
  36. };
  37. // conversions are by default constants.
  38. template <class R> constexpr auto dual(Dual<R> const & r) { return r; }
  39. template <class R> constexpr auto dual(R const & r) { return Dual<R> { r, 0. }; }
  40. template <class R, class D>
  41. constexpr auto
  42. dual(R const & r, D const & d)
  43. {
  44. return Dual<std::common_type_t<R, D>> { r, d };
  45. }
  46. template <class A, class B>
  47. constexpr auto
  48. operator*(Dual<A> const & a, Dual<B> const & b)
  49. {
  50. return dual(a.re*b.re, a.re*b.du + a.du*b.re);
  51. }
  52. template <class A, class B>
  53. constexpr auto
  54. operator*(A const & a, Dual<B> const & b)
  55. {
  56. return dual(a*b.re, a*b.du);
  57. }
  58. template <class A, class B>
  59. constexpr auto
  60. operator*(Dual<A> const & a, B const & b)
  61. {
  62. return dual(a.re*b, a.du*b);
  63. }
  64. template <class A, class B, class C>
  65. constexpr auto
  66. fma(Dual<A> const & a, Dual<B> const & b, Dual<C> const & c)
  67. {
  68. return dual(fma(a.re, b.re, c.re), fma(a.re, b.du, fma(a.du, b.re, c.du)));
  69. }
  70. template <class A, class B>
  71. constexpr auto
  72. operator+(Dual<A> const & a, Dual<B> const & b)
  73. {
  74. return dual(a.re+b.re, a.du+b.du);
  75. }
  76. template <class A, class B>
  77. constexpr auto
  78. operator+(A const & a, Dual<B> const & b)
  79. {
  80. return dual(a+b.re, b.du);
  81. }
  82. template <class A, class B>
  83. constexpr auto
  84. operator+(Dual<A> const & a, B const & b)
  85. {
  86. return dual(a.re+b, a.du);
  87. }
  88. template <class A, class B>
  89. constexpr auto
  90. operator-(Dual<A> const & a, Dual<B> const & b)
  91. {
  92. return dual(a.re-b.re, a.du-b.du);
  93. }
  94. template <class A, class B>
  95. constexpr auto
  96. operator-(Dual<A> const & a, B const & b)
  97. {
  98. return dual(a.re-b, a.du);
  99. }
  100. template <class A, class B>
  101. constexpr auto
  102. operator-(A const & a, Dual<B> const & b)
  103. {
  104. return dual(a-b.re, -b.du);
  105. }
  106. template <class A>
  107. constexpr auto
  108. operator-(Dual<A> const & a)
  109. {
  110. return dual(-a.re, -a.du);
  111. }
  112. template <class A>
  113. constexpr decltype(auto)
  114. operator+(Dual<A> const & a)
  115. {
  116. return a;
  117. }
  118. template <class A>
  119. constexpr auto
  120. inv(Dual<A> const & a)
  121. {
  122. auto i = 1./a.re;
  123. return dual(i, -a.du*(i*i));
  124. }
  125. template <class A, class B>
  126. constexpr auto
  127. operator/(Dual<A> const & a, Dual<B> const & b)
  128. {
  129. return a*inv(b);
  130. }
  131. template <class A, class B>
  132. constexpr auto
  133. operator/(Dual<A> const & a, B const & b)
  134. {
  135. return a*inv(dual(b));
  136. }
  137. template <class A, class B>
  138. constexpr auto
  139. operator/(A const & a, Dual<B> const & b)
  140. {
  141. return dual(a)*inv(b);
  142. }
  143. template <class A>
  144. constexpr auto
  145. cos(Dual<A> const & a)
  146. {
  147. return dual(cos(a.re), -sin(a.re)*a.du);
  148. }
  149. template <class A>
  150. constexpr auto
  151. sin(Dual<A> const & a)
  152. {
  153. return dual(sin(a.re), +cos(a.re)*a.du);
  154. }
  155. template <class A>
  156. constexpr auto
  157. cosh(Dual<A> const & a)
  158. {
  159. return dual(cosh(a.re), +sinh(a.re)*a.du);
  160. }
  161. template <class A>
  162. constexpr auto
  163. sinh(Dual<A> const & a)
  164. {
  165. return dual(sinh(a.re), +cosh(a.re)*a.du);
  166. }
  167. template <class A>
  168. constexpr auto
  169. tan(Dual<A> const & a)
  170. {
  171. auto c = cos(a.du);
  172. return dual(tan(a.re), a.du/(c*c));
  173. }
  174. template <class A>
  175. constexpr auto
  176. exp(Dual<A> const & a)
  177. {
  178. return dual(exp(a.re), +exp(a.re)*a.du);
  179. }
  180. template <class A, class B>
  181. constexpr auto
  182. pow(Dual<A> const & a, B const & b)
  183. {
  184. return dual(pow(a.re, b), +b*pow(a.re, b-1)*a.du);
  185. }
  186. template <class A>
  187. constexpr auto
  188. log(Dual<A> const & a)
  189. {
  190. return dual(log(a.re), +a.du/a.re);
  191. }
  192. template <class A>
  193. constexpr auto
  194. sqrt(Dual<A> const & a)
  195. {
  196. return dual(sqrt(a.re), +a.du/(2.*sqrt(a.re)));
  197. }
  198. template <class A>
  199. constexpr auto
  200. sqr(Dual<A> const & a)
  201. {
  202. return a*a;
  203. }
  204. template <class A>
  205. constexpr auto
  206. abs(Dual<A> const & a)
  207. {
  208. return abs(a.re);
  209. }
  210. template <class A>
  211. constexpr bool
  212. isfinite(Dual<A> const & a)
  213. {
  214. return isfinite(a.re) && isfinite(a.du);
  215. }
  216. template <class A>
  217. constexpr auto
  218. xI(Dual<A> const & a)
  219. {
  220. return dual(xI(a.re), xI(a.du));
  221. }
  222. template <class A>
  223. std::ostream & operator<<(std::ostream & o, Dual<A> const & a)
  224. {
  225. return o << "[" << a.re << " " << a.du << "]";
  226. }
  227. template <class A>
  228. std::istream & operator>>(std::istream & i, Dual<A> & a)
  229. {
  230. std::string s;
  231. i >> s;
  232. if (s!="[") {
  233. i.setstate(std::ios::failbit);
  234. return i;
  235. }
  236. a >> a.re;
  237. a >> a.du;
  238. i >> s;
  239. if (s!="]") {
  240. i.setstate(std::ios::failbit);
  241. return i;
  242. }
  243. }