images_utils.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. """
  2. Utility functions for image processing
  3. """
  4. import numpy as np
  5. import cv2
  6. from omegaconf import DictConfig
  7. import matplotlib.pyplot as plt
  8. def read_image(img_path, color_mode):
  9. """
  10. Read and return image as np array from given path.
  11. In case of color image, it returns image in BGR mode.
  12. """
  13. return cv2.imread(img_path, color_mode)
  14. def resize_image(img, height, width, resize_method=cv2.INTER_CUBIC):
  15. """
  16. Resize image
  17. """
  18. return cv2.resize(img, dsize=(width, height), interpolation=resize_method)
  19. def prepare_image(path: str, resize: DictConfig, normalize_type: str):
  20. """
  21. Prepare image for model.
  22. read image --> resize --> normalize --> return as float32
  23. """
  24. image = read_image(path, cv2.IMREAD_COLOR)
  25. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  26. if resize.VALUE:
  27. # TODO verify image resizing method
  28. image = resize_image(image, resize.HEIGHT, resize.WIDTH, cv2.INTER_AREA)
  29. if normalize_type == "normalize":
  30. image = image / 255.0
  31. image = image.astype(np.float32)
  32. return image
  33. def prepare_mask(path: str, resize: dict, normalize_mask: dict):
  34. """
  35. Prepare mask for model.
  36. read mask --> resize --> normalize --> return as int32
  37. """
  38. mask = read_image(path, cv2.IMREAD_GRAYSCALE)
  39. if resize.VALUE:
  40. mask = resize_image(mask, resize.HEIGHT, resize.WIDTH, cv2.INTER_NEAREST)
  41. if normalize_mask.VALUE:
  42. mask = mask / normalize_mask.NORMALIZE_VALUE
  43. mask = mask.astype(np.int32)
  44. return mask
  45. def image_to_mask_name(image_name: str):
  46. """
  47. Convert image file name to it's corresponding mask file name e.g.
  48. image name --> mask name
  49. image_28_0.png mask_28_0.png
  50. replace image with mask
  51. """
  52. return image_name.replace('image', 'mask')
  53. def postprocess_mask(mask, classes, output_type=np.int32):
  54. """
  55. Post process model output.
  56. Covert probabilities into indexes based on maximum value.
  57. """
  58. if classes == 1:
  59. mask = np.where(mask > .5, 1.0, 0.0)
  60. else:
  61. mask = np.argmax(mask, axis=-1)
  62. return mask.astype(output_type)
  63. def denormalize_mask(mask, classes):
  64. """
  65. Denormalize mask by multiplying each class with higher
  66. integer (255 / classes) for better visualization.
  67. """
  68. mask = mask * (255 / classes)
  69. return mask.astype(np.int32)
  70. def display(display_list, show_true_mask=False):
  71. """
  72. Show list of images. it could be
  73. either [image, true_mask, predicted_mask] or [image, predicted_mask].
  74. Set show_true_mask to True if true mask is available or vice versa
  75. """
  76. if show_true_mask:
  77. title_list = ('Input Image', 'True Mask', 'Predicted Mask')
  78. plt.figure(figsize=(12, 4))
  79. else:
  80. title_list = ('Input Image', 'Predicted Mask')
  81. plt.figure(figsize=(8, 4))
  82. for i in range(len(display_list)):
  83. plt.subplot(1, len(display_list), i + 1)
  84. if title_list is not None:
  85. plt.title(title_list[i])
  86. if len(np.squeeze(display_list[i]).shape) == 2:
  87. plt.imshow(np.squeeze(display_list[i]), cmap='gray')
  88. plt.axis('on')
  89. else:
  90. plt.imshow(np.squeeze(display_list[i]))
  91. plt.axis('on')
  92. plt.show()