ソースを参照

TensorRT UFF detection (#511)

Lutz Roeder 5 年 前
コミット
f7e87f4c7f
4 ファイル変更58 行追加0 行削除
  1. 1 0
      setup.py
  2. 49 0
      src/tensorrt.js
  3. 1 0
      src/view.js
  4. 7 0
      test/models.json

+ 1 - 0
setup.py

@@ -121,6 +121,7 @@ setuptools.setup(
             'pytorch.js', 'pytorch-metadata.json', 'python.js',
             'sklearn.js', 'sklearn-metadata.json',
             'tengine.js', 'tengine-metadata.json', 
+            'tensorrt.js', 
             'tf.js', 'tf-metadata.json', 'tf-proto.js', 
             'tflite.js', 'tflite-metadata.json', 'tflite-schema.js',
             'torch.js', 'torch-metadata.json',

+ 49 - 0
src/tensorrt.js

@@ -0,0 +1,49 @@
+/* jshint esversion: 6 */
+/* eslint "indent": [ "error", 4, { "SwitchCase": 1 } ] */
+
+// Experimental
+
+var tensorrt = tensorrt || {};
+
+tensorrt.ModelFactory = class {
+
+    match(context) {
+        const identifier = context.identifier;
+        const extension = identifier.split('.').pop().toLowerCase();
+        if (extension === 'uff' || extension === 'pb') {
+            const tags = context.tags('pb');
+            if (tags.size > 0 &&
+                tags.has(1) && tags.get(1) === 0 &&
+                tags.has(2) && tags.get(2) === 0 &&
+                tags.has(3) && tags.get(3) === 2 &&
+                tags.has(4) && tags.get(4) === 2 &&
+                tags.has(5) && tags.get(5) === 2) {
+                return true;
+            }
+        }
+        if (extension === 'pbtxt') {
+            const tags = context.tags('pbtxt');
+            if (tags.has('version') && tags.has('descriptors') && tags.has('graphs')) {
+                return true;
+            }
+        }
+        return false;
+    }
+
+    open(context /*, host */) {
+        const identifier = context.identifier;
+        throw new tensorrt.Error("TensorRT UFF is a proprietary file format in '" + identifier + "'.");
+    }
+};
+
+tensorrt.Error = class extends Error {
+
+    constructor(message) {
+        super(message);
+        this.name = 'Error loading TensorRT model.';
+    }
+};
+
+if (typeof module !== 'undefined' && typeof module.exports === 'object') {
+    module.exports.ModelFactory = tensorrt.ModelFactory;
+}

+ 1 - 0
src/view.js

@@ -1140,6 +1140,7 @@ view.ModelFactoryService = class {
         this.register('./tflite', [ '.tflite', '.lite', '.tfl', '.bin', '.pb', '.tmfile', '.h5', '.model', '.json' ]);
         this.register('./tf', [ '.pb', '.meta', '.pbtxt', '.prototxt', '.json', '.index', '.ckpt' ]);
         this.register('./mediapipe', [ '.pbtxt' ]);
+        this.register('./tensorrt', [ '.uff', '.pb', '.pbtxt' ]);
         this.register('./sklearn', [ '.pkl', '.joblib', '.model', '.meta', '.pb', '.pt', '.h5' ]);
         this.register('./cntk', [ '.model', '.cntk', '.cmf', '.dnn' ]);
         this.register('./paddle', [ '.paddle', '__model__' ]);

+ 7 - 0
test/models.json

@@ -4511,6 +4511,13 @@
     "format": "Tengine v2.0",
     "link":   "https://github.com/pierricklee/tmfile-sample"
   },
+  {
+    "type":   "tensorrt",
+    "target": "tmp_v2_coco.pbtxt",
+    "source": "https://github.com/lutzroeder/netron/files/4741218/tmp_v2_coco.zip[tmp_v2_coco.pbtxt]",
+    "error":  "TensorRT UFF is a proprietary file format in 'tmp_v2_coco.pbtxt'.",
+    "link":   "https://github.com/lutzroeder/netron/issues/511"
+  },
   {
       "type":   "tf",
       "target": "bert_uncased_L-12_H-768_A-12.meta",