Lutz Roeder пре 6 година
родитељ
комит
52aebcd000
4 измењених фајлова са 14 додато и 19 уклоњено
  1. 9 5
      src/chainer.js
  2. 3 12
      src/keras.js
  3. 1 1
      src/view.js
  4. 1 1
      test/models.json

+ 9 - 5
src/chainer.js

@@ -11,12 +11,15 @@ chainer.ModelFactory = class {
     match(context) {
         const identifier = context.identifier; 
         const extension = identifier.split('.').pop().toLowerCase();
-        switch (extension) {
-            case 'npz':
-                return context.entries.length > 0 && context.entries.every((entry) => entry.name.indexOf('/') !== -1);
-            case 'h5':
-            case 'hdf5':
+        if (extension === 'npz') {
+            return context.entries.length > 0 && context.entries.every((entry) => entry.name.indexOf('/') !== -1);
+        }
+        if (extension === 'h5' || extension === 'hd5' || extension === 'hdf5' || extension === 'keras' || extension === 'model') {
+            const buffer = context.buffer;
+            const signature = [ 0x89, 0x48, 0x44, 0x46, 0x0D, 0x0A, 0x1A, 0x0A ];
+            if (buffer && buffer.length > signature.length && signature.every((v, i) => v === buffer[i])) {
                 return true;
+            }
         }
         return false;
     }
@@ -28,6 +31,7 @@ chainer.ModelFactory = class {
             case 'npz':
                 return this._openNumPy(context, host);
             case 'h5':
+            case 'hd5':
             case 'hdf5':
                 return this._openHdf5(context, host);
         }

+ 3 - 12
src/keras.js

@@ -10,19 +10,10 @@ keras.ModelFactory = class {
     match(context) {
         const identifier = context.identifier;
         const extension = identifier.split('.').pop().toLowerCase();
-        if (extension == 'keras' || extension == 'h5' || extension == 'hd5' || extension == 'hdf5') {
-            // Reject PyTorch models with .h5 file extension.
+        if (extension === 'h5' || extension === 'hd5' || extension === 'hdf5' || extension === 'keras' || extension === 'model') {
             const buffer = context.buffer;
-            const torch = [ 0x8a, 0x0a, 0x6c, 0xfc, 0x9c, 0x46, 0xf9, 0x20, 0x6a, 0xa8, 0x50, 0x19 ];
-            if (buffer && buffer.length > 14 && buffer[0] == 0x80 && torch.every((v, i) => v == buffer[i + 2])) {
-                return false;
-            }
-            return true;
-        }
-        if (extension == 'model') {
-            const buffer = context.buffer;
-            const hdf5 = [ 0x89, 0x48, 0x44, 0x46 ];
-            return (buffer && buffer.length > hdf5.length && hdf5.every((v, i) => v == buffer[i]));
+            const signature = [ 0x89, 0x48, 0x44, 0x46, 0x0D, 0x0A, 0x1A, 0x0A ];
+            return (buffer && buffer.length > signature.length && signature.every((v, i) => v === buffer[i]));
         }
         if (extension == 'json' && !identifier.endsWith('-symbol.json')) {
             const json = context.text;

+ 1 - 1
src/view.js

@@ -1158,7 +1158,7 @@ view.ModelFactoryService = class {
         this.register('./mnn', ['.mnn']);
         this.register('./ncnn', [ '.param', '.bin', '.cfg.ncnn', '.weights.ncnn' ]);
         this.register('./flux', [ '.bson' ]);
-        this.register('./chainer', [ '.npz', '.h5', '.hdf5' ]);
+        this.register('./chainer', [ '.npz', '.h5', '.hd5', '.hdf5' ]);
         this.register('./dl4j', [ '.zip' ]);
         this.register('./mlnet', [ '.zip' ]);
     }

+ 1 - 1
test/models.json

@@ -2021,7 +2021,7 @@
     "type":   "keras",
     "target": "keras_invalid_file.h5",
     "source": "https://github.com/lutzroeder/netron/files/3364286/keras_invalid_file.zip[keras_invalid_file.h5]",
-    "error":  "Not a valid HDF5 file in 'keras_invalid_file.h5'.\nNot a valid HDF5 file in 'keras_invalid_file.h5'.",
+    "error":  "Unsupported file content for extension '.h5' in 'keras_invalid_file.h5'.",
     "link":   "https://github.com/lutzroeder/netron/issues/57"
   },
   {