Lutz Roeder 4 лет назад
Родитель
Сommit
8f6fa8d6db
1 измененных файлов с 68 добавлено и 77 удалено
  1. 68 77
      source/mxnet.js

+ 68 - 77
source/mxnet.js

@@ -32,13 +32,54 @@ mxnet.ModelFactory = class {
     }
 
     open(context) {
-        return Promise.resolve().then(() => {
+        return mxnet.Metadata.open(context).then((metadata) => {
+            const basename = (identifier, extension, suffix) => {
+                const dots = identifier.split('.');
+                if (dots.length >= 2 && dots.pop().toLowerCase() === extension) {
+                    const dashes = dots.join('.').split('-');
+                    if (dashes.length >= 2) {
+                        const token = dashes.pop();
+                        if (suffix) {
+                            if (token != suffix) {
+                                return null;
+                            }
+                        }
+                        else {
+                            for (let i = 0; i < token.length; i++) {
+                                const c = token.charAt(i);
+                                if ((c >= '0' && c <= '9') || (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z')) {
+                                    continue;
+                                }
+                                return null;
+                            }
+                        }
+                        return dashes.join('-');
+                    }
+                }
+                return null;
+            };
+            const open_model = (metadata, format, manifest, symbol, signature, params) => {
+                const parameters = new Map();
+                if (params) {
+                    try {
+                        const stream = new ndarray.Stream(params);
+                        for (const key of Object.keys(stream.arrays)) {
+                            const name = (key.startsWith('arg:') || key.startsWith('aux:')) ? key.substring(4) : key;
+                            parameters.set(name, stream.arrays[key]);
+                        }
+                    }
+                    catch (error) {
+                        // continue regardless of error
+                    }
+                }
+                return new mxnet.Model(metadata, format, manifest, symbol, signature, parameters);
+            };
             const identifier = context.identifier;
             const extension = context.identifier.split('.').pop().toLowerCase();
             let symbol = null;
             let params = null;
             let format = null;
-            let basename = null;
+            let base = null;
             switch (extension) {
                 case 'json':
                     try {
@@ -51,31 +92,31 @@ mxnet.ModelFactory = class {
                         const message = error && error.message ? error.message : error.toString();
                         throw new mxnet.Error("Failed to load symbol entry (" + message.replace(/\.$/, '') + ').');
                     }
-                    basename = mxnet.ModelFactory._basename(identifier, 'json', 'symbol');
-                    if (basename) {
-                        return context.request(basename + '-0000.params', null).then((stream) => {
+                    base = basename(identifier, 'json', 'symbol');
+                    if (base) {
+                        return context.request(base + '-0000.params', null).then((stream) => {
                             const buffer = stream.peek();
-                            return this._openModel(format, null, symbol, null, buffer, context);
+                            return open_model(metadata, format, null, symbol, null, buffer);
                         }).catch(() => {
-                            return this._openModel(format, null, symbol, null, params, context);
+                            return open_model(metadata, format, null, symbol, null, params);
                         });
                     }
-                    return this._openModel(format, null, symbol, null, null, context);
+                    return open_model(metadata, format, null, symbol, null, null);
                 case 'params':
                     params = context.stream.peek();
-                    basename = mxnet.ModelFactory._basename(context.identifier, 'params');
-                    if (basename) {
-                        return context.request(basename + '-symbol.json', 'utf-8').then((text) => {
+                    base = basename(context.identifier, 'params');
+                    if (base) {
+                        return context.request(base + '-symbol.json', 'utf-8').then((text) => {
                             symbol = JSON.parse(text);
                             if (symbol && symbol.nodes && symbol.nodes.some((node) => node && node.op == 'tvm_op')) {
                                 format  = 'TVM';
                             }
-                            return this._openModel(format, null, symbol, null, params, context);
+                            return open_model(metadata, format, null, symbol, null, params);
                         }).catch(() => {
-                            return this._openModel(format, null, null, null, params, context);
+                            return open_model(metadata, format, null, null, null, params);
                         });
                     }
-                    return this._openModel(format, null, null, null, params, context);
+                    return open_model(metadata, format, null, null, null, params);
                 case 'mar':
                 case 'model': {
                     const entries = new Map();
@@ -187,59 +228,13 @@ mxnet.ModelFactory = class {
                     catch (err) {
                         // continue regardless of error
                     }
-
-                    return this._openModel(format, manifest, symbol, signature, params, context);
+                    return open_model(metadata, format, manifest, symbol, signature, params);
                 }
                 default:
                     throw new mxnet.Error('Unsupported file extension.');
             }
         });
     }
-
-    _openModel(format, manifest, symbol, signature, params, context) {
-        return mxnet.Metadata.open(context).then((metadata) => {
-            const parameters = new Map();
-            if (params) {
-                try {
-                    const stream = new ndarray.Stream(params);
-                    for (const key of Object.keys(stream.arrays)) {
-                        const name = (key.startsWith('arg:') || key.startsWith('aux:')) ? key.substring(4) : key;
-                        parameters.set(name, stream.arrays[key]);
-                    }
-                }
-                catch (error) {
-                    // continue regardless of error
-                }
-            }
-            return new mxnet.Model(metadata, format, manifest, symbol, signature, parameters);
-        });
-    }
-
-    static _basename(identifier, extension, suffix) {
-        const dots = identifier.split('.');
-        if (dots.length >= 2 && dots.pop().toLowerCase() === extension) {
-            const dashes = dots.join('.').split('-');
-            if (dashes.length >= 2) {
-                const token = dashes.pop();
-                if (suffix) {
-                    if (token != suffix) {
-                        return null;
-                    }
-                }
-                else {
-                    for (let i = 0; i < token.length; i++) {
-                        const c = token.charAt(i);
-                        if ((c >= '0' && c <= '9') || (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z')) {
-                            continue;
-                        }
-                        return null;
-                    }
-                }
-                return dashes.join('-');
-            }
-        }
-        return null;
-    }
 };
 
 mxnet.Model = class {
@@ -446,36 +441,32 @@ mxnet.Graph = class {
             }
         }
         else if (params) {
-            let block = null;
-            const blocks = [];
-            let separator = Object.keys(params).every((k) => k.indexOf('_') != -1) ? '_' : '';
+            const blocks = new Map();
+            let separator = Array.from(params.keys()).every((key) => key.indexOf('_') != -1) ? '_' : '';
             if (separator.length == 0) {
-                separator = Object.keys(params).every((k) => k.indexOf('.') != -1) ? '.' : '';
+                separator = Array.from(params.keys()).every((key) => key.indexOf('.') != -1) ? '.' : '';
             }
             if (separator.length > 0) {
-                const blockMap = {};
-                for (const id of Object.keys(params)) {
-                    const parts = id.split(separator);
+                for (const param of params) {
+                    const key = param[0];
+                    const parts = key.split(separator);
                     let argumentName = parts.pop();
-                    if (id.endsWith('moving_mean') || id.endsWith('moving_var')) {
+                    if (key.endsWith('moving_mean') || key.endsWith('moving_var')) {
                         argumentName = [ parts.pop(), argumentName ].join(separator);
                     }
                     const nodeName = parts.join(separator);
-                    block = blockMap[nodeName];
-                    if (!block) {
-                        block = { name: nodeName, op: 'Weights', params: [] };
-                        blockMap[nodeName] = block;
-                        blocks.push(block);
+                    if (!blocks.has(nodeName)) {
+                        blocks.set(nodeName, { name: nodeName, op: 'Weights', params: [] });
                     }
-                    blockMap[nodeName].params.push({ name: argumentName, id: id });
+                    blocks.get(nodeName).params.push({ name: argumentName, id: key });
                 }
             }
             else {
                 throw new mxnet.Error("Unsupported key format in params.");
             }
 
-            for (block of blocks) {
-                this._nodes.push(new mxnet.Node(metadata, block, {}, {}, params));
+            for (const block of blocks.values()) {
+                this._nodes.push(new mxnet.Node(metadata, block, {}, {}, tensors));
             }
         }
     }