|
|
@@ -9,63 +9,106 @@ npz.ModelFactory = class {
|
|
|
|
|
|
match(context) {
|
|
|
const entries = context.entries('zip');
|
|
|
- return entries.length > 0 && entries.every((entry) => entry.name.endsWith('.npy'));
|
|
|
+ if (entries.length > 0 && entries.every((entry) => entry.name.endsWith('.npy'))) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ const tags = context.tags('pkl');
|
|
|
+ if (tags.size === 1 && tags.keys().next().value === '') {
|
|
|
+ if (npz.Utility.weights(tags.values().next().value)) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return false;
|
|
|
}
|
|
|
|
|
|
open(context) {
|
|
|
return context.require('./numpy').then((numpy) => {
|
|
|
- const modules = [];
|
|
|
- const modulesMap = new Map();
|
|
|
- const dataTypeMap = new Map([
|
|
|
- [ 'i1', 'int8'], [ 'i2', 'int16' ], [ 'i4', 'int32'], [ 'i8', 'int64' ],
|
|
|
- [ 'u1', 'uint8'], [ 'u2', 'uint16' ], [ 'u4', 'uint32'], [ 'u8', 'uint64' ],
|
|
|
- [ 'f2', 'float16'], [ 'f4', 'float32' ], [ 'f8', 'float64']
|
|
|
- ]);
|
|
|
- const execution = new python.Execution(null);
|
|
|
- for (const entry of context.entries('zip')) {
|
|
|
- if (!entry.name.endsWith('.npy')) {
|
|
|
- throw new npz.Error("Invalid file name '" + entry.name + "'.");
|
|
|
+ const tags = context.tags('pkl');
|
|
|
+ const groups = new Map();
|
|
|
+ let format = '';
|
|
|
+ if (tags.size === 1) {
|
|
|
+ format = 'NumPy Weights';
|
|
|
+ const weights = npz.Utility.weights(tags.values().next().value);
|
|
|
+ let separator = '_';
|
|
|
+ if (Array.from(weights.keys()).every((key) => key.indexOf('.') !== -1) &&
|
|
|
+ !Array.from(weights.keys()).every((key) => key.indexOf('_') !== -1)) {
|
|
|
+ separator = '.';
|
|
|
}
|
|
|
- const name = entry.name.replace(/\.npy$/, '');
|
|
|
- const parts = name.split('/');
|
|
|
- const parameterName = parts.pop();
|
|
|
- const moduleName = parts.join('/');
|
|
|
- if (!modulesMap.has(moduleName)) {
|
|
|
- const newModule = { name: moduleName, parameters: [] };
|
|
|
- modules.push(newModule);
|
|
|
- modulesMap.set(moduleName, newModule);
|
|
|
+ for (const pair of weights) {
|
|
|
+ const name = pair[0];
|
|
|
+ const value = pair[1];
|
|
|
+ const parts = name.split(separator);
|
|
|
+ const parameterName = parts.length > 1 ? parts.pop() : '?';
|
|
|
+ const groupName = parts.join(separator);
|
|
|
+ if (!groups.has(groupName)) {
|
|
|
+ groups.set(groupName, { name: groupName, parameters: [] });
|
|
|
+ }
|
|
|
+ const group = groups.get(groupName);
|
|
|
+ group.parameters.push({
|
|
|
+ name: parameterName,
|
|
|
+ tensor: {
|
|
|
+ name: name,
|
|
|
+ byteOrder: value.dtype.byteorder,
|
|
|
+ dataType: value.dtype.name,
|
|
|
+ shape: value.shape,
|
|
|
+ data: value.data
|
|
|
+ }
|
|
|
+ });
|
|
|
}
|
|
|
- const module = modulesMap.get(moduleName);
|
|
|
- const data = entry.data;
|
|
|
- let array = new numpy.Array(data);
|
|
|
- if (array.byteOrder === '|') {
|
|
|
- if (array.dataType !== 'O') {
|
|
|
- throw new npz.Error("Invalid data type '" + array.dataType + "'.");
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ format = 'NumPy Zip';
|
|
|
+ const dataTypeMap = new Map([
|
|
|
+ [ 'i1', 'int8'], [ 'i2', 'int16' ], [ 'i4', 'int32'], [ 'i8', 'int64' ],
|
|
|
+ [ 'u1', 'uint8'], [ 'u2', 'uint16' ], [ 'u4', 'uint32'], [ 'u8', 'uint64' ],
|
|
|
+ [ 'f2', 'float16'], [ 'f4', 'float32' ], [ 'f8', 'float64']
|
|
|
+ ]);
|
|
|
+ const execution = new python.Execution(null);
|
|
|
+ for (const entry of context.entries('zip')) {
|
|
|
+ if (!entry.name.endsWith('.npy')) {
|
|
|
+ throw new npz.Error("Invalid file name '" + entry.name + "'.");
|
|
|
+ }
|
|
|
+ const name = entry.name.replace(/\.npy$/, '');
|
|
|
+ const parts = name.split('/');
|
|
|
+ const parameterName = parts.pop();
|
|
|
+ const groupName = parts.join('/');
|
|
|
+ if (!groups.has(groupName)) {
|
|
|
+ groups.set(groupName, { name: groupName, parameters: [] });
|
|
|
+ }
|
|
|
+ const group = groups.get(groupName);
|
|
|
+ const data = entry.data;
|
|
|
+ let array = new numpy.Array(data);
|
|
|
+ if (array.byteOrder === '|') {
|
|
|
+ if (array.dataType !== 'O') {
|
|
|
+ throw new npz.Error("Invalid data type '" + array.dataType + "'.");
|
|
|
+ }
|
|
|
+ const unpickler = new python.Unpickler(array.data);
|
|
|
+ const root = unpickler.load((name, args) => execution.invoke(name, args));
|
|
|
+ array = { dataType: root.dtype.name, shape: null, data: null, byteOrder: '|' };
|
|
|
}
|
|
|
- const unpickler = new python.Unpickler(array.data);
|
|
|
- const root = unpickler.load((name, args) => execution.invoke(name, args));
|
|
|
- array = { dataType: root.dtype.name, shape: null, data: null, byteOrder: '|' };
|
|
|
+ group.parameters.push({
|
|
|
+ name: parameterName,
|
|
|
+ tensor: {
|
|
|
+ name: name,
|
|
|
+ byteOrder: array.byteOrder,
|
|
|
+ dataType: dataTypeMap.has(array.dataType) ? dataTypeMap.get(array.dataType) : array.dataType,
|
|
|
+ shape: array.shape,
|
|
|
+ data: array.data,
|
|
|
+ }
|
|
|
+ });
|
|
|
}
|
|
|
-
|
|
|
- module.parameters.push({
|
|
|
- name: parameterName,
|
|
|
- dataType: dataTypeMap.has(array.dataType) ? dataTypeMap.get(array.dataType) : array.dataType,
|
|
|
- shape: array.shape,
|
|
|
- data: array.data,
|
|
|
- byteOrder: array.byteOrder
|
|
|
- });
|
|
|
}
|
|
|
- return new npz.Model(modules, 'NumPy Zip');
|
|
|
+ return new npz.Model(format, groups.values());
|
|
|
});
|
|
|
}
|
|
|
};
|
|
|
|
|
|
npz.Model = class {
|
|
|
|
|
|
- constructor(modules, format) {
|
|
|
+ constructor(format, groups) {
|
|
|
this._format = format;
|
|
|
this._graphs = [];
|
|
|
- this._graphs.push(new npz.Graph(modules));
|
|
|
+ this._graphs.push(new npz.Graph(groups));
|
|
|
}
|
|
|
|
|
|
get format() {
|
|
|
@@ -79,10 +122,10 @@ npz.Model = class {
|
|
|
|
|
|
npz.Graph = class {
|
|
|
|
|
|
- constructor(modules) {
|
|
|
+ constructor(groups) {
|
|
|
this._nodes = [];
|
|
|
- for (const module of modules) {
|
|
|
- this._nodes.push(new npz.Node(module));
|
|
|
+ for (const group of groups) {
|
|
|
+ this._nodes.push(new npz.Node(group));
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -144,14 +187,15 @@ npz.Argument = class {
|
|
|
|
|
|
npz.Node = class {
|
|
|
|
|
|
- constructor(module) {
|
|
|
- this._name = module.name;
|
|
|
+ constructor(group) {
|
|
|
+ this._name = group.name;
|
|
|
this._inputs = [];
|
|
|
- for (const parameter of module.parameters) {
|
|
|
+ for (const parameter of group.parameters) {
|
|
|
const name = this._name ? [ this._name, parameter.name ].join('/') : parameter.name;
|
|
|
- const initializer = new npz.Tensor(name, parameter.dataType, parameter.shape, parameter.data, parameter.byteOrder);
|
|
|
+ const tensor = parameter.tensor;
|
|
|
+ const initializer = new npz.Tensor(name, tensor.dataType, tensor.shape, tensor.data, tensor.byteOrder);
|
|
|
this._inputs.push(new npz.Parameter(parameter.name, [
|
|
|
- new npz.Argument(name, initializer)
|
|
|
+ new npz.Argument(tensor.name || '', initializer)
|
|
|
]));
|
|
|
}
|
|
|
}
|
|
|
@@ -192,7 +236,7 @@ npz.Tensor = class {
|
|
|
}
|
|
|
|
|
|
get kind() {
|
|
|
- return '';
|
|
|
+ return 'NumPy Array';
|
|
|
}
|
|
|
|
|
|
get name() {
|
|
|
@@ -415,6 +459,59 @@ npz.TensorShape = class {
|
|
|
}
|
|
|
};
|
|
|
|
|
|
+npz.Utility = class {
|
|
|
+
|
|
|
+ static isTensor(obj) {
|
|
|
+ return obj && obj.__module__ === 'numpy' && obj.__name__ === 'ndarray';
|
|
|
+ }
|
|
|
+
|
|
|
+ static weights(obj) {
|
|
|
+ const keys = [ '', 'blobs' ];
|
|
|
+ for (const key of keys) {
|
|
|
+ const dict = key === '' ? obj : obj[key];
|
|
|
+ if (dict) {
|
|
|
+ const weights = new Map();
|
|
|
+ if (dict instanceof Map) {
|
|
|
+ for (const pair of dict) {
|
|
|
+ if (!npz.Utility.isTensor(pair[1])) {
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+ weights.set(pair[0], pair[1]);
|
|
|
+ }
|
|
|
+ return weights;
|
|
|
+ }
|
|
|
+ else if (!Array.isArray(dict)) {
|
|
|
+ for (const key in dict) {
|
|
|
+ const value = dict[key];
|
|
|
+ if (key != 'weight_order' && key != 'lr') {
|
|
|
+ if (!key || !npz.Utility.isTensor(value)) {
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+ weights.set(key, value);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return weights;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ for (const key of keys) {
|
|
|
+ const list = key === '' ? obj : obj[key];
|
|
|
+ if (list && Array.isArray(list)) {
|
|
|
+ const weights = new Map();
|
|
|
+ for (let i = 0; i < list.length; i++) {
|
|
|
+ const value = list[i];
|
|
|
+ if (!npz.Utility.isTensor(value, 'numpy.ndarray')) {
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+ weights.set(i.toString(), value);
|
|
|
+ }
|
|
|
+ return weights;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+};
|
|
|
+
|
|
|
npz.Error = class extends Error {
|
|
|
|
|
|
constructor(message) {
|