Jelajahi Sumber

Update acuity.js

Lutz Roeder 1 tahun lalu
induk
melakukan
17b35c5014
2 mengubah file dengan 786 tambahan dan 73 penghapusan
  1. 628 57
      source/acuity-metadata.json
  2. 158 16
      source/acuity.js

File diff ditekan karena terlalu besar
+ 628 - 57
source/acuity-metadata.json


+ 158 - 16
source/acuity.js

@@ -54,6 +54,8 @@ acuity.Graph = class {
                     } else if (Object.prototype.hasOwnProperty.call(layer.parameters, 'size') && Object.prototype.hasOwnProperty.call(layer.parameters, 'channels')) {
                     } else if (Object.prototype.hasOwnProperty.call(layer.parameters, 'size') && Object.prototype.hasOwnProperty.call(layer.parameters, 'channels')) {
                         const sizes = layer.parameters.size.split(' ');
                         const sizes = layer.parameters.size.split(' ');
                         shape = [0, parseInt(sizes[0], 10), parseInt(sizes[1], 10), layer.parameters.channels];
                         shape = [0, parseInt(sizes[0], 10), parseInt(sizes[1], 10), layer.parameters.channels];
+                    } else if (Object.prototype.hasOwnProperty.call(layer.parameters, 'is_scalar')) {
+                        shape = [1];
                     }
                     }
                     if (shape && shape.length === 4 && shape[0] === 0) {
                     if (shape && shape.length === 4 && shape[0] === 0) {
                         shape[0] = 1;
                         shape[0] = 1;
@@ -218,7 +220,8 @@ acuity.Inference = class {
         const broadcasts = new Set([
         const broadcasts = new Set([
             'add', 'equal', 'fllor_mod', 'floor_div', 'greater', 'greater_equal', 'less', 'less_equal',
             'add', 'equal', 'fllor_mod', 'floor_div', 'greater', 'greater_equal', 'less', 'less_equal',
             'logical_and', 'logical_or', 'minimum', 'multiply', 'not_equal', 'pow', 'real_div',
             'logical_and', 'logical_or', 'minimum', 'multiply', 'not_equal', 'pow', 'real_div',
-            'squared_difference', 'subtract'
+            'squared_difference', 'subtract', 'divide', 'addn', 'Divide', 'bitwise_and', 'bitwise_or',
+            'bitwise_xor', 'average', 'logical_not', 'logical_xor'
         ]);
         ]);
         const passthroughs = new Set([
         const passthroughs = new Set([
             'LocalResponseNormalization', 'a_times_b_plus_c', 'abs', 'batchnorm_single', 'batchnormalize',
             'LocalResponseNormalization', 'a_times_b_plus_c', 'abs', 'batchnorm_single', 'batchnormalize',
@@ -226,11 +229,17 @@ acuity.Inference = class {
             'groupnormalize', 'hard_sigmoid', 'hard_swish', 'instancenormalize', 'l2normalize', 'l2normalizescale',
             'groupnormalize', 'hard_sigmoid', 'hard_swish', 'instancenormalize', 'l2normalize', 'l2normalizescale',
             'layernormalize', 'leakyrelu', 'log', 'log_softmax', 'mish', 'neg', 'norm_with_channel_mean',
             'layernormalize', 'leakyrelu', 'log', 'log_softmax', 'mish', 'neg', 'norm_with_channel_mean',
             'norm_with_min_max', 'norm_with_scale', 'pow', 'prelu', 'quantize', 'relu', 'relu_keras',
             'norm_with_min_max', 'norm_with_scale', 'pow', 'prelu', 'quantize', 'relu', 'relu_keras',
-            'relun', 'reverse', 'round', 'rsqrt', 'sigmoid', 'sin', 'softmax', 'softrelu', 'sqrt', 'square', 'tanh'
+            'relun', 'reverse', 'round', 'rsqrt', 'sigmoid', 'sin', 'softmax', 'softrelu', 'sqrt', 'square', 'tanh',
+            'swish', 'gelu', 'dropout', 'eltwise', 'cos', 'l1_layernormalize', 'inverse_sigmoid', 'selu', 'mod',
+            'mish', 'minimum_with_clip', 'celu', 'cumsum', 'dft', 'dropout2', 'erf', 'noop', 'squashing', 'tan', 'ceil',
+            'atan', 'atan2', 'atanh', 'alpha_dropout', 'acosh', 'rmsnormalize', 'sign'
         ]);
         ]);
         const reduces = new Set([
         const reduces = new Set([
             'reduceany', 'reducemax', 'reducemean', 'reducemin', 'reduceprod', 'reducesum'
             'reduceany', 'reducemax', 'reducemean', 'reducemin', 'reduceprod', 'reducesum'
         ]);
         ]);
+        const poolings = new Set([
+            'pooling', 'l2pooling'
+        ]);
         const operators = new Map();
         const operators = new Map();
         operators.set('broadcast', ([a, b]) => {
         operators.set('broadcast', ([a, b]) => {
             const longer = a.length >= b.length ? a.slice() : b.slice();
             const longer = a.length >= b.length ? a.slice() : b.slice();
@@ -274,9 +283,34 @@ acuity.Inference = class {
             }
             }
             return null;
             return null;
         });
         });
+        operators.set('depthwise_conv1d', (inputs, params) => {
+            if (params.padding === 'VALID') {
+                const out_h = ~~((inputs[0][1] + params.stride + params.pad[0] + params.pad[1] - params.ksize) / params.stride);
+                return [[inputs[0][0], out_h, inputs[0][2] * params.multiplier]];
+            } else if (params.padding === 'SAME') {
+                const out_h = ~~((inputs[0][1] + params.stride - 1) / params.stride);
+                return [[inputs[0][0], out_h, inputs[0][2] * params.multiplier]];
+            }
+            return null;
+        });
+        operators.set('depthwise_convolution', (inputs, params) => {
+            if (params.padding === 'VALID') {
+                const out_h = ~~((inputs[0][1] + params.stride_h + params.pad[0] + params.pad[1] - params.ksize_h) / params.stride_h);
+                const out_w = ~~((inputs[0][2] + params.stride_w + params.pad[2] + params.pad[3] - params.ksize_w) / params.stride_w);
+                return [[inputs[0][0], out_h, out_w, inputs[0][3] * params.multiplier]];
+            } else if (params.padding === 'SAME') {
+                const out_h = ~~((inputs[0][1] + params.stride_h - 1) / params.stride_h);
+                const out_w = ~~((inputs[0][2] + params.stride_w - 1) / params.stride_w);
+                return [[inputs[0][0], out_h, out_w, inputs[0][3] * params.multiplier]];
+            }
+            return null;
+        });
         operators.set('deconvolution', (inputs, params) => {
         operators.set('deconvolution', (inputs, params) => {
             return [params.output_shape.map((item, index) => item === 0 ? inputs[0][index] : item)];
             return [params.output_shape.map((item, index) => item === 0 ? inputs[0][index] : item)];
         });
         });
+        operators.set('deconvolution1d', (inputs, params) => {
+            return [params.output_shape.map((item, index) => item === 0 ? inputs[0][index] : item)];
+        });
         operators.set('fullconnect', (inputs, params) => {
         operators.set('fullconnect', (inputs, params) => {
             return [inputs[0].slice(0, params.axis).concat([params.weights])];
             return [inputs[0].slice(0, params.axis).concat([params.weights])];
         });
         });
@@ -330,20 +364,19 @@ acuity.Inference = class {
         });
         });
         operators.set('reduce', (inputs, params) => {
         operators.set('reduce', (inputs, params) => {
             const newShape = inputs[0].slice();
             const newShape = inputs[0].slice();
-            if (params.keep_dims) {
-                for (const i in params.axis_list) {
-                    newShape[i] = 1;
-                }
-            } else {
-                const axis_list = params.axis_list.map((item) => {
-                    return item < 0 ? newShape.length + item : item;
-                });
-                axis_list.sort((a, b) => {
-                    return b - a;
+            const axis_list = params.axis_list.map((item) => {
+                return item < 0 ? newShape.length + item : item;
+            });
+            axis_list.sort((a, b) => {
+                return b - a;
+            });
+            axis_list.forEach((i) => {
+                newShape[i] = 1;
+            });
+            if (!params.keep_dims) {
+                axis_list.forEach((i) => {
+                    newShape.splice(i, 1);
                 });
                 });
-                for (const item of axis_list) {
-                    newShape.splice(item, 1);
-                }
                 if (!newShape.length) {
                 if (!newShape.length) {
                     newShape.splice(0, 0, 0);
                     newShape.splice(0, 0, 0);
                 }
                 }
@@ -398,6 +431,20 @@ acuity.Inference = class {
             const c = inputs[0][3] * params.block_size[1] * params.block_size[1];
             const c = inputs[0][3] * params.block_size[1] * params.block_size[1];
             return [[inputs[0][0], h, w, c]];
             return [[inputs[0][0], h, w, c]];
         });
         });
+        operators.set('depth2space', (inputs, params) => {
+            const h = inputs[0][1] * params.block_size;
+            const w = inputs[0][2] * params.block_size;
+            const c = inputs[0][3] / (params.block_size * params.block_size);
+            return [[inputs[0][0], h, w, c]];
+        });
+        operators.set('upsampling', (inputs, params) => {
+            const h = inputs[0][1] * params.factor;
+            const w = inputs[0][2] * params.factor;
+            return [[inputs[0][0], h, w, inputs[0][3]]];
+        });
+        operators.set('crop_image', (inputs, params) => {
+            return [[inputs[0][0], params.crop_size[0], params.crop_size[1], inputs[0][3]]];
+        });
         operators.set('split', (inputs, params) => {
         operators.set('split', (inputs, params) => {
             const sizes = [];
             const sizes = [];
             const slices = params.slices.slice();
             const slices = params.slices.slice();
@@ -492,6 +539,99 @@ acuity.Inference = class {
             }
             }
             return [newShape];
             return [newShape];
         });
         });
+        operators.set('image_resize', (inputs, params) => {
+            const newShape = inputs[0].slice();
+            /* eslint-disable prefer-destructuring */
+            newShape[1] = params.new_size[0];
+            newShape[2] = params.new_size[1];
+            /* eslint-enable prefer-destructuring */
+            return [newShape];
+        });
+        operators.set('argmax', (inputs, params) => {
+            const newShape = inputs[0].slice();
+            if (params.keepdims) {
+                newShape[params.axis] = 1;
+            } else {
+                newShape.splice(params.axis, 1);
+                if (!newShape.length) {
+                    newShape.splice(0, 0, 0);
+                }
+            }
+            return [newShape];
+        });
+        operators.set('argmin', operators.get('argmax'));
+        /* eslint-disable no-unused-vars */
+        operators.set('shapelayer', (inputs, params) => {
+            return [[inputs[0].length]];
+        });
+        operators.set('capsule_norm', (inputs, params) => {
+            return [[inputs[0][0], inputs[0][inputs[0].length - 1]]];
+        });
+        operators.set('size', (inputs, params) => {
+            return [[1]];
+        });
+        /* eslint-enable no-unused-vars */
+        operators.set('einsum', ((operators, inputs, params) => {
+            const identifyOperation = (inputs, equation) => {
+                const identifyFuncs = new Map();
+                identifyFuncs.set('matmul', (inputs, equation) => {
+                    if (inputs.length !== 2) {
+                        return { found: false };
+                    }
+
+                    const parts = equation.replace(/\s+/g, '').split(/,|->/);
+                    if (parts.length !== 3) {
+                        return { found: false };
+                    }
+
+                    const [first, second, output] = parts.map((p) => p.split(''));
+                    if (!(first.length === output.length || second.length === output.length)) {
+                        return { found: false };
+                    }
+
+                    let a = first.slice(-2);
+                    const b = second.slice(-2);
+                    const c = output.slice(-2);
+                    let transpose_a = false;
+                    let transpose_b = false;
+                    if (a[0] === c[0]) {
+                        transpose_a = false;
+                    } else if (a[1] === c[0]) {
+                        transpose_a = true;
+                        a = [].concat(a.reverse());
+                    } else {
+                        return { found: false };
+                    }
+
+                    if (a[1] === b[0]) {
+                        transpose_b = false;
+                    } else if (a[1] === b[1]) {
+                        transpose_b = true;
+                    } else {
+                        return { found: false };
+                    }
+                    return { found: true, op: 'matmul', params: { transpose_a, transpose_b } };
+                });
+
+                /* eslint-disable no-unused-vars */
+                for (const [name, func] of identifyFuncs.entries()) {
+                    const result = func(inputs, equation);
+                    if (result.found) {
+                        return result;
+                    }
+                }
+                /* eslint-enable no-unused-vars */
+                return { found: false };
+            };
+
+            const result = identifyOperation(inputs, params.equation);
+            if (result.found) {
+                if (operators.has(result.op)) {
+                    return operators.get(result.op)(inputs, result.params);
+                }
+            }
+            return [];
+        }).bind(undefined, operators));
         const infer = (output) => {
         const infer = (output) => {
             if (outputs.has(output.name)) {
             if (outputs.has(output.name)) {
                 let ready = true;
                 let ready = true;
@@ -515,6 +655,8 @@ acuity.Inference = class {
                         callback = operators.get('broadcast');
                         callback = operators.get('broadcast');
                     } else if (reduces.has(layer.op)) {
                     } else if (reduces.has(layer.op)) {
                         callback = operators.get('reduce');
                         callback = operators.get('reduce');
+                    } else if (poolings.has(layer.op)) {
+                        callback = operators.get('pooling');
                     }
                     }
                     if (!callback) {
                     if (!callback) {
                         callback = () => [];
                         callback = () => [];
@@ -546,4 +688,4 @@ acuity.Error = class extends Error {
     }
     }
 };
 };
 
 
-export const ModelFactory = acuity.ModelFactory;
+export const ModelFactory = acuity.ModelFactory;

Beberapa file tidak ditampilkan karena terlalu banyak file yang berubah dalam diff ini