Lutz Roeder 3 éve
szülő
commit
79d1c438ad
2 módosított fájl, 41 hozzáadás és 5 törlés
  1. 22 1
      source/hdf5.js
  2. 19 4
      source/keras.js

+ 22 - 1
source/hdf5.js

@@ -210,7 +210,7 @@ hdf5.Variable = class {
         switch (this._dataLayout.layoutClass) {
             case 1: // Contiguous
                 if (this._dataLayout.address) {
-                    return this._reader.at(this._dataLayout.address).read(this._dataLayout.size);
+                    return this._reader.at(this._dataLayout.address).stream(this._dataLayout.size);
                 }
                 break;
             case 2: { // Chunked
@@ -481,11 +481,26 @@ hdf5.BinaryReader = class extends hdf5.Reader {
         }
     }
 
+    peek(length) {
+        const position = this._offset + this._position;
+        length = length !== undefined ? length : this._buffer.length - position;
+        this.take(length);
+        const buffer = this._buffer.subarray(position, position + length);
+        this._position = position - this._offset;
+        return buffer;
+    }
+
     read(length) {
         const position = this.take(length);
         return this._buffer.subarray(position, position + length);
     }
 
+    stream(length) {
+        const position = this.take(length);
+        const buffer = this._buffer.subarray(position, position + length);
+        return new hdf5.BinaryReader(buffer);
+    }
+
     size(terminator) {
         let position = this._offset + this._position;
         while (this._buffer[position] !== terminator) {
@@ -540,6 +555,12 @@ hdf5.StreamReader = class extends hdf5.Reader {
         return this._stream.read(length);
     }
 
+    stream(length) {
+        this._stream.seek(this._offset + this._position);
+        this.skip(length);
+        return this._stream.stream(length);
+    }
+
     byte() {
         const position = this.take(1);
         return this._view.getUint8(position);

+ 19 - 4
source/keras.js

@@ -1096,26 +1096,37 @@ keras.Tensor = class {
             context.state = 'Tensor data is empty.';
             return context;
         }
+
+        try {
+            context.data = this._data instanceof Uint8Array ? this._data : this._data.peek();
+        }
+        catch (err) {
+            context.state = err.message;
+            return context;
+        }
+
         switch (this._type.dataType) {
             case 'boolean':
             case 'float16':
             case 'float32':
             case 'float64':
             case 'uint8':
+            case 'int8':
             case 'int32':
-            case 'int64':
+            case 'int64': {
                 context.dataType = this._type.dataType;
-                context.view = new DataView(this._data.buffer, this._data.byteOffset, this._data.byteLength);
+                context.view = new DataView(context.data.buffer, context.data.byteOffset, context.data.byteLength);
                 context.littleEndian = this._littleEndian;
                 break;
+            }
             case 'string':
                 context.dataType = this._type.dataType;
-                context.data = this._data;
                 break;
             default:
-                context.state = 'Tensor data type is not supported.';
+                context.state = "Tensor data type '" + this._type.dataType + "' is not supported.";
                 break;
         }
+
         context.shape = this._type.shape.dimensions;
         return context;
     }
@@ -1152,6 +1163,10 @@ keras.Tensor = class {
                         results.push(context.view.getUint8(context.index));
                         context.index += 1;
                         break;
+                    case 'int8':
+                        results.push(context.view.getInt8(context.index));
+                        context.index += 1;
+                        break;
                     case 'int32':
                         results.push(context.view.getInt32(context.index, littleEndian));
                         context.index += 4;