|
|
@@ -1,6 +1,7 @@
|
|
|
''' ONNX backend '''
|
|
|
|
|
|
import collections
|
|
|
+import enum
|
|
|
import json
|
|
|
|
|
|
class ModelFactory:
|
|
|
@@ -10,7 +11,6 @@ class ModelFactory:
|
|
|
print('Experimental')
|
|
|
# import onnx.shape_inference
|
|
|
# model = onnx.shape_inference.infer_shapes(model)
|
|
|
- import onnx.onnx_pb # pylint: disable=import-outside-toplevel
|
|
|
json_model = {}
|
|
|
json_model['signature'] = 'netron:onnx'
|
|
|
json_model['format'] = 'ONNX' + (' v' + str(model.ir_version) if model.ir_version else '')
|
|
|
@@ -63,7 +63,7 @@ class ModelFactory:
|
|
|
}
|
|
|
json_model['graphs'].append(json_graph)
|
|
|
arguments = {}
|
|
|
- def tensor(tensor):
|
|
|
+ def tensor(tensor): # pylint: disable=unused-argument
|
|
|
return {}
|
|
|
def argument(name, tensor_type=None, initializer=None):
|
|
|
if not name in arguments:
|
|
|
@@ -105,42 +105,41 @@ class ModelFactory:
|
|
|
'arguments': [ argument(value) ]
|
|
|
})
|
|
|
json_node['attributes'] = []
|
|
|
- AttributeProto = onnx.onnx_pb.AttributeProto
|
|
|
for _ in node.attribute:
|
|
|
- if _.type == AttributeProto.UNDEFINED:
|
|
|
+ if _.type == _AttributeType.UNDEFINED:
|
|
|
attribute_type = None
|
|
|
value = None
|
|
|
- elif _.type == AttributeProto.FLOAT:
|
|
|
+ elif _.type == _AttributeType.FLOAT:
|
|
|
attribute_type = 'float32'
|
|
|
value = _.f
|
|
|
- elif _.type == AttributeProto.INT:
|
|
|
+ elif _.type == _AttributeType.INT:
|
|
|
attribute_type = 'int64'
|
|
|
value = _.i
|
|
|
- elif _.type == AttributeProto.STRING:
|
|
|
+ elif _.type == _AttributeType.STRING:
|
|
|
attribute_type = 'string'
|
|
|
value = _.s.decode('latin1' if op_type == 'Int8GivenTensorFill' else 'utf-8')
|
|
|
- elif _.type == AttributeProto.TENSOR:
|
|
|
+ elif _.type == _AttributeType.TENSOR:
|
|
|
attribute_type = 'tensor'
|
|
|
value = tensor(_.t)
|
|
|
- elif _.type == AttributeProto.GRAPH:
|
|
|
+ elif _.type == _AttributeType.GRAPH:
|
|
|
attribute_type = 'tensor'
|
|
|
raise Exception('Unsupported graph attribute type')
|
|
|
- elif _.type == AttributeProto.FLOATS:
|
|
|
+ elif _.type == _AttributeType.FLOATS:
|
|
|
attribute_type = 'float32[]'
|
|
|
- value = [ item for item in _.floats ]
|
|
|
- elif _.type == AttributeProto.INTS:
|
|
|
+ value = list(_.floats)
|
|
|
+ elif _.type == _AttributeType.INTS:
|
|
|
attribute_type = 'int64[]'
|
|
|
- value = [ item for item in _.ints ]
|
|
|
- elif _.type == AttributeProto.STRINGS:
|
|
|
+ value = list(_.ints)
|
|
|
+ elif _.type == _AttributeType.STRINGS:
|
|
|
attribute_type = 'string[]'
|
|
|
value = [ item.decode('utf-8') for item in _.strings ]
|
|
|
- elif _.type == AttributeProto.TENSORS:
|
|
|
+ elif _.type == _AttributeType.TENSORS:
|
|
|
attribute_type = 'tensor[]'
|
|
|
raise Exception('Unsupported tensors attribute type')
|
|
|
- elif _.type == AttributeProto.GRAPHS:
|
|
|
+ elif _.type == _AttributeType.GRAPHS:
|
|
|
attribute_type = 'graph[]'
|
|
|
raise Exception('Unsupported graphs attribute type')
|
|
|
- elif _.type == AttributeProto.SPARSE_TENSOR:
|
|
|
+ elif _.type == _AttributeType.SPARSE_TENSOR:
|
|
|
attribute_type = 'tensor'
|
|
|
value = tensor(_.sparse_tensor)
|
|
|
else:
|
|
|
@@ -218,3 +217,20 @@ class ModelFactory:
|
|
|
def category(self, name):
|
|
|
''' Get category for type '''
|
|
|
return self.categories[name] if name in self.categories else ''
|
|
|
+
|
|
|
+class _AttributeType(enum.IntEnum):
|
|
|
+ UNDEFINED = 0
|
|
|
+ FLOAT = 1
|
|
|
+ INT = 2
|
|
|
+ STRING = 3
|
|
|
+ TENSOR = 4
|
|
|
+ GRAPH = 5
|
|
|
+ FLOATS = 6
|
|
|
+ INTS = 7
|
|
|
+ STRINGS = 8
|
|
|
+ TENSORS = 9
|
|
|
+ GRAPHS = 10
|
|
|
+ SPARSE_TENSOR = 11
|
|
|
+ SPARSE_TENSORS = 12
|
|
|
+ TYPE_PROTO = 13
|
|
|
+ TYPE_PROTOS = 14
|