Lutz Roeder 5 anni fa
parent
commit
dc9b3e868d
1 ha cambiato i file con 26 aggiunte e 8 eliminazioni
  1. 26 8
      source/npz.js

+ 26 - 8
source/npz.js

@@ -464,8 +464,7 @@ npz.Utility = class {
     }
 
     static weights(obj) {
-        const keys = [ '', 'blobs' ];
-        for (const key of keys) {
+        const dict = (obj, key) => {
             const dict = key === '' ? obj : obj[key];
             if (dict) {
                 const weights = new Map();
@@ -479,20 +478,26 @@ npz.Utility = class {
                     return weights;
                 }
                 else if (!Array.isArray(dict)) {
+                    const set = new Set([ 'weight_order', 'lr', 'model_iter' ]);
                     for (const key in dict) {
                         const value = dict[key];
-                        if (key != 'weight_order' && key != 'lr') {
-                            if (!key || !npz.Utility.isTensor(value)) {
-                                return null;
+                        if (key) {
+                            if (npz.Utility.isTensor(value)) {
+                                weights.set(key, value);
+                                continue;
+                            }
+                            if (set.has(key)) {
+                                continue;
                             }
-                            weights.set(key, value);
                         }
+                        return null;
                     }
                     return weights;
                 }
             }
-        }
-        for (const key of keys) {
+            return null;
+        };
+        const list = (obj, key) => {
             const list = key === '' ? obj : obj[key];
             if (list && Array.isArray(list)) {
                 const weights = new Map();
@@ -505,6 +510,19 @@ npz.Utility = class {
                 }
                 return weights;
             }
+        };
+        const keys = [ '', 'blobs' ];
+        for (const key of keys) {
+            const weights = dict(obj, key);
+            if (weights) {
+                return weights;
+            }
+        }
+        for (const key of keys) {
+            const weights = list(obj, key);
+            if (weights) {
+                return weights;
+            }
         }
         return null;
     }