benchmarking.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from mxnet.io import DataIter
  15. import time
  16. class BenchmarkingDataIter:
  17. def __init__(self, data_iter, benchmark_iters=None):
  18. self.data_iter = data_iter
  19. self.benchmark_iters = benchmark_iters
  20. self.overall_time = 0
  21. self.num = 0
  22. def __iter__(self):
  23. iter(self.data_iter)
  24. return self
  25. def next(self):
  26. if self.benchmark_iters is not None and self.num >= self.benchmark_iters:
  27. raise StopIteration
  28. try:
  29. start_time = time.time()
  30. ret = self.data_iter.next()
  31. end_time = time.time()
  32. except StopIteration:
  33. if self.benchmark_iters is None:
  34. raise
  35. self.data_iter.reset()
  36. start_time = time.time()
  37. ret = self.data_iter.next()
  38. end_time = time.time()
  39. if self.num != 0:
  40. self.overall_time += end_time - start_time
  41. self.num += 1
  42. return ret
  43. def __next__(self):
  44. return self.next()
  45. def __getattr__(self, attr):
  46. return getattr(self.data_iter, attr)
  47. def get_avg_time(self):
  48. if self.num <= 1:
  49. avg = float('nan')
  50. else:
  51. avg = self.overall_time / (self.num - 1)
  52. return avg
  53. def reset(self):
  54. self.overall_time = 0
  55. self.num = 0
  56. self.data_iter.reset()