2
0

onnx_script.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. """ ONNX metadata script """
  2. import collections
  3. import json
  4. import os
  5. import re
  6. import onnx.backend.test.case
  7. import onnx.defs
  8. import onnx.onnx_ml_pb2
  9. import onnxruntime
  10. attribute_type_table = [
  11. "undefined",
  12. "float32",
  13. "int64",
  14. "string",
  15. "tensor",
  16. "graph",
  17. "float32[]",
  18. "int64[]",
  19. "string[]",
  20. "tensor[]",
  21. "graph[]",
  22. "sparse_tensor",
  23. "sparse_tensor[]",
  24. "type_proto",
  25. "type_proto[]"
  26. ]
  27. def _format_description(description):
  28. def replace_line(match):
  29. link = match.group(1)
  30. url = match.group(2)
  31. if not url.startswith("http://") and not url.startswith("https://"):
  32. url = "https://github.com/onnx/onnx/blob/master/docs/" + url
  33. return "[" + link + "](" + url + ")"
  34. return re.sub('\\[(.+)\\]\\(([^ ]+?)( "(.+)")?\\)', replace_line, description)
  35. def _format_range(value):
  36. return "∞" if value == 2147483647 else str(value)
  37. class OnnxSchema:
  38. """ ONNX schema """
  39. def __init__(self, schema, snippets):
  40. self.schema = schema
  41. self.snippets = snippets
  42. self.name = self.schema.name
  43. self.module = self.schema.domain if self.schema.domain else "ai.onnx"
  44. self.version = self.schema.since_version
  45. self.key = self.name + ":" + self.module + ":" + str(self.version).zfill(4)
  46. def _get_attr_type(self, attribute_type, attribute_name, op_type, op_domain):
  47. key = op_domain + ":" + op_type + ":" + attribute_name
  48. if key in (":Cast:to", ":EyeLike:dtype", ":RandomNormal:dtype"):
  49. return "DataType"
  50. return attribute_type_table[attribute_type]
  51. def _get_attr_default_value(self, attr_value):
  52. if attr_value.HasField("i"):
  53. return attr_value.i
  54. if attr_value.HasField("s"):
  55. return attr_value.s.decode("utf8")
  56. if attr_value.HasField("f"):
  57. return attr_value.f
  58. return None
  59. def _update_attributes(self, value, schema):
  60. target = value["attributes"] = []
  61. attributes = sorted(schema.attributes.items())
  62. for _ in collections.OrderedDict(attributes).values():
  63. value = {}
  64. value["name"] = _.name
  65. attr_type = self._get_attr_type(_.type, _.name, schema.name, schema.domain)
  66. if attr_type:
  67. value["type"] = attr_type
  68. value["required"] = _.required
  69. default_value = self._get_attr_default_value(_.default_value)
  70. if default_value:
  71. value["default"] = default_value
  72. description = _format_description(_.description)
  73. if len(description) > 0:
  74. value["description"] = description
  75. target.append(value)
  76. def _update_inputs(self, value, inputs):
  77. target = value["inputs"] = []
  78. for _ in inputs:
  79. value = {}
  80. value["name"] = _.name
  81. value["type"] = _.type_str
  82. if _.option == onnx.defs.OpSchema.FormalParameterOption.Optional:
  83. value["option"] = "optional"
  84. elif _.option == onnx.defs.OpSchema.FormalParameterOption.Variadic:
  85. value["list"] = True
  86. description = _format_description(_.description)
  87. if len(description) > 0:
  88. value["description"] = description
  89. target.append(value)
  90. def _update_outputs(self, value, outputs):
  91. target = value["outputs"] = []
  92. for _ in outputs:
  93. value = {}
  94. value["name"] = _.name
  95. value["type"] = _.type_str
  96. if _.option == onnx.defs.OpSchema.FormalParameterOption.Optional:
  97. value["option"] = "optional"
  98. elif _.option == onnx.defs.OpSchema.FormalParameterOption.Variadic:
  99. value["list"] = True
  100. description = _format_description(_.description)
  101. if len(description) > 0:
  102. value["description"] = description
  103. target.append(value)
  104. def _update_type_constraints(self, value, type_constraints):
  105. value["type_constraints"] = []
  106. for _ in type_constraints:
  107. value["type_constraints"].append({
  108. "description": _.description,
  109. "type_param_str": _.type_param_str,
  110. "allowed_type_strs": _.allowed_type_strs
  111. })
  112. def _update_snippets(self, value, snippets):
  113. target = value["examples"] = []
  114. for summary, code in sorted(snippets):
  115. lines = code.splitlines()
  116. while len(lines) > 0 and re.search("\\s*#", lines[-1]):
  117. lines.pop()
  118. if len(lines) > 0 and len(lines[-1]) == 0:
  119. lines.pop()
  120. target.append({
  121. "summary": summary,
  122. "code": "\n".join(lines)
  123. })
  124. def to_dict(self):
  125. """ Serialize model to JSON message """
  126. value = {}
  127. value["name"] = self.name
  128. value["module"] = self.module
  129. value["version"] = self.version
  130. if self.schema.support_level != onnx.defs.OpSchema.SupportType.COMMON:
  131. value["status"] = self.schema.support_level.name.lower()
  132. description = _format_description(self.schema.doc.lstrip())
  133. if len(description) > 0:
  134. value["description"] = description
  135. if self.schema.attributes:
  136. self._update_attributes(value, self.schema)
  137. if self.schema.inputs:
  138. self._update_inputs(value, self.schema.inputs)
  139. value["min_input"] = self.schema.min_input
  140. value["max_input"] = self.schema.max_input
  141. if self.schema.outputs:
  142. self._update_outputs(value, self.schema.outputs)
  143. value["min_output"] = self.schema.min_output
  144. value["max_output"] = self.schema.max_output
  145. if self.schema.min_input != self.schema.max_input:
  146. value["inputs_range"] = _format_range(self.schema.min_input) + " - " \
  147. + _format_range(self.schema.max_input)
  148. if self.schema.min_output != self.schema.max_output:
  149. value["outputs_range"] = _format_range(self.schema.min_output) + " - " \
  150. + _format_range(self.schema.max_output)
  151. if self.schema.type_constraints:
  152. self._update_type_constraints(value, self.schema.type_constraints)
  153. if self.name in self.snippets:
  154. self._update_snippets(value, self.snippets[self.name])
  155. return value
  156. class OnnxRuntimeSchema:
  157. """ ONNX Runtime schema """
  158. def __init__(self, schema):
  159. self.schema = schema
  160. self.name = self.schema.name
  161. self.module = self.schema.domain if self.schema.domain else "ai.onnx"
  162. self.version = self.schema.since_version
  163. self.key = self.name + ":" + self.module + ":" + str(self.version).zfill(4)
  164. def _get_attr_type(self, attribute_type):
  165. return attribute_type_table[attribute_type]
  166. def _get_attr_default_value(self, attr_value):
  167. if attr_value.HasField("i"):
  168. return attr_value.i
  169. if attr_value.HasField("s"):
  170. return attr_value.s.decode("utf8")
  171. if attr_value.HasField("f"):
  172. return attr_value.f
  173. return None
  174. def _update_attributes(self, value, schema):
  175. target = value["attributes"] = []
  176. attributes = sorted(schema.attributes.items())
  177. for _ in collections.OrderedDict(attributes).values():
  178. value = {}
  179. value["name"] = _.name
  180. attribute_type = self._get_attr_type(_.type)
  181. if attribute_type:
  182. value["type"] = attribute_type
  183. value["required"] = _.required
  184. default_value = onnx.onnx_ml_pb2.AttributeProto()
  185. default_value.ParseFromString(_._default_value)
  186. default_value = self._get_attr_default_value(default_value)
  187. if default_value:
  188. value["default"] = default_value
  189. description = _format_description(_.description)
  190. if len(description) > 0:
  191. value["description"] = description
  192. target.append(value)
  193. def _update_inputs(self, value, inputs):
  194. target = value["inputs"] = []
  195. for _ in inputs:
  196. value = {}
  197. value["name"] = _.name
  198. value["type"] = _.typeStr
  199. schemadef = onnxruntime.capi.onnxruntime_pybind11_state.schemadef
  200. if _.option == schemadef.OpSchema.FormalParameterOption.Optional:
  201. value["option"] = "optional"
  202. elif _.option == schemadef.OpSchema.FormalParameterOption.Variadic:
  203. value["list"] = True
  204. description = _format_description(_.description)
  205. if len(description) > 0:
  206. value["description"] = description
  207. target.append(value)
  208. def _update_outputs(self, value, outputs):
  209. target = value["outputs"] = []
  210. for _ in outputs:
  211. value = {}
  212. value["name"] = _.name
  213. value["type"] = _.typeStr
  214. schemadef = onnxruntime.capi.onnxruntime_pybind11_state.schemadef
  215. if _.option == schemadef.OpSchema.FormalParameterOption.Optional:
  216. value["option"] = "optional"
  217. elif _.option == schemadef.OpSchema.FormalParameterOption.Variadic:
  218. value["list"] = True
  219. description = _format_description(_.description)
  220. if len(description) > 0:
  221. value["description"] = description
  222. target.append(value)
  223. def _update_type_constraints(self, value, type_constraints):
  224. value["type_constraints"] = []
  225. for _ in type_constraints:
  226. value["type_constraints"].append({
  227. "description": _.description,
  228. "type_param_str": _.type_param_str,
  229. "allowed_type_strs": _.allowed_type_strs
  230. })
  231. def to_dict(self):
  232. """ Serialize model to JSON message """
  233. value = {}
  234. value["name"] = self.name
  235. value["module"] = self.module
  236. value["version"] = self.version
  237. schemadef = onnxruntime.capi.onnxruntime_pybind11_state.schemadef
  238. if self.schema.support_level != schemadef.OpSchema.SupportType.COMMON:
  239. value["status"] = self.schema.support_level.name.lower()
  240. if self.schema.doc:
  241. description = _format_description(self.schema.doc.lstrip())
  242. if len(description) > 0:
  243. value["description"] = description
  244. if self.schema.attributes:
  245. self._update_attributes(value, self.schema)
  246. if self.schema.inputs:
  247. self._update_inputs(value, self.schema.inputs)
  248. value["min_input"] = self.schema.min_input
  249. value["max_input"] = self.schema.max_input
  250. if self.schema.outputs:
  251. self._update_outputs(value, self.schema.outputs)
  252. value["min_output"] = self.schema.min_output
  253. value["max_output"] = self.schema.max_output
  254. if self.schema.min_input != self.schema.max_input:
  255. value["inputs_range"] = _format_range(self.schema.min_input) + " - " \
  256. + _format_range(self.schema.max_input)
  257. if self.schema.min_output != self.schema.max_output:
  258. value["outputs_range"] = _format_range(self.schema.min_output) + " - " \
  259. + _format_range(self.schema.max_output)
  260. if self.schema.type_constraints:
  261. self._update_type_constraints(value, self.schema.type_constraints)
  262. return value
  263. def _metadata():
  264. root_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
  265. file = os.path.join(root_dir, "source", "onnx-metadata.json")
  266. with open(file, encoding="utf-8") as handle:
  267. content = handle.read()
  268. categories = {}
  269. content = json.loads(content)
  270. for schema in content:
  271. if "category" in schema:
  272. name = schema["name"]
  273. categories[name] = schema["category"]
  274. types = collections.OrderedDict()
  275. numpy = __import__("numpy")
  276. with numpy.errstate(all="ignore"):
  277. snippets = onnx.backend.test.case.collect_snippets()
  278. for schema in onnx.defs.get_all_schemas_with_history():
  279. schema = OnnxSchema(schema, snippets)
  280. if schema.key not in types:
  281. types[schema.key] = schema.to_dict()
  282. for schema in onnxruntime.capi.onnxruntime_pybind11_state.get_all_operator_schema():
  283. schema = OnnxRuntimeSchema(schema)
  284. if schema.key not in types:
  285. types[schema.key] = schema.to_dict()
  286. for schema in content:
  287. key = f"{schema['name']}:{schema['module']}:{str(schema['version']).zfill(4)}"
  288. if key not in types:
  289. types[key] = schema
  290. types = [types[key] for key in sorted(types)]
  291. for schema in types:
  292. name = schema["name"]
  293. # copy = schema.copy()
  294. # schema.clear()
  295. # schema['name'] = name
  296. # schema['module'] = copy['module']
  297. if name in categories:
  298. schema["category"] = categories[name]
  299. # for key, value in copy.items():
  300. # if key not in schema:
  301. # schema[key] = value
  302. content = json.dumps(types, indent=2)
  303. with open(file, "w", encoding="utf-8") as handle:
  304. handle.write(content)
  305. def main():
  306. _metadata()
  307. if __name__ == "__main__":
  308. main()