onnx.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. ''' ONNX backend '''
  2. import collections
  3. import enum
  4. import json
  5. class ModelFactory:
  6. ''' ONNX backend model factory '''
  7. def serialize(self, model):
  8. ''' Serialize ONNX model to JSON message '''
  9. # import onnx.shape_inference
  10. # model = onnx.shape_inference.infer_shapes(model)
  11. json_model = {}
  12. json_model['signature'] = 'netron:onnx'
  13. json_model['format'] = 'ONNX' + (' v' + str(model.ir_version) if model.ir_version else '')
  14. if model.producer_name and len(model.producer_name) > 0:
  15. producer_version = ' v' + model.producer_version if model.producer_version else ''
  16. json_model['producer'] = model.producer_name + producer_version
  17. if model.model_version and model.model_version != 0:
  18. json_model['version'] = str(model.model_version)
  19. if model.doc_string and len(model.doc_string):
  20. json_model['description'] = str(model.doc_string)
  21. json_metadata = []
  22. metadata_props = [ [ entry.key, entry.value ] for entry in model.metadata_props ]
  23. metadata = collections.OrderedDict(metadata_props)
  24. value = metadata.get('converted_from')
  25. if value:
  26. json_metadata.append({ 'name': 'source', 'value': value })
  27. value = metadata.get('author')
  28. if value:
  29. json_metadata.append({ 'name': 'author', 'value': value })
  30. value = metadata.get('company')
  31. if value:
  32. json_metadata.append({ 'name': 'company', 'value': value })
  33. value = metadata.get('license')
  34. license_url = metadata.get('license_url')
  35. if license_url:
  36. value = '<a href=\'' + license_url + '\'>' + (value if value else license_url) + '</a>'
  37. if value:
  38. json_metadata.append({ 'name': 'license', 'value': value })
  39. if 'author' in metadata:
  40. metadata.pop('author')
  41. if 'company' in metadata:
  42. metadata.pop('company')
  43. if 'converted_from' in metadata:
  44. metadata.pop('converted_from')
  45. if 'license' in metadata:
  46. metadata.pop('license')
  47. if 'license_url' in metadata:
  48. metadata.pop('license_url')
  49. for name, value in metadata.items():
  50. json_metadata.append({ 'name': name, 'value': value })
  51. if len(json_metadata) > 0:
  52. json_model['metadata'] = json_metadata
  53. json_model['graphs'] = []
  54. graph = model.graph
  55. json_graph = {
  56. 'nodes': [],
  57. 'inputs': [],
  58. 'outputs': [],
  59. 'arguments': []
  60. }
  61. json_model['graphs'].append(json_graph)
  62. arguments = {}
  63. def tensor(tensor): # pylint: disable=unused-argument
  64. return {}
  65. def argument(name, tensor_type=None, initializer=None):
  66. if not name in arguments:
  67. json_argument = {}
  68. json_argument['name'] = name
  69. arguments[name] = len(json_graph['arguments'])
  70. json_graph['arguments'].append(json_argument)
  71. index = arguments[name]
  72. if tensor_type or initializer:
  73. json_argument = json_graph['arguments'][index]
  74. if initializer:
  75. json_argument['initializer'] = tensor(initializer)
  76. return index
  77. for value_info in graph.value_info:
  78. argument(value_info.name)
  79. for initializer in graph.initializer:
  80. argument(initializer.name, None, initializer)
  81. for node in graph.node:
  82. op_type = node.op_type
  83. json_node = {}
  84. json_node_type = {}
  85. json_node_type['name'] = op_type
  86. if self.category(op_type):
  87. json_node_type['category'] = self.category(op_type)
  88. json_node['type'] = json_node_type
  89. if node.name:
  90. json_node['name'] = node.name
  91. json_node['inputs'] = []
  92. for value in node.input:
  93. json_node['inputs'].append({
  94. 'name': 'X',
  95. 'arguments': [ argument(value) ]
  96. })
  97. json_node['outputs'] = []
  98. for value in node.output:
  99. json_node['outputs'].append({
  100. 'name': 'X',
  101. 'arguments': [ argument(value) ]
  102. })
  103. json_node['attributes'] = []
  104. for _ in node.attribute:
  105. if _.type == _AttributeType.UNDEFINED:
  106. attribute_type = None
  107. value = None
  108. elif _.type == _AttributeType.FLOAT:
  109. attribute_type = 'float32'
  110. value = _.f
  111. elif _.type == _AttributeType.INT:
  112. attribute_type = 'int64'
  113. value = _.i
  114. elif _.type == _AttributeType.STRING:
  115. attribute_type = 'string'
  116. value = _.s.decode('latin1' if op_type == 'Int8GivenTensorFill' else 'utf-8')
  117. elif _.type == _AttributeType.TENSOR:
  118. attribute_type = 'tensor'
  119. value = tensor(_.t)
  120. elif _.type == _AttributeType.GRAPH:
  121. attribute_type = 'tensor'
  122. raise Exception('Unsupported graph attribute type')
  123. elif _.type == _AttributeType.FLOATS:
  124. attribute_type = 'float32[]'
  125. value = list(_.floats)
  126. elif _.type == _AttributeType.INTS:
  127. attribute_type = 'int64[]'
  128. value = list(_.ints)
  129. elif _.type == _AttributeType.STRINGS:
  130. attribute_type = 'string[]'
  131. value = [ item.decode('utf-8') for item in _.strings ]
  132. elif _.type == _AttributeType.TENSORS:
  133. attribute_type = 'tensor[]'
  134. raise Exception('Unsupported tensors attribute type')
  135. elif _.type == _AttributeType.GRAPHS:
  136. attribute_type = 'graph[]'
  137. raise Exception('Unsupported graphs attribute type')
  138. elif _.type == _AttributeType.SPARSE_TENSOR:
  139. attribute_type = 'tensor'
  140. value = tensor(_.sparse_tensor)
  141. else:
  142. raise Exception("Unsupported attribute type '" + str(_.type) + "'.")
  143. json_attribute = {}
  144. json_attribute['name'] = _.name
  145. if attribute_type:
  146. json_attribute['type'] = attribute_type
  147. json_attribute['value'] = value
  148. json_node['attributes'].append(json_attribute)
  149. json_graph['nodes'].append(json_node)
  150. text = json.dumps(json_model, ensure_ascii=False)
  151. return text.encode('utf-8')
  152. categories = {
  153. 'Constant': 'Constant',
  154. 'Conv': 'Layer',
  155. 'ConvInteger': 'Layer',
  156. 'ConvTranspose': 'Layer',
  157. 'FC': 'Layer',
  158. 'RNN': 'Layer',
  159. 'LSTM': 'Layer',
  160. 'GRU': 'Layer',
  161. 'Gemm': 'Layer',
  162. 'FusedConv': 'Layer',
  163. 'Dropout': 'Dropout',
  164. 'Elu': 'Activation',
  165. 'HardSigmoid': 'Activation',
  166. 'LeakyRelu': 'Activation',
  167. 'PRelu': 'Activation',
  168. 'ThresholdedRelu': 'Activation',
  169. 'Relu': 'Activation',
  170. 'Selu': 'Activation',
  171. 'Sigmoid': 'Activation',
  172. 'Tanh': 'Activation',
  173. 'LogSoftmax': 'Activation',
  174. 'Softmax': 'Activation',
  175. 'Softplus': 'Activation',
  176. 'Softsign': 'Activation',
  177. 'Clip': 'Activation',
  178. 'BatchNormalization': 'Normalization',
  179. 'InstanceNormalization': 'Normalization',
  180. 'LpNormalization': 'Normalization',
  181. 'LRN': 'Normalization',
  182. 'Flatten': 'Shape',
  183. 'Reshape': 'Shape',
  184. 'Tile': 'Shape',
  185. 'Xor': 'Logic',
  186. 'Not': 'Logic',
  187. 'Or': 'Logic',
  188. 'Less': 'Logic',
  189. 'And': 'Logic',
  190. 'Greater': 'Logic',
  191. 'Equal': 'Logic',
  192. 'AveragePool': 'Pool',
  193. 'GlobalAveragePool': 'Pool',
  194. 'GlobalLpPool': 'Pool',
  195. 'GlobalMaxPool': 'Pool',
  196. 'LpPool': 'Pool',
  197. 'MaxPool': 'Pool',
  198. 'MaxRoiPool': 'Pool',
  199. 'Concat': 'Tensor',
  200. 'Slice': 'Tensor',
  201. 'Split': 'Tensor',
  202. 'Pad': 'Tensor',
  203. 'ImageScaler': 'Data',
  204. 'Crop': 'Data',
  205. 'Upsample': 'Data',
  206. 'Transpose': 'Transform',
  207. 'Gather': 'Transform',
  208. 'Unsqueeze': 'Transform',
  209. 'Squeeze': 'Transform',
  210. }
  211. def category(self, name):
  212. ''' Get category for type '''
  213. return self.categories[name] if name in self.categories else ''
  214. class _AttributeType(enum.IntEnum):
  215. UNDEFINED = 0
  216. FLOAT = 1
  217. INT = 2
  218. STRING = 3
  219. TENSOR = 4
  220. GRAPH = 5
  221. FLOATS = 6
  222. INTS = 7
  223. STRINGS = 8
  224. TENSORS = 9
  225. GRAPHS = 10
  226. SPARSE_TENSOR = 11
  227. SPARSE_TENSORS = 12
  228. TYPE_PROTO = 13
  229. TYPE_PROTOS = 14