瀏覽代碼

Darknet .weights detection (#395)

Lutz Roeder 5 年之前
父節點
當前提交
07ad4b0ef8
共有 4 個文件被更改,包括 97 次插入42 次删除
  1. 82 39
      source/darknet.js
  2. 12 0
      source/pytorch-metadata.json
  3. 1 1
      source/view.js
  4. 2 2
      test/models.json

+ 82 - 39
source/darknet.js

@@ -6,44 +6,62 @@ var base = base || require('./base');
 darknet.ModelFactory = class {
 
     match(context) {
-        try {
-            const reader = base.TextReader.create(context.buffer);
-            for (;;) {
-                const line = reader.read();
-                if (line === undefined) {
-                    break;
+        const identifier = context.identifier;
+        const extension = identifier.split('.').pop().toLowerCase();
+        switch (extension) {
+            case 'weights':
+                if (darknet.Weights.open(context.buffer)) {
+                    return true;
                 }
-                const text = line.trim();
-                if (text.length === 0 || text.startsWith('#')) {
-                    continue;
+                break;
+            default:
+                try {
+                    const reader = base.TextReader.create(context.buffer);
+                    for (;;) {
+                        const line = reader.read();
+                        if (line === undefined) {
+                            break;
+                        }
+                        const text = line.trim();
+                        if (text.length === 0 || text.startsWith('#')) {
+                            continue;
+                        }
+                        if (text.startsWith('[') && text.endsWith(']')) {
+                            return true;
+                        }
+                    }
                 }
-                if (text.startsWith('[') && text.endsWith(']')) {
-                    return true;
+                catch (err) {
+                    // continue regardless of error
                 }
-            }
-        }
-        catch (err) {
-            // continue regardless of error
+                break;
         }
         return false;
     }
 
     open(context, host) {
         return darknet.Metadata.open(host).then((metadata) => {
+            const open = (metadata, cfg, weights) => {
+                return new darknet.Model(metadata, cfg, darknet.Weights.open(weights));
+            };
             const identifier = context.identifier;
             const parts = identifier.split('.');
-            parts.pop();
+            const extension = parts.pop().toLowerCase();
             const basename = parts.join('.');
-            return context.request(basename + '.weights', null).then((weights) => {
-                return this._openModel(metadata, identifier, context.buffer, weights);
-            }).catch(() => {
-                return this._openModel(metadata, identifier, context.buffer, null);
-            });
+            switch (extension) {
+                case 'weights':
+                    return context.request(basename + '.cfg', null).then((cfg) => {
+                        return open(metadata, cfg, context.buffer);
+                    });
+                default:
+                    return context.request(basename + '.weights', null).then((weights) => {
+                        return open(metadata, context.buffer, weights);
+                    }).catch(() => {
+                        return open(metadata, context.buffer, null);
+                    });
+            }
         });
     }
-    _openModel( metadata, identifier, cfg, weights) {
-        return new darknet.Model(metadata, cfg, weights ? new darknet.Weights(weights) : null);
-    }
 };
 
 darknet.Model = class {
@@ -141,7 +159,7 @@ darknet.Graph = class {
         };
 
         const load_weights = (name, shape, visible) => {
-            const data = weights ? weights.bytes(4 * shape.reduce((a, b) => a * b)) : null;
+            const data = weights ? weights.read(4 * shape.reduce((a, b) => a * b)) : null;
             const type = new darknet.TensorType('float32', make_shape(shape, 'load_weights'));
             const initializer = new darknet.Tensor(type, data);
             const argument = new darknet.Argument('', null, initializer);
@@ -1087,18 +1105,49 @@ darknet.TensorShape = class {
 
 darknet.Weights = class {
 
+    static open(buffer) {
+        if (buffer) {
+            const reader = new darknet.Weights.BinaryReader(buffer);
+            const major = reader.int32();
+            const minor = reader.int32();
+            const revision = reader.int32();
+            const seen = ((major * 10 + minor) >= 2) ? reader.int64() : reader.int32();
+            const transpose = (major > 1000) || (minor > 1000);
+            // if (transpose) {
+            //     throw new darknet.Error("Unsupported transpose weights file version '" + [ major, minor, revision ].join('.') + "'.");
+            // }
+            if (!transpose) {
+                return new darknet.Weights(reader);
+            }
+        }
+        return null;
+    }
+
+    constructor(reader) {
+        this._reader = reader;
+    }
+
+    read(size) {
+        return this._reader.bytes(size);
+    }
+
+    validate() {
+        if (!this._reader.end()) {
+            throw new darknet.Error('Invalid weights size.');
+        }
+    }
+};
+
+darknet.Weights.BinaryReader = class {
+
     constructor(buffer) {
         this._buffer = buffer;
         this._dataView = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
         this._position = 0;
-        const major = this.int32();
-        const minor = this.int32();
-        const revision = this.int32();
-        this._seen = ((major * 10 + minor) >= 2) ? this.int64() : this.int32();
-        const transpose = (major > 1000) || (minor > 1000);
-        if (transpose) {
-            throw new darknet.Error("Unsupported transpose weights file version '" + [ major, minor, revision ].join('.') + "'.");
-        }
+    }
+
+    end() {
+        return this._position === this._buffer.length;
     }
 
     int32() {
@@ -1125,12 +1174,6 @@ darknet.Weights = class {
             throw new darknet.Error('Expected ' + (this._position - this._buffer.length) + ' more bytes. The file might be corrupted. Unexpected end of file.');
         }
     }
-
-    validate() {
-        if (this._position !== this._buffer.length) {
-            throw new darknet.Error('Invalid weights size.');
-        }
-    }
 };
 
 darknet.Metadata = class {

+ 12 - 0
source/pytorch-metadata.json

@@ -5640,5 +5640,17 @@
         { "name": "output", "type": "Tensor" }
       ]
     }
+  },
+  {
+    "name": "torch.type_as",
+    "schema": {
+      "inputs": [
+        { "name": "self", "type": "Tensor" },
+        { "name": "other", "type": "Tensor" }
+      ],
+      "outputs": [
+        { "name": "output", "type": "Tensor" }
+      ]
+    }
   }
 ]

+ 1 - 1
source/view.js

@@ -1196,7 +1196,7 @@ view.ModelFactoryService = class {
         this.register('./cntk', [ '.model', '.cntk', '.cmf', '.dnn' ]);
         this.register('./paddle', [ '.paddle', '.pdmodel', '__model__', '.pbtxt', '.txt', '.tar', '.tar.gz' ]);
         this.register('./bigdl', [ '.model', '.bigdl' ]);
-        this.register('./darknet', [ '.cfg', '.model', '.txt' ]);
+        this.register('./darknet', [ '.cfg', '.model', '.txt', '.weights' ]);
         this.register('./weka', [ '.model' ]);
         this.register('./rknn', [ '.rknn' ]);
         this.register('./dlc', [ '.dlc' ]);

+ 2 - 2
test/models.json

@@ -1724,8 +1724,8 @@
   },
   {
     "type":   "darknet",
-    "target": "darknet53_448.cfg,darknet53_448.weights",
-    "source": "https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/darknet53_448.cfg,https://pjreddie.com/media/files/darknet53_448.weights",
+    "target": "darknet53_448.weights,darknet53_448.cfg",
+    "source": "https://pjreddie.com/media/files/darknet53_448.weights,https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/darknet53_448.cfg",
     "format": "Darknet",
     "link":   "https://pjreddie.com/darknet/imagenet"
   },