trt_utils.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. # *****************************************************************************
  2. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Redistribution and use in source and binary forms, with or without
  5. # modification, are permitted provided that the following conditions are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of the NVIDIA CORPORATION nor the
  12. # names of its contributors may be used to endorse or promote products
  13. # derived from this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  16. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  17. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  18. # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
  19. # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  20. # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  21. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  22. # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  23. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  24. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  25. #
  26. # *****************************************************************************
  27. import tensorrt as trt
  28. def is_dimension_dynamic(dim):
  29. return dim is None or dim <= 0
  30. def is_shape_dynamic(shape):
  31. return any([is_dimension_dynamic(dim) for dim in shape])
  32. def run_trt_engine(context, engine, tensors):
  33. bindings = [None]*engine.num_bindings
  34. for name,tensor in tensors['inputs'].items():
  35. idx = engine.get_binding_index(name)
  36. bindings[idx] = tensor.data_ptr()
  37. if engine.is_shape_binding(idx) and is_shape_dynamic(context.get_shape(idx)):
  38. context.set_shape_input(idx, tensor)
  39. elif is_shape_dynamic(engine.get_binding_shape(idx)):
  40. context.set_binding_shape(idx, tensor.shape)
  41. for name,tensor in tensors['outputs'].items():
  42. idx = engine.get_binding_index(name)
  43. bindings[idx] = tensor.data_ptr()
  44. context.execute_v2(bindings=bindings)
  45. def load_engine(engine_filepath, trt_logger):
  46. with open(engine_filepath, "rb") as f, trt.Runtime(trt_logger) as runtime:
  47. engine = runtime.deserialize_cuda_engine(f.read())
  48. return engine
  49. def engine_info(engine_filepath):
  50. TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
  51. engine = load_engine(engine_filepath, TRT_LOGGER)
  52. binding_template = r"""
  53. {btype} {{
  54. name: "{bname}"
  55. data_type: {dtype}
  56. dims: {dims}
  57. }}"""
  58. type_mapping = {"DataType.HALF": "TYPE_FP16",
  59. "DataType.FLOAT": "TYPE_FP32",
  60. "DataType.INT32": "TYPE_INT32",
  61. "DataType.BOOL" : "TYPE_BOOL"}
  62. print("engine name", engine.name)
  63. print("has_implicit_batch_dimension", engine.has_implicit_batch_dimension)
  64. start_dim = 0 if engine.has_implicit_batch_dimension else 1
  65. print("num_optimization_profiles", engine.num_optimization_profiles)
  66. print("max_batch_size:", engine.max_batch_size)
  67. print("device_memory_size:", engine.device_memory_size)
  68. print("max_workspace_size:", engine.max_workspace_size)
  69. print("num_layers:", engine.num_layers)
  70. for i in range(engine.num_bindings):
  71. btype = "input" if engine.binding_is_input(i) else "output"
  72. bname = engine.get_binding_name(i)
  73. dtype = engine.get_binding_dtype(i)
  74. bdims = engine.get_binding_shape(i)
  75. config_values = {
  76. "btype": btype,
  77. "bname": bname,
  78. "dtype": type_mapping[str(dtype)],
  79. "dims": list(bdims[start_dim:])
  80. }
  81. final_binding_str = binding_template.format_map(config_values)
  82. print(final_binding_str)
  83. def build_engine(model_file, shapes, max_ws=512*1024*1024, fp16=False):
  84. TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
  85. builder = trt.Builder(TRT_LOGGER)
  86. builder.fp16_mode = fp16
  87. config = builder.create_builder_config()
  88. config.max_workspace_size = max_ws
  89. if fp16:
  90. config.flags |= 1 << int(trt.BuilderFlag.FP16)
  91. profile = builder.create_optimization_profile()
  92. for s in shapes:
  93. profile.set_shape(s['name'], min=s['min'], opt=s['opt'], max=s['max'])
  94. config.add_optimization_profile(profile)
  95. explicit_batch = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
  96. network = builder.create_network(explicit_batch)
  97. with trt.OnnxParser(network, TRT_LOGGER) as parser:
  98. with open(model_file, 'rb') as model:
  99. parsed = parser.parse(model.read())
  100. for i in range(parser.num_errors):
  101. print("TensorRT ONNX parser error:", parser.get_error(i))
  102. engine = builder.build_engine(network, config=config)
  103. return engine