Explorar el Código

[FT] 1. Fix the bug of TensorRT plugin of FasterTransformer encoder. (#640)

* [FT] 1. Fix the bug of TensorRT plugin of FasterTransformer encoder.
BO-YANG HSUEH hace 5 años
padre
commit
1aa6813450

+ 2 - 2
FasterTransformer/v2.1/fastertransformer/trt_plugin/bert_transformer_plugin.h

@@ -269,13 +269,13 @@ class TransformerPlugin: public IPluginV2
 
     bool supportsFormat(nvinfer1::DataType type, PluginFormat format) const override 
     {
-      return type == nvinfer1::DataType::kFLOAT && format == PluginFormat::kNCHW;
+      return type == TransformerTrtTraits<T>::DataType && format == PluginFormat::kNCHW;
     }
 
     void configureWithFormat(const Dims* pInputDim, int nInputDim, const Dims* pOutputDim, 
         int nOutputDim, nvinfer1::DataType dataType, nvinfer1::PluginFormat pluginFormat, int maxBatchSize) override 
     {
-      assert(dataType == nvinfer1::DataType::kFLOAT && pluginFormat == nvinfer1::PluginFormat::kNCHW);
+      assert(dataType == TransformerTrtTraits<T>::DataType && pluginFormat == nvinfer1::PluginFormat::kNCHW);
       assert(nInputDim == 2);
       assert(pInputDim[0].nbDims == 2 && pInputDim[0].d[0] == seq_len_ && pInputDim[0].d[1] == hidden_dim_);
       assert(pInputDim[1].nbDims == 2 && pInputDim[1].d[0] == seq_len_ && pInputDim[1].d[1] == seq_len_);

+ 1 - 1
FasterTransformer/v2.1/fastertransformer/trt_plugin/trt_model.h

@@ -104,7 +104,7 @@ class TRT_Transformer
 
       builder->setMaxBatchSize(batch_size_);
       builder->setMaxWorkspaceSize(1 << 20);
-      builder->setFp16Mode(false);
+      builder->setFp16Mode(sizeof(T) == 2);
 
       engine_ = builder->buildCudaEngine(*network);
       assert(engine_);