config_model_on_triton.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. #!/usr/bin/env python3
  2. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. r"""
  16. To configure model on Triton, you can use `config_model_on_triton.py` script.
  17. This will prepare layout of Model Repository, including Model Configuration.
  18. ```shell script
  19. python ./triton/config_model_on_triton.py \
  20. --model-repository /model_repository \
  21. --model-path /models/exported/model.onnx \
  22. --model-format onnx \
  23. --model-name ResNet50 \
  24. --model-version 1 \
  25. --max-batch-size 32 \
  26. --precision fp16 \
  27. --backend-accelerator trt \
  28. --load-model explicit \
  29. --timeout 120 \
  30. --verbose
  31. ```
  32. If Triton server to which we prepare model repository is running with **explicit model control mode**,
  33. use `--load-model` argument to send request load_model request to Triton Inference Server.
  34. If server is listening on non-default address or port use `--server-url` argument to point server control endpoint.
  35. If it is required to use HTTP protocol to communicate with Triton server use `--http` argument.
  36. To improve inference throughput you can use
  37. [dynamic batching](https://github.com/triton-inference-server/server/blob/master/docs/model_configuration.md#dynamic-batcher)
  38. for your model by providing `--preferred-batch-sizes` and `--max-queue-delay-us` parameters.
  39. For models which doesn't support batching, set `--max-batch-sizes` to 0.
  40. By default Triton will [automatically obtain inputs and outputs definitions](https://github.com/triton-inference-server/server/blob/master/docs/model_configuration.md#auto-generated-model-configuration).
  41. but for TorchScript ang TF GraphDef models script uses file with I/O specs. This file is automatically generated
  42. when the model is converted to ScriptModule (either traced or scripted).
  43. If there is a need to pass different than default path to I/O spec file use `--io-spec` CLI argument.
  44. I/O spec file is yaml file with below structure:
  45. ```yaml
  46. - inputs:
  47. - name: input
  48. dtype: float32 # np.dtype name
  49. shape: [None, 224, 224, 3]
  50. - outputs:
  51. - name: probabilities
  52. dtype: float32
  53. shape: [None, 1001]
  54. - name: classes
  55. dtype: int32
  56. shape: [None, 1]
  57. ```
  58. """
  59. import argparse
  60. import logging
  61. import time
  62. from model_navigator import Accelerator, Format, Precision
  63. from model_navigator.args import str2bool
  64. from model_navigator.log import set_logger, log_dict
  65. from model_navigator.triton import ModelConfig, TritonClient, TritonModelStore
  66. LOGGER = logging.getLogger("config_model")
  67. def _available_enum_values(my_enum):
  68. return [item.value for item in my_enum]
  69. def main():
  70. parser = argparse.ArgumentParser(
  71. description="Create Triton model repository and model configuration", allow_abbrev=False
  72. )
  73. parser.add_argument("--model-repository", required=True, help="Path to Triton model repository.")
  74. parser.add_argument("--model-path", required=True, help="Path to model to configure")
  75. # TODO: automation
  76. parser.add_argument(
  77. "--model-format",
  78. required=True,
  79. choices=_available_enum_values(Format),
  80. help="Format of model to deploy",
  81. )
  82. parser.add_argument("--model-name", required=True, help="Model name")
  83. parser.add_argument("--model-version", default="1", help="Version of model (default 1)")
  84. parser.add_argument(
  85. "--max-batch-size",
  86. type=int,
  87. default=32,
  88. help="Maximum batch size allowed for inference. "
  89. "A max_batch_size value of 0 indicates that batching is not allowed for the model",
  90. )
  91. # TODO: automation
  92. parser.add_argument(
  93. "--precision",
  94. type=str,
  95. default=Precision.FP16.value,
  96. choices=_available_enum_values(Precision),
  97. help="Model precision (parameter used only by Tensorflow backend with TensorRT optimization)",
  98. )
  99. # Triton Inference Server endpoint
  100. parser.add_argument(
  101. "--server-url",
  102. type=str,
  103. default="grpc://localhost:8001",
  104. help="Inference server URL in format protocol://host[:port] (default grpc://localhost:8001)",
  105. )
  106. parser.add_argument(
  107. "--load-model",
  108. choices=["none", "poll", "explicit"],
  109. help="Loading model while Triton Server is in given model control mode",
  110. )
  111. parser.add_argument(
  112. "--timeout", default=120, help="Timeout in seconds to wait till model load (default=120)", type=int
  113. )
  114. # optimization related
  115. parser.add_argument(
  116. "--backend-accelerator",
  117. type=str,
  118. choices=_available_enum_values(Accelerator),
  119. default=Accelerator.TRT.value,
  120. help="Select Backend Accelerator used to serve model",
  121. )
  122. parser.add_argument("--number-of-model-instances", type=int, default=1, help="Number of model instances per GPU")
  123. parser.add_argument(
  124. "--preferred-batch-sizes",
  125. type=int,
  126. nargs="*",
  127. help="Batch sizes that the dynamic batcher should attempt to create. "
  128. "In case --max-queue-delay-us is set and this parameter is not, default value will be --max-batch-size",
  129. )
  130. parser.add_argument(
  131. "--max-queue-delay-us",
  132. type=int,
  133. default=0,
  134. help="Max delay time which dynamic batcher shall wait to form a batch (default 0)",
  135. )
  136. parser.add_argument(
  137. "--capture-cuda-graph",
  138. type=int,
  139. default=0,
  140. help="Use cuda capture graph (used only by TensorRT platform)",
  141. )
  142. parser.add_argument("-v", "--verbose", help="Provide verbose logs", type=str2bool, default=False)
  143. args = parser.parse_args()
  144. set_logger(verbose=args.verbose)
  145. log_dict("args", vars(args))
  146. config = ModelConfig.create(
  147. model_path=args.model_path,
  148. # model definition
  149. model_name=args.model_name,
  150. model_version=args.model_version,
  151. model_format=args.model_format,
  152. precision=args.precision,
  153. max_batch_size=args.max_batch_size,
  154. # optimization
  155. accelerator=args.backend_accelerator,
  156. gpu_engine_count=args.number_of_model_instances,
  157. preferred_batch_sizes=args.preferred_batch_sizes or [],
  158. max_queue_delay_us=args.max_queue_delay_us,
  159. capture_cuda_graph=args.capture_cuda_graph,
  160. )
  161. model_store = TritonModelStore(args.model_repository)
  162. model_store.deploy_model(model_config=config, model_path=args.model_path)
  163. if args.load_model != "none":
  164. client = TritonClient(server_url=args.server_url, verbose=args.verbose)
  165. client.wait_for_server_ready(timeout=args.timeout)
  166. if args.load_model == "explicit":
  167. client.load_model(model_name=args.model_name)
  168. if args.load_model == "poll":
  169. time.sleep(15)
  170. client.wait_for_model(model_name=args.model_name, model_version=args.model_version, timeout_s=args.timeout)
  171. if __name__ == "__main__":
  172. main()