|
|
@@ -10,19 +10,10 @@ keras.ModelFactory = class {
|
|
|
match(context) {
|
|
|
const identifier = context.identifier;
|
|
|
const extension = identifier.split('.').pop().toLowerCase();
|
|
|
- if (extension == 'keras' || extension == 'h5' || extension == 'hd5' || extension == 'hdf5') {
|
|
|
- // Reject PyTorch models with .h5 file extension.
|
|
|
+ if (extension === 'h5' || extension === 'hd5' || extension === 'hdf5' || extension === 'keras' || extension === 'model') {
|
|
|
const buffer = context.buffer;
|
|
|
- const torch = [ 0x8a, 0x0a, 0x6c, 0xfc, 0x9c, 0x46, 0xf9, 0x20, 0x6a, 0xa8, 0x50, 0x19 ];
|
|
|
- if (buffer && buffer.length > 14 && buffer[0] == 0x80 && torch.every((v, i) => v == buffer[i + 2])) {
|
|
|
- return false;
|
|
|
- }
|
|
|
- return true;
|
|
|
- }
|
|
|
- if (extension == 'model') {
|
|
|
- const buffer = context.buffer;
|
|
|
- const hdf5 = [ 0x89, 0x48, 0x44, 0x46 ];
|
|
|
- return (buffer && buffer.length > hdf5.length && hdf5.every((v, i) => v == buffer[i]));
|
|
|
+ const signature = [ 0x89, 0x48, 0x44, 0x46, 0x0D, 0x0A, 0x1A, 0x0A ];
|
|
|
+ return (buffer && buffer.length > signature.length && signature.every((v, i) => v === buffer[i]));
|
|
|
}
|
|
|
if (extension == 'json' && !identifier.endsWith('-symbol.json')) {
|
|
|
const json = context.text;
|