|
|
@@ -383,18 +383,39 @@ tf.ModelFactory = class {
|
|
|
let offset = 0;
|
|
|
for (const weight of manifest.weights) {
|
|
|
const dtype = weight.quantization && weight.quantization.dtype ? weight.quantization.dtype : weight.dtype;
|
|
|
- if (!dtype_size_map.has(dtype)) {
|
|
|
- throw new tf.Error("Unknown weight data type size '" + dtype + "'.");
|
|
|
- }
|
|
|
- const itemsize = dtype_size_map.get(dtype);
|
|
|
const size = weight.shape.reduce((a, b) => a * b, 1);
|
|
|
- const length = itemsize * size;
|
|
|
- const tensor_content = buffer ? buffer.slice(offset, offset + length) : null;
|
|
|
- offset += length;
|
|
|
- if (nodes.has(weight.name)) {
|
|
|
- const node = nodes.get(weight.name);
|
|
|
- node.attr.value.tensor.dtype = tf.Utility.dataTypeKey(dtype);
|
|
|
- node.attr.value.tensor.tensor_content = tensor_content;
|
|
|
+ switch (dtype) {
|
|
|
+ case 'string': {
|
|
|
+ const data = [];
|
|
|
+ if (buffer && size > 0) {
|
|
|
+ const reader = new tf.BinaryReader(buffer.subarray(offset));
|
|
|
+ for (let i = 0; i < size; i++) {
|
|
|
+ data[i] = reader.string();
|
|
|
+ }
|
|
|
+ offset += reader.position;
|
|
|
+ }
|
|
|
+ if (nodes.has(weight.name)) {
|
|
|
+ const node = nodes.get(weight.name);
|
|
|
+ node.attr.value.tensor.dtype = tf.Utility.dataTypeKey(dtype);
|
|
|
+ node.attr.value.tensor.string_val = data;
|
|
|
+ }
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ default: {
|
|
|
+ if (!dtype_size_map.has(dtype)) {
|
|
|
+ throw new tf.Error("Unknown weight data type size '" + dtype + "'.");
|
|
|
+ }
|
|
|
+ const itemsize = dtype_size_map.get(dtype);
|
|
|
+ const length = itemsize * size;
|
|
|
+ const tensor_content = buffer ? buffer.slice(offset, offset + length) : null;
|
|
|
+ offset += length;
|
|
|
+ if (nodes.has(weight.name)) {
|
|
|
+ const node = nodes.get(weight.name);
|
|
|
+ node.attr.value.tensor.dtype = tf.Utility.dataTypeKey(dtype);
|
|
|
+ node.attr.value.tensor.tensor_content = tensor_content;
|
|
|
+ }
|
|
|
+ break;
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -1679,6 +1700,7 @@ tf.BinaryReader = class {
|
|
|
this._position = 0;
|
|
|
this._length = this._buffer.length;
|
|
|
this._dataView = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
|
|
|
+ this._decoder = new TextDecoder('utf-8');
|
|
|
}
|
|
|
|
|
|
get position() {
|
|
|
@@ -1739,6 +1761,12 @@ tf.BinaryReader = class {
|
|
|
return this._dataView.getUint64(position, true);
|
|
|
}
|
|
|
|
|
|
+ string() {
|
|
|
+ const size = this.uint32();
|
|
|
+ const buffer = this.read(size);
|
|
|
+ return this._decoder.decode(buffer);
|
|
|
+ }
|
|
|
+
|
|
|
varint32() {
|
|
|
return this.varint64();
|
|
|
}
|