evaluate_accuracy.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #
  15. # author: Tomasz Grel ([email protected])
  16. import dataloading.feature_spec
  17. import os
  18. import numpy as np
  19. import argparse
  20. import dllogger
  21. from dataloading.dataloader import create_input_pipelines
  22. from nn.evaluator import Evaluator
  23. from utils.logging import IterTimer, init_logging
  24. import deployment.tf.triton_ensemble_wrapper
  25. import deployment.hps.triton_ensemble_wrapper
  26. def log_results(auc, test_loss, latencies, batch_size, compute_latencies=False, warmup_steps=10):
  27. # don't benchmark the first few warmup steps
  28. latencies = latencies[warmup_steps:]
  29. result_data = {
  30. 'mean_inference_throughput': batch_size / np.mean(latencies),
  31. 'mean_inference_latency': np.mean(latencies)
  32. }
  33. if compute_latencies:
  34. for percentile in [90, 95, 99]:
  35. result_data[f'p{percentile}_inference_latency'] = np.percentile(latencies, percentile)
  36. result_data['auc'] = auc
  37. result_data['test_loss'] = test_loss
  38. dllogger.log(data=result_data, step=tuple())
  39. def parse_args():
  40. parser = argparse.ArgumentParser(description='')
  41. parser.add_argument('--dataset_path', type=str, required=True, help='')
  42. parser.add_argument('--dataset_type', default='tf_raw', type=str, help='')
  43. parser.add_argument('--feature_spec', default='feature_spec.yaml', type=str, help='')
  44. parser.add_argument('--batch_size', type=int, default=32768, help='Batch size')
  45. parser.add_argument('--auc_thresholds', type=int, default=8000, help='')
  46. parser.add_argument('--max_steps', type=int, default=None, help='')
  47. parser.add_argument('--print_freq', type=int, default=10, help='')
  48. parser.add_argument('--log_path', type=str, default='dlrm_tf_log.json', help='triton_inference_log.json')
  49. parser.add_argument('--verbose', action='store_true', default=False, help='')
  50. parser.add_argument('--test_on_train', action='store_true', default=False,
  51. help='Run validation on the training set.')
  52. parser.add_argument('--fused_embedding', action='store_true', default=False,
  53. help='Fuse the embedding table together for better GPU utilization.')
  54. parser.add_argument("--model_name", type=str, help="The name of the model used for inference.", required=True)
  55. parser.add_argument("--sparse_input_format", type=str, choices=["tf-savedmodel", "hps"],
  56. required=True, default="tf-savedmodel")
  57. args = parser.parse_args()
  58. return args
  59. def main():
  60. args = parse_args()
  61. init_logging(log_path=args.log_path, params_dict=args.__dict__)
  62. fspec = dataloading.feature_spec.FeatureSpec.from_yaml(os.path.join(args.dataset_path, args.feature_spec))
  63. num_tables = len(fspec.get_categorical_sizes())
  64. table_ids = list(range(num_tables)) # possibly wrong ordering, to be tested
  65. train_pipeline, validation_pipeline = create_input_pipelines(dataset_type=args.dataset_type,
  66. dataset_path=args.dataset_path,
  67. train_batch_size=args.batch_size,
  68. test_batch_size=args.batch_size,
  69. table_ids=table_ids,
  70. feature_spec=args.feature_spec,
  71. rank=0, world_size=1)
  72. if args.test_on_train:
  73. validation_pipeline = train_pipeline
  74. if args.sparse_input_format == 'hps':
  75. wrapper_cls = deployment.hps.triton_ensemble_wrapper.RecsysTritonEnsemble
  76. else:
  77. wrapper_cls = deployment.tf.triton_ensemble_wrapper.RecsysTritonEnsemble
  78. model = wrapper_cls(model_name=args.model_name, num_tables=num_tables, verbose=args.verbose,
  79. categorical_sizes=fspec.get_categorical_sizes(), fused_embedding=args.fused_embedding)
  80. timer = IterTimer(train_batch_size=args.batch_size, test_batch_size=args.batch_size,
  81. optimizer=None, print_freq=args.print_freq, enabled=True)
  82. evaluator = Evaluator(model=model, timer=timer, auc_thresholds=args.auc_thresholds,
  83. max_steps=args.max_steps, cast_dtype=None)
  84. auc, test_loss, latencies = evaluator(validation_pipeline=validation_pipeline)
  85. log_results(auc, test_loss, latencies, batch_size=args.batch_size)
  86. print('DONE')
  87. if __name__ == '__main__':
  88. main()