|
|
@@ -12,7 +12,7 @@ mxnet.ModelFactory = class {
|
|
|
var identifier = context.identifier;
|
|
|
var extension = identifier.split('.').pop().toLowerCase();
|
|
|
var buffer = null;
|
|
|
- if (extension == 'model') {
|
|
|
+ if (extension == 'model' || extension == 'mar') {
|
|
|
buffer = context.buffer;
|
|
|
if (buffer && buffer.length > 2 && buffer[0] == 0x50 && buffer[1] == 0x4B) {
|
|
|
return true;
|
|
|
@@ -32,7 +32,7 @@ mxnet.ModelFactory = class {
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
- else if (mxnet.ModelFactory._basename(identifier, 'params')) {
|
|
|
+ else if (extension == 'params') {
|
|
|
buffer = context.buffer;
|
|
|
var signature = [ 0x12, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 ];
|
|
|
if (buffer && buffer.length > signature.length && signature.every((v, i) => v == buffer[i])) {
|
|
|
@@ -73,15 +73,21 @@ mxnet.ModelFactory = class {
|
|
|
case 'params':
|
|
|
params = context.buffer;
|
|
|
basename = mxnet.ModelFactory._basename(context.identifier, 'params');
|
|
|
- return context.request(basename + '-symbol.json', 'utf-8').then((text) => {
|
|
|
- symbol = JSON.parse(text);
|
|
|
- if (symbol && symbol.nodes && symbol.nodes.some((node) => node && node.op == 'tvm_op')) {
|
|
|
- format = 'TVM';
|
|
|
- }
|
|
|
- return this._openModel(identifier, format, null, symbol, null, params, host);
|
|
|
- }).catch(() => {
|
|
|
+ if (basename) {
|
|
|
+ return context.request(basename + '-symbol.json', 'utf-8').then((text) => {
|
|
|
+ symbol = JSON.parse(text);
|
|
|
+ if (symbol && symbol.nodes && symbol.nodes.some((node) => node && node.op == 'tvm_op')) {
|
|
|
+ format = 'TVM';
|
|
|
+ }
|
|
|
+ return this._openModel(identifier, format, null, symbol, null, params, host);
|
|
|
+ }).catch(() => {
|
|
|
+ return this._openModel(identifier, format, null, null, null, params, host);
|
|
|
+ });
|
|
|
+ }
|
|
|
+ else {
|
|
|
return this._openModel(identifier, format, null, null, null, params, host);
|
|
|
- });
|
|
|
+ }
|
|
|
+ case 'mar':
|
|
|
case 'model':
|
|
|
var entries = {};
|
|
|
try {
|
|
|
@@ -94,7 +100,7 @@ mxnet.ModelFactory = class {
|
|
|
throw new mxnet.Error('Failed to decompress ZIP archive. ' + err.message);
|
|
|
}
|
|
|
|
|
|
- var manifestEntry = entries['MANIFEST.json'];
|
|
|
+ var manifestEntry = entries['MANIFEST.json'] || entries['MAR-INF/MANIFEST.json'];
|
|
|
var rootFolder = '';
|
|
|
if (!manifestEntry) {
|
|
|
var folders = Object.keys(entries).filter((name) => name.endsWith('/')).filter((name) => entries[name + 'MANIFEST.json']);
|
|
|
@@ -113,46 +119,81 @@ mxnet.ModelFactory = class {
|
|
|
catch (err) {
|
|
|
throw new mxnet.Error('Failed to read manifest. ' + err.message);
|
|
|
}
|
|
|
-
|
|
|
- if (!manifest.Model) {
|
|
|
+
|
|
|
+ var modelFormat = null;
|
|
|
+ var symbolEntry = null;
|
|
|
+ var signatureEntry = null;
|
|
|
+ var paramsEntry = null;
|
|
|
+ if (manifest.Model) {
|
|
|
+ modelFormat = manifest.Model['Model-Format'];
|
|
|
+ if (modelFormat && modelFormat != 'MXNet-Symbolic') {
|
|
|
+ throw new mxnet.Error('Model format \'' + modelFormat + '\' not supported.');
|
|
|
+ }
|
|
|
+ format = 'MXNet Model Server';
|
|
|
+ if (manifest['Model-Archive-Version']) {
|
|
|
+ format += ' v' + manifest['Model-Archive-Version'].toString();
|
|
|
+ }
|
|
|
+ if (!manifest.Model.Symbol) {
|
|
|
+ throw new mxnet.Error('Manifest does not contain symbol entry.');
|
|
|
+ }
|
|
|
+ symbolEntry = entries[rootFolder + manifest.Model.Symbol];
|
|
|
+ if (manifest.Model.Signature) {
|
|
|
+ signatureEntry = entries[rootFolder + manifest.Model.Signature];
|
|
|
+ }
|
|
|
+ if (manifest.Model.Parameters) {
|
|
|
+ paramsEntry = entries[rootFolder + manifest.Model.Parameters];
|
|
|
+ }
|
|
|
+ }
|
|
|
+ else if (manifest.model) {
|
|
|
+ format = 'MXNet Model Archive';
|
|
|
+ if (manifest.specificationVersion) {
|
|
|
+ format += ' v' + manifest.specificationVersion.toString();
|
|
|
+ }
|
|
|
+ if (manifest.model.modelName) {
|
|
|
+ symbolEntry = entries[rootFolder + manifest.model.modelName + '-symbol.json']
|
|
|
+ var key = null;
|
|
|
+ for (key of Object.keys(entries)) {
|
|
|
+ key = key.substring(rootFolder.length);
|
|
|
+ if (key.endsWith('.params') && key.startsWith(manifest.model.modelName)) {
|
|
|
+ paramsEntry = entries[key];
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (!symbolEntry && !paramsEntry) {
|
|
|
+ for (key of Object.keys(entries)) {
|
|
|
+ key = key.substring(rootFolder.length);
|
|
|
+ if (key.endsWith('.params')) {
|
|
|
+ paramsEntry = entries[key];
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ else {
|
|
|
throw new mxnet.Error('Manifest does not contain model.');
|
|
|
}
|
|
|
|
|
|
- var modelFormat = manifest.Model['Model-Format'];
|
|
|
- if (modelFormat && modelFormat != 'MXNet-Symbolic') {
|
|
|
- throw new mxnet.Error('Model format \'' + modelFormat + '\' not supported.');
|
|
|
- }
|
|
|
-
|
|
|
- if (!manifest.Model.Symbol) {
|
|
|
- throw new mxnet.Error('Manifest does not contain symbol entry.');
|
|
|
+ if (!symbolEntry && !paramsEntry) {
|
|
|
+ throw new mxnet.Error("Model does not contain symbol entry.");
|
|
|
}
|
|
|
|
|
|
try {
|
|
|
- var symbolEntry = entries[rootFolder + manifest.Model.Symbol];
|
|
|
- symbol = JSON.parse(decoder.decode(symbolEntry.data));
|
|
|
+ if (symbolEntry) {
|
|
|
+ symbol = JSON.parse(decoder.decode(symbolEntry.data));
|
|
|
+ }
|
|
|
}
|
|
|
catch (err) {
|
|
|
throw new mxnet.Error('Failed to load symbol entry.' + err.message);
|
|
|
}
|
|
|
|
|
|
- var signature = null;
|
|
|
- try {
|
|
|
- if (manifest.Model.Signature) {
|
|
|
- var signatureEntry = entries[rootFolder + manifest.Model.Signature];
|
|
|
- if (signatureEntry) {
|
|
|
- signature = JSON.parse(decoder.decode(signatureEntry.data));
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- catch (err) {
|
|
|
- // continue regardless of error
|
|
|
+ if (paramsEntry) {
|
|
|
+ params = paramsEntry.data;
|
|
|
}
|
|
|
+ var signature = null;
|
|
|
try {
|
|
|
- if (manifest.Model.Parameters) {
|
|
|
- var parametersEntry = entries[rootFolder + manifest.Model.Parameters];
|
|
|
- if (parametersEntry) {
|
|
|
- params = parametersEntry.data;
|
|
|
- }
|
|
|
+ if (signatureEntry) {
|
|
|
+ signature = JSON.parse(decoder.decode(signatureEntry.data));
|
|
|
}
|
|
|
}
|
|
|
catch (err) {
|
|
|
@@ -160,12 +201,6 @@ mxnet.ModelFactory = class {
|
|
|
}
|
|
|
|
|
|
try {
|
|
|
- if (manifest) {
|
|
|
- format = 'MXNet Model Server';
|
|
|
- if (manifest['Model-Archive-Version']) {
|
|
|
- format += ' v' + manifest['Model-Archive-Version'].toString();
|
|
|
- }
|
|
|
- }
|
|
|
return this._openModel(identifier, format, manifest, symbol, signature, params, host);
|
|
|
}
|
|
|
catch (error) {
|
|
|
@@ -211,7 +246,7 @@ mxnet.ModelFactory = class {
|
|
|
static _basename(identifier, extension, suffix) {
|
|
|
var dots = identifier.split('.');
|
|
|
if (dots.length >= 2 && dots.pop().toLowerCase() === extension) {
|
|
|
- var dashes = dots.pop().split('-');
|
|
|
+ var dashes = dots.join('.').split('-');
|
|
|
if (dashes.length >= 2) {
|
|
|
var token = dashes.pop();
|
|
|
if (suffix) {
|
|
|
@@ -261,7 +296,38 @@ mxnet.Model = class {
|
|
|
}
|
|
|
if (manifest.Engine && manifest.Engine.MXNet) {
|
|
|
var engineVersion = mxnet.Model._convert_version(manifest.Engine.MXNet);
|
|
|
- this._engine = 'MXNet v' + (engineVersion ? engineVersion : manifest.Engine.MXNet.toString());
|
|
|
+ this._runtime = 'MXNet v' + (engineVersion ? engineVersion : manifest.Engine.MXNet.toString());
|
|
|
+ }
|
|
|
+ if (manifest.License) {
|
|
|
+ this._license = manifest.License;
|
|
|
+ }
|
|
|
+ if (manifest.model && manifest.model.modelName) {
|
|
|
+ this._name = manifest.model.modelName;
|
|
|
+ }
|
|
|
+ if (manifest.model && manifest.model.modelVersion) {
|
|
|
+ this._version = manifest.model.modelVersion;
|
|
|
+ }
|
|
|
+ if (manifest.model && manifest.model.modelName && this._name != manifest.model.description) {
|
|
|
+ this._description = manifest.model.description;
|
|
|
+ }
|
|
|
+ if (manifest.runtime) {
|
|
|
+ this._runtime = manifest.runtime;
|
|
|
+ }
|
|
|
+ if (manifest.engine && manifest.engine.engineName) {
|
|
|
+ var engine = manifest.engine.engineName;
|
|
|
+ if (manifest.engine.engineVersion) {
|
|
|
+ engine = engine + ' ' + manifest.engine.engineVersion;
|
|
|
+ }
|
|
|
+ this._runtime = this._runtime ? (this._runtime + ' (' + engine + ')') : engine;
|
|
|
+ }
|
|
|
+ if (manifest.publisher && manifest.publisher.author) {
|
|
|
+ this._author = manifest.publisher.author;
|
|
|
+ if (manifest.publisher.email) {
|
|
|
+ this._author = this._author + ' <' + manifest.publisher.email + '>';
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (manifest.license) {
|
|
|
+ this._license = manifest.license;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -280,20 +346,32 @@ mxnet.Model = class {
|
|
|
this._graphs.push(new mxnet.Graph(metadata, manifest, symbol, signature, params));
|
|
|
}
|
|
|
|
|
|
+ get format() {
|
|
|
+ return this._format;
|
|
|
+ }
|
|
|
+
|
|
|
get name() {
|
|
|
return this._name;
|
|
|
}
|
|
|
|
|
|
- get format() {
|
|
|
- return this._format;
|
|
|
+ get version() {
|
|
|
+ return this._version;
|
|
|
}
|
|
|
|
|
|
get description() {
|
|
|
return this._description;
|
|
|
}
|
|
|
|
|
|
+ get author() {
|
|
|
+ return this._author;
|
|
|
+ }
|
|
|
+
|
|
|
+ get license() {
|
|
|
+ return this._license;
|
|
|
+ }
|
|
|
+
|
|
|
get runtime() {
|
|
|
- return this._engine;
|
|
|
+ return this._runtime;
|
|
|
}
|
|
|
|
|
|
get graphs() {
|
|
|
@@ -415,14 +493,18 @@ mxnet.Graph = class {
|
|
|
var block = null;
|
|
|
var blocks = [];
|
|
|
var blockMap = {};
|
|
|
- if (Object.keys(params).every((k) => k.indexOf('_') != -1)) {
|
|
|
+ var separator = Object.keys(params).every((k) => k.indexOf('_') != -1) ? '_' : '';
|
|
|
+ if (separator.length == 0) {
|
|
|
+ separator = Object.keys(params).every((k) => k.indexOf('.') != -1) ? '.' : '';
|
|
|
+ }
|
|
|
+ if (separator.length > 0) {
|
|
|
for (var id of Object.keys(params)) {
|
|
|
- var parts = id.split('_');
|
|
|
+ var parts = id.split(separator);
|
|
|
var argumentName = parts.pop();
|
|
|
if (id.endsWith('moving_mean') || id.endsWith('moving_var')) {
|
|
|
- argumentName = [ parts.pop(), argumentName ].join('_');
|
|
|
+ argumentName = [ parts.pop(), argumentName ].join(separator);
|
|
|
}
|
|
|
- var nodeName = parts.join('_');
|
|
|
+ var nodeName = parts.join(separator);
|
|
|
block = blockMap[nodeName];
|
|
|
if (!block) {
|
|
|
block = { name: nodeName, op: 'Weights', params: [] };
|