소스 검색

Update PyTorch test files (#543)

Lutz Roeder 1 년 전
부모
커밋
5475205a7d
3개의 변경된 파일82개의 추가작업 그리고 41개의 파일을 삭제
  1. 63 40
      source/pytorch.js
  2. 14 1
      source/view.js
  3. 5 0
      test/models.json

+ 63 - 40
source/pytorch.js

@@ -3545,22 +3545,22 @@ pytorch.Utility = class {
         return `${name} ${versions.get(value)}`;
     }
 
-    static find(obj) {
-        if (obj) {
-            if (pytorch.Utility.isTensor(obj)) {
+    static find(data) {
+        if (data) {
+            if (pytorch.Utility.isTensor(data)) {
                 const module = {};
                 module.__class__ = {
-                    __module__: obj.__class__.__module__,
-                    __name__: obj.__class__.__name__
+                    __module__: data.__class__.__module__,
+                    __name__: data.__class__.__name__
                 };
                 module._parameters = new Map();
-                module._parameters.set('value', obj);
+                module._parameters.set('value', data);
                 return new Map([['', { _modules: new Map([['', module]]) }]]);
             }
-            if (!Array.isArray(obj) && !(obj instanceof Map) && obj === Object(obj) && Object.keys(obj).length === 0) {
+            if (!Array.isArray(data) && !(data instanceof Map) && data === Object(data) && Object.keys(data).length === 0) {
                 return new Map();
             }
-            const keys = Array.isArray(obj) ? [] : Object.keys(obj);
+            const keys = Array.isArray(data) ? [] : Object.keys(data);
             if (keys.length > 1) {
                 keys.splice(0, keys.length);
             }
@@ -3572,28 +3572,22 @@ pytorch.Utility = class {
                 'EMA_generator', 'runner', ''
             ]);
             for (const key of keys) {
-                const value = key === '' ? obj : obj[key];
-                let graphs = null;
-                graphs = graphs || pytorch.Utility._convertObjectList(value);
-                graphs = graphs || pytorch.Utility._convertStateDict(value);
+                const obj = key === '' ? data : data[key];
+                if (obj && Array.isArray(obj)) {
+                    if (obj.every((item) => typeof item === 'number' || typeof item === 'string')) {
+                        return new Map([['', obj]]);
+                    }
+                    if (obj.every((item) => item && Object.values(item).filter((value) => pytorch.Utility.isTensor(value)).length > 0)) {
+                        return new Map([['', obj]]);
+                    }
+                }
+                const graphs = pytorch.Utility._convertStateDict(obj);
                 if (graphs) {
                     return graphs;
                 }
             }
         }
-        return new Map([['', obj]]);
-    }
-
-    static _convertObjectList(obj) {
-        if (obj && Array.isArray(obj)) {
-            if (obj.every((item) => typeof item === 'number' || typeof item === 'string')) {
-                return new Map([['', obj]]);
-            }
-            if (obj.every((item) => item && Object.values(item).filter((value) => pytorch.Utility.isTensor(value)).length > 0)) {
-                return new Map([['', obj]]);
-            }
-        }
-        return null;
+        return new Map([['', data]]);
     }
 
     static _convertStateDict(obj) {
@@ -3630,22 +3624,25 @@ pytorch.Utility = class {
             return count > 0;
         };
         const isLayer = (obj) => {
-            if (obj instanceof Map === false) {
+            if (Object(obj) === obj) {
                 obj = new Map(Object.entries(obj));
             }
-            for (const [key, value] of Array.from(obj)) {
-                if (pytorch.Utility.isTensor(value)) {
-                    continue;
-                }
-                if (key === '_metadata') {
-                    continue;
-                }
-                if (typeof value === 'string' || typeof value === 'number' || typeof value === 'boolean') {
-                    continue;
+            if (obj instanceof Map) {
+                for (const [key, value] of Array.from(obj)) {
+                    if (pytorch.Utility.isTensor(value)) {
+                        continue;
+                    }
+                    if (key === '_metadata') {
+                        continue;
+                    }
+                    if (typeof value === 'string' || typeof value === 'number' || typeof value === 'boolean') {
+                        continue;
+                    }
+                    return false;
                 }
-                return false;
+                return true;
             }
-            return true;
+            return false;
         };
         const flatten = (obj) => {
             if (!obj || Array.isArray(obj) || ArrayBuffer.isView(obj)) {
@@ -3666,6 +3663,9 @@ pytorch.Utility = class {
             }
             const target = new Map();
             for (const [name, obj] of map) {
+                if (obj && pytorch.Utility.isInstance(obj, 'builtins.type')) {
+                    return null;
+                }
                 const value = flatten(obj);
                 if (value && value instanceof Map) {
                     for (const pair of value) {
@@ -3687,9 +3687,9 @@ pytorch.Utility = class {
             }
         } else if (obj instanceof Map && validate(obj)) {
             map.set('', flatten(obj));
-        } else if ((Object(obj) === obj && Object.entries(obj).every(([, value]) => value && isLayer(value)))) {
+        } else if (obj instanceof Map === false && Object(obj) === obj && Object.entries(obj).every(([, value]) => value && isLayer(value))) {
             return new Map([['', { _modules: new Map(Object.entries(obj)) }]]);
-        } else if (Object(obj) === obj && Object.entries(obj).every(([, value]) => validate(value))) {
+        } else if (obj instanceof Map === false && Object(obj) === obj && Object.entries(obj).every(([, value]) => validate(value))) {
             for (const [name, value] of Object.entries(obj)) {
                 if (Object(value) === value) {
                     map.set(name, new Map(Object.entries(value)));
@@ -4218,13 +4218,36 @@ pytorch.Metadata = class {
 numpy.Tensor = class  {
 
     constructor(array) {
-        this.type = new pytorch.TensorType(array.dtype.__name__, new pytorch.TensorShape(array.shape));
+        this.type = new numpy.TensorType(array.dtype.__name__, new numpy.TensorShape(array.shape));
         this.stride = array.strides.map((stride) => stride / array.itemsize);
         this.values = this.type.dataType === 'string' || this.type.dataType === 'object' || this.type.dataType === 'void' ? array.flatten().tolist() : array.tobytes();
         this.encoding = this.type.dataType === 'string' || this.type.dataType === 'object' ? '|' : array.dtype.byteorder;
     }
 };
 
+numpy.TensorType = class {
+
+    constructor(dataType, shape) {
+        this.dataType = dataType || '?';
+        this.shape = shape;
+    }
+
+    toString() {
+        return this.dataType + this.shape.toString();
+    }
+};
+
+numpy.TensorShape = class {
+
+    constructor(dimensions) {
+        this.dimensions = dimensions;
+    }
+
+    toString() {
+        return this.dimensions && this.dimensions.length > 0 ? `[${this.dimensions.join(',')}]` : '';
+    }
+};
+
 pytorch.Error = class extends Error {
 
     constructor(message) {

+ 14 - 1
source/view.js

@@ -2042,10 +2042,23 @@ view.Node = class extends grapher.Node {
             item.separator = ' = ';
             return item;
         };
+        const isObject = (node) => {
+            if (node.name || node.identifier || node.description ||
+                (Array.isArray(node.inputs) && node.inputs.length > 0) ||
+                (Array.isArray(node.outputs) && node.outputs.length > 0) ||
+                (Array.isArray(node.attributes) && node.attributes.length > 0) ||
+                (Array.isArray(node.chain) && node.chain.length > 0)) {
+                return true;
+            }
+
+            return false;
+        };
         if (Array.isArray(node.inputs)) {
             for (const argument of node.inputs) {
                 const type = argument.type;
-                if (type === 'graph' || type === 'object' || type === 'object[]' || type === 'function' || type === 'function[]') {
+                if (type === 'graph' ||
+                    (type === 'object' && isObject(argument.value)) ||
+                    type === 'object[]' || type === 'function' || type === 'function[]') {
                     objects.push(argument);
                 } else if (options.weights && argument.visible !== false && argument.type !== 'attribute' && Array.isArray(argument.value) && argument.value.length === 1 && argument.value[0].initializer) {
                     const item = this.context.createArgument(argument);

+ 5 - 0
test/models.json

@@ -4899,6 +4899,7 @@
     "target":   "bad-hands-5.pt",
     "source":   "https://github.com/lutzroeder/netron/files/14471657/bad-hands-5.pt.zip[bad-hands-5.pt]",
     "format":   "PyTorch v1.6",
+    "assert":   "model.graphs[0].nodes[0].inputs.length == 6",
     "link":     "https://github.com/lutzroeder/netron/issues/720"
   },
   {
@@ -4913,6 +4914,7 @@
     "target":   "best_mask.pth",
     "source":   "https://github.com/user-attachments/files/16401712/best_mask.pth.zip[best_mask.pth]",
     "format":   "PyTorch v1.6",
+    "assert":   "model.graphs.length == 0",
     "link":     "https://github.com/lutzroeder/netron/issues/543"
   },
   {
@@ -5297,6 +5299,7 @@
     "target":   "mcunet-5fps.pkl",
     "source":   "https://github.com/user-attachments/files/16401553/mcunet-5fps.pkl.zip[mcunet-5fps.pkl]",
     "format":   "PyTorch v1.6",
+    "assert":   "model.graphs[0].nodes[0].inputs[0].value.inputs[9].value.inputs[0].value.type.dataType == 'int8'",
     "link":     "https://github.com/lutzroeder/netron/issues/543"
   },
   {
@@ -5304,6 +5307,7 @@
     "target":   "mnist_bfloat16.pt",
     "source":   "https://github.com/lutzroeder/netron/files/8556403/mnist_bfloat16.pt.zip[mnist_bfloat16.pt]",
     "format":   "PyTorch v1.6",
+    "assert":   "model.graphs[0].nodes[0].inputs[0].value[0].initializer.type.dataType == 'bfloat16'",
     "link":     "https://github.com/lutzroeder/netron/issues/720"
   },
   {
@@ -5809,6 +5813,7 @@
     "target":   "rng_state.pth",
     "source":   "https://github.com/user-attachments/files/16401709/rng_state.pth.zip[rng_state.pth]",
     "format":   "PyTorch v1.6",
+    "assert":   "model.graphs[0].nodes[0].inputs[3].name == 'numpy'",
     "link":     "https://github.com/lutzroeder/netron/issues/543"
   },
   {