Parcourir la source

Update view.js

Lutz Roeder il y a 4 ans
Parent
commit
a55d59e380
12 fichiers modifiés avec 213 ajouts et 151 suppressions
  1. 1 2
      source/coreml.js
  2. 8 6
      source/dl4j.js
  3. 2 1
      source/flux.js
  4. 2 2
      source/hdf5.js
  5. 11 10
      source/json.js
  6. 1 1
      source/keras.js
  7. 1 1
      source/npz.js
  8. 48 14
      source/python.js
  9. 8 8
      source/pytorch.js
  10. 1 1
      source/rknn.js
  11. 126 104
      source/view.js
  12. 4 1
      source/zip.js

+ 1 - 2
source/coreml.js

@@ -139,8 +139,7 @@ coreml.ModelFactory = class {
                 };
                 const openManifestStream = (context, path) => {
                     return context.request(path + 'Manifest.json', null).then((stream) => {
-                        const buffer = stream.peek();
-                        const reader = json.TextReader.create(buffer);
+                        const reader = json.TextReader.open(stream);
                         const obj = reader.read();
                         return openManifest(obj, context, path);
                     });

+ 8 - 6
source/dl4j.js

@@ -26,15 +26,17 @@ dl4j.ModelFactory = class {
     }
 
     static _openContainer(entries) {
-        const configurationStream = entries.get('configuration.json');
-        const coefficientsStream = entries.get('coefficients.bin');
-        if (configurationStream) {
+        const stream = entries.get('configuration.json');
+        const coefficients = entries.get('coefficients.bin');
+        if (stream) {
             try {
-                const reader = json.TextReader.create(configurationStream.peek());
+                const reader = json.TextReader.open(stream);
                 const configuration = reader.read();
                 if (configuration && (configuration.confs || configuration.vertices)) {
-                    const coefficients = coefficientsStream ? coefficientsStream.peek() : [];
-                    return { configuration: configuration, coefficients: coefficients };
+                    return {
+                        configuration: configuration,
+                        coefficients: coefficients ? coefficients.peek() : []
+                    };
                 }
             }
             catch (error) {

+ 2 - 1
source/flux.js

@@ -20,7 +20,8 @@ flux.ModelFactory = class {
         return Promise.resolve().then(() => {
             let root = null;
             try {
-                const reader = json.BinaryReader.create(context.stream.peek());
+                const stream = context.stream;
+                const reader = json.BinaryReader.open(stream);
                 root = reader.read();
             }
             catch (error) {

+ 2 - 2
source/hdf5.js

@@ -1170,8 +1170,8 @@ hdf5.Filter = class {
     decode(data) {
         switch (this.id) {
             case 1: { // gzip
-                const buffer = data.subarray(2, data.length); // skip zlib header
-                return new zip.Inflater().inflateRaw(buffer);
+                const archive = zip.Archive.open(data);
+                return archive.entries.get('');
             }
             default:
                 throw hdf5.Error("Unsupported filter '" + this.name + "'.");

+ 11 - 10
source/json.js

@@ -5,17 +5,17 @@ var base = base || require('./base');
 
 json.TextReader = class {
 
-    constructor(buffer) {
-        this._buffer = buffer;
-        this._escape = { '"': '"', '\\': '\\', '/': '/', b: '\b', f: '\f', n: '\n', r: '\r', t: '\t' };
+    static open(data) {
+        return new json.TextReader(data);
     }
 
-    static create(buffer) {
-        return new json.TextReader(buffer);
+    constructor(data) {
+        this._data = data;
+        this._escape = { '"': '"', '\\': '\\', '/': '/', b: '\b', f: '\f', n: '\n', r: '\r', t: '\t' };
     }
 
     read() {
-        const decoder = base.TextDecoder.open(this._buffer);
+        const decoder = base.TextDecoder.open(this._data);
         const stack = [];
         this._decoder = decoder;
         this._position = 0;
@@ -357,12 +357,13 @@ json.TextReader = class {
 
 json.BinaryReader = class {
 
-    constructor(buffer) {
-        this._buffer = buffer;
+    static open(data) {
+        const buffer = data instanceof Uint8Array ? data : data.peek();
+        return new json.BinaryReader(buffer);
     }
 
-    static create(buffer) {
-        return new json.BinaryReader(buffer);
+    constructor(buffer) {
+        this._buffer = buffer;
     }
 
     read() {

+ 1 - 1
source/keras.js

@@ -51,7 +51,7 @@ keras.ModelFactory = class {
                     if (rootGroup.attribute('model_config') || rootGroup.attribute('layer_names')) {
                         const model_config_json = rootGroup.attribute('model_config');
                         if (model_config_json) {
-                            const reader = json.TextReader.create(model_config_json);
+                            const reader = json.TextReader.open(model_config_json);
                             model_config = reader.read();
                         }
                         backend = rootGroup.attribute('backend') || '';

+ 1 - 1
source/npz.js

@@ -72,7 +72,7 @@ npz.ModelFactory = class {
                             if (array.dataType !== 'O') {
                                 throw new npz.Error("Invalid data type '" + array.dataType + "'.");
                             }
-                            const unpickler = new python.Unpickler(array.data);
+                            const unpickler = python.Unpickler.open(array.data);
                             const root = unpickler.load((name, args) => execution.invoke(name, args));
                             array = { dataType: root.dtype.name, shape: null, data: null, byteOrder: '|' };
                         }

+ 48 - 14
source/python.js

@@ -2081,12 +2081,12 @@ python.Execution = class {
         });
         this.registerType('spacy._ml.PrecomputableAffine', class {
             __setstate__(state) {
-                Object.assign(this, new python.Unpickler(state).load((name, args) => self.invoke(name, args), null));
+                Object.assign(this, python.Unpickler.open(state).load((name, args) => self.invoke(name, args), null));
             }
         });
         this.registerType('spacy.syntax._parser_model.ParserModel', class {
             __setstate__(state) {
-                Object.assign(this, new python.Unpickler(state).load((name, args) => self.invoke(name, args), null));
+                Object.assign(this, python.Unpickler.open(state).load((name, args) => self.invoke(name, args), null));
             }
         });
         this.registerType('thinc.describe.Biases', class {
@@ -2116,52 +2116,52 @@ python.Execution = class {
         });
         this.registerType('thinc.neural._classes.affine.Affine', class {
             __setstate__(state) {
-                Object.assign(this, new python.Unpickler(state).load((name, args) => self.invoke(name, args), null));
+                Object.assign(this, python.Unpickler.open(state).load((name, args) => self.invoke(name, args), null));
             }
         });
         this.registerType('thinc.neural._classes.convolution.ExtractWindow', class {
             __setstate__(state) {
-                Object.assign(this, new python.Unpickler(state).load((name, args) => self.invoke(name, args), null));
+                Object.assign(this, python.Unpickler.open(state).load((name, args) => self.invoke(name, args), null));
             }
         });
         this.registerType('thinc.neural._classes.feature_extracter.FeatureExtracter', class {
             __setstate__(state) {
-                Object.assign(this, new python.Unpickler(state).load((name, args) => self.invoke(name, args), null));
+                Object.assign(this, python.Unpickler.open(state).load((name, args) => self.invoke(name, args), null));
             }
         });
         this.registerType('thinc.neural._classes.feed_forward.FeedForward', class {
             __setstate__(state) {
-                Object.assign(this, new python.Unpickler(state).load((name, args) => self.invoke(name, args), null));
+                Object.assign(this, python.Unpickler.open(state).load((name, args) => self.invoke(name, args), null));
             }
         });
         this.registerType('thinc.neural._classes.function_layer.FunctionLayer', class {
             __setstate__(state) {
-                Object.assign(this, new python.Unpickler(state).load((name, args) => self.invoke(name, args), null));
+                Object.assign(this, python.Unpickler.open(state).load((name, args) => self.invoke(name, args), null));
             }
         });
         this.registerType('thinc.neural._classes.hash_embed.HashEmbed', class {
             __setstate__(state) {
-                Object.assign(this, new python.Unpickler(state).load((name, args) => self.invoke(name, args), null));
+                Object.assign(this, python.Unpickler.open(state).load((name, args) => self.invoke(name, args), null));
             }
         });
         this.registerType('thinc.neural._classes.layernorm.LayerNorm', class {
             __setstate__(state) {
-                Object.assign(this, new python.Unpickler(state).load((name, args) => self.invoke(name, args), null));
+                Object.assign(this, python.Unpickler.open(state).load((name, args) => self.invoke(name, args), null));
             }
         });
         this.registerType('thinc.neural._classes.maxout.Maxout', class {
             __setstate__(state) {
-                Object.assign(this, new python.Unpickler(state).load((name, args) => self.invoke(name, args), null));
+                Object.assign(this, python.Unpickler.open(state).load((name, args) => self.invoke(name, args), null));
             }
         });
         this.registerType('thinc.neural._classes.resnet.Residual', class {
             __setstate__(state) {
-                Object.assign(this, new python.Unpickler(state).load((name, args) => self.invoke(name, args), null));
+                Object.assign(this, python.Unpickler.open(state).load((name, args) => self.invoke(name, args), null));
             }
         });
         this.registerType('thinc.neural._classes.softmax.Softmax', class {
             __setstate__(state) {
-                Object.assign(this, new python.Unpickler(state).load((name, args) => self.invoke(name, args), null));
+                Object.assign(this, python.Unpickler.open(state).load((name, args) => self.invoke(name, args), null));
             }
         });
         this.registerType('thinc.neural.mem.Memory', class {
@@ -2958,8 +2958,25 @@ python.Utility = class {
 
 python.Unpickler = class {
 
-    constructor(buffer) {
-        this._reader = buffer instanceof Uint8Array ? new python.Unpickler.BinaryReader(buffer) : new python.Unpickler.StreamReader(buffer);
+    static open(data) {
+        const reader = data instanceof Uint8Array ? new python.Unpickler.BinaryReader(data) : new python.Unpickler.StreamReader(data);
+        if (reader.length > 2) {
+            const head = reader.peek(2);
+            if (head[0] === 0x80 && head[1] < 7) {
+                return new python.Unpickler(reader);
+            }
+            reader.seek(-1);
+            const tail = reader.peek(1);
+            reader.seek(0);
+            if (tail[0] === 0x2e) {
+                return new python.Unpickler(reader);
+            }
+        }
+        return null;
+    }
+
+    constructor(reader) {
+        this._reader = reader;
     }
 
     load(function_call, persistent_load) {
@@ -3464,6 +3481,13 @@ python.Unpickler.BinaryReader = class {
         return this._length;
     }
 
+    seek(position) {
+        this._position = position >= 0 ? position : this._length + position;
+        if (this._position > this._buffer.length) {
+            throw new Error('Expected ' + (this._position - this._buffer.length) + ' more bytes. The file might be corrupted. Unexpected end of file.');
+        }
+    }
+
     skip(offset) {
         this._position += offset;
         if (this._position > this._buffer.length) {
@@ -3577,6 +3601,11 @@ python.Unpickler.StreamReader = class {
         return this._length;
     }
 
+    seek(position) {
+        this._stream.seek(position);
+        this._position = this._stream.position;
+    }
+
     skip(offset) {
         this._position += offset;
         if (this._position > this._length) {
@@ -3590,6 +3619,11 @@ python.Unpickler.StreamReader = class {
         return this._stream.stream(length);
     }
 
+    peek(length) {
+        this._stream.seek(this._position);
+        return this._stream.peek(length);
+    }
+
     read(length) {
         this._stream.seek(this._position);
         this.skip(length);

+ 8 - 8
source/pytorch.js

@@ -1892,7 +1892,7 @@ pytorch.Execution = class extends python.Execution {
         const buffer = this.source(file + '.debug_pkl');
         if (buffer) {
             return null;
-            // const unpickler = new python.Unpickler(buffer);
+            // const unpickler = python.Unpickler.open(buffer);
             // return unpickler.load((name, args) => this.invoke(name, args), null);
         }
         return null;
@@ -1969,7 +1969,7 @@ pytorch.Container.Tar = class {
         this._entries = null;
 
         if (entries.sys_info) {
-            const unpickler = new python.Unpickler(entries.sys_info);
+            const unpickler = python.Unpickler.open(entries.sys_info);
             const sys_info = unpickler.load((name, args) => execution.invoke(name, args));
             if (sys_info.protocol_version != 1000) {
                 throw new pytorch.Error("Unsupported protocol version '" + sys_info.protocol_version + "'.");
@@ -1985,7 +1985,7 @@ pytorch.Container.Tar = class {
 
         const deserialized_objects = {};
         if (entries.storages) {
-            const unpickler = new python.Unpickler(entries.storages);
+            const unpickler = python.Unpickler.open(entries.storages);
             const num_storages = unpickler.load((name, args) => execution.invoke(name, args));
             for (let i = 0; i < num_storages; i++) {
                 const args = unpickler.load();
@@ -2003,7 +2003,7 @@ pytorch.Container.Tar = class {
         }
 
         if (entries.tensors) {
-            const unpickler = new python.Unpickler(entries.tensors);
+            const unpickler = python.Unpickler.open(entries.tensors);
             const num_tensors = unpickler.load((name, args) => execution.invoke(name, args));
             for (let i = 0; i < num_tensors; i++) {
                 const args = unpickler.load();
@@ -2027,7 +2027,7 @@ pytorch.Container.Tar = class {
         }
 
         if (entries.pickle) {
-            const unpickler = new python.Unpickler(entries.pickle);
+            const unpickler = python.Unpickler.open(entries.pickle);
             const persistent_load = (saved_id) => {
                 return deserialized_objects[saved_id];
             };
@@ -2076,7 +2076,7 @@ pytorch.Container.Pickle = class {
         }
 
         const execution = new pytorch.Execution(null, this._exceptionCallback);
-        const unpickler = new python.Unpickler(this._stream.length < 0x7ffff000 ? this._stream.peek() : this._stream);
+        const unpickler = python.Unpickler.open(this._stream.length < 0x7ffff000 ? this._stream.peek() : this._stream);
 
         this._stream = null;
         this._exceptionCallback = null;
@@ -2372,7 +2372,7 @@ pytorch.Container.Zip = class {
                     const stream = this._entry('attributes.pkl');
                     if (stream) {
                         const buffer = stream.peek();
-                        const unpickler = new python.Unpickler(buffer);
+                        const unpickler = python.Unpickler.open(buffer);
                         this._attributes.push(...unpickler.load((name, args) => this.execution.invoke(name, args)));
                     }
                     while (queue.length > 0) {
@@ -2483,7 +2483,7 @@ pytorch.Container.Zip = class {
             }
             return storage;
         };
-        return new python.Unpickler(data).load((name, args) => this.execution.invoke(name, args), persistent_load);
+        return python.Unpickler.open(data).load((name, args) => this.execution.invoke(name, args), persistent_load);
     }
 
     _storage(dirname) {

+ 1 - 1
source/rknn.js

@@ -486,7 +486,7 @@ rknn.Container = class {
             this._version = this._reader.uint64();
             this._weights = this._reader.read();
             const buffer = this._reader.read();
-            const reader = json.TextReader.create(buffer);
+            const reader = json.TextReader.open(buffer);
             this._model = reader.read();
             delete this._reader;
         }

+ 126 - 104
source/view.js

@@ -1220,146 +1220,168 @@ view.ModelContext = class {
     open(type) {
         if (!this._content.has(type)) {
             this._content.set(type, undefined);
-            let reset = false;
-            switch (type) {
-                case 'json': {
-                    try {
-                        reset = true;
-                        const buffer = this.stream.peek();
-                        const reader = json.TextReader.create(buffer);
-                        const obj = reader.read();
-                        this._content.set(type, obj);
-                    }
-                    catch (err) {
-                        // continue regardless of error
+            const stream = this.stream;
+            const position = stream.position;
+            const skip =
+                Array.from(this._tags.values()).some((map) => map.size > 0) ||
+                Array.from(this._content.values()).some((obj) => obj !== undefined);
+            if (!skip) {
+                switch (type) {
+                    case 'json': {
+                        try {
+                            const reader = json.TextReader.open(stream);
+                            const obj = reader.read();
+                            this._content.set(type, obj);
+                        }
+                        catch (err) {
+                            // continue regardless of error
+                        }
+                        break;
                     }
-                    break;
-                }
-                case 'pkl': {
-                    try {
-                        if (this.stream.length > 2) {
-                            const stream = this.stream.peek(1)[0] === 0x78 ? zip.Archive.open(this.stream).entries.values().next().value : this.stream;
-                            const match = (stream) => {
-                                const head = stream.peek(2);
-                                if (head[0] === 0x80 && head[1] < 7) {
-                                    return true;
-                                }
-                                stream.seek(-1);
-                                const tail = stream.peek(1);
-                                stream.seek(0);
-                                if (tail[0] === 0x2e) {
-                                    return true;
+                    case 'pkl': {
+                        try {
+                            if (stream.length > 2) {
+                                const zlib = (stream) => {
+                                    const buffer = stream.peek(2);
+                                    if (buffer[0] === 0x78) {
+                                        const check = (buffer[0] << 8) + buffer[1];
+                                        if (check % 31 === 0) {
+                                            const archive = zip.Archive.open(stream);
+                                            return archive.entries.get('');
+                                        }
+                                    }
+                                    return stream;
+                                };
+                                const unpickler = python.Unpickler.open(zlib(stream));
+                                if (unpickler) {
+                                    const execution = new python.Execution(null, (error, fatal) => {
+                                        const message = error && error.message ? error.message : error.toString();
+                                        this.exception(new view.Error(message.replace(/\.$/, '') + " in '" + this.identifier + "'."), fatal);
+                                    });
+                                    const obj = unpickler.load((name, args) => execution.invoke(name, args));
+                                    this._content.set(type, obj);
                                 }
-                                return false;
-                            };
-                            if (match(stream)) {
-                                reset = true;
-                                const unpickler = new python.Unpickler(stream);
-                                const execution = new python.Execution(null, (error, fatal) => {
-                                    const message = error && error.message ? error.message : error.toString();
-                                    this.exception(new view.Error(message.replace(/\.$/, '') + " in '" + this.identifier + "'."), fatal);
-                                });
-                                const obj = unpickler.load((name, args) => execution.invoke(name, args));
-                                this._content.set(type, obj);
                             }
                         }
+                        catch (err) {
+                            // continue regardless of error
+                        }
+                        break;
                     }
-                    catch (err) {
-                        // continue regardless of error
-                    }
-                    break;
                 }
             }
-            if (reset) {
-                this.stream.seek(0);
+            if (stream.position !== position) {
+                stream.seek(0);
             }
         }
         return this._content.get(type);
     }
 
     tags(type) {
-        let tags = this._tags.get(type);
-        if (!tags) {
-            tags = new Map();
+        if (!this._tags.has(type)) {
+            const tags = new Map();
             const stream = this.stream;
+            const position = stream.position;
             if (stream) {
                 const signatures = [
-                    // Reject PyTorch models
-                    [ 0x80, undefined, 0x8a, 0x0a, 0x6c, 0xfc, 0x9c, 0x46, 0xf9, 0x20, 0x6a, 0xa8, 0x50, 0x19 ],
-                    [ 0x50, 0x4b ]
+                    [ 0x80, undefined, 0x8a, 0x0a, 0x6c, 0xfc, 0x9c, 0x46, 0xf9, 0x20, 0x6a, 0xa8, 0x50, 0x19 ], // PyTorch
+                    [ 0x50, 0x4b ], // Zip
+                    [ 0x1f, 0x8b ] // Gzip
                 ];
-                if (!signatures.some((signature) => signature.length <= stream.length && stream.peek(signature.length).every((value, index) => signature[index] === undefined || signature[index] === value))) {
+                const skip =
+                    signatures.some((signature) => signature.length <= stream.length && stream.peek(signature.length).every((value, index) => signature[index] === undefined || signature[index] === value)) ||
+                    Array.from(this._tags.values()).some((map) => map.size > 0) ||
+                    Array.from(this._content.values()).some((obj) => obj !== undefined);
+                if (!skip) {
+                    const detectTextProto = (stream) => {
+                        const decoder = base.TextDecoder.open(stream);
+                        for (let i = 0; i < 0x100; i++) {
+                            const c = decoder.decode();
+                            if (c === undefined || c === '\0') {
+                                break;
+                            }
+                            if (c < ' ' && c !== '\n' && c !== '\r' && c !== '\t') {
+                                return false;
+                            }
+                        }
+                        return true;
+                    };
+                    const decodeTextProto = (stream, tags) => {
+                        const reader = protobuf.TextReader.open(stream);
+                        reader.start(false);
+                        while (!reader.end(false)) {
+                            const tag = reader.tag();
+                            tags.set(tag, true);
+                            if (reader.token() === '{') {
+                                reader.start();
+                                while (!reader.end()) {
+                                    const subtag = reader.tag();
+                                    tags.set(tag + '.' + subtag, true);
+                                    reader.skip();
+                                    reader.match(',');
+                                }
+                            }
+                            else {
+                                reader.skip();
+                            }
+                        }
+                    };
+                    const detectBinaryProto = (stream) => {
+                        const buffer = stream.peek(1);
+                        const type = buffer[0] & 7;
+                        if (type === 4 || type === 6 || type === 7) {
+                            return false;
+                        }
+                        return true;
+                    };
+                    const decodeBinaryProto = (stream, tags) => {
+                        const reader = protobuf.BinaryReader.open(stream);
+                        const length = reader.length;
+                        while (reader.position < length) {
+                            const tag = reader.uint32();
+                            const field = tag >>> 3;
+                            const type = tag & 7;
+                            if (type > 5 || field === 0) {
+                                tags.clear();
+                                break;
+                            }
+                            tags.set(field, type);
+                            try {
+                                reader.skipType(type);
+                            }
+                            catch (err) {
+                                tags.clear();
+                                break;
+                            }
+                        }
+                    };
                     try {
                         switch (type) {
                             case 'pbtxt': {
-                                const buffer = stream.peek();
-                                const decoder = base.TextDecoder.open(buffer);
-                                let count = 0;
-                                for (let i = 0; i < 0x100; i++) {
-                                    const c = decoder.decode();
-                                    switch (c) {
-                                        case '\n': case '\r': case '\t': case '\0': break;
-                                        case undefined: i = 0x100; break;
-                                        default: count += c < ' ' ? 1 : 0; break;
-                                    }
-                                }
-                                if (count < 4) {
-                                    const reader = protobuf.TextReader.open(stream);
-                                    reader.start(false);
-                                    while (!reader.end(false)) {
-                                        const tag = reader.tag();
-                                        tags.set(tag, true);
-                                        if (reader.token() === '{') {
-                                            reader.start();
-                                            while (!reader.end()) {
-                                                const subtag = reader.tag();
-                                                tags.set(tag + '.' + subtag, true);
-                                                reader.skip();
-                                                reader.match(',');
-                                            }
-                                        }
-                                        else {
-                                            reader.skip();
-                                        }
-                                    }
+                                if (detectTextProto(stream)) {
+                                    decodeTextProto(stream, tags);
                                 }
                                 break;
                             }
                             case 'pb': {
-                                const reader = protobuf.BinaryReader.open(stream);
-                                const length = reader.length;
-                                while (reader.position < length) {
-                                    const tag = reader.uint32();
-                                    const number = tag >>> 3;
-                                    const type = tag & 7;
-                                    if (type > 5 || number === 0) {
-                                        tags = new Map();
-                                        break;
-                                    }
-                                    tags.set(number, type);
-                                    try {
-                                        reader.skipType(type);
-                                    }
-                                    catch (err) {
-                                        tags = new Map();
-                                        break;
-                                    }
+                                if (detectBinaryProto(stream)) {
+                                    decodeBinaryProto(stream, tags);
                                 }
                                 break;
                             }
                         }
                     }
                     catch (error) {
-                        tags = new Map();
+                        tags.clear();
                     }
                 }
             }
-            if (stream.position !== 0) {
-                stream.seek(0);
+            if (stream.position !== position) {
+                stream.seek(position);
             }
             this._tags.set(type, tags);
         }
-        return tags;
+        return this._tags.get(type);
     }
 };
 

+ 4 - 1
source/zip.js

@@ -10,7 +10,10 @@ zip.Archive = class {
         if (stream.length > 2) {
             const buffer = stream.peek(2);
             if (buffer[0] === 0x78) { // zlib
-                return new zlib.Archive(stream);
+                const check = (buffer[0] << 8) + buffer[1];
+                if (check % 31 === 0) {
+                    return new zlib.Archive(stream);
+                }
             }
             const signature = buffer[0] === 0x50 && buffer[1] === 0x4B;
             const position = stream.position;