run_inference_on_fw.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. #!/usr/bin/env python3
  2. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. r"""
  16. To infer the model on framework runtime, you can use `run_inference_on_fw.py` script.
  17. It infers data obtained from pointed data loader locally and saves received data into npz files.
  18. Those files are stored in directory pointed by `--output-dir` argument.
  19. Example call:
  20. ```shell script
  21. python ./triton/run_inference_on_fw.py \
  22. --input-path /models/exported/model.onnx \
  23. --input-type onnx \
  24. --dataloader triton/dataloader.py \
  25. --data-dir /data/imagenet \
  26. --batch-size 32 \
  27. --output-dir /results/dump_local \
  28. --dump-labels
  29. ```
  30. """
  31. import argparse
  32. import logging
  33. import os
  34. from pathlib import Path
  35. os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
  36. os.environ["TF_ENABLE_DEPRECATION_WARNINGS"] = "0"
  37. from tqdm import tqdm
  38. # method from PEP-366 to support relative import in executed modules
  39. if __package__ is None:
  40. __package__ = Path(__file__).parent.name
  41. from .deployment_toolkit.args import ArgParserGenerator
  42. from .deployment_toolkit.core import DATALOADER_FN_NAME, BaseLoader, BaseRunner, Format, load_from_file
  43. from .deployment_toolkit.dump import NpzWriter
  44. from .deployment_toolkit.extensions import loaders, runners
  45. LOGGER = logging.getLogger("run_inference_on_fw")
  46. def _verify_and_format_dump(args, ids, x, y_pred, y_real):
  47. data = {"outputs": y_pred, "ids": {"ids": ids}}
  48. if args.dump_inputs:
  49. data["inputs"] = x
  50. if args.dump_labels:
  51. if not y_real:
  52. raise ValueError(
  53. "Found empty label values. Please provide labels in dataloader_fn or do not use --dump-labels argument"
  54. )
  55. data["labels"] = y_real
  56. return data
  57. def _parse_and_validate_args():
  58. supported_inputs = set(runners.supported_extensions) & set(loaders.supported_extensions)
  59. parser = argparse.ArgumentParser(description="Dump local inference output of given model", allow_abbrev=False)
  60. parser.add_argument("--input-path", help="Path to input model", required=True)
  61. parser.add_argument("--input-type", help="Input model type", choices=supported_inputs, required=True)
  62. parser.add_argument("--dataloader", help="Path to python file containing dataloader.", required=True)
  63. parser.add_argument("--output-dir", help="Path to dir where output files will be stored", required=True)
  64. parser.add_argument("--dump-labels", help="Dump labels to output dir", action="store_true", default=False)
  65. parser.add_argument("--dump-inputs", help="Dump inputs to output dir", action="store_true", default=False)
  66. parser.add_argument("-v", "--verbose", help="Verbose logs", action="store_true", default=False)
  67. args, *_ = parser.parse_known_args()
  68. get_dataloader_fn = load_from_file(args.dataloader, label="dataloader", target=DATALOADER_FN_NAME)
  69. ArgParserGenerator(get_dataloader_fn).update_argparser(parser)
  70. Loader: BaseLoader = loaders.get(args.input_type)
  71. ArgParserGenerator(Loader, module_path=args.input_path).update_argparser(parser)
  72. Runner: BaseRunner = runners.get(args.input_type)
  73. ArgParserGenerator(Runner).update_argparser(parser)
  74. args = parser.parse_args()
  75. types_requiring_io_params = []
  76. if args.input_type in types_requiring_io_params and not all(p for p in [args.inputs, args.outputs]):
  77. parser.error(f"For {args.input_type} input provide --inputs and --outputs parameters")
  78. return args
  79. def main():
  80. args = _parse_and_validate_args()
  81. log_level = logging.INFO if not args.verbose else logging.DEBUG
  82. log_format = "%(asctime)s %(levelname)s %(name)s %(message)s"
  83. logging.basicConfig(level=log_level, format=log_format)
  84. LOGGER.info(f"args:")
  85. for key, value in vars(args).items():
  86. LOGGER.info(f" {key} = {value}")
  87. Loader: BaseLoader = loaders.get(args.input_type)
  88. Runner: BaseRunner = runners.get(args.input_type)
  89. loader = ArgParserGenerator(Loader, module_path=args.input_path).from_args(args)
  90. runner = ArgParserGenerator(Runner).from_args(args)
  91. LOGGER.info(f"Loading {args.input_path}")
  92. model = loader.load(args.input_path)
  93. with runner.init_inference(model=model) as runner_session, NpzWriter(args.output_dir) as writer:
  94. get_dataloader_fn = load_from_file(args.dataloader, label="dataloader", target=DATALOADER_FN_NAME)
  95. dataloader_fn = ArgParserGenerator(get_dataloader_fn).from_args(args)
  96. LOGGER.info(f"Data loader initialized; Running inference")
  97. for ids, x, y_real in tqdm(dataloader_fn(), unit="batch", mininterval=10):
  98. y_pred = runner_session(x)
  99. data = _verify_and_format_dump(args, ids=ids, x=x, y_pred=y_pred, y_real=y_real)
  100. writer.write(**data)
  101. LOGGER.info(f"Inference finished")
  102. if __name__ == "__main__":
  103. main()