onnx.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. ''' ONNX backend '''
  2. import collections
  3. import enum
  4. import json
  5. class ModelFactory:
  6. ''' ONNX backend model factory '''
  7. def serialize(self, model):
  8. ''' Serialize ONNX model to JSON message '''
  9. print('Experimental')
  10. # import onnx.shape_inference
  11. # model = onnx.shape_inference.infer_shapes(model)
  12. json_model = {}
  13. json_model['signature'] = 'netron:onnx'
  14. json_model['format'] = 'ONNX' + (' v' + str(model.ir_version) if model.ir_version else '')
  15. if model.producer_name and len(model.producer_name) > 0:
  16. producer_version = ' v' + model.producer_version if model.producer_version else ''
  17. json_model['producer'] = model.producer_name + producer_version
  18. if model.model_version and model.model_version != 0:
  19. json_model['version'] = str(model.model_version)
  20. if model.doc_string and len(model.doc_string):
  21. json_model['description'] = str(model.doc_string)
  22. json_metadata = []
  23. metadata_props = [ [ entry.key, entry.value ] for entry in model.metadata_props ]
  24. metadata = collections.OrderedDict(metadata_props)
  25. value = metadata.get('converted_from')
  26. if value:
  27. json_metadata.append({ 'name': 'source', 'value': value })
  28. value = metadata.get('author')
  29. if value:
  30. json_metadata.append({ 'name': 'author', 'value': value })
  31. value = metadata.get('company')
  32. if value:
  33. json_metadata.append({ 'name': 'company', 'value': value })
  34. value = metadata.get('license')
  35. license_url = metadata.get('license_url')
  36. if license_url:
  37. value = '<a href=\'' + license_url + '\'>' + (value if value else license_url) + '</a>'
  38. if value:
  39. json_metadata.append({ 'name': 'license', 'value': value })
  40. if 'author' in metadata:
  41. metadata.pop('author')
  42. if 'company' in metadata:
  43. metadata.pop('company')
  44. if 'converted_from' in metadata:
  45. metadata.pop('converted_from')
  46. if 'license' in metadata:
  47. metadata.pop('license')
  48. if 'license_url' in metadata:
  49. metadata.pop('license_url')
  50. for name, value in metadata.items():
  51. json_metadata.append({ 'name': name, 'value': value })
  52. if len(json_metadata) > 0:
  53. json_model['metadata'] = json_metadata
  54. json_model['graphs'] = []
  55. graph = model.graph
  56. json_graph = {
  57. 'nodes': [],
  58. 'inputs': [],
  59. 'outputs': [],
  60. 'arguments': []
  61. }
  62. json_model['graphs'].append(json_graph)
  63. arguments = {}
  64. def tensor(tensor): # pylint: disable=unused-argument
  65. return {}
  66. def argument(name, tensor_type=None, initializer=None):
  67. if not name in arguments:
  68. json_argument = {}
  69. json_argument['name'] = name
  70. arguments[name] = len(json_graph['arguments'])
  71. json_graph['arguments'].append(json_argument)
  72. index = arguments[name]
  73. if tensor_type or initializer:
  74. json_argument = json_graph['arguments'][index]
  75. if initializer:
  76. json_argument['initializer'] = tensor(initializer)
  77. return index
  78. for value_info in graph.value_info:
  79. argument(value_info.name)
  80. for initializer in graph.initializer:
  81. argument(initializer.name, None, initializer)
  82. for node in graph.node:
  83. op_type = node.op_type
  84. json_node = {}
  85. json_node_type = {}
  86. json_node_type['name'] = op_type
  87. if self.category(op_type):
  88. json_node_type['category'] = self.category(op_type)
  89. json_node['type'] = json_node_type
  90. if node.name:
  91. json_node['name'] = node.name
  92. json_node['inputs'] = []
  93. for value in node.input:
  94. json_node['inputs'].append({
  95. 'name': 'X',
  96. 'arguments': [ argument(value) ]
  97. })
  98. json_node['outputs'] = []
  99. for value in node.output:
  100. json_node['outputs'].append({
  101. 'name': 'X',
  102. 'arguments': [ argument(value) ]
  103. })
  104. json_node['attributes'] = []
  105. for _ in node.attribute:
  106. if _.type == _AttributeType.UNDEFINED:
  107. attribute_type = None
  108. value = None
  109. elif _.type == _AttributeType.FLOAT:
  110. attribute_type = 'float32'
  111. value = _.f
  112. elif _.type == _AttributeType.INT:
  113. attribute_type = 'int64'
  114. value = _.i
  115. elif _.type == _AttributeType.STRING:
  116. attribute_type = 'string'
  117. value = _.s.decode('latin1' if op_type == 'Int8GivenTensorFill' else 'utf-8')
  118. elif _.type == _AttributeType.TENSOR:
  119. attribute_type = 'tensor'
  120. value = tensor(_.t)
  121. elif _.type == _AttributeType.GRAPH:
  122. attribute_type = 'tensor'
  123. raise Exception('Unsupported graph attribute type')
  124. elif _.type == _AttributeType.FLOATS:
  125. attribute_type = 'float32[]'
  126. value = list(_.floats)
  127. elif _.type == _AttributeType.INTS:
  128. attribute_type = 'int64[]'
  129. value = list(_.ints)
  130. elif _.type == _AttributeType.STRINGS:
  131. attribute_type = 'string[]'
  132. value = [ item.decode('utf-8') for item in _.strings ]
  133. elif _.type == _AttributeType.TENSORS:
  134. attribute_type = 'tensor[]'
  135. raise Exception('Unsupported tensors attribute type')
  136. elif _.type == _AttributeType.GRAPHS:
  137. attribute_type = 'graph[]'
  138. raise Exception('Unsupported graphs attribute type')
  139. elif _.type == _AttributeType.SPARSE_TENSOR:
  140. attribute_type = 'tensor'
  141. value = tensor(_.sparse_tensor)
  142. else:
  143. raise Exception("Unsupported attribute type '" + str(_.type) + "'.")
  144. json_attribute = {}
  145. json_attribute['name'] = _.name
  146. if attribute_type:
  147. json_attribute['type'] = attribute_type
  148. json_attribute['value'] = value
  149. json_node['attributes'].append(json_attribute)
  150. json_graph['nodes'].append(json_node)
  151. text = json.dumps(json_model, ensure_ascii=False)
  152. return text.encode('utf-8')
  153. categories = {
  154. 'Constant': 'Constant',
  155. 'Conv': 'Layer',
  156. 'ConvInteger': 'Layer',
  157. 'ConvTranspose': 'Layer',
  158. 'FC': 'Layer',
  159. 'RNN': 'Layer',
  160. 'LSTM': 'Layer',
  161. 'GRU': 'Layer',
  162. 'Gemm': 'Layer',
  163. 'FusedConv': 'Layer',
  164. 'Dropout': 'Dropout',
  165. 'Elu': 'Activation',
  166. 'HardSigmoid': 'Activation',
  167. 'LeakyRelu': 'Activation',
  168. 'PRelu': 'Activation',
  169. 'ThresholdedRelu': 'Activation',
  170. 'Relu': 'Activation',
  171. 'Selu': 'Activation',
  172. 'Sigmoid': 'Activation',
  173. 'Tanh': 'Activation',
  174. 'LogSoftmax': 'Activation',
  175. 'Softmax': 'Activation',
  176. 'Softplus': 'Activation',
  177. 'Softsign': 'Activation',
  178. 'Clip': 'Activation',
  179. 'BatchNormalization': 'Normalization',
  180. 'InstanceNormalization': 'Normalization',
  181. 'LpNormalization': 'Normalization',
  182. 'LRN': 'Normalization',
  183. 'Flatten': 'Shape',
  184. 'Reshape': 'Shape',
  185. 'Tile': 'Shape',
  186. 'Xor': 'Logic',
  187. 'Not': 'Logic',
  188. 'Or': 'Logic',
  189. 'Less': 'Logic',
  190. 'And': 'Logic',
  191. 'Greater': 'Logic',
  192. 'Equal': 'Logic',
  193. 'AveragePool': 'Pool',
  194. 'GlobalAveragePool': 'Pool',
  195. 'GlobalLpPool': 'Pool',
  196. 'GlobalMaxPool': 'Pool',
  197. 'LpPool': 'Pool',
  198. 'MaxPool': 'Pool',
  199. 'MaxRoiPool': 'Pool',
  200. 'Concat': 'Tensor',
  201. 'Slice': 'Tensor',
  202. 'Split': 'Tensor',
  203. 'Pad': 'Tensor',
  204. 'ImageScaler': 'Data',
  205. 'Crop': 'Data',
  206. 'Upsample': 'Data',
  207. 'Transpose': 'Transform',
  208. 'Gather': 'Transform',
  209. 'Unsqueeze': 'Transform',
  210. 'Squeeze': 'Transform',
  211. }
  212. def category(self, name):
  213. ''' Get category for type '''
  214. return self.categories[name] if name in self.categories else ''
  215. class _AttributeType(enum.IntEnum):
  216. UNDEFINED = 0
  217. FLOAT = 1
  218. INT = 2
  219. STRING = 3
  220. TENSOR = 4
  221. GRAPH = 5
  222. FLOATS = 6
  223. INTS = 7
  224. STRINGS = 8
  225. TENSORS = 9
  226. GRAPHS = 10
  227. SPARSE_TENSOR = 11
  228. SPARSE_TENSORS = 12
  229. TYPE_PROTO = 13
  230. TYPE_PROTOS = 14