|
|
@@ -269,13 +269,13 @@ class TransformerPlugin: public IPluginV2
|
|
|
|
|
|
bool supportsFormat(nvinfer1::DataType type, PluginFormat format) const override
|
|
|
{
|
|
|
- return type == nvinfer1::DataType::kFLOAT && format == PluginFormat::kNCHW;
|
|
|
+ return type == TransformerTrtTraits<T>::DataType && format == PluginFormat::kNCHW;
|
|
|
}
|
|
|
|
|
|
void configureWithFormat(const Dims* pInputDim, int nInputDim, const Dims* pOutputDim,
|
|
|
int nOutputDim, nvinfer1::DataType dataType, nvinfer1::PluginFormat pluginFormat, int maxBatchSize) override
|
|
|
{
|
|
|
- assert(dataType == nvinfer1::DataType::kFLOAT && pluginFormat == nvinfer1::PluginFormat::kNCHW);
|
|
|
+ assert(dataType == TransformerTrtTraits<T>::DataType && pluginFormat == nvinfer1::PluginFormat::kNCHW);
|
|
|
assert(nInputDim == 2);
|
|
|
assert(pInputDim[0].nbDims == 2 && pInputDim[0].d[0] == seq_len_ && pInputDim[0].d[1] == hidden_dim_);
|
|
|
assert(pInputDim[1].nbDims == 2 && pInputDim[1].d[0] == seq_len_ && pInputDim[1].d[1] == seq_len_);
|