|
|
@@ -3545,22 +3545,22 @@ pytorch.Utility = class {
|
|
|
return `${name} ${versions.get(value)}`;
|
|
|
}
|
|
|
|
|
|
- static find(obj) {
|
|
|
- if (obj) {
|
|
|
- if (pytorch.Utility.isTensor(obj)) {
|
|
|
+ static find(data) {
|
|
|
+ if (data) {
|
|
|
+ if (pytorch.Utility.isTensor(data)) {
|
|
|
const module = {};
|
|
|
module.__class__ = {
|
|
|
- __module__: obj.__class__.__module__,
|
|
|
- __name__: obj.__class__.__name__
|
|
|
+ __module__: data.__class__.__module__,
|
|
|
+ __name__: data.__class__.__name__
|
|
|
};
|
|
|
module._parameters = new Map();
|
|
|
- module._parameters.set('value', obj);
|
|
|
+ module._parameters.set('value', data);
|
|
|
return new Map([['', { _modules: new Map([['', module]]) }]]);
|
|
|
}
|
|
|
- if (!Array.isArray(obj) && !(obj instanceof Map) && obj === Object(obj) && Object.keys(obj).length === 0) {
|
|
|
+ if (!Array.isArray(data) && !(data instanceof Map) && data === Object(data) && Object.keys(data).length === 0) {
|
|
|
return new Map();
|
|
|
}
|
|
|
- const keys = Array.isArray(obj) ? [] : Object.keys(obj);
|
|
|
+ const keys = Array.isArray(data) ? [] : Object.keys(data);
|
|
|
if (keys.length > 1) {
|
|
|
keys.splice(0, keys.length);
|
|
|
}
|
|
|
@@ -3572,28 +3572,22 @@ pytorch.Utility = class {
|
|
|
'EMA_generator', 'runner', ''
|
|
|
]);
|
|
|
for (const key of keys) {
|
|
|
- const value = key === '' ? obj : obj[key];
|
|
|
- let graphs = null;
|
|
|
- graphs = graphs || pytorch.Utility._convertObjectList(value);
|
|
|
- graphs = graphs || pytorch.Utility._convertStateDict(value);
|
|
|
+ const obj = key === '' ? data : data[key];
|
|
|
+ if (obj && Array.isArray(obj)) {
|
|
|
+ if (obj.every((item) => typeof item === 'number' || typeof item === 'string')) {
|
|
|
+ return new Map([['', obj]]);
|
|
|
+ }
|
|
|
+ if (obj.every((item) => item && Object.values(item).filter((value) => pytorch.Utility.isTensor(value)).length > 0)) {
|
|
|
+ return new Map([['', obj]]);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ const graphs = pytorch.Utility._convertStateDict(obj);
|
|
|
if (graphs) {
|
|
|
return graphs;
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
- return new Map([['', obj]]);
|
|
|
- }
|
|
|
-
|
|
|
- static _convertObjectList(obj) {
|
|
|
- if (obj && Array.isArray(obj)) {
|
|
|
- if (obj.every((item) => typeof item === 'number' || typeof item === 'string')) {
|
|
|
- return new Map([['', obj]]);
|
|
|
- }
|
|
|
- if (obj.every((item) => item && Object.values(item).filter((value) => pytorch.Utility.isTensor(value)).length > 0)) {
|
|
|
- return new Map([['', obj]]);
|
|
|
- }
|
|
|
- }
|
|
|
- return null;
|
|
|
+ return new Map([['', data]]);
|
|
|
}
|
|
|
|
|
|
static _convertStateDict(obj) {
|
|
|
@@ -3630,22 +3624,25 @@ pytorch.Utility = class {
|
|
|
return count > 0;
|
|
|
};
|
|
|
const isLayer = (obj) => {
|
|
|
- if (obj instanceof Map === false) {
|
|
|
+ if (Object(obj) === obj) {
|
|
|
obj = new Map(Object.entries(obj));
|
|
|
}
|
|
|
- for (const [key, value] of Array.from(obj)) {
|
|
|
- if (pytorch.Utility.isTensor(value)) {
|
|
|
- continue;
|
|
|
- }
|
|
|
- if (key === '_metadata') {
|
|
|
- continue;
|
|
|
- }
|
|
|
- if (typeof value === 'string' || typeof value === 'number' || typeof value === 'boolean') {
|
|
|
- continue;
|
|
|
+ if (obj instanceof Map) {
|
|
|
+ for (const [key, value] of Array.from(obj)) {
|
|
|
+ if (pytorch.Utility.isTensor(value)) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ if (key === '_metadata') {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ if (typeof value === 'string' || typeof value === 'number' || typeof value === 'boolean') {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ return false;
|
|
|
}
|
|
|
- return false;
|
|
|
+ return true;
|
|
|
}
|
|
|
- return true;
|
|
|
+ return false;
|
|
|
};
|
|
|
const flatten = (obj) => {
|
|
|
if (!obj || Array.isArray(obj) || ArrayBuffer.isView(obj)) {
|
|
|
@@ -3666,6 +3663,9 @@ pytorch.Utility = class {
|
|
|
}
|
|
|
const target = new Map();
|
|
|
for (const [name, obj] of map) {
|
|
|
+ if (obj && pytorch.Utility.isInstance(obj, 'builtins.type')) {
|
|
|
+ return null;
|
|
|
+ }
|
|
|
const value = flatten(obj);
|
|
|
if (value && value instanceof Map) {
|
|
|
for (const pair of value) {
|
|
|
@@ -3687,9 +3687,9 @@ pytorch.Utility = class {
|
|
|
}
|
|
|
} else if (obj instanceof Map && validate(obj)) {
|
|
|
map.set('', flatten(obj));
|
|
|
- } else if ((Object(obj) === obj && Object.entries(obj).every(([, value]) => value && isLayer(value)))) {
|
|
|
+ } else if (obj instanceof Map === false && Object(obj) === obj && Object.entries(obj).every(([, value]) => value && isLayer(value))) {
|
|
|
return new Map([['', { _modules: new Map(Object.entries(obj)) }]]);
|
|
|
- } else if (Object(obj) === obj && Object.entries(obj).every(([, value]) => validate(value))) {
|
|
|
+ } else if (obj instanceof Map === false && Object(obj) === obj && Object.entries(obj).every(([, value]) => validate(value))) {
|
|
|
for (const [name, value] of Object.entries(obj)) {
|
|
|
if (Object(value) === value) {
|
|
|
map.set(name, new Map(Object.entries(value)));
|
|
|
@@ -4218,13 +4218,36 @@ pytorch.Metadata = class {
|
|
|
numpy.Tensor = class {
|
|
|
|
|
|
constructor(array) {
|
|
|
- this.type = new pytorch.TensorType(array.dtype.__name__, new pytorch.TensorShape(array.shape));
|
|
|
+ this.type = new numpy.TensorType(array.dtype.__name__, new numpy.TensorShape(array.shape));
|
|
|
this.stride = array.strides.map((stride) => stride / array.itemsize);
|
|
|
this.values = this.type.dataType === 'string' || this.type.dataType === 'object' || this.type.dataType === 'void' ? array.flatten().tolist() : array.tobytes();
|
|
|
this.encoding = this.type.dataType === 'string' || this.type.dataType === 'object' ? '|' : array.dtype.byteorder;
|
|
|
}
|
|
|
};
|
|
|
|
|
|
+numpy.TensorType = class {
|
|
|
+
|
|
|
+ constructor(dataType, shape) {
|
|
|
+ this.dataType = dataType || '?';
|
|
|
+ this.shape = shape;
|
|
|
+ }
|
|
|
+
|
|
|
+ toString() {
|
|
|
+ return this.dataType + this.shape.toString();
|
|
|
+ }
|
|
|
+};
|
|
|
+
|
|
|
+numpy.TensorShape = class {
|
|
|
+
|
|
|
+ constructor(dimensions) {
|
|
|
+ this.dimensions = dimensions;
|
|
|
+ }
|
|
|
+
|
|
|
+ toString() {
|
|
|
+ return this.dimensions && this.dimensions.length > 0 ? `[${this.dimensions.join(',')}]` : '';
|
|
|
+ }
|
|
|
+};
|
|
|
+
|
|
|
pytorch.Error = class extends Error {
|
|
|
|
|
|
constructor(message) {
|