|
|
@@ -20,7 +20,6 @@ import numpy as np
|
|
|
import json
|
|
|
|
|
|
from distributed_embeddings.python.layers import dist_model_parallel as dmp
|
|
|
-from distributed_embeddings.python.layers import embedding
|
|
|
|
|
|
from utils.checkpointing import get_variable_path
|
|
|
|
|
|
@@ -29,7 +28,19 @@ from .embedding import EmbeddingInitializer, DualEmbeddingGroup
|
|
|
|
|
|
sparse_model_parameters = ['use_mde_embeddings', 'embedding_dim', 'column_slice_threshold',
|
|
|
'embedding_zeros_initializer', 'embedding_trainable', 'categorical_cardinalities',
|
|
|
- 'concat_embedding', 'cpu_offloading_threshold_gb']
|
|
|
+ 'concat_embedding', 'cpu_offloading_threshold_gb',
|
|
|
+ 'data_parallel_input', 'row_slice_threshold', 'data_parallel_threshold']
|
|
|
+
|
|
|
+def _gigabytes_to_elements(gb, dtype=tf.float32):
|
|
|
+ if gb is None:
|
|
|
+ return None
|
|
|
+
|
|
|
+ if dtype == tf.float32:
|
|
|
+ bytes_per_element = 4
|
|
|
+ else:
|
|
|
+ raise ValueError(f'Unsupported dtype: {dtype}')
|
|
|
+
|
|
|
+ return gb * 10**9 / bytes_per_element
|
|
|
|
|
|
class SparseModel(tf.keras.Model):
|
|
|
def __init__(self, **kwargs):
|
|
|
@@ -61,21 +72,21 @@ class SparseModel(tf.keras.Model):
|
|
|
for table_size, dim in zip(self.categorical_cardinalities, self.embedding_dim):
|
|
|
if hvd.rank() == 0:
|
|
|
print(f'Creating embedding with size: {table_size} {dim}')
|
|
|
- if self.use_mde_embeddings:
|
|
|
- e = embedding.Embedding(input_dim=table_size, output_dim=dim,
|
|
|
- combiner='sum', embeddings_initializer=initializer_cls())
|
|
|
- else:
|
|
|
- e = tf.keras.layers.Embedding(input_dim=table_size, output_dim=dim,
|
|
|
- embeddings_initializer=initializer_cls())
|
|
|
+ e = tf.keras.layers.Embedding(input_dim=table_size, output_dim=dim,
|
|
|
+ embeddings_initializer=initializer_cls())
|
|
|
self.embedding_layers.append(e)
|
|
|
|
|
|
+ gpu_size = _gigabytes_to_elements(self.cpu_offloading_threshold_gb)
|
|
|
self.embedding = dmp.DistributedEmbedding(self.embedding_layers,
|
|
|
strategy='memory_balanced',
|
|
|
- dp_input=False,
|
|
|
- column_slice_threshold=self.column_slice_threshold)
|
|
|
+ dp_input=self.data_parallel_input,
|
|
|
+ column_slice_threshold=self.column_slice_threshold,
|
|
|
+ row_slice_threshold=self.row_slice_threshold,
|
|
|
+ data_parallel_threshold=self.data_parallel_threshold,
|
|
|
+ gpu_embedding_size=gpu_size)
|
|
|
|
|
|
def get_local_table_ids(self, rank):
|
|
|
- if self.use_concat_embedding:
|
|
|
+ if self.use_concat_embedding or self.data_parallel_input:
|
|
|
return list(range(self.num_all_categorical_features))
|
|
|
else:
|
|
|
return self.embedding.strategy.input_ids_list[rank]
|
|
|
@@ -127,4 +138,10 @@ class SparseModel(tf.keras.Model):
|
|
|
def from_config(path):
|
|
|
with open(path) as f:
|
|
|
config = json.load(fp=f)
|
|
|
+ if 'data_parallel_input' not in config:
|
|
|
+ config['data_parallel_input'] = False
|
|
|
+ if 'row_slice_threshold' not in config:
|
|
|
+ config['row_slice_threshold'] = None
|
|
|
+ if 'data_parallel_threshold' not in config:
|
|
|
+ config['data_parallel_threshold'] = None
|
|
|
return SparseModel(**config)
|