draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.html 66 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587
  1. <!DOCTYPE html>
  2. <html lang="en">
  3. <head>
  4. <meta charset="utf-8" />
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="generator" content="Docutils 0.17.1: http://docutils.sourceforge.net/" />
  6. <title>draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU &#8212; Draugr 1.0.1 documentation</title>
  7. <link rel="stylesheet" type="text/css" href="../_static/pygments.css" />
  8. <link rel="stylesheet" type="text/css" href="../_static/alabaster.css" />
  9. <link rel="stylesheet" type="text/css" href="../_static/graphviz.css" />
  10. <script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
  11. <script src="../_static/jquery.js"></script>
  12. <script src="../_static/underscore.js"></script>
  13. <script src="../_static/_sphinx_javascript_frameworks_compat.js"></script>
  14. <script src="../_static/doctools.js"></script>
  15. <link rel="canonical" href="pything.github.io/draugr/generated/draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.html" />
  16. <link rel="index" title="Index" href="../genindex.html" />
  17. <link rel="search" title="Search" href="../search.html" />
  18. <link rel="next" title="draugr.torch_utilities.optimisation.debugging.layer_fetching" href="draugr.torch_utilities.optimisation.debugging.layer_fetching.html" />
  19. <link rel="prev" title="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLUModel" href="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLUModel.html" />
  20. <link rel="stylesheet" href="../_static/custom.css" type="text/css" />
  21. <meta name="viewport" content="width=device-width, initial-scale=0.9, maximum-scale=0.9" />
  22. </head><body>
  23. <div class="document">
  24. <div class="documentwrapper">
  25. <div class="bodywrapper">
  26. <div class="body" role="main">
  27. <section id="draugr-torch-utilities-optimisation-debugging-gradients-guided-guidedbackproprelu">
  28. <h1>draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU<a class="headerlink" href="#draugr-torch-utilities-optimisation-debugging-gradients-guided-guidedbackproprelu" title="Permalink to this heading">¶</a></h1>
  29. <dl class="py class">
  30. <dt class="sig sig-object py" id="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU">
  31. <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">draugr.torch_utilities.optimisation.debugging.gradients.guided.</span></span><span class="sig-name descname"><span class="pre">GuidedBackPropReLU</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/draugr/torch_utilities/optimisation/debugging/gradients/guided.html#GuidedBackPropReLU"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU" title="Permalink to this definition">¶</a></dt>
  32. <dd><p>Bases: <code class="xref py py-class docutils literal notranslate"><span class="pre">Function</span></code></p>
  33. <dl class="py method">
  34. <dt class="sig sig-object py" id="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.__init__">
  35. <span class="sig-name descname"><span class="pre">__init__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.__init__" title="Permalink to this definition">¶</a></dt>
  36. <dd></dd></dl>
  37. <p class="rubric">Methods</p>
  38. <table class="autosummary longtable docutils align-default">
  39. <colgroup>
  40. <col style="width: 10%" />
  41. <col style="width: 90%" />
  42. </colgroup>
  43. <tbody>
  44. <tr class="row-odd"><td><p><a class="reference internal" href="#draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.__init__" title="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.__init__"><code class="xref py py-obj docutils literal notranslate"><span class="pre">__init__</span></code></a>(*args, **kwargs)</p></td>
  45. <td><p></p></td>
  46. </tr>
  47. <tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">apply</span></code></p></td>
  48. <td><p></p></td>
  49. </tr>
  50. <tr class="row-odd"><td><p><a class="reference internal" href="#draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.backward" title="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.backward"><code class="xref py py-obj docutils literal notranslate"><span class="pre">backward</span></code></a>(self, grad_output)</p></td>
  51. <td><p><dl class="field-list simple">
  52. <dt class="field-odd">param self</dt>
  53. <dd class="field-odd"><p></p></dd>
  54. </dl>
  55. </p></td>
  56. </tr>
  57. <tr class="row-even"><td><p><a class="reference internal" href="#draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.forward" title="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.forward"><code class="xref py py-obj docutils literal notranslate"><span class="pre">forward</span></code></a>(self, input_img)</p></td>
  58. <td><p><dl class="field-list simple">
  59. <dt class="field-odd">param self</dt>
  60. <dd class="field-odd"><p></p></dd>
  61. </dl>
  62. </p></td>
  63. </tr>
  64. <tr class="row-odd"><td><p><a class="reference internal" href="#draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.jvp" title="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.jvp"><code class="xref py py-obj docutils literal notranslate"><span class="pre">jvp</span></code></a>(ctx, *grad_inputs)</p></td>
  65. <td><p>Defines a formula for differentiating the operation with forward mode automatic differentiation.</p></td>
  66. </tr>
  67. <tr class="row-even"><td><p><a class="reference internal" href="#draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.mark_dirty" title="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.mark_dirty"><code class="xref py py-obj docutils literal notranslate"><span class="pre">mark_dirty</span></code></a>(*args)</p></td>
  68. <td><p>Marks given tensors as modified in an in-place operation.</p></td>
  69. </tr>
  70. <tr class="row-odd"><td><p><a class="reference internal" href="#draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.mark_non_differentiable" title="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.mark_non_differentiable"><code class="xref py py-obj docutils literal notranslate"><span class="pre">mark_non_differentiable</span></code></a>(*args)</p></td>
  71. <td><p>Marks outputs as non-differentiable.</p></td>
  72. </tr>
  73. <tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">mark_shared_storage</span></code>(*pairs)</p></td>
  74. <td><p></p></td>
  75. </tr>
  76. <tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">name</span></code></p></td>
  77. <td><p></p></td>
  78. </tr>
  79. <tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">register_hook</span></code></p></td>
  80. <td><p></p></td>
  81. </tr>
  82. <tr class="row-odd"><td><p><a class="reference internal" href="#draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.save_for_backward" title="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.save_for_backward"><code class="xref py py-obj docutils literal notranslate"><span class="pre">save_for_backward</span></code></a>(*tensors)</p></td>
  83. <td><p>Saves given tensors for a future call to <code class="xref py py-func docutils literal notranslate"><span class="pre">backward()</span></code>.</p></td>
  84. </tr>
  85. <tr class="row-even"><td><p><a class="reference internal" href="#draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.save_for_forward" title="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.save_for_forward"><code class="xref py py-obj docutils literal notranslate"><span class="pre">save_for_forward</span></code></a>(*tensors)</p></td>
  86. <td><p>Saves given tensors for a future call to <code class="xref py py-func docutils literal notranslate"><span class="pre">jvp()</span></code>.</p></td>
  87. </tr>
  88. <tr class="row-odd"><td><p><a class="reference internal" href="#draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.set_materialize_grads" title="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.set_materialize_grads"><code class="xref py py-obj docutils literal notranslate"><span class="pre">set_materialize_grads</span></code></a>(value)</p></td>
  89. <td><p>Sets whether to materialize output grad tensors.</p></td>
  90. </tr>
  91. <tr class="row-even"><td><p><a class="reference internal" href="#draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.vjp" title="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.vjp"><code class="xref py py-obj docutils literal notranslate"><span class="pre">vjp</span></code></a>(ctx, *grad_outputs)</p></td>
  92. <td><p>Defines a formula for differentiating the operation with backward mode automatic differentiation (alias to the vjp function).</p></td>
  93. </tr>
  94. </tbody>
  95. </table>
  96. <p class="rubric">Attributes</p>
  97. <table class="autosummary longtable docutils align-default">
  98. <colgroup>
  99. <col style="width: 10%" />
  100. <col style="width: 90%" />
  101. </colgroup>
  102. <tbody>
  103. <tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">dirty_tensors</span></code></p></td>
  104. <td><p></p></td>
  105. </tr>
  106. <tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">is_traceable</span></code></p></td>
  107. <td><p></p></td>
  108. </tr>
  109. <tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">materialize_grads</span></code></p></td>
  110. <td><p></p></td>
  111. </tr>
  112. <tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">metadata</span></code></p></td>
  113. <td><p></p></td>
  114. </tr>
  115. <tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">needs_input_grad</span></code></p></td>
  116. <td><p></p></td>
  117. </tr>
  118. <tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">next_functions</span></code></p></td>
  119. <td><p></p></td>
  120. </tr>
  121. <tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">non_differentiable</span></code></p></td>
  122. <td><p></p></td>
  123. </tr>
  124. <tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">requires_grad</span></code></p></td>
  125. <td><p></p></td>
  126. </tr>
  127. <tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">saved_for_forward</span></code></p></td>
  128. <td><p></p></td>
  129. </tr>
  130. <tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">saved_tensors</span></code></p></td>
  131. <td><p></p></td>
  132. </tr>
  133. <tr class="row-odd"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">saved_variables</span></code></p></td>
  134. <td><p></p></td>
  135. </tr>
  136. <tr class="row-even"><td><p><code class="xref py py-obj docutils literal notranslate"><span class="pre">to_save</span></code></p></td>
  137. <td><p></p></td>
  138. </tr>
  139. </tbody>
  140. </table>
  141. <dl class="py method">
  142. <dt class="sig sig-object py" id="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.backward">
  143. <em class="property"><span class="pre">static</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">backward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">self</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">grad_output</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/draugr/torch_utilities/optimisation/debugging/gradients/guided.html#GuidedBackPropReLU.backward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.backward" title="Permalink to this definition">¶</a></dt>
  144. <dd><dl class="field-list simple">
  145. <dt class="field-odd">Parameters</dt>
  146. <dd class="field-odd"><ul class="simple">
  147. <li><p><strong>self</strong> – </p></li>
  148. <li><p><strong>grad_output</strong> – </p></li>
  149. </ul>
  150. </dd>
  151. <dt class="field-even">Returns</dt>
  152. <dd class="field-even"><p></p>
  153. </dd>
  154. </dl>
  155. </dd></dl>
  156. <dl class="py method">
  157. <dt class="sig sig-object py" id="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.forward">
  158. <em class="property"><span class="pre">static</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">self</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">input_img</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../_modules/draugr/torch_utilities/optimisation/debugging/gradients/guided.html#GuidedBackPropReLU.forward"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.forward" title="Permalink to this definition">¶</a></dt>
  159. <dd><dl class="field-list simple">
  160. <dt class="field-odd">Parameters</dt>
  161. <dd class="field-odd"><ul class="simple">
  162. <li><p><strong>self</strong> – </p></li>
  163. <li><p><strong>input_img</strong> – </p></li>
  164. </ul>
  165. </dd>
  166. <dt class="field-even">Returns</dt>
  167. <dd class="field-even"><p></p>
  168. </dd>
  169. </dl>
  170. </dd></dl>
  171. <dl class="py method">
  172. <dt class="sig sig-object py" id="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.jvp">
  173. <em class="property"><span class="pre">static</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">jvp</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">ctx</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Any" title="(in Python v3.10)"><span class="pre">Any</span></a></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">grad_inputs</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Any" title="(in Python v3.10)"><span class="pre">Any</span></a></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Any" title="(in Python v3.10)"><span class="pre">Any</span></a></span></span><a class="headerlink" href="#draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.jvp" title="Permalink to this definition">¶</a></dt>
  174. <dd><p>Defines a formula for differentiating the operation with forward mode
  175. automatic differentiation.
  176. This function is to be overridden by all subclasses.
  177. It must accept a context <code class="xref py py-attr docutils literal notranslate"><span class="pre">ctx</span></code> as the first argument, followed by
  178. as many inputs as the <a class="reference internal" href="#draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.forward" title="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.forward"><code class="xref py py-func docutils literal notranslate"><span class="pre">forward()</span></code></a> got (None will be passed in
  179. for non tensor inputs of the forward function),
  180. and it should return as many tensors as there were outputs to
  181. <a class="reference internal" href="#draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.forward" title="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.forward"><code class="xref py py-func docutils literal notranslate"><span class="pre">forward()</span></code></a>. Each argument is the gradient w.r.t the given input,
  182. and each returned value should be the gradient w.r.t. the
  183. corresponding output. If an output is not a Tensor or the function is not
  184. differentiable with respect to that output, you can just pass None as a
  185. gradient for that input.</p>
  186. <p>You can use the <code class="xref py py-attr docutils literal notranslate"><span class="pre">ctx</span></code> object to pass any value from the forward to this
  187. functions.</p>
  188. </dd></dl>
  189. <dl class="py method">
  190. <dt class="sig sig-object py" id="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.mark_dirty">
  191. <span class="sig-name descname"><span class="pre">mark_dirty</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Tensor</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.mark_dirty" title="Permalink to this definition">¶</a></dt>
  192. <dd><p>Marks given tensors as modified in an in-place operation.</p>
  193. <p><strong>This should be called at most once, only from inside the</strong>
  194. <a class="reference internal" href="#draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.forward" title="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.forward"><code class="xref py py-func docutils literal notranslate"><span class="pre">forward()</span></code></a> <strong>method, and all arguments should be inputs.</strong></p>
  195. <p>Every tensor that’s been modified in-place in a call to <a class="reference internal" href="#draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.forward" title="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.forward"><code class="xref py py-func docutils literal notranslate"><span class="pre">forward()</span></code></a>
  196. should be given to this function, to ensure correctness of our checks.
  197. It doesn’t matter whether the function is called before or after
  198. modification.</p>
  199. <dl>
  200. <dt>Examples::</dt><dd><div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="k">class</span> <span class="nc">Inplace</span><span class="p">(</span><span class="n">Function</span><span class="p">):</span>
  201. <span class="gp">&gt;&gt;&gt; </span> <span class="nd">@staticmethod</span>
  202. <span class="gp">&gt;&gt;&gt; </span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
  203. <span class="gp">&gt;&gt;&gt; </span> <span class="n">x_npy</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="c1"># x_npy shares storage with x</span>
  204. <span class="gp">&gt;&gt;&gt; </span> <span class="n">x_npy</span> <span class="o">+=</span> <span class="mi">1</span>
  205. <span class="gp">&gt;&gt;&gt; </span> <span class="n">ctx</span><span class="o">.</span><span class="n">mark_dirty</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
  206. <span class="gp">&gt;&gt;&gt; </span> <span class="k">return</span> <span class="n">x</span>
  207. <span class="go">&gt;&gt;&gt;</span>
  208. <span class="gp">&gt;&gt;&gt; </span> <span class="nd">@staticmethod</span>
  209. <span class="gp">&gt;&gt;&gt; </span> <span class="nd">@once_differentiable</span>
  210. <span class="gp">&gt;&gt;&gt; </span> <span class="k">def</span> <span class="nf">backward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">grad_output</span><span class="p">):</span>
  211. <span class="gp">&gt;&gt;&gt; </span> <span class="k">return</span> <span class="n">grad_output</span>
  212. <span class="go">&gt;&gt;&gt;</span>
  213. <span class="gp">&gt;&gt;&gt; </span><span class="n">a</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="mf">1.</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">double</span><span class="p">)</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
  214. <span class="gp">&gt;&gt;&gt; </span><span class="n">b</span> <span class="o">=</span> <span class="n">a</span> <span class="o">*</span> <span class="n">a</span>
  215. <span class="gp">&gt;&gt;&gt; </span><span class="n">Inplace</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">a</span><span class="p">)</span> <span class="c1"># This would lead to wrong gradients!</span>
  216. <span class="gp">&gt;&gt;&gt; </span> <span class="c1"># but the engine would not know unless we mark_dirty</span>
  217. <span class="gp">&gt;&gt;&gt; </span><span class="n">b</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span> <span class="c1"># RuntimeError: one of the variables needed for gradient</span>
  218. <span class="gp">&gt;&gt;&gt; </span> <span class="c1"># computation has been modified by an inplace operation</span>
  219. </pre></div>
  220. </div>
  221. </dd>
  222. </dl>
  223. </dd></dl>
  224. <dl class="py method">
  225. <dt class="sig sig-object py" id="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.mark_non_differentiable">
  226. <span class="sig-name descname"><span class="pre">mark_non_differentiable</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Tensor</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.mark_non_differentiable" title="Permalink to this definition">¶</a></dt>
  227. <dd><p>Marks outputs as non-differentiable.</p>
  228. <p><strong>This should be called at most once, only from inside the</strong>
  229. <a class="reference internal" href="#draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.forward" title="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.forward"><code class="xref py py-func docutils literal notranslate"><span class="pre">forward()</span></code></a> <strong>method, and all arguments should be tensor outputs.</strong></p>
  230. <p>This will mark outputs as not requiring gradients, increasing the
  231. efficiency of backward computation. You still need to accept a gradient
  232. for each output in <code class="xref py py-meth docutils literal notranslate"><span class="pre">backward()</span></code>, but it’s always going to
  233. be a zero tensor with the same shape as the shape of a corresponding
  234. output.</p>
  235. <dl>
  236. <dt>This is used e.g. for indices returned from a sort. See example::</dt><dd><div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="k">class</span> <span class="nc">Func</span><span class="p">(</span><span class="n">Function</span><span class="p">):</span>
  237. <span class="gp">&gt;&gt;&gt; </span> <span class="nd">@staticmethod</span>
  238. <span class="gp">&gt;&gt;&gt; </span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
  239. <span class="gp">&gt;&gt;&gt; </span> <span class="nb">sorted</span><span class="p">,</span> <span class="n">idx</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">sort</span><span class="p">()</span>
  240. <span class="gp">&gt;&gt;&gt; </span> <span class="n">ctx</span><span class="o">.</span><span class="n">mark_non_differentiable</span><span class="p">(</span><span class="n">idx</span><span class="p">)</span>
  241. <span class="gp">&gt;&gt;&gt; </span> <span class="n">ctx</span><span class="o">.</span><span class="n">save_for_backward</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">idx</span><span class="p">)</span>
  242. <span class="gp">&gt;&gt;&gt; </span> <span class="k">return</span> <span class="nb">sorted</span><span class="p">,</span> <span class="n">idx</span>
  243. <span class="go">&gt;&gt;&gt;</span>
  244. <span class="gp">&gt;&gt;&gt; </span> <span class="nd">@staticmethod</span>
  245. <span class="gp">&gt;&gt;&gt; </span> <span class="nd">@once_differentiable</span>
  246. <span class="gp">&gt;&gt;&gt; </span> <span class="k">def</span> <span class="nf">backward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">g1</span><span class="p">,</span> <span class="n">g2</span><span class="p">):</span> <span class="c1"># still need to accept g2</span>
  247. <span class="gp">&gt;&gt;&gt; </span> <span class="n">x</span><span class="p">,</span> <span class="n">idx</span> <span class="o">=</span> <span class="n">ctx</span><span class="o">.</span><span class="n">saved_tensors</span>
  248. <span class="gp">&gt;&gt;&gt; </span> <span class="n">grad_input</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
  249. <span class="gp">&gt;&gt;&gt; </span> <span class="n">grad_input</span><span class="o">.</span><span class="n">index_add_</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">idx</span><span class="p">,</span> <span class="n">g1</span><span class="p">)</span>
  250. <span class="gp">&gt;&gt;&gt; </span> <span class="k">return</span> <span class="n">grad_input</span>
  251. </pre></div>
  252. </div>
  253. </dd>
  254. </dl>
  255. </dd></dl>
  256. <dl class="py method">
  257. <dt class="sig sig-object py" id="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.save_for_backward">
  258. <span class="sig-name descname"><span class="pre">save_for_backward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">tensors</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Tensor</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.save_for_backward" title="Permalink to this definition">¶</a></dt>
  259. <dd><p>Saves given tensors for a future call to <code class="xref py py-func docutils literal notranslate"><span class="pre">backward()</span></code>.</p>
  260. <p><code class="docutils literal notranslate"><span class="pre">save_for_backward</span></code> should be called at most once, only from inside the
  261. <a class="reference internal" href="#draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.forward" title="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.forward"><code class="xref py py-func docutils literal notranslate"><span class="pre">forward()</span></code></a> method, and only with tensors.</p>
  262. <p>All tensors intended to be used in the backward pass should be saved
  263. with <code class="docutils literal notranslate"><span class="pre">save_for_backward</span></code> (as opposed to directly on <code class="docutils literal notranslate"><span class="pre">ctx</span></code>) to prevent
  264. incorrect gradients and memory leaks, and enable the application of saved
  265. tensor hooks. See <code class="xref py py-class docutils literal notranslate"><span class="pre">torch.autograd.graph.saved_tensors_hooks</span></code>.</p>
  266. <p>In <a class="reference internal" href="#draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.backward" title="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.backward"><code class="xref py py-func docutils literal notranslate"><span class="pre">backward()</span></code></a>, saved tensors can be accessed through the <code class="xref py py-attr docutils literal notranslate"><span class="pre">saved_tensors</span></code>
  267. attribute. Before returning them to the user, a check is made to ensure
  268. they weren’t used in any in-place operation that modified their content.</p>
  269. <p>Arguments can also be <code class="docutils literal notranslate"><span class="pre">None</span></code>. This is a no-op.</p>
  270. <p>See <span class="xref std std-ref">extending-autograd</span> for more details on how to use this method.</p>
  271. <dl>
  272. <dt>Example::</dt><dd><div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="k">class</span> <span class="nc">Func</span><span class="p">(</span><span class="n">Function</span><span class="p">):</span>
  273. <span class="gp">&gt;&gt;&gt; </span> <span class="nd">@staticmethod</span>
  274. <span class="gp">&gt;&gt;&gt; </span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">z</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
  275. <span class="gp">&gt;&gt;&gt; </span> <span class="n">w</span> <span class="o">=</span> <span class="n">x</span> <span class="o">*</span> <span class="n">y</span> <span class="o">*</span> <span class="n">z</span>
  276. <span class="gp">&gt;&gt;&gt; </span> <span class="n">out</span> <span class="o">=</span> <span class="n">x</span> <span class="o">*</span> <span class="n">y</span> <span class="o">+</span> <span class="n">y</span> <span class="o">*</span> <span class="n">z</span> <span class="o">+</span> <span class="n">w</span>
  277. <span class="gp">&gt;&gt;&gt; </span> <span class="n">ctx</span><span class="o">.</span><span class="n">save_for_backward</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="n">out</span><span class="p">)</span>
  278. <span class="gp">&gt;&gt;&gt; </span> <span class="n">ctx</span><span class="o">.</span><span class="n">z</span> <span class="o">=</span> <span class="n">z</span> <span class="c1"># z is not a tensor</span>
  279. <span class="gp">&gt;&gt;&gt; </span> <span class="k">return</span> <span class="n">out</span>
  280. <span class="go">&gt;&gt;&gt;</span>
  281. <span class="gp">&gt;&gt;&gt; </span> <span class="nd">@staticmethod</span>
  282. <span class="gp">&gt;&gt;&gt; </span> <span class="k">def</span> <span class="nf">backward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">grad_out</span><span class="p">):</span>
  283. <span class="gp">&gt;&gt;&gt; </span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="n">out</span> <span class="o">=</span> <span class="n">ctx</span><span class="o">.</span><span class="n">saved_tensors</span>
  284. <span class="gp">&gt;&gt;&gt; </span> <span class="n">z</span> <span class="o">=</span> <span class="n">ctx</span><span class="o">.</span><span class="n">z</span>
  285. <span class="gp">&gt;&gt;&gt; </span> <span class="n">gx</span> <span class="o">=</span> <span class="n">grad_out</span> <span class="o">*</span> <span class="p">(</span><span class="n">y</span> <span class="o">+</span> <span class="n">y</span> <span class="o">*</span> <span class="n">z</span><span class="p">)</span>
  286. <span class="gp">&gt;&gt;&gt; </span> <span class="n">gy</span> <span class="o">=</span> <span class="n">grad_out</span> <span class="o">*</span> <span class="p">(</span><span class="n">x</span> <span class="o">+</span> <span class="n">z</span> <span class="o">+</span> <span class="n">x</span> <span class="o">*</span> <span class="n">z</span><span class="p">)</span>
  287. <span class="gp">&gt;&gt;&gt; </span> <span class="n">gz</span> <span class="o">=</span> <span class="kc">None</span>
  288. <span class="gp">&gt;&gt;&gt; </span> <span class="k">return</span> <span class="n">gx</span><span class="p">,</span> <span class="n">gy</span><span class="p">,</span> <span class="n">gz</span>
  289. <span class="go">&gt;&gt;&gt;</span>
  290. <span class="gp">&gt;&gt;&gt; </span><span class="n">a</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="mf">1.</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">double</span><span class="p">)</span>
  291. <span class="gp">&gt;&gt;&gt; </span><span class="n">b</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="mf">2.</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">double</span><span class="p">)</span>
  292. <span class="gp">&gt;&gt;&gt; </span><span class="n">c</span> <span class="o">=</span> <span class="mi">4</span>
  293. <span class="gp">&gt;&gt;&gt; </span><span class="n">d</span> <span class="o">=</span> <span class="n">Func</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span>
  294. </pre></div>
  295. </div>
  296. </dd>
  297. </dl>
  298. </dd></dl>
  299. <dl class="py method">
  300. <dt class="sig sig-object py" id="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.save_for_forward">
  301. <span class="sig-name descname"><span class="pre">save_for_forward</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">tensors</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><span class="pre">Tensor</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.save_for_forward" title="Permalink to this definition">¶</a></dt>
  302. <dd><p>Saves given tensors for a future call to <code class="xref py py-func docutils literal notranslate"><span class="pre">jvp()</span></code>.</p>
  303. <p><code class="docutils literal notranslate"><span class="pre">save_for_forward</span></code> should be only called once, from inside the <a class="reference internal" href="#draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.forward" title="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.forward"><code class="xref py py-func docutils literal notranslate"><span class="pre">forward()</span></code></a>
  304. method, and only be called with tensors.</p>
  305. <p>In <a class="reference internal" href="#draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.jvp" title="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.jvp"><code class="xref py py-func docutils literal notranslate"><span class="pre">jvp()</span></code></a>, saved objects can be accessed through the <code class="xref py py-attr docutils literal notranslate"><span class="pre">saved_tensors</span></code>
  306. attribute.</p>
  307. <p>Arguments can also be <code class="docutils literal notranslate"><span class="pre">None</span></code>. This is a no-op.</p>
  308. <p>See <span class="xref std std-ref">extending-autograd</span> for more details on how to use this method.</p>
  309. <dl>
  310. <dt>Example::</dt><dd><div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="k">class</span> <span class="nc">Func</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">Function</span><span class="p">):</span>
  311. <span class="gp">&gt;&gt;&gt; </span> <span class="nd">@staticmethod</span>
  312. <span class="gp">&gt;&gt;&gt; </span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">z</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
  313. <span class="gp">&gt;&gt;&gt; </span> <span class="n">ctx</span><span class="o">.</span><span class="n">save_for_backward</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
  314. <span class="gp">&gt;&gt;&gt; </span> <span class="n">ctx</span><span class="o">.</span><span class="n">save_for_forward</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
  315. <span class="gp">&gt;&gt;&gt; </span> <span class="n">ctx</span><span class="o">.</span><span class="n">z</span> <span class="o">=</span> <span class="n">z</span>
  316. <span class="gp">&gt;&gt;&gt; </span> <span class="k">return</span> <span class="n">x</span> <span class="o">*</span> <span class="n">y</span> <span class="o">*</span> <span class="n">z</span>
  317. <span class="go">&gt;&gt;&gt;</span>
  318. <span class="gp">&gt;&gt;&gt; </span> <span class="nd">@staticmethod</span>
  319. <span class="gp">&gt;&gt;&gt; </span> <span class="k">def</span> <span class="nf">jvp</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">x_t</span><span class="p">,</span> <span class="n">y_t</span><span class="p">,</span> <span class="n">_</span><span class="p">):</span>
  320. <span class="gp">&gt;&gt;&gt; </span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">ctx</span><span class="o">.</span><span class="n">saved_tensors</span>
  321. <span class="gp">&gt;&gt;&gt; </span> <span class="n">z</span> <span class="o">=</span> <span class="n">ctx</span><span class="o">.</span><span class="n">z</span>
  322. <span class="gp">&gt;&gt;&gt; </span> <span class="k">return</span> <span class="n">z</span> <span class="o">*</span> <span class="p">(</span><span class="n">y</span> <span class="o">*</span> <span class="n">x_t</span> <span class="o">+</span> <span class="n">x</span> <span class="o">*</span> <span class="n">y_t</span><span class="p">)</span>
  323. <span class="go">&gt;&gt;&gt;</span>
  324. <span class="gp">&gt;&gt;&gt; </span> <span class="nd">@staticmethod</span>
  325. <span class="gp">&gt;&gt;&gt; </span> <span class="k">def</span> <span class="nf">vjp</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">grad_out</span><span class="p">):</span>
  326. <span class="gp">&gt;&gt;&gt; </span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">ctx</span><span class="o">.</span><span class="n">saved_tensors</span>
  327. <span class="gp">&gt;&gt;&gt; </span> <span class="n">z</span> <span class="o">=</span> <span class="n">ctx</span><span class="o">.</span><span class="n">z</span>
  328. <span class="gp">&gt;&gt;&gt; </span> <span class="k">return</span> <span class="n">z</span> <span class="o">*</span> <span class="n">grad_out</span> <span class="o">*</span> <span class="n">y</span><span class="p">,</span> <span class="n">z</span> <span class="o">*</span> <span class="n">grad_out</span> <span class="o">*</span> <span class="n">x</span><span class="p">,</span> <span class="kc">None</span>
  329. <span class="go">&gt;&gt;&gt;</span>
  330. <span class="gp">&gt;&gt;&gt; </span> <span class="n">a</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="mf">1.</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">double</span><span class="p">)</span>
  331. <span class="gp">&gt;&gt;&gt; </span> <span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="mf">1.</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">double</span><span class="p">)</span>
  332. <span class="gp">&gt;&gt;&gt; </span> <span class="n">b</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="mf">2.</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">double</span><span class="p">)</span>
  333. <span class="gp">&gt;&gt;&gt; </span> <span class="n">c</span> <span class="o">=</span> <span class="mi">4</span>
  334. <span class="go">&gt;&gt;&gt;</span>
  335. <span class="gp">&gt;&gt;&gt; </span> <span class="k">with</span> <span class="n">fwAD</span><span class="o">.</span><span class="n">dual_level</span><span class="p">():</span>
  336. <span class="gp">&gt;&gt;&gt; </span> <span class="n">a_dual</span> <span class="o">=</span> <span class="n">fwAD</span><span class="o">.</span><span class="n">make_dual</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span>
  337. <span class="gp">&gt;&gt;&gt; </span> <span class="n">d</span> <span class="o">=</span> <span class="n">Func</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">a_dual</span><span class="p">,</span> <span class="n">b</span><span class="p">,</span> <span class="n">c</span><span class="p">)</span>
  338. </pre></div>
  339. </div>
  340. </dd>
  341. </dl>
  342. </dd></dl>
  343. <dl class="py method">
  344. <dt class="sig sig-object py" id="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.set_materialize_grads">
  345. <span class="sig-name descname"><span class="pre">set_materialize_grads</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">value</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.10)"><span class="pre">bool</span></a></span></em><span class="sig-paren">)</span><a class="headerlink" href="#draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.set_materialize_grads" title="Permalink to this definition">¶</a></dt>
  346. <dd><p>Sets whether to materialize output grad tensors. Default is <code class="docutils literal notranslate"><span class="pre">True</span></code>.</p>
  347. <p><strong>This should be called only from inside the</strong> <a class="reference internal" href="#draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.forward" title="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.forward"><code class="xref py py-func docutils literal notranslate"><span class="pre">forward()</span></code></a> <strong>method</strong></p>
  348. <p>If <code class="docutils literal notranslate"><span class="pre">True</span></code>, undefined output grad tensors will be expanded to tensors full
  349. of zeros prior to calling the <a class="reference internal" href="#draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.backward" title="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.backward"><code class="xref py py-func docutils literal notranslate"><span class="pre">backward()</span></code></a> method.</p>
  350. <dl>
  351. <dt>Example::</dt><dd><div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="k">class</span> <span class="nc">SimpleFunc</span><span class="p">(</span><span class="n">Function</span><span class="p">):</span>
  352. <span class="gp">&gt;&gt;&gt; </span> <span class="nd">@staticmethod</span>
  353. <span class="gp">&gt;&gt;&gt; </span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
  354. <span class="gp">&gt;&gt;&gt; </span> <span class="k">return</span> <span class="n">x</span><span class="o">.</span><span class="n">clone</span><span class="p">(),</span> <span class="n">x</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
  355. <span class="go">&gt;&gt;&gt;</span>
  356. <span class="gp">&gt;&gt;&gt; </span> <span class="nd">@staticmethod</span>
  357. <span class="gp">&gt;&gt;&gt; </span> <span class="nd">@once_differentiable</span>
  358. <span class="gp">&gt;&gt;&gt; </span> <span class="k">def</span> <span class="nf">backward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">g1</span><span class="p">,</span> <span class="n">g2</span><span class="p">):</span>
  359. <span class="gp">&gt;&gt;&gt; </span> <span class="k">return</span> <span class="n">g1</span> <span class="o">+</span> <span class="n">g2</span> <span class="c1"># No check for None necessary</span>
  360. <span class="go">&gt;&gt;&gt;</span>
  361. <span class="gp">&gt;&gt;&gt; </span><span class="c1"># We modify SimpleFunc to handle non-materialized grad outputs</span>
  362. <span class="gp">&gt;&gt;&gt; </span><span class="k">class</span> <span class="nc">Func</span><span class="p">(</span><span class="n">Function</span><span class="p">):</span>
  363. <span class="gp">&gt;&gt;&gt; </span> <span class="nd">@staticmethod</span>
  364. <span class="gp">&gt;&gt;&gt; </span> <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
  365. <span class="gp">&gt;&gt;&gt; </span> <span class="n">ctx</span><span class="o">.</span><span class="n">set_materialize_grads</span><span class="p">(</span><span class="kc">False</span><span class="p">)</span>
  366. <span class="gp">&gt;&gt;&gt; </span> <span class="n">ctx</span><span class="o">.</span><span class="n">save_for_backward</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
  367. <span class="gp">&gt;&gt;&gt; </span> <span class="k">return</span> <span class="n">x</span><span class="o">.</span><span class="n">clone</span><span class="p">(),</span> <span class="n">x</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
  368. <span class="go">&gt;&gt;&gt;</span>
  369. <span class="gp">&gt;&gt;&gt; </span> <span class="nd">@staticmethod</span>
  370. <span class="gp">&gt;&gt;&gt; </span> <span class="nd">@once_differentiable</span>
  371. <span class="gp">&gt;&gt;&gt; </span> <span class="k">def</span> <span class="nf">backward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">g1</span><span class="p">,</span> <span class="n">g2</span><span class="p">):</span>
  372. <span class="gp">&gt;&gt;&gt; </span> <span class="n">x</span><span class="p">,</span> <span class="o">=</span> <span class="n">ctx</span><span class="o">.</span><span class="n">saved_tensors</span>
  373. <span class="gp">&gt;&gt;&gt; </span> <span class="n">grad_input</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
  374. <span class="gp">&gt;&gt;&gt; </span> <span class="k">if</span> <span class="n">g1</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> <span class="c1"># We must check for None now</span>
  375. <span class="gp">&gt;&gt;&gt; </span> <span class="n">grad_input</span> <span class="o">+=</span> <span class="n">g1</span>
  376. <span class="gp">&gt;&gt;&gt; </span> <span class="k">if</span> <span class="n">g2</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
  377. <span class="gp">&gt;&gt;&gt; </span> <span class="n">grad_input</span> <span class="o">+=</span> <span class="n">g2</span>
  378. <span class="gp">&gt;&gt;&gt; </span> <span class="k">return</span> <span class="n">grad_input</span>
  379. <span class="go">&gt;&gt;&gt;</span>
  380. <span class="gp">&gt;&gt;&gt; </span><span class="n">a</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="mf">1.</span><span class="p">,</span> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
  381. <span class="gp">&gt;&gt;&gt; </span><span class="n">b</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">Func</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">a</span><span class="p">)</span> <span class="c1"># induces g2 to be undefined</span>
  382. </pre></div>
  383. </div>
  384. </dd>
  385. </dl>
  386. </dd></dl>
  387. <dl class="py method">
  388. <dt class="sig sig-object py" id="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.vjp">
  389. <em class="property"><span class="pre">static</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">vjp</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">ctx</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Any" title="(in Python v3.10)"><span class="pre">Any</span></a></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">grad_outputs</span></span><span class="p"><span class="pre">:</span></span><span class="w"> </span><span class="n"><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Any" title="(in Python v3.10)"><span class="pre">Any</span></a></span></em><span class="sig-paren">)</span> <span class="sig-return"><span class="sig-return-icon">&#x2192;</span> <span class="sig-return-typehint"><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Any" title="(in Python v3.10)"><span class="pre">Any</span></a></span></span><a class="headerlink" href="#draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.vjp" title="Permalink to this definition">¶</a></dt>
  390. <dd><p>Defines a formula for differentiating the operation with backward mode
  391. automatic differentiation (alias to the vjp function).</p>
  392. <p>This function is to be overridden by all subclasses.</p>
  393. <p>It must accept a context <code class="xref py py-attr docutils literal notranslate"><span class="pre">ctx</span></code> as the first argument, followed by
  394. as many outputs as the <a class="reference internal" href="#draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.forward" title="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.forward"><code class="xref py py-func docutils literal notranslate"><span class="pre">forward()</span></code></a> returned (None will be passed in
  395. for non tensor outputs of the forward function),
  396. and it should return as many tensors, as there were inputs to
  397. <a class="reference internal" href="#draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.forward" title="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.forward"><code class="xref py py-func docutils literal notranslate"><span class="pre">forward()</span></code></a>. Each argument is the gradient w.r.t the given output,
  398. and each returned value should be the gradient w.r.t. the
  399. corresponding input. If an input is not a Tensor or is a Tensor not
  400. requiring grads, you can just pass None as a gradient for that input.</p>
  401. <p>The context can be used to retrieve tensors saved during the forward
  402. pass. It also has an attribute <code class="xref py py-attr docutils literal notranslate"><span class="pre">ctx.needs_input_grad</span></code> as a tuple
  403. of booleans representing whether each input needs gradient. E.g.,
  404. <a class="reference internal" href="#draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.backward" title="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.backward"><code class="xref py py-func docutils literal notranslate"><span class="pre">backward()</span></code></a> will have <code class="docutils literal notranslate"><span class="pre">ctx.needs_input_grad[0]</span> <span class="pre">=</span> <span class="pre">True</span></code> if the
  405. first input to <a class="reference internal" href="#draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.forward" title="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.forward"><code class="xref py py-func docutils literal notranslate"><span class="pre">forward()</span></code></a> needs gradient computated w.r.t. the
  406. output.</p>
  407. </dd></dl>
  408. </dd></dl>
  409. </section>
  410. </div>
  411. </div>
  412. </div>
  413. <div class="sphinxsidebar" role="navigation" aria-label="main navigation">
  414. <div class="sphinxsidebarwrapper">
  415. <p class="logo"><a href="../index.html">
  416. <img class="logo" src="../_static/valknut.svg" alt="Logo"/>
  417. </a></p>
  418. <h1 class="logo"><a href="../index.html">Draugr</a></h1>
  419. <h3>Navigation</h3>
  420. <ul class="current">
  421. <li class="toctree-l1 current"><a class="reference internal" href="draugr.html">draugr</a><ul class="current">
  422. <li class="toctree-l2"><a class="reference internal" href="draugr.dist_is_editable.html">draugr.dist_is_editable</a></li>
  423. <li class="toctree-l2"><a class="reference internal" href="draugr.get_version.html">draugr.get_version</a></li>
  424. <li class="toctree-l2"><a class="reference internal" href="draugr.dlib_utilities.html">draugr.dlib_utilities</a></li>
  425. <li class="toctree-l2"><a class="reference internal" href="draugr.drawers.html">draugr.drawers</a></li>
  426. <li class="toctree-l2"><a class="reference internal" href="draugr.entry_points.html">draugr.entry_points</a></li>
  427. <li class="toctree-l2"><a class="reference internal" href="draugr.extensions.html">draugr.extensions</a></li>
  428. <li class="toctree-l2"><a class="reference internal" href="draugr.ffmpeg_utilities.html">draugr.ffmpeg_utilities</a></li>
  429. <li class="toctree-l2"><a class="reference internal" href="draugr.jax_utilities.html">draugr.jax_utilities</a></li>
  430. <li class="toctree-l2"><a class="reference internal" href="draugr.metrics.html">draugr.metrics</a></li>
  431. <li class="toctree-l2"><a class="reference internal" href="draugr.multiprocessing_utilities.html">draugr.multiprocessing_utilities</a></li>
  432. <li class="toctree-l2"><a class="reference internal" href="draugr.numpy_utilities.html">draugr.numpy_utilities</a></li>
  433. <li class="toctree-l2"><a class="reference internal" href="draugr.opencv_utilities.html">draugr.opencv_utilities</a></li>
  434. <li class="toctree-l2"><a class="reference internal" href="draugr.os_utilities.html">draugr.os_utilities</a></li>
  435. <li class="toctree-l2"><a class="reference internal" href="draugr.pandas_utilities.html">draugr.pandas_utilities</a></li>
  436. <li class="toctree-l2"><a class="reference internal" href="draugr.pygame_utilities.html">draugr.pygame_utilities</a></li>
  437. <li class="toctree-l2"><a class="reference internal" href="draugr.python_utilities.html">draugr.python_utilities</a></li>
  438. <li class="toctree-l2"><a class="reference internal" href="draugr.random_utilities.html">draugr.random_utilities</a></li>
  439. <li class="toctree-l2"><a class="reference internal" href="draugr.scipy_utilities.html">draugr.scipy_utilities</a></li>
  440. <li class="toctree-l2"><a class="reference internal" href="draugr.stopping.html">draugr.stopping</a></li>
  441. <li class="toctree-l2"><a class="reference internal" href="draugr.tensorboard_utilities.html">draugr.tensorboard_utilities</a></li>
  442. <li class="toctree-l2"><a class="reference internal" href="draugr.threading_utilities.html">draugr.threading_utilities</a></li>
  443. <li class="toctree-l2 current"><a class="reference internal" href="draugr.torch_utilities.html">draugr.torch_utilities</a><ul class="current">
  444. <li class="toctree-l3"><a class="reference internal" href="draugr.torch_utilities.architectures.html">draugr.torch_utilities.architectures</a></li>
  445. <li class="toctree-l3"><a class="reference internal" href="draugr.torch_utilities.datasets.html">draugr.torch_utilities.datasets</a></li>
  446. <li class="toctree-l3"><a class="reference internal" href="draugr.torch_utilities.distributions.html">draugr.torch_utilities.distributions</a></li>
  447. <li class="toctree-l3"><a class="reference internal" href="draugr.torch_utilities.evaluation.html">draugr.torch_utilities.evaluation</a></li>
  448. <li class="toctree-l3"><a class="reference internal" href="draugr.torch_utilities.exporting.html">draugr.torch_utilities.exporting</a></li>
  449. <li class="toctree-l3"><a class="reference internal" href="draugr.torch_utilities.generators.html">draugr.torch_utilities.generators</a></li>
  450. <li class="toctree-l3"><a class="reference internal" href="draugr.torch_utilities.images.html">draugr.torch_utilities.images</a></li>
  451. <li class="toctree-l3"><a class="reference internal" href="draugr.torch_utilities.opencv.html">draugr.torch_utilities.opencv</a></li>
  452. <li class="toctree-l3"><a class="reference internal" href="draugr.torch_utilities.operations.html">draugr.torch_utilities.operations</a></li>
  453. <li class="toctree-l3 current"><a class="reference internal" href="draugr.torch_utilities.optimisation.html">draugr.torch_utilities.optimisation</a><ul class="current">
  454. <li class="toctree-l4 current"><a class="reference internal" href="draugr.torch_utilities.optimisation.debugging.html">draugr.torch_utilities.optimisation.debugging</a><ul class="current">
  455. <li class="toctree-l5 current"><a class="reference internal" href="draugr.torch_utilities.optimisation.debugging.gradients.html">draugr.torch_utilities.optimisation.debugging.gradients</a><ul class="current">
  456. <li class="toctree-l6"><a class="reference internal" href="draugr.torch_utilities.optimisation.debugging.gradients.checking.html">draugr.torch_utilities.optimisation.debugging.gradients.checking</a></li>
  457. <li class="toctree-l6"><a class="reference internal" href="draugr.torch_utilities.optimisation.debugging.gradients.flow.html">draugr.torch_utilities.optimisation.debugging.gradients.flow</a></li>
  458. <li class="toctree-l6"><a class="reference internal" href="draugr.torch_utilities.optimisation.debugging.gradients.grad_trace.html">draugr.torch_utilities.optimisation.debugging.gradients.grad_trace</a></li>
  459. <li class="toctree-l6 current"><a class="reference internal" href="draugr.torch_utilities.optimisation.debugging.gradients.guided.html">draugr.torch_utilities.optimisation.debugging.gradients.guided</a><ul class="current">
  460. <li class="toctree-l7"><a class="reference internal" href="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLUModel.html">draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLUModel</a></li>
  461. <li class="toctree-l7 current"><a class="current reference internal" href="#">draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU</a></li>
  462. </ul>
  463. </li>
  464. </ul>
  465. </li>
  466. <li class="toctree-l5"><a class="reference internal" href="draugr.torch_utilities.optimisation.debugging.layer_fetching.html">draugr.torch_utilities.optimisation.debugging.layer_fetching</a></li>
  467. <li class="toctree-l5"><a class="reference internal" href="draugr.torch_utilities.optimisation.debugging.opt_verification.html">draugr.torch_utilities.optimisation.debugging.opt_verification</a></li>
  468. </ul>
  469. </li>
  470. <li class="toctree-l4"><a class="reference internal" href="draugr.torch_utilities.optimisation.parameters.html">draugr.torch_utilities.optimisation.parameters</a></li>
  471. <li class="toctree-l4"><a class="reference internal" href="draugr.torch_utilities.optimisation.scheduling.html">draugr.torch_utilities.optimisation.scheduling</a></li>
  472. <li class="toctree-l4"><a class="reference internal" href="draugr.torch_utilities.optimisation.stopping.html">draugr.torch_utilities.optimisation.stopping</a></li>
  473. <li class="toctree-l4"><a class="reference internal" href="draugr.torch_utilities.optimisation.updates.html">draugr.torch_utilities.optimisation.updates</a></li>
  474. </ul>
  475. </li>
  476. <li class="toctree-l3"><a class="reference internal" href="draugr.torch_utilities.persistence.html">draugr.torch_utilities.persistence</a></li>
  477. <li class="toctree-l3"><a class="reference internal" href="draugr.torch_utilities.sessions.html">draugr.torch_utilities.sessions</a></li>
  478. <li class="toctree-l3"><a class="reference internal" href="draugr.torch_utilities.system.html">draugr.torch_utilities.system</a></li>
  479. <li class="toctree-l3"><a class="reference internal" href="draugr.torch_utilities.tensors.html">draugr.torch_utilities.tensors</a></li>
  480. <li class="toctree-l3"><a class="reference internal" href="draugr.torch_utilities.writers.html">draugr.torch_utilities.writers</a></li>
  481. </ul>
  482. </li>
  483. <li class="toctree-l2"><a class="reference internal" href="draugr.tqdm_utilities.html">draugr.tqdm_utilities</a></li>
  484. <li class="toctree-l2"><a class="reference internal" href="draugr.visualisation.html">draugr.visualisation</a></li>
  485. <li class="toctree-l2"><a class="reference internal" href="draugr.writers.html">draugr.writers</a></li>
  486. </ul>
  487. </li>
  488. </ul>
  489. <p class="caption" role="heading"><span class="caption-text">Notes</span></p>
  490. <ul>
  491. <li class="toctree-l1"><a class="reference internal" href="../getting_started.html">Getting Started</a></li>
  492. </ul>
  493. <div class="relations">
  494. <h3>Related Topics</h3>
  495. <ul>
  496. <li><a href="../index.html">Documentation overview</a><ul>
  497. <li><a href="draugr.html">draugr</a><ul>
  498. <li><a href="draugr.torch_utilities.html">draugr.torch_utilities</a><ul>
  499. <li><a href="draugr.torch_utilities.optimisation.html">draugr.torch_utilities.optimisation</a><ul>
  500. <li><a href="draugr.torch_utilities.optimisation.debugging.html">draugr.torch_utilities.optimisation.debugging</a><ul>
  501. <li><a href="draugr.torch_utilities.optimisation.debugging.gradients.html">draugr.torch_utilities.optimisation.debugging.gradients</a><ul>
  502. <li><a href="draugr.torch_utilities.optimisation.debugging.gradients.guided.html">draugr.torch_utilities.optimisation.debugging.gradients.guided</a><ul>
  503. <li>Previous: <a href="draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLUModel.html" title="previous chapter">draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLUModel</a></li>
  504. <li>Next: <a href="draugr.torch_utilities.optimisation.debugging.layer_fetching.html" title="next chapter">draugr.torch_utilities.optimisation.debugging.layer_fetching</a></li>
  505. </ul></li>
  506. </ul></li>
  507. </ul></li>
  508. </ul></li>
  509. </ul></li>
  510. </ul></li>
  511. </ul></li>
  512. </ul>
  513. </div>
  514. <div id="searchbox" style="display: none" role="search">
  515. <h3 id="searchlabel">Quick search</h3>
  516. <div class="searchformwrapper">
  517. <form class="search" action="../search.html" method="get">
  518. <input type="text" name="q" aria-labelledby="searchlabel" autocomplete="off" autocorrect="off" autocapitalize="off" spellcheck="false"/>
  519. <input type="submit" value="Go" />
  520. </form>
  521. </div>
  522. </div>
  523. <script>document.getElementById('searchbox').style.display = "block"</script>
  524. </div>
  525. </div>
  526. <div class="clearer"></div>
  527. </div>
  528. <div class="footer">
  529. &copy;.
  530. |
  531. Powered by <a href="http://sphinx-doc.org/">Sphinx 5.0.2</a>
  532. &amp; <a href="https://github.com/bitprophet/alabaster">Alabaster 0.7.12</a>
  533. |
  534. <a href="../_sources/generated/draugr.torch_utilities.optimisation.debugging.gradients.guided.GuidedBackPropReLU.rst.txt"
  535. rel="nofollow">Page source</a>
  536. </div>
  537. </body>
  538. </html>