Jelajahi Sumber

Update mnn.js (#341)

Lutz Roeder 3 tahun lalu
induk
melakukan
8403a3d613
1 mengubah file dengan 129 tambahan dan 103 penghapusan
  1. 129 103
      source/mnn.js

+ 129 - 103
source/mnn.js

@@ -72,44 +72,65 @@ mnn.Graph = class {
         this._nodes = [];
         this._inputs = [];
         this._outputs = [];
-        const inputSet = new Set();
         for (let i = 0; i < net.tensorName.length; i++) {
             if (net.tensorName[i] === '') {
                 net.tensorName[i] = '\n' + i.toString();
             }
         }
-        for (let i = 0; i < net.oplists.length; i++) {
-            const op = net.oplists[i];
-            if (op.type === mnn.schema.OpType.Input) {
-                const args = [];
-                for (let j = 0; j < op.outputIndexes.length; j++) {
-                    const index = op.outputIndexes[j];
-                    const name = net.tensorName[index];
+        const inputs = new Map();
+        for (const op of net.oplists) {
+            for (const input of op.inputIndexes) {
+                inputs.set(input, (inputs.get(input) || 0) + 1);
+            }
+        }
+        const consts = new Map();
+        const oplists = net.oplists.filter((op) => {
+            if (op.type === mnn.schema.OpType.Const &&
+                op.inputIndexes.length === 0 &&
+                op.outputIndexes.length === 1 &&
+                op.main instanceof mnn.schema.Blob &&
+                inputs.get(op.outputIndexes[0]) === 1) {
+                consts.set(op.outputIndexes[0], op);
+                return false;
+            }
+            return true;
+        });
+        const args = new Map();
+        const arg = (index) => {
+            if (!args.has(index)) {
+                const name = net.tensorName[index];
+                const op = consts.get(index);
+                if (op) {
+                    const tensor = op ? mnn.Utility.createTensor(op.main, 'Const') : null;
+                    const argument = new mnn.Argument(name, null, tensor);
+                    args.set(index, argument);
+                }
+                else {
                     const extraTensorDescribe = net.extraTensorDescribe[index];
                     const blob = extraTensorDescribe ? extraTensorDescribe.blob : null;
-                    const type = blob ? new mnn.TensorType(blob.dataType, new mnn.TensorShape(blob.dims)) : null;
-                    args.push(new mnn.Argument(name, type, null));
+                    const type = blob && blob.dims && blob.dims.length > 0 ? new mnn.TensorType(blob.dataType, new mnn.TensorShape(blob.dims), blob.dataFormat) : null;
+                    const argument = new mnn.Argument(name, type, null);
+                    args.set(index, argument);
                 }
+            }
+            return args.get(index);
+        };
+
+        for (const op of oplists) {
+            if (op.type === mnn.schema.OpType.Input) {
+                const args = Array.from(op.outputIndexes).map((index) => arg(index));
                 this._inputs.push(new mnn.Parameter(op.name, true, args));
             }
             else {
-                this._nodes.push(new mnn.Node(metadata, op, net));
-            }
-            for (let k = 0; k < op.inputIndexes.length; k++) {
-                const index = op.inputIndexes[k];
-                inputSet.add(index);
+                this._nodes.push(new mnn.Node(metadata, op, net, arg));
             }
         }
 
         for (let i = 0; i < net.tensorName.length; i++) {
-            if (!inputSet.has(i)) {
-                const name = net.tensorName[i];
-                const extraTensorDescribe = net.extraTensorDescribe[i];
-                const blob = extraTensorDescribe ? extraTensorDescribe.blob : null;
-                const type = blob ? new mnn.TensorType(blob.dataType, new mnn.TensorShape(blob.dims)) : null;
-                this._outputs.push(new mnn.Parameter(name, true, [
-                    new mnn.Argument(name, type, null)
-                ]));
+            if (!inputs.has(i)) {
+                const argument = arg(i);
+                const parameter = new mnn.Parameter(argument.name, true, [ argument ]);
+                this._outputs.push(parameter);
             }
         }
     }
@@ -133,7 +154,7 @@ mnn.Graph = class {
 
 mnn.Node = class {
 
-    constructor(metadata, op, net) {
+    constructor(metadata, op, net, arg) {
         const type = mnn.Utility.enum('OpType', op.type) || '(' + op.type.toString() + ')';
         this._type = metadata.type(type) || { name: type };
         this._name = op.name || '';
@@ -141,97 +162,73 @@ mnn.Node = class {
         this._inputs = [];
         this._outputs = [];
         this._chains = [];
-        const inputs = [];
-        for (let i = 0; i < op.inputIndexes.length; i++) {
-            const index = op.inputIndexes[i];
-            const id = net.tensorName[index];
-            inputs.push(new mnn.Argument(id, null, null));
+        if (op.inputIndexes && op.inputIndexes.length > 0) {
+            this._inputs.push(new mnn.Parameter('input', true, Array.from(op.inputIndexes).map((index) => arg(index))));
         }
-        this._inputs.push(new mnn.Parameter('input', true, inputs));
-        const outputs = [];
-        for (let i = 0; i < op.outputIndexes.length; i++) {
-            const index = op.outputIndexes[i];
-            const name = net.tensorName[index];
-            outputs.push(new mnn.Argument(name, null, null));
+        if (op.outputIndexes && op.outputIndexes.length > 0) {
+            this._outputs.push(new mnn.Parameter('output', true, Array.from(op.outputIndexes).map((index) => arg(index))));
         }
-        this._outputs.push(new mnn.Parameter('output', true, outputs));
-
-        const ignoreAttributes = new Set();
-        const parameter = op.main;
-        if (parameter) {
-            const parameters = [ parameter ];
-            if (parameter instanceof mnn.schema.Blob) {
-                const type = new mnn.TensorType(parameter.dataType, new mnn.TensorShape(parameter.dims));
-                let data = null;
-                switch (type.dataType) {
-                    case 'uint8': data = parameter.uint8s; break;
-                    case 'int8': data = parameter.int8s; break;
-                    case 'int32': data = parameter.int32s; break;
-                    case 'int64': data = parameter.int64s; break;
-                    case 'float16': data = parameter.uint8s; break;
-                    case 'float32': data = parameter.float32s; break;
-                    default:
-                        throw new mnn.Error("Unknown blob data type '" + JSON.stringify(type.dataType) + "'.");
-                }
-                this._inputs.push(new mnn.Parameter('value', true, [
-                    new mnn.Argument('', null, new mnn.Tensor('Blob', type, data))
-                ]));
+        const param = op.main;
+        if (param) {
+            const parameters = [ param ];
+            if (param instanceof mnn.schema.Blob) {
+                const tensor = mnn.Utility.createTensor(param, 'Blob');
+                const argument = new mnn.Argument('', null, tensor);
+                const parameter = new mnn.Parameter('value', true, [ argument ]);
+                this._inputs.push(parameter);
                 parameters.splice(0, parameters.length);
             }
-            else if (parameter instanceof mnn.schema.Convolution2D) {
-                const common = parameter.common;
+            else if (param instanceof mnn.schema.Convolution2D) {
+                const common = param.common;
                 const outputCount = common.outputCount;
                 const inputCount = common.inputCount;
                 const kernelX = common.kernelX;
                 const kernelY = common.kernelY;
-                this._buildTensor(mnn.schema.DataType.DT_FLOAT, 'weight', [ outputCount, inputCount, kernelX, kernelY ], parameter.weight);
-                this._buildTensor(mnn.schema.DataType.DT_FLOAT, 'bias', [ outputCount ], parameter.bias);
-                ignoreAttributes.add('weight');
-                ignoreAttributes.add('bias');
-                ignoreAttributes.add('quanParameter');
-                ignoreAttributes.add('symmetricQuan');
+                this._buildTensor('weight', mnn.schema.DataType.DT_FLOAT, [ outputCount, inputCount, kernelX, kernelY ], param.weight);
+                this._buildTensor('bias', mnn.schema.DataType.DT_FLOAT, [ outputCount ], param.bias);
+                delete param.weight;
+                delete param.bias;
+                delete param.quanParameter;
+                delete param.symmetricQuan;
             }
-            else if (parameter instanceof mnn.schema.InnerProduct) {
-                const outputCount = parameter.outputCount;
-                const inputCount = parameter.weightSize / outputCount;
-                this._buildTensor(mnn.schema.DataType.DT_FLOAT, 'weight', [ outputCount, inputCount ], parameter.weight);
-                this._buildTensor(mnn.schema.DataType.DT_FLOAT, 'bias', [ outputCount ], parameter.bias);
-                ignoreAttributes.add('weight');
-                ignoreAttributes.add('bias');
-                ignoreAttributes.add('quanParameter');
+            else if (param instanceof mnn.schema.InnerProduct) {
+                const outputCount = param.outputCount;
+                const inputCount = param.weightSize / outputCount;
+                this._buildTensor('weight', mnn.schema.DataType.DT_FLOAT, [ outputCount, inputCount ], param.weight);
+                this._buildTensor('bias', mnn.schema.DataType.DT_FLOAT, [ outputCount ], param.bias);
+                delete param.weight;
+                delete param.bias;
+                delete param.quanParameter;
             }
-            else if (parameter instanceof mnn.schema.Scale) {
-                const scaleDataCount = parameter.channels;
-                this._buildTensor(mnn.schema.DataType.DT_FLOAT, 'scale', [ scaleDataCount ], parameter.scaleData);
-                this._buildTensor(mnn.schema.DataType.DT_FLOAT, 'bias', [ scaleDataCount ], parameter.biasData);
-                ignoreAttributes.add('scaleData');
-                ignoreAttributes.add('biasData');
+            else if (param instanceof mnn.schema.Scale) {
+                const scaleDataCount = param.channels;
+                this._buildTensor('scale', mnn.schema.DataType.DT_FLOAT, [ scaleDataCount ], param.scaleData);
+                this._buildTensor('bias', mnn.schema.DataType.DT_FLOAT, [ scaleDataCount ], param.biasData);
+                delete param.scaleData;
+                delete param.biasData;
             }
-            else if (parameter instanceof mnn.schema.BatchNorm) {
-                const channels = parameter.channels;
-                this._buildTensor(mnn.schema.DataType.DT_FLOAT, 'mean', [ channels ], parameter.meanData);
-                this._buildTensor(mnn.schema.DataType.DT_FLOAT, 'slope', [ channels ], parameter.slopeData);
-                this._buildTensor(mnn.schema.DataType.DT_FLOAT, 'variance', [ channels ], parameter.varData);
-                this._buildTensor(mnn.schema.DataType.DT_FLOAT, 'bias', [ channels ], parameter.biasData);
-                ignoreAttributes.add('slopeData');
-                ignoreAttributes.add('meanData');
-                ignoreAttributes.add('varData');
-                ignoreAttributes.add('biasData');
+            else if (param instanceof mnn.schema.BatchNorm) {
+                const channels = param.channels;
+                this._buildTensor('mean', mnn.schema.DataType.DT_FLOAT, [ channels ], param.meanData);
+                this._buildTensor('slope', mnn.schema.DataType.DT_FLOAT, [ channels ], param.slopeData);
+                this._buildTensor('variance', mnn.schema.DataType.DT_FLOAT, [ channels ], param.varData);
+                this._buildTensor('bias', mnn.schema.DataType.DT_FLOAT, [ channels ], param.biasData);
+                delete param.slopeData;
+                delete param.meanData;
+                delete param.varData;
+                delete param.biasData;
             }
-            else if (parameter instanceof mnn.schema.PRelu) {
-                this._buildTensor(mnn.schema.DataType.DT_FLOAT, 'slope', [ parameter.slopeCount ], parameter.slope);
-                ignoreAttributes.add('slope');
+            else if (param instanceof mnn.schema.PRelu) {
+                this._buildTensor('slope', mnn.schema.DataType.DT_FLOAT, [ param.slopeCount ], param.slope);
+                delete param.slopeCount;
             }
-            else if (parameter instanceof mnn.schema.Normalize) {
-                this._buildTensor(mnn.schema.DataType.DT_FLOAT, 'scale', [ parameter.scale.length ], parameter.scale);
-                ignoreAttributes.add('scale');
+            else if (param instanceof mnn.schema.Normalize) {
+                this._buildTensor('scale', mnn.schema.DataType.DT_FLOAT, [ param.scale.length ], param.scale);
+                delete param.scale;
             }
             while (parameters.length > 0) {
                 const parameter = parameters.shift();
                 for (const key of Object.keys(parameter)) {
-                    if (ignoreAttributes && ignoreAttributes.has(key)) {
-                        continue;
-                    }
                     if (Object.prototype.hasOwnProperty.call(parameter, key)) {
                         const value = parameter[key];
                         if (Object.keys(mnn.schema).find((key) => mnn.schema[key].prototype && value instanceof mnn.schema[key])) {
@@ -246,10 +243,13 @@ mnn.Node = class {
         }
     }
 
-    _buildTensor(dataType, name, dimensions, value) {
-        this._inputs.push(new mnn.Parameter(name, true, [
-            new mnn.Argument('', null, new mnn.Tensor('Weight', new mnn.TensorType(dataType, new mnn.TensorShape(dimensions)), value))
-        ]));
+    _buildTensor(name, dataType, dimensions, value) {
+        const shape = new mnn.TensorShape(dimensions);
+        const type = new mnn.TensorType(dataType, shape);
+        const tensor = new mnn.Tensor('Weight', type, value);
+        const argument = new mnn.Argument('', null, tensor);
+        const parameter = new mnn.Parameter(name, true, [ argument ]);
+        this._inputs.push(parameter);
     }
 
     get type() {
@@ -466,9 +466,15 @@ mnn.Tensor = class {
 
 mnn.TensorType = class {
 
-    constructor(dataType, shape) {
+    constructor(dataType, shape, format) {
         this._dataType = mnn.Utility.dataType(dataType);
         this._shape = shape;
+        switch (format) {
+            case mnn.schema.MNN_DATA_FORMAT.NCHW:   this._denotation = 'NCHW'; break;
+            case mnn.schema.MNN_DATA_FORMAT.NHWC:   this._denotation = 'NHWC'; break;
+            case mnn.schema.MNN_DATA_FORMAT.NC4HW4: this._denotation = 'NC4HW4'; break;
+            case mnn.schema.MNN_DATA_FORMAT.NHWC4:  this._denotation = 'NHWC4'; break;
+        }
     }
 
     get dataType() {
@@ -479,6 +485,10 @@ mnn.TensorType = class {
         return this._shape;
     }
 
+    get denotation() {
+        return this._denotation;
+    }
+
     toString() {
         return this._dataType + this._shape.toString();
     }
@@ -599,6 +609,22 @@ mnn.Utility = class {
         }
         return value.toString();
     }
+
+    static createTensor(param, kind) {
+        const type = new mnn.TensorType(param.dataType, new mnn.TensorShape(param.dims), param.dataFormat);
+        let data = null;
+        switch (type.dataType) {
+            case 'uint8': data = param.uint8s; break;
+            case 'int8': data = param.int8s; break;
+            case 'int32': data = param.int32s; break;
+            case 'int64': data = param.int64s; break;
+            case 'float16': data = param.uint8s; break;
+            case 'float32': data = param.float32s; break;
+            default:
+                throw new mnn.Error("Unknown blob data type '" + JSON.stringify(type.dataType) + "'.");
+        }
+        return new mnn.Tensor(kind, type, data);
+    }
 };
 
 mnn.Error = class extends Error {