|
|
@@ -54,6 +54,8 @@ acuity.Graph = class {
|
|
|
} else if (Object.prototype.hasOwnProperty.call(layer.parameters, 'size') && Object.prototype.hasOwnProperty.call(layer.parameters, 'channels')) {
|
|
|
const sizes = layer.parameters.size.split(' ');
|
|
|
shape = [0, parseInt(sizes[0], 10), parseInt(sizes[1], 10), layer.parameters.channels];
|
|
|
+ } else if (Object.prototype.hasOwnProperty.call(layer.parameters, 'is_scalar')) {
|
|
|
+ shape = [1];
|
|
|
}
|
|
|
if (shape && shape.length === 4 && shape[0] === 0) {
|
|
|
shape[0] = 1;
|
|
|
@@ -218,7 +220,8 @@ acuity.Inference = class {
|
|
|
const broadcasts = new Set([
|
|
|
'add', 'equal', 'fllor_mod', 'floor_div', 'greater', 'greater_equal', 'less', 'less_equal',
|
|
|
'logical_and', 'logical_or', 'minimum', 'multiply', 'not_equal', 'pow', 'real_div',
|
|
|
- 'squared_difference', 'subtract'
|
|
|
+ 'squared_difference', 'subtract', 'divide', 'addn', 'Divide', 'bitwise_and', 'bitwise_or',
|
|
|
+ 'bitwise_xor', 'average', 'logical_not', 'logical_xor'
|
|
|
]);
|
|
|
const passthroughs = new Set([
|
|
|
'LocalResponseNormalization', 'a_times_b_plus_c', 'abs', 'batchnorm_single', 'batchnormalize',
|
|
|
@@ -226,11 +229,17 @@ acuity.Inference = class {
|
|
|
'groupnormalize', 'hard_sigmoid', 'hard_swish', 'instancenormalize', 'l2normalize', 'l2normalizescale',
|
|
|
'layernormalize', 'leakyrelu', 'log', 'log_softmax', 'mish', 'neg', 'norm_with_channel_mean',
|
|
|
'norm_with_min_max', 'norm_with_scale', 'pow', 'prelu', 'quantize', 'relu', 'relu_keras',
|
|
|
- 'relun', 'reverse', 'round', 'rsqrt', 'sigmoid', 'sin', 'softmax', 'softrelu', 'sqrt', 'square', 'tanh'
|
|
|
+ 'relun', 'reverse', 'round', 'rsqrt', 'sigmoid', 'sin', 'softmax', 'softrelu', 'sqrt', 'square', 'tanh',
|
|
|
+ 'swish', 'gelu', 'dropout', 'eltwise', 'cos', 'l1_layernormalize', 'inverse_sigmoid', 'selu', 'mod',
|
|
|
+ 'mish', 'minimum_with_clip', 'celu', 'cumsum', 'dft', 'dropout2', 'erf', 'noop', 'squashing', 'tan', 'ceil',
|
|
|
+ 'atan', 'atan2', 'atanh', 'alpha_dropout', 'acosh', 'rmsnormalize', 'sign'
|
|
|
]);
|
|
|
const reduces = new Set([
|
|
|
'reduceany', 'reducemax', 'reducemean', 'reducemin', 'reduceprod', 'reducesum'
|
|
|
]);
|
|
|
+ const poolings = new Set([
|
|
|
+ 'pooling', 'l2pooling'
|
|
|
+ ]);
|
|
|
const operators = new Map();
|
|
|
operators.set('broadcast', ([a, b]) => {
|
|
|
const longer = a.length >= b.length ? a.slice() : b.slice();
|
|
|
@@ -274,9 +283,34 @@ acuity.Inference = class {
|
|
|
}
|
|
|
return null;
|
|
|
});
|
|
|
+ operators.set('depthwise_conv1d', (inputs, params) => {
|
|
|
+ if (params.padding === 'VALID') {
|
|
|
+ const out_h = ~~((inputs[0][1] + params.stride + params.pad[0] + params.pad[1] - params.ksize) / params.stride);
|
|
|
+ return [[inputs[0][0], out_h, inputs[0][2] * params.multiplier]];
|
|
|
+ } else if (params.padding === 'SAME') {
|
|
|
+ const out_h = ~~((inputs[0][1] + params.stride - 1) / params.stride);
|
|
|
+ return [[inputs[0][0], out_h, inputs[0][2] * params.multiplier]];
|
|
|
+ }
|
|
|
+ return null;
|
|
|
+ });
|
|
|
+ operators.set('depthwise_convolution', (inputs, params) => {
|
|
|
+ if (params.padding === 'VALID') {
|
|
|
+ const out_h = ~~((inputs[0][1] + params.stride_h + params.pad[0] + params.pad[1] - params.ksize_h) / params.stride_h);
|
|
|
+ const out_w = ~~((inputs[0][2] + params.stride_w + params.pad[2] + params.pad[3] - params.ksize_w) / params.stride_w);
|
|
|
+ return [[inputs[0][0], out_h, out_w, inputs[0][3] * params.multiplier]];
|
|
|
+ } else if (params.padding === 'SAME') {
|
|
|
+ const out_h = ~~((inputs[0][1] + params.stride_h - 1) / params.stride_h);
|
|
|
+ const out_w = ~~((inputs[0][2] + params.stride_w - 1) / params.stride_w);
|
|
|
+ return [[inputs[0][0], out_h, out_w, inputs[0][3] * params.multiplier]];
|
|
|
+ }
|
|
|
+ return null;
|
|
|
+ });
|
|
|
operators.set('deconvolution', (inputs, params) => {
|
|
|
return [params.output_shape.map((item, index) => item === 0 ? inputs[0][index] : item)];
|
|
|
});
|
|
|
+ operators.set('deconvolution1d', (inputs, params) => {
|
|
|
+ return [params.output_shape.map((item, index) => item === 0 ? inputs[0][index] : item)];
|
|
|
+ });
|
|
|
operators.set('fullconnect', (inputs, params) => {
|
|
|
return [inputs[0].slice(0, params.axis).concat([params.weights])];
|
|
|
});
|
|
|
@@ -330,20 +364,19 @@ acuity.Inference = class {
|
|
|
});
|
|
|
operators.set('reduce', (inputs, params) => {
|
|
|
const newShape = inputs[0].slice();
|
|
|
- if (params.keep_dims) {
|
|
|
- for (const i in params.axis_list) {
|
|
|
- newShape[i] = 1;
|
|
|
- }
|
|
|
- } else {
|
|
|
- const axis_list = params.axis_list.map((item) => {
|
|
|
- return item < 0 ? newShape.length + item : item;
|
|
|
- });
|
|
|
- axis_list.sort((a, b) => {
|
|
|
- return b - a;
|
|
|
+ const axis_list = params.axis_list.map((item) => {
|
|
|
+ return item < 0 ? newShape.length + item : item;
|
|
|
+ });
|
|
|
+ axis_list.sort((a, b) => {
|
|
|
+ return b - a;
|
|
|
+ });
|
|
|
+ axis_list.forEach((i) => {
|
|
|
+ newShape[i] = 1;
|
|
|
+ });
|
|
|
+ if (!params.keep_dims) {
|
|
|
+ axis_list.forEach((i) => {
|
|
|
+ newShape.splice(i, 1);
|
|
|
});
|
|
|
- for (const item of axis_list) {
|
|
|
- newShape.splice(item, 1);
|
|
|
- }
|
|
|
if (!newShape.length) {
|
|
|
newShape.splice(0, 0, 0);
|
|
|
}
|
|
|
@@ -398,6 +431,20 @@ acuity.Inference = class {
|
|
|
const c = inputs[0][3] * params.block_size[1] * params.block_size[1];
|
|
|
return [[inputs[0][0], h, w, c]];
|
|
|
});
|
|
|
+ operators.set('depth2space', (inputs, params) => {
|
|
|
+ const h = inputs[0][1] * params.block_size;
|
|
|
+ const w = inputs[0][2] * params.block_size;
|
|
|
+ const c = inputs[0][3] / (params.block_size * params.block_size);
|
|
|
+ return [[inputs[0][0], h, w, c]];
|
|
|
+ });
|
|
|
+ operators.set('upsampling', (inputs, params) => {
|
|
|
+ const h = inputs[0][1] * params.factor;
|
|
|
+ const w = inputs[0][2] * params.factor;
|
|
|
+ return [[inputs[0][0], h, w, inputs[0][3]]];
|
|
|
+ });
|
|
|
+ operators.set('crop_image', (inputs, params) => {
|
|
|
+ return [[inputs[0][0], params.crop_size[0], params.crop_size[1], inputs[0][3]]];
|
|
|
+ });
|
|
|
operators.set('split', (inputs, params) => {
|
|
|
const sizes = [];
|
|
|
const slices = params.slices.slice();
|
|
|
@@ -492,6 +539,99 @@ acuity.Inference = class {
|
|
|
}
|
|
|
return [newShape];
|
|
|
});
|
|
|
+ operators.set('image_resize', (inputs, params) => {
|
|
|
+ const newShape = inputs[0].slice();
|
|
|
+ /* eslint-disable prefer-destructuring */
|
|
|
+ newShape[1] = params.new_size[0];
|
|
|
+ newShape[2] = params.new_size[1];
|
|
|
+ /* eslint-enable prefer-destructuring */
|
|
|
+ return [newShape];
|
|
|
+ });
|
|
|
+ operators.set('argmax', (inputs, params) => {
|
|
|
+ const newShape = inputs[0].slice();
|
|
|
+ if (params.keepdims) {
|
|
|
+ newShape[params.axis] = 1;
|
|
|
+ } else {
|
|
|
+ newShape.splice(params.axis, 1);
|
|
|
+ if (!newShape.length) {
|
|
|
+ newShape.splice(0, 0, 0);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return [newShape];
|
|
|
+ });
|
|
|
+ operators.set('argmin', operators.get('argmax'));
|
|
|
+ /* eslint-disable no-unused-vars */
|
|
|
+ operators.set('shapelayer', (inputs, params) => {
|
|
|
+ return [[inputs[0].length]];
|
|
|
+ });
|
|
|
+ operators.set('capsule_norm', (inputs, params) => {
|
|
|
+ return [[inputs[0][0], inputs[0][inputs[0].length - 1]]];
|
|
|
+ });
|
|
|
+ operators.set('size', (inputs, params) => {
|
|
|
+ return [[1]];
|
|
|
+ });
|
|
|
+ /* eslint-enable no-unused-vars */
|
|
|
+ operators.set('einsum', ((operators, inputs, params) => {
|
|
|
+ const identifyOperation = (inputs, equation) => {
|
|
|
+ const identifyFuncs = new Map();
|
|
|
+ identifyFuncs.set('matmul', (inputs, equation) => {
|
|
|
+ if (inputs.length !== 2) {
|
|
|
+ return { found: false };
|
|
|
+ }
|
|
|
+
|
|
|
+ const parts = equation.replace(/\s+/g, '').split(/,|->/);
|
|
|
+ if (parts.length !== 3) {
|
|
|
+ return { found: false };
|
|
|
+ }
|
|
|
+
|
|
|
+ const [first, second, output] = parts.map((p) => p.split(''));
|
|
|
+ if (!(first.length === output.length || second.length === output.length)) {
|
|
|
+ return { found: false };
|
|
|
+ }
|
|
|
+
|
|
|
+ let a = first.slice(-2);
|
|
|
+ const b = second.slice(-2);
|
|
|
+ const c = output.slice(-2);
|
|
|
+ let transpose_a = false;
|
|
|
+ let transpose_b = false;
|
|
|
+ if (a[0] === c[0]) {
|
|
|
+ transpose_a = false;
|
|
|
+ } else if (a[1] === c[0]) {
|
|
|
+ transpose_a = true;
|
|
|
+ a = [].concat(a.reverse());
|
|
|
+ } else {
|
|
|
+ return { found: false };
|
|
|
+ }
|
|
|
+
|
|
|
+ if (a[1] === b[0]) {
|
|
|
+ transpose_b = false;
|
|
|
+ } else if (a[1] === b[1]) {
|
|
|
+ transpose_b = true;
|
|
|
+ } else {
|
|
|
+ return { found: false };
|
|
|
+ }
|
|
|
+ return { found: true, op: 'matmul', params: { transpose_a, transpose_b } };
|
|
|
+ });
|
|
|
+
|
|
|
+ /* eslint-disable no-unused-vars */
|
|
|
+ for (const [name, func] of identifyFuncs.entries()) {
|
|
|
+ const result = func(inputs, equation);
|
|
|
+ if (result.found) {
|
|
|
+ return result;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ /* eslint-enable no-unused-vars */
|
|
|
+ return { found: false };
|
|
|
+ };
|
|
|
+
|
|
|
+ const result = identifyOperation(inputs, params.equation);
|
|
|
+ if (result.found) {
|
|
|
+ if (operators.has(result.op)) {
|
|
|
+ return operators.get(result.op)(inputs, result.params);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return [];
|
|
|
+ }).bind(undefined, operators));
|
|
|
const infer = (output) => {
|
|
|
if (outputs.has(output.name)) {
|
|
|
let ready = true;
|
|
|
@@ -515,6 +655,8 @@ acuity.Inference = class {
|
|
|
callback = operators.get('broadcast');
|
|
|
} else if (reduces.has(layer.op)) {
|
|
|
callback = operators.get('reduce');
|
|
|
+ } else if (poolings.has(layer.op)) {
|
|
|
+ callback = operators.get('pooling');
|
|
|
}
|
|
|
if (!callback) {
|
|
|
callback = () => [];
|
|
|
@@ -546,4 +688,4 @@ acuity.Error = class extends Error {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
-export const ModelFactory = acuity.ModelFactory;
|
|
|
+export const ModelFactory = acuity.ModelFactory;
|