inspector.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. # Copyright 2020 Google Research. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tool to inspect a model."""
  16. import os
  17. from absl import app
  18. from absl import flags
  19. from absl import logging
  20. import numpy as np
  21. from PIL import Image
  22. import tensorflow as tf
  23. from dllogger import StdOutBackend, JSONStreamBackend, Verbosity
  24. import dllogger as DLLogger
  25. from model import inference
  26. from utils import hparams_config
  27. from utils import model_utils
  28. from utils import setup
  29. flags.DEFINE_string('model_name', 'efficientdet-d0', 'Model.')
  30. flags.DEFINE_string('mode', 'benchmark',
  31. 'Run mode: {dry, export, benchmark}')
  32. flags.DEFINE_string('trace_filename', None, 'Trace file name.')
  33. flags.DEFINE_integer('bm_runs', 100, 'Number of benchmark runs.')
  34. flags.DEFINE_string('tensorrt', None, 'TensorRT mode: {None, FP32, FP16, INT8}')
  35. flags.DEFINE_integer('batch_size', 1, 'Batch size for inference.')
  36. flags.DEFINE_string('ckpt_path', '_', 'checkpoint dir used for eval.')
  37. flags.DEFINE_string('export_ckpt', None, 'Output model ckpt path.')
  38. flags.DEFINE_string(
  39. 'hparams', '', 'Comma separated k=v pairs of hyperparameters or a module'
  40. ' containing attributes to use as hyperparameters.')
  41. flags.DEFINE_bool('amp', True, 'Enable mixed precision training')
  42. flags.DEFINE_bool('use_xla', True, 'Use XLA')
  43. flags.DEFINE_string('input_image', None, 'Input image path for inference.')
  44. flags.DEFINE_string('output_image_dir', None, 'Output dir for inference.')
  45. flags.DEFINE_string('dllogger_path', '/tmp/time_log.txt', 'Filepath for dllogger logs')
  46. # For video.
  47. flags.DEFINE_string('input_video', None, 'Input video path for inference.')
  48. flags.DEFINE_string('output_video', None,
  49. 'Output video path. If None, play it online instead.')
  50. # For visualization.
  51. flags.DEFINE_integer('max_boxes_to_draw', 100, 'Max number of boxes to draw.')
  52. flags.DEFINE_float('min_score_thresh', 0.4, 'Score threshold to show box.')
  53. flags.DEFINE_string('nms_method', 'hard', 'nms method, hard or gaussian.')
  54. # For saved model.
  55. flags.DEFINE_string('saved_model_dir', None,
  56. 'Folder path for saved model.')
  57. flags.DEFINE_string('tflite_path', None, 'Path for exporting tflite file.')
  58. flags.DEFINE_bool('debug', False, 'Debug mode.')
  59. FLAGS = flags.FLAGS
  60. def main(_):
  61. model_config = hparams_config.get_detection_config(FLAGS.model_name)
  62. model_config.override(FLAGS.hparams) # Add custom overrides
  63. model_config.is_training_bn = False
  64. model_config.image_size = model_utils.parse_image_size(model_config.image_size)
  65. # A hack to make flag consistent with nms configs.
  66. if FLAGS.min_score_thresh:
  67. model_config.nms_configs.score_thresh = FLAGS.min_score_thresh
  68. if FLAGS.nms_method:
  69. model_config.nms_configs.method = FLAGS.nms_method
  70. if FLAGS.max_boxes_to_draw:
  71. model_config.nms_configs.max_output_size = FLAGS.max_boxes_to_draw
  72. model_config.mixed_precision = FLAGS.amp
  73. setup.set_flags(FLAGS, model_config, training=False)
  74. model_params = model_config.as_dict()
  75. ckpt_path_or_file = FLAGS.ckpt_path
  76. if tf.io.gfile.isdir(ckpt_path_or_file):
  77. ckpt_path_or_file = tf.train.latest_checkpoint(ckpt_path_or_file)
  78. driver = inference.ServingDriver(FLAGS.model_name, ckpt_path_or_file,
  79. FLAGS.batch_size or None,
  80. FLAGS.min_score_thresh,
  81. FLAGS.max_boxes_to_draw, model_params)
  82. # dllogger setup
  83. backends = []
  84. backends+=[
  85. JSONStreamBackend(verbosity=Verbosity.VERBOSE, filename=FLAGS.dllogger_path),
  86. StdOutBackend(verbosity=Verbosity.DEFAULT)]
  87. DLLogger.init(backends=backends)
  88. DLLogger.metadata('inference_fps', {'unit': 'images/s'})
  89. DLLogger.metadata('inference_latency_ms', {'unit': 'ms'})
  90. DLLogger.metadata('latency_avg', {'unit': 's'})
  91. DLLogger.metadata('latency_90', {'unit': 's'})
  92. DLLogger.metadata('latency_95', {'unit': 's'})
  93. DLLogger.metadata('latency_99', {'unit': 's'})
  94. if FLAGS.mode == 'export':
  95. if tf.io.gfile.exists(FLAGS.saved_model_dir):
  96. tf.io.gfile.rmtree(FLAGS.saved_model_dir)
  97. driver.export(FLAGS.saved_model_dir, FLAGS.tflite_path, FLAGS.tensorrt)
  98. elif FLAGS.mode == 'benchmark':
  99. if FLAGS.saved_model_dir:
  100. driver.load(FLAGS.saved_model_dir)
  101. batch_size = FLAGS.batch_size or 1
  102. if FLAGS.input_image:
  103. image_file = tf.io.read_file(FLAGS.input_image)
  104. image_arrays = tf.image.decode_image(image_file)
  105. image_arrays.set_shape((None, None, 3))
  106. image_arrays = tf.expand_dims(image_arrays, 0)
  107. if batch_size > 1:
  108. image_arrays = tf.tile(image_arrays, [batch_size, 1, 1, 1])
  109. else:
  110. # use synthetic data if no image is provided.
  111. image_arrays = tf.ones((batch_size, *model_config.image_size, 3),
  112. dtype=tf.uint8)
  113. driver.benchmark(image_arrays, FLAGS.bm_runs, FLAGS.trace_filename)
  114. elif FLAGS.mode == 'dry':
  115. # transfer to tf2 format ckpt
  116. driver.build()
  117. if FLAGS.export_ckpt:
  118. driver.model.save_weights(FLAGS.export_ckpt)
  119. elif FLAGS.mode == 'video':
  120. import cv2 # pylint: disable=g-import-not-at-top
  121. if tf.saved_model.contains_saved_model(FLAGS.saved_model_dir):
  122. driver.load(FLAGS.saved_model_dir)
  123. cap = cv2.VideoCapture(FLAGS.input_video)
  124. if not cap.isOpened():
  125. print('Error opening input video: {}'.format(FLAGS.input_video))
  126. out_ptr = None
  127. if FLAGS.output_video:
  128. frame_width, frame_height = int(cap.get(3)), int(cap.get(4))
  129. out_ptr = cv2.VideoWriter(FLAGS.output_video,
  130. cv2.VideoWriter_fourcc('m', 'p', '4', 'v'), 25,
  131. (frame_width, frame_height))
  132. while cap.isOpened():
  133. # Capture frame-by-frame
  134. ret, frame = cap.read()
  135. if not ret:
  136. break
  137. raw_frames = np.array([frame])
  138. detections_bs = driver.serve(raw_frames)
  139. boxes, scores, classes, _ = tf.nest.map_structure(np.array, detections_bs)
  140. new_frame = driver.visualize(
  141. raw_frames[0],
  142. boxes[0],
  143. scores[0],
  144. classes[0],
  145. min_score_thresh=model_config.nms_configs.score_thresh,
  146. max_boxes_to_draw=model_config.nms_configs.max_output_size)
  147. if out_ptr:
  148. # write frame into output file.
  149. out_ptr.write(new_frame)
  150. else:
  151. # show the frame online, mainly used for real-time speed test.
  152. cv2.imshow('Frame', new_frame)
  153. # Press Q on keyboard to exit
  154. if cv2.waitKey(1) & 0xFF == ord('q'):
  155. break
  156. if __name__ == '__main__':
  157. logging.set_verbosity(logging.ERROR)
  158. app.run(main)