Przeglądaj źródła

Update mxnet.js (#723)

Lutz Roeder 4 lat temu
rodzic
commit
cccd39d23e
3 zmienionych plików z 217 dodań i 270 usunięć
  1. 204 253
      source/mxnet.js
  2. 5 3
      source/view.js
  3. 8 14
      test/models.json

+ 204 - 253
source/mxnet.js

@@ -10,22 +10,21 @@ mxnet.ModelFactory = class {
     match(context) {
         const identifier = context.identifier;
         const extension = identifier.split('.').pop().toLowerCase();
-        if (extension === 'model' || extension === 'mar') {
-            if (context.entries('zip').length > 0) {
-                return true;
-            }
-        }
-        else if (extension == 'json') {
-            const obj = context.open('json');
-            if (obj && obj.nodes && obj.arg_nodes && obj.heads) {
-                return true;
+        switch (extension) {
+            case 'json': {
+                const obj = context.open('json');
+                if (obj && obj.nodes && obj.arg_nodes && obj.heads) {
+                    return true;
+                }
+                break;
             }
-        }
-        else if (extension == 'params') {
-            const stream = context.stream;
-            const signature = [ 0x12, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 ];
-            if (stream.length > signature.length && stream.peek(signature.length).every((value, index) => value == signature[index])) {
-                return true;
+            case 'params': {
+                const stream = context.stream;
+                const signature = [ 0x12, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 ];
+                if (stream.length > signature.length && stream.peek(signature.length).every((value, index) => value == signature[index])) {
+                    return true;
+                }
+                break;
             }
         }
         return false;
@@ -33,32 +32,138 @@ mxnet.ModelFactory = class {
 
     open(context) {
         return mxnet.Metadata.open(context).then((metadata) => {
-            const basename = (identifier, extension, suffix) => {
-                const dots = identifier.split('.');
-                if (dots.length >= 2 && dots.pop().toLowerCase() === extension) {
-                    const dashes = dots.join('.').split('-');
-                    if (dashes.length >= 2) {
-                        const token = dashes.pop();
-                        if (suffix) {
-                            if (token != suffix) {
-                                return null;
+            const basename = (base, identifier, extension, suffix, append) => {
+                if (!base) {
+                    if (identifier.toLowerCase().endsWith(extension)) {
+                        const items = identifier.substring(0, identifier.length - extension.length).split('-');
+                        if (items.length >= 2) {
+                            const token = items.pop();
+                            if ((suffix && token === suffix) || /[a-zA-Z0-9]*/.exec(token)) {
+                                return items.join('-') + append;
                             }
                         }
-                        else {
-                            for (let i = 0; i < token.length; i++) {
-                                const c = token.charAt(i);
-                                if ((c >= '0' && c <= '9') || (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z')) {
-                                    continue;
+                    }
+                }
+                return base;
+            };
+            const convertVersion = (value) => {
+                if (Array.isArray(value)) {
+                    if (value.length === 2 && value[0] === 'int') {
+                        const major = Math.floor(value[1] / 10000) % 100;
+                        const minor = Math.floor(value[1] / 100) % 100;
+                        const patch = Math.floor(value[1]) % 100;
+                        return [ major.toString(), minor.toString(), patch.toString() ].join('.');
+                    }
+                }
+                return null;
+            };
+            const requestManifest = () => {
+                const parse = (stream) => {
+                    try {
+                        const manifest = {};
+                        const decoder = new TextDecoder('utf-8');
+                        if (stream) {
+                            const buffer = stream.peek();
+                            const text = decoder.decode(buffer);
+                            const json = JSON.parse(text);
+                            if (json.Model) {
+                                const modelFormat = json.Model['Model-Format'];
+                                if (modelFormat && modelFormat != 'MXNet-Symbolic') {
+                                    throw new mxnet.Error('Model format \'' + modelFormat + '\' not supported.');
+                                }
+                                manifest.format = 'MXNet Model Server';
+                                if (json['Model-Archive-Version']) {
+                                    manifest.format += ' v' + json['Model-Archive-Version'].toString();
                                 }
-                                return null;
+                                if (!json.Model.Symbol) {
+                                    throw new mxnet.Error('Manifest does not contain symbol entry.');
+                                }
+                                manifest.symbol = json.Model.Symbol;
+                                if (json.Model.Signature) {
+                                    manifest.signature = json.Model.Signature;
+                                }
+                                if (json.Model.Parameters) {
+                                    manifest.params = json.Model.Parameters;
+                                }
+                                if (json.Model['Model-Name']) {
+                                    manifest.name = json.Model['Model-Name'];
+                                }
+                                if (json.Model.Description && manifest.name !== json.Model.Description) {
+                                    manifest.description = json.Model.Description;
+                                }
+                            }
+                            else if (json.model) {
+                                manifest.format = 'MXNet Model Archive';
+                                if (json.specificationVersion) {
+                                    manifest.format += ' v' + json.specificationVersion.toString();
+                                }
+                                if (json.model.modelName) {
+                                    manifest.symbol = json.model.modelName + '-symbol.json';
+                                }
+                                if (json.model.modelName) {
+                                    manifest.name = json.model.modelName;
+                                }
+                                if (manifest.model && json.model.modelVersion) {
+                                    manifest.version = json.model.modelVersion;
+                                }
+                                if (manifest.model && manifest.model.modelName && manifest.name != json.model.description) {
+                                    manifest.description = json.model.description;
+                                }
+                            }
+                            else {
+                                throw new mxnet.Error('Manifest does not contain model.');
+                            }
+                            if (json.Engine && json.Engine.MXNet) {
+                                const version = convertVersion(json.Engine.MXNet);
+                                manifest.runtime = 'MXNet v' + (version ? version : json.Engine.MXNet.toString());
+                            }
+                            if (json.License) {
+                                manifest.license = json.License;
+                            }
+                            if (json.runtime) {
+                                manifest.runtime = json.runtime;
+                            }
+                            if (json.engine && json.engine.engineName) {
+                                const engine = json.engine.engineVersion ? json.engine.engineName + ' ' + json.engine.engineVersion : json.engine.engineName;
+                                manifest.runtime = manifest.runtime ? (manifest.runtime + ' (' + engine + ')') : engine;
+                            }
+                            if (json.publisher && json.publisher.author) {
+                                manifest.author = json.publisher.author;
+                                if (json.publisher.email) {
+                                    manifest.author = manifest.author + ' <' + json.publisher.email + '>';
+                                }
+                            }
+                            if (json.license) {
+                                manifest.license = json.license;
+                            }
+                            if (json.Model && json.Model.Signature) {
+                                return context.request(json.Model.Signature).then((stream) => {
+                                    const buffer = stream.peek();
+                                    const text = decoder.decode(buffer);
+                                    manifest.signature = JSON.parse(text);
+                                    return manifest;
+                                }).catch (() => {
+                                    return manifest;
+                                });
                             }
                         }
-                        return dashes.join('-');
+                        return manifest;
                     }
-                }
-                return null;
+                    catch (err) {
+                        throw new mxnet.Error('Failed to read manifest. ' + err.message);
+                    }
+                };
+                return context.request('MANIFEST.json').then((stream) => {
+                    return parse(stream);
+                }).catch (() => {
+                    return context.request('MAR-INF/MANIFEST.json').then((stream) => {
+                        return parse(stream);
+                    }).catch(() => {
+                        return parse(null);
+                    });
+                });
             };
-            const open_model = (metadata, format, manifest, symbol, signature, params) => {
+            const createModel = (metadata, manifest, symbol, params) => {
                 const parameters = new Map();
                 if (params) {
                     try {
@@ -72,166 +177,66 @@ mxnet.ModelFactory = class {
                         // continue regardless of error
                     }
                 }
-                return new mxnet.Model(metadata, format, manifest, symbol, signature, parameters);
+                if (symbol) {
+                    if (!manifest.format) {
+                        const version = convertVersion(symbol && symbol.attrs && symbol.attrs.mxnet_version ? symbol.attrs.mxnet_version : null);
+                        manifest.format = 'MXNet' + (version ? ' v' + version : '');
+                    }
+                    if (symbol.nodes && symbol.nodes.some((node) => node && node.op == 'tvm_op')) {
+                        manifest.producer  = 'TVM';
+                    }
+                }
+                return new mxnet.Model(metadata, manifest, symbol, parameters);
             };
             const identifier = context.identifier;
-            const extension = context.identifier.split('.').pop().toLowerCase();
-            let symbol = null;
-            let params = null;
-            let format = null;
-            let base = null;
+            const extension = identifier.split('.').pop().toLowerCase();
             switch (extension) {
-                case 'json':
+                case 'json': {
+                    let symbol = null;
                     try {
                         symbol = context.open('json');
-                        if (symbol && symbol.nodes && symbol.nodes.some((node) => node && node.op == 'tvm_op')) {
-                            format  = 'TVM';
-                        }
                     }
                     catch (error) {
                         const message = error && error.message ? error.message : error.toString();
                         throw new mxnet.Error("Failed to load symbol entry (" + message.replace(/\.$/, '') + ').');
                     }
-                    base = basename(identifier, 'json', 'symbol');
-                    if (base) {
-                        return context.request(base + '-0000.params', null).then((stream) => {
-                            const buffer = stream.peek();
-                            return open_model(metadata, format, null, symbol, null, buffer);
-                        }).catch(() => {
-                            return open_model(metadata, format, null, symbol, null, params);
-                        });
-                    }
-                    return open_model(metadata, format, null, symbol, null, null);
-                case 'params':
-                    params = context.stream.peek();
-                    base = basename(context.identifier, 'params');
-                    if (base) {
-                        return context.request(base + '-symbol.json', 'utf-8').then((text) => {
-                            symbol = JSON.parse(text);
-                            if (symbol && symbol.nodes && symbol.nodes.some((node) => node && node.op == 'tvm_op')) {
-                                format  = 'TVM';
-                            }
-                            return open_model(metadata, format, null, symbol, null, params);
-                        }).catch(() => {
-                            return open_model(metadata, format, null, null, null, params);
-                        });
-                    }
-                    return open_model(metadata, format, null, null, null, params);
-                case 'mar':
-                case 'model': {
-                    const entries = new Map();
-                    try {
-                        for (const entry of context.entries('zip')) {
-                            entries.set(entry.name, entry);
+                    const requestParams = (manifest) => {
+                        const file = basename(manifest.params, identifier, '.json', 'symbol', '-0000.params');
+                        if (file) {
+                            return context.request(file, null).then((stream) => {
+                                const buffer = stream.peek();
+                                return createModel(metadata, manifest, symbol, buffer);
+                            }).catch(() => {
+                                return createModel(metadata, manifest, symbol, null);
+                            });
                         }
-                    }
-                    catch (err) {
-                        throw new mxnet.Error('Failed to decompress Zip archive. ' + err.message);
-                    }
-
-                    let manifestEntry = entries.get(entries.has('MANIFEST.json') ? 'MANIFEST.json' : 'MAR-INF/MANIFEST.json');
-                    let rootFolder = '';
-                    if (!manifestEntry) {
-                        const folders = Array.from(entries.keys()).filter((name) => name.endsWith('/')).filter((name) => entries.get(name + 'MANIFEST.json'));
-                        if (folders.length != 1) {
-                            throw new mxnet.Error("Manifest not found.");
-                        }
-                        rootFolder = folders[0];
-                        manifestEntry = entries.get(rootFolder + 'MANIFEST.json');
-                    }
-
-                    const decoder = new TextDecoder('utf-8');
-                    let manifest = null;
-                    try {
-                        manifest = JSON.parse(decoder.decode(manifestEntry.data));
-                    }
-                    catch (err) {
-                        throw new mxnet.Error('Failed to read manifest. ' + err.message);
-                    }
-
-                    let modelFormat = null;
-                    let symbolEntry = null;
-                    let signatureEntry = null;
-                    let paramsEntry = null;
-                    if (manifest.Model) {
-                        modelFormat = manifest.Model['Model-Format'];
-                        if (modelFormat && modelFormat != 'MXNet-Symbolic') {
-                            throw new mxnet.Error('Model format \'' + modelFormat + '\' not supported.');
-                        }
-                        format = 'MXNet Model Server';
-                        if (manifest['Model-Archive-Version']) {
-                            format += ' v' + manifest['Model-Archive-Version'].toString();
-                        }
-                        if (!manifest.Model.Symbol) {
-                            throw new mxnet.Error('Manifest does not contain symbol entry.');
-                        }
-                        symbolEntry = entries.get(rootFolder + manifest.Model.Symbol);
-                        if (manifest.Model.Signature) {
-                            signatureEntry = entries.get(rootFolder + manifest.Model.Signature);
-                        }
-                        if (manifest.Model.Parameters) {
-                            paramsEntry = entries.get(rootFolder + manifest.Model.Parameters);
-                        }
-                    }
-                    else if (manifest.model) {
-                        format = 'MXNet Model Archive';
-                        if (manifest.specificationVersion) {
-                            format += ' v' + manifest.specificationVersion.toString();
-                        }
-                        if (manifest.model.modelName) {
-                            symbolEntry = entries.get(rootFolder + manifest.model.modelName + '-symbol.json');
-                            let key = null;
-                            for (key of Array.from(entries.keys())) {
-                                key = key.substring(rootFolder.length);
-                                if (key.endsWith('.params') && key.startsWith(manifest.model.modelName)) {
-                                    paramsEntry = entries.get(key);
-                                    break;
-                                }
-                            }
-                            if (!symbolEntry && !paramsEntry) {
-                                for (key of Object.keys(entries)) {
-                                    key = key.substring(rootFolder.length);
-                                    if (key.endsWith('.params')) {
-                                        paramsEntry = entries.get(key);
-                                        break;
-                                    }
-                                }
-                            }
-                        }
-                    }
-                    else {
-                        throw new mxnet.Error('Manifest does not contain model.');
-                    }
-
-                    if (!symbolEntry && !paramsEntry) {
-                        throw new mxnet.Error("Model does not contain symbol entry.");
-                    }
-
-                    try {
-                        if (symbolEntry) {
-                            symbol = JSON.parse(decoder.decode(symbolEntry.data));
-                        }
-                    }
-                    catch (err) {
-                        throw new mxnet.Error('Failed to load symbol entry.' + err.message);
-                    }
-
-                    if (paramsEntry) {
-                        params = paramsEntry.data;
-                    }
-                    let signature = null;
-                    try {
-                        if (signatureEntry) {
-                            signature = JSON.parse(decoder.decode(signatureEntry.data));
+                        return createModel(metadata, manifest, symbol, null);
+                    };
+                    return requestManifest().then((manifest) => {
+                        return requestParams(manifest);
+                    });
+                }
+                case 'params': {
+                    const params = context.stream.peek();
+                    const requestSymbol = (manifest) => {
+                        const file = basename(manifest.symbol, identifier, '.params', null, '-symbol.json');
+                        if (file) {
+                            return context.request(file, 'utf-8').then((text) => {
+                                const symbol = JSON.parse(text);
+                                return createModel(metadata, manifest, symbol, params);
+                            }).catch(() => {
+                                return createModel(metadata, manifest, null, params);
+                            });
                         }
-                    }
-                    catch (err) {
-                        // continue regardless of error
-                    }
-                    return open_model(metadata, format, manifest, symbol, signature, params);
+                        return createModel(metadata, manifest, null, params);
+                    };
+                    return requestManifest().then((manifest) => {
+                        return requestSymbol(manifest);
+                    });
                 }
-                default:
+                default: {
                     throw new mxnet.Error('Unsupported file extension.');
+                }
             }
         });
     }
@@ -239,7 +244,7 @@ mxnet.ModelFactory = class {
 
 mxnet.Model = class {
 
-    constructor(metadata, format, manifest, symbol, signature, params) {
+    constructor(metadata, manifest, symbol, params) {
         if (!symbol && !params) {
             throw new mxnet.Error('JSON symbol data not available.');
         }
@@ -254,67 +259,25 @@ mxnet.Model = class {
                 throw new mxnet.Error('JSON file does not contain an MXNet \'heads\' property.');
             }
         }
-
-        if (manifest) {
-            if (manifest.Model && manifest.Model['Model-Name']) {
-                this._name = manifest.Model['Model-Name'];
-            }
-            if (manifest.Model && manifest.Model.Description && this._name != manifest.Model.Description) {
-                this._description = manifest.Model.Description;
-            }
-            if (manifest.Engine && manifest.Engine.MXNet) {
-                const engineVersion = mxnet.Model._convert_version(manifest.Engine.MXNet);
-                this._runtime = 'MXNet v' + (engineVersion ? engineVersion : manifest.Engine.MXNet.toString());
-            }
-            if (manifest.License) {
-                this._license = manifest.License;
-            }
-            if (manifest.model && manifest.model.modelName) {
-                this._name = manifest.model.modelName;
-            }
-            if (manifest.model && manifest.model.modelVersion) {
-                this._version = manifest.model.modelVersion;
-            }
-            if (manifest.model && manifest.model.modelName && this._name != manifest.model.description) {
-                this._description = manifest.model.description;
-            }
-            if (manifest.runtime) {
-                this._runtime = manifest.runtime;
-            }
-            if (manifest.engine && manifest.engine.engineName) {
-                const engine = manifest.engine.engineVersion ? manifest.engine.engineName + ' ' + manifest.engine.engineVersion : manifest.engine.engineName;
-                this._runtime =  this._runtime ? (this._runtime + ' (' + engine + ')') : engine;
-            }
-            if (manifest.publisher && manifest.publisher.author) {
-                this._author = manifest.publisher.author;
-                if (manifest.publisher.email) {
-                    this._author = this._author + ' <' + manifest.publisher.email + '>';
-                }
-            }
-            if (manifest.license) {
-                this._license = manifest.license;
-            }
-        }
-
-        this._format = format;
-        if (!this._format && symbol && symbol.attrs && symbol.attrs.mxnet_version) {
-            const version = mxnet.Model._convert_version(symbol.attrs.mxnet_version);
-            if (version) {
-                this._format = 'MXNet v' + version;
-            }
-        }
-        if (!this._format) {
-            this._format = 'MXNet';
-        }
-
-        this._graphs = [];
-        this._graphs.push(new mxnet.Graph(metadata, manifest, symbol, signature, params));
+        this._format = manifest.format || 'MXNet';
+        this._producer = manifest.producer || '';
+        this._name = manifest.name || '';
+        this._version = manifest.version;
+        this._description = manifest.description || '';
+        this._runtime = manifest.runtime || '';
+        this._author = manifest.author || '';
+        this._license = manifest.license || '';
+        this._graphs = [ new mxnet.Graph(metadata, manifest, symbol, params) ];
     }
 
     get format() {
         return this._format;
     }
 
+    get producer() {
+        return this._producer;
+    }
+
     get name() {
         return this._name;
     }
@@ -342,23 +305,11 @@ mxnet.Model = class {
     get graphs() {
         return this._graphs;
     }
-
-    static _convert_version(value) {
-        if (Array.isArray(value)) {
-            if (value.length == 2 && value[0] == 'int') {
-                const major = Math.floor(value[1] / 10000) % 100;
-                const minor = Math.floor(value[1] / 100) % 100;
-                const patch = Math.floor(value[1]) % 100;
-                return [ major.toString(), minor.toString(), patch.toString() ].join('.');
-            }
-        }
-        return null;
-    }
 };
 
 mxnet.Graph = class {
 
-    constructor(metadata, manifest, symbol, signature, params) {
+    constructor(metadata, manifest, symbol, params) {
         this._metadata = metadata;
         this._nodes = [];
         this._inputs = [];
@@ -376,14 +327,14 @@ mxnet.Graph = class {
         if (symbol) {
             const nodes = symbol.nodes;
             const inputs = {};
-            if (signature && signature.inputs) {
-                for (const input of signature.inputs) {
+            if (manifest && manifest.signature && manifest.signature.inputs) {
+                for (const input of manifest.signature.inputs) {
                     inputs[input.data_name] = input;
                 }
             }
             const outputs = {};
-            if (signature && signature.outputs) {
-                for (const output of signature.outputs) {
+            if (manifest && manifest.signature && manifest.signature.outputs) {
+                for (const output of manifest.signature.outputs) {
                     outputs[output.data_name] = output;
                 }
             }

+ 5 - 3
source/view.js

@@ -1326,7 +1326,7 @@ view.ArchiveContext = class {
             for (const entry of entries) {
                 if (entry.name.startsWith(rootFolder)) {
                     const name = entry.name.substring(rootFolder.length);
-                    if (name.length > 0 && name.indexOf('/') === -1) {
+                    if (name.length > 0 && (name.indexOf('/') === -1 || name.startsWith('MAR-INF/'))) {
                         this._entries[name] = entry;
                     }
                 }
@@ -1379,7 +1379,7 @@ view.ModelFactoryService = class {
         this._extensions = [];
         this.register('./pytorch', [ '.pt', '.pth', '.pt1', '.pyt', '.pkl', '.h5', '.t7', '.model', '.dms', '.tar', '.ckpt', '.chkpt', '.tckpt', '.bin', '.pb', '.zip', '.nn' ]);
         this.register('./onnx', [ '.onnx', '.pb', '.pbtxt', '.prototxt', '.model', '.pt', '.pth', '.pkl' ]);
-        this.register('./mxnet', [ '.mar', '.model', '.json', '.params' ]);
+        this.register('./mxnet', [ '.json', '.params' ]);
         this.register('./coreml', [ '.mlmodel' ]);
         this.register('./caffe', [ '.caffemodel', '.pbtxt', '.prototxt', '.pt', '.txt' ]);
         this.register('./caffe2', [ '.pb', '.pbtxt', '.prototxt' ]);
@@ -1727,7 +1727,9 @@ view.ModelFactoryService = class {
         if (identifier.endsWith('.zip') ||
             identifier.endsWith('.tar') ||
             identifier.endsWith('.tar.gz') ||
-            identifier.endsWith('.tgz')) {
+            identifier.endsWith('.tgz') ||
+            identifier.endsWith('.mar') ||
+            identifier.endsWith('.model')) {
             this._host.event('File', 'Accept', extension, 1);
             return true;
         }

+ 8 - 14
test/models.json

@@ -2701,7 +2701,7 @@
     "type":   "mxnet",
     "target": "mobilenet-v1-tvm.json",
     "source": "https://github.com/lutzroeder/netron/files/2636924/mobilenet-v1-tvm.json.zip[mobilenet-v1-tvm.json]",
-    "format": "TVM",
+    "format": "MXNet", "producer": "TVM",
     "link":   "https://github.com/lutzroeder/netron/issues/199"
   },
   {
@@ -2758,13 +2758,6 @@
     "format": "MXNet Model Archive v1.0",
     "link":   "https://github.com/awslabs/mxnet-model-server/blob/master/docs/model_zoo.md"
   },
-  {
-    "type":   "mxnet",
-    "target": "squeezenet_v1.1.zip",
-    "source": "https://s3.amazonaws.com/model-server/model_archive_1.0/squeezenet_v1.1.mar",
-    "format": "MXNet",
-    "link":   "https://github.com/awslabs/mxnet-model-server/blob/master/docs/model_zoo.md"
-  },
   {
     "type":   "mxnet",
     "target": "vgg16.mar",
@@ -4119,6 +4112,13 @@
     "source": "https://github.com/lutzroeder/netron/files/6096621/densenet161.zip.pth.zip[densenet161.zip.pth]",
     "link":   "https://pytorch.org/docs/stable/torchvision/models.html"
   },
+  {
+    "type":   "pytorch",
+    "target": "densenet161.mar",
+    "source": "https://torchserve.pytorch.org/mar_files/densenet161.mar",
+    "format": "PyTorch v0.1.10",
+    "link":   "https://github.com/lutzroeder/netron/issues/286"
+  },
   {
     "type":   "pytorch",
     "target": "DRNL4x_dual_model.pth",
@@ -4172,12 +4172,6 @@
     "source": "https://github.com/lutzroeder/netron/files/6096630/inception_v3.zip.pth.zip[inception_v3.zip.pth]",
     "link":   "https://github.com/lutzroeder/netron/issues/133"
   },
-  {
-    "type":   "pytorch",
-    "target": "inception_v3_google-1a9a5a14.pth",
-    "format": "PyTorch v0.1.10",
-    "source": "https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth"
-  },
   {
     "type":   "pytorch",
     "target": "iv3_pertensor.pt",