Lutz Roeder 4 лет назад
Родитель
Сommit
177cf30a40
1 измененных файлов с 86 добавлено и 91 удалено
  1. 86 91
      source/barracuda.js

+ 86 - 91
source/barracuda.js

@@ -137,7 +137,7 @@ barracuda.Node = class {
 
     constructor(metadata, layer, type, initializers) {
         this._name = layer.name || '';
-        this._type = type ? type : metadata.type(layer.type) || { name: layer.type.toString() };
+        this._type = type ? type : metadata.type(layer.type);
         this._inputs = [];
         this._outputs = [];
         this._attributes = [];
@@ -171,10 +171,6 @@ barracuda.Node = class {
                 new barracuda.Argument(this._name)
             ]));
         }
-        /* if (this._type.name === 'Activation') {
-            const type = barracuda.Activation[layer.activation];
-            this._type = metadata.type(layer.activation) || { name: type };
-        } */
         if (layer.activation && layer.activation !== 0) {
             const type = barracuda.Activation[layer.activation];
             if (!type) {
@@ -579,96 +575,95 @@ barracuda.Metadata = class {
     }
 
     constructor() {
-        this._map = new Map();
-        this._register(0, 'Nop', '');
-        this._register(1, 'Dense', 'Layer', [ 'input', 'kernel', 'bias' ]);
-        this._register(2, 'MatMul', '', [ 'input', 'kernel', 'bias' ]);
-        this._register(20, 'Conv2D', 'Layer', [ 'input', 'kernel', 'bias' ]);
-        this._register(21, 'DepthwiseConv2D', 'Layer', [ 'input', 'kernel', 'bias' ]);
-        this._register(22, 'Conv2DTrans', '');
-        this._register(23, 'Upsample2D', '');
-        this._register(25, 'MaxPool2D', 'Pool');
-        this._register(26, 'AvgPool2D', 'Pool');
-        this._register(27, 'GlobalMaxPool2D', 'Pool');
-        this._register(28, 'GlobalAvgPool2D', 'Pool');
-        this._register(29, 'Border2D', '');
-        this._register(30, 'Conv3D', 'Layer');
-        this._register(32, 'Conv3DTrans', 'Layer');
-        this._register(33, 'Upsample3D', '');
-        this._register(35, 'MaxPool3D', 'Pool');
-        this._register(36, 'AvgPool3D', 'Pool');
-        this._register(37, 'GlobalMaxPool3D', 'Pool');
-        this._register(38, 'GlobalAvgPool3D', 'Pool');
-        this._register(39, 'Border3D', '');
-        this._register(50, 'Activation', '');
-        this._register(51, 'ScaleBias', 'Normalization', [ 'input', 'scale', 'bias' ]);
-        this._register(52, 'Normalization', 'Normalization');
-        this._register(53, 'LRN', 'Normalization');
-        this._register(60, 'Dropout', 'Dropout');
-        this._register(64, 'RandomNormal', '');
-        this._register(65, 'RandomUniform', '');
-        this._register(66, 'Multinomial', '');
-        this._register(67, 'OneHot', '');
-        this._register(68, 'TopKIndices', '');
-        this._register(69, 'TopKValues', '');
-        this._register(100, 'Add', '', [ 'inputs' ]);
-        this._register(101, 'Sub', '', [ 'inputs' ]);
-        this._register(102, 'Mul', '', [ 'inputs' ]);
-        this._register(103, 'RealDiv', '', [ 'inputs' ]);
-        this._register(104, 'Pow', '', [ 'inputs' ]);
-        this._register(110, 'Minimum', '', [ 'inputs' ]);
-        this._register(111, 'Maximum', '', [ 'inputs' ]);
-        this._register(112, 'Mean', '', [ 'inputs' ]);
-        this._register(120, 'ReduceL1', '', [ 'inputs' ]);
-        this._register(121, 'ReduceL2', '', [ 'inputs' ]);
-        this._register(122, 'ReduceLogSum', '', [ 'inputs' ]);
-        this._register(123, 'ReduceLogSumExp', '', [ 'inputs' ]);
-        this._register(124, 'ReduceMax', '', [ 'inputs' ]);
-        this._register(125, 'ReduceMean', '', [ 'inputs' ]);
-        this._register(126, 'ReduceMin', '', [ 'inputs' ]);
-        this._register(127, 'ReduceProd', '', [ 'inputs' ]);
-        this._register(128, 'ReduceSum', '', [ 'inputs' ]);
-        this._register(129, 'ReduceSumSquare', '', [ 'inputs' ]);
-        this._register(140, 'Greater', '');
-        this._register(141, 'GreaterEqual', '');
-        this._register(142, 'Less', '');
-        this._register(143, 'LessEqual', '');
-        this._register(144, 'Equal', '');
-        this._register(145, 'LogicalOr', '');
-        this._register(146, 'LogicalAnd', '');
-        this._register(147, 'LogicalNot', '');
-        this._register(148, 'LogicalXor', '');
-        this._register(160, 'Pad2DReflect', '');
-        this._register(161, 'Pad2DSymmetric', '');
-        this._register(162, 'Pad2DEdge', '');
-        this._register(200, 'Flatten', 'Shape');
-        this._register(201, 'Reshape', 'Shape');
-        this._register(202, 'Transpose', '');
-        this._register(203, 'Squeeze', '');
-        this._register(204, 'Unsqueeze', '');
-        this._register(205, 'Gather', '');
-        this._register(206, 'DepthToSpace', '');
-        this._register(207, 'SpaceToDepth', '');
-        this._register(208, 'Expand', '');
-        this._register(209, 'Resample2D', '');
-        this._register(210, 'Concat', 'Tensor', [ 'inputs' ]);
-        this._register(211, 'StridedSlice', 'Shape');
-        this._register(212, 'Tile', '');
-        this._register(213, 'Shape', '');
-        this._register(214, 'NonMaxSuppression', '');
-        this._register(215, 'LSTM', '');
-        this._register(255, 'Load', '');
-    }
-
-    _register(id, name, category, inputs) {
-        this._map.set(id, { name: name, category: category, inputs: (inputs || []).map((input) => { return { name: input }; }) });
+        this._types = new Map();
+        const register = (id, name, category, inputs) => {
+            this._types.set(id, { name: name, category: category, inputs: (inputs || []).map((input) => { return { name: input }; }) });
+        };
+        register(0, 'Nop', '');
+        register(1, 'Dense', 'Layer', [ 'input', 'kernel', 'bias' ]);
+        register(2, 'MatMul', '', [ 'input', 'kernel', 'bias' ]);
+        register(20, 'Conv2D', 'Layer', [ 'input', 'kernel', 'bias' ]);
+        register(21, 'DepthwiseConv2D', 'Layer', [ 'input', 'kernel', 'bias' ]);
+        register(22, 'Conv2DTrans', '');
+        register(23, 'Upsample2D', '');
+        register(25, 'MaxPool2D', 'Pool');
+        register(26, 'AvgPool2D', 'Pool');
+        register(27, 'GlobalMaxPool2D', 'Pool');
+        register(28, 'GlobalAvgPool2D', 'Pool');
+        register(29, 'Border2D', '');
+        register(30, 'Conv3D', 'Layer');
+        register(32, 'Conv3DTrans', 'Layer');
+        register(33, 'Upsample3D', '');
+        register(35, 'MaxPool3D', 'Pool');
+        register(36, 'AvgPool3D', 'Pool');
+        register(37, 'GlobalMaxPool3D', 'Pool');
+        register(38, 'GlobalAvgPool3D', 'Pool');
+        register(39, 'Border3D', '');
+        register(50, 'Activation', '', [ 'input' ]);
+        register(51, 'ScaleBias', 'Normalization', [ 'input', 'scale', 'bias' ]);
+        register(52, 'Normalization', 'Normalization');
+        register(53, 'LRN', 'Normalization');
+        register(60, 'Dropout', 'Dropout');
+        register(64, 'RandomNormal', '');
+        register(65, 'RandomUniform', '');
+        register(66, 'Multinomial', '');
+        register(67, 'OneHot', '');
+        register(68, 'TopKIndices', '');
+        register(69, 'TopKValues', '');
+        register(100, 'Add', '', [ 'inputs' ]);
+        register(101, 'Sub', '', [ 'inputs' ]);
+        register(102, 'Mul', '', [ 'inputs' ]);
+        register(103, 'RealDiv', '', [ 'inputs' ]);
+        register(104, 'Pow', '', [ 'inputs' ]);
+        register(110, 'Minimum', '', [ 'inputs' ]);
+        register(111, 'Maximum', '', [ 'inputs' ]);
+        register(112, 'Mean', '', [ 'inputs' ]);
+        register(120, 'ReduceL1', '', [ 'inputs' ]);
+        register(121, 'ReduceL2', '', [ 'inputs' ]);
+        register(122, 'ReduceLogSum', '', [ 'inputs' ]);
+        register(123, 'ReduceLogSumExp', '', [ 'inputs' ]);
+        register(124, 'ReduceMax', '', [ 'inputs' ]);
+        register(125, 'ReduceMean', '', [ 'inputs' ]);
+        register(126, 'ReduceMin', '', [ 'inputs' ]);
+        register(127, 'ReduceProd', '', [ 'inputs' ]);
+        register(128, 'ReduceSum', '', [ 'inputs' ]);
+        register(129, 'ReduceSumSquare', '', [ 'inputs' ]);
+        register(140, 'Greater', '');
+        register(141, 'GreaterEqual', '');
+        register(142, 'Less', '');
+        register(143, 'LessEqual', '');
+        register(144, 'Equal', '');
+        register(145, 'LogicalOr', '');
+        register(146, 'LogicalAnd', '');
+        register(147, 'LogicalNot', '');
+        register(148, 'LogicalXor', '');
+        register(160, 'Pad2DReflect', '');
+        register(161, 'Pad2DSymmetric', '');
+        register(162, 'Pad2DEdge', '');
+        register(200, 'Flatten', 'Shape');
+        register(201, 'Reshape', 'Shape');
+        register(202, 'Transpose', '');
+        register(203, 'Squeeze', '');
+        register(204, 'Unsqueeze', '');
+        register(205, 'Gather', '');
+        register(206, 'DepthToSpace', '');
+        register(207, 'SpaceToDepth', '');
+        register(208, 'Expand', '');
+        register(209, 'Resample2D', '');
+        register(210, 'Concat', 'Tensor', [ 'inputs' ]);
+        register(211, 'StridedSlice', 'Shape');
+        register(212, 'Tile', '');
+        register(213, 'Shape', '');
+        register(214, 'NonMaxSuppression', '');
+        register(215, 'LSTM', '');
+        register(255, 'Load', '');
     }
 
     type(name) {
-        if (this._map.has(name)) {
-            return this._map.get(name);
+        if (!this._types.has(name)) {
+            this._types.set(name, { name: name.toString() });
         }
-        return null;
+        return this._types.get(name);
     }
 };