瀏覽代碼

PaddlePaddle support (#198)

Lutz Roeder 4 年之前
父節點
當前提交
22a0ece467
共有 2 個文件被更改,包括 6 次插入9 次删除
  1. 5 3
      source/paddle.js
  2. 1 6
      source/view.js

+ 5 - 3
source/paddle.js

@@ -8,9 +8,11 @@ paddle.ModelFactory = class {
     match(context) {
         const identifier = context.identifier;
         const extension = identifier.split('.').pop().toLowerCase();
-        if (identifier === '__model__' || identifier === 'model' ||
-            extension === 'paddle' || extension === 'pdmodel') {
-            return true;
+        if (identifier === '__model__' || extension === 'paddle' || extension === 'pdmodel') {
+            const tags = context.tags('pb');
+            if (tags.get(1) === 2) {
+                return true;
+            }
         }
         if (extension === 'pbtxt' || extension === 'txt') {
             const tags = context.tags('pbtxt');

+ 1 - 6
source/view.js

@@ -1408,7 +1408,7 @@ view.ModelFactoryService = class {
         this.register('./sklearn', [ '.pkl', '.pickle', '.joblib', '.model', '.meta', '.pb', '.pt', '.h5', '.pkl.z', '.joblib.z' ]);
         this.register('./pickle', [ '.pkl', '.pickle', '.joblib', '.model', '.meta', '.pb', '.pt', '.h5', '.pkl.z', '.joblib.z' ]);
         this.register('./cntk', [ '.model', '.cntk', '.cmf', '.dnn' ]);
-        this.register('./paddle', [ '.pdmodel', '.pdparams', '.pdiparams', '.paddle', '__model__', '.pbtxt', '.txt', '.tar', '.tar.gz', 'model', 'params' ]);
+        this.register('./paddle', [ '.pdmodel', '.pdparams', '.pdiparams', '.paddle', '__model__', '.pbtxt', '.txt', '.tar', '.tar.gz' ]);
         this.register('./bigdl', [ '.model', '.bigdl' ]);
         this.register('./darknet', [ '.cfg', '.model', '.txt', '.weights' ]);
         this.register('./weka', [ '.model' ]);
@@ -1744,11 +1744,6 @@ view.ModelFactoryService = class {
                             matches.some((e) => e.name.toLowerCase().endsWith('.pdiparams'))) {
                             matches = matches.filter((e) => e.name.toLowerCase().endsWith('.pdmodel'));
                         }
-                        if (matches.length > 0 &&
-                            matches.some((e) => e.name.split('/').pop().toLowerCase() === 'model') &&
-                            matches.some((e) => e.name.split('/').pop().toLowerCase() === 'params')) {
-                            matches = matches.filter((e) => e.name.toLowerCase().split('/').pop().toLowerCase() === 'model');
-                        }
                         // TensorFlow Bundle
                         if (matches.length > 1 &&
                             matches.some((e) => e.name.toLowerCase().endsWith('.data-00000-of-00001'))) {