Ver código fonte

Update PyTorch test files

Lutz Roeder 3 anos atrás
pai
commit
741bf2c213
3 arquivos alterados com 67 adições e 54 exclusões
  1. 58 50
      source/view.js
  2. 5 0
      source/zip.js
  3. 4 4
      test/models.json

+ 58 - 50
source/view.js

@@ -1244,11 +1244,56 @@ view.Edge = class extends grapher.Edge {
 
 view.ModelContext = class {
 
-    constructor(context, formats) {
+    constructor(context) {
         this._context = context;
         this._tags = new Map();
         this._content = new Map();
-        this._formats = formats || new Map();
+        let stream = context.stream;
+        const entries = context.entries;
+        if (!stream && entries && entries.size > 0) {
+            this._entries = entries;
+            this._format = '';
+        }
+        else {
+            this._entries = new Map();
+            const entry = context instanceof view.EntryContext;
+            const identifier = context.identifier;
+            try {
+                const archive = gzip.Archive.open(stream);
+                if (archive) {
+                    this._entries = archive.entries;
+                    this._format = 'gzip';
+                    if (this._entries.size === 1) {
+                        stream = this._entries.values().next().value;
+                    }
+                }
+            }
+            catch (error) {
+                if (!entry) {
+                    const message = error && error.message ? error.message : error.toString();
+                    throw new view.ArchiveError(message.replace(/\.$/, '') + " in '" + identifier + "'.");
+                }
+            }
+            try {
+                const formats = new Map([ [ 'zip', zip ], [ 'tar', tar ] ]);
+                for (const pair of formats) {
+                    const format = pair[0];
+                    const module = pair[1];
+                    const archive = module.Archive.open(stream);
+                    if (archive) {
+                        this._entries = archive.entries;
+                        this._format = format;
+                        break;
+                    }
+                }
+            }
+            catch (error) {
+                if (!entry) {
+                    const message = error && error.message ? error.message : error.toString();
+                    throw new view.ArchiveError(message.replace(/\.$/, '') + " in '" + identifier + "'.");
+                }
+            }
+        }
     }
 
     get identifier() {
@@ -1272,7 +1317,10 @@ view.ModelContext = class {
     }
 
     entries(format) {
-        return this._formats.get(format) || new Map();
+        if (format !== undefined && format !== this._format) {
+            return new Map();
+        }
+        return this._entries;
     }
 
     open(type) {
@@ -1445,7 +1493,7 @@ view.ModelContext = class {
     }
 };
 
-view.ArchiveContext = class {
+view.EntryContext = class {
 
     constructor(host, entries, rootFolder, identifier, stream) {
         this._host = host;
@@ -1569,55 +1617,15 @@ view.ModelFactoryService = class {
 
     open(context) {
         return this._openSignature(context).then((context) => {
-            const containers = new Map();
-            let stream = context.stream;
-            const entries = context.entries;
-            if (!stream && entries && entries.size > 0) {
-                containers.set('', entries);
-            }
-            else {
-                const identifier = context.identifier;
-                try {
-                    const archive = gzip.Archive.open(stream);
-                    if (archive) {
-                        const entries = archive.entries;
-                        containers.set('gzip', entries);
-                        if (entries.size === 1) {
-                            stream = entries.values().next().value;
-                        }
-                    }
-                }
-                catch (error) {
-                    const message = error && error.message ? error.message : error.toString();
-                    throw new view.ArchiveError(message.replace(/\.$/, '') + " in '" + identifier + "'.");
-                }
-                try {
-                    const formats = new Map([ [ 'zip', zip ], [ 'tar', tar ] ]);
-                    for (const pair of formats) {
-                        const format = pair[0];
-                        const module = pair[1];
-                        const archive = module.Archive.open(stream);
-                        if (archive) {
-                            containers.set(format, archive.entries);
-                            containers.delete('gzip');
-                            break;
-                        }
-                    }
-                }
-                catch (error) {
-                    const message = error && error.message ? error.message : error.toString();
-                    throw new view.ArchiveError(message.replace(/\.$/, '') + " in '" + identifier + "'.");
-                }
-            }
-
-            const modelContext = new view.ModelContext(context, containers);
+            const modelContext = new view.ModelContext(context);
             /* eslint-disable consistent-return */
             return this._openContext(modelContext).then((model) => {
                 if (model) {
                     return model;
                 }
-                if (containers.size > 0) {
-                    return this._openEntries(containers.values().next().value).then((context) => {
+                const entries = modelContext.entries();
+                if (entries && entries.size > 0) {
+                    return this._openEntries(entries).then((context) => {
                         if (context) {
                             return this._openContext(context);
                         }
@@ -1899,7 +1907,7 @@ view.ModelFactoryService = class {
                 const nextEntry = () => {
                     if (queue.length > 0) {
                         const entry = queue.shift();
-                        const context = new view.ModelContext(new view.ArchiveContext(this._host, null, folder, entry.name, entry.stream));
+                        const context = new view.ModelContext(new view.EntryContext(this._host, null, folder, entry.name, entry.stream));
                         let modules = this._filter(context);
                         const nextModule = () => {
                             if (modules.length > 0) {
@@ -1975,7 +1983,7 @@ view.ModelFactoryService = class {
                         return Promise.reject(new view.ArchiveError('Archive contains multiple model files.'));
                     }
                     const match = matches.shift();
-                    return Promise.resolve(new view.ModelContext(new view.ArchiveContext(this._host, entries, folder, match.name, match.stream)));
+                    return Promise.resolve(new view.ModelContext(new view.EntryContext(this._host, entries, folder, match.name, match.stream)));
                 };
                 return nextEntry();
             };

+ 5 - 0
source/zip.js

@@ -576,6 +576,11 @@ zip.InflaterStream = class {
         return this._buffer.subarray(position, this._position);
     }
 
+    stream(length) {
+        const buffer = this.read(length);
+        return new zip.BinaryReader(buffer);
+    }
+
     byte() {
         const position = this._position;
         this.skip(1);

+ 4 - 4
test/models.json

@@ -4387,8 +4387,8 @@
   },
   {
     "type":     "pytorch",
-    "target":   "alexnet_traced.pt",
-    "source":   "https://github.com/lutzroeder/netron/files/6096602/alexnet_traced.pt.zip[alexnet_traced.pt]",
+    "target":   "alexnet_traced.pt.zip",
+    "source":   "https://github.com/lutzroeder/netron/files/6096602/alexnet_traced.pt.zip",
     "format":   "TorchScript v1.6",
     "link":     "https://github.com/lutzroeder/netron/issues/281"
   },
@@ -4401,8 +4401,8 @@
   },
   {
     "type":     "pytorch",
-    "target":   "alexnet.pt",
-    "source":   "https://github.com/lutzroeder/netron/files/6096605/alexnet.pt.zip[alexnet.pt]",
+    "target":   "alexnet.pt.zip",
+    "source":   "https://github.com/lutzroeder/netron/files/6096605/alexnet.pt.zip",
     "format":   "TorchScript v1.6",
     "link":     "https://github.com/lutzroeder/netron/issues/281"
   },