|
|
@@ -3733,10 +3733,12 @@ pytorch.Utility = class {
|
|
|
const obj = key === '' ? root : root[key];
|
|
|
if (obj && obj instanceof Map && obj.has('engine')) {
|
|
|
// https://github.com/NVIDIA-AI-IOT/torch2trt/blob/master/torch2trt/torch2trt.py
|
|
|
+ const data = obj.get('engine');
|
|
|
const signature = [ 0x70, 0x74, 0x72, 0x74 ]; // ptrt
|
|
|
- const buffer = obj.get('engine');
|
|
|
- if (buffer instanceof Uint8Array && buffer.length > signature.length && signature.every((value, index) => value === buffer[index])) {
|
|
|
- throw new pytorch.Error('Invalid file content. File contains undocumented PyTorch TensorRT engine data.');
|
|
|
+ if (data instanceof Uint8Array && data.length > signature.length && signature.every((value, index) => value === data[index])) {
|
|
|
+ const buffer = data.slice(0, 24);
|
|
|
+ const content = Array.from(buffer).map((c) => (c < 16 ? '0' : '') + c.toString(16)).join('');
|
|
|
+ throw new pytorch.Error("Invalid file content. File contains undocumented PyTorch TensorRT engine data (" + content.substring(8) + ").");
|
|
|
}
|
|
|
}
|
|
|
if (obj) {
|