model_fn.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """ Model function in charge to collect metrics and feed them to the optimizer """
  15. import horovod.tensorflow as hvd
  16. import tensorflow as tf
  17. from model.unet3d import Builder
  18. from model.losses import make_loss, eval_dice, total_dice
  19. from dataset.data_loader import CLASSES
  20. def unet_3d(features, labels, mode, params):
  21. """ Gather loss and feed it to the optimizer
  22. :param features: Input features
  23. :param labels: Input labels
  24. :param mode: Estimator's execution mode
  25. :param params: Dict with additional parameters
  26. :return: Estimator spec
  27. """
  28. # TODO: Find a better way to handle the empty params namespace
  29. try:
  30. normalization = params.normalization
  31. except:
  32. normalization = 'instancenorm'
  33. input_node = tf.identity(features, name='input_node')
  34. logits = Builder(n_classes=4, normalization=normalization, mode=mode)(input_node)
  35. logits = tf.identity(logits, name='output_node')
  36. if mode == tf.estimator.ModeKeys.PREDICT:
  37. prediction = tf.argmax(input=logits, axis=-1, output_type=tf.dtypes.int32, name="predictions")
  38. return tf.estimator.EstimatorSpec(mode=mode,
  39. predictions={'predictions': tf.cast(prediction, tf.int8)})
  40. labels = tf.cast(labels, tf.float32)
  41. if mode == tf.estimator.ModeKeys.EVAL:
  42. prediction = tf.argmax(input=logits, axis=-1, output_type=tf.dtypes.int32)
  43. prediction = tf.one_hot(prediction, 4)
  44. if not params.include_background:
  45. labels = labels[..., 1:]
  46. prediction = prediction[..., 1:]
  47. prediction = tf.cast(prediction, tf.float32)
  48. eval_acc = eval_dice(y_true=labels, y_pred=prediction)
  49. total_eval_acc = total_dice(prediction, labels)
  50. metrics = {CLASSES[i]: tf.compat.v1.metrics.mean(eval_acc[i]) for i in range(eval_acc.shape[-1])}
  51. metrics['whole_tumor'] = tf.compat.v1.metrics.mean(total_eval_acc)
  52. return tf.estimator.EstimatorSpec(mode=mode, loss=tf.reduce_mean(eval_acc),
  53. eval_metric_ops=metrics)
  54. if not params.include_background:
  55. labels = labels[..., 1:]
  56. logits = logits[..., 1:]
  57. loss = make_loss(params, y_pred=logits, y_true=labels)
  58. loss = tf.identity(loss, name="total_loss_ref")
  59. global_step = tf.compat.v1.train.get_or_create_global_step()
  60. boundaries = [params.max_steps // (2 * hvd.size()),
  61. params.max_steps // (2 * hvd.size()),
  62. 3 * params.max_steps // (4 * hvd.size())]
  63. lr = params.learning_rate
  64. values = [lr / 4, lr, lr / 5, lr / 20]
  65. learning_rate = tf.compat.v1.train.piecewise_constant(global_step, boundaries, values)
  66. optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
  67. if params.use_amp:
  68. loss_scale = tf.train.experimental.DynamicLossScale()
  69. optimizer = tf.compat.v1.train.experimental.MixedPrecisionLossScaleOptimizer(optimizer, loss_scale)
  70. optimizer = hvd.DistributedOptimizer(optimizer)
  71. with tf.control_dependencies(tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)):
  72. train_op = optimizer.minimize(loss, global_step=global_step)
  73. return tf.estimator.EstimatorSpec(
  74. mode=mode, loss=loss, train_op=train_op)