Просмотр исходного кода

Byshiue patch 2 (#805)

* fix: fix the problem that we do not destroy the cublaslt Desc and lead to memory leak
byshiue 5 лет назад
Родитель
Сommit
3d0d45b409

+ 10 - 0
FasterTransformer/v3.0/fastertransformer/common.h

@@ -193,6 +193,11 @@ void cublasLtMM_withAlgo(int *res, int batchCount, int m, int n, int k,
                  res,
                  CtransformDesc,
                  (findAlgo == 1 ? (&algo) : NULL), NULL, 0, stream);
+
+  cublasLtMatmulDescDestroy(matmulDesc);
+  cublasLtMatrixLayoutDestroy(AtransformDesc);
+  cublasLtMatrixLayoutDestroy(BtransformDesc);
+  cublasLtMatrixLayoutDestroy(CtransformDesc);
 }
 
 //for int8 IO cublasLtMM with algo
@@ -281,6 +286,11 @@ void cublasLtMM_withAlgo_int8IO(int8_t *res, int batchCount, int m, int n, int k
                  res,
                  CtransformDesc,
                  (findAlgo == 1 ? (&algo) : NULL), NULL, 0, stream);
+
+  cublasLtMatmulDescDestroy(matmulDesc);
+  cublasLtMatrixLayoutDestroy(AtransformDesc);
+  cublasLtMatrixLayoutDestroy(BtransformDesc);
+  cublasLtMatrixLayoutDestroy(CtransformDesc);
 }
 
 template <typename T>

+ 10 - 0
FasterTransformer/v3.1/fastertransformer/common.h

@@ -243,6 +243,11 @@ void cublasLtMM_withAlgo(int *res, int batchCount, int m, int n, int k,
                  res,
                  CtransformDesc,
                  (findAlgo == 1 ? (&algo) : NULL), NULL, 0, stream);
+
+  cublasLtMatmulDescDestroy(matmulDesc);
+  cublasLtMatrixLayoutDestroy(AtransformDesc);
+  cublasLtMatrixLayoutDestroy(BtransformDesc);
+  cublasLtMatrixLayoutDestroy(CtransformDesc);
 }
 
 //for int8 IO cublasLtMM with algo
@@ -384,6 +389,11 @@ void cublasLtMM_withAlgo_int8IO(int8_t *res, int batchCount, int m, int n, int k
                  res,
                  CtransformDesc,
                  (findAlgo == 1 ? (&algo) : NULL), NULL, 0, stream);
+
+  cublasLtMatmulDescDestroy(matmulDesc);
+  cublasLtMatrixLayoutDestroy(AtransformDesc);
+  cublasLtMatrixLayoutDestroy(BtransformDesc);
+  cublasLtMatrixLayoutDestroy(CtransformDesc);
 }
 
 template <typename T>