Quellcode durchsuchen

Update numpy.js (#859)

Lutz Roeder vor 4 Jahren
Ursprung
Commit
3ddc7c33d2
1 geänderte Dateien mit 72 neuen und 46 gelöschten Zeilen
  1. 72 46
      source/numpy.js

+ 72 - 46
source/numpy.js

@@ -10,19 +10,23 @@ numpy.ModelFactory = class {
         const stream = context.stream;
         const signature = [ 0x93, 0x4E, 0x55, 0x4D, 0x50, 0x59 ];
         if (signature.length <= stream.length && stream.peek(signature.length).every((value, index) => value === signature[index])) {
-            return 'npy';
+            return { name: 'npy' };
         }
         const entries = context.entries('zip');
         if (entries.size > 0 && Array.from(entries.keys()).every((name) => name.endsWith('.npy'))) {
-            return 'npz';
+            return { name: 'npz', value: entries };
         }
         const obj = context.open('pkl');
         if (obj) {
             if (numpy.Utility.isTensor(obj)) {
-                return 'numpy.ndarray';
+                return { name: 'numpy.ndarray', value: obj };
             }
-            if (numpy.Utility.weights(obj)) {
-                return 'pickle';
+            if (Array.isArray(obj) && obj.every((obj) => obj && obj.__class__ && obj.__class__.__name__ === 'Network' && (obj.__class__.__module__ === 'dnnlib.tflib.network' || obj.__class__.__module__ === 'tfutil'))) {
+                return { name: 'dnnlib.tflib.network', value: obj };
+            }
+            const weights = numpy.Utility.weights(obj);
+            if (weights) {
+                return { name: 'pickle', value: weights };
             }
         }
         return undefined;
@@ -30,8 +34,8 @@ numpy.ModelFactory = class {
 
     open(context, match) {
         let format = '';
-        const groups = new Map();
-        switch (match) {
+        const graphs = [];
+        switch (match.name) {
             case 'npy': {
                 format = 'NumPy Array';
                 const execution = new python.Execution(null);
@@ -39,19 +43,15 @@ numpy.ModelFactory = class {
                 const buffer = stream.peek();
                 const bytes = execution.invoke('io.BytesIO', [ buffer ]);
                 const array = execution.invoke('numpy.load', [ bytes ]);
-                const group = { type: format, parameters: [] };
-                group.parameters.push({
-                    name: 'value',
-                    tensor: { name: '', array: array }
-                });
-                groups.set('', group);
+                const layer = { type: 'numpy.ndarray', parameters: [ { name: 'value', tensor: { name: '', array: array } } ] };
+                graphs.push({ layers: [ layer ] });
                 break;
             }
             case 'npz': {
                 format = 'NumPy Zip';
+                const layers = new Map();
                 const execution = new python.Execution(null);
-                const entries = context.entries('zip');
-                for (const entry of entries) {
+                for (const entry of match.value) {
                     if (!entry[0].endsWith('.npy')) {
                         throw new numpy.Error("Invalid file name '" + entry.name + "'.");
                     }
@@ -59,10 +59,10 @@ numpy.ModelFactory = class {
                     const parts = name.split('/');
                     const parameterName = parts.pop();
                     const groupName = parts.join('/');
-                    if (!groups.has(groupName)) {
-                        groups.set(groupName, { name: groupName, parameters: [] });
+                    if (!layers.has(groupName)) {
+                        layers.set(groupName, { name: groupName, parameters: [] });
                     }
-                    const group = groups.get(groupName);
+                    const layer = layers.get(groupName);
                     const stream = entry[1];
                     const buffer = stream.peek();
                     const bytes = execution.invoke('io.BytesIO', [ buffer ]);
@@ -74,17 +74,18 @@ numpy.ModelFactory = class {
                         const unpickler = python.Unpickler.open(array.data);
                         array = unpickler.load((name, args) => execution.invoke(name, args));
                     }
-                    group.parameters.push({
+                    layer.parameters.push({
                         name: parameterName,
                         tensor: { name: name, array: array }
                     });
                 }
+                graphs.push({ layers: Array.from(layers.values()) });
                 break;
             }
             case 'pickle': {
                 format = 'NumPy Weights';
-                const obj = context.open('pkl');
-                const weights = numpy.Utility.weights(obj);
+                const layers = new Map();
+                const weights = match.value;
                 let separator = '_';
                 if (Array.from(weights.keys()).every((key) => key.indexOf('.') !== -1) &&
                     !Array.from(weights.keys()).every((key) => key.indexOf('_') !== -1)) {
@@ -95,41 +96,64 @@ numpy.ModelFactory = class {
                     const array = pair[1];
                     const parts = name.split(separator);
                     const parameterName = parts.length > 1 ? parts.pop() : '?';
-                    const groupName = parts.join(separator);
-                    if (!groups.has(groupName)) {
-                        groups.set(groupName, { name: groupName, parameters: [] });
+                    const layerName = parts.join(separator);
+                    if (!layers.has(layerName)) {
+                        layers.set(layerName, { name: layerName, parameters: [] });
                     }
-                    const group = groups.get(groupName);
-                    group.parameters.push({
+                    const layer = layers.get(layerName);
+                    layer.parameters.push({
                         name: parameterName,
                         tensor: { name: name, array: array }
                     });
                 }
+                graphs.push({ layers: Array.from(layers.values()) });
                 break;
             }
             case 'numpy.ndarray': {
                 format = 'NumPy NDArray';
-                const array = context.open('pkl');
-                const group = { type: 'numpy.ndarray', parameters: [] };
-                group.parameters.push({
-                    name: 'data',
-                    tensor: { name: '', array: array }
-                });
-                groups.set('', group);
+                const layer = {
+                    type: 'numpy.ndarray',
+                    parameters: [ { name: 'value', tensor: { name: '', array: match.value } } ]
+                };
+                graphs.push({ layers: [ layer ] });
+                break;
+            }
+            case 'dnnlib.tflib.network': {
+                format = 'dnnlib';
+                for (const obj of match.value) {
+                    const layers = new Map();
+                    for (const entry of obj.variables) {
+                        const name = entry[0];
+                        const value = entry[1];
+                        if (numpy.Utility.isTensor(value)) {
+                            const parts = name.split('/');
+                            const parameterName = parts.length > 1 ? parts.pop() : '?';
+                            const layerName = parts.join('/');
+                            if (!layers.has(layerName)) {
+                                layers.set(layerName, { name: layerName, parameters: [] });
+                            }
+                            const layer = layers.get(layerName);
+                            layer.parameters.push({
+                                name: parameterName,
+                                tensor: { name: name, array: value }
+                            });
+                        }
+                    }
+                    graphs.push({ name: obj.name, layers: Array.from(layers.values()) });
+                }
                 break;
             }
         }
-        const model = new numpy.Model(format, groups.values());
+        const model = new numpy.Model(format, graphs);
         return Promise.resolve(model);
     }
 };
 
 numpy.Model = class {
 
-    constructor(format, groups) {
+    constructor(format, graphs) {
         this._format = format;
-        this._graphs = [];
-        this._graphs.push(new numpy.Graph(groups));
+        this._graphs = graphs.map((graph) => new numpy.Graph(graph));
     }
 
     get format() {
@@ -143,11 +167,13 @@ numpy.Model = class {
 
 numpy.Graph = class {
 
-    constructor(groups) {
-        this._nodes = [];
-        for (const group of groups) {
-            this._nodes.push(new numpy.Node(group));
-        }
+    constructor(graph) {
+        this._name = graph.name || '';
+        this._nodes = graph.layers.map((layer) => new numpy.Node(layer));
+    }
+
+    get name() {
+        return this._name;
     }
 
     get inputs() {
@@ -208,11 +234,11 @@ numpy.Argument = class {
 
 numpy.Node = class {
 
-    constructor(group) {
-        this._name = group.name || '';
-        this._type = { name: group.type || 'Module' };
+    constructor(layer) {
+        this._name = layer.name || '';
+        this._type = { name: layer.type || 'Module' };
         this._inputs = [];
-        for (const parameter of group.parameters) {
+        for (const parameter of layer.parameters) {
             const initializer = new numpy.Tensor(parameter.tensor.array);
             this._inputs.push(new numpy.Parameter(parameter.name, [
                 new numpy.Argument(parameter.tensor.name || '', initializer)