瀏覽代碼

Add PyTorch test file (#720)

Lutz Roeder 3 年之前
父節點
當前提交
e8aaaa4d68
共有 4 個文件被更改,包括 80 次插入5 次删除
  1. 61 1
      source/python.js
  2. 9 1
      source/pytorch.js
  3. 3 3
      source/view.js
  4. 7 0
      test/models.json

+ 61 - 1
source/python.js

@@ -1947,10 +1947,67 @@ python.Execution = class {
         this.registerType('numpy.inexact', class {});
         this.registerType('numpy.number', class extends numpy.generic {});
         this.registerType('numpy.integer', class extends numpy.number {});
-        this.registerType('numpy.signedinteger', class extends numpy.integer {});
         this.registerType('numpy.floating', class extends numpy.inexact {});
+        this.registerType('numpy.float32', class extends numpy.floating {});
         this.registerType('numpy.float64', class extends numpy.floating {});
+        this.registerType('numpy.signedinteger', class extends numpy.integer {});
+        this.registerType('numpy.int8', class extends numpy.signedinteger {});
+        this.registerType('numpy.int16', class extends numpy.signedinteger {});
+        this.registerType('numpy.int32', class extends numpy.signedinteger {});
         this.registerType('numpy.int64', class extends numpy.signedinteger {});
+        this.registerType('numpy.unsignedinteger', class extends numpy.integer {});
+        this.registerType('numpy.uint8', class extends numpy.unsignedinteger {});
+        this.registerType('numpy.uint16', class extends numpy.unsignedinteger {});
+        this.registerType('numpy.uint32', class extends numpy.unsignedinteger {});
+        this.registerType('numpy.uint64', class extends numpy.unsignedinteger {});
+        this.registerType('fastai.callback.core.TrainEvalCallback', class {});
+        this.registerType('fastai.callback.progress.ProgressCallback', class {});
+        this.registerType('fastai.data.core.DataLoaders', class {});
+        this.registerType('fastai.data.core.Datasets', class {});
+        this.registerType('fastai.data.core.TfmdDL', class {});
+        this.registerType('fastai.data.core.TfmdLists', class {});
+        this.registerType('fastai.data.load._FakeLoader', class {});
+        this.registerType('fastai.data.load._wif', class {});
+        this.registerType('fastai.data.transforms.Categorize', class {});
+        this.registerType('fastai.data.transforms.CategoryMap', class {});
+        this.registerType('fastai.data.transforms.IntToFloatTensor', class {});
+        this.registerType('fastai.data.transforms.Normalize', class {});
+        this.registerType('fastai.data.transforms.parent_label', class {});
+        this.registerType('fastai.data.transforms.ToTensor', class {});
+        this.registerType('fastai.imports.noop', class {});
+        this.registerType('fastai.layers.AdaptiveConcatPool2d', class {});
+        this.registerType('fastai.layers.Flatten', class {});
+        this.registerType('fastai.learner.AvgLoss', class {});
+        this.registerType('fastai.learner.AvgMetric', class {});
+        this.registerType('fastai.learner.AvgSmoothLoss', class {});
+        this.registerType('fastai.learner.Learner', class {});
+        this.registerType('fastai.learner.Recorder', class {});
+        this.registerType('fastai.losses.CrossEntropyLossFlat', class {});
+        this.registerType('fastai.metrics.error_rate', class {});
+        this.registerType('fastai.optimizer.Adam', class {});
+        this.registerType('fastai.torch_core._fa_rebuild_tensor', class {});
+        this.registerType('fastai.torch_core.TensorBase', class {});
+        this.registerType('fastai.torch_core.TensorCategory', class {});
+        this.registerType('fastai.torch_core.TensorImage', class {});
+        this.registerType('fastai.vision.augment._BrightnessLogit', class {});
+        this.registerType('fastai.vision.augment._ContrastLogit', class {});
+        this.registerType('fastai.vision.augment._WarpCoord', class {});
+        this.registerType('fastai.vision.augment.Brightness', class {});
+        this.registerType('fastai.vision.augment.Flip', class {});
+        this.registerType('fastai.vision.augment.flip_mat', class {});
+        this.registerType('fastai.vision.augment.RandomResizedCropGPU', class {});
+        this.registerType('fastai.vision.augment.Resize', class {});
+        this.registerType('fastai.vision.augment.rotate_mat', class {});
+        this.registerType('fastai.vision.augment.zoom_mat', class {});
+        this.registerType('fastai.vision.core.PILImage', class {});
+        this.registerType('fastai.vision.learner._resnet_split', class {});
+        this.registerType('fastcore.basics.fastuple', class {});
+        this.registerType('fastcore.dispatch._TypeDict', class {});
+        this.registerType('fastcore.dispatch.TypeDispatch', class {});
+        this.registerType('fastcore.foundation.L', class {});
+        this.registerType('fastcore.transform.Pipeline', class {});
+        this.registerType('fastcore.transform.Transform', class {});
+        this.registerType('functools.partial', class {});
         this.registerType('gensim.models.doc2vec.Doctag', class {});
         this.registerType('gensim.models.doc2vec.Doc2Vec', class {});
         this.registerType('gensim.models.doc2vec.Doc2VecTrainables', class {});
@@ -2433,6 +2490,7 @@ python.Execution = class {
         this.registerType('sklearn.preprocessing._function_transformer.FunctionTransformer', class {});
         this.registerType('sklearn.preprocessing._label.LabelBinarizer', class {});
         this.registerType('sklearn.preprocessing._label.LabelEncoder', class {});
+        this.registerType('sklearn.preprocessing._label.MultiLabelBinarizer', class {});
         this.registerType('sklearn.preprocessing._polynomial.PolynomialFeatures', class {});
         this.registerType('sklearn.preprocessing.data.Binarizer', class {});
         this.registerType('sklearn.preprocessing.data.MaxAbsScaler', class {});
@@ -2841,6 +2899,7 @@ python.Execution = class {
                 return this._reader.stream(size);
             }
         });
+        this.registerType('random.Random', class {});
         this.registerType('re.Pattern', class {
             constructor(pattern, flags) {
                 this.pattern = pattern;
@@ -3776,6 +3835,7 @@ python.Execution = class {
         this.registerType('torchvision.ops.feature_pyramid_network.FeaturePyramidNetwork', class {});
         this.registerType('torchvision.ops.feature_pyramid_network.LastLevelMaxPool', class {});
         this.registerType('torchvision.ops.feature_pyramid_network.LastLevelP6P7', class {});
+        this.registerType('torchvision.ops.misc.Conv2dNormActivation', class {});
         this.registerType('torchvision.ops.misc.ConvNormActivation', class {});
         this.registerType('torchvision.ops.misc.ConvTranspose2d', class {});
         this.registerType('torchvision.ops.misc.FrozenBatchNorm2d', class {});

+ 9 - 1
source/pytorch.js

@@ -1770,7 +1770,15 @@ pytorch.Container.Zip.Pickle = class extends pytorch.Container.Zip {
         if (!this._graphs) {
             const execution = new pytorch.Container.Zip.Execution(null, this._exceptionCallback, this._metadata);
             const graph = new pytorch.Container.Zip.Pickle.Script(this._entries, execution);
-            this._graphs = graph.data.forward ? [ graph ] : pytorch.Utility.find(graph.data);
+            if (graph.data && graph.data.forward) {
+                this._graphs = [ graph ];
+            }
+            else if (graph.data && graph.data.__class__ && graph.data.__class__.__module__ == 'fastai.learner' && graph.data.__class__.__name__ == 'Learner') {
+                this._graphs = pytorch.Utility.find(graph.data.model);
+            }
+            else {
+                this._graphs = pytorch.Utility.find(graph.data);
+            }
         }
         return this._graphs;
     }

+ 3 - 3
source/view.js

@@ -951,8 +951,8 @@ view.Graph = class extends grapher.Graph {
                         }
                     }
                     if (groupName) {
-                        createCluster(groupName);
-                        this.setParent(viewNode.name, groupName);
+                        createCluster(groupName + '\ngroup');
+                        this.setParent(viewNode.name, groupName + '\ngroup');
                     }
                 }
             }
@@ -1100,7 +1100,7 @@ view.Node = class extends grapher.Node {
             this._add(node.inner);
         }
         if (node.nodes) {
-            this.canvas = this.canvas();
+            // this.canvas = this.canvas();
         }
     }
 

+ 7 - 0
test/models.json

@@ -4198,6 +4198,13 @@
     "error":    "Unsupported torch.add expression type in 'fasterrcnn_resnet50_fpn.pt'.",
     "link":     "https://github.com/lutzroeder/netron/issues/689"
   },
+  {
+    "type":     "pytorch",
+    "target":   "fruit_veg_model.pkl",
+    "source":   "https://github.com/lutzroeder/netron/files/9265633/fruit_veg_model.pkl.zip[fruit_veg_model.pkl]",
+    "format":   "PyTorch v1.6",
+    "link":     "https://github.com/lutzroeder/netron/issues/720"
+  },
   {
     "type":     "pytorch",
     "target":   "gcn2_tiny_320x240.pb",