setup.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. # Copyright (c) 2021 NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from setuptools import setup, find_packages
  16. from torch.utils.cpp_extension import BuildExtension, CUDAExtension
  17. abspath = os.path.dirname(os.path.realpath(__file__))
  18. print(find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]))
  19. setup(name="dlrm",
  20. package_dir={'dlrm': 'dlrm'},
  21. version="1.0.0",
  22. description="Reimplementation of Facebook's DLRM",
  23. packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]),
  24. zip_safe=False,
  25. ext_modules=[
  26. CUDAExtension(name="dlrm.cuda_ext.fused_embedding",
  27. sources=[
  28. os.path.join(abspath, "dlrm/cuda_src/pytorch_embedding_ops.cpp"),
  29. os.path.join(abspath, "dlrm/cuda_src/gather_gpu_fused_pytorch_impl.cu")
  30. ],
  31. extra_compile_args={
  32. 'cxx': [],
  33. 'nvcc': ["-arch=sm_70",
  34. '-gencode', 'arch=compute_80,code=sm_80']
  35. }),
  36. CUDAExtension(name="dlrm.cuda_ext.interaction_volta",
  37. sources=[
  38. os.path.join(abspath, "dlrm/cuda_src/dot_based_interact_volta/pytorch_ops.cpp"),
  39. os.path.join(abspath, "dlrm/cuda_src/dot_based_interact_volta/dot_based_interact_pytorch_types.cu")
  40. ],
  41. extra_compile_args={
  42. 'cxx': [],
  43. 'nvcc': [
  44. '-DCUDA_HAS_FP16=1',
  45. '-D__CUDA_NO_HALF_OPERATORS__',
  46. '-D__CUDA_NO_HALF_CONVERSIONS__',
  47. '-D__CUDA_NO_HALF2_OPERATORS__',
  48. '-gencode', 'arch=compute_70,code=sm_70']
  49. }),
  50. CUDAExtension(name="dlrm.cuda_ext.interaction_ampere",
  51. sources=[
  52. os.path.join(abspath, "dlrm/cuda_src/dot_based_interact_ampere/pytorch_ops.cpp"),
  53. os.path.join(abspath, "dlrm/cuda_src/dot_based_interact_ampere/dot_based_interact_pytorch_types.cu")
  54. ],
  55. extra_compile_args={
  56. 'cxx': [],
  57. 'nvcc': [
  58. '-DCUDA_HAS_FP16=1',
  59. '-D__CUDA_NO_HALF_OPERATORS__',
  60. '-D__CUDA_NO_HALF_CONVERSIONS__',
  61. '-D__CUDA_NO_HALF2_OPERATORS__',
  62. '-gencode', 'arch=compute_80,code=sm_80']
  63. }),
  64. CUDAExtension(name="dlrm.cuda_ext.sparse_gather",
  65. sources=[
  66. os.path.join(abspath, "dlrm/cuda_src/sparse_gather/sparse_pytorch_ops.cpp"),
  67. os.path.join(abspath, "dlrm/cuda_src/sparse_gather/gather_gpu.cu")
  68. ],
  69. extra_compile_args={
  70. 'cxx': [],
  71. 'nvcc': ["-arch=sm_70",
  72. '-gencode', 'arch=compute_80,code=sm_80']
  73. })
  74. ],
  75. cmdclass={"build_ext": BuildExtension})