model.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. """
  2. Returns Unet3+ model
  3. """
  4. import tensorflow as tf
  5. from omegaconf import DictConfig
  6. from .backbones import vgg16_backbone, vgg19_backbone, unet3plus_backbone
  7. from .unet3plus import unet3plus
  8. from .unet3plus_deep_supervision import unet3plus_deepsup
  9. from .unet3plus_deep_supervision_cgm import unet3plus_deepsup_cgm
  10. def prepare_model(cfg: DictConfig, training=False):
  11. """
  12. Creates and return model object based on given model type.
  13. """
  14. input_shape = [cfg.INPUT.HEIGHT, cfg.INPUT.WIDTH, cfg.INPUT.CHANNELS]
  15. input_layer = tf.keras.layers.Input(
  16. shape=input_shape,
  17. name="input_layer"
  18. ) # 320*320*3
  19. filters = [64, 128, 256, 512, 1024]
  20. # create backbone
  21. if cfg.MODEL.BACKBONE.TYPE == "unet3plus":
  22. backbone_layers = unet3plus_backbone(
  23. input_layer,
  24. filters
  25. )
  26. elif cfg.MODEL.BACKBONE.TYPE == "vgg16":
  27. backbone_layers = vgg16_backbone(input_layer, )
  28. elif cfg.MODEL.BACKBONE.TYPE == "vgg19":
  29. backbone_layers = vgg19_backbone(input_layer, )
  30. else:
  31. raise ValueError(
  32. "Wrong backbone type passed."
  33. "\nPlease check config file for possible options."
  34. )
  35. print(f"Using {cfg.MODEL.BACKBONE.TYPE} as a backbone.")
  36. if cfg.MODEL.TYPE == "unet3plus":
  37. # training parameter does not matter in this case
  38. outputs, model_name = unet3plus(
  39. backbone_layers,
  40. cfg.OUTPUT.CLASSES,
  41. filters
  42. )
  43. elif cfg.MODEL.TYPE == "unet3plus_deepsup":
  44. outputs, model_name = unet3plus_deepsup(
  45. backbone_layers,
  46. cfg.OUTPUT.CLASSES,
  47. filters,
  48. training
  49. )
  50. elif cfg.MODEL.TYPE == "unet3plus_deepsup_cgm":
  51. if cfg.OUTPUT.CLASSES != 1:
  52. raise ValueError(
  53. "UNet3+ with Deep Supervision and Classification Guided Module"
  54. "\nOnly works when model output classes are equal to 1"
  55. )
  56. outputs, model_name = unet3plus_deepsup_cgm(
  57. backbone_layers,
  58. cfg.OUTPUT.CLASSES,
  59. filters,
  60. training
  61. )
  62. else:
  63. raise ValueError(
  64. "Wrong model type passed."
  65. "\nPlease check config file for possible options."
  66. )
  67. return tf.keras.Model(
  68. inputs=input_layer,
  69. outputs=outputs,
  70. name=model_name
  71. )
  72. if __name__ == "__main__":
  73. """## Test model Compilation,"""
  74. from omegaconf import OmegaConf
  75. cfg = {
  76. "WORK_DIR": "H:\\Projects\\UNet3P",
  77. "INPUT": {"HEIGHT": 320, "WIDTH": 320, "CHANNELS": 3},
  78. "OUTPUT": {"CLASSES": 1},
  79. # available variants are unet3plus, unet3plus_deepsup, unet3plus_deepsup_cgm
  80. "MODEL": {"TYPE": "unet3plus",
  81. # available variants are unet3plus, vgg16, vgg19
  82. "BACKBONE": {"TYPE": "vgg19", }
  83. }
  84. }
  85. unet_3P = prepare_model(OmegaConf.create(cfg), True)
  86. unet_3P.summary()
  87. # tf.keras.utils.plot_model(unet_3P, show_layer_names=True, show_shapes=True)
  88. # unet_3P.save("unet_3P.hdf5")