feature_spec.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from typing import Dict
  16. import numpy as np
  17. import yaml
  18. from sim.data.defaults import (CARDINALITY_SELECTOR, DIMENSIONS_SELECTOR, DTYPE_SELECTOR, LABEL_CHANNEL,
  19. NEGATIVE_HISTORY_CHANNEL, POSITIVE_HISTORY_CHANNEL, TARGET_ITEM_FEATURES_CHANNEL,
  20. TEST_MAPPING, TRAIN_MAPPING, USER_FEATURES_CHANNEL)
  21. class FeatureSpec:
  22. def __init__(self, feature_spec=None, source_spec=None, channel_spec=None, metadata=None, base_directory=None):
  23. self.feature_spec: Dict = feature_spec if feature_spec is not None else {}
  24. self.source_spec: Dict = source_spec if source_spec is not None else {}
  25. self.channel_spec: Dict = channel_spec if channel_spec is not None else {}
  26. self.metadata: Dict = metadata if metadata is not None else {}
  27. self.base_directory: str = base_directory
  28. @classmethod
  29. def from_yaml(cls, path):
  30. with open(path, 'r') as feature_spec_file:
  31. base_directory = os.path.dirname(path)
  32. feature_spec = yaml.safe_load(feature_spec_file)
  33. return cls.from_dict(feature_spec, base_directory=base_directory)
  34. @classmethod
  35. def from_dict(cls, source_dict, base_directory):
  36. return cls(base_directory=base_directory, **source_dict)
  37. def to_dict(self):
  38. attributes_to_dump = ['feature_spec', 'source_spec', 'channel_spec', 'metadata']
  39. return {attr: self.__dict__[attr] for attr in attributes_to_dump}
  40. def to_string(self):
  41. return yaml.dump(self.to_dict())
  42. def to_yaml(self, output_path=None):
  43. if not output_path:
  44. output_path = self.base_directory + '/feature_spec.yaml'
  45. with open(output_path, 'w') as output_file:
  46. print(yaml.dump(self.to_dict()), file=output_file)
  47. @staticmethod
  48. def get_default_features_names(number_of_user_features, number_of_item_features):
  49. user_feature_fstring = 'user_feat_{}'
  50. item_feature_fstring = 'item_feat_{}_{}'
  51. label_feature_name = "label"
  52. item_channels_feature_name_suffixes = ['trgt', 'pos', 'neg']
  53. user_features_names = [user_feature_fstring.format(i) for i in range(number_of_user_features)]
  54. item_features_names = [item_feature_fstring.format(i, channel_suffix)
  55. for channel_suffix in item_channels_feature_name_suffixes
  56. for i in range(number_of_item_features)]
  57. return [label_feature_name] + user_features_names + item_features_names
  58. @staticmethod
  59. def get_default_feature_spec(user_features_cardinalities, item_features_cardinalities, max_seq_len):
  60. number_of_user_features = len(user_features_cardinalities)
  61. number_of_item_features = len(item_features_cardinalities)
  62. all_features_names = FeatureSpec.get_default_features_names(number_of_user_features, number_of_item_features)
  63. user_features = {
  64. f_name: {
  65. DTYPE_SELECTOR: str(np.dtype(np.int64)),
  66. CARDINALITY_SELECTOR: int(cardinality)
  67. } for i, (f_name, cardinality)
  68. in enumerate(zip(all_features_names[1:1+number_of_user_features], user_features_cardinalities))
  69. }
  70. item_channels = [TARGET_ITEM_FEATURES_CHANNEL, POSITIVE_HISTORY_CHANNEL, NEGATIVE_HISTORY_CHANNEL]
  71. item_channels_feature_dicts = [{} for _ in range(len(item_channels))]
  72. item_channels_info = list(zip(item_channels, item_channels_feature_dicts))
  73. for i, cardinality in enumerate(item_features_cardinalities):
  74. for j, (channel, dictionary) in enumerate(item_channels_info):
  75. feature_name = all_features_names[1 + number_of_user_features + i + j * number_of_item_features]
  76. dictionary[feature_name] = {
  77. DTYPE_SELECTOR: str(np.dtype(np.int64)),
  78. CARDINALITY_SELECTOR: int(cardinality)
  79. }
  80. if channel != TARGET_ITEM_FEATURES_CHANNEL:
  81. dictionary[feature_name][DIMENSIONS_SELECTOR] = [max_seq_len]
  82. feature_spec = {
  83. feat_name: feat_spec
  84. for dictionary in [user_features] + item_channels_feature_dicts
  85. for feat_name, feat_spec in dictionary.items()
  86. }
  87. feature_spec[all_features_names[0]] = {DTYPE_SELECTOR: str(np.dtype(np.bool))}
  88. channel_spec = {
  89. USER_FEATURES_CHANNEL: list(user_features),
  90. TARGET_ITEM_FEATURES_CHANNEL: list(item_channels_feature_dicts[0]),
  91. POSITIVE_HISTORY_CHANNEL: list(item_channels_feature_dicts[1]),
  92. NEGATIVE_HISTORY_CHANNEL: list(item_channels_feature_dicts[2]),
  93. LABEL_CHANNEL: all_features_names[:1]
  94. }
  95. source_spec = {
  96. split: [
  97. {
  98. 'type': 'tfrecord',
  99. 'features': all_features_names,
  100. 'files': []
  101. }
  102. ] for split in [TRAIN_MAPPING, TEST_MAPPING]
  103. }
  104. return FeatureSpec(feature_spec=feature_spec, channel_spec=channel_spec, source_spec=source_spec)