Parcourir la source

Update view.js

Lutz Roeder il y a 2 ans
Parent
commit
d87a13a7b5
1 fichiers modifiés avec 14 ajouts et 13 suppressions
  1. 14 13
      source/view.js

+ 14 - 13
source/view.js

@@ -5255,7 +5255,7 @@ view.ModelFactoryService = class {
 
     constructor(host) {
         this._host = host;
-        this._extensions = new Set([ '.zip', '.tar', '.tar.gz', '.tgz', '.gz' ]);
+        this._patterns = new Set([ '.zip', '.tar', '.tar.gz', '.tgz', '.gz' ]);
         this._factories = [];
         this.register('./server', [ '.netron']);
         this.register('./pytorch', [ '.pt', '.pth', '.ptl', '.pt1', '.pyt', '.pyth', '.pkl', '.pickle', '.h5', '.t7', '.model', '.dms', '.tar', '.ckpt', '.chkpt', '.tckpt', '.bin', '.pb', '.zip', '.nn', '.torchmodel', '.torchscript', '.pytorch', '.ot', '.params', '.trt', '.ff', '.ptmf', '.jit', '.pte', '.bin.index.json', 'serialized_exported_program.json' ], [ '.model', '.pt2' ]);
@@ -5302,7 +5302,7 @@ view.ModelFactoryService = class {
         this.register('./imgdnn', [ '.dnn', 'params', '.json' ]);
         this.register('./flax', [ '.msgpack' ]);
         this.register('./om', [ '.om', '.onnx', '.pb', '.engine' ]);
-        this.register('./gguf', [ '.gguf' ]);
+        this.register('./gguf', [ '.gguf', /^[^.]+$/ ]);
         this.register('./nnabla', [ '.nntxt' ], [ '.nnp' ]);
         this.register('./hickle', [ '.h5', '.hkl' ]);
         this.register('./nnef', [ '.nnef', '.dat' ]);
@@ -5317,13 +5317,13 @@ view.ModelFactoryService = class {
         this.register('./weka', [ '.model' ]);
     }
 
-    register(id, factories, containers) {
-        for (const extension of factories) {
-            this._factories.push({ extension: extension, id: id });
-            this._extensions.add(extension);
+    register(module, factories, containers) {
+        for (const pattern of factories) {
+            this._factories.push({ pattern: pattern, module: module });
+            this._patterns.add(pattern);
         }
-        for (const extension of containers || []) {
-            this._extensions.add(extension);
+        for (const pattern of containers || []) {
+            this._patterns.add(pattern);
         }
     }
 
@@ -5773,8 +5773,9 @@ view.ModelFactoryService = class {
         const extension = identifier.indexOf('.') === -1 ? '' : identifier.split('.').pop().toLowerCase();
         identifier = identifier.toLowerCase().split('/').pop();
         let accept = false;
-        for (const extension of this._extensions) {
-            if ((typeof extension === 'string' && identifier.endsWith(extension)) || (extension instanceof RegExp && extension.exec(identifier))) {
+        for (const extension of this._patterns) {
+            if ((typeof extension === 'string' && identifier.endsWith(extension)) ||
+                (extension instanceof RegExp && extension.exec(identifier))) {
                 accept = true;
                 break;
             }
@@ -5790,9 +5791,9 @@ view.ModelFactoryService = class {
     _filter(context) {
         const identifier = context.identifier.toLowerCase().split('/').pop();
         const list = this._factories.filter((entry) =>
-            (typeof entry.extension === 'string' && identifier.endsWith(entry.extension)) ||
-            (entry.extension instanceof RegExp && entry.extension.exec(identifier)));
-        return Array.from(new Set(list.map((entry) => entry.id)));
+            (typeof entry.pattern === 'string' && identifier.endsWith(entry.pattern)) ||
+            (entry.pattern instanceof RegExp && entry.pattern.test(identifier)));
+        return Array.from(new Set(list.map((entry) => entry.module)));
     }
 
     async _openSignature(context) {