|
|
@@ -3699,10 +3699,11 @@ pytorch.Utility = class {
|
|
|
}
|
|
|
return obj;
|
|
|
};
|
|
|
- const validate = (map) => {
|
|
|
- let tensor = false;
|
|
|
- if (map && map instanceof Map) {
|
|
|
- for (const [key, value] of map) {
|
|
|
+ const validate = (entries) => {
|
|
|
+ let count = 0;
|
|
|
+ if (entries && entries instanceof Map) {
|
|
|
+ entries.delete('_extra_state');
|
|
|
+ for (const [key, value] of entries) {
|
|
|
const separator = key.indexOf('.') === -1 && key.indexOf('|') !== -1 ? '|' : '.';
|
|
|
const keys = key.split(separator);
|
|
|
if (keys[keys.length - 1] === '_metadata') {
|
|
|
@@ -3710,10 +3711,10 @@ pytorch.Utility = class {
|
|
|
} else if (keys.length >= 2 && keys[keys.length - 2] === '_packed_params') {
|
|
|
continue;
|
|
|
} else if (pytorch.Utility.isTensor(value)) {
|
|
|
- tensor = true;
|
|
|
+ count++;
|
|
|
continue;
|
|
|
} else if (value && Array.isArray(value) && value.every((item) => pytorch.Utility.isTensor(item))) {
|
|
|
- tensor = true;
|
|
|
+ count++;
|
|
|
continue;
|
|
|
} else if (typeof value === 'string' || typeof value === 'number' || typeof value === 'boolean') {
|
|
|
continue;
|
|
|
@@ -3723,7 +3724,7 @@ pytorch.Utility = class {
|
|
|
return false;
|
|
|
}
|
|
|
}
|
|
|
- return tensor;
|
|
|
+ return count > 0;
|
|
|
};
|
|
|
const flatten = (obj) => {
|
|
|
if (!obj || Array.isArray(obj) || ArrayBuffer.isView(obj)) {
|