ソースを参照

Update onnx.py

Lutz Roeder 3 年 前
コミット
e49e8e064b
1 ファイル変更33 行追加17 行削除
  1. 33 17
      source/onnx.py

+ 33 - 17
source/onnx.py

@@ -1,6 +1,7 @@
 ''' ONNX backend '''
 
 import collections
+import enum
 import json
 
 class ModelFactory:
@@ -10,7 +11,6 @@ class ModelFactory:
         print('Experimental')
         # import onnx.shape_inference
         # model = onnx.shape_inference.infer_shapes(model)
-        import onnx.onnx_pb # pylint: disable=import-outside-toplevel
         json_model = {}
         json_model['signature'] = 'netron:onnx'
         json_model['format'] = 'ONNX' + (' v' + str(model.ir_version) if model.ir_version else '')
@@ -63,7 +63,7 @@ class ModelFactory:
         }
         json_model['graphs'].append(json_graph)
         arguments = {}
-        def tensor(tensor):
+        def tensor(tensor): # pylint: disable=unused-argument
             return {}
         def argument(name, tensor_type=None, initializer=None):
             if not name in arguments:
@@ -105,42 +105,41 @@ class ModelFactory:
                         'arguments': [ argument(value) ]
                     })
             json_node['attributes'] = []
-            AttributeProto = onnx.onnx_pb.AttributeProto
             for _ in node.attribute:
-                if _.type == AttributeProto.UNDEFINED:
+                if _.type == _AttributeType.UNDEFINED:
                     attribute_type = None
                     value = None
-                elif _.type == AttributeProto.FLOAT:
+                elif _.type == _AttributeType.FLOAT:
                     attribute_type = 'float32'
                     value = _.f
-                elif _.type == AttributeProto.INT:
+                elif _.type == _AttributeType.INT:
                     attribute_type = 'int64'
                     value = _.i
-                elif _.type == AttributeProto.STRING:
+                elif _.type == _AttributeType.STRING:
                     attribute_type = 'string'
                     value = _.s.decode('latin1' if op_type == 'Int8GivenTensorFill' else 'utf-8')
-                elif _.type == AttributeProto.TENSOR:
+                elif _.type == _AttributeType.TENSOR:
                     attribute_type = 'tensor'
                     value = tensor(_.t)
-                elif _.type == AttributeProto.GRAPH:
+                elif _.type == _AttributeType.GRAPH:
                     attribute_type = 'tensor'
                     raise Exception('Unsupported graph attribute type')
-                elif _.type == AttributeProto.FLOATS:
+                elif _.type == _AttributeType.FLOATS:
                     attribute_type = 'float32[]'
-                    value = [ item for item in _.floats ]
-                elif _.type == AttributeProto.INTS:
+                    value = list(_.floats)
+                elif _.type == _AttributeType.INTS:
                     attribute_type = 'int64[]'
-                    value = [ item for item in _.ints ]
-                elif _.type == AttributeProto.STRINGS:
+                    value = list(_.ints)
+                elif _.type == _AttributeType.STRINGS:
                     attribute_type = 'string[]'
                     value = [ item.decode('utf-8') for item in _.strings ]
-                elif _.type == AttributeProto.TENSORS:
+                elif _.type == _AttributeType.TENSORS:
                     attribute_type = 'tensor[]'
                     raise Exception('Unsupported tensors attribute type')
-                elif _.type == AttributeProto.GRAPHS:
+                elif _.type == _AttributeType.GRAPHS:
                     attribute_type = 'graph[]'
                     raise Exception('Unsupported graphs attribute type')
-                elif _.type == AttributeProto.SPARSE_TENSOR:
+                elif _.type == _AttributeType.SPARSE_TENSOR:
                     attribute_type = 'tensor'
                     value = tensor(_.sparse_tensor)
                 else:
@@ -218,3 +217,20 @@ class ModelFactory:
     def category(self, name):
         ''' Get category for type '''
         return self.categories[name] if name in self.categories else ''
+
+class _AttributeType(enum.IntEnum):
+    UNDEFINED = 0
+    FLOAT = 1
+    INT = 2
+    STRING = 3
+    TENSOR = 4
+    GRAPH = 5
+    FLOATS = 6
+    INTS = 7
+    STRINGS = 8
+    TENSORS = 9
+    GRAPHS = 10
+    SPARSE_TENSOR = 11
+    SPARSE_TENSORS = 12
+    TYPE_PROTO = 13
+    TYPE_PROTOS = 14