dcnv2.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #
  15. # author: Tomasz Grel ([email protected])
  16. from absl import app, flags
  17. def define_dcnv2_specific_flags():
  18. flags.DEFINE_integer("batch_size", default=64 * 1024, help="Batch size used for training")
  19. flags.DEFINE_integer("valid_batch_size", default=64 * 1024, help="Batch size used for validation")
  20. flags.DEFINE_list("top_mlp_dims", [1024, 1024, 512, 256, 1], "Linear layer sizes for the top MLP")
  21. flags.DEFINE_list("bottom_mlp_dims", [512, 256, 128], "Linear layer sizes for the bottom MLP")
  22. flags.DEFINE_string("embedding_dim", default='128', help='Number of columns in the embedding tables')
  23. flags.DEFINE_enum("optimizer", default="adam", enum_values=['sgd', 'adam'],
  24. help='The optimization algorithm to be used.')
  25. flags.DEFINE_enum("interaction", default="cross", enum_values=["dot_custom_cuda", "dot_tensorflow", "cross"],
  26. help="Feature interaction implementation to use")
  27. flags.DEFINE_float("learning_rate", default=0.0001, help="Learning rate")
  28. flags.DEFINE_float("beta1", default=0.9, help="Beta1 for the Adam optimizer")
  29. flags.DEFINE_float("beta2", default=0.999, help="Bea2 for the Adam optimizer")
  30. flags.DEFINE_integer("warmup_steps", default=100,
  31. help='Number of steps over which to linearly increase the LR at the beginning')
  32. flags.DEFINE_integer("decay_start_step", default=48000, help='Optimization step at which to start the poly LR decay')
  33. flags.DEFINE_integer("decay_steps", default=24000, help='Number of steps over which to decay from base LR to 0')
  34. flags.DEFINE_integer("num_cross_layers", default=3, help='Number of cross layers for DCNv2')
  35. flags.DEFINE_integer("cross_layer_projection_dim", default=512, help='Projection dimension used in the cross layers')
  36. define_dcnv2_specific_flags()
  37. import main
  38. def _main(argv):
  39. main.main()
  40. if __name__ == '__main__':
  41. app.run(_main)