|
|
@@ -464,8 +464,7 @@ npz.Utility = class {
|
|
|
}
|
|
|
|
|
|
static weights(obj) {
|
|
|
- const keys = [ '', 'blobs' ];
|
|
|
- for (const key of keys) {
|
|
|
+ const dict = (obj, key) => {
|
|
|
const dict = key === '' ? obj : obj[key];
|
|
|
if (dict) {
|
|
|
const weights = new Map();
|
|
|
@@ -479,20 +478,26 @@ npz.Utility = class {
|
|
|
return weights;
|
|
|
}
|
|
|
else if (!Array.isArray(dict)) {
|
|
|
+ const set = new Set([ 'weight_order', 'lr', 'model_iter' ]);
|
|
|
for (const key in dict) {
|
|
|
const value = dict[key];
|
|
|
- if (key != 'weight_order' && key != 'lr') {
|
|
|
- if (!key || !npz.Utility.isTensor(value)) {
|
|
|
- return null;
|
|
|
+ if (key) {
|
|
|
+ if (npz.Utility.isTensor(value)) {
|
|
|
+ weights.set(key, value);
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ if (set.has(key)) {
|
|
|
+ continue;
|
|
|
}
|
|
|
- weights.set(key, value);
|
|
|
}
|
|
|
+ return null;
|
|
|
}
|
|
|
return weights;
|
|
|
}
|
|
|
}
|
|
|
- }
|
|
|
- for (const key of keys) {
|
|
|
+ return null;
|
|
|
+ };
|
|
|
+ const list = (obj, key) => {
|
|
|
const list = key === '' ? obj : obj[key];
|
|
|
if (list && Array.isArray(list)) {
|
|
|
const weights = new Map();
|
|
|
@@ -505,6 +510,19 @@ npz.Utility = class {
|
|
|
}
|
|
|
return weights;
|
|
|
}
|
|
|
+ };
|
|
|
+ const keys = [ '', 'blobs' ];
|
|
|
+ for (const key of keys) {
|
|
|
+ const weights = dict(obj, key);
|
|
|
+ if (weights) {
|
|
|
+ return weights;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ for (const key of keys) {
|
|
|
+ const weights = list(obj, key);
|
|
|
+ if (weights) {
|
|
|
+ return weights;
|
|
|
+ }
|
|
|
}
|
|
|
return null;
|
|
|
}
|