dcn.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. # Copyright 2021 The TensorFlow Recommenders Authors.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. #
  16. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
  17. #
  18. # Licensed under the Apache License, Version 2.0 (the "License");
  19. # you may not use this file except in compliance with the License.
  20. # You may obtain a copy of the License at
  21. #
  22. # http://www.apache.org/licenses/LICENSE-2.0
  23. #
  24. # Unless required by applicable law or agreed to in writing, software
  25. # distributed under the License is distributed on an "AS IS" BASIS,
  26. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  27. # See the License for the specific language governing permissions and
  28. # limitations under the License.
  29. #
  30. """Implements `Cross` Layer, the cross layer in Deep & Cross Network (DCN)."""
  31. from typing import Union, Text, Optional
  32. import tensorflow as tf
  33. @tf.keras.utils.register_keras_serializable()
  34. class Cross(tf.keras.layers.Layer):
  35. """Cross Layer in Deep & Cross Network to learn explicit feature interactions.
  36. A layer that creates explicit and bounded-degree feature interactions
  37. efficiently. The `call` method accepts `inputs` as a tuple of size 2
  38. tensors. The first input `x0` is the base layer that contains the original
  39. features (usually the embedding layer); the second input `xi` is the output
  40. of the previous `Cross` layer in the stack, i.e., the i-th `Cross`
  41. layer. For the first `Cross` layer in the stack, x0 = xi.
  42. The output is x_{i+1} = x0 .* (W * xi + bias + diag_scale * xi) + xi,
  43. where .* designates elementwise multiplication, W could be a full-rank
  44. matrix, or a low-rank matrix U*V to reduce the computational cost, and
  45. diag_scale increases the diagonal of W to improve training stability (
  46. especially for the low-rank case).
  47. References:
  48. 1. [R. Wang et al.](https://arxiv.org/pdf/2008.13535.pdf)
  49. See Eq. (1) for full-rank and Eq. (2) for low-rank version.
  50. 2. [R. Wang et al.](https://arxiv.org/pdf/1708.05123.pdf)
  51. Example:
  52. ```python
  53. # after embedding layer in a functional model:
  54. input = tf.keras.Input(shape=(None,), name='index', dtype=tf.int64)
  55. x0 = tf.keras.layers.Embedding(input_dim=32, output_dim=6)
  56. x1 = Cross()(x0, x0)
  57. x2 = Cross()(x0, x1)
  58. logits = tf.keras.layers.Dense(units=10)(x2)
  59. model = tf.keras.Model(input, logits)
  60. ```
  61. Args:
  62. projection_dim: project dimension to reduce the computational cost.
  63. Default is `None` such that a full (`input_dim` by `input_dim`) matrix
  64. W is used. If enabled, a low-rank matrix W = U*V will be used, where U
  65. is of size `input_dim` by `projection_dim` and V is of size
  66. `projection_dim` by `input_dim`. `projection_dim` need to be smaller
  67. than `input_dim`/2 to improve the model efficiency. In practice, we've
  68. observed that `projection_dim` = d/4 consistently preserved the
  69. accuracy of a full-rank version.
  70. diag_scale: a non-negative float used to increase the diagonal of the
  71. kernel W by `diag_scale`, that is, W + diag_scale * I, where I is an
  72. identity matrix.
  73. use_bias: whether to add a bias term for this layer. If set to False,
  74. no bias term will be used.
  75. kernel_initializer: Initializer to use on the kernel matrix.
  76. bias_initializer: Initializer to use on the bias vector.
  77. kernel_regularizer: Regularizer to use on the kernel matrix.
  78. bias_regularizer: Regularizer to use on bias vector.
  79. Input shape: A tuple of 2 (batch_size, `input_dim`) dimensional inputs.
  80. Output shape: A single (batch_size, `input_dim`) dimensional output.
  81. """
  82. def __init__(
  83. self,
  84. projection_dim: Optional[int] = None,
  85. diag_scale: Optional[float] = 0.0,
  86. use_bias: bool = True,
  87. kernel_initializer: Union[
  88. Text, tf.keras.initializers.Initializer] = "truncated_normal",
  89. bias_initializer: Union[Text,
  90. tf.keras.initializers.Initializer] = "zeros",
  91. kernel_regularizer: Union[Text, None,
  92. tf.keras.regularizers.Regularizer] = None,
  93. bias_regularizer: Union[Text, None,
  94. tf.keras.regularizers.Regularizer] = None,
  95. **kwargs):
  96. super(Cross, self).__init__(**kwargs)
  97. self._projection_dim = projection_dim
  98. self._diag_scale = diag_scale
  99. self._use_bias = use_bias
  100. self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
  101. self._bias_initializer = tf.keras.initializers.get(bias_initializer)
  102. self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
  103. self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
  104. self._input_dim = None
  105. self._supports_masking = True
  106. if self._diag_scale < 0:
  107. raise ValueError(
  108. "`diag_scale` should be non-negative. Got `diag_scale` = {}".format(
  109. self._diag_scale))
  110. def build(self, input_shape):
  111. last_dim = input_shape[-1]
  112. if self._projection_dim is None:
  113. self._dense = tf.keras.layers.Dense(
  114. last_dim,
  115. kernel_initializer=self._kernel_initializer,
  116. bias_initializer=self._bias_initializer,
  117. kernel_regularizer=self._kernel_regularizer,
  118. bias_regularizer=self._bias_regularizer,
  119. use_bias=self._use_bias,
  120. )
  121. else:
  122. self._dense_u = tf.keras.layers.Dense(
  123. self._projection_dim,
  124. kernel_initializer=self._kernel_initializer,
  125. kernel_regularizer=self._kernel_regularizer,
  126. use_bias=False,
  127. )
  128. self._dense_v = tf.keras.layers.Dense(
  129. last_dim,
  130. kernel_initializer=self._kernel_initializer,
  131. bias_initializer=self._bias_initializer,
  132. kernel_regularizer=self._kernel_regularizer,
  133. bias_regularizer=self._bias_regularizer,
  134. use_bias=self._use_bias,
  135. )
  136. self.built = True
  137. def call(self, x0: tf.Tensor, x: Optional[tf.Tensor] = None) -> tf.Tensor:
  138. """Computes the feature cross.
  139. Args:
  140. x0: The input tensor
  141. x: Optional second input tensor. If provided, the layer will compute
  142. crosses between x0 and x; if not provided, the layer will compute
  143. crosses between x0 and itself.
  144. Returns:
  145. Tensor of crosses.
  146. """
  147. if not self.built:
  148. self.build(x0.shape)
  149. if x is None:
  150. x = x0
  151. if x0.shape[-1] != x.shape[-1]:
  152. raise ValueError(
  153. "`x0` and `x` dimension mismatch! Got `x0` dimension {}, and x "
  154. "dimension {}. This case is not supported yet.".format(
  155. x0.shape[-1], x.shape[-1]))
  156. if self._projection_dim is None:
  157. prod_output = self._dense(x)
  158. else:
  159. prod_output = self._dense_v(self._dense_u(x))
  160. if self._diag_scale:
  161. prod_output = prod_output + self._diag_scale * x
  162. return x0 * prod_output + x
  163. def get_config(self):
  164. config = {
  165. "projection_dim":
  166. self._projection_dim,
  167. "diag_scale":
  168. self._diag_scale,
  169. "use_bias":
  170. self._use_bias,
  171. "kernel_initializer":
  172. tf.keras.initializers.serialize(self._kernel_initializer),
  173. "bias_initializer":
  174. tf.keras.initializers.serialize(self._bias_initializer),
  175. "kernel_regularizer":
  176. tf.keras.regularizers.serialize(self._kernel_regularizer),
  177. "bias_regularizer":
  178. tf.keras.regularizers.serialize(self._bias_regularizer),
  179. }
  180. base_config = super(Cross, self).get_config()
  181. return dict(list(base_config.items()) + list(config.items()))
  182. class CrossNetwork(tf.Module):
  183. def __init__(self, num_layers, projection_dim=None):
  184. self.cross_layers = []
  185. for _ in range(num_layers):
  186. self.cross_layers.append(Cross(projection_dim=projection_dim))
  187. def __call__(self, x0):
  188. x = x0
  189. for cl in self.cross_layers:
  190. x = cl(x0=x0, x=x)
  191. return x