onnx.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. ''' ONNX backend '''
  2. import collections
  3. import json
  4. class ModelFactory:
  5. ''' ONNX backend model factory '''
  6. def serialize(self, model):
  7. ''' Serialize ONNX model to JSON message '''
  8. print('Experimental')
  9. # import onnx.shape_inference
  10. # model = onnx.shape_inference.infer_shapes(model)
  11. import onnx.onnx_pb # pylint: disable=import-outside-toplevel
  12. json_model = {}
  13. json_model['signature'] = 'netron:onnx'
  14. json_model['format'] = 'ONNX' + (' v' + str(model.ir_version) if model.ir_version else '')
  15. if model.producer_name and len(model.producer_name) > 0:
  16. producer_version = ' v' + model.producer_version if model.producer_version else ''
  17. json_model['producer'] = model.producer_name + producer_version
  18. if model.model_version and model.model_version != 0:
  19. json_model['version'] = str(model.model_version)
  20. if model.doc_string and len(model.doc_string):
  21. json_model['description'] = str(model.doc_string)
  22. json_metadata = []
  23. metadata_props = [ [ entry.key, entry.value ] for entry in model.metadata_props ]
  24. metadata = collections.OrderedDict(metadata_props)
  25. value = metadata.get('converted_from')
  26. if value:
  27. json_metadata.append({ 'name': 'source', 'value': value })
  28. value = metadata.get('author')
  29. if value:
  30. json_metadata.append({ 'name': 'author', 'value': value })
  31. value = metadata.get('company')
  32. if value:
  33. json_metadata.append({ 'name': 'company', 'value': value })
  34. value = metadata.get('license')
  35. license_url = metadata.get('license_url')
  36. if license_url:
  37. value = '<a href=\'' + license_url + '\'>' + (value if value else license_url) + '</a>'
  38. if value:
  39. json_metadata.append({ 'name': 'license', 'value': value })
  40. if 'author' in metadata:
  41. metadata.pop('author')
  42. if 'company' in metadata:
  43. metadata.pop('company')
  44. if 'converted_from' in metadata:
  45. metadata.pop('converted_from')
  46. if 'license' in metadata:
  47. metadata.pop('license')
  48. if 'license_url' in metadata:
  49. metadata.pop('license_url')
  50. for name, value in metadata.items():
  51. json_metadata.append({ 'name': name, 'value': value })
  52. if len(json_metadata) > 0:
  53. json_model['metadata'] = json_metadata
  54. json_model['graphs'] = []
  55. graph = model.graph
  56. json_graph = {
  57. 'nodes': [],
  58. 'inputs': [],
  59. 'outputs': [],
  60. 'arguments': []
  61. }
  62. json_model['graphs'].append(json_graph)
  63. arguments = {}
  64. def tensor(tensor):
  65. return {}
  66. def argument(name, tensor_type=None, initializer=None):
  67. if not name in arguments:
  68. json_argument = {}
  69. json_argument['name'] = name
  70. arguments[name] = len(json_graph['arguments'])
  71. json_graph['arguments'].append(json_argument)
  72. index = arguments[name]
  73. if tensor_type or initializer:
  74. json_argument = json_graph['arguments'][index]
  75. if initializer:
  76. json_argument['initializer'] = tensor(initializer)
  77. return index
  78. for value_info in graph.value_info:
  79. argument(value_info.name)
  80. for initializer in graph.initializer:
  81. argument(initializer.name, None, initializer)
  82. for node in graph.node:
  83. op_type = node.op_type
  84. json_node = {}
  85. json_node_type = {}
  86. json_node_type['name'] = op_type
  87. if self.category(op_type):
  88. json_node_type['category'] = self.category(op_type)
  89. json_node['type'] = json_node_type
  90. if node.name:
  91. json_node['name'] = node.name
  92. json_node['inputs'] = []
  93. for value in node.input:
  94. json_node['inputs'].append({
  95. 'name': 'X',
  96. 'arguments': [ argument(value) ]
  97. })
  98. json_node['outputs'] = []
  99. for value in node.output:
  100. json_node['outputs'].append({
  101. 'name': 'X',
  102. 'arguments': [ argument(value) ]
  103. })
  104. json_node['attributes'] = []
  105. AttributeProto = onnx.onnx_pb.AttributeProto
  106. for _ in node.attribute:
  107. if _.type == AttributeProto.UNDEFINED:
  108. attribute_type = None
  109. value = None
  110. elif _.type == AttributeProto.FLOAT:
  111. attribute_type = 'float32'
  112. value = _.f
  113. elif _.type == AttributeProto.INT:
  114. attribute_type = 'int64'
  115. value = _.i
  116. elif _.type == AttributeProto.STRING:
  117. attribute_type = 'string'
  118. value = _.s.decode('latin1' if op_type == 'Int8GivenTensorFill' else 'utf-8')
  119. elif _.type == AttributeProto.TENSOR:
  120. attribute_type = 'tensor'
  121. value = tensor(_.t)
  122. elif _.type == AttributeProto.GRAPH:
  123. attribute_type = 'tensor'
  124. raise Exception('Unsupported graph attribute type')
  125. elif _.type == AttributeProto.FLOATS:
  126. attribute_type = 'float32[]'
  127. value = [ item for item in _.floats ]
  128. elif _.type == AttributeProto.INTS:
  129. attribute_type = 'int64[]'
  130. value = [ item for item in _.ints ]
  131. elif _.type == AttributeProto.STRINGS:
  132. attribute_type = 'string[]'
  133. value = [ item.decode('utf-8') for item in _.strings ]
  134. elif _.type == AttributeProto.TENSORS:
  135. attribute_type = 'tensor[]'
  136. raise Exception('Unsupported tensors attribute type')
  137. elif _.type == AttributeProto.GRAPHS:
  138. attribute_type = 'graph[]'
  139. raise Exception('Unsupported graphs attribute type')
  140. elif _.type == AttributeProto.SPARSE_TENSOR:
  141. attribute_type = 'tensor'
  142. value = tensor(_.sparse_tensor)
  143. else:
  144. raise Exception("Unsupported attribute type '" + str(_.type) + "'.")
  145. json_attribute = {}
  146. json_attribute['name'] = _.name
  147. if attribute_type:
  148. json_attribute['type'] = attribute_type
  149. json_attribute['value'] = value
  150. json_node['attributes'].append(json_attribute)
  151. json_graph['nodes'].append(json_node)
  152. text = json.dumps(json_model, ensure_ascii=False)
  153. return text.encode('utf-8')
  154. categories = {
  155. 'Constant': 'Constant',
  156. 'Conv': 'Layer',
  157. 'ConvInteger': 'Layer',
  158. 'ConvTranspose': 'Layer',
  159. 'FC': 'Layer',
  160. 'RNN': 'Layer',
  161. 'LSTM': 'Layer',
  162. 'GRU': 'Layer',
  163. 'Gemm': 'Layer',
  164. 'FusedConv': 'Layer',
  165. 'Dropout': 'Dropout',
  166. 'Elu': 'Activation',
  167. 'HardSigmoid': 'Activation',
  168. 'LeakyRelu': 'Activation',
  169. 'PRelu': 'Activation',
  170. 'ThresholdedRelu': 'Activation',
  171. 'Relu': 'Activation',
  172. 'Selu': 'Activation',
  173. 'Sigmoid': 'Activation',
  174. 'Tanh': 'Activation',
  175. 'LogSoftmax': 'Activation',
  176. 'Softmax': 'Activation',
  177. 'Softplus': 'Activation',
  178. 'Softsign': 'Activation',
  179. 'Clip': 'Activation',
  180. 'BatchNormalization': 'Normalization',
  181. 'InstanceNormalization': 'Normalization',
  182. 'LpNormalization': 'Normalization',
  183. 'LRN': 'Normalization',
  184. 'Flatten': 'Shape',
  185. 'Reshape': 'Shape',
  186. 'Tile': 'Shape',
  187. 'Xor': 'Logic',
  188. 'Not': 'Logic',
  189. 'Or': 'Logic',
  190. 'Less': 'Logic',
  191. 'And': 'Logic',
  192. 'Greater': 'Logic',
  193. 'Equal': 'Logic',
  194. 'AveragePool': 'Pool',
  195. 'GlobalAveragePool': 'Pool',
  196. 'GlobalLpPool': 'Pool',
  197. 'GlobalMaxPool': 'Pool',
  198. 'LpPool': 'Pool',
  199. 'MaxPool': 'Pool',
  200. 'MaxRoiPool': 'Pool',
  201. 'Concat': 'Tensor',
  202. 'Slice': 'Tensor',
  203. 'Split': 'Tensor',
  204. 'Pad': 'Tensor',
  205. 'ImageScaler': 'Data',
  206. 'Crop': 'Data',
  207. 'Upsample': 'Data',
  208. 'Transpose': 'Transform',
  209. 'Gather': 'Transform',
  210. 'Unsqueeze': 'Transform',
  211. 'Squeeze': 'Transform',
  212. }
  213. def category(self, name):
  214. ''' Get category for type '''
  215. return self.categories[name] if name in self.categories else ''