unet3plus_utils.py 1005 B

12345678910111213141516171819202122232425262728293031
  1. """
  2. Utility functions for Unet3+ models
  3. """
  4. import tensorflow as tf
  5. import tensorflow.keras as k
  6. def conv_block(x, kernels, kernel_size=(3, 3), strides=(1, 1), padding='same',
  7. is_bn=True, is_relu=True, n=2):
  8. """ Custom function for conv2d:
  9. Apply 3*3 convolutions with BN and relu.
  10. """
  11. for i in range(1, n + 1):
  12. x = k.layers.Conv2D(filters=kernels, kernel_size=kernel_size,
  13. padding=padding, strides=strides,
  14. kernel_regularizer=tf.keras.regularizers.l2(1e-4),
  15. kernel_initializer=k.initializers.he_normal(seed=5))(x)
  16. if is_bn:
  17. x = k.layers.BatchNormalization()(x)
  18. if is_relu:
  19. x = k.activations.relu(x)
  20. return x
  21. def dot_product(seg, cls):
  22. b, h, w, n = k.backend.int_shape(seg)
  23. seg = tf.reshape(seg, [-1, h * w, n])
  24. final = tf.einsum("ijk,ik->ijk", seg, cls)
  25. final = tf.reshape(final, [-1, h, w, n])
  26. return final