|
|
@@ -1,95 +0,0 @@
|
|
|
-#!/usr/bin/env python3
|
|
|
-# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
|
|
-#
|
|
|
-# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
-# you may not use this file except in compliance with the License.
|
|
|
-# You may obtain a copy of the License at
|
|
|
-#
|
|
|
-# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
-#
|
|
|
-# Unless required by applicable law or agreed to in writing, software
|
|
|
-# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
-# See the License for the specific language governing permissions and
|
|
|
-# limitations under the License.
|
|
|
-
|
|
|
-import tensorrt as trt
|
|
|
-import os
|
|
|
-
|
|
|
-import pycuda.driver as cuda
|
|
|
-import pycuda.autoinit # lgtm[py/unused-import]
|
|
|
-import numpy as np
|
|
|
-import helpers.tokenization as tokenization
|
|
|
-import helpers.data_processing as dp
|
|
|
-
|
|
|
-class BertCalibrator(trt.IInt8MinMaxCalibrator):
|
|
|
- def __init__(self, squad_json, vocab_file, cache_file, batch_size, max_seq_length, num_inputs):
|
|
|
- # Whenever you specify a custom constructor for a TensorRT class,
|
|
|
- # you MUST call the constructor of the parent explicitly.
|
|
|
- trt.IInt8MinMaxCalibrator.__init__(self)
|
|
|
-
|
|
|
- self.cache_file = cache_file
|
|
|
-
|
|
|
- # Every time get_batch is called, the next batch of size batch_size will be copied to the device and returned.
|
|
|
- self.data = dp.read_squad_json(squad_json)
|
|
|
- self.max_seq_length = max_seq_length
|
|
|
- self.batch_size = batch_size
|
|
|
- self.current_index = 0
|
|
|
- self.num_inputs = num_inputs
|
|
|
- self.tokenizer = tokenization.BertTokenizer(vocab_file=vocab_file, do_lower_case=True)
|
|
|
- self.doc_stride = 128
|
|
|
- self.max_query_length = 64
|
|
|
-
|
|
|
- # Allocate enough memory for a whole batch.
|
|
|
- self.device_inputs = [cuda.mem_alloc(self.max_seq_length * trt.int32.itemsize * self.batch_size) for binding in range(3)]
|
|
|
-
|
|
|
- def free(self):
|
|
|
- for dinput in self.device_inputs:
|
|
|
- dinput.free()
|
|
|
-
|
|
|
- def get_batch_size(self):
|
|
|
- return self.batch_size
|
|
|
-
|
|
|
- # TensorRT passes along the names of the engine bindings to the get_batch function.
|
|
|
- # You don't necessarily have to use them, but they can be useful to understand the order of
|
|
|
- # the inputs. The bindings list is expected to have the same ordering as 'names'.
|
|
|
- def get_batch(self, bindings, names):
|
|
|
- if self.current_index + self.batch_size > self.num_inputs:
|
|
|
- print("Calibrating index {:} batch size {:} exceed max input limit {:} sentences".format(self.current_index, self.batch_size, self.num_inputs))
|
|
|
- return None
|
|
|
-
|
|
|
- current_batch = int(self.current_index / self.batch_size)
|
|
|
- if current_batch % 10 == 0:
|
|
|
- print("Calibrating batch {:}, containing {:} sentences".format(current_batch, self.batch_size))
|
|
|
-
|
|
|
- input_ids = []
|
|
|
- segment_ids = []
|
|
|
- input_mask = []
|
|
|
- for i in range(self.batch_size):
|
|
|
- example = self.data[self.current_index + i]
|
|
|
- features = dp.convert_example_to_features(example.doc_tokens, example.question_text, self.tokenizer, self.max_seq_length, self.doc_stride, self.max_query_length)
|
|
|
- if len(input_ids) and len(segment_ids) and len(input_mask):
|
|
|
- input_ids = np.concatenate((input_ids, features[0].input_ids))
|
|
|
- segment_ids = np.concatenate((segment_ids, features[0].segment_ids))
|
|
|
- input_mask = np.concatenate((input_mask, features[0].input_mask))
|
|
|
- else:
|
|
|
- input_ids = features[0].input_ids
|
|
|
- segment_ids = features[0].segment_ids
|
|
|
- input_mask = features[0].input_mask
|
|
|
-
|
|
|
- cuda.memcpy_htod(self.device_inputs[0], input_ids.ravel())
|
|
|
- cuda.memcpy_htod(self.device_inputs[1], segment_ids.ravel())
|
|
|
- cuda.memcpy_htod(self.device_inputs[2], input_mask.ravel())
|
|
|
-
|
|
|
- self.current_index += self.batch_size
|
|
|
- return self.device_inputs
|
|
|
-
|
|
|
- def read_calibration_cache(self):
|
|
|
- # If there is a cache, use it instead of calibrating again. Otherwise, implicitly return None.
|
|
|
- if os.path.exists(self.cache_file):
|
|
|
- with open(self.cache_file, "rb") as f:
|
|
|
- return f.read()
|
|
|
-
|
|
|
- def write_calibration_cache(self, cache):
|
|
|
- with open(self.cache_file, "wb") as f:
|
|
|
- f.write(cache)
|