Browse Source

Handle Keras HDF5 with .pth extension

Lutz Roeder 6 years ago
parent
commit
c0ced6d27a
2 changed files with 5 additions and 4 deletions
  1. 4 3
      src/keras.js
  2. 1 1
      src/view.js

+ 4 - 3
src/keras.js

@@ -9,10 +9,10 @@ keras.ModelFactory = class {
     match(context) {
         const identifier = context.identifier;
         const extension = identifier.split('.').pop().toLowerCase();
-        if (extension === 'h5' || extension === 'hd5' || extension === 'hdf5' || extension === 'keras' || extension === 'model' || extension == 'pb') {
+        if (extension === 'h5' || extension === 'hd5' || extension === 'hdf5' || extension === 'keras' || extension === 'model' || extension == 'pb' || extension == 'pth') {
             const buffer = context.buffer;
             const signature = [ 0x89, 0x48, 0x44, 0x46, 0x0D, 0x0A, 0x1A, 0x0A ];
-            return (buffer && buffer.length > signature.length && signature.every((v, i) => v === buffer[i]));
+            return buffer && buffer.length > signature.length && signature.every((v, i) => v === buffer[i]);
         }
         if (extension == 'json' && !identifier.endsWith('-symbol.json')) {
             const json = context.text;
@@ -57,7 +57,8 @@ keras.ModelFactory = class {
                     case 'hd5':
                     case 'hdf5':
                     case 'model':
-                    case 'pb': {
+                    case 'pb':
+                    case 'pth': {
                         const file = new hdf5.File(context.buffer);
                         rootGroup = file.rootGroup;
                         if (!rootGroup.attribute('model_config') && !rootGroup.attribute('layer_names')) {

+ 1 - 1
src/view.js

@@ -1128,7 +1128,7 @@ view.ModelFactoryService = class {
         this._extensions = [];
         this.register('./onnx', [ '.onnx', '.pb', '.pbtxt', '.prototxt' ]);
         this.register('./mxnet', [ '.mar', '.model', '.json', '.params' ]);
-        this.register('./keras', [ '.h5', '.hd5', '.hdf5', '.keras', '.json', '.model', '.pb' ]);
+        this.register('./keras', [ '.h5', '.hd5', '.hdf5', '.keras', '.json', '.model', '.pb', '.pth' ]);
         this.register('./coreml', [ '.mlmodel' ]);
         this.register('./caffe', [ '.caffemodel', '.pbtxt', '.prototxt', '.pt' ]);
         this.register('./caffe2', [ '.pb', '.pbtxt', '.prototxt' ]);