loss.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. """
  2. Implementation of different loss functions
  3. """
  4. import tensorflow as tf
  5. import tensorflow.keras.backend as K
  6. def iou(y_true, y_pred, smooth=1.e-9):
  7. """
  8. Calculate intersection over union (IoU) between images.
  9. Input shape should be Batch x Height x Width x #Classes (BxHxWxN).
  10. Using Mean as reduction type for batch values.
  11. """
  12. intersection = K.sum(K.abs(y_true * y_pred), axis=[1, 2, 3])
  13. union = K.sum(y_true, [1, 2, 3]) + K.sum(y_pred, [1, 2, 3])
  14. union = union - intersection
  15. iou = K.mean((intersection + smooth) / (union + smooth), axis=0)
  16. return iou
  17. def iou_loss(y_true, y_pred):
  18. """
  19. Jaccard / IoU loss
  20. """
  21. return 1 - iou(y_true, y_pred)
  22. def focal_loss(y_true, y_pred):
  23. """
  24. Focal loss
  25. """
  26. gamma = 2.
  27. alpha = 4.
  28. epsilon = 1.e-9
  29. y_true_c = tf.convert_to_tensor(y_true, tf.float32)
  30. y_pred_c = tf.convert_to_tensor(y_pred, tf.float32)
  31. model_out = tf.add(y_pred_c, epsilon)
  32. ce = tf.multiply(y_true_c, -tf.math.log(model_out))
  33. weight = tf.multiply(y_true_c, tf.pow(
  34. tf.subtract(1., model_out), gamma)
  35. )
  36. fl = tf.multiply(alpha, tf.multiply(weight, ce))
  37. reduced_fl = tf.reduce_max(fl, axis=-1)
  38. return tf.reduce_mean(reduced_fl)
  39. def ssim_loss(y_true, y_pred, smooth=1.e-9):
  40. """
  41. Structural Similarity Index loss.
  42. Input shape should be Batch x Height x Width x #Classes (BxHxWxN).
  43. Using Mean as reduction type for batch values.
  44. """
  45. ssim_value = tf.image.ssim(y_true, y_pred, max_val=1)
  46. return K.mean(1 - ssim_value + smooth, axis=0)
  47. class DiceCoefficient(tf.keras.metrics.Metric):
  48. """
  49. Dice coefficient metric. Can be used to calculate dice on probabilities
  50. or on their respective classes
  51. """
  52. def __init__(self, post_processed: bool,
  53. classes: int,
  54. name='dice_coef',
  55. **kwargs):
  56. """
  57. Set post_processed=False if dice coefficient needs to be calculated
  58. on probabilities. Set post_processed=True if probabilities needs to
  59. be first converted/mapped into their respective class.
  60. """
  61. super(DiceCoefficient, self).__init__(name=name, **kwargs)
  62. self.dice_value = self.add_weight(name='dice_value', initializer='zeros',
  63. aggregation=tf.VariableAggregation.MEAN) # SUM
  64. self.post_processed = post_processed
  65. self.classes = classes
  66. if self.classes == 1:
  67. self.axis = [1, 2, 3]
  68. else:
  69. self.axis = [1, 2, ]
  70. def update_state(self, y_true, y_pred, sample_weight=None):
  71. if self.post_processed:
  72. if self.classes == 1:
  73. y_true_ = y_true
  74. y_pred_ = tf.where(y_pred > .5, 1.0, 0.0)
  75. else:
  76. y_true_ = tf.math.argmax(y_true, axis=-1, output_type=tf.int32)
  77. y_pred_ = tf.math.argmax(y_pred, axis=-1, output_type=tf.int32)
  78. y_true_ = tf.cast(y_true_, dtype=tf.float32)
  79. y_pred_ = tf.cast(y_pred_, dtype=tf.float32)
  80. else:
  81. y_true_, y_pred_ = y_true, y_pred
  82. self.dice_value.assign(self.dice_coef(y_true_, y_pred_))
  83. def result(self):
  84. return self.dice_value
  85. def reset_state(self):
  86. self.dice_value.assign(0.0) # reset metric state
  87. def dice_coef(self, y_true, y_pred, smooth=1.e-9):
  88. """
  89. Calculate dice coefficient.
  90. Input shape could be either Batch x Height x Width x #Classes (BxHxWxN)
  91. or Batch x Height x Width (BxHxW).
  92. Using Mean as reduction type for batch values.
  93. """
  94. intersection = K.sum(y_true * y_pred, axis=self.axis)
  95. union = K.sum(y_true, axis=self.axis) + K.sum(y_pred, axis=self.axis)
  96. return K.mean((2. * intersection + smooth) / (union + smooth), axis=0)