model.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #
  15. # author: Tomasz Grel ([email protected])
  16. import tensorflow as tf
  17. import horovod.tensorflow as hvd
  18. import time
  19. import os
  20. from utils.distributed import dist_print
  21. from .dense_model import DenseModel, dense_model_parameters
  22. from .sparse_model import SparseModel, sparse_model_parameters
  23. from .nn_utils import create_inputs_dict
  24. class Model(tf.keras.Model):
  25. def __init__(self, **kwargs):
  26. super(Model, self).__init__()
  27. if kwargs:
  28. dense_model_kwargs = {k:kwargs[k] for k in dense_model_parameters}
  29. self.dense_model = DenseModel(**dense_model_kwargs)
  30. sparse_model_kwargs = {k:kwargs[k] for k in sparse_model_parameters}
  31. self.sparse_model = SparseModel(**sparse_model_kwargs)
  32. @staticmethod
  33. def create_from_checkpoint(checkpoint_path):
  34. if checkpoint_path is None:
  35. return None
  36. model = Model()
  37. model.dense_model = DenseModel.from_config(os.path.join(checkpoint_path, 'dense', 'config.json'))
  38. model.sparse_model = SparseModel.from_config(os.path.join(checkpoint_path, 'sparse', 'config.json'))
  39. model.restore_checkpoint(checkpoint_path)
  40. return model
  41. def force_initialization(self, global_batch_size):
  42. numerical_features = tf.zeros(shape=[global_batch_size // hvd.size(),
  43. self.dense_model.num_numerical_features])
  44. categorical_features = [tf.zeros(shape=[global_batch_size, 1], dtype=tf.int32)
  45. for _ in range(len(self.sparse_model.get_local_table_ids(hvd.rank())))]
  46. inputs = create_inputs_dict(numerical_features, categorical_features)
  47. self(inputs=inputs)
  48. @tf.function
  49. def call(self, inputs, sigmoid=False, training=False):
  50. numerical_features, cat_features = list(inputs.values())
  51. embedding_outputs = self.sparse_model(cat_features)
  52. embedding_outputs = tf.reshape(embedding_outputs, shape=[-1])
  53. x = self.dense_model(numerical_features, embedding_outputs, sigmoid=sigmoid, training=training)
  54. return x
  55. def save_checkpoint(self, checkpoint_path):
  56. dist_print('Saving a checkpoint...')
  57. begin_save = time.time()
  58. os.makedirs(checkpoint_path, exist_ok=True)
  59. if hvd.rank() == 0:
  60. dense_checkpoint_dir = os.path.join(checkpoint_path, 'dense')
  61. os.makedirs(dense_checkpoint_dir, exist_ok=True)
  62. self.dense_model.save_config(os.path.join(dense_checkpoint_dir, 'config.json'))
  63. self.dense_model.save_weights(os.path.join(dense_checkpoint_dir, 'dense'))
  64. sparse_checkpoint_dir = os.path.join(checkpoint_path, 'sparse')
  65. os.makedirs(sparse_checkpoint_dir, exist_ok=True)
  66. self.sparse_model.save_config(os.path.join(sparse_checkpoint_dir, 'config.json'))
  67. self.sparse_model.save_checkpoint(sparse_checkpoint_dir)
  68. end_save = time.time()
  69. dist_print('Saved a checkpoint to ', checkpoint_path)
  70. dist_print(f'Saving a checkpoint took {end_save - begin_save:.3f}')
  71. def restore_checkpoint(self, checkpoint_path):
  72. begin = time.time()
  73. dist_print('Restoring a checkpoint...')
  74. local_batch = 64
  75. self.force_initialization(global_batch_size=hvd.size()*local_batch)
  76. dense_checkpoint_path = os.path.join(checkpoint_path, 'dense', 'dense')
  77. self.dense_model.load_weights(dense_checkpoint_path)
  78. sparse_checkpoint_dir = os.path.join(checkpoint_path, 'sparse')
  79. self.sparse_model.load_checkpoint(sparse_checkpoint_dir)
  80. end = time.time()
  81. dist_print(f'Restoring a checkpoint took: {end-begin:.3f} seconds')
  82. return self