Forráskód Böngészése

Add TensorFlow.js test file (#294)

Lutz Roeder 4 éve
szülő
commit
f73b0f93e7
2 módosított fájl, 46 hozzáadás és 11 törlés
  1. 39 11
      source/tf.js
  2. 7 0
      test/models.json

+ 39 - 11
source/tf.js

@@ -383,18 +383,39 @@ tf.ModelFactory = class {
                             let offset = 0;
                             for (const weight of manifest.weights) {
                                 const dtype = weight.quantization && weight.quantization.dtype ? weight.quantization.dtype : weight.dtype;
-                                if (!dtype_size_map.has(dtype)) {
-                                    throw new tf.Error("Unknown weight data type size '" + dtype + "'.");
-                                }
-                                const itemsize = dtype_size_map.get(dtype);
                                 const size = weight.shape.reduce((a, b) => a * b, 1);
-                                const length = itemsize * size;
-                                const tensor_content = buffer ? buffer.slice(offset, offset + length) : null;
-                                offset += length;
-                                if (nodes.has(weight.name)) {
-                                    const node = nodes.get(weight.name);
-                                    node.attr.value.tensor.dtype = tf.Utility.dataTypeKey(dtype);
-                                    node.attr.value.tensor.tensor_content = tensor_content;
+                                switch (dtype) {
+                                    case 'string': {
+                                        const data = [];
+                                        if (buffer && size > 0) {
+                                            const reader = new tf.BinaryReader(buffer.subarray(offset));
+                                            for (let i = 0; i < size; i++) {
+                                                data[i] = reader.string();
+                                            }
+                                            offset += reader.position;
+                                        }
+                                        if (nodes.has(weight.name)) {
+                                            const node = nodes.get(weight.name);
+                                            node.attr.value.tensor.dtype = tf.Utility.dataTypeKey(dtype);
+                                            node.attr.value.tensor.string_val = data;
+                                        }
+                                        break;
+                                    }
+                                    default: {
+                                        if (!dtype_size_map.has(dtype)) {
+                                            throw new tf.Error("Unknown weight data type size '" + dtype + "'.");
+                                        }
+                                        const itemsize = dtype_size_map.get(dtype);
+                                        const length = itemsize * size;
+                                        const tensor_content = buffer ? buffer.slice(offset, offset + length) : null;
+                                        offset += length;
+                                        if (nodes.has(weight.name)) {
+                                            const node = nodes.get(weight.name);
+                                            node.attr.value.tensor.dtype = tf.Utility.dataTypeKey(dtype);
+                                            node.attr.value.tensor.tensor_content = tensor_content;
+                                        }
+                                        break;
+                                    }
                                 }
                             }
                         }
@@ -1679,6 +1700,7 @@ tf.BinaryReader = class {
         this._position = 0;
         this._length = this._buffer.length;
         this._dataView = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
+        this._decoder = new TextDecoder('utf-8');
     }
 
     get position() {
@@ -1739,6 +1761,12 @@ tf.BinaryReader = class {
         return this._dataView.getUint64(position, true);
     }
 
+    string() {
+        const size = this.uint32();
+        const buffer = this.read(size);
+        return this._decoder.decode(buffer);
+    }
+
     varint32() {
         return this.varint64();
     }

+ 7 - 0
test/models.json

@@ -5644,6 +5644,13 @@
     "format": "TensorFlow.js Keras v2.1.4",
     "link":   "https://github.com/tensorflow/tfjs-examples"
   },
+  {
+    "type":   "tfjs",
+    "target": "test_concat_const_string_tfjs.zip",
+    "source": "https://github.com/lutzroeder/netron/files/7207228/test_concat_const_string_tfjs.zip",
+    "format": "TensorFlow.js graph-model",
+    "link":   "https://github.com/lutzroeder/netron/issues/294"
+  },
   {
     "type":   "tfjs",
     "target": "yamnet.tar",