Răsfoiți Sursa

CoreML biDirectionalLSTM weights support

Lutz Roeder 8 ani în urmă
părinte
comite
cb6845d333
2 a modificat fișierele cu 57 adăugiri și 2 ștergeri
  1. 12 1
      src/coreml-model.js
  2. 45 1
      src/coreml-operator.json

+ 12 - 1
src/coreml-model.js

@@ -386,11 +386,22 @@ class CoreMLNode {
             case 'uniDirectionalLSTM':
                 if (value instanceof coreml.LSTMWeightParams) {
                     Object.keys(value).forEach((key) => {
-                        this._initializers.push(new CoreMLTensor(key, value));
+                        this._initializers.push(new CoreMLTensor(key, value[key]));
                     });
                     return;
                 }
                 break;
+            case 'biDirectionalLSTM':
+                if (name == 'weightParams' && value.length == 2) {
+                    Object.keys(value[0]).forEach((key) => {
+                        this._initializers.push(new CoreMLTensor(key, value[0][key]));
+                    });
+                    Object.keys(value[1]).forEach((key) => {
+                        this._initializers.push(new CoreMLTensor(key + '_rev', value[1][key]));
+                    });
+                    return;
+                }    
+                break;
         }
 
         this._attributes.push(new CoreMLAttribute(this, name, value));

+ 45 - 1
src/coreml-operator.json

@@ -84,7 +84,45 @@
     "name": "biDirectionalLSTM",
     "schema": {
       "category": "Layer",
-      "description": "Bidirectional long short-term memory (LSTM) layer. The first LSTM operates on the input sequence in the forward direction. The second LSTM operates on the input sequence in the reverse direction."
+      "description": "Bidirectional long short-term memory (LSTM) layer. The first LSTM operates on the input sequence in the forward direction. The second LSTM operates on the input sequence in the reverse direction.",
+      "inputs": [
+        { "name": "input" },
+        { "name": "h" },
+        { "name": "c" },
+        { "name": "h_rev" },
+        { "name": "c_rev" },
+        { "name": "inputGateWeightMatrix", "hidden": true },
+        { "name": "forgetGateWeightMatrix", "hidden": true },
+        { "name": "blockInputWeightMatrix", "hidden": true },
+        { "name": "outputGateWeightMatrix", "hidden": true },
+        { "name": "inputGateRecursionMatrix", "hidden": true },
+        { "name": "forgetGateRecursionMatrix", "hidden": true },
+        { "name": "blockInputRecursionMatrix", "hidden": true },
+        { "name": "outputGateRecursionMatrix", "hidden": true },
+        { "name": "inputGateBiasVector", "hidden": true },
+        { "name": "forgetGateBiasVector", "hidden": true },
+        { "name": "blockInputBiasVector", "hidden": true },
+        { "name": "outputGateBiasVector", "hidden": true },
+        { "name": "inputGateWeightMatrix_rev", "hidden": true },
+        { "name": "forgetGateWeightMatrix_rev", "hidden": true },
+        { "name": "blockInputWeightMatrix_rev", "hidden": true },
+        { "name": "outputGateWeightMatrix_rev", "hidden": true },
+        { "name": "inputGateRecursionMatrix_rev", "hidden": true },
+        { "name": "forgetGateRecursionMatrix_rev", "hidden": true },
+        { "name": "blockInputRecursionMatrix_rev", "hidden": true },
+        { "name": "outputGateRecursionMatrix_rev", "hidden": true },
+        { "name": "inputGateBiasVector_rev", "hidden": true },
+        { "name": "forgetGateBiasVector_rev", "hidden": true },
+        { "name": "blockInputBiasVector_rev", "hidden": true },
+        { "name": "outputGateBiasVector_rev", "hidden": true }
+      ],
+      "outputs": [
+        { "name": "output" },
+        { "name": "h" },
+        { "name": "c" },
+        { "name": "h_rev" },
+        { "name": "c_rev" }
+      ]
     }
   },
   {
@@ -167,6 +205,12 @@
       ]
     }
   },
+  {
+    "name": "sequenceRepeat",
+    "schema": {
+      "category": "Shape"
+    }    
+  },
   {
     "name": "concat",
     "schema": {