pytorch.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. """ PyTorch backend """
  2. import json
  3. import os
  4. class ModelFactory:
  5. """ PyTorch backend model factory """
  6. def open(self, model):
  7. metadata = {}
  8. metadata_files = [
  9. ("pytorch-metadata.json", ""),
  10. ("onnx-metadata.json", "onnx::")
  11. ]
  12. path = os.path.dirname(__file__)
  13. for entry in metadata_files:
  14. file = os.path.join(path, entry[0])
  15. with open(file, encoding="utf-8") as handle:
  16. for item in json.load(handle):
  17. name = entry[1] + item["name"].split("(", 1)[0]
  18. metadata[name] = item
  19. metadata = Metadata(metadata)
  20. return _Model(metadata, model)
  21. class _Model:
  22. def __init__(self, metadata, model):
  23. self.graph = _Graph(metadata, model)
  24. def to_json(self):
  25. """ Serialize model to JSON message """
  26. import torch
  27. json_model = {
  28. "signature": "netron:pytorch",
  29. "format": "TorchScript v" + torch.__version__,
  30. "graphs": [ self.graph.to_json() ]
  31. }
  32. return json_model
  33. class _Graph:
  34. def __init__(self, metadata, model):
  35. self.metadata = metadata
  36. self.param = model
  37. self.value = model.graph
  38. self.nodes = []
  39. def _getattr(self, node):
  40. if node.kind() == "prim::Param":
  41. return (self.param, "")
  42. if node.kind() == "prim::GetAttr":
  43. name = node.s("name")
  44. obj, parent = self._getattr(node.input().node())
  45. value = getattr(obj, name)
  46. path = parent + "." + name if len(parent) > 0 else name
  47. return (value, path)
  48. raise NotImplementedError()
  49. def to_json(self):
  50. import torch
  51. graph = self.value
  52. json_graph = {
  53. "values": [],
  54. "nodes": [],
  55. "inputs": [],
  56. "outputs": []
  57. }
  58. data_type_map = dict([
  59. [ torch.float16, "float16"],
  60. [ torch.float32, "float32"],
  61. [ torch.float64, "float64"],
  62. [ torch.int32, "int32"],
  63. [ torch.int64, "int64"],
  64. ])
  65. def constant_value(node):
  66. if node.hasAttribute("value"):
  67. selector = node.kindOf("value")
  68. return getattr(node, selector)("value")
  69. return None
  70. values_index = {}
  71. def argument(value):
  72. if value not in values_index:
  73. json_value = {}
  74. json_value["name"] = str(value.unique())
  75. node = value.node()
  76. if node.kind() == "prim::GetAttr":
  77. tensor, name = self._getattr(node)
  78. if tensor is not None and len(name) > 0 and \
  79. isinstance(tensor, torch.Tensor):
  80. json_tensor_shape = {
  81. "dimensions": list(tensor.shape)
  82. }
  83. tensor_type = {
  84. "dataType": data_type_map[tensor.dtype],
  85. "shape": json_tensor_shape
  86. }
  87. json_value["name"] = name
  88. json_value["type"] = tensor_type
  89. json_value["initializer"] = { "type": tensor_type }
  90. elif node.kind() == "prim::Constant":
  91. tensor = constant_value(node)
  92. if tensor and isinstance(tensor, torch.Tensor):
  93. json_tensor_shape = {
  94. "dimensions": list(tensor.shape)
  95. }
  96. tensor_type = {
  97. "dataType": data_type_map[tensor.dtype],
  98. "shape": json_tensor_shape
  99. }
  100. json_value["type"] = tensor_type
  101. json_value["initializer"] = { "type": tensor_type }
  102. elif value.isCompleteTensor():
  103. json_tensor_shape = {
  104. "dimensions": value.type().sizes()
  105. }
  106. json_value["type"] = {
  107. "dataType": data_type_map[value.type().dtype()],
  108. "shape": json_tensor_shape
  109. }
  110. values = json_graph["values"]
  111. values_index[value] = len(values)
  112. values.append(json_value)
  113. return values_index[value]
  114. for value in graph.inputs():
  115. if len(value.uses()) != 0 and value.type().kind() != "ClassType":
  116. json_graph["inputs"].append({
  117. "name": value.debugName(),
  118. "value": [ argument(value) ]
  119. })
  120. for value in graph.outputs():
  121. json_graph["outputs"].append({
  122. "name": value.debugName(),
  123. "value": [ argument(value) ]
  124. })
  125. constants = {}
  126. for node in graph.nodes():
  127. if node.kind() == "prim::Constant":
  128. constants[node] = 0
  129. lists = {}
  130. for node in graph.nodes():
  131. if node.kind() == "prim::ListConstruct":
  132. if all(_.node() in constants for _ in node.inputs()):
  133. for _ in node.inputs():
  134. constants[_.node()] += 1
  135. lists[node] = 0
  136. def create_node(node):
  137. identifier = node.schema()
  138. schema, category = self.metadata.type(identifier)
  139. json_node = {
  140. "type": {
  141. "name": node.kind(),
  142. "category": category
  143. },
  144. "inputs": [],
  145. "outputs": [],
  146. "attributes": []
  147. }
  148. json_graph["nodes"].append(json_node)
  149. for name in node.attributeNames():
  150. selector = node.kindOf(name)
  151. value = getattr(node, selector)(name)
  152. json_attribute = {
  153. "name": name,
  154. "value": value
  155. }
  156. if torch.is_tensor(value):
  157. json_node["inputs"].append({
  158. "name": name,
  159. "value": []
  160. })
  161. else:
  162. json_node["attributes"].append(json_attribute)
  163. for i, value in enumerate(node.inputs()):
  164. arg = None
  165. if schema and i < len(schema.arguments):
  166. arg = schema.arguments[i]
  167. parameter_name = arg.name if arg else "input"
  168. real_type = arg.real_type if arg else None
  169. input_node = value.node()
  170. if input_node in constants:
  171. if (real_type and real_type.kind() == "TensorType") or \
  172. value.type().kind() == "TensorType":
  173. json_node["inputs"].append({
  174. "name": parameter_name,
  175. "value": [ argument(value) ]
  176. })
  177. else:
  178. json_attribute = {
  179. "name": parameter_name,
  180. "value": constant_value(input_node)
  181. }
  182. if real_type:
  183. json_attribute["type"] = self._argument_type(real_type)
  184. json_node["attributes"].append(json_attribute)
  185. constants[input_node] = constants[input_node] + 1
  186. continue
  187. if input_node in lists:
  188. value = [ constant_value(_.node()) for _ in input_node.inputs() ]
  189. json_attribute = {
  190. "name": parameter_name,
  191. "value": value
  192. }
  193. json_node["attributes"].append(json_attribute)
  194. lists[input_node] += 1
  195. continue
  196. if input_node.kind() == "prim::TupleUnpack":
  197. continue
  198. if input_node.kind() == "prim::TupleConstruct":
  199. continue
  200. json_node["inputs"].append({
  201. "name": parameter_name,
  202. "value": [ argument(value) ]
  203. })
  204. for i, value in enumerate(node.outputs()):
  205. ret = schema.returns[i] if schema and i < len(schema.returns) else None
  206. name = ret.name if ret else "output"
  207. json_node["outputs"].append({
  208. "name": name,
  209. "value": [ argument(value) ]
  210. })
  211. for node in graph.nodes():
  212. if node in lists:
  213. continue
  214. if node in constants:
  215. continue
  216. if node.kind() == "prim::GetAttr":
  217. continue
  218. create_node(node)
  219. for node in graph.nodes():
  220. if node.kind() == "prim::Constant" and \
  221. node in constants and constants[node] != len(node.output().uses()):
  222. create_node(node)
  223. if node.kind() == "prim::ListConstruct" and \
  224. node in lists and lists[node] != len(node.output().uses()):
  225. create_node(node)
  226. return json_graph
  227. def _argument_type(self, value):
  228. if value.kind() == "TensorType":
  229. return "Tensor"
  230. if value.kind() == "OptionalType":
  231. element_type = self._argument_type(value.getElementType())
  232. return f"{element_type}?"
  233. if value.kind() == "ListType":
  234. element_type = self._argument_type(value.getElementType())
  235. size = str(value.size) if hasattr(value, "size") else ""
  236. return f"{element_type}[{size}]"
  237. if value.kind() == "DictType":
  238. key_type = self._argument_type(value.getKeyType())
  239. value_type = self._argument_type(value.getValueType())
  240. return f"Dict({key_type}, {value_type})"
  241. if value.kind() == "TupleType":
  242. elements = []
  243. for element in value.elements():
  244. elements.append(self._argument_type(element))
  245. return f"({', '.join(elements)})"
  246. if value.kind() == "IntType":
  247. return "int64"
  248. if value.kind() == "SymIntType":
  249. return "SymInt"
  250. if value.kind() == "FloatType":
  251. return "float32"
  252. if value.kind() == "BoolType":
  253. return "boolean"
  254. if value.kind() == "StringType":
  255. return "string"
  256. if value.kind() == "NumberType":
  257. return "Scalar"
  258. if value.kind() == "ScalarTypeType":
  259. return "ScalarType"
  260. if value.kind() == "LayoutType":
  261. return "Layout"
  262. if value.kind() == "MemoryFormatType":
  263. return "MemoryFormat"
  264. if value.kind() == "DeviceObjType":
  265. return "Device"
  266. if value.kind() == "GeneratorType":
  267. return "Generator"
  268. if value.kind() == "VarType":
  269. return value.annotation_str
  270. raise NotImplementedError()
  271. class Metadata:
  272. def __init__(self, metadata):
  273. self.types = metadata
  274. def type(self, identifier):
  275. if identifier == "(no schema)":
  276. return (None, "")
  277. key = identifier.split("(", 1)[0]
  278. value = self.types.get(key)
  279. category = value["category"] if value and "category" in value else ""
  280. name, overload_name = key.split(".", 1) if key.find(".") > 0 else (key, "")
  281. import torch
  282. schema = torch._C._get_schema(name, overload_name)
  283. return (schema, category)