pytorch.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. ''' PyTorch backend '''
  2. import json
  3. import os
  4. class ModelFactory:
  5. ''' PyTorch backend model factory '''
  6. def serialize(self, model):
  7. ''' Serialize PyTorch model to JSON message '''
  8. print('Experimental')
  9. import torch # pylint: disable=import-outside-toplevel
  10. metadata = {}
  11. metadata_file = os.path.join(os.path.dirname(__file__), 'onnx-metadata.json')
  12. with open(metadata_file, 'r', encoding='utf-8') as file:
  13. for item in json.load(file):
  14. name = 'onnx::' + item['name']
  15. metadata[name] = item
  16. json_model = {}
  17. json_model['signature'] = 'netron:pytorch'
  18. json_model['format'] = 'TorchScript'
  19. json_model['graphs'] = []
  20. json_graph = {}
  21. json_graph['arguments'] = []
  22. json_graph['nodes'] = []
  23. json_graph['inputs'] = []
  24. json_graph['outputs'] = []
  25. json_model['graphs'].append(json_graph)
  26. data_type_map = dict([
  27. [ torch.float16, 'float16'], # pylint: disable=no-member
  28. [ torch.float32, 'float32'], # pylint: disable=no-member
  29. [ torch.float64, 'float64'], # pylint: disable=no-member
  30. [ torch.int32, 'int32'], # pylint: disable=no-member
  31. [ torch.int64, 'int64'], # pylint: disable=no-member
  32. ])
  33. arguments_map = {}
  34. def argument(value):
  35. if not value in arguments_map:
  36. json_argument = {}
  37. json_argument['name'] = str(value.unique()) + '>' + str(value.node().kind())
  38. if value.isCompleteTensor():
  39. json_tensor_shape = {
  40. 'dimensions': value.type().sizes()
  41. }
  42. json_argument['type'] = {
  43. 'dataType': data_type_map[value.type().dtype()],
  44. 'shape': json_tensor_shape
  45. }
  46. if value.node().kind() == "prim::Param":
  47. json_argument['initializer'] = {}
  48. arguments = json_graph['arguments']
  49. arguments_map[value] = len(arguments)
  50. arguments.append(json_argument)
  51. return arguments_map[value]
  52. for input_value in model.inputs():
  53. json_graph['inputs'].append({
  54. 'name': input_value.debugName(),
  55. 'arguments': [ argument(input_value) ]
  56. })
  57. for output_value in model.outputs():
  58. json_graph['outputs'].append({
  59. 'name': output_value.debugName(),
  60. 'arguments': [ argument(output_value) ]
  61. })
  62. for node in model.nodes():
  63. kind = node.kind()
  64. json_type = {
  65. 'name': kind
  66. }
  67. json_node = {
  68. 'type': json_type,
  69. 'inputs': [],
  70. 'outputs': [],
  71. 'attributes': []
  72. }
  73. json_graph['nodes'].append(json_node)
  74. for name in node.attributeNames():
  75. value = node[name]
  76. json_attribute = {
  77. 'name': name,
  78. 'value': value
  79. }
  80. if torch.is_tensor(value):
  81. json_node['inputs'].append({
  82. 'name': name,
  83. 'arguments': []
  84. })
  85. else:
  86. json_node['attributes'].append(json_attribute)
  87. for input_value in node.inputs():
  88. json_parameter = {
  89. 'name': 'x',
  90. 'arguments': [ argument(input_value) ]
  91. }
  92. json_node['inputs'].append(json_parameter)
  93. for output_value in node.outputs():
  94. json_node['outputs'].append({
  95. 'name': 'x',
  96. 'arguments': [ argument(output_value) ]
  97. })
  98. text = json.dumps(json_model, ensure_ascii=False)
  99. return text.encode('utf-8')