convert_model.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. #!/usr/bin/env python3
  2. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. r"""
  16. `convert_model.py` script allows to convert between model formats with additional model optimizations
  17. for faster inference.
  18. It converts model from results of get_model function.
  19. Currently supported input and output formats are:
  20. - inputs
  21. - `tf-estimator` - `get_model` function returning Tensorflow Estimator
  22. - `tf-keras` - `get_model` function returning Tensorflow Keras Model
  23. - `tf-savedmodel` - Tensorflow SavedModel binary
  24. - `pyt` - `get_model` function returning PyTorch Module
  25. - output
  26. - `tf-savedmodel` - Tensorflow saved model
  27. - `tf-trt` - TF-TRT saved model
  28. - `ts-trace` - PyTorch traced ScriptModule
  29. - `ts-script` - PyTorch scripted ScriptModule
  30. - `onnx` - ONNX
  31. - `trt` - TensorRT plan file
  32. For tf-keras input you can use:
  33. - --large-model flag - helps loading model which exceeds maximum protobuf size of 2GB
  34. - --tf-allow-growth flag - control limiting GPU memory growth feature
  35. (https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth). By default it is disabled.
  36. """
  37. import argparse
  38. import logging
  39. import os
  40. from pathlib import Path
  41. os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
  42. os.environ["TF_ENABLE_DEPRECATION_WARNINGS"] = "1"
  43. # method from PEP-366 to support relative import in executed modules
  44. if __name__ == "__main__" and __package__ is None:
  45. __package__ = Path(__file__).parent.name
  46. from .deployment_toolkit.args import ArgParserGenerator
  47. from .deployment_toolkit.core import (
  48. DATALOADER_FN_NAME,
  49. BaseConverter,
  50. BaseLoader,
  51. BaseSaver,
  52. Format,
  53. Precision,
  54. load_from_file,
  55. )
  56. from .deployment_toolkit.extensions import converters, loaders, savers
  57. LOGGER = logging.getLogger("convert_model")
  58. INPUT_MODEL_TYPES = [Format.TF_ESTIMATOR, Format.TF_KERAS, Format.TF_SAVEDMODEL, Format.PYT]
  59. OUTPUT_MODEL_TYPES = [Format.TF_SAVEDMODEL, Format.TF_TRT, Format.ONNX, Format.TRT, Format.TS_TRACE, Format.TS_SCRIPT]
  60. def _get_args():
  61. parser = argparse.ArgumentParser(description="Script for conversion between model formats.", allow_abbrev=False)
  62. parser.add_argument("--input-path", help="Path to input model file (python module or binary file)", required=True)
  63. parser.add_argument(
  64. "--input-type", help="Input model type", choices=[f.value for f in INPUT_MODEL_TYPES], required=True
  65. )
  66. parser.add_argument("--output-path", help="Path to output model file", required=True)
  67. parser.add_argument(
  68. "--output-type", help="Output model type", choices=[f.value for f in OUTPUT_MODEL_TYPES], required=True
  69. )
  70. parser.add_argument("--dataloader", help="Path to python module containing data loader")
  71. parser.add_argument("-v", "--verbose", help="Verbose logs", action="store_true", default=False)
  72. parser.add_argument(
  73. "--ignore-unknown-parameters",
  74. help="Ignore unknown parameters (argument often used in CI where set of arguments is constant)",
  75. action="store_true",
  76. default=False,
  77. )
  78. args, unparsed_args = parser.parse_known_args()
  79. Loader: BaseLoader = loaders.get(args.input_type)
  80. ArgParserGenerator(Loader, module_path=args.input_path).update_argparser(parser)
  81. converter_name = f"{args.input_type}--{args.output_type}"
  82. Converter: BaseConverter = converters.get(converter_name)
  83. if Converter is not None:
  84. ArgParserGenerator(Converter).update_argparser(parser)
  85. Saver: BaseSaver = savers.get(args.output_type)
  86. ArgParserGenerator(Saver).update_argparser(parser)
  87. if args.dataloader is not None:
  88. get_dataloader_fn = load_from_file(args.dataloader, label="dataloader", target=DATALOADER_FN_NAME)
  89. ArgParserGenerator(get_dataloader_fn).update_argparser(parser)
  90. if args.ignore_unknown_parameters:
  91. args, unknown_args = parser.parse_known_args()
  92. LOGGER.warning(f"Got additional args {unknown_args}")
  93. else:
  94. args = parser.parse_args()
  95. return args
  96. def main():
  97. args = _get_args()
  98. log_level = logging.INFO if not args.verbose else logging.DEBUG
  99. log_format = "%(asctime)s %(levelname)s %(name)s %(message)s"
  100. logging.basicConfig(level=log_level, format=log_format)
  101. LOGGER.info(f"args:")
  102. for key, value in vars(args).items():
  103. LOGGER.info(f" {key} = {value}")
  104. requested_model_precision = Precision(args.precision)
  105. dataloader_fn = None
  106. # if conversion is required, temporary change model load precision to that required by converter
  107. # it is for TensorRT converters which require fp32 models for all requested precisions
  108. converter_name = f"{args.input_type}--{args.output_type}"
  109. Converter: BaseConverter = converters.get(converter_name)
  110. if Converter:
  111. args.precision = Converter.required_source_model_precision(requested_model_precision).value
  112. Loader: BaseLoader = loaders.get(args.input_type)
  113. loader = ArgParserGenerator(Loader, module_path=args.input_path).from_args(args)
  114. model = loader.load(args.input_path)
  115. LOGGER.info("inputs: %s", model.inputs)
  116. LOGGER.info("outputs: %s", model.outputs)
  117. if Converter: # if conversion is needed
  118. # dataloader must much source model precision - so not recovering it yet
  119. if args.dataloader is not None:
  120. get_dataloader_fn = load_from_file(args.dataloader, label="dataloader", target=DATALOADER_FN_NAME)
  121. dataloader_fn = ArgParserGenerator(get_dataloader_fn).from_args(args)
  122. # recover precision to that requested by user
  123. args.precision = requested_model_precision.value
  124. if Converter:
  125. converter = ArgParserGenerator(Converter).from_args(args)
  126. model = converter.convert(model, dataloader_fn=dataloader_fn)
  127. Saver: BaseSaver = savers.get(args.output_type)
  128. saver = ArgParserGenerator(Saver).from_args(args)
  129. saver.save(model, args.output_path)
  130. return 0
  131. if __name__ == "__main__":
  132. main()