infer.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. # Copyright 2020 Google Research. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """A simple example on how to use keras model for inference."""
  16. import os
  17. from absl import app
  18. from absl import flags
  19. from absl import logging
  20. import numpy as np
  21. from PIL import Image
  22. import tensorflow as tf
  23. from utils import hparams_config
  24. from model import inference
  25. from model import efficientdet_keras
  26. flags.DEFINE_string('image_path', None, 'Location of test image.')
  27. flags.DEFINE_string('output_dir', None, 'Directory of annotated output images.')
  28. flags.DEFINE_string('model_dir', None, 'Location of the checkpoint to run.')
  29. flags.DEFINE_string('model_name', 'efficientdet-d0', 'Model name to use.')
  30. flags.DEFINE_string('hparams', '', 'Comma separated k=v pairs or a yaml file')
  31. flags.DEFINE_bool('debug', False, 'If true, run function in eager for debug.')
  32. flags.DEFINE_string('saved_model_dir', None, 'Saved model directory')
  33. FLAGS = flags.FLAGS
  34. def main(_):
  35. imgs = [np.array(Image.open(FLAGS.image_path))] * 2
  36. # Create model config.
  37. config = hparams_config.get_efficientdet_config('efficientdet-d0')
  38. config.is_training_bn = False
  39. config.image_size = '1920x1280'
  40. config.nms_configs.score_thresh = 0.4
  41. config.nms_configs.max_output_size = 100
  42. config.override(FLAGS.hparams)
  43. # Use 'mixed_float16' if running on GPUs.
  44. policy = tf.keras.mixed_precision.experimental.Policy('float32')
  45. tf.keras.mixed_precision.experimental.set_policy(policy)
  46. tf.config.experimental_run_functions_eagerly(FLAGS.debug)
  47. # Create and run the model.
  48. model = efficientdet_keras.EfficientDetModel(config=config)
  49. model.build((None, None, None, 3))
  50. model.load_weights(tf.train.latest_checkpoint(FLAGS.model_dir))
  51. model.summary()
  52. class ExportModel(tf.Module):
  53. def __init__(self, model):
  54. super().__init__()
  55. self.model = model
  56. @tf.function
  57. def f(self, imgs):
  58. return self.model(imgs, training=False, post_mode='global')
  59. imgs = tf.convert_to_tensor(imgs, dtype=tf.uint8)
  60. export_model = ExportModel(model)
  61. if FLAGS.saved_model_dir:
  62. tf.saved_model.save(
  63. export_model,
  64. FLAGS.saved_model_dir,
  65. signatures=export_model.f.get_concrete_function(
  66. tf.TensorSpec(shape=(None, None, None, 3), dtype=tf.uint8)))
  67. export_model = tf.saved_model.load(FLAGS.saved_model_dir)
  68. boxes, scores, classes, valid_len = export_model.f(imgs)
  69. # Visualize results.
  70. for i, img in enumerate(imgs):
  71. length = valid_len[i]
  72. img = inference.visualize_image(
  73. img,
  74. boxes[i].numpy()[:length],
  75. classes[i].numpy().astype(np.int)[:length],
  76. scores[i].numpy()[:length],
  77. label_map=config.label_map,
  78. min_score_thresh=config.nms_configs.score_thresh,
  79. max_boxes_to_draw=config.nms_configs.max_output_size)
  80. output_image_path = os.path.join(FLAGS.output_dir, str(i) + '.jpg')
  81. Image.fromarray(img).save(output_image_path)
  82. print('writing annotated image to %s' % output_image_path)
  83. if __name__ == '__main__':
  84. flags.mark_flag_as_required('image_path')
  85. flags.mark_flag_as_required('output_dir')
  86. flags.mark_flag_as_required('model_dir')
  87. logging.set_verbosity(logging.ERROR)
  88. app.run(main)