onnx.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. """ ONNX backend """
  2. import collections
  3. import enum
  4. import json
  5. import os
  6. class ModelFactory:
  7. """ ONNX backend model factory """
  8. def open(self, model):
  9. return _Model(model)
  10. class _Model:
  11. def __init__(self, model):
  12. """ Serialize ONNX model to JSON message """
  13. # import onnx.shape_inference
  14. # model = onnx.shape_inference.infer_shapes(model)
  15. self.value = model
  16. self.metadata = _Metadata()
  17. self.graph = _Graph(model.graph, self.metadata)
  18. def to_json(self):
  19. """ Serialize model to JSON message """
  20. model = self.value
  21. json_model = {}
  22. json_model["signature"] = "netron:onnx"
  23. ir_version = model.ir_version
  24. json_model["format"] = "ONNX" + (f" v{ir_version}" if ir_version else "")
  25. if model.producer_name and len(model.producer_name) > 0:
  26. producer_version = model.producer_version
  27. producer_version = f" v{producer_version}" if producer_version else ""
  28. json_model["producer"] = model.producer_name + producer_version
  29. if model.model_version and model.model_version != 0:
  30. json_model["version"] = str(model.model_version)
  31. if model.doc_string and len(model.doc_string):
  32. json_model["description"] = str(model.doc_string)
  33. json_metadata = self._metadata_props(model.metadata_props)
  34. if len(json_metadata) > 0:
  35. json_model["metadata"] = json_metadata
  36. json_model["graphs"] = []
  37. json_model["graphs"].append(self.graph.to_json())
  38. return json_model
  39. def _metadata_props(self, metadata_props):
  40. json_metadata = []
  41. metadata_props = [ [ entry.key, entry.value ] for entry in metadata_props ]
  42. metadata = collections.OrderedDict(metadata_props)
  43. value = metadata.get("converted_from")
  44. if value:
  45. json_metadata.append({ "name": "source", "value": value })
  46. value = metadata.get("author")
  47. if value:
  48. json_metadata.append({ "name": "author", "value": value })
  49. value = metadata.get("company")
  50. if value:
  51. json_metadata.append({ "name": "company", "value": value })
  52. value = metadata.get("license")
  53. license_url = metadata.get("license_url")
  54. if license_url:
  55. value = f"<a href='{license_url}'>{value if value else license_url}</a>"
  56. if value:
  57. json_metadata.append({ "name": "license", "value": value })
  58. if "author" in metadata:
  59. metadata.pop("author")
  60. if "company" in metadata:
  61. metadata.pop("company")
  62. if "converted_from" in metadata:
  63. metadata.pop("converted_from")
  64. if "license" in metadata:
  65. metadata.pop("license")
  66. if "license_url" in metadata:
  67. metadata.pop("license_url")
  68. for name, value in metadata.items():
  69. json_metadata.append({ "name": name, "value": value })
  70. return json_metadata
  71. class _Graph:
  72. def __init__(self, graph, metadata):
  73. self.metadata = metadata
  74. self.graph = graph
  75. self.values_index = {}
  76. self.values = []
  77. def _tensor(self, tensor):
  78. return {}
  79. def value(self, name, tensor_type=None, initializer=None):
  80. if name not in self.values_index:
  81. argument = _Value(name, tensor_type, initializer)
  82. self.values_index[name] = len(self.values)
  83. self.values.append(argument)
  84. index = self.values_index[name]
  85. # argument.set_initializer(initializer)
  86. return index
  87. def attribute(self, _, op_type):
  88. if _.type == _AttributeType.UNDEFINED:
  89. attribute_type = None
  90. value = None
  91. elif _.type == _AttributeType.FLOAT:
  92. attribute_type = "float32"
  93. value = _.f
  94. elif _.type == _AttributeType.INT:
  95. attribute_type = "int64"
  96. value = _.i
  97. elif _.type == _AttributeType.STRING:
  98. attribute_type = "string"
  99. encoding = "latin1" if op_type == "Int8GivenTensorFill" else "utf-8"
  100. value = _.s.decode(encoding)
  101. elif _.type == _AttributeType.TENSOR:
  102. attribute_type = "tensor"
  103. value = self._tensor(_.t)
  104. elif _.type == _AttributeType.GRAPH:
  105. attribute_type = "graph"
  106. raise Exception("Unsupported graph attribute type")
  107. elif _.type == _AttributeType.FLOATS:
  108. attribute_type = "float32[]"
  109. value = list(_.floats)
  110. elif _.type == _AttributeType.INTS:
  111. attribute_type = "int64[]"
  112. value = list(_.ints)
  113. elif _.type == _AttributeType.STRINGS:
  114. attribute_type = "string[]"
  115. value = [ item.decode("utf-8") for item in _.strings ]
  116. elif _.type == _AttributeType.TENSORS:
  117. attribute_type = "tensor[]"
  118. raise Exception("Unsupported tensors attribute type")
  119. elif _.type == _AttributeType.GRAPHS:
  120. attribute_type = "graph[]"
  121. raise Exception("Unsupported graphs attribute type")
  122. elif _.type == _AttributeType.SPARSE_TENSOR:
  123. attribute_type = "tensor"
  124. value = self._tensor(_.sparse_tensor)
  125. else:
  126. raise Exception("Unsupported attribute type '" + str(_.type) + "'.")
  127. json_attribute = {}
  128. json_attribute["name"] = _.name
  129. if attribute_type:
  130. json_attribute["type"] = attribute_type
  131. json_attribute["value"] = value
  132. return json_attribute
  133. def to_json(self):
  134. graph = self.graph
  135. json_graph = {
  136. "nodes": [],
  137. "inputs": [],
  138. "outputs": [],
  139. "values": []
  140. }
  141. for value_info in graph.value_info:
  142. self.value(value_info.name)
  143. for initializer in graph.initializer:
  144. self.value(initializer.name, None, initializer)
  145. for node in graph.node:
  146. op_type = node.op_type
  147. json_node = {}
  148. json_node_type = {}
  149. json_node_type["name"] = op_type
  150. type_metadata = self.metadata.type(op_type)
  151. if type_metadata and "category" in type_metadata:
  152. json_node_type["category"] = type_metadata["category"]
  153. json_node["type"] = json_node_type
  154. if node.name:
  155. json_node["name"] = node.name
  156. json_node["inputs"] = []
  157. for value in node.input:
  158. json_node["inputs"].append({
  159. "name": "X",
  160. "value": [ self.value(value) ]
  161. })
  162. json_node["outputs"] = []
  163. for value in node.output:
  164. json_node["outputs"].append({
  165. "name": "X",
  166. "value": [ self.value(value) ]
  167. })
  168. json_node["attributes"] = []
  169. for _ in node.attribute:
  170. json_attribute = self.attribute(_, op_type)
  171. json_node["attributes"].append(json_attribute)
  172. json_graph["nodes"].append(json_node)
  173. for _ in self.values:
  174. json_graph["values"].append(_.to_json())
  175. return json_graph
  176. class _Value:
  177. def __init__(self, name, tensor_type=None, initializer=None):
  178. self.name = name
  179. self.type = tensor_type
  180. self.initializer = initializer
  181. def to_json(self):
  182. target = {}
  183. target["name"] = self.name
  184. # if self.initializer:
  185. # target['initializer'] = {}
  186. return target
  187. class _Metadata:
  188. metadata = {}
  189. def __init__(self):
  190. metadata_file = os.path.join(os.path.dirname(__file__), "onnx-metadata.json")
  191. with open(metadata_file, encoding="utf-8") as file:
  192. for item in json.load(file):
  193. name = item["name"]
  194. self.metadata[name] = item
  195. def type(self, name):
  196. if name in self.metadata:
  197. return self.metadata[name]
  198. return {}
  199. class _AttributeType(enum.IntEnum):
  200. UNDEFINED = 0
  201. FLOAT = 1
  202. INT = 2
  203. STRING = 3
  204. TENSOR = 4
  205. GRAPH = 5
  206. FLOATS = 6
  207. INTS = 7
  208. STRINGS = 8
  209. TENSORS = 9
  210. GRAPHS = 10
  211. SPARSE_TENSOR = 11
  212. SPARSE_TENSORS = 12
  213. TYPE_PROTO = 13
  214. TYPE_PROTOS = 14