Browse Source

Update acuity.js

Lutz Roeder 4 years ago
parent
commit
dfa802f5ff
1 changed files with 88 additions and 112 deletions
  1. 88 112
      source/acuity.js

+ 88 - 112
source/acuity.js

@@ -91,7 +91,7 @@ acuity.Graph = class {
             });
         }
 
-        new acuity.Inference(model.Layers);
+        acuity.Inference.infer(model.Layers);
 
         for (const pair of args) {
             const type = new acuity.TensorType(null, new acuity.TensorShape(pair[1].shape));
@@ -355,7 +355,7 @@ acuity.Tensor = class {
     }
 
     get state() {
-        return 'Not supported.';
+        return 'Tensor data not implemented.';
     }
 
     toString() {
@@ -412,10 +412,10 @@ acuity.Metadata = class {
     }
 };
 
-acuity.Inference =  class {
+acuity.Inference = class {
 
-    constructor(layers) {
-        this._outputs = new Map();
+    static infer(layers) {
+        const outputs = new Map();
         const outputLayers = [];
         for (const layerName of Object.keys(layers)) {
             const layer = layers[layerName];
@@ -423,15 +423,15 @@ acuity.Inference =  class {
                 outputLayers.push(layer);
             }
             for (const output of layer.outputs) {
-                this._outputs.set(output.name, layer);
+                outputs.set(output.name, layer);
             }
         }
-        this._broadcasts = new Set([
+        const broadcasts = new Set([
             'add', 'equal', 'fllor_mod', 'floor_div', 'greater', 'greater_equal', 'less', 'less_equal',
             'logical_and', 'logical_or', 'minimum', 'multiply', 'not_equal', 'pow', 'real_div',
             'squared_difference', 'subtract'
         ]);
-        this._passthroughs = new Set([
+        const passthroughs = new Set([
             'LocalResponseNormalization', 'a_times_b_plus_c', 'abs', 'batchnorm_single', 'batchnormalize',
             'cast', 'cast', 'clipbyvalue', 'dequantize', 'dtype_converter', 'elu', 'exp', 'floor',
             'groupnormalize', 'hard_sigmoid', 'hard_swish', 'instancenormalize', 'l2normalize', 'l2normalizescale',
@@ -439,11 +439,11 @@ acuity.Inference =  class {
             'norm_with_min_max', 'norm_with_scale', 'pow', 'prelu', 'quantize', 'relu', 'relu_keras',
             'relun', 'reverse', 'round', 'rsqrt', 'sigmoid', 'sin', 'softmax', 'softrelu', 'sqrt', 'square', 'tanh'
         ]);
-        this._reduces = new Set([
+        const reduces = new Set([
             'reduceany', 'reducemax', 'reducemean', 'reducemin', 'reduceprod', 'reducesum'
         ]);
-        this._operators = new Map();
-        this._operators.set('broadcast', (inputs) => {
+        const operators = new Map();
+        operators.set('broadcast', (inputs) => {
             const a = inputs[0];
             const b = inputs[1];
             const longer = a.length >= b.length ? a.slice() : b.slice();
@@ -455,69 +455,59 @@ acuity.Inference =  class {
             for (let i = 0; i < longer.length; i++) {
                 longer[i] = longer[i] > shorter[i] ? longer[i] : shorter[i];
             }
-            return [longer];
+            return [ longer ];
         });
-        this._operators.set('concat', (inputs, params) => {
+        operators.set('concat', (inputs, params) => {
             const outputShape = inputs[0].slice();
             outputShape[params.dim] = 0;
             for (const shape of inputs) {
                 outputShape[params.dim] += shape[params.dim];
             }
-            return [outputShape];
+            return [ outputShape ];
         });
-        this._operators.set('conv1d', (inputs, params) => {
+        operators.set('conv1d', (inputs, params) => {
             if (params.padding == 'VALID') {
                 const out_h = ~~((inputs[0][1] + params.stride - params.ksize) / params.stride);
-                return [[inputs[0][0], out_h, params.weights]];
+                return [ [ inputs[0][0], out_h, params.weights ] ];
             }
             else if (params.padding == 'SAME') {
                 const out_h = ~~((inputs[0][1] + params.stride - 1) / params.stride);
-                return [[inputs[0][0], out_h, params.weights]];
+                return [ [ inputs[0][0], out_h, params.weights ] ];
             }
         });
-        this._operators.set('convolution', (inputs, params) => {
+        operators.set('convolution', (inputs, params) => {
             if (params.padding == 'VALID') {
                 const out_h = ~~((inputs[0][1] + params.stride_h + params.pad[0] + params.pad[1] - params.ksize_h) / params.stride_h);
                 const out_w = ~~((inputs[0][2] + params.stride_w + params.pad[2] + params.pad[3]- params.ksize_w) / params.stride_w);
-                return [[inputs[0][0], out_h, out_w, params.weights]];
+                return [ [ inputs[0][0], out_h, out_w, params.weights ] ];
             }
             else if (params.padding == 'SAME') {
                 const out_h = ~~((inputs[0][1] + params.stride_h - 1) / params.stride_h);
                 const out_w = ~~((inputs[0][2] + params.stride_w - 1) / params.stride_w);
-                return [[inputs[0][0], out_h, out_w, params.weights]];
+                return [ [ inputs[0][0], out_h, out_w, params.weights ] ];
             }
         });
-        this._operators.set('deconvolution', (inputs, params) => {
-            const newShape = params.output_shape.map((item, index) => {
-                return item == 0 ? inputs[0][index] : item;
-            });
-            return [ newShape ];
+        operators.set('deconvolution', (inputs, params) => {
+            return [ params.output_shape.map((item, index) => item == 0 ? inputs[0][index] : item) ];
         });
-        this._operators.set('fullconnect', (inputs, params) => {
+        operators.set('fullconnect', (inputs, params) => {
             return [ inputs[0].slice(0, params.axis).concat([params.weights]) ];
         });
-        this._operators.set('gather', (inputs, params) => {
+        operators.set('gather', (inputs, params) => {
             const prefix = inputs[1].slice();
             const suffix = inputs[0].slice(params.axis + 1);
-            const newShape = prefix.concat(suffix);
-            return [ newShape ];
+            return [ prefix.concat(suffix) ];
         });
-        this._operators.set('lstm', (inputs, params) => {
+        operators.set('lstm', (inputs, params) => {
             let batch = inputs[0][0];
             const output = params.num_proj != null ? params.num_proj : params.weights;
             if (params.time_major) {
                 batch = inputs[0][1];
             }
-            let newShape = [];
-            if (params.return_sequences) {
-                newShape = [inputs[0][0], inputs[0][1], output];
-            }
-            else {
-                newShape = [batch, output];
-            }
+            const newShape = params.return_sequences ? [ inputs[0][0], inputs[0][1], output ] : [ batch, output ];
             return [ newShape, [batch, output], [batch, params.weights] ];
         });
-        this._operators.set('matmul', (inputs, params) => {
+        operators.set('matmul', (inputs, params) => {
             const a = inputs[0];
             const b = inputs[1];
             let newShape = a.slice(0, -2);
@@ -535,23 +525,17 @@ acuity.Inference =  class {
             }
             return [ newShape ];
         });
-        this._operators.set('pad', (inputs, params) => {
-            const newShape = inputs[0].map((item, index) => {
-                return item + params.padding_value[index][0] + params.padding_value[index][1];
-            });
-            return [ newShape ];
+        operators.set('pad', (inputs, params) => {
+            return [ inputs[0].map((item, index) => item + params.padding_value[index][0] + params.padding_value[index][1]) ];
         });
-        this._operators.set('permute', (inputs, params) => {
-            const newShape = inputs[0].map((item, index) => {
-                return inputs[0][params.perm[index]];
-            });
-            return [ newShape ];
+        operators.set('permute', (inputs, params) => {
+            return [ inputs[0].map((item, index) => inputs[0][params.perm[index]]) ];
         });
-        this._operators.set('pooling', (inputs, params) => {
+        operators.set('pooling', (inputs, params) => {
             if (params.padding == 'VALID') {
                 const out_h = ~~((inputs[0][1] + params.stride_h - params.ksize_h) / params.stride_h);
                 const out_w = ~~((inputs[0][2] + params.stride_w - params.ksize_w) / params.stride_w);
-                return [[inputs[0][0], out_h, out_w, inputs[0][3]]];
+                return [ [inputs[0][0], out_h, out_w, inputs[0][3]] ];
             }
             else if (params.padding == 'SAME') {
                 const out_h = ~~((inputs[0][1] + params.stride_h - 1) / params.stride_h);
@@ -559,7 +543,7 @@ acuity.Inference =  class {
                 return [ [inputs[0][0], out_h, out_w, inputs[0][3]] ];
             }
         });
-        this._operators.set('reduce', (inputs, params) => {
+        operators.set('reduce', (inputs, params) => {
             const newShape = inputs[0].slice();
             if (params.keep_dims) {
                 for (const i in params.axis_list) {
@@ -582,12 +566,12 @@ acuity.Inference =  class {
             }
             return [ newShape ];
         });
-        this._operators.set('repeat', (inputs, params) => {
+        operators.set('repeat', (inputs, params) => {
             const newShape = inputs[0].slice();
             newShape[params.axis] = params.maxlen;
             return [ newShape ];
         });
-        this._operators.set('reshape', (inputs, params) => {
+        operators.set('reshape', (inputs, params) => {
             const negativeIndexs = [];
             let shape = params.shape;
             if (typeof params.shape === 'string') {
@@ -599,29 +583,24 @@ acuity.Inference =  class {
                 if (item == 0) {
                     return inputs[0][index];
                 }
-                else if (item == -1) {
+                if (item == -1) {
                     negativeIndexs.push(index);
                     return 1;
                 }
-                else {
-                    return item;
-                }
+                return item;
             });
             if (negativeIndexs.length > 0) {
                 newShape[negativeIndexs[0]] = inputs[0].reduce((a, c) => a * c) / newShape.reduce((a, c) => a * c);
             }
             return [ newShape ];
         });
-        this._operators.set('sequence_mask', (inputs, params) => {
-            return [inputs[0].slice().concat([params.maxlen])];
+        operators.set('sequence_mask', (inputs, params) => {
+            return [ inputs[0].slice().concat([params.maxlen]) ];
         });
-        this._operators.set('slice', (inputs, params) => {
-            const newShape = params.size.map((item, index) => {
-                return item == -1 ? inputs[0][index] : item;
-            });
-            return [ newShape ];
+        operators.set('slice', (inputs, params) => {
+            return [ params.size.map((item, index) => item == -1 ? inputs[0][index] : item) ];
         });
-        this._operators.set('squeeze', (inputs, params) => {
+        operators.set('squeeze', (inputs, params) => {
             const newShape = inputs[0].slice();
             const axis_list = [...new Set(params.axis_list)].sort((a, b) => b - a);
             axis_list.map((item) => {
@@ -629,13 +608,13 @@ acuity.Inference =  class {
             });
             return [ newShape ];
         });
-        this._operators.set('space2depth', (inputs, params) => {
+        operators.set('space2depth', (inputs, params) => {
             const h = inputs[0][1] / params.block_size[0];
             const w = inputs[0][2] / params.block_size[1];
             const c = inputs[0][3] * params.block_size[1] * params.block_size[1];
             return [ [inputs[0][0], h, w, c] ];
         });
-        this._operators.set('split', (inputs, params) => {
+        operators.set('split', (inputs, params) => {
             const sizes = [];
             const slices = params.slices.slice();
             slices.splice(0, 0, 0);
@@ -644,14 +623,13 @@ acuity.Inference =  class {
                 sizes.push(b - a);
                 return b;
             });
-            const newShapes = sizes.map((item) => {
+            return sizes.map((item) => {
                 const shape = inputs[0].slice();
                 shape[params.dim] = item;
                 return shape;
             });
-            return newShapes;
         });
-        this._operators.set('stack', (inputs, params) => {
+        operators.set('stack', (inputs, params) => {
             const newShape = inputs[0].slice();
             if (newShape.length == 1 && newShape[0] == 0) {
                 newShape[0] = 1;
@@ -661,7 +639,7 @@ acuity.Inference =  class {
             }
             return [ newShape ];
         });
-        this._operators.set('stridedslice', (inputs, params) => {
+        operators.set('stridedslice', (inputs, params) => {
             const input_shape = inputs[0].slice();
             const begin = params.slice_begin.slice();
             const end = params.slice_end.slice();
@@ -733,53 +711,51 @@ acuity.Inference =  class {
             }
             return [ newShape ];
         });
-        for (const layer of outputLayers) {
-            for (const output of layer.outputs) {
-                this._infer(output);
-            }
-        }
-    }
-
-    _infer(output) {
-        if (this._outputs.has(output.name)) {
-            let inputShapeReady = true;
-            const layer = this._outputs.get(output.name);
-            for (const input of layer.inputs) {
-                if (input.shape === null) {
-                    this._infer(input);
+        const infer = (output) => {
+            if (outputs.has(output.name)) {
+                let ready = true;
+                const layer = outputs.get(output.name);
+                for (const input of layer.inputs) {
                     if (input.shape === null) {
-                        inputShapeReady = false;
-                        break;
+                        infer(input);
+                        if (input.shape === null) {
+                            ready = false;
+                            break;
+                        }
                     }
                 }
-            }
-
-            if (inputShapeReady) {
-                let callback = null;
-                if (this._operators.has(layer.op)) {
-                    callback = this._operators.get(layer.op);
-                }
-                else if (this._passthroughs.has(layer.op)) {
-                    callback = (inputs) => [ inputs[0].slice() ];
-                }
-                else if (this._broadcasts.has(layer.op)) {
-                    callback = this._operators.get('broadcast');
-                }
-                else if (this._reduces.has(layer.op)) {
-                    callback = this._operators.get('reduce');
-                }
-                else {
-                    callback = () => [];
-                }
-                const parameters = layer.parameters;
-                const inputs = layer.inputs.map((input) => input.shape);
-                const outputs = callback(inputs, parameters);
-                for (let i = 0; i < outputs.length; i++) {
-                    if (i < layer.outputs.length) {
-                        layer.outputs[i].shape = outputs[i];
+                if (ready) {
+                    let callback = null;
+                    if (operators.has(layer.op)) {
+                        callback = operators.get(layer.op);
+                    }
+                    else if (passthroughs.has(layer.op)) {
+                        callback = (inputs) => [ inputs[0].slice() ];
+                    }
+                    else if (broadcasts.has(layer.op)) {
+                        callback = operators.get('broadcast');
+                    }
+                    else if (reduces.has(layer.op)) {
+                        callback = operators.get('reduce');
+                    }
+                    else {
+                        callback = () => [];
+                    }
+                    const parameters = layer.parameters;
+                    const inputs = layer.inputs.map((input) => input.shape);
+                    const outputs = callback(inputs, parameters);
+                    for (let i = 0; i < outputs.length; i++) {
+                        if (i < layer.outputs.length) {
+                            layer.outputs[i].shape = outputs[i];
+                        }
                     }
                 }
             }
+        };
+        for (const layer of outputLayers) {
+            for (const output of layer.outputs) {
+                infer(output);
+            }
         }
     }
 };