Переглянути джерело

[DLRM/TF2] Support TensorFlow 2.10

Tomasz Grel 3 роки тому
батько
коміт
d6f4301a38

+ 3 - 2
TensorFlow2/Recommendation/DLRM/Dockerfile

@@ -22,12 +22,13 @@ WORKDIR /dlrm
 
 ADD requirements.txt .
 
-RUN pip install -r requirements.txt
+RUN pip install --upgrade pip && pip install -r requirements.txt
 
 RUN rm -rf distributed-embeddings &&\
     git clone https://github.com/NVIDIA-Merlin/distributed-embeddings.git &&\
     cd distributed-embeddings &&\
-    git checkout 427f869ac &&\
+    git checkout v0.2 &&\
+    git submodule init && git submodule update &&\
     pip uninstall -y distributed-embeddings &&\
     make clean &&\
     make pip_pkg -j all &&\

+ 10 - 4
TensorFlow2/Recommendation/DLRM/tensorflow-dot-based-interact/Makefile

@@ -19,8 +19,14 @@ PYTHON_BIN_PATH = python
 
 TF_CFLAGS := $(shell $(PYTHON_BIN_PATH) -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))')
 TF_LFLAGS := $(shell $(PYTHON_BIN_PATH) -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))')
-
-CFLAGS = ${TF_CFLAGS} -fPIC -O2 -std=c++14
+TF_VERSION := $(shell $(PYTHON_BIN_PATH) -c 'import tensorflow as tf; print(int(tf.__version__.split(".")[1]))')
+ifeq ($(shell expr $(TF_VERSION) \>= 10), 1)
+	  CPP_STD := 17
+	else
+	  CPP_STD := 14
+	endif
+
+CFLAGS = ${TF_CFLAGS} -fPIC -O2 -std=c++${CPP_STD}
 LDFLAGS = -shared ${TF_LFLAGS}
 
 .DEFAULT_GOAL := lib
@@ -50,11 +56,11 @@ endif
 
 volta: $(VOLTA_TARGET_OBJECT)
 $(VOLTA_TARGET_OBJECT): $(CC_SRC_DIR)/kernels/volta/dot_based_interact_volta.cu
-	$(NVCC) -std=c++14 -c -o $@ $^  $(TF_CFLAGS) -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC -DNDEBUG --expt-relaxed-constexpr -arch=sm_70
+	$(NVCC) -std=c++${CPP_STD} -c -o $@ $^  $(TF_CFLAGS) -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC -DNDEBUG --expt-relaxed-constexpr -arch=sm_70
 
 ampere: $(AMPERE_TARGET_OBJECT)
 $(AMPERE_TARGET_OBJECT): $(CC_SRC_DIR)/kernels/ampere/dot_based_interact_ampere.cu
-	$(NVCC) -std=c++14 -c -o $@ $^  $(TF_CFLAGS) -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC -DNDEBUG --expt-relaxed-constexpr -arch=sm_80
+	$(NVCC) -std=c++${CPP_STD} -c -o $@ $^  $(TF_CFLAGS) -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC -DNDEBUG --expt-relaxed-constexpr -arch=sm_80
 
 lib: $(TARGET_LIB)
 $(TARGET_LIB): $(CC_SRCS) $(VOLTA_TARGET_OBJECT) $(AMPERE_TARGET_OBJECT)