Explorar o código

Add PyTorch test file (#720)

Lutz Roeder %!s(int64=4) %!d(string=hai) anos
pai
achega
e1dd6361f3
Modificáronse 2 ficheiros con 37 adicións e 10 borrados
  1. 30 10
      source/pytorch.js
  2. 7 0
      test/models.json

+ 30 - 10
source/pytorch.js

@@ -3384,17 +3384,15 @@ pytorch.Utility = class {
         return null;
     }
 
-    static _convertObjectList(list) {
-        if (list && Array.isArray(list) && list.every((obj) => obj && Object.keys(obj).filter((key) => pytorch.Utility.isTensor(obj[key]).length > 0))) {
-            const layers = [];
-            for (const obj of list) {
+    static _convertObjectList(obj) {
+        if (obj && Array.isArray(obj)) {
+            if (obj.every((item) => typeof item === 'number' || typeof item === 'string')) {
+                const layers = [];
                 const type = obj.__class__ ? obj.__class__.__module__ + '.' + obj.__class__.__name__ : '?';
                 const layer = { type: type, states: [], attributes: [] };
-                if (obj instanceof Map) {
-                    return null;
-                }
-                for (const key of Object.keys(obj)) {
-                    const value = obj[key];
+                for (let i = 0; i < obj.length; i++) {
+                    const key = i.toString();
+                    const value = obj[i];
                     if (pytorch.Utility.isTensor(value)) {
                         layer.states.push({ name: key, arguments: [ { id: '', value: value } ] });
                     }
@@ -3403,8 +3401,30 @@ pytorch.Utility = class {
                     }
                 }
                 layers.push(layer);
+                return [ { layers: layers } ];
+            }
+            if (obj.every((item) => item && Object.values(item).filter((value) => pytorch.Utility.isTensor(value)).length > 0)) {
+                const layers = [];
+                for (const item of obj) {
+                    const type = item.__class__ ? item.__class__.__module__ + '.' + item.__class__.__name__ : '?';
+                    const layer = { type: type, states: [], attributes: [] };
+                    if (item instanceof Map) {
+                        return null;
+                    }
+                    for (const entry of Object.entries(item)) {
+                        const key = entry[0];
+                        const value = entry[1];
+                        if (pytorch.Utility.isTensor(value)) {
+                            layer.states.push({ name: key, arguments: [ { id: '', value: value } ] });
+                        }
+                        else {
+                            layer.attributes.push({ name: key, value: value });
+                        }
+                    }
+                    layers.push(layer);
+                }
+                return [ { layers: layers } ];
             }
-            return [ { layers: layers } ];
         }
         return null;
     }

+ 7 - 0
test/models.json

@@ -4334,6 +4334,13 @@
     "format": "TorchScript v1.0",
     "link":   "https://github.com/ApolloAuto/apollo"
   },
+  {
+    "type":   "pytorch",
+    "target": "labels.pth",
+    "source": "https://github.com/lutzroeder/netron/files/7350657/labels.pth.zip[labels.pth]",
+    "format": "PyTorch v1.6",
+    "link":   "https://github.com/lutzroeder/netron/issues/720"
+  },
   {
     "type":   "pytorch",
     "target": "lane_scanning_vehicle_model.pt",