plotting_utils.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. # BSD 3-Clause License
  2. # Copyright (c) 2018-2020, NVIDIA Corporation
  3. # All rights reserved.
  4. # Redistribution and use in source and binary forms, with or without
  5. # modification, are permitted provided that the following conditions are met:
  6. # * Redistributions of source code must retain the above copyright notice, this
  7. # list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright notice,
  9. # this list of conditions and the following disclaimer in the documentation
  10. # and/or other materials provided with the distribution.
  11. # * Neither the name of the copyright holder nor the names of its
  12. # contributors may be used to endorse or promote products derived from
  13. # this software without specific prior written permission.
  14. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  15. # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  16. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  17. # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
  18. # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  19. # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
  20. # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  21. # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
  22. # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  23. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  24. """https://github.com/NVIDIA/tacotron2"""
  25. import matplotlib
  26. matplotlib.use("Agg")
  27. import matplotlib.pylab as plt
  28. import numpy as np
  29. def save_figure_to_numpy(fig):
  30. # save it to a numpy array.
  31. data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
  32. data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
  33. return data
  34. def plot_alignment_to_numpy(alignment, info=None):
  35. fig, ax = plt.subplots(figsize=(6, 4))
  36. im = ax.imshow(alignment, aspect='auto', origin='lower',
  37. interpolation='none')
  38. fig.colorbar(im, ax=ax)
  39. xlabel = 'Decoder timestep'
  40. if info is not None:
  41. xlabel += '\n\n' + info
  42. plt.xlabel(xlabel)
  43. plt.ylabel('Encoder timestep')
  44. plt.tight_layout()
  45. fig.canvas.draw()
  46. data = save_figure_to_numpy(fig)
  47. plt.close()
  48. return data
  49. def plot_spectrogram_to_numpy(spectrogram):
  50. fig, ax = plt.subplots(figsize=(12, 3))
  51. im = ax.imshow(spectrogram, aspect="auto", origin="lower",
  52. interpolation='none')
  53. plt.colorbar(im, ax=ax)
  54. plt.xlabel("Frames")
  55. plt.ylabel("Channels")
  56. plt.tight_layout()
  57. fig.canvas.draw()
  58. data = save_figure_to_numpy(fig)
  59. plt.close()
  60. return data
  61. def plot_gate_outputs_to_numpy(gate_targets, gate_outputs):
  62. fig, ax = plt.subplots(figsize=(12, 3))
  63. ax.scatter(range(len(gate_targets)), gate_targets, alpha=0.5,
  64. color='green', marker='+', s=1, label='target')
  65. ax.scatter(range(len(gate_outputs)), gate_outputs, alpha=0.5,
  66. color='red', marker='.', s=1, label='predicted')
  67. plt.xlabel("Frames (Green target, Red predicted)")
  68. plt.ylabel("Gate State")
  69. plt.tight_layout()
  70. fig.canvas.draw()
  71. data = save_figure_to_numpy(fig)
  72. plt.close()
  73. return data