| 12345678910111213141516171819202122232425262728293031 |
- """
- Utility functions for Unet3+ models
- """
- import tensorflow as tf
- import tensorflow.keras as k
- def conv_block(x, kernels, kernel_size=(3, 3), strides=(1, 1), padding='same',
- is_bn=True, is_relu=True, n=2):
- """ Custom function for conv2d:
- Apply 3*3 convolutions with BN and relu.
- """
- for i in range(1, n + 1):
- x = k.layers.Conv2D(filters=kernels, kernel_size=kernel_size,
- padding=padding, strides=strides,
- kernel_regularizer=tf.keras.regularizers.l2(1e-4),
- kernel_initializer=k.initializers.he_normal(seed=5))(x)
- if is_bn:
- x = k.layers.BatchNormalization()(x)
- if is_relu:
- x = k.activations.relu(x)
- return x
- def dot_product(seg, cls):
- b, h, w, n = k.backend.int_shape(seg)
- seg = tf.reshape(seg, [-1, h * w, n])
- final = tf.einsum("ijk,ik->ijk", seg, cls)
- final = tf.reshape(final, [-1, h, w, n])
- return final
|