Lutz Roeder 4 лет назад
Родитель
Сommit
1f308335a9
4 измененных файлов с 68 добавлено и 25 удалено
  1. 6 0
      source/electron.js
  2. 54 23
      source/rknn.js
  3. 2 2
      source/view.js
  4. 6 0
      test/models.js

+ 6 - 0
source/electron.js

@@ -524,10 +524,16 @@ host.ElectronHost.BinaryStream = class {
 
     seek(position) {
         this._position = position >= 0 ? position : this._length + position;
+        if (this._position > this._buffer.length) {
+            throw new Error('Expected ' + (this._position - this._buffer.length) + ' more bytes. The file might be corrupted. Unexpected end of file.');
+        }
     }
 
     skip(offset) {
         this._position += offset;
+        if (this._position > this._buffer.length) {
+            throw new Error('Expected ' + (this._position - this._buffer.length) + ' more bytes. The file might be corrupted. Unexpected end of file.');
+        }
     }
 
     peek(length) {

+ 54 - 23
source/rknn.js

@@ -16,8 +16,8 @@ rknn.ModelFactory = class {
 
     open(context) {
         return rknn.Metadata.open(context).then((metadata) => {
-            const buffer = context.stream.peek();
-            const container = rknn.Container.open(buffer);
+            const stream = context.stream;
+            const container = rknn.Container.open(stream);
             return new rknn.Model(metadata, container.model, container.weights);
         });
     }
@@ -468,41 +468,72 @@ rknn.TensorShape = class {
 
 rknn.Container = class {
 
-    static open(buffer) {
-        if (buffer && buffer.length > 4 && [ 0x52, 0x4B, 0x4E, 0x4E, 0x00, 0x00, 0x00, 0x00 ].every((value, index) => buffer[index] === value)) {
-            return new rknn.Container(buffer);
+    static open(stream) {
+        const signature = [ 0x52, 0x4B, 0x4E, 0x4E, 0x00, 0x00, 0x00, 0x00 ];
+        if (signature.length <= stream.length && stream.peek(signature.length).every((value, index) => value === signature[index])) {
+            return new rknn.Container(stream);
         }
         return null;
     }
 
-    constructor(buffer) {
-        this._buffer = buffer;
-        const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
-        this._version = view.getUint64(8, true).toNumber();
-        let position = 16;
-        this._blocks = [];
-        while (position < buffer.length) {
-            const size = view.getUint64(position, true).toNumber();
-            position += 8;
-            this._blocks.push({ start: position, end: position + size });
-            position += size;
-        }
+    constructor(stream) {
+        this._reader = new rknn.Container.StreamReader(stream);
     }
 
     get version() {
+        this._read();
         return this._version;
     }
 
     get weights() {
-        const block = this._blocks[0];
-        return this._buffer.subarray(block.start, block.end);
+        this._read();
+        return this._weights;
     }
 
     get model() {
-        const block = this._blocks[1];
-        const buffer = this._buffer.subarray(block.start, block.end);
-        const reader = json.TextReader.create(buffer);
-        return reader.read();
+        this._read();
+        return this._model;
+    }
+
+    _read() {
+        if (this._reader) {
+            this._reader.uint64();
+            this._version = this._reader.uint64();
+            this._weights = this._reader.read();
+            const buffer = this._reader.read();
+            const reader = json.TextReader.create(buffer);
+            this._model = reader.read();
+            delete this._reader;
+        }
+    }
+};
+
+rknn.Container.StreamReader = class {
+
+    constructor(stream) {
+        this._stream = stream;
+        this._length = stream.length;
+        this._position = 0;
+    }
+
+    skip(offset) {
+        this._position += offset;
+        if (this._position > this._length) {
+            throw new rknn.Error('Expected ' + (this._position - this._length) + ' more bytes. The file might be corrupted. Unexpected end of file.');
+        }
+    }
+
+    uint64() {
+        this.skip(8);
+        const buffer = this._stream.read(8);
+        const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
+        return view.getUint64(0, true).toNumber();
+    }
+
+    read() {
+        const size = this.uint64();
+        this.skip(size);
+        return this._stream.read(size);
     }
 };
 

+ 2 - 2
source/view.js

@@ -1398,7 +1398,7 @@ view.ModelFactoryService = class {
         this.register('./bigdl', [ '.model', '.bigdl' ]);
         this.register('./darknet', [ '.cfg', '.model', '.txt', '.weights' ]);
         this.register('./weka', [ '.model' ]);
-        this.register('./rknn', [ '.rknn' ]);
+        this.register('./rknn', [ '.rknn', '.onnx' ]);
         this.register('./dlc', [ '.dlc' ]);
         this.register('./keras', [ '.h5', '.hd5', '.hdf5', '.keras', '.json', '.cfg', '.model', '.pb', '.pth', '.weights', '.pkl', '.lite', '.tflite', '.ckpt' ]);
         this.register('./armnn', [ '.armnn', '.json' ]);
@@ -1532,7 +1532,7 @@ view.ModelFactoryService = class {
         stream.seek(0);
         const buffer = stream.peek(Math.min(16, stream.length));
         const bytes = Array.from(buffer).map((c) => (c < 16 ? '0' : '') + c.toString(16)).join('');
-        const content = buffer.length > 268435456 ? '(' + bytes + ') [' + stream.length.toString() + ']': '(' + bytes + ')';
+        const content = stream.length > 268435456 ? '(' + bytes + ') [' + stream.length.toString() + ']': '(' + bytes + ')';
         throw new view.Error("Unsupported file content " + content + " for extension '." + extension + "' in '" + identifier + "'.", !skip);
     }
 

+ 6 - 0
test/models.js

@@ -154,10 +154,16 @@ class TestBinaryStream {
 
     seek(position) {
         this._position = position >= 0 ? position : this._length + position;
+        if (this._position > this._buffer.length) {
+            throw new Error('Expected ' + (this._position - this._buffer.length) + ' more bytes. The file might be corrupted. Unexpected end of file.');
+        }
     }
 
     skip(offset) {
         this._position += offset;
+        if (this._position > this._buffer.length) {
+            throw new Error('Expected ' + (this._position - this._buffer.length) + ' more bytes. The file might be corrupted. Unexpected end of file.');
+        }
     }
 
     peek(length) {