|
|
@@ -1235,13 +1235,19 @@ view.ModelContext = class {
|
|
|
if (!tags) {
|
|
|
tags = new Map();
|
|
|
let reset = false;
|
|
|
- const signature = [ 0x50, 0x4B, 0x03, 0x04 ];
|
|
|
- if (this.stream.length < 4 || !this.stream.peek(4).every((value, index) => value === signature[index])) {
|
|
|
+ const signatures = [
|
|
|
+ // Reject PyTorch models
|
|
|
+ [ 0x80, undefined, 0x8a, 0x0a, 0x6c, 0xfc, 0x9c, 0x46, 0xf9, 0x20, 0x6a, 0xa8, 0x50, 0x19 ],
|
|
|
+ // Reject TorchScript models
|
|
|
+ [ 0x50, 0x4b ]
|
|
|
+ ];
|
|
|
+ const stream = this.stream;
|
|
|
+ if (!signatures.some((signature) => signature.length <= stream.length && stream.peek(signature.length).every((value, index) => signature[index] === undefined || signature[index] === value))) {
|
|
|
try {
|
|
|
switch (type) {
|
|
|
case 'pbtxt': {
|
|
|
reset = true;
|
|
|
- const decoder = base.TextDecoder.create(this.stream.peek());
|
|
|
+ const decoder = base.TextDecoder.create(stream.peek());
|
|
|
let count = 0;
|
|
|
for (let i = 0; i < 0x100; i++) {
|
|
|
const c = decoder.decode();
|
|
|
@@ -1252,7 +1258,7 @@ view.ModelContext = class {
|
|
|
}
|
|
|
}
|
|
|
if (count < 4) {
|
|
|
- const reader = protobuf.TextReader.create(this.stream.peek());
|
|
|
+ const reader = protobuf.TextReader.create(stream.peek());
|
|
|
reader.start(false);
|
|
|
while (!reader.end(false)) {
|
|
|
const tag = reader.tag();
|
|
|
@@ -1274,7 +1280,7 @@ view.ModelContext = class {
|
|
|
}
|
|
|
case 'pb': {
|
|
|
reset = true;
|
|
|
- const reader = protobuf.Reader.create(this.stream.peek());
|
|
|
+ const reader = protobuf.Reader.create(stream.peek());
|
|
|
const length = reader.length;
|
|
|
while (reader.position < length) {
|
|
|
const tag = reader.uint32();
|
|
|
@@ -1371,7 +1377,7 @@ view.ModelFactoryService = class {
|
|
|
this._host = host;
|
|
|
this._extensions = [];
|
|
|
this.register('./pytorch', [ '.pt', '.pth', '.pt1', '.pyt', '.pkl', '.h5', '.t7', '.model', '.dms', '.tar', '.ckpt', '.chkpt', '.tckpt', '.bin', '.pb', '.zip', '.nn' ]);
|
|
|
- this.register('./onnx', [ '.onnx', '.pb', '.pbtxt', '.prototxt', '.model' ]);
|
|
|
+ this.register('./onnx', [ '.onnx', '.pb', '.pbtxt', '.prototxt', '.model', '.pt', '.pth', '.pkl' ]);
|
|
|
this.register('./mxnet', [ '.mar', '.model', '.json', '.params' ]);
|
|
|
this.register('./coreml', [ '.mlmodel' ]);
|
|
|
this.register('./caffe', [ '.caffemodel', '.pbtxt', '.prototxt', '.pt', '.txt' ]);
|