mandel.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. # Mandelbrot example
  2. # Import libraries for simulation
  3. import tensorflow as tf
  4. import numpy as np
  5. # Imports for visualization
  6. import PIL.Image
  7. from cStringIO import StringIO
  8. from IPython.display import clear_output, Image, display
  9. import scipy.ndimage as nd
  10. def DisplayFractal(a, fmt='jpeg'):
  11. """Display an array of iteration counts as a
  12. colorful picture of a fractal."""
  13. a_cyclic = (6.28*a/20.0).reshape(list(a.shape)+[1])
  14. # orig parameters
  15. # img = np.concatenate([10+20*np.cos(a_cyclic),
  16. # 30+50*np.sin(a_cyclic),
  17. # 155-80*np.cos(a_cyclic)], 2)
  18. img = np.concatenate([95+25*np.cos(a_cyclic+1.21),
  19. 20+155*np.sin(a_cyclic),
  20. #155-80*np.cos(a_cyclic)], 2)
  21. 155-8*np.cos(a_cyclic)], 2)
  22. img[a==a.max()] = 0
  23. a = img
  24. a = np.uint8(np.clip(a, 0, 255))
  25. f = StringIO()
  26. image = PIL.Image.fromarray(a)
  27. image.show()
  28. image.save(f, fmt)
  29. display(Image(data=f.getvalue()))
  30. sess = tf.InteractiveSession()
  31. # Use NumPy to create a 2D array of complex numbers on [-2,2]x[-2,2]
  32. Y, X = np.mgrid[-1.3:1.3:0.005, -2:1:0.005]
  33. Z = X+1j*Y
  34. #Now we define and initialize TensorFlow tensors.
  35. xs = tf.constant(Z.astype("complex64"))
  36. zs = tf.Variable(xs)
  37. ns = tf.Variable(tf.zeros_like(xs, "float32"))
  38. tf.initialize_all_variables().run()
  39. # Compute the new values of z: z^2 + x
  40. zs_ = zs*zs + xs
  41. # Have we diverged with this new value?
  42. not_diverged = tf.complex_abs(zs_) < 4
  43. # Operation to update the zs and the iteration count.
  44. #
  45. # Note: We keep computing zs after they diverge! This
  46. # is very wasteful! There are better, if a little
  47. # less simple, ways to do this.
  48. #
  49. step = tf.group(
  50. zs.assign(zs_),
  51. ns.assign_add(tf.cast(not_diverged, "float32"))
  52. )
  53. for i in range(200): step.run()
  54. # Let's see what we've got.
  55. DisplayFractal(ns.eval())