|
|
@@ -10,19 +10,23 @@ numpy.ModelFactory = class {
|
|
|
const stream = context.stream;
|
|
|
const signature = [ 0x93, 0x4E, 0x55, 0x4D, 0x50, 0x59 ];
|
|
|
if (signature.length <= stream.length && stream.peek(signature.length).every((value, index) => value === signature[index])) {
|
|
|
- return 'npy';
|
|
|
+ return { name: 'npy' };
|
|
|
}
|
|
|
const entries = context.entries('zip');
|
|
|
if (entries.size > 0 && Array.from(entries.keys()).every((name) => name.endsWith('.npy'))) {
|
|
|
- return 'npz';
|
|
|
+ return { name: 'npz', value: entries };
|
|
|
}
|
|
|
const obj = context.open('pkl');
|
|
|
if (obj) {
|
|
|
if (numpy.Utility.isTensor(obj)) {
|
|
|
- return 'numpy.ndarray';
|
|
|
+ return { name: 'numpy.ndarray', value: obj };
|
|
|
}
|
|
|
- if (numpy.Utility.weights(obj)) {
|
|
|
- return 'pickle';
|
|
|
+ if (Array.isArray(obj) && obj.every((obj) => obj && obj.__class__ && obj.__class__.__name__ === 'Network' && (obj.__class__.__module__ === 'dnnlib.tflib.network' || obj.__class__.__module__ === 'tfutil'))) {
|
|
|
+ return { name: 'dnnlib.tflib.network', value: obj };
|
|
|
+ }
|
|
|
+ const weights = numpy.Utility.weights(obj);
|
|
|
+ if (weights) {
|
|
|
+ return { name: 'pickle', value: weights };
|
|
|
}
|
|
|
}
|
|
|
return undefined;
|
|
|
@@ -30,8 +34,8 @@ numpy.ModelFactory = class {
|
|
|
|
|
|
open(context, match) {
|
|
|
let format = '';
|
|
|
- const groups = new Map();
|
|
|
- switch (match) {
|
|
|
+ const graphs = [];
|
|
|
+ switch (match.name) {
|
|
|
case 'npy': {
|
|
|
format = 'NumPy Array';
|
|
|
const execution = new python.Execution(null);
|
|
|
@@ -39,19 +43,15 @@ numpy.ModelFactory = class {
|
|
|
const buffer = stream.peek();
|
|
|
const bytes = execution.invoke('io.BytesIO', [ buffer ]);
|
|
|
const array = execution.invoke('numpy.load', [ bytes ]);
|
|
|
- const group = { type: format, parameters: [] };
|
|
|
- group.parameters.push({
|
|
|
- name: 'value',
|
|
|
- tensor: { name: '', array: array }
|
|
|
- });
|
|
|
- groups.set('', group);
|
|
|
+ const layer = { type: 'numpy.ndarray', parameters: [ { name: 'value', tensor: { name: '', array: array } } ] };
|
|
|
+ graphs.push({ layers: [ layer ] });
|
|
|
break;
|
|
|
}
|
|
|
case 'npz': {
|
|
|
format = 'NumPy Zip';
|
|
|
+ const layers = new Map();
|
|
|
const execution = new python.Execution(null);
|
|
|
- const entries = context.entries('zip');
|
|
|
- for (const entry of entries) {
|
|
|
+ for (const entry of match.value) {
|
|
|
if (!entry[0].endsWith('.npy')) {
|
|
|
throw new numpy.Error("Invalid file name '" + entry.name + "'.");
|
|
|
}
|
|
|
@@ -59,10 +59,10 @@ numpy.ModelFactory = class {
|
|
|
const parts = name.split('/');
|
|
|
const parameterName = parts.pop();
|
|
|
const groupName = parts.join('/');
|
|
|
- if (!groups.has(groupName)) {
|
|
|
- groups.set(groupName, { name: groupName, parameters: [] });
|
|
|
+ if (!layers.has(groupName)) {
|
|
|
+ layers.set(groupName, { name: groupName, parameters: [] });
|
|
|
}
|
|
|
- const group = groups.get(groupName);
|
|
|
+ const layer = layers.get(groupName);
|
|
|
const stream = entry[1];
|
|
|
const buffer = stream.peek();
|
|
|
const bytes = execution.invoke('io.BytesIO', [ buffer ]);
|
|
|
@@ -74,17 +74,18 @@ numpy.ModelFactory = class {
|
|
|
const unpickler = python.Unpickler.open(array.data);
|
|
|
array = unpickler.load((name, args) => execution.invoke(name, args));
|
|
|
}
|
|
|
- group.parameters.push({
|
|
|
+ layer.parameters.push({
|
|
|
name: parameterName,
|
|
|
tensor: { name: name, array: array }
|
|
|
});
|
|
|
}
|
|
|
+ graphs.push({ layers: Array.from(layers.values()) });
|
|
|
break;
|
|
|
}
|
|
|
case 'pickle': {
|
|
|
format = 'NumPy Weights';
|
|
|
- const obj = context.open('pkl');
|
|
|
- const weights = numpy.Utility.weights(obj);
|
|
|
+ const layers = new Map();
|
|
|
+ const weights = match.value;
|
|
|
let separator = '_';
|
|
|
if (Array.from(weights.keys()).every((key) => key.indexOf('.') !== -1) &&
|
|
|
!Array.from(weights.keys()).every((key) => key.indexOf('_') !== -1)) {
|
|
|
@@ -95,41 +96,64 @@ numpy.ModelFactory = class {
|
|
|
const array = pair[1];
|
|
|
const parts = name.split(separator);
|
|
|
const parameterName = parts.length > 1 ? parts.pop() : '?';
|
|
|
- const groupName = parts.join(separator);
|
|
|
- if (!groups.has(groupName)) {
|
|
|
- groups.set(groupName, { name: groupName, parameters: [] });
|
|
|
+ const layerName = parts.join(separator);
|
|
|
+ if (!layers.has(layerName)) {
|
|
|
+ layers.set(layerName, { name: layerName, parameters: [] });
|
|
|
}
|
|
|
- const group = groups.get(groupName);
|
|
|
- group.parameters.push({
|
|
|
+ const layer = layers.get(layerName);
|
|
|
+ layer.parameters.push({
|
|
|
name: parameterName,
|
|
|
tensor: { name: name, array: array }
|
|
|
});
|
|
|
}
|
|
|
+ graphs.push({ layers: Array.from(layers.values()) });
|
|
|
break;
|
|
|
}
|
|
|
case 'numpy.ndarray': {
|
|
|
format = 'NumPy NDArray';
|
|
|
- const array = context.open('pkl');
|
|
|
- const group = { type: 'numpy.ndarray', parameters: [] };
|
|
|
- group.parameters.push({
|
|
|
- name: 'data',
|
|
|
- tensor: { name: '', array: array }
|
|
|
- });
|
|
|
- groups.set('', group);
|
|
|
+ const layer = {
|
|
|
+ type: 'numpy.ndarray',
|
|
|
+ parameters: [ { name: 'value', tensor: { name: '', array: match.value } } ]
|
|
|
+ };
|
|
|
+ graphs.push({ layers: [ layer ] });
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ case 'dnnlib.tflib.network': {
|
|
|
+ format = 'dnnlib';
|
|
|
+ for (const obj of match.value) {
|
|
|
+ const layers = new Map();
|
|
|
+ for (const entry of obj.variables) {
|
|
|
+ const name = entry[0];
|
|
|
+ const value = entry[1];
|
|
|
+ if (numpy.Utility.isTensor(value)) {
|
|
|
+ const parts = name.split('/');
|
|
|
+ const parameterName = parts.length > 1 ? parts.pop() : '?';
|
|
|
+ const layerName = parts.join('/');
|
|
|
+ if (!layers.has(layerName)) {
|
|
|
+ layers.set(layerName, { name: layerName, parameters: [] });
|
|
|
+ }
|
|
|
+ const layer = layers.get(layerName);
|
|
|
+ layer.parameters.push({
|
|
|
+ name: parameterName,
|
|
|
+ tensor: { name: name, array: value }
|
|
|
+ });
|
|
|
+ }
|
|
|
+ }
|
|
|
+ graphs.push({ name: obj.name, layers: Array.from(layers.values()) });
|
|
|
+ }
|
|
|
break;
|
|
|
}
|
|
|
}
|
|
|
- const model = new numpy.Model(format, groups.values());
|
|
|
+ const model = new numpy.Model(format, graphs);
|
|
|
return Promise.resolve(model);
|
|
|
}
|
|
|
};
|
|
|
|
|
|
numpy.Model = class {
|
|
|
|
|
|
- constructor(format, groups) {
|
|
|
+ constructor(format, graphs) {
|
|
|
this._format = format;
|
|
|
- this._graphs = [];
|
|
|
- this._graphs.push(new numpy.Graph(groups));
|
|
|
+ this._graphs = graphs.map((graph) => new numpy.Graph(graph));
|
|
|
}
|
|
|
|
|
|
get format() {
|
|
|
@@ -143,11 +167,13 @@ numpy.Model = class {
|
|
|
|
|
|
numpy.Graph = class {
|
|
|
|
|
|
- constructor(groups) {
|
|
|
- this._nodes = [];
|
|
|
- for (const group of groups) {
|
|
|
- this._nodes.push(new numpy.Node(group));
|
|
|
- }
|
|
|
+ constructor(graph) {
|
|
|
+ this._name = graph.name || '';
|
|
|
+ this._nodes = graph.layers.map((layer) => new numpy.Node(layer));
|
|
|
+ }
|
|
|
+
|
|
|
+ get name() {
|
|
|
+ return this._name;
|
|
|
}
|
|
|
|
|
|
get inputs() {
|
|
|
@@ -208,11 +234,11 @@ numpy.Argument = class {
|
|
|
|
|
|
numpy.Node = class {
|
|
|
|
|
|
- constructor(group) {
|
|
|
- this._name = group.name || '';
|
|
|
- this._type = { name: group.type || 'Module' };
|
|
|
+ constructor(layer) {
|
|
|
+ this._name = layer.name || '';
|
|
|
+ this._type = { name: layer.type || 'Module' };
|
|
|
this._inputs = [];
|
|
|
- for (const parameter of group.parameters) {
|
|
|
+ for (const parameter of layer.parameters) {
|
|
|
const initializer = new numpy.Tensor(parameter.tensor.array);
|
|
|
this._inputs.push(new numpy.Parameter(parameter.name, [
|
|
|
new numpy.Argument(parameter.tensor.name || '', initializer)
|