Browse Source

Update gzip support (#249)

Lutz Roeder 4 years ago
parent
commit
781510f9ab
3 changed files with 85 additions and 86 deletions
  1. 2 1
      source/gzip.js
  2. 74 83
      source/view.js
  3. 9 2
      test/models.json

+ 2 - 1
source/gzip.js

@@ -15,8 +15,9 @@ gzip.Archive = class {
     }
 
     constructor(stream) {
+        const position = stream.position;
         this._entries = [ new gzip.Entry(stream) ];
-        stream.seek(0);
+        stream.seek(position);
     }
 
     get entries() {

+ 74 - 83
source/view.js

@@ -1501,9 +1501,7 @@ view.ModelFactoryService = class {
     _unsupported(context) {
         const identifier = context.identifier;
         const extension = identifier.split('.').pop().toLowerCase();
-        for (const format of new Map([ [ 'Zip', zip ], [ 'tar', tar ] ])) {
-            const name = format[0];
-            const module = format[1];
+        for (const module of [ zip, tar, gzip ]) {
             let archive = null;
             try {
                 archive = module.Archive.open(context.stream);
@@ -1512,7 +1510,7 @@ view.ModelFactoryService = class {
                 // continue regardless of error
             }
             if (archive) {
-                throw new view.Error("Invalid file content. File contains " + name + " archive in '" + identifier + "'.", true);
+                throw new view.Error("Archive contains no model files in '" + identifier + "'.", true);
             }
         }
         const knownUnsupportedIdentifiers = new Set([
@@ -1603,30 +1601,13 @@ view.ModelFactoryService = class {
     _openArchive(context) {
         const entries = new Map();
         let stream = context.stream;
-        let extension;
-        let identifier = context.identifier;
+        const identifier = context.identifier;
         try {
-            extension = identifier.split('.').pop().toLowerCase();
-            const gzipArchive = gzip.Archive.open(stream);
-            if (gzipArchive) {
-                const entries = gzipArchive.entries;
-                if (entries.length === 1) {
-                    const entry = entries[0];
-                    if (entry.name) {
-                        identifier = entry.name;
-                    }
-                    else {
-                        identifier = identifier.substring(0, identifier.lastIndexOf('.'));
-                        switch (extension) {
-                            case 'tgz':
-                            case 'tar': {
-                                if (identifier.split('.').pop().toLowerCase() !== 'tar') {
-                                    identifier += '.tar';
-                                }
-                                break;
-                            }
-                        }
-                    }
+            const archive = gzip.Archive.open(stream);
+            if (archive) {
+                entries.set('gzip', archive.entries);
+                if (archive.entries.length === 1) {
+                    const entry = archive.entries[0];
                     stream = entry.stream;
                 }
             }
@@ -1635,7 +1616,6 @@ view.ModelFactoryService = class {
             const message = error && error.message ? error.message : error.toString();
             throw new view.ArchiveError(message.replace(/\.$/, '') + " in '" + identifier + "'.");
         }
-
         try {
             const formats = new Map([ [ 'zip', zip ], [ 'tar', tar ] ]);
             for (const pair of formats) {
@@ -1644,6 +1624,7 @@ view.ModelFactoryService = class {
                 const archive = module.Archive.open(stream);
                 if (archive) {
                     entries.set(format, archive.entries);
+                    entries.delete('gzip');
                     break;
                 }
             }
@@ -1706,6 +1687,64 @@ view.ModelFactoryService = class {
                 const folder = rotate(map).filter(equals).map(at(0)).join('/');
                 return folder.length === 0 ? folder : folder + '/';
             };
+            const filter = (queue) => {
+                let matches = [];
+                const nextEntry = () => {
+                    if (queue.length > 0) {
+                        const entry = queue.shift();
+                        const context = new view.ModelContext(new view.ArchiveContext(this._host, null, folder, entry.name, entry.stream));
+                        let modules = this._filter(context);
+                        const nextModule = () => {
+                            if (modules.length > 0) {
+                                const id = modules.shift();
+                                return this._host.require(id).then((module) => {
+                                    if (!module.ModelFactory) {
+                                        throw new view.ArchiveError("Failed to load module '" + id + "'.", null);
+                                    }
+                                    const factory = new module.ModelFactory();
+                                    if (factory.match(context)) {
+                                        matches.push(entry);
+                                        modules = [];
+                                    }
+                                    return nextModule();
+                                });
+                            }
+                            else {
+                                return nextEntry();
+                            }
+                        };
+                        return nextModule();
+                    }
+                    else {
+                        if (matches.length === 0) {
+                            return Promise.resolve(null);
+                        }
+                        // MXNet
+                        if (matches.length === 2 &&
+                            matches.some((e) => e.name.toLowerCase().endsWith('.params')) &&
+                            matches.some((e) => e.name.toLowerCase().endsWith('-symbol.json'))) {
+                            matches = matches.filter((e) => e.name.toLowerCase().endsWith('.params'));
+                        }
+                        // TensorFlow.js
+                        if (matches.length > 0 &&
+                            matches.some((e) => e.name.toLowerCase().endsWith('.bin')) &&
+                            matches.some((e) => e.name.toLowerCase().endsWith('.json'))) {
+                            matches = matches.filter((e) => e.name.toLowerCase().endsWith('.json'));
+                        }
+                        // TensorFlow Bundle
+                        if (matches.length > 1 &&
+                            matches.some((e) => e.name.toLowerCase().endsWith('.data-00000-of-00001'))) {
+                            matches = matches.filter((e) => !e.name.toLowerCase().endsWith('.data-00000-of-00001'));
+                        }
+                        if (matches.length > 1) {
+                            return Promise.reject(new view.ArchiveError('Archive contains multiple model files.'));
+                        }
+                        const match = matches.shift();
+                        return Promise.resolve(new view.ModelContext(new view.ArchiveContext(this._host, entries, folder, match.name, match.stream)));
+                    }
+                };
+                return nextEntry();
+            };
             const files = entries.filter((entry) => {
                 if (entry.name.endsWith('/')) {
                     return false;
@@ -1719,63 +1758,14 @@ view.ModelFactoryService = class {
                 return true;
             });
             const folder = rootFolder(files.map((entry) => entry.name));
-            let matches = [];
             const queue = files.slice(0).filter((entry) => entry.name.substring(folder.length).indexOf('/') < 0);
-            const nextEntry = () => {
-                if (queue.length > 0) {
-                    const entry = queue.shift();
-                    const context = new view.ModelContext(new view.ArchiveContext(this._host, null, folder, entry.name, entry.stream));
-                    let modules = this._filter(context);
-                    const nextModule = () => {
-                        if (modules.length > 0) {
-                            const id = modules.shift();
-                            return this._host.require(id).then((module) => {
-                                if (!module.ModelFactory) {
-                                    throw new view.ArchiveError("Failed to load module '" + id + "'.", null);
-                                }
-                                const factory = new module.ModelFactory();
-                                if (factory.match(context)) {
-                                    matches.push(entry);
-                                    modules = [];
-                                }
-                                return nextModule();
-                            });
-                        }
-                        else {
-                            return nextEntry();
-                        }
-                    };
-                    return nextModule();
-                }
-                else {
-                    if (matches.length === 0) {
-                        return Promise.resolve(null);
-                    }
-                    // MXNet
-                    if (matches.length === 2 &&
-                        matches.some((e) => e.name.toLowerCase().endsWith('.params')) &&
-                        matches.some((e) => e.name.toLowerCase().endsWith('-symbol.json'))) {
-                        matches = matches.filter((e) => e.name.toLowerCase().endsWith('.params'));
-                    }
-                    // TensorFlow.js
-                    if (matches.length > 0 &&
-                        matches.some((e) => e.name.toLowerCase().endsWith('.bin')) &&
-                        matches.some((e) => e.name.toLowerCase().endsWith('.json'))) {
-                        matches = matches.filter((e) => e.name.toLowerCase().endsWith('.json'));
-                    }
-                    // TensorFlow Bundle
-                    if (matches.length > 1 &&
-                        matches.some((e) => e.name.toLowerCase().endsWith('.data-00000-of-00001'))) {
-                        matches = matches.filter((e) => !e.name.toLowerCase().endsWith('.data-00000-of-00001'));
-                    }
-                    if (matches.length > 1) {
-                        return Promise.reject(new view.ArchiveError('Archive contains multiple model files.'));
-                    }
-                    const match = matches.shift();
-                    return Promise.resolve(new view.ModelContext(new view.ArchiveContext(this._host, entries, folder, match.name, match.stream)));
+            return filter(queue).then((context) => {
+                if (context) {
+                    return Promise.resolve(context);
                 }
-            };
-            return nextEntry();
+                const queue = files.slice(0).filter((entry) => entry.name.substring(folder.length).indexOf('/') >= 0);
+                return filter(queue);
+            });
         }
         catch (error) {
             return Promise.reject(new view.ArchiveError(error.message));
@@ -1796,6 +1786,7 @@ view.ModelFactoryService = class {
             identifier.endsWith('.tar') ||
             identifier.endsWith('.tar.gz') ||
             identifier.endsWith('.tgz') ||
+            identifier.endsWith('.gz') ||
             identifier.endsWith('.mar') ||
             identifier.endsWith('.model')) {
             this._host.event('File', 'Accept', extension, 1);

+ 9 - 2
test/models.json

@@ -17,7 +17,7 @@
     "type":   "_",
     "target": "coreml_invalid_file.mlmodel",
     "source": "https://github.com/lutzroeder/netron/files/3219681/coreml_invalid_file.mlmodel.zip",
-    "error":  "Invalid file content. File contains Zip archive in 'coreml_invalid_file.mlmodel'.",
+    "error":  "Archive contains no model files in 'coreml_invalid_file.mlmodel'.",
     "format": "Core ML v1",
     "link":   "https://github.com/lutzroeder/netron/issues/193"
   },
@@ -25,7 +25,7 @@
     "type":   "_",
     "target": "empty.zip",
     "source": "https://github.com/lutzroeder/netron/files/5581087/empty.zip",
-    "error":  "Invalid file content. File contains Zip archive in 'empty.zip'.",
+    "error":  "Archive contains no model files in 'empty.zip'.",
     "link":   "https://github.com/lutzroeder/netron/issues/458"
   },
   {
@@ -3065,6 +3065,13 @@
     "format": "ONNX v3",
     "link":   "https://github.com/lutzroeder/netron/issues/139"
   },
+  {
+    "type":   "onnx",
+    "target": "eisber_model3.pbtxt.gz",
+    "source": "https://github.com/lutzroeder/netron/files/6490172/eisber_model3.pbtxt.gz",
+    "format": "ONNX v3",
+    "link":   "https://github.com/lutzroeder/netron/issues/249"
+  },
   {
     "type":   "onnx",
     "target": "eisber_model3_invalid.pbtxt",