Procházet zdrojové kódy

MXNet Model Archive support (#286) (#285)

Lutz Roeder před 6 roky
rodič
revize
85dfeebabc
6 změnil soubory, kde provedl 178 přidání a 71 odebrání
  1. 2 0
      electron-builder.yml
  2. 1 1
      src/app.js
  3. 135 53
      src/mxnet.js
  4. 1 1
      src/view-browser.html
  5. 11 4
      src/view.js
  6. 28 12
      test/models.json

+ 2 - 0
electron-builder.yml

@@ -18,6 +18,8 @@ fileAssociations:
     ext: mlmodel
   - name: "MXNet Model"
     ext: model
+  - name: "MXNet Model"
+    ext: mar
   - name: "CNTK Model"
     ext: model
   - name: "CNTK Model"

+ 1 - 1
src/app.js

@@ -119,7 +119,7 @@ class Application {
                     'mlmodel',
                     'caffemodel',
                     'model', 'dnn', 'cmf',
-                    'params',
+                    'mar', 'params',
                     'meta',
                     'tflite', 'lite',
                     'pt', 'pth', 't7',

+ 135 - 53
src/mxnet.js

@@ -12,7 +12,7 @@ mxnet.ModelFactory = class {
         var identifier = context.identifier;
         var extension = identifier.split('.').pop().toLowerCase();
         var buffer = null;
-        if (extension == 'model') {
+        if (extension == 'model' || extension == 'mar') {
             buffer = context.buffer;
             if (buffer && buffer.length > 2 && buffer[0] == 0x50 && buffer[1] == 0x4B) {
                 return true;
@@ -32,7 +32,7 @@ mxnet.ModelFactory = class {
                 }
             }
         }
-        else if (mxnet.ModelFactory._basename(identifier, 'params')) {
+        else if (extension == 'params') {
             buffer = context.buffer;
             var signature = [ 0x12, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 ];
             if (buffer && buffer.length > signature.length && signature.every((v, i) => v == buffer[i])) {
@@ -73,15 +73,21 @@ mxnet.ModelFactory = class {
             case 'params':
                 params = context.buffer;
                 basename = mxnet.ModelFactory._basename(context.identifier, 'params');
-                return context.request(basename + '-symbol.json', 'utf-8').then((text) => {
-                    symbol = JSON.parse(text);
-                    if (symbol && symbol.nodes && symbol.nodes.some((node) => node && node.op == 'tvm_op')) {
-                        format  = 'TVM';
-                    }
-                    return this._openModel(identifier, format, null, symbol, null, params, host);
-                }).catch(() => {
+                if (basename) {
+                    return context.request(basename + '-symbol.json', 'utf-8').then((text) => {
+                        symbol = JSON.parse(text);
+                        if (symbol && symbol.nodes && symbol.nodes.some((node) => node && node.op == 'tvm_op')) {
+                            format  = 'TVM';
+                        }
+                        return this._openModel(identifier, format, null, symbol, null, params, host);
+                    }).catch(() => {
+                        return this._openModel(identifier, format, null, null, null, params, host);
+                    });
+                }
+                else {
                     return this._openModel(identifier, format, null, null, null, params, host);
-                });
+                }
+            case 'mar':
             case 'model':
                 var entries = {};
                 try {
@@ -94,7 +100,7 @@ mxnet.ModelFactory = class {
                     throw new mxnet.Error('Failed to decompress ZIP archive. ' + err.message);
                 }
 
-                var manifestEntry = entries['MANIFEST.json'];
+                var manifestEntry = entries['MANIFEST.json'] || entries['MAR-INF/MANIFEST.json'];
                 var rootFolder = '';
                 if (!manifestEntry) {
                     var folders = Object.keys(entries).filter((name) => name.endsWith('/')).filter((name) => entries[name + 'MANIFEST.json']);
@@ -113,46 +119,81 @@ mxnet.ModelFactory = class {
                 catch (err) {
                     throw new mxnet.Error('Failed to read manifest. ' + err.message);
                 }
-        
-                if (!manifest.Model) {
+
+                var modelFormat = null;
+                var symbolEntry = null;
+                var signatureEntry = null;
+                var paramsEntry = null;
+                if (manifest.Model) {
+                    modelFormat = manifest.Model['Model-Format'];
+                    if (modelFormat && modelFormat != 'MXNet-Symbolic') {
+                        throw new mxnet.Error('Model format \'' + modelFormat + '\' not supported.');
+                    }
+                    format = 'MXNet Model Server';
+                    if (manifest['Model-Archive-Version']) {
+                        format += ' v' + manifest['Model-Archive-Version'].toString();
+                    }
+                    if (!manifest.Model.Symbol) {
+                        throw new mxnet.Error('Manifest does not contain symbol entry.');
+                    }
+                    symbolEntry = entries[rootFolder + manifest.Model.Symbol];
+                    if (manifest.Model.Signature) {
+                        signatureEntry = entries[rootFolder + manifest.Model.Signature];
+                    }
+                    if (manifest.Model.Parameters) {
+                        paramsEntry = entries[rootFolder + manifest.Model.Parameters];
+                    }
+                }
+                else if (manifest.model) {
+                    format = 'MXNet Model Archive';
+                    if (manifest.specificationVersion) {
+                        format += ' v' + manifest.specificationVersion.toString();
+                    }
+                    if (manifest.model.modelName) {
+                        symbolEntry = entries[rootFolder + manifest.model.modelName + '-symbol.json']
+                        var key = null;
+                        for (key of Object.keys(entries)) {
+                            key = key.substring(rootFolder.length);
+                            if (key.endsWith('.params') && key.startsWith(manifest.model.modelName)) {
+                                paramsEntry = entries[key];
+                                break;
+                            }
+                        }
+                        if (!symbolEntry && !paramsEntry) {
+                            for (key of Object.keys(entries)) {
+                                key = key.substring(rootFolder.length);
+                                if (key.endsWith('.params')) {
+                                    paramsEntry = entries[key];
+                                    break;
+                                }
+                            }
+                        }
+                    }
+                }
+                else {
                     throw new mxnet.Error('Manifest does not contain model.');
                 }
 
-                var modelFormat = manifest.Model['Model-Format'];
-                if (modelFormat && modelFormat != 'MXNet-Symbolic') {
-                    throw new mxnet.Error('Model format \'' + modelFormat + '\' not supported.');
-                }
-        
-                if (!manifest.Model.Symbol) {
-                    throw new mxnet.Error('Manifest does not contain symbol entry.');
+                if (!symbolEntry && !paramsEntry) {
+                    throw new mxnet.Error("Model does not contain symbol entry.");
                 }
 
                 try {
-                    var symbolEntry = entries[rootFolder + manifest.Model.Symbol];
-                    symbol = JSON.parse(decoder.decode(symbolEntry.data));
+                    if (symbolEntry) {
+                        symbol = JSON.parse(decoder.decode(symbolEntry.data));
+                    }
                 }
                 catch (err) {
                     throw new mxnet.Error('Failed to load symbol entry.' + err.message);
                 }
 
-                var signature = null;
-                try {
-                    if (manifest.Model.Signature) {
-                        var signatureEntry = entries[rootFolder + manifest.Model.Signature];
-                        if (signatureEntry) {
-                            signature = JSON.parse(decoder.decode(signatureEntry.data));
-                        }
-                    }
-                }
-                catch (err) {
-                    // continue regardless of error
+                if (paramsEntry) {
+                    params = paramsEntry.data;
                 }
+                var signature = null;
                 try {
-                    if (manifest.Model.Parameters) {
-                        var parametersEntry = entries[rootFolder + manifest.Model.Parameters];
-                        if (parametersEntry) {
-                            params = parametersEntry.data;
-                        }
+                    if (signatureEntry) {
+                        signature = JSON.parse(decoder.decode(signatureEntry.data));
                     }
                 }
                 catch (err) {
@@ -160,12 +201,6 @@ mxnet.ModelFactory = class {
                 }
 
                 try {
-                    if (manifest) {
-                        format = 'MXNet Model Server';
-                        if (manifest['Model-Archive-Version']) {
-                            format += ' v' + manifest['Model-Archive-Version'].toString();
-                        }
-                    }
                     return this._openModel(identifier, format, manifest, symbol, signature, params, host);
                 }
                 catch (error) {
@@ -211,7 +246,7 @@ mxnet.ModelFactory = class {
     static _basename(identifier, extension, suffix) {
         var dots = identifier.split('.');
         if (dots.length >= 2 && dots.pop().toLowerCase() === extension) {
-            var dashes = dots.pop().split('-');
+            var dashes = dots.join('.').split('-');
             if (dashes.length >= 2) {
                 var token = dashes.pop();
                 if (suffix) {
@@ -261,7 +296,38 @@ mxnet.Model = class {
             }
             if (manifest.Engine && manifest.Engine.MXNet) {
                 var engineVersion = mxnet.Model._convert_version(manifest.Engine.MXNet);
-                this._engine = 'MXNet v' + (engineVersion ? engineVersion : manifest.Engine.MXNet.toString());
+                this._runtime = 'MXNet v' + (engineVersion ? engineVersion : manifest.Engine.MXNet.toString());
+            }
+            if (manifest.License) {
+                this._license = manifest.License;
+            }
+            if (manifest.model && manifest.model.modelName) {
+                this._name = manifest.model.modelName;
+            }
+            if (manifest.model && manifest.model.modelVersion) {
+                this._version = manifest.model.modelVersion;
+            }
+            if (manifest.model && manifest.model.modelName && this._name != manifest.model.description) {
+                this._description = manifest.model.description;
+            }
+            if (manifest.runtime) {
+                this._runtime = manifest.runtime;
+            }
+            if (manifest.engine && manifest.engine.engineName) {
+                var engine = manifest.engine.engineName;
+                if (manifest.engine.engineVersion) {
+                    engine = engine + ' ' + manifest.engine.engineVersion;
+                }
+                this._runtime =  this._runtime ? (this._runtime + ' (' + engine + ')') : engine;
+            }
+            if (manifest.publisher && manifest.publisher.author) {
+                this._author = manifest.publisher.author;
+                if (manifest.publisher.email) {
+                    this._author = this._author + ' <' + manifest.publisher.email + '>';
+                }
+            }
+            if (manifest.license) {
+                this._license = manifest.license;
             }
         }
 
@@ -280,20 +346,32 @@ mxnet.Model = class {
         this._graphs.push(new mxnet.Graph(metadata, manifest, symbol, signature, params));
     }
 
+    get format() {
+        return this._format;
+    }
+
     get name() {
         return this._name;
     }
 
-    get format() {
-        return this._format;
+    get version() {
+        return this._version;
     }
 
     get description() {
         return this._description;
     }
 
+    get author() {
+        return this._author;
+    }
+
+    get license() {
+        return this._license;
+    }
+
     get runtime() {
-        return this._engine;
+        return this._runtime;
     }
 
     get graphs() {
@@ -415,14 +493,18 @@ mxnet.Graph = class {
             var block = null;
             var blocks = [];
             var blockMap = {};
-            if (Object.keys(params).every((k) => k.indexOf('_') != -1)) {
+            var separator = Object.keys(params).every((k) => k.indexOf('_') != -1) ? '_' : '';
+            if (separator.length == 0) {
+                separator = Object.keys(params).every((k) => k.indexOf('.') != -1) ? '.' : '';
+            }
+            if (separator.length > 0) {
                 for (var id of Object.keys(params)) {
-                    var parts = id.split('_');
+                    var parts = id.split(separator);
                     var argumentName = parts.pop();
                     if (id.endsWith('moving_mean') || id.endsWith('moving_var')) {
-                        argumentName = [ parts.pop(), argumentName ].join('_');
+                        argumentName = [ parts.pop(), argumentName ].join(separator);
                     }
-                    var nodeName = parts.join('_');
+                    var nodeName = parts.join(separator);
                     block = blockMap[nodeName];
                     if (!block) {
                         block = { name: nodeName, op: 'Weights', params: [] };

+ 1 - 1
src/view-browser.html

@@ -68,7 +68,7 @@
         </svg>
     </a>
     <button id="open-file-button" class="center" style="top: 200px; width: 125px; opacity: 0;">Open Model...</button>
-    <input type="file" id="open-file-dialog" style="display:none" multiple="false" accept=".onnx, .pb, .meta, .tflite, .keras, .h5, .hdf5, .json, .model, .params, .dnn, .cmf, .mlmodel, .caffemodel, .pbtxt, .prototxt, .pkl, .pt, .pth, .t7, .joblib, .cfg, .xml">
+    <input type="file" id="open-file-dialog" style="display:none" multiple="false" accept=".onnx, .pb, .meta, .tflite, .keras, .h5, .hdf5, .json, .model, .mar, .params, .dnn, .cmf, .mlmodel, .caffemodel, .pbtxt, .prototxt, .pkl, .pt, .pth, .t7, .joblib, .cfg, .xml">
     <!-- Preload fonts to workaround Chrome SVG layout issue -->
     <div style="font-weight: normal; color: rgba(0, 0, 0, 0.01); user-select: none;">.</div>
     <div style="font-weight: bold; color: rgba(0, 0, 0, 0.01); user-select: none;">.</div>

+ 11 - 4
src/view.js

@@ -1049,7 +1049,7 @@ class ArchiveContext {
                 if (entry.name.startsWith(rootFolder)) {
                     var name = entry.name.substring(rootFolder.length);
                     if (identifier.length > 0 && identifier.indexOf('/') < 0) {
-                        this._entries[name] = entry.substring(rootFolder.length);
+                        this._entries[name] = entry;
                     }
                 }
             }
@@ -1092,7 +1092,7 @@ view.ModelFactoryService = class {
         this._host = host;
         this._extensions = [];
         this.register('./onnx', [ '.onnx', '.pb', '.pbtxt', '.prototxt' ]);
-        this.register('./mxnet', [ '.model', '.json', '.params' ]);
+        this.register('./mxnet', [ '.mar', '.model', '.json', '.params' ]);
         this.register('./keras', [ '.h5', '.keras', '.hdf5', '.json', '.model' ]);
         this.register('./coreml', [ '.mlmodel' ]);
         this.register('./caffe', [ '.caffemodel', '.pbtxt', '.prototxt', '.pt' ]);
@@ -1274,10 +1274,17 @@ view.ModelFactoryService = class {
                         return Promise.reject(new ArchiveError('Archive does not contain model file.'));
                     }
                     else if (matches.length > 1) {
-                        return Promise.reject(new ArchiveError('Archive contains multiple model files.'));
+                        if (matches.length == 2 &&
+                            matches.some((e) => e.name.endsWith('.params')) &&
+                            matches.some((e) => e.name.endsWith('-symbol.json'))) {
+                            matches = matches.filter((e) => e.name.endsWith('.params'));
+                        }
+                        else {
+                            return Promise.reject(new ArchiveError('Archive contains multiple model files.'));
+                        }
                     }
                     var match = matches[0];
-                    return Promise.resolve(new ModelContext(new ArchiveContext(entries, rootFolder, match.name, match.data)));
+                    return Promise.resolve(new ModelContext(new ArchiveContext(archive.entries, rootFolder, match.name, match.data)));
                 }
             };
             return nextEntry();

+ 28 - 12
test/models.json

@@ -2100,6 +2100,19 @@
     "format": "MXNet Model Server v0.1",
     "link":   "https://github.com/awslabs/mxnet-model-server/blob/master/docs/model_zoo.md"
   },
+  {
+    "type":   "mxnet",
+    "target": "caffenet.model",
+    "source": "https://s3.amazonaws.com/model-server/models/caffenet/caffenet.model",
+    "format": "MXNet Model Server v0.1",
+    "link":   "https://github.com/awslabs/mxnet-model-server/blob/master/docs/model_zoo.md"
+  },
+  {
+    "type":   "mxnet",
+    "target": "crepe.mar",
+    "source": "https://s3.amazonaws.com/model-server/model_archive_1.0/crepe.mar",
+    "link":   "https://github.com/awslabs/mxnet-model-server/blob/master/docs/model_zoo.md"
+  },
   {
     "type":   "mxnet",
     "target": "conv_weights_sharing.json",
@@ -2257,16 +2270,11 @@
     "source": "https://s3.amazonaws.com/mxnet-model-server/onnx-mobilenet/mobilenetv2-1.0.model",
     "link":   "https://github.com/awslabs/mxnet-model-server/blob/master/docs/model_zoo.md"
   },
-  {
-    "type":   "mxnet",
-    "target": "resnet101v1.model",
-    "source": "https://s3.amazonaws.com/mxnet-model-server/onnx-resnetv1/resnet101v1.model",
-    "link":   "https://github.com/awslabs/mxnet-model-server/blob/master/docs/model_zoo.md"
-  },
   {
     "type":   "mxnet",
     "target": "resnet101v2.model",
     "source": "https://s3.amazonaws.com/mxnet-model-server/onnx-resnetv2/resnet101v2.model",
+    "format": "MXNet Model Server v0.2", "runtime": "MXNet v1.2.0",
     "link":   "https://github.com/awslabs/mxnet-model-server/blob/master/docs/model_zoo.md"
   },
   {
@@ -2275,12 +2283,6 @@
     "source": "https://s3.amazonaws.com/mxnet-model-server/onnx-resnetv1/resnet152v1.model",
     "link":   "https://github.com/awslabs/mxnet-model-server/blob/master/docs/model_zoo.md"
   },
-  {
-    "type":   "mxnet",
-    "target": "resnet152v2.model",
-    "source": "https://s3.amazonaws.com/mxnet-model-server/onnx-resnetv2/resnet152v2.model",
-    "link":   "https://github.com/awslabs/mxnet-model-server/blob/master/docs/model_zoo.md"
-  },
   {
     "type":   "mxnet",
     "target": "resnet18v1.model",
@@ -2333,6 +2335,20 @@
     "target": "squeezenet_v1.1.model",
     "source": "https://s3.amazonaws.com/model-server/models/squeezenet_v1.1/squeezenet_v1.1.model"
   },
+  {
+    "type":   "mxnet",
+    "target": "squeezenet_v1.1.mar",
+    "source": "https://s3.amazonaws.com/model-server/model_archive_1.0/squeezenet_v1.1.mar",
+    "format": "MXNet Model Archive v1.0",
+    "link":   "https://github.com/awslabs/mxnet-model-server/blob/master/docs/model_zoo.md"
+  },
+  {
+    "type":   "mxnet",
+    "target": "squeezenet_v1.1.zip",
+    "source": "https://s3.amazonaws.com/model-server/model_archive_1.0/squeezenet_v1.1.mar",
+    "format": "MXNet",
+    "link":   "https://github.com/awslabs/mxnet-model-server/blob/master/docs/model_zoo.md"
+  },
   {
     "type":   "mxnet",
     "target": "vgg16.model",