Lutz Roeder пре 6 месеци
родитељ
комит
360994ee83
17 измењених фајлова са 43 додато и 43 уклоњено
  1. 2 2
      source/coreml.js
  2. 2 2
      source/dl4j.js
  3. 3 3
      source/espresso.js
  4. 3 3
      source/hailo.js
  5. 4 4
      source/keras.js
  6. 2 2
      source/mxnet.js
  7. 2 2
      source/ncnn.js
  8. 2 2
      source/nnabla.js
  9. 2 2
      source/nnef.js
  10. 2 2
      source/onnx.js
  11. 2 2
      source/openvino.js
  12. 3 3
      source/paddle.js
  13. 5 5
      source/pytorch.js
  14. 3 3
      source/tf.js
  15. 2 2
      source/transformers.js
  16. 2 2
      source/tvm.js
  17. 2 2
      source/view.js

+ 2 - 2
source/coreml.js

@@ -74,8 +74,8 @@ coreml.ModelFactory = class {
         return null;
     }
 
-    filter(context, type) {
-        if (context.type === 'coreml.metadata.mlmodelc' && (type === 'coreml.mil')) {
+    filter(context, match) {
+        if (context.type === 'coreml.metadata.mlmodelc' && (match.type === 'coreml.mil')) {
             return false;
         }
         return true;

+ 2 - 2
source/dl4j.js

@@ -22,8 +22,8 @@ dl4j.ModelFactory = class {
         return null;
     }
 
-    filter(context, type) {
-        return context.type !== 'dl4j.configuration' || (type !== 'dl4j.coefficients' && type !== 'openvino.bin');
+    filter(context, match) {
+        return context.type !== 'dl4j.configuration' || (match.type !== 'dl4j.coefficients' && match.type !== 'openvino.bin');
     }
 
     async open(context) {

+ 3 - 3
source/espresso.js

@@ -24,11 +24,11 @@ espresso.ModelFactory = class {
         return null;
     }
 
-    filter(context, type) {
-        if (context.type === 'espresso.net' && (type === 'espresso.weights' || type === 'espresso.shape' || type === 'coreml.metadata.mlmodelc')) {
+    filter(context, match) {
+        if (context.type === 'espresso.net' && (match.type === 'espresso.weights' || match.type === 'espresso.shape' || match.type === 'coreml.metadata.mlmodelc')) {
             return false;
         }
-        if (context.type === 'espresso.shape' && (type === 'espresso.weights' || type === 'coreml.metadata.mlmodelc')) {
+        if (context.type === 'espresso.shape' && (match.type === 'espresso.weights' || match.type === 'coreml.metadata.mlmodelc')) {
             return false;
         }
         return true;

+ 3 - 3
source/hailo.js

@@ -12,11 +12,11 @@ hailo.ModelFactory = class {
         return null;
     }
 
-    filter(context, type) {
-        if (context.type === 'hailo.metadata' && (type === 'hailo.configuration' || type === 'npz' || type === 'onnx.proto')) {
+    filter(context, match) {
+        if (context.type === 'hailo.metadata' && (match.type === 'hailo.configuration' || match.type === 'npz' || match.type === 'onnx.proto')) {
             return false;
         }
-        if (context.type === 'hailo.configuration' && type === 'npz') {
+        if (context.type === 'hailo.configuration' && match.type === 'npz') {
             return false;
         }
         return true;

+ 4 - 4
source/keras.js

@@ -63,14 +63,14 @@ keras.ModelFactory = class {
         return null;
     }
 
-    filter(context, type) {
-        if (context.type === 'keras.metadata.json' && (type === 'keras.config.json' || type === 'keras.model.weights.h5' || type === 'keras.model.weights.npz')) {
+    filter(context, match) {
+        if (context.type === 'keras.metadata.json' && (match.type === 'keras.config.json' || match.type === 'keras.model.weights.h5' || match.type === 'keras.model.weights.npz')) {
             return false;
         }
-        if (context.type === 'keras.config.json' && (type === 'keras.model.weights.h5' || type === 'keras.model.weights.npz')) {
+        if (context.type === 'keras.config.json' && (match.type === 'keras.model.weights.h5' || match.type === 'keras.model.weights.npz')) {
             return false;
         }
-        if (context.type === 'tfjs' && type === 'tf.tfjs.weights') {
+        if (context.type === 'tfjs' && match.type === 'tf.tfjs.weights') {
             return false;
         }
         return true;

+ 2 - 2
source/mxnet.js

@@ -22,8 +22,8 @@ mxnet.ModelFactory = class {
         return null;
     }
 
-    filter(context, type) {
-        return context.type !== 'mxnet.json' || type !== 'mxnet.params';
+    filter(context, match) {
+        return context.type !== 'mxnet.json' || match.type !== 'mxnet.params';
     }
 
     async open(context) {

+ 2 - 2
source/ncnn.js

@@ -99,8 +99,8 @@ ncnn.ModelFactory = class {
         return null;
     }
 
-    filter(context, type) {
-        return (context.type !== 'ncnn.model' && context.type !== 'ncnn.model.bin') || type !== 'ncnn.weights';
+    filter(context, match) {
+        return (context.type !== 'ncnn.model' && context.type !== 'ncnn.model.bin') || match.type !== 'ncnn.weights';
     }
 
     async open(context) {

+ 2 - 2
source/nnabla.js

@@ -70,8 +70,8 @@ nnabla.ModelFactory = class {
         }
     }
 
-    filter(context, type) {
-        return context.type !== 'nnabla.pbtxt' || (type !== 'hdf5.parameter.h5' && type !== 'keras.h5');
+    filter(context, match) {
+        return context.type !== 'nnabla.pbtxt' || (match.type !== 'hdf5.parameter.h5' && match.type !== 'keras.h5');
     }
 };
 

+ 2 - 2
source/nnef.js

@@ -30,8 +30,8 @@ nnef.ModelFactory = class {
         return null;
     }
 
-    filter(context, type) {
-        return context.type !== 'nnef.graph' || type !== 'nnef.dat';
+    filter(context, match) {
+        return context.type !== 'nnef.graph' || match.type !== 'nnef.dat';
     }
 
     async open(context) {

+ 2 - 2
source/onnx.js

@@ -40,8 +40,8 @@ onnx.ModelFactory = class {
         return new onnx.Model(metadata, target);
     }
 
-    filter(context, type) {
-        return context.type !== 'onnx.proto' || (type !== 'onnx.data' && type !== 'onnx.meta' && type !== 'dot');
+    filter(context, match) {
+        return context.type !== 'onnx.proto' || (match.type !== 'onnx.data' && match.type !== 'onnx.meta' && match.type !== 'dot');
     }
 };
 

+ 2 - 2
source/openvino.js

@@ -95,8 +95,8 @@ openvino.ModelFactory = class {
         return null;
     }
 
-    filter(context, type) {
-        return context.type !== 'openvino.xml' || type !== 'openvino.bin';
+    filter(context, match) {
+        return context.type !== 'openvino.xml' || match.type !== 'openvino.bin';
     }
 
     async open(context) {

+ 3 - 3
source/paddle.js

@@ -46,11 +46,11 @@ paddle.ModelFactory = class {
         return null;
     }
 
-    filter(context, type) {
-        if (context.type === 'paddle.pb' && (type === 'paddle.params' || type === 'paddle.pickle')) {
+    filter(context, match) {
+        if (context.type === 'paddle.pb' && (match.type === 'paddle.params' || match.type === 'paddle.pickle')) {
             return false;
         }
-        if (context.type === 'paddle.naive.model' && type === 'paddle.naive.param') {
+        if (context.type === 'paddle.naive.model' && match.type === 'paddle.naive.param') {
             return false;
         }
         return true;

+ 5 - 5
source/pytorch.js

@@ -18,17 +18,17 @@ pytorch.ModelFactory = class {
         return null;
     }
 
-    filter(context, type) {
-        if (context.type === 'pytorch.export' && type === 'pytorch.zip') {
+    filter(context, match) {
+        if (context.type === 'pytorch.export' && match.type === 'pytorch.zip') {
             return false;
         }
-        if (context.type === 'pytorch.index' && type === 'pytorch.zip') {
+        if (context.type === 'pytorch.index' && match.type === 'pytorch.zip') {
             return false;
         }
-        if (context.type === 'pytorch.model.json' && type === 'pytorch.data.pkl') {
+        if (context.type === 'pytorch.model.json' && match.type === 'pytorch.data.pkl') {
             return false;
         }
-        if (context.type === 'pytorch.model.json' && type === 'pickle') {
+        if (context.type === 'pytorch.model.json' && match.type === 'pickle') {
             return false;
         }
         return true;

+ 3 - 3
source/tf.js

@@ -233,11 +233,11 @@ tf.ModelFactory = class {
         return null;
     }
 
-    filter(context, type) {
-        if (context.type === 'tf.bundle' && type === 'tf.data') {
+    filter(context, match) {
+        if (context.type === 'tf.bundle' && match.type === 'tf.data') {
             return false;
         }
-        if ((context.type === 'tf.json' || context.type === 'tf.json.gz') && type === 'tf.tfjs.weights') {
+        if ((context.type === 'tf.json' || context.type === 'tf.json.gz') && match.type === 'tf.tfjs.weights') {
             return false;
         }
         return true;

+ 2 - 2
source/transformers.js

@@ -66,7 +66,7 @@ transformers.ModelFactory = class {
         return new transformers.Model(config, tokenizer, tokenizer_config, vocab, generation_config, preprocessor_config);
     }
 
-    filter(context, type) {
+    filter(context, match) {
         const priority = new Map([
             ['transformers.config', 7],
             ['transformers.tokenizer', 6],
@@ -79,7 +79,7 @@ transformers.ModelFactory = class {
             ['safetensors', 0]
         ]);
         const a = priority.has(context.type) ? priority.get(context.type) : -1; // current
-        const b = priority.has(type) ? priority.get(type) : -1;
+        const b = priority.has(match.type) ? priority.get(match.type) : -1;
         if (a !== -1 && b !== -1) {
             return a < b;
         }

+ 2 - 2
source/tvm.js

@@ -21,8 +21,8 @@ tvm.ModelFactory = class {
         return null;
     }
 
-    filter(context, type) {
-        return context.type !== 'tvm.json' || type !== 'tvm.params';
+    filter(context, match) {
+        return context.type !== 'tvm.json' || match.type !== 'tvm.params';
     }
 
     async open(context) {

+ 2 - 2
source/view.js

@@ -6720,8 +6720,8 @@ view.ModelFactoryService = class {
                         }
                         delete context.value;
                         if (type) {
-                            matches = matches.filter((match) => !factory.filter || factory.filter(context, match.type));
-                            if (matches.every((match) => !match.factory.filter || match.factory.filter(match, context.type))) {
+                            matches = matches.filter((match) => !factory.filter || factory.filter(context, match));
+                            if (matches.every((match) => !match.factory.filter || match.factory.filter(match, context))) {
                                 context.factory = factory;
                                 matches.push(context);
                             }