eval.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. # Copyright 2020 Google Research. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Eval libraries."""
  16. import os
  17. from mpi4py import MPI
  18. from absl import app
  19. from absl import flags
  20. from absl import logging
  21. import tensorflow as tf
  22. import horovod.tensorflow.keras as hvd
  23. from model import anchors
  24. from model import coco_metric
  25. from model import dataloader
  26. from model import efficientdet_keras
  27. from model import label_util
  28. from model import postprocess
  29. from utils import hparams_config
  30. from utils import model_utils
  31. from utils import util_keras
  32. from utils.horovod_utils import get_rank, get_world_size, is_main_process
  33. flags.DEFINE_integer('eval_samples', 5000, 'Number of eval samples.')
  34. flags.DEFINE_string('val_file_pattern', None,
  35. 'Glob for eval tfrecords, e.g. coco/val-*.tfrecord.')
  36. flags.DEFINE_string('val_json_file', None,
  37. 'Groudtruth, e.g. annotations/instances_val2017.json.')
  38. flags.DEFINE_string('model_name', 'efficientdet-d0', 'Model name to use.')
  39. flags.DEFINE_string('ckpt_path', None, 'Checkpoint path to evaluate')
  40. flags.DEFINE_integer('batch_size', 8, 'Local batch size.')
  41. flags.DEFINE_string('only_this_epoch', None, 'Evaluate only this epoch checkpoint.')
  42. flags.DEFINE_bool('enable_map_parallelization', True, 'Parallelize stateless map transformations in dataloader')
  43. flags.DEFINE_bool('amp', True, 'Use mixed precision for eval.')
  44. flags.DEFINE_string('hparams', '', 'Comma separated k=v pairs or a yaml file.')
  45. FLAGS = flags.FLAGS
  46. def main(_):
  47. hvd.init()
  48. gpus = tf.config.experimental.list_physical_devices('GPU')
  49. for gpu in gpus:
  50. tf.config.experimental.set_memory_growth(gpu, True)
  51. if gpus:
  52. tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU')
  53. if FLAGS.amp:
  54. policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
  55. tf.keras.mixed_precision.experimental.set_policy(policy)
  56. else:
  57. os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '0'
  58. config = hparams_config.get_efficientdet_config(FLAGS.model_name)
  59. config.override(FLAGS.hparams)
  60. config.val_json_file = FLAGS.val_json_file
  61. config.nms_configs.max_nms_inputs = anchors.MAX_DETECTION_POINTS
  62. config.drop_remainder = False # eval all examples w/o drop.
  63. config.image_size = model_utils.parse_image_size(config['image_size'])
  64. @tf.function
  65. def model_fn(images, labels):
  66. cls_outputs, box_outputs = model(images, training=False)
  67. detections = postprocess.generate_detections(config, cls_outputs, box_outputs,
  68. labels['image_scales'],
  69. labels['source_ids'])
  70. tf.numpy_function(evaluator.update_state,
  71. [labels['groundtruth_data'],
  72. postprocess.transform_detections(detections)], [])
  73. # Network
  74. model = efficientdet_keras.EfficientDetNet(config=config)
  75. model.build((None, *config.image_size, 3))
  76. # dataset
  77. batch_size = FLAGS.batch_size # local batch size.
  78. ds = dataloader.InputReader(
  79. FLAGS.val_file_pattern,
  80. is_training=False,
  81. max_instances_per_image=config.max_instances_per_image,
  82. enable_map_parallelization=FLAGS.enable_map_parallelization)(
  83. config, batch_size=batch_size)
  84. ds = ds.shard(get_world_size(), get_rank())
  85. # Evaluator for AP calculation.
  86. label_map = label_util.get_label_map(config.label_map)
  87. evaluator = coco_metric.EvaluationMetric(
  88. filename=config.val_json_file, label_map=label_map)
  89. util_keras.restore_ckpt(model, FLAGS.ckpt_path, config.moving_average_decay,
  90. steps_per_epoch=0, skip_mismatch=False, expect_partial=True)
  91. if FLAGS.eval_samples:
  92. num_samples = (FLAGS.eval_samples + get_world_size() - 1) // get_world_size()
  93. num_samples = (num_samples + batch_size - 1) // batch_size
  94. ds = ds.take(num_samples)
  95. evaluator.reset_states()
  96. # evaluate all images.
  97. pbar = tf.keras.utils.Progbar(num_samples)
  98. for i, (images, labels) in enumerate(ds):
  99. model_fn(images, labels)
  100. if is_main_process():
  101. pbar.update(i)
  102. # gather detections from all ranks
  103. evaluator.gather()
  104. if is_main_process():
  105. # compute the final eval results.
  106. metrics = evaluator.result()
  107. metric_dict = {}
  108. for i, name in enumerate(evaluator.metric_names):
  109. metric_dict[name] = metrics[i]
  110. if label_map:
  111. for i, cid in enumerate(sorted(label_map.keys())):
  112. name = 'AP_/%s' % label_map[cid]
  113. metric_dict[name] = metrics[i + len(evaluator.metric_names)]
  114. # csv format
  115. csv_metrics = ['AP','AP50','AP75','APs','APm','APl']
  116. csv_format = ",".join([str(round(metric_dict[key] * 100, 2)) for key in csv_metrics])
  117. print(FLAGS.model_name, metric_dict, "csv format:", csv_format)
  118. MPI.COMM_WORLD.Barrier()
  119. if __name__ == '__main__':
  120. flags.mark_flag_as_required('val_file_pattern')
  121. flags.mark_flag_as_required('val_json_file')
  122. flags.mark_flag_as_required('ckpt_path')
  123. logging.set_verbosity(logging.WARNING)
  124. app.run(main)