predict.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. """
  2. Prediction script used to visualize model output
  3. """
  4. import os
  5. import hydra
  6. from omegaconf import DictConfig
  7. from data_generators import tf_data_generator
  8. from utils.general_utils import join_paths, suppress_warnings
  9. from utils.images_utils import display
  10. from utils.images_utils import postprocess_mask, denormalize_mask
  11. from models.model import prepare_model
  12. def predict(cfg: DictConfig):
  13. """
  14. Predict and visualize given data
  15. """
  16. # suppress TensorFlow and DALI warnings
  17. suppress_warnings()
  18. # set batch size to one
  19. cfg.HYPER_PARAMETERS.BATCH_SIZE = 1
  20. # data generator
  21. val_generator = tf_data_generator.DataGenerator(cfg, mode="VAL")
  22. # create model
  23. model = prepare_model(cfg)
  24. # weights model path
  25. checkpoint_path = join_paths(
  26. cfg.WORK_DIR,
  27. cfg.CALLBACKS.MODEL_CHECKPOINT.PATH,
  28. f"{cfg.MODEL.WEIGHTS_FILE_NAME}.hdf5"
  29. )
  30. assert os.path.exists(checkpoint_path), \
  31. f"Model weight's file does not exist at \n{checkpoint_path}"
  32. # load model weights
  33. model.load_weights(checkpoint_path, by_name=True, skip_mismatch=True)
  34. # model.summary()
  35. # check mask are available or not
  36. mask_available = True
  37. if cfg.DATASET.VAL.MASK_PATH is None or \
  38. str(cfg.DATASET.VAL.MASK_PATH).lower() == "none":
  39. mask_available = False
  40. showed_images = 0
  41. for batch_data in val_generator: # for each batch
  42. batch_images = batch_data[0]
  43. if mask_available:
  44. batch_mask = batch_data[1]
  45. # make prediction on batch
  46. batch_predictions = model.predict_on_batch(batch_images)
  47. if len(model.outputs) > 1:
  48. batch_predictions = batch_predictions[0]
  49. for index in range(len(batch_images)):
  50. image = batch_images[index] # for each image
  51. if cfg.SHOW_CENTER_CHANNEL_IMAGE:
  52. # for UNet3+ show only center channel as image
  53. image = image[:, :, 1]
  54. # do postprocessing on predicted mask
  55. prediction = batch_predictions[index]
  56. prediction = postprocess_mask(prediction, cfg.OUTPUT.CLASSES)
  57. # denormalize mask for better visualization
  58. prediction = denormalize_mask(prediction, cfg.OUTPUT.CLASSES)
  59. if mask_available:
  60. mask = batch_mask[index]
  61. mask = postprocess_mask(mask, cfg.OUTPUT.CLASSES)
  62. mask = denormalize_mask(mask, cfg.OUTPUT.CLASSES)
  63. # if np.unique(mask).shape[0] == 2:
  64. if mask_available:
  65. display([image, mask, prediction], show_true_mask=True)
  66. else:
  67. display([image, prediction], show_true_mask=False)
  68. showed_images += 1
  69. # stop after displaying below number of images
  70. # if showed_images >= 10: break
  71. @hydra.main(version_base=None, config_path="configs", config_name="config")
  72. def main(cfg: DictConfig):
  73. """
  74. Read config file and pass to prediction method
  75. """
  76. predict(cfg)
  77. if __name__ == "__main__":
  78. main()