Lutz Roeder преди 2 години
родител
ревизия
336e5f1b05
променени са 1 файла, в които са добавени 8 реда и са изтрити 7 реда
  1. 8 7
      source/pytorch.js

+ 8 - 7
source/pytorch.js

@@ -3699,10 +3699,11 @@ pytorch.Utility = class {
             }
             return obj;
         };
-        const validate = (map) => {
-            let tensor = false;
-            if (map && map instanceof Map) {
-                for (const [key, value] of map) {
+        const validate = (entries) => {
+            let count = 0;
+            if (entries && entries instanceof Map) {
+                entries.delete('_extra_state');
+                for (const [key, value] of entries) {
                     const separator = key.indexOf('.') === -1 && key.indexOf('|') !== -1 ? '|' : '.';
                     const keys = key.split(separator);
                     if (keys[keys.length - 1] === '_metadata') {
@@ -3710,10 +3711,10 @@ pytorch.Utility = class {
                     } else if (keys.length >= 2 && keys[keys.length - 2] === '_packed_params') {
                         continue;
                     } else if (pytorch.Utility.isTensor(value)) {
-                        tensor = true;
+                        count++;
                         continue;
                     } else if (value && Array.isArray(value) && value.every((item) => pytorch.Utility.isTensor(item))) {
-                        tensor = true;
+                        count++;
                         continue;
                     } else if (typeof value === 'string' || typeof value === 'number' || typeof value === 'boolean') {
                         continue;
@@ -3723,7 +3724,7 @@ pytorch.Utility = class {
                     return false;
                 }
             }
-            return tensor;
+            return count > 0;
         };
         const flatten = (obj) => {
             if (!obj || Array.isArray(obj) || ArrayBuffer.isView(obj)) {