|
|
@@ -1501,9 +1501,7 @@ view.ModelFactoryService = class {
|
|
|
_unsupported(context) {
|
|
|
const identifier = context.identifier;
|
|
|
const extension = identifier.split('.').pop().toLowerCase();
|
|
|
- for (const format of new Map([ [ 'Zip', zip ], [ 'tar', tar ] ])) {
|
|
|
- const name = format[0];
|
|
|
- const module = format[1];
|
|
|
+ for (const module of [ zip, tar, gzip ]) {
|
|
|
let archive = null;
|
|
|
try {
|
|
|
archive = module.Archive.open(context.stream);
|
|
|
@@ -1512,7 +1510,7 @@ view.ModelFactoryService = class {
|
|
|
// continue regardless of error
|
|
|
}
|
|
|
if (archive) {
|
|
|
- throw new view.Error("Invalid file content. File contains " + name + " archive in '" + identifier + "'.", true);
|
|
|
+ throw new view.Error("Archive contains no model files in '" + identifier + "'.", true);
|
|
|
}
|
|
|
}
|
|
|
const knownUnsupportedIdentifiers = new Set([
|
|
|
@@ -1603,30 +1601,13 @@ view.ModelFactoryService = class {
|
|
|
_openArchive(context) {
|
|
|
const entries = new Map();
|
|
|
let stream = context.stream;
|
|
|
- let extension;
|
|
|
- let identifier = context.identifier;
|
|
|
+ const identifier = context.identifier;
|
|
|
try {
|
|
|
- extension = identifier.split('.').pop().toLowerCase();
|
|
|
- const gzipArchive = gzip.Archive.open(stream);
|
|
|
- if (gzipArchive) {
|
|
|
- const entries = gzipArchive.entries;
|
|
|
- if (entries.length === 1) {
|
|
|
- const entry = entries[0];
|
|
|
- if (entry.name) {
|
|
|
- identifier = entry.name;
|
|
|
- }
|
|
|
- else {
|
|
|
- identifier = identifier.substring(0, identifier.lastIndexOf('.'));
|
|
|
- switch (extension) {
|
|
|
- case 'tgz':
|
|
|
- case 'tar': {
|
|
|
- if (identifier.split('.').pop().toLowerCase() !== 'tar') {
|
|
|
- identifier += '.tar';
|
|
|
- }
|
|
|
- break;
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
+ const archive = gzip.Archive.open(stream);
|
|
|
+ if (archive) {
|
|
|
+ entries.set('gzip', archive.entries);
|
|
|
+ if (archive.entries.length === 1) {
|
|
|
+ const entry = archive.entries[0];
|
|
|
stream = entry.stream;
|
|
|
}
|
|
|
}
|
|
|
@@ -1635,7 +1616,6 @@ view.ModelFactoryService = class {
|
|
|
const message = error && error.message ? error.message : error.toString();
|
|
|
throw new view.ArchiveError(message.replace(/\.$/, '') + " in '" + identifier + "'.");
|
|
|
}
|
|
|
-
|
|
|
try {
|
|
|
const formats = new Map([ [ 'zip', zip ], [ 'tar', tar ] ]);
|
|
|
for (const pair of formats) {
|
|
|
@@ -1644,6 +1624,7 @@ view.ModelFactoryService = class {
|
|
|
const archive = module.Archive.open(stream);
|
|
|
if (archive) {
|
|
|
entries.set(format, archive.entries);
|
|
|
+ entries.delete('gzip');
|
|
|
break;
|
|
|
}
|
|
|
}
|
|
|
@@ -1706,6 +1687,64 @@ view.ModelFactoryService = class {
|
|
|
const folder = rotate(map).filter(equals).map(at(0)).join('/');
|
|
|
return folder.length === 0 ? folder : folder + '/';
|
|
|
};
|
|
|
+ const filter = (queue) => {
|
|
|
+ let matches = [];
|
|
|
+ const nextEntry = () => {
|
|
|
+ if (queue.length > 0) {
|
|
|
+ const entry = queue.shift();
|
|
|
+ const context = new view.ModelContext(new view.ArchiveContext(this._host, null, folder, entry.name, entry.stream));
|
|
|
+ let modules = this._filter(context);
|
|
|
+ const nextModule = () => {
|
|
|
+ if (modules.length > 0) {
|
|
|
+ const id = modules.shift();
|
|
|
+ return this._host.require(id).then((module) => {
|
|
|
+ if (!module.ModelFactory) {
|
|
|
+ throw new view.ArchiveError("Failed to load module '" + id + "'.", null);
|
|
|
+ }
|
|
|
+ const factory = new module.ModelFactory();
|
|
|
+ if (factory.match(context)) {
|
|
|
+ matches.push(entry);
|
|
|
+ modules = [];
|
|
|
+ }
|
|
|
+ return nextModule();
|
|
|
+ });
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ return nextEntry();
|
|
|
+ }
|
|
|
+ };
|
|
|
+ return nextModule();
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ if (matches.length === 0) {
|
|
|
+ return Promise.resolve(null);
|
|
|
+ }
|
|
|
+ // MXNet
|
|
|
+ if (matches.length === 2 &&
|
|
|
+ matches.some((e) => e.name.toLowerCase().endsWith('.params')) &&
|
|
|
+ matches.some((e) => e.name.toLowerCase().endsWith('-symbol.json'))) {
|
|
|
+ matches = matches.filter((e) => e.name.toLowerCase().endsWith('.params'));
|
|
|
+ }
|
|
|
+ // TensorFlow.js
|
|
|
+ if (matches.length > 0 &&
|
|
|
+ matches.some((e) => e.name.toLowerCase().endsWith('.bin')) &&
|
|
|
+ matches.some((e) => e.name.toLowerCase().endsWith('.json'))) {
|
|
|
+ matches = matches.filter((e) => e.name.toLowerCase().endsWith('.json'));
|
|
|
+ }
|
|
|
+ // TensorFlow Bundle
|
|
|
+ if (matches.length > 1 &&
|
|
|
+ matches.some((e) => e.name.toLowerCase().endsWith('.data-00000-of-00001'))) {
|
|
|
+ matches = matches.filter((e) => !e.name.toLowerCase().endsWith('.data-00000-of-00001'));
|
|
|
+ }
|
|
|
+ if (matches.length > 1) {
|
|
|
+ return Promise.reject(new view.ArchiveError('Archive contains multiple model files.'));
|
|
|
+ }
|
|
|
+ const match = matches.shift();
|
|
|
+ return Promise.resolve(new view.ModelContext(new view.ArchiveContext(this._host, entries, folder, match.name, match.stream)));
|
|
|
+ }
|
|
|
+ };
|
|
|
+ return nextEntry();
|
|
|
+ };
|
|
|
const files = entries.filter((entry) => {
|
|
|
if (entry.name.endsWith('/')) {
|
|
|
return false;
|
|
|
@@ -1719,63 +1758,14 @@ view.ModelFactoryService = class {
|
|
|
return true;
|
|
|
});
|
|
|
const folder = rootFolder(files.map((entry) => entry.name));
|
|
|
- let matches = [];
|
|
|
const queue = files.slice(0).filter((entry) => entry.name.substring(folder.length).indexOf('/') < 0);
|
|
|
- const nextEntry = () => {
|
|
|
- if (queue.length > 0) {
|
|
|
- const entry = queue.shift();
|
|
|
- const context = new view.ModelContext(new view.ArchiveContext(this._host, null, folder, entry.name, entry.stream));
|
|
|
- let modules = this._filter(context);
|
|
|
- const nextModule = () => {
|
|
|
- if (modules.length > 0) {
|
|
|
- const id = modules.shift();
|
|
|
- return this._host.require(id).then((module) => {
|
|
|
- if (!module.ModelFactory) {
|
|
|
- throw new view.ArchiveError("Failed to load module '" + id + "'.", null);
|
|
|
- }
|
|
|
- const factory = new module.ModelFactory();
|
|
|
- if (factory.match(context)) {
|
|
|
- matches.push(entry);
|
|
|
- modules = [];
|
|
|
- }
|
|
|
- return nextModule();
|
|
|
- });
|
|
|
- }
|
|
|
- else {
|
|
|
- return nextEntry();
|
|
|
- }
|
|
|
- };
|
|
|
- return nextModule();
|
|
|
- }
|
|
|
- else {
|
|
|
- if (matches.length === 0) {
|
|
|
- return Promise.resolve(null);
|
|
|
- }
|
|
|
- // MXNet
|
|
|
- if (matches.length === 2 &&
|
|
|
- matches.some((e) => e.name.toLowerCase().endsWith('.params')) &&
|
|
|
- matches.some((e) => e.name.toLowerCase().endsWith('-symbol.json'))) {
|
|
|
- matches = matches.filter((e) => e.name.toLowerCase().endsWith('.params'));
|
|
|
- }
|
|
|
- // TensorFlow.js
|
|
|
- if (matches.length > 0 &&
|
|
|
- matches.some((e) => e.name.toLowerCase().endsWith('.bin')) &&
|
|
|
- matches.some((e) => e.name.toLowerCase().endsWith('.json'))) {
|
|
|
- matches = matches.filter((e) => e.name.toLowerCase().endsWith('.json'));
|
|
|
- }
|
|
|
- // TensorFlow Bundle
|
|
|
- if (matches.length > 1 &&
|
|
|
- matches.some((e) => e.name.toLowerCase().endsWith('.data-00000-of-00001'))) {
|
|
|
- matches = matches.filter((e) => !e.name.toLowerCase().endsWith('.data-00000-of-00001'));
|
|
|
- }
|
|
|
- if (matches.length > 1) {
|
|
|
- return Promise.reject(new view.ArchiveError('Archive contains multiple model files.'));
|
|
|
- }
|
|
|
- const match = matches.shift();
|
|
|
- return Promise.resolve(new view.ModelContext(new view.ArchiveContext(this._host, entries, folder, match.name, match.stream)));
|
|
|
+ return filter(queue).then((context) => {
|
|
|
+ if (context) {
|
|
|
+ return Promise.resolve(context);
|
|
|
}
|
|
|
- };
|
|
|
- return nextEntry();
|
|
|
+ const queue = files.slice(0).filter((entry) => entry.name.substring(folder.length).indexOf('/') >= 0);
|
|
|
+ return filter(queue);
|
|
|
+ });
|
|
|
}
|
|
|
catch (error) {
|
|
|
return Promise.reject(new view.ArchiveError(error.message));
|
|
|
@@ -1796,6 +1786,7 @@ view.ModelFactoryService = class {
|
|
|
identifier.endsWith('.tar') ||
|
|
|
identifier.endsWith('.tar.gz') ||
|
|
|
identifier.endsWith('.tgz') ||
|
|
|
+ identifier.endsWith('.gz') ||
|
|
|
identifier.endsWith('.mar') ||
|
|
|
identifier.endsWith('.model')) {
|
|
|
this._host.event('File', 'Accept', extension, 1);
|