|
|
@@ -10,22 +10,21 @@ mxnet.ModelFactory = class {
|
|
|
match(context) {
|
|
|
const identifier = context.identifier;
|
|
|
const extension = identifier.split('.').pop().toLowerCase();
|
|
|
- if (extension === 'model' || extension === 'mar') {
|
|
|
- if (context.entries('zip').length > 0) {
|
|
|
- return true;
|
|
|
- }
|
|
|
- }
|
|
|
- else if (extension == 'json') {
|
|
|
- const obj = context.open('json');
|
|
|
- if (obj && obj.nodes && obj.arg_nodes && obj.heads) {
|
|
|
- return true;
|
|
|
+ switch (extension) {
|
|
|
+ case 'json': {
|
|
|
+ const obj = context.open('json');
|
|
|
+ if (obj && obj.nodes && obj.arg_nodes && obj.heads) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ break;
|
|
|
}
|
|
|
- }
|
|
|
- else if (extension == 'params') {
|
|
|
- const stream = context.stream;
|
|
|
- const signature = [ 0x12, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 ];
|
|
|
- if (stream.length > signature.length && stream.peek(signature.length).every((value, index) => value == signature[index])) {
|
|
|
- return true;
|
|
|
+ case 'params': {
|
|
|
+ const stream = context.stream;
|
|
|
+ const signature = [ 0x12, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 ];
|
|
|
+ if (stream.length > signature.length && stream.peek(signature.length).every((value, index) => value == signature[index])) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ break;
|
|
|
}
|
|
|
}
|
|
|
return false;
|
|
|
@@ -33,32 +32,138 @@ mxnet.ModelFactory = class {
|
|
|
|
|
|
open(context) {
|
|
|
return mxnet.Metadata.open(context).then((metadata) => {
|
|
|
- const basename = (identifier, extension, suffix) => {
|
|
|
- const dots = identifier.split('.');
|
|
|
- if (dots.length >= 2 && dots.pop().toLowerCase() === extension) {
|
|
|
- const dashes = dots.join('.').split('-');
|
|
|
- if (dashes.length >= 2) {
|
|
|
- const token = dashes.pop();
|
|
|
- if (suffix) {
|
|
|
- if (token != suffix) {
|
|
|
- return null;
|
|
|
+ const basename = (base, identifier, extension, suffix, append) => {
|
|
|
+ if (!base) {
|
|
|
+ if (identifier.toLowerCase().endsWith(extension)) {
|
|
|
+ const items = identifier.substring(0, identifier.length - extension.length).split('-');
|
|
|
+ if (items.length >= 2) {
|
|
|
+ const token = items.pop();
|
|
|
+ if ((suffix && token === suffix) || /[a-zA-Z0-9]*/.exec(token)) {
|
|
|
+ return items.join('-') + append;
|
|
|
}
|
|
|
}
|
|
|
- else {
|
|
|
- for (let i = 0; i < token.length; i++) {
|
|
|
- const c = token.charAt(i);
|
|
|
- if ((c >= '0' && c <= '9') || (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z')) {
|
|
|
- continue;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return base;
|
|
|
+ };
|
|
|
+ const convertVersion = (value) => {
|
|
|
+ if (Array.isArray(value)) {
|
|
|
+ if (value.length === 2 && value[0] === 'int') {
|
|
|
+ const major = Math.floor(value[1] / 10000) % 100;
|
|
|
+ const minor = Math.floor(value[1] / 100) % 100;
|
|
|
+ const patch = Math.floor(value[1]) % 100;
|
|
|
+ return [ major.toString(), minor.toString(), patch.toString() ].join('.');
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return null;
|
|
|
+ };
|
|
|
+ const requestManifest = () => {
|
|
|
+ const parse = (stream) => {
|
|
|
+ try {
|
|
|
+ const manifest = {};
|
|
|
+ const decoder = new TextDecoder('utf-8');
|
|
|
+ if (stream) {
|
|
|
+ const buffer = stream.peek();
|
|
|
+ const text = decoder.decode(buffer);
|
|
|
+ const json = JSON.parse(text);
|
|
|
+ if (json.Model) {
|
|
|
+ const modelFormat = json.Model['Model-Format'];
|
|
|
+ if (modelFormat && modelFormat != 'MXNet-Symbolic') {
|
|
|
+ throw new mxnet.Error('Model format \'' + modelFormat + '\' not supported.');
|
|
|
+ }
|
|
|
+ manifest.format = 'MXNet Model Server';
|
|
|
+ if (json['Model-Archive-Version']) {
|
|
|
+ manifest.format += ' v' + json['Model-Archive-Version'].toString();
|
|
|
}
|
|
|
- return null;
|
|
|
+ if (!json.Model.Symbol) {
|
|
|
+ throw new mxnet.Error('Manifest does not contain symbol entry.');
|
|
|
+ }
|
|
|
+ manifest.symbol = json.Model.Symbol;
|
|
|
+ if (json.Model.Signature) {
|
|
|
+ manifest.signature = json.Model.Signature;
|
|
|
+ }
|
|
|
+ if (json.Model.Parameters) {
|
|
|
+ manifest.params = json.Model.Parameters;
|
|
|
+ }
|
|
|
+ if (json.Model['Model-Name']) {
|
|
|
+ manifest.name = json.Model['Model-Name'];
|
|
|
+ }
|
|
|
+ if (json.Model.Description && manifest.name !== json.Model.Description) {
|
|
|
+ manifest.description = json.Model.Description;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ else if (json.model) {
|
|
|
+ manifest.format = 'MXNet Model Archive';
|
|
|
+ if (json.specificationVersion) {
|
|
|
+ manifest.format += ' v' + json.specificationVersion.toString();
|
|
|
+ }
|
|
|
+ if (json.model.modelName) {
|
|
|
+ manifest.symbol = json.model.modelName + '-symbol.json';
|
|
|
+ }
|
|
|
+ if (json.model.modelName) {
|
|
|
+ manifest.name = json.model.modelName;
|
|
|
+ }
|
|
|
+ if (manifest.model && json.model.modelVersion) {
|
|
|
+ manifest.version = json.model.modelVersion;
|
|
|
+ }
|
|
|
+ if (manifest.model && manifest.model.modelName && manifest.name != json.model.description) {
|
|
|
+ manifest.description = json.model.description;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ throw new mxnet.Error('Manifest does not contain model.');
|
|
|
+ }
|
|
|
+ if (json.Engine && json.Engine.MXNet) {
|
|
|
+ const version = convertVersion(json.Engine.MXNet);
|
|
|
+ manifest.runtime = 'MXNet v' + (version ? version : json.Engine.MXNet.toString());
|
|
|
+ }
|
|
|
+ if (json.License) {
|
|
|
+ manifest.license = json.License;
|
|
|
+ }
|
|
|
+ if (json.runtime) {
|
|
|
+ manifest.runtime = json.runtime;
|
|
|
+ }
|
|
|
+ if (json.engine && json.engine.engineName) {
|
|
|
+ const engine = json.engine.engineVersion ? json.engine.engineName + ' ' + json.engine.engineVersion : json.engine.engineName;
|
|
|
+ manifest.runtime = manifest.runtime ? (manifest.runtime + ' (' + engine + ')') : engine;
|
|
|
+ }
|
|
|
+ if (json.publisher && json.publisher.author) {
|
|
|
+ manifest.author = json.publisher.author;
|
|
|
+ if (json.publisher.email) {
|
|
|
+ manifest.author = manifest.author + ' <' + json.publisher.email + '>';
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (json.license) {
|
|
|
+ manifest.license = json.license;
|
|
|
+ }
|
|
|
+ if (json.Model && json.Model.Signature) {
|
|
|
+ return context.request(json.Model.Signature).then((stream) => {
|
|
|
+ const buffer = stream.peek();
|
|
|
+ const text = decoder.decode(buffer);
|
|
|
+ manifest.signature = JSON.parse(text);
|
|
|
+ return manifest;
|
|
|
+ }).catch (() => {
|
|
|
+ return manifest;
|
|
|
+ });
|
|
|
}
|
|
|
}
|
|
|
- return dashes.join('-');
|
|
|
+ return manifest;
|
|
|
}
|
|
|
- }
|
|
|
- return null;
|
|
|
+ catch (err) {
|
|
|
+ throw new mxnet.Error('Failed to read manifest. ' + err.message);
|
|
|
+ }
|
|
|
+ };
|
|
|
+ return context.request('MANIFEST.json').then((stream) => {
|
|
|
+ return parse(stream);
|
|
|
+ }).catch (() => {
|
|
|
+ return context.request('MAR-INF/MANIFEST.json').then((stream) => {
|
|
|
+ return parse(stream);
|
|
|
+ }).catch(() => {
|
|
|
+ return parse(null);
|
|
|
+ });
|
|
|
+ });
|
|
|
};
|
|
|
- const open_model = (metadata, format, manifest, symbol, signature, params) => {
|
|
|
+ const createModel = (metadata, manifest, symbol, params) => {
|
|
|
const parameters = new Map();
|
|
|
if (params) {
|
|
|
try {
|
|
|
@@ -72,166 +177,66 @@ mxnet.ModelFactory = class {
|
|
|
// continue regardless of error
|
|
|
}
|
|
|
}
|
|
|
- return new mxnet.Model(metadata, format, manifest, symbol, signature, parameters);
|
|
|
+ if (symbol) {
|
|
|
+ if (!manifest.format) {
|
|
|
+ const version = convertVersion(symbol && symbol.attrs && symbol.attrs.mxnet_version ? symbol.attrs.mxnet_version : null);
|
|
|
+ manifest.format = 'MXNet' + (version ? ' v' + version : '');
|
|
|
+ }
|
|
|
+ if (symbol.nodes && symbol.nodes.some((node) => node && node.op == 'tvm_op')) {
|
|
|
+ manifest.producer = 'TVM';
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return new mxnet.Model(metadata, manifest, symbol, parameters);
|
|
|
};
|
|
|
const identifier = context.identifier;
|
|
|
- const extension = context.identifier.split('.').pop().toLowerCase();
|
|
|
- let symbol = null;
|
|
|
- let params = null;
|
|
|
- let format = null;
|
|
|
- let base = null;
|
|
|
+ const extension = identifier.split('.').pop().toLowerCase();
|
|
|
switch (extension) {
|
|
|
- case 'json':
|
|
|
+ case 'json': {
|
|
|
+ let symbol = null;
|
|
|
try {
|
|
|
symbol = context.open('json');
|
|
|
- if (symbol && symbol.nodes && symbol.nodes.some((node) => node && node.op == 'tvm_op')) {
|
|
|
- format = 'TVM';
|
|
|
- }
|
|
|
}
|
|
|
catch (error) {
|
|
|
const message = error && error.message ? error.message : error.toString();
|
|
|
throw new mxnet.Error("Failed to load symbol entry (" + message.replace(/\.$/, '') + ').');
|
|
|
}
|
|
|
- base = basename(identifier, 'json', 'symbol');
|
|
|
- if (base) {
|
|
|
- return context.request(base + '-0000.params', null).then((stream) => {
|
|
|
- const buffer = stream.peek();
|
|
|
- return open_model(metadata, format, null, symbol, null, buffer);
|
|
|
- }).catch(() => {
|
|
|
- return open_model(metadata, format, null, symbol, null, params);
|
|
|
- });
|
|
|
- }
|
|
|
- return open_model(metadata, format, null, symbol, null, null);
|
|
|
- case 'params':
|
|
|
- params = context.stream.peek();
|
|
|
- base = basename(context.identifier, 'params');
|
|
|
- if (base) {
|
|
|
- return context.request(base + '-symbol.json', 'utf-8').then((text) => {
|
|
|
- symbol = JSON.parse(text);
|
|
|
- if (symbol && symbol.nodes && symbol.nodes.some((node) => node && node.op == 'tvm_op')) {
|
|
|
- format = 'TVM';
|
|
|
- }
|
|
|
- return open_model(metadata, format, null, symbol, null, params);
|
|
|
- }).catch(() => {
|
|
|
- return open_model(metadata, format, null, null, null, params);
|
|
|
- });
|
|
|
- }
|
|
|
- return open_model(metadata, format, null, null, null, params);
|
|
|
- case 'mar':
|
|
|
- case 'model': {
|
|
|
- const entries = new Map();
|
|
|
- try {
|
|
|
- for (const entry of context.entries('zip')) {
|
|
|
- entries.set(entry.name, entry);
|
|
|
+ const requestParams = (manifest) => {
|
|
|
+ const file = basename(manifest.params, identifier, '.json', 'symbol', '-0000.params');
|
|
|
+ if (file) {
|
|
|
+ return context.request(file, null).then((stream) => {
|
|
|
+ const buffer = stream.peek();
|
|
|
+ return createModel(metadata, manifest, symbol, buffer);
|
|
|
+ }).catch(() => {
|
|
|
+ return createModel(metadata, manifest, symbol, null);
|
|
|
+ });
|
|
|
}
|
|
|
- }
|
|
|
- catch (err) {
|
|
|
- throw new mxnet.Error('Failed to decompress Zip archive. ' + err.message);
|
|
|
- }
|
|
|
-
|
|
|
- let manifestEntry = entries.get(entries.has('MANIFEST.json') ? 'MANIFEST.json' : 'MAR-INF/MANIFEST.json');
|
|
|
- let rootFolder = '';
|
|
|
- if (!manifestEntry) {
|
|
|
- const folders = Array.from(entries.keys()).filter((name) => name.endsWith('/')).filter((name) => entries.get(name + 'MANIFEST.json'));
|
|
|
- if (folders.length != 1) {
|
|
|
- throw new mxnet.Error("Manifest not found.");
|
|
|
- }
|
|
|
- rootFolder = folders[0];
|
|
|
- manifestEntry = entries.get(rootFolder + 'MANIFEST.json');
|
|
|
- }
|
|
|
-
|
|
|
- const decoder = new TextDecoder('utf-8');
|
|
|
- let manifest = null;
|
|
|
- try {
|
|
|
- manifest = JSON.parse(decoder.decode(manifestEntry.data));
|
|
|
- }
|
|
|
- catch (err) {
|
|
|
- throw new mxnet.Error('Failed to read manifest. ' + err.message);
|
|
|
- }
|
|
|
-
|
|
|
- let modelFormat = null;
|
|
|
- let symbolEntry = null;
|
|
|
- let signatureEntry = null;
|
|
|
- let paramsEntry = null;
|
|
|
- if (manifest.Model) {
|
|
|
- modelFormat = manifest.Model['Model-Format'];
|
|
|
- if (modelFormat && modelFormat != 'MXNet-Symbolic') {
|
|
|
- throw new mxnet.Error('Model format \'' + modelFormat + '\' not supported.');
|
|
|
- }
|
|
|
- format = 'MXNet Model Server';
|
|
|
- if (manifest['Model-Archive-Version']) {
|
|
|
- format += ' v' + manifest['Model-Archive-Version'].toString();
|
|
|
- }
|
|
|
- if (!manifest.Model.Symbol) {
|
|
|
- throw new mxnet.Error('Manifest does not contain symbol entry.');
|
|
|
- }
|
|
|
- symbolEntry = entries.get(rootFolder + manifest.Model.Symbol);
|
|
|
- if (manifest.Model.Signature) {
|
|
|
- signatureEntry = entries.get(rootFolder + manifest.Model.Signature);
|
|
|
- }
|
|
|
- if (manifest.Model.Parameters) {
|
|
|
- paramsEntry = entries.get(rootFolder + manifest.Model.Parameters);
|
|
|
- }
|
|
|
- }
|
|
|
- else if (manifest.model) {
|
|
|
- format = 'MXNet Model Archive';
|
|
|
- if (manifest.specificationVersion) {
|
|
|
- format += ' v' + manifest.specificationVersion.toString();
|
|
|
- }
|
|
|
- if (manifest.model.modelName) {
|
|
|
- symbolEntry = entries.get(rootFolder + manifest.model.modelName + '-symbol.json');
|
|
|
- let key = null;
|
|
|
- for (key of Array.from(entries.keys())) {
|
|
|
- key = key.substring(rootFolder.length);
|
|
|
- if (key.endsWith('.params') && key.startsWith(manifest.model.modelName)) {
|
|
|
- paramsEntry = entries.get(key);
|
|
|
- break;
|
|
|
- }
|
|
|
- }
|
|
|
- if (!symbolEntry && !paramsEntry) {
|
|
|
- for (key of Object.keys(entries)) {
|
|
|
- key = key.substring(rootFolder.length);
|
|
|
- if (key.endsWith('.params')) {
|
|
|
- paramsEntry = entries.get(key);
|
|
|
- break;
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- else {
|
|
|
- throw new mxnet.Error('Manifest does not contain model.');
|
|
|
- }
|
|
|
-
|
|
|
- if (!symbolEntry && !paramsEntry) {
|
|
|
- throw new mxnet.Error("Model does not contain symbol entry.");
|
|
|
- }
|
|
|
-
|
|
|
- try {
|
|
|
- if (symbolEntry) {
|
|
|
- symbol = JSON.parse(decoder.decode(symbolEntry.data));
|
|
|
- }
|
|
|
- }
|
|
|
- catch (err) {
|
|
|
- throw new mxnet.Error('Failed to load symbol entry.' + err.message);
|
|
|
- }
|
|
|
-
|
|
|
- if (paramsEntry) {
|
|
|
- params = paramsEntry.data;
|
|
|
- }
|
|
|
- let signature = null;
|
|
|
- try {
|
|
|
- if (signatureEntry) {
|
|
|
- signature = JSON.parse(decoder.decode(signatureEntry.data));
|
|
|
+ return createModel(metadata, manifest, symbol, null);
|
|
|
+ };
|
|
|
+ return requestManifest().then((manifest) => {
|
|
|
+ return requestParams(manifest);
|
|
|
+ });
|
|
|
+ }
|
|
|
+ case 'params': {
|
|
|
+ const params = context.stream.peek();
|
|
|
+ const requestSymbol = (manifest) => {
|
|
|
+ const file = basename(manifest.symbol, identifier, '.params', null, '-symbol.json');
|
|
|
+ if (file) {
|
|
|
+ return context.request(file, 'utf-8').then((text) => {
|
|
|
+ const symbol = JSON.parse(text);
|
|
|
+ return createModel(metadata, manifest, symbol, params);
|
|
|
+ }).catch(() => {
|
|
|
+ return createModel(metadata, manifest, null, params);
|
|
|
+ });
|
|
|
}
|
|
|
- }
|
|
|
- catch (err) {
|
|
|
- // continue regardless of error
|
|
|
- }
|
|
|
- return open_model(metadata, format, manifest, symbol, signature, params);
|
|
|
+ return createModel(metadata, manifest, null, params);
|
|
|
+ };
|
|
|
+ return requestManifest().then((manifest) => {
|
|
|
+ return requestSymbol(manifest);
|
|
|
+ });
|
|
|
}
|
|
|
- default:
|
|
|
+ default: {
|
|
|
throw new mxnet.Error('Unsupported file extension.');
|
|
|
+ }
|
|
|
}
|
|
|
});
|
|
|
}
|
|
|
@@ -239,7 +244,7 @@ mxnet.ModelFactory = class {
|
|
|
|
|
|
mxnet.Model = class {
|
|
|
|
|
|
- constructor(metadata, format, manifest, symbol, signature, params) {
|
|
|
+ constructor(metadata, manifest, symbol, params) {
|
|
|
if (!symbol && !params) {
|
|
|
throw new mxnet.Error('JSON symbol data not available.');
|
|
|
}
|
|
|
@@ -254,67 +259,25 @@ mxnet.Model = class {
|
|
|
throw new mxnet.Error('JSON file does not contain an MXNet \'heads\' property.');
|
|
|
}
|
|
|
}
|
|
|
-
|
|
|
- if (manifest) {
|
|
|
- if (manifest.Model && manifest.Model['Model-Name']) {
|
|
|
- this._name = manifest.Model['Model-Name'];
|
|
|
- }
|
|
|
- if (manifest.Model && manifest.Model.Description && this._name != manifest.Model.Description) {
|
|
|
- this._description = manifest.Model.Description;
|
|
|
- }
|
|
|
- if (manifest.Engine && manifest.Engine.MXNet) {
|
|
|
- const engineVersion = mxnet.Model._convert_version(manifest.Engine.MXNet);
|
|
|
- this._runtime = 'MXNet v' + (engineVersion ? engineVersion : manifest.Engine.MXNet.toString());
|
|
|
- }
|
|
|
- if (manifest.License) {
|
|
|
- this._license = manifest.License;
|
|
|
- }
|
|
|
- if (manifest.model && manifest.model.modelName) {
|
|
|
- this._name = manifest.model.modelName;
|
|
|
- }
|
|
|
- if (manifest.model && manifest.model.modelVersion) {
|
|
|
- this._version = manifest.model.modelVersion;
|
|
|
- }
|
|
|
- if (manifest.model && manifest.model.modelName && this._name != manifest.model.description) {
|
|
|
- this._description = manifest.model.description;
|
|
|
- }
|
|
|
- if (manifest.runtime) {
|
|
|
- this._runtime = manifest.runtime;
|
|
|
- }
|
|
|
- if (manifest.engine && manifest.engine.engineName) {
|
|
|
- const engine = manifest.engine.engineVersion ? manifest.engine.engineName + ' ' + manifest.engine.engineVersion : manifest.engine.engineName;
|
|
|
- this._runtime = this._runtime ? (this._runtime + ' (' + engine + ')') : engine;
|
|
|
- }
|
|
|
- if (manifest.publisher && manifest.publisher.author) {
|
|
|
- this._author = manifest.publisher.author;
|
|
|
- if (manifest.publisher.email) {
|
|
|
- this._author = this._author + ' <' + manifest.publisher.email + '>';
|
|
|
- }
|
|
|
- }
|
|
|
- if (manifest.license) {
|
|
|
- this._license = manifest.license;
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- this._format = format;
|
|
|
- if (!this._format && symbol && symbol.attrs && symbol.attrs.mxnet_version) {
|
|
|
- const version = mxnet.Model._convert_version(symbol.attrs.mxnet_version);
|
|
|
- if (version) {
|
|
|
- this._format = 'MXNet v' + version;
|
|
|
- }
|
|
|
- }
|
|
|
- if (!this._format) {
|
|
|
- this._format = 'MXNet';
|
|
|
- }
|
|
|
-
|
|
|
- this._graphs = [];
|
|
|
- this._graphs.push(new mxnet.Graph(metadata, manifest, symbol, signature, params));
|
|
|
+ this._format = manifest.format || 'MXNet';
|
|
|
+ this._producer = manifest.producer || '';
|
|
|
+ this._name = manifest.name || '';
|
|
|
+ this._version = manifest.version;
|
|
|
+ this._description = manifest.description || '';
|
|
|
+ this._runtime = manifest.runtime || '';
|
|
|
+ this._author = manifest.author || '';
|
|
|
+ this._license = manifest.license || '';
|
|
|
+ this._graphs = [ new mxnet.Graph(metadata, manifest, symbol, params) ];
|
|
|
}
|
|
|
|
|
|
get format() {
|
|
|
return this._format;
|
|
|
}
|
|
|
|
|
|
+ get producer() {
|
|
|
+ return this._producer;
|
|
|
+ }
|
|
|
+
|
|
|
get name() {
|
|
|
return this._name;
|
|
|
}
|
|
|
@@ -342,23 +305,11 @@ mxnet.Model = class {
|
|
|
get graphs() {
|
|
|
return this._graphs;
|
|
|
}
|
|
|
-
|
|
|
- static _convert_version(value) {
|
|
|
- if (Array.isArray(value)) {
|
|
|
- if (value.length == 2 && value[0] == 'int') {
|
|
|
- const major = Math.floor(value[1] / 10000) % 100;
|
|
|
- const minor = Math.floor(value[1] / 100) % 100;
|
|
|
- const patch = Math.floor(value[1]) % 100;
|
|
|
- return [ major.toString(), minor.toString(), patch.toString() ].join('.');
|
|
|
- }
|
|
|
- }
|
|
|
- return null;
|
|
|
- }
|
|
|
};
|
|
|
|
|
|
mxnet.Graph = class {
|
|
|
|
|
|
- constructor(metadata, manifest, symbol, signature, params) {
|
|
|
+ constructor(metadata, manifest, symbol, params) {
|
|
|
this._metadata = metadata;
|
|
|
this._nodes = [];
|
|
|
this._inputs = [];
|
|
|
@@ -376,14 +327,14 @@ mxnet.Graph = class {
|
|
|
if (symbol) {
|
|
|
const nodes = symbol.nodes;
|
|
|
const inputs = {};
|
|
|
- if (signature && signature.inputs) {
|
|
|
- for (const input of signature.inputs) {
|
|
|
+ if (manifest && manifest.signature && manifest.signature.inputs) {
|
|
|
+ for (const input of manifest.signature.inputs) {
|
|
|
inputs[input.data_name] = input;
|
|
|
}
|
|
|
}
|
|
|
const outputs = {};
|
|
|
- if (signature && signature.outputs) {
|
|
|
- for (const output of signature.outputs) {
|
|
|
+ if (manifest && manifest.signature && manifest.signature.outputs) {
|
|
|
+ for (const output of manifest.signature.outputs) {
|
|
|
outputs[output.data_name] = output;
|
|
|
}
|
|
|
}
|