client.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. # Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import json
  16. import torch
  17. from dlrm.data import data_loader
  18. from dlrm.data.synthetic_dataset import SyntheticDataset
  19. from tqdm import tqdm
  20. from tensorrtserver.api import *
  21. from sklearn.metrics import roc_auc_score
  22. from functools import partial
  23. def get_data_loader(batch_size, *, data_file, model_config):
  24. with open(model_config.dataset_config) as f:
  25. categorical_sizes = list(json.load(f).values())
  26. if data_file:
  27. data = data_loader.CriteoBinDataset(data_file=data_file,
  28. batch_size=batch_size, subset=None,
  29. numerical_features=model_config.num_numerical_features,
  30. categorical_features=len(categorical_sizes),
  31. online_shuffle=False)
  32. else:
  33. data = SyntheticDataset(num_entries=batch_size * 1024, batch_size=batch_size,
  34. dense_features=model_config.num_numerical_features,
  35. categorical_feature_sizes=categorical_sizes,
  36. device="cpu")
  37. return torch.utils.data.DataLoader(data,
  38. batch_size=None,
  39. num_workers=0,
  40. pin_memory=False)
  41. if __name__ == "__main__":
  42. parser = argparse.ArgumentParser()
  43. parser.add_argument("--triton-server-url", type=str, required=True,
  44. help="URL adress of trtion server (with port)")
  45. parser.add_argument("--triton-model-name", type=str, required=True,
  46. help="Triton deployed model name")
  47. parser.add_argument("--triton-model-version", type=int, default=-1,
  48. help="Triton model version")
  49. parser.add_argument("--protocol", type=str, default="HTTP",
  50. help="Communication protocol (HTTP/GRPC)")
  51. parser.add_argument("-v", "--verbose", action="store_true", default=False,
  52. help="Verbose mode.")
  53. parser.add_argument('-H', dest='http_headers', metavar="HTTP_HEADER",
  54. required=False, action='append',
  55. help='HTTP headers to add to inference server requests. ' +
  56. 'Format is -H"Header:Value".')
  57. parser.add_argument("--num_numerical_features", type=int, default=13)
  58. parser.add_argument("--dataset_config", type=str, required=True)
  59. parser.add_argument("--inference_data", type=str,
  60. help="Path to file with inference data.")
  61. parser.add_argument("--batch_size", type=int, default=1,
  62. help="Inference request batch size")
  63. parser.add_argument("--fp16", action="store_true", default=False,
  64. help="Use 16bit for numerical input")
  65. FLAGS = parser.parse_args()
  66. FLAGS.protocol = ProtocolType.from_str(FLAGS.protocol)
  67. # Create a health context, get the ready and live state of server.
  68. health_ctx = ServerHealthContext(FLAGS.triton_server_url, FLAGS.protocol,
  69. http_headers=FLAGS.http_headers, verbose=FLAGS.verbose)
  70. print("Health for model {}".format(FLAGS.triton_model_name))
  71. print("Live: {}".format(health_ctx.is_live()))
  72. print("Ready: {}".format(health_ctx.is_ready()))
  73. with ModelControlContext(FLAGS.triton_server_url, FLAGS.protocol) as ctx:
  74. ctx.load(FLAGS.triton_model_name)
  75. # Create a status context and get server status
  76. status_ctx = ServerStatusContext(FLAGS.triton_server_url, FLAGS.protocol, FLAGS.triton_model_name,
  77. http_headers=FLAGS.http_headers, verbose=FLAGS.verbose)
  78. print("Status for model {}".format(FLAGS.triton_model_name))
  79. print(status_ctx.get_server_status())
  80. # Create the inference context for the model.
  81. infer_ctx = InferContext(FLAGS.triton_server_url, FLAGS.protocol, FLAGS.triton_model_name,
  82. FLAGS.triton_model_version,
  83. http_headers=FLAGS.http_headers, verbose=FLAGS.verbose)
  84. dataloader = get_data_loader(FLAGS.batch_size,
  85. data_file=FLAGS.inference_data,
  86. model_config=FLAGS)
  87. results = []
  88. tgt_list = []
  89. for num, cat, target in tqdm(dataloader):
  90. num = num.cpu().numpy()
  91. if FLAGS.fp16:
  92. num = num.astype(np.float16)
  93. cat = cat.long().cpu().numpy()
  94. input_dict = {"input__0": tuple(num[i] for i in range(len(num))),
  95. "input__1": tuple(cat[i] for i in range(len(cat)))}
  96. output_keys = ["output__0"]
  97. output_dict = {x: InferContext.ResultFormat.RAW for x in output_keys}
  98. result = infer_ctx.run(input_dict, output_dict, len(num))
  99. results.append(result["output__0"])
  100. tgt_list.append(target.cpu().numpy())
  101. results = np.concatenate(results).squeeze()
  102. tgt_list = np.concatenate(tgt_list)
  103. score = roc_auc_score(tgt_list, results)
  104. print(F"Model score: {score}")
  105. with ModelControlContext(FLAGS.triton_server_url, FLAGS.protocol) as ctx:
  106. ctx.unload(FLAGS.triton_model_name)