dense_model.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #
  15. # author: Tomasz Grel ([email protected])
  16. import json
  17. import tensorflow.keras.initializers as initializers
  18. import math
  19. from tensorflow.python.keras.saving.saving_utils import model_input_signature
  20. from .dcn import CrossNetwork
  21. from . import interaction
  22. import tensorflow as tf
  23. import horovod.tensorflow as hvd
  24. try:
  25. from tensorflow_dot_based_interact.python.ops import dot_based_interact_ops
  26. except ImportError:
  27. print('WARNING: Could not import the custom dot-interaction kernels')
  28. dense_model_parameters = ['embedding_dim', 'interaction', 'bottom_mlp_dims',
  29. 'top_mlp_dims', 'num_numerical_features', 'categorical_cardinalities',
  30. 'transpose', 'num_cross_layers', 'cross_layer_projection_dim',
  31. 'batch_size']
  32. class DenseModel(tf.keras.Model):
  33. def __init__(self, **kwargs):
  34. super(DenseModel, self).__init__()
  35. for field in dense_model_parameters:
  36. self.__dict__[field] = kwargs[field]
  37. self.num_all_categorical_features = len(self.categorical_cardinalities)
  38. self.bottom_mlp_dims = [int(d) for d in self.bottom_mlp_dims]
  39. self.top_mlp_dims = [int(d) for d in self.top_mlp_dims]
  40. if self.interaction != 'cross' and any(dim != self.embedding_dim[0] for dim in self.embedding_dim):
  41. raise ValueError(f'For DLRM all embedding dimensions should be equal, '
  42. f'got interaction={interaction}, embedding_dim={self.embedding_dim}')
  43. if self.interaction != 'cross' and self.bottom_mlp_dims[-1] != self.embedding_dim[0]:
  44. raise ValueError(f'Final dimension of the Bottom MLP should match embedding dimension. '
  45. f'Got: {self.bottom_mlp_dims[-1]} and {self.embedding_dim} respectively.')
  46. self._create_interaction_op()
  47. self._create_bottom_mlp()
  48. self._create_top_mlp()
  49. self.bottom_mlp_padding = self._compute_padding(num_features=self.num_numerical_features)
  50. self.top_mlp_padding = self._compute_padding(num_features=self._get_top_mlp_input_features())
  51. def _create_interaction_op(self):
  52. if self.interaction == 'dot_custom_cuda':
  53. self.interact_op = dot_based_interact_ops.dot_based_interact
  54. elif self.interaction == 'dot_tensorflow':
  55. # TODO: add support for datasets with no dense features
  56. self.interact_op = interaction.DotInteractionGather(num_features=self.num_all_categorical_features + 1)
  57. elif self.interaction == 'cross':
  58. self.interact_op = CrossNetwork(num_layers=self.num_cross_layers,
  59. projection_dim=self.cross_layer_projection_dim)
  60. else:
  61. raise ValueError(f'Unknown interaction {self.interaction}')
  62. @staticmethod
  63. def _compute_padding(num_features, multiple=8):
  64. pad_to = math.ceil(num_features / multiple) * multiple
  65. return pad_to - num_features
  66. def _get_top_mlp_input_features(self):
  67. if self.interaction == 'cross':
  68. num_features = sum(self.embedding_dim)
  69. if self.num_numerical_features != 0:
  70. num_features += self.bottom_mlp_dims[-1]
  71. return num_features
  72. else:
  73. num_features = self.num_all_categorical_features
  74. if self.num_numerical_features != 0:
  75. num_features += 1
  76. num_features = num_features * (num_features - 1)
  77. num_features = num_features // 2
  78. num_features = num_features + self.bottom_mlp_dims[-1]
  79. return num_features
  80. def _create_bottom_mlp(self):
  81. self.bottom_mlp_layers = []
  82. for dim in self.bottom_mlp_dims:
  83. kernel_initializer = initializers.GlorotNormal()
  84. bias_initializer = initializers.RandomNormal(stddev=math.sqrt(1. / dim))
  85. l = tf.keras.layers.Dense(dim, activation='relu',
  86. kernel_initializer=kernel_initializer,
  87. bias_initializer=bias_initializer)
  88. self.bottom_mlp_layers.append(l)
  89. def _create_top_mlp(self):
  90. self.top_mlp = []
  91. for i, dim in enumerate(self.top_mlp_dims):
  92. if i == len(self.top_mlp_dims) - 1:
  93. # final layer
  94. activation = 'linear'
  95. else:
  96. activation = 'relu'
  97. kernel_initializer = initializers.GlorotNormal()
  98. bias_initializer = initializers.RandomNormal(stddev=math.sqrt(1. / dim))
  99. l = tf.keras.layers.Dense(dim, activation=activation,
  100. kernel_initializer=kernel_initializer,
  101. bias_initializer=bias_initializer)
  102. self.top_mlp.append(l)
  103. def transpose_nonequal_embedding_dim(self, embedding_outputs, numerical_features):
  104. # We get a table-major format here for inference,
  105. # but the sizes of the tables are not the same.
  106. # Therefore a simple transposition will not work,
  107. # we need to perform multiple splits and concats instead.
  108. # TODO: test this.
  109. embedding_outputs = tf.reshape(embedding_outputs, shape=[-1])
  110. batch_size = numerical_features.shape[0]
  111. split_sizes = [batch_size * dim for dim in self.embedding_dim]
  112. embedding_outputs = tf.split(embedding_outputs, num_or_size_splits=split_sizes)
  113. embedding_outputs = [tf.split(eout, num_or_size_splits=dim) for eout, dim in zip(embedding_outputs,
  114. self.emdedding_dim)]
  115. transposed_outputs = [] * batch_size
  116. for i, o in enumerate(transposed_outputs):
  117. ith_sample = [out[i] for out in embedding_outputs]
  118. ith_sample = tf.concat(ith_sample, axis=1)
  119. transposed_outputs[i] = ith_sample
  120. transposed_outputs = tf.concat(transposed_outputs, axis=0)
  121. return tf.reshape(transposed_outputs, shape=[batch_size, sum(self.embedding_dim)])
  122. def transpose_input(self, embedding_outputs, numerical_features):
  123. if any(dim != self.embedding_dim[0] for dim in self.embedding_dim):
  124. return self.transpose_nonequal_embedding_dim(embedding_outputs, numerical_features)
  125. else:
  126. embedding_outputs = tf.reshape(embedding_outputs, shape=[self.num_all_categorical_features, -1, self.embedding_dim[0]])
  127. return tf.transpose(embedding_outputs, perm=[1, 0, 2])
  128. def reshape_input(self, embedding_outputs):
  129. if self.interaction == 'cross':
  130. return tf.reshape(embedding_outputs, shape=[-1, sum(self.embedding_dim)])
  131. else:
  132. return tf.reshape(embedding_outputs, shape=[-1, self.num_all_categorical_features, self.embedding_dim[0]])
  133. @tf.function
  134. def call(self, numerical_features, embedding_outputs, sigmoid=False, training=False):
  135. numerical_features = tf.reshape(numerical_features, shape=[-1, self.num_numerical_features])
  136. bottom_mlp_out = self._call_bottom_mlp(numerical_features, training)
  137. if self.transpose:
  138. embedding_outputs = self.transpose_input(embedding_outputs, numerical_features)
  139. embedding_outputs = self.reshape_input(embedding_outputs)
  140. x = self._call_interaction(embedding_outputs, bottom_mlp_out)
  141. x = self._call_top_mlp(x)
  142. if sigmoid:
  143. x = tf.math.sigmoid(x)
  144. x = tf.cast(x, tf.float32)
  145. return x
  146. def _pad_bottom_mlp_input(self, numerical_features, training):
  147. if training:
  148. # When training, padding with a statically fixed batch size so that XLA has better shape information.
  149. # This yields a significant (~15%) speedup for singleGPU DLRM.
  150. padding = tf.zeros(shape=[self.batch_size // hvd.size(), self.bottom_mlp_padding],
  151. dtype=self.compute_dtype)
  152. x = tf.concat([numerical_features, padding], axis=1)
  153. else:
  154. # For inference, use tf.pad.
  155. # This way inference can be performed with any batch size on the deployed SavedModel.
  156. x = tf.pad(numerical_features, [[0, 0], [0, self.bottom_mlp_padding]])
  157. return x
  158. def _call_bottom_mlp(self, numerical_features, training):
  159. numerical_features = tf.cast(numerical_features, dtype=self.compute_dtype)
  160. x = self._pad_bottom_mlp_input(numerical_features, training)
  161. with tf.name_scope('bottom_mlp'):
  162. for l in self.bottom_mlp_layers:
  163. x = l(x)
  164. x = tf.expand_dims(x, axis=1)
  165. bottom_mlp_out = x
  166. return bottom_mlp_out
  167. def _call_interaction(self, embedding_outputs, bottom_mlp_out):
  168. if self.interaction == 'cross':
  169. bottom_mlp_out = tf.reshape(bottom_mlp_out, [-1, self.bottom_mlp_dims[-1]])
  170. x = tf.concat([bottom_mlp_out, embedding_outputs], axis=1)
  171. x = self.interact_op(x)
  172. else:
  173. bottom_part_output = tf.concat([bottom_mlp_out, embedding_outputs], axis=1)
  174. x = tf.reshape(bottom_part_output, shape=[-1, self.num_all_categorical_features + 1, self.embedding_dim[0]])
  175. bottom_mlp_out = tf.reshape(bottom_mlp_out, shape=[-1, self.bottom_mlp_dims[-1]])
  176. x = self.interact_op(x, bottom_mlp_out)
  177. return x
  178. def _call_top_mlp(self, x):
  179. if self.interaction != 'dot_custom_cuda':
  180. x = tf.reshape(x, [-1, self._get_top_mlp_input_features()])
  181. x = tf.pad(x, [[0, 0], [0, self.top_mlp_padding]])
  182. with tf.name_scope('top_mlp'):
  183. for i, l in enumerate(self.top_mlp):
  184. x = l(x)
  185. return x
  186. def save_model(self, path, save_input_signature=False):
  187. if save_input_signature:
  188. input_sig = model_input_signature(self, keep_original_batch_size=True)
  189. call_graph = tf.function(self)
  190. signatures = call_graph.get_concrete_function(input_sig[0])
  191. else:
  192. signatures = None
  193. tf.keras.models.save_model(model=self, filepath=path, overwrite=True, signatures=signatures)
  194. def force_initialization(self, batch_size=64, training=False, flattened_input=True):
  195. if flattened_input:
  196. embeddings_output = tf.zeros([batch_size * sum(self.embedding_dim)])
  197. numerical_input = tf.zeros([batch_size * self.num_numerical_features])
  198. else:
  199. embeddings_output = tf.zeros([batch_size, sum(self.embedding_dim)])
  200. numerical_input = tf.zeros([batch_size, self.num_numerical_features])
  201. _ = self(numerical_input, embeddings_output, sigmoid=False, training=training)
  202. @staticmethod
  203. def load_model(path):
  204. print('Loading a saved model from', path)
  205. loaded = tf.keras.models.load_model(path)
  206. return loaded
  207. def save_config(self, path):
  208. config = {k : self.__dict__[k] for k in dense_model_parameters}
  209. with open(path, 'w') as f:
  210. json.dump(obj=config, fp=f, indent=4)
  211. @staticmethod
  212. def from_config(path):
  213. with open(path) as f:
  214. config = json.load(fp=f)
  215. return DenseModel(**config)