|
|
@@ -32,13 +32,54 @@ mxnet.ModelFactory = class {
|
|
|
}
|
|
|
|
|
|
open(context) {
|
|
|
- return Promise.resolve().then(() => {
|
|
|
+ return mxnet.Metadata.open(context).then((metadata) => {
|
|
|
+ const basename = (identifier, extension, suffix) => {
|
|
|
+ const dots = identifier.split('.');
|
|
|
+ if (dots.length >= 2 && dots.pop().toLowerCase() === extension) {
|
|
|
+ const dashes = dots.join('.').split('-');
|
|
|
+ if (dashes.length >= 2) {
|
|
|
+ const token = dashes.pop();
|
|
|
+ if (suffix) {
|
|
|
+ if (token != suffix) {
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ for (let i = 0; i < token.length; i++) {
|
|
|
+ const c = token.charAt(i);
|
|
|
+ if ((c >= '0' && c <= '9') || (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z')) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return dashes.join('-');
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return null;
|
|
|
+ };
|
|
|
+ const open_model = (metadata, format, manifest, symbol, signature, params) => {
|
|
|
+ const parameters = new Map();
|
|
|
+ if (params) {
|
|
|
+ try {
|
|
|
+ const stream = new ndarray.Stream(params);
|
|
|
+ for (const key of Object.keys(stream.arrays)) {
|
|
|
+ const name = (key.startsWith('arg:') || key.startsWith('aux:')) ? key.substring(4) : key;
|
|
|
+ parameters.set(name, stream.arrays[key]);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ catch (error) {
|
|
|
+ // continue regardless of error
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return new mxnet.Model(metadata, format, manifest, symbol, signature, parameters);
|
|
|
+ };
|
|
|
const identifier = context.identifier;
|
|
|
const extension = context.identifier.split('.').pop().toLowerCase();
|
|
|
let symbol = null;
|
|
|
let params = null;
|
|
|
let format = null;
|
|
|
- let basename = null;
|
|
|
+ let base = null;
|
|
|
switch (extension) {
|
|
|
case 'json':
|
|
|
try {
|
|
|
@@ -51,31 +92,31 @@ mxnet.ModelFactory = class {
|
|
|
const message = error && error.message ? error.message : error.toString();
|
|
|
throw new mxnet.Error("Failed to load symbol entry (" + message.replace(/\.$/, '') + ').');
|
|
|
}
|
|
|
- basename = mxnet.ModelFactory._basename(identifier, 'json', 'symbol');
|
|
|
- if (basename) {
|
|
|
- return context.request(basename + '-0000.params', null).then((stream) => {
|
|
|
+ base = basename(identifier, 'json', 'symbol');
|
|
|
+ if (base) {
|
|
|
+ return context.request(base + '-0000.params', null).then((stream) => {
|
|
|
const buffer = stream.peek();
|
|
|
- return this._openModel(format, null, symbol, null, buffer, context);
|
|
|
+ return open_model(metadata, format, null, symbol, null, buffer);
|
|
|
}).catch(() => {
|
|
|
- return this._openModel(format, null, symbol, null, params, context);
|
|
|
+ return open_model(metadata, format, null, symbol, null, params);
|
|
|
});
|
|
|
}
|
|
|
- return this._openModel(format, null, symbol, null, null, context);
|
|
|
+ return open_model(metadata, format, null, symbol, null, null);
|
|
|
case 'params':
|
|
|
params = context.stream.peek();
|
|
|
- basename = mxnet.ModelFactory._basename(context.identifier, 'params');
|
|
|
- if (basename) {
|
|
|
- return context.request(basename + '-symbol.json', 'utf-8').then((text) => {
|
|
|
+ base = basename(context.identifier, 'params');
|
|
|
+ if (base) {
|
|
|
+ return context.request(base + '-symbol.json', 'utf-8').then((text) => {
|
|
|
symbol = JSON.parse(text);
|
|
|
if (symbol && symbol.nodes && symbol.nodes.some((node) => node && node.op == 'tvm_op')) {
|
|
|
format = 'TVM';
|
|
|
}
|
|
|
- return this._openModel(format, null, symbol, null, params, context);
|
|
|
+ return open_model(metadata, format, null, symbol, null, params);
|
|
|
}).catch(() => {
|
|
|
- return this._openModel(format, null, null, null, params, context);
|
|
|
+ return open_model(metadata, format, null, null, null, params);
|
|
|
});
|
|
|
}
|
|
|
- return this._openModel(format, null, null, null, params, context);
|
|
|
+ return open_model(metadata, format, null, null, null, params);
|
|
|
case 'mar':
|
|
|
case 'model': {
|
|
|
const entries = new Map();
|
|
|
@@ -187,59 +228,13 @@ mxnet.ModelFactory = class {
|
|
|
catch (err) {
|
|
|
// continue regardless of error
|
|
|
}
|
|
|
-
|
|
|
- return this._openModel(format, manifest, symbol, signature, params, context);
|
|
|
+ return open_model(metadata, format, manifest, symbol, signature, params);
|
|
|
}
|
|
|
default:
|
|
|
throw new mxnet.Error('Unsupported file extension.');
|
|
|
}
|
|
|
});
|
|
|
}
|
|
|
-
|
|
|
- _openModel(format, manifest, symbol, signature, params, context) {
|
|
|
- return mxnet.Metadata.open(context).then((metadata) => {
|
|
|
- const parameters = new Map();
|
|
|
- if (params) {
|
|
|
- try {
|
|
|
- const stream = new ndarray.Stream(params);
|
|
|
- for (const key of Object.keys(stream.arrays)) {
|
|
|
- const name = (key.startsWith('arg:') || key.startsWith('aux:')) ? key.substring(4) : key;
|
|
|
- parameters.set(name, stream.arrays[key]);
|
|
|
- }
|
|
|
- }
|
|
|
- catch (error) {
|
|
|
- // continue regardless of error
|
|
|
- }
|
|
|
- }
|
|
|
- return new mxnet.Model(metadata, format, manifest, symbol, signature, parameters);
|
|
|
- });
|
|
|
- }
|
|
|
-
|
|
|
- static _basename(identifier, extension, suffix) {
|
|
|
- const dots = identifier.split('.');
|
|
|
- if (dots.length >= 2 && dots.pop().toLowerCase() === extension) {
|
|
|
- const dashes = dots.join('.').split('-');
|
|
|
- if (dashes.length >= 2) {
|
|
|
- const token = dashes.pop();
|
|
|
- if (suffix) {
|
|
|
- if (token != suffix) {
|
|
|
- return null;
|
|
|
- }
|
|
|
- }
|
|
|
- else {
|
|
|
- for (let i = 0; i < token.length; i++) {
|
|
|
- const c = token.charAt(i);
|
|
|
- if ((c >= '0' && c <= '9') || (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z')) {
|
|
|
- continue;
|
|
|
- }
|
|
|
- return null;
|
|
|
- }
|
|
|
- }
|
|
|
- return dashes.join('-');
|
|
|
- }
|
|
|
- }
|
|
|
- return null;
|
|
|
- }
|
|
|
};
|
|
|
|
|
|
mxnet.Model = class {
|
|
|
@@ -446,36 +441,32 @@ mxnet.Graph = class {
|
|
|
}
|
|
|
}
|
|
|
else if (params) {
|
|
|
- let block = null;
|
|
|
- const blocks = [];
|
|
|
- let separator = Object.keys(params).every((k) => k.indexOf('_') != -1) ? '_' : '';
|
|
|
+ const blocks = new Map();
|
|
|
+ let separator = Array.from(params.keys()).every((key) => key.indexOf('_') != -1) ? '_' : '';
|
|
|
if (separator.length == 0) {
|
|
|
- separator = Object.keys(params).every((k) => k.indexOf('.') != -1) ? '.' : '';
|
|
|
+ separator = Array.from(params.keys()).every((key) => key.indexOf('.') != -1) ? '.' : '';
|
|
|
}
|
|
|
if (separator.length > 0) {
|
|
|
- const blockMap = {};
|
|
|
- for (const id of Object.keys(params)) {
|
|
|
- const parts = id.split(separator);
|
|
|
+ for (const param of params) {
|
|
|
+ const key = param[0];
|
|
|
+ const parts = key.split(separator);
|
|
|
let argumentName = parts.pop();
|
|
|
- if (id.endsWith('moving_mean') || id.endsWith('moving_var')) {
|
|
|
+ if (key.endsWith('moving_mean') || key.endsWith('moving_var')) {
|
|
|
argumentName = [ parts.pop(), argumentName ].join(separator);
|
|
|
}
|
|
|
const nodeName = parts.join(separator);
|
|
|
- block = blockMap[nodeName];
|
|
|
- if (!block) {
|
|
|
- block = { name: nodeName, op: 'Weights', params: [] };
|
|
|
- blockMap[nodeName] = block;
|
|
|
- blocks.push(block);
|
|
|
+ if (!blocks.has(nodeName)) {
|
|
|
+ blocks.set(nodeName, { name: nodeName, op: 'Weights', params: [] });
|
|
|
}
|
|
|
- blockMap[nodeName].params.push({ name: argumentName, id: id });
|
|
|
+ blocks.get(nodeName).params.push({ name: argumentName, id: key });
|
|
|
}
|
|
|
}
|
|
|
else {
|
|
|
throw new mxnet.Error("Unsupported key format in params.");
|
|
|
}
|
|
|
|
|
|
- for (block of blocks) {
|
|
|
- this._nodes.push(new mxnet.Node(metadata, block, {}, {}, params));
|
|
|
+ for (const block of blocks.values()) {
|
|
|
+ this._nodes.push(new mxnet.Node(metadata, block, {}, {}, tensors));
|
|
|
}
|
|
|
}
|
|
|
}
|