evaluator.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #
  15. # author: Tomasz Grel ([email protected])
  16. import tensorflow as tf
  17. import time
  18. from .nn_utils import create_inputs_dict
  19. class Evaluator:
  20. def __init__(self, model, timer, auc_thresholds, max_steps=None, cast_dtype=None, distributed=False):
  21. self.model = model
  22. self.timer = timer
  23. self.max_steps = max_steps
  24. self.cast_dtype = cast_dtype
  25. self.distributed = distributed
  26. if self.distributed:
  27. import horovod.tensorflow as hvd
  28. self.hvd = hvd
  29. else:
  30. self.hvd = None
  31. self.auc_metric = tf.keras.metrics.AUC(num_thresholds=auc_thresholds, curve='ROC',
  32. summation_method='interpolation', from_logits=True)
  33. self.bce_op = tf.keras.losses.BinaryCrossentropy(reduction=tf.keras.losses.Reduction.NONE, from_logits=True)
  34. def _reset(self):
  35. self.latencies, self.all_test_losses = [], []
  36. self.auc_metric.reset_state()
  37. @tf.function
  38. def update_auc_metric(self, labels, y_pred):
  39. self.auc_metric.update_state(labels, y_pred)
  40. @tf.function
  41. def compute_bce_loss(self, labels, y_pred):
  42. return self.bce_op(labels, y_pred)
  43. def _step(self, pipe):
  44. begin = time.time()
  45. batch = pipe.get_next()
  46. (numerical_features, categorical_features), labels = batch
  47. if self.cast_dtype is not None:
  48. numerical_features = tf.cast(numerical_features, self.cast_dtype)
  49. inputs = create_inputs_dict(numerical_features, categorical_features)
  50. y_pred = self.model(inputs, sigmoid=False, training=False)
  51. end = time.time()
  52. self.latencies.append(end - begin)
  53. if self.distributed:
  54. y_pred = self.hvd.allgather(y_pred)
  55. labels = self.hvd.allgather(labels)
  56. self.timer.step_test()
  57. if not self.distributed or self.hvd.rank() == 0:
  58. self.update_auc_metric(labels, y_pred)
  59. test_loss = self.compute_bce_loss(labels, y_pred)
  60. self.all_test_losses.append(test_loss)
  61. def __call__(self, validation_pipeline):
  62. self._reset()
  63. auc, test_loss = 0, 0
  64. pipe = iter(validation_pipeline.op())
  65. num_steps = len(validation_pipeline)
  66. if self.max_steps is not None and self.max_steps >= 0:
  67. num_steps = min(num_steps, self.max_steps)
  68. for _ in range(num_steps):
  69. self._step(pipe)
  70. if not self.distributed or self.hvd.rank() == 0:
  71. auc = self.auc_metric.result().numpy().item()
  72. test_loss = tf.reduce_mean(self.all_test_losses).numpy().item()
  73. return auc, test_loss, self.latencies