Lutz Roeder 4 лет назад
Родитель
Сommit
ff80fbe224
1 измененных файлов с 304 добавлено и 31 удалено
  1. 304 31
      source/acuity.js

+ 304 - 31
source/acuity.js

@@ -82,6 +82,9 @@ acuity.Graph = class {
                         const sizes = layer.parameters.size.split(' ');
                         shape = [0, parseInt(sizes[0]), parseInt(sizes[1]), layer.parameters.channels];
                     }
+                    if (shape && shape.length === 4 && shape[0] === 0) {
+                        shape[0] = 1;
+                    }
                 }
                 argument.shape = shape;
                 return argument;
@@ -323,15 +326,14 @@ acuity.TensorShape = class {
     }
 
     get dimensions() {
+        if (Array.isArray(this._dimensions) && this._dimensions.length == 1 && this._dimensions[0] == 0) {
+            return [];
+        }
         return this._dimensions;
     }
 
-    set dimensions(dimensions) {
-        this._dimensions = dimensions;
-    }
-
     toString() {
-        if (!this._dimensions || this._dimensions.length == 0) {
+        if (!Array.isArray(this._dimensions) || this._dimensions.length == 0 || (this._dimensions.length == 1 && this._dimensions[0] == 0)) {
             return '';
         }
         return '[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']';
@@ -424,47 +426,312 @@ acuity.Inference =  class {
                 this._outputs.set(output.name, layer);
             }
         }
+        this._broadcasts = new Set([
+            'add', 'equal', 'fllor_mod', 'floor_div', 'greater', 'greater_equal', 'less', 'less_equal',
+            'logical_and', 'logical_or', 'minimum', 'multiply', 'not_equal', 'pow', 'real_div',
+            'squared_difference', 'subtract'
+        ]);
         this._passthroughs = new Set([
-            'a_times_b_plus_c', 'abs', 'cast', 'clipbyvalue', 'dequantize', 'dtype_converter',
-            'elu', 'exp', 'floor', 'floor_div', 'hard_swish', 'leakyrelu', 'log', 'log_softmax',
-            'neg', 'pow', 'prelu', 'quantize', 'relu', 'relu_keras', 'relun', 'rsqrt', 'sigmoid',
-            'sin', 'softmax', 'softrelu', 'sqrt', 'square', 'tanh'
+            'LocalResponseNormalization', 'a_times_b_plus_c', 'abs', 'batchnorm_single', 'batchnormalize',
+            'cast', 'cast', 'clipbyvalue', 'dequantize', 'dtype_converter', 'elu', 'exp', 'floor',
+            'groupnormalize', 'hard_sigmoid', 'hard_swish', 'instancenormalize', 'l2normalize', 'l2normalizescale',
+            'layernormalize', 'leakyrelu', 'log', 'log_softmax', 'mish', 'neg', 'norm_with_channel_mean',
+            'norm_with_min_max', 'norm_with_scale', 'pow', 'prelu', 'quantize', 'relu', 'relu_keras',
+            'relun', 'reverse', 'round', 'rsqrt', 'sigmoid', 'sin', 'softmax', 'softrelu', 'sqrt', 'square', 'tanh'
+        ]);
+        this._reduces = new Set([
+            'reduceany', 'reducemax', 'reducemean', 'reducemin', 'reduceprod', 'reducesum'
         ]);
         this._operators = new Map();
-        this._operators.set('concat', (inputs, parameters) => {
+        this._operators.set('broadcast', (inputs) => {
+            const a = inputs[0];
+            const b = inputs[1];
+            const longer = a.length >= b.length ? a.slice() : b.slice();
+            const shorter = a.length < b.length ? a.slice() : b.slice();
+            const remain = longer.length - shorter.length;
+            for (let i = 0; i < remain; i++) {
+                shorter.splice(0, 0, 1);
+            }
+            for (let i = 0; i < longer.length; i++) {
+                longer[i] = longer[i] > shorter[i] ? longer[i] : shorter[i];
+            }
+            return [longer];
+        });
+        this._operators.set('concat', (inputs, params) => {
             const outputShape = inputs[0].slice();
-            outputShape[parameters.dim] = 0;
+            outputShape[params.dim] = 0;
             for (const shape of inputs) {
-                outputShape[parameters.dim] += shape[parameters.dim];
+                outputShape[params.dim] += shape[params.dim];
             }
             return [outputShape];
         });
-        this._operators.set('convolution', (inputs, parameters) => {
-            if (parameters.padding == 'VALID') {
-                const out_h = ~~((inputs[0][1] + parameters.stride_h - parameters.ksize_h) / parameters.stride_h);
-                const out_w = ~~((inputs[0][2] + parameters.stride_w - parameters.ksize_w) / parameters.stride_w);
-                return [[inputs[0][0], out_h, out_w, parameters.weights]];
+        this._operators.set('conv1d', (inputs, params) => {
+            if (params.padding == 'VALID') {
+                const out_h = ~~((inputs[0][1] + params.stride - params.ksize) / params.stride);
+                return [[inputs[0][0], out_h, params.weights]];
+            }
+            else if (params.padding == 'SAME') {
+                const out_h = ~~((inputs[0][1] + params.stride - 1) / params.stride);
+                return [[inputs[0][0], out_h, params.weights]];
+            }
+        });
+        this._operators.set('convolution', (inputs, params) => {
+            if (params.padding == 'VALID') {
+                const out_h = ~~((inputs[0][1] + params.stride_h + params.pad[0] + params.pad[1] - params.ksize_h) / params.stride_h);
+                const out_w = ~~((inputs[0][2] + params.stride_w + params.pad[2] + params.pad[3]- params.ksize_w) / params.stride_w);
+                return [[inputs[0][0], out_h, out_w, params.weights]];
             }
-            else if (parameters.padding == 'SAME') {
-                const out_h = ~~((inputs[0][1] + parameters.stride_h - 1) / parameters.stride_h);
-                const out_w = ~~((inputs[0][2] + parameters.stride_w - 1) / parameters.stride_w);
-                return [[inputs[0][0], out_h, out_w, parameters.weights]];
+            else if (params.padding == 'SAME') {
+                const out_h = ~~((inputs[0][1] + params.stride_h - 1) / params.stride_h);
+                const out_w = ~~((inputs[0][2] + params.stride_w - 1) / params.stride_w);
+                return [[inputs[0][0], out_h, out_w, params.weights]];
             }
         });
-        this._operators.set('fullconnect', (inputs, parameters) => {
-            return [inputs[0].slice(0, parameters.axis).concat([parameters.weights])];
+        this._operators.set('deconvolution', (inputs, params) => {
+            const newShape = params.output_shape.map((item, index) => {
+                return item == 0 ? inputs[0][index] : item;
+            });
+            return [ newShape ];
         });
-        this._operators.set('pooling', (inputs, parameters) => {
-            if (parameters.padding == 'VALID') {
-                const out_h = ~~((inputs[0][1] + parameters.stride_h - parameters.ksize_h) / parameters.stride_h);
-                const out_w = ~~((inputs[0][2] + parameters.stride_w - parameters.ksize_w) / parameters.stride_w);
-                return [[inputs[0][0], out_h, out_w, inputs[0][3]]];
+        this._operators.set('fullconnect', (inputs, params) => {
+            return [ inputs[0].slice(0, params.axis).concat([params.weights]) ];
+        });
+        this._operators.set('gather', (inputs, params) => {
+            const prefix = inputs[1].slice();
+            const suffix = inputs[0].slice(params.axis + 1);
+            const newShape = prefix.concat(suffix);
+            return [ newShape ];
+        });
+        this._operators.set('lstm', (inputs, params) => {
+            let batch = inputs[0][0];
+            const output = params.num_proj != null ? params.num_proj : params.weights;
+            if (params.time_major) {
+                batch = inputs[0][1];
+            }
+            let newShape = [];
+            if (params.return_sequences) {
+                newShape = [inputs[0][0], inputs[0][1], output];
+            }
+            else {
+                newShape = [batch, output];
+            }
+            return [ newShape, [batch, output], [batch, params.weights] ];
+        });
+        this._operators.set('matmul', (inputs, params) => {
+            const a = inputs[0];
+            const b = inputs[1];
+            let newShape = a.slice(0, -2);
+            if (params.transpose_a) {
+                newShape = newShape.concat(a.slice(-1));
             }
-            else if (parameters.padding == 'SAME') {
-                const out_h = ~~((inputs[0][1] + parameters.stride_h - 1) / parameters.stride_h);
-                const out_w = ~~((inputs[0][2] + parameters.stride_w - 1) / parameters.stride_w);
+            else {
+                newShape = newShape.concat(a.slice(-2, -1));
+            }
+            if (params.transpose_b) {
+                newShape = newShape.concat(b.slice(-2, -1));
+            }
+            else {
+                newShape = newShape.concat(b.slice(-1));
+            }
+            return [ newShape ];
+        });
+        this._operators.set('pad', (inputs, params) => {
+            const newShape = inputs[0].map((item, index) => {
+                return item + params.padding_value[index][0] + params.padding_value[index][1];
+            });
+            return [ newShape ];
+        });
+        this._operators.set('permute', (inputs, params) => {
+            const newShape = inputs[0].map((item, index) => {
+                return inputs[0][params.perm[index]];
+            });
+            return [ newShape ];
+        });
+        this._operators.set('pooling', (inputs, params) => {
+            if (params.padding == 'VALID') {
+                const out_h = ~~((inputs[0][1] + params.stride_h - params.ksize_h) / params.stride_h);
+                const out_w = ~~((inputs[0][2] + params.stride_w - params.ksize_w) / params.stride_w);
                 return [[inputs[0][0], out_h, out_w, inputs[0][3]]];
             }
+            else if (params.padding == 'SAME') {
+                const out_h = ~~((inputs[0][1] + params.stride_h - 1) / params.stride_h);
+                const out_w = ~~((inputs[0][2] + params.stride_w - 1) / params.stride_w);
+                return [ [inputs[0][0], out_h, out_w, inputs[0][3]] ];
+            }
+        });
+        this._operators.set('reduce', (inputs, params) => {
+            const newShape = inputs[0].slice();
+            if (params.keep_dims) {
+                for (const i in params.axis_list) {
+                    newShape[i] = 1;
+                }
+            }
+            else {
+                const axis_list = params.axis_list.map((item) => {
+                    return item < 0 ? newShape.length + item : item;
+                });
+                axis_list.sort((a, b) => {
+                    return b - a;
+                });
+                axis_list.map((item) => {
+                    newShape.splice(item, 1);
+                });
+                if (!newShape.length) {
+                    newShape.splice(0, 0, 0);
+                }
+            }
+            return [ newShape ];
+        });
+        this._operators.set('repeat', (inputs, params) => {
+            const newShape = inputs[0].slice();
+            newShape[params.axis] = params.maxlen;
+            return [ newShape ];
+        });
+        this._operators.set('reshape', (inputs, params) => {
+            const negativeIndexs = [];
+            let shape = params.shape;
+            if (typeof params.shape === 'string') {
+                shape = params.shape.split(/\s+/).map((item) => {
+                    return parseInt(item);
+                });
+            }
+            const newShape = shape.map((item, index) => {
+                if (item == 0) {
+                    return inputs[0][index];
+                }
+                else if (item == -1) {
+                    negativeIndexs.push(index);
+                    return 1;
+                }
+                else {
+                    return item;
+                }
+            });
+            if (negativeIndexs.length > 0) {
+                newShape[negativeIndexs[0]] = inputs[0].reduce((a, c) => a * c) / newShape.reduce((a, c) => a * c);
+            }
+            return [ newShape ];
+        });
+        this._operators.set('sequence_mask', (inputs, params) => {
+            return [inputs[0].slice().concat([params.maxlen])];
+        });
+        this._operators.set('slice', (inputs, params) => {
+            const newShape = params.size.map((item, index) => {
+                return item == -1 ? inputs[0][index] : item;
+            });
+            return [ newShape ];
+        });
+        this._operators.set('squeeze', (inputs, params) => {
+            const newShape = inputs[0].slice();
+            const axis_list = [...new Set(params.axis_list)].sort((a, b) => b - a);
+            axis_list.map((item) => {
+                newShape.splice(item, 1);
+            });
+            return [ newShape ];
+        });
+        this._operators.set('space2depth', (inputs, params) => {
+            const h = inputs[0][1] / params.block_size[0];
+            const w = inputs[0][2] / params.block_size[1];
+            const c = inputs[0][3] * params.block_size[1] * params.block_size[1];
+            return [ [inputs[0][0], h, w, c] ];
+        });
+        this._operators.set('split', (inputs, params) => {
+            const sizes = [];
+            const slices = params.slices.slice();
+            slices.splice(0, 0, 0);
+            slices.push(inputs[0][params.dim]);
+            slices.reduce((a, b) => {
+                sizes.push(b - a);
+                return b;
+            });
+            const newShapes = sizes.map((item) => {
+                const shape = inputs[0].slice();
+                shape[params.dim] = item;
+                return shape;
+            });
+            return newShapes;
+        });
+        this._operators.set('stack', (inputs, params) => {
+            const newShape = inputs[0].slice();
+            if (newShape.length == 1 && newShape[0] == 0) {
+                newShape[0] = 1;
+            }
+            else {
+                newShape.splice(params.axis, 0, inputs.length);
+            }
+            return [ newShape ];
+        });
+        this._operators.set('stridedslice', (inputs, params) => {
+            const input_shape = inputs[0].slice();
+            const begin = params.slice_begin.slice();
+            const end = params.slice_end.slice();
+            if (params.slice_begin_mask > 0) {
+                for (let i = 0; i < begin.length; i++) {
+                    if ((params.slice_begin_mask >>> i) & 0x1) {
+                        begin[i] = -1;
+                    }
+                }
+            }
+            if (params.slice_end_mask > 0) {
+                for (let i = 0; i < end.length; i++) {
+                    if ((params.slice_end_mask >>> i) & 0x1) {
+                        end[i] = -1;
+                    }
+                }
+            }
+            for (let i = 0; i < begin.length; i++) {
+                if (begin[i] == -1) {
+                    begin[i] = 0;
+                }
+            }
+            if (inputs[0].length == end.length){
+                for (let i = 0; i < end.length; i++) {
+                    if (end[i] == -1 || end[i] > input_shape[i]) {
+                        end[i] = input_shape[i];
+                    }
+                }
+            }
+            else if (inputs[0].length < end.length){
+                if (params.slice_new_axis_mask) {
+                    const len = (params.slice_new_axis_mask >>> 0).toString(2).length;
+                    for (let i = 0; i < len; i++) {
+                        if ((params.slice_new_axis_mask >>> i) & 0x1) {
+                            input_shape.splice(i, 0, 1);
+                        }
+                    }
+                    for (let i = 0; i < end.length; i++) {
+                        if (end[i] == -1) {
+                            end[i] = input_shape[i];
+                        }
+                    }
+                }
+            }
+            let newShape = [];
+            for (let i = 0; i < begin.length; i++) {
+                newShape = newShape.concat([(end[i] - begin[i])/params.slice_strides[i]]);
+            }
+            if (params.slice_shrink_axis_mask) {
+                const len = (params.slice_shrink_axis_mask >>> 0).toString(2).length;
+                for (let i = 0; i < len; i++) {
+                    if ((params.slice_shrink_axis_mask >>> i) & 0x1) {
+                        newShape.splice(i, 1);
+                    }
+                }
+            }
+            if (params.slice_new_axis_mask) {
+                const len = (params.slice_new_axis_mask >>> 0).toString(2).length;
+                for (let i = 0; i < len; i++) {
+                    if ((params.slice_new_axis_mask >>> i) & 0x1) {
+                        if (inputs[0].length == begin.length) {
+                            newShape.splice(i, 0, 1);
+                        }
+                        else if (inputs[0].length < begin.length) {
+                            newShape[i] = 1;
+                        }
+                    }
+                }
+            }
+            return [ newShape ];
         });
         for (const layer of outputLayers) {
             for (const output of layer.outputs) {
@@ -495,6 +762,12 @@ acuity.Inference =  class {
                 else if (this._passthroughs.has(layer.op)) {
                     callback = (inputs) => [ inputs[0].slice() ];
                 }
+                else if (this._broadcasts.has(layer.op)) {
+                    callback = this._operators.get('broadcast');
+                }
+                else if (this._reduces.has(layer.op)) {
+                    callback = this._operators.get('reduce');
+                }
                 else {
                     callback = () => [];
                 }