main.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. # Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # DO NOT REMOVE THIS IMPORT
  15. # It is here to initialize nvtabular before tensorflow is initialized.
  16. # Removing it leads to a drop in nvtabular dataloader performance
  17. # Do not put other imports before this without running performance validation
  18. import nvtabular # noqa # pylint: disable=unused-import
  19. # See above
  20. import os
  21. os.environ["TF_GPU_ALLOCATOR"]="cuda_malloc_async"
  22. from trainer.model.widedeep import wide_deep_model
  23. from trainer.run import run
  24. from trainer.utils.arguments import parse_args
  25. from trainer.utils.setup import create_config
  26. def main():
  27. args = parse_args()
  28. config = create_config(args)
  29. model, _ = wide_deep_model(args, config["feature_spec"], config["embedding_dimensions"])
  30. run(args, model, config)
  31. if __name__ == "__main__":
  32. main()