pytorch.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567
  1. ''' PyTorch backend '''
  2. import json
  3. import os
  4. class ModelFactory: # pylint: disable=too-few-public-methods
  5. ''' PyTorch backend model factory '''
  6. def open(self, model): # pylint: disable=missing-function-docstring
  7. metadata = {}
  8. metadata_files = [
  9. ('pytorch-metadata.json', ''),
  10. ('onnx-metadata.json', 'onnx::')
  11. ]
  12. path = os.path.dirname(__file__)
  13. for entry in metadata_files:
  14. file = os.path.join(path, entry[0])
  15. with open(file, 'r', encoding='utf-8') as handle:
  16. for item in json.load(handle):
  17. name = entry[1] + item['name']
  18. metadata[name] = item
  19. metadata = Metadata(metadata)
  20. return _Model(metadata, model)
  21. class _Model: # pylint: disable=too-few-public-methods
  22. def __init__(self, metadata, model):
  23. self.graph = _Graph(metadata, model)
  24. def to_json(self):
  25. ''' Serialize model to JSON message '''
  26. import torch # pylint: disable=import-outside-toplevel,import-error
  27. json_model = {
  28. 'signature': 'netron:pytorch',
  29. 'format': 'TorchScript v' + torch.__version__,
  30. 'graphs': [ self.graph.to_json() ]
  31. }
  32. return json_model
  33. class _Graph: # pylint: disable=too-few-public-methods
  34. def __init__(self, metadata, model):
  35. self.metadata = metadata
  36. self.param = model
  37. self.value = model.graph
  38. self.nodes = []
  39. def _getattr(self, node):
  40. if node.kind() == 'prim::Param':
  41. return (self.param, '')
  42. if node.kind() == 'prim::GetAttr':
  43. name = node.s('name')
  44. obj, parent = self._getattr(node.input().node())
  45. return (getattr(obj, name), parent + '.' + name if len(parent) > 0 else name)
  46. raise NotImplementedError()
  47. def to_json(self): # pylint: disable=missing-function-docstring,too-many-locals,too-many-statements,too-many-branches
  48. import torch # pylint: disable=import-outside-toplevel,import-error
  49. graph = self.value
  50. json_graph = {
  51. 'arguments': [],
  52. 'nodes': [],
  53. 'inputs': [],
  54. 'outputs': []
  55. }
  56. data_type_map = dict([
  57. [ torch.float16, 'float16'], # pylint: disable=no-member
  58. [ torch.float32, 'float32'], # pylint: disable=no-member
  59. [ torch.float64, 'float64'], # pylint: disable=no-member
  60. [ torch.int32, 'int32'], # pylint: disable=no-member
  61. [ torch.int64, 'int64'], # pylint: disable=no-member
  62. ])
  63. def constant_value(node):
  64. if node.hasAttribute('value'):
  65. selector = node.kindOf('value')
  66. return getattr(node, selector)('value')
  67. return None
  68. arguments_map = {}
  69. def argument(value):
  70. if not value in arguments_map:
  71. json_argument = {}
  72. json_argument['name'] = str(value.unique())
  73. node = value.node()
  74. if node.kind() == "prim::GetAttr":
  75. tensor, name = self._getattr(node)
  76. if tensor is not None and len(name) > 0 and \
  77. isinstance(tensor, torch.Tensor):
  78. json_argument['name'] = name
  79. json_argument['initializer'] = {}
  80. json_tensor_shape = {
  81. 'dimensions': list(tensor.shape)
  82. }
  83. json_argument['type'] = {
  84. 'dataType': data_type_map[tensor.dtype],
  85. 'shape': json_tensor_shape
  86. }
  87. elif node.kind() == "prim::Constant":
  88. tensor = constant_value(node)
  89. if tensor and isinstance(tensor, torch.Tensor):
  90. json_argument['initializer'] = {}
  91. json_tensor_shape = {
  92. 'dimensions': list(tensor.shape)
  93. }
  94. json_argument['type'] = {
  95. 'dataType': data_type_map[tensor.dtype],
  96. 'shape': json_tensor_shape
  97. }
  98. elif value.isCompleteTensor():
  99. json_tensor_shape = {
  100. 'dimensions': value.type().sizes()
  101. }
  102. json_argument['type'] = {
  103. 'dataType': data_type_map[value.type().dtype()],
  104. 'shape': json_tensor_shape
  105. }
  106. arguments = json_graph['arguments']
  107. arguments_map[value] = len(arguments)
  108. arguments.append(json_argument)
  109. return arguments_map[value]
  110. for value in graph.inputs():
  111. if len(value.uses()) != 0 and value.type().kind() != 'ClassType':
  112. json_graph['inputs'].append({
  113. 'name': value.debugName(),
  114. 'arguments': [ argument(value) ]
  115. })
  116. for value in graph.outputs():
  117. json_graph['outputs'].append({
  118. 'name': value.debugName(),
  119. 'arguments': [ argument(value) ]
  120. })
  121. constants = {}
  122. for node in graph.nodes():
  123. if node.kind() == 'prim::Constant':
  124. constants[node] = 0
  125. lists = {}
  126. for node in graph.nodes():
  127. if node.kind() == 'prim::ListConstruct':
  128. if all(_.node() in constants for _ in node.inputs()):
  129. for _ in node.inputs():
  130. constants[_.node()] += 1
  131. lists[node] = 0
  132. def create_node(node):
  133. schema = node.schema() if hasattr(node, 'schema') else None
  134. schema = self.metadata.type(schema) if schema and schema != '(no schema)' else None
  135. json_node = {
  136. 'type': {
  137. 'name': node.kind(),
  138. 'category': schema['category'] if schema and 'category' in schema else ''
  139. },
  140. 'inputs': [],
  141. 'outputs': [],
  142. 'attributes': []
  143. }
  144. json_graph['nodes'].append(json_node)
  145. for name in node.attributeNames():
  146. selector = node.kindOf(name)
  147. value = getattr(node, selector)(name)
  148. json_attribute = {
  149. 'name': name,
  150. 'value': value
  151. }
  152. if torch.is_tensor(value):
  153. json_node['inputs'].append({
  154. 'name': name,
  155. 'arguments': []
  156. })
  157. else:
  158. json_node['attributes'].append(json_attribute)
  159. for i, value in enumerate(node.inputs()):
  160. parameter = schema['inputs'][i] if schema and i < len(schema['inputs']) else None
  161. parameter_name = parameter['name'] if parameter and 'name' in parameter else 'input'
  162. parameter_type = parameter['type'] if parameter and 'type' in parameter else None
  163. input_node = value.node()
  164. if input_node in constants:
  165. if parameter_type == 'Tensor' or value.type().kind() == 'TensorType':
  166. json_node['inputs'].append({
  167. 'name': parameter_name,
  168. 'arguments': [ argument(value) ]
  169. })
  170. else:
  171. json_attribute = {
  172. 'name': parameter_name,
  173. 'value': constant_value(input_node)
  174. }
  175. if parameter and 'type' in parameter:
  176. json_attribute['type'] = parameter['type']
  177. json_node['attributes'].append(json_attribute)
  178. constants[input_node] = constants[input_node] + 1
  179. continue
  180. if input_node in lists:
  181. json_attribute = {
  182. 'name': parameter_name,
  183. 'value': [ constant_value(_.node()) for _ in input_node.inputs() ]
  184. }
  185. json_node['attributes'].append(json_attribute)
  186. lists[input_node] += 1
  187. continue
  188. if input_node.kind() == 'prim::TupleUnpack':
  189. continue
  190. if input_node.kind() == 'prim::TupleConstruct':
  191. continue
  192. json_node['inputs'].append({
  193. 'name': parameter_name,
  194. 'arguments': [ argument(value) ]
  195. })
  196. for i, value in enumerate(node.outputs()):
  197. parameter = schema['outputs'][i] if schema and i < len(schema['outputs']) else None
  198. name = parameter['name'] if parameter and 'name' in parameter else 'output'
  199. json_node['outputs'].append({
  200. 'name': name,
  201. 'arguments': [ argument(value) ]
  202. })
  203. for node in graph.nodes():
  204. if node in lists:
  205. continue
  206. if node in constants:
  207. continue
  208. if node.kind() == 'prim::GetAttr':
  209. continue
  210. create_node(node)
  211. for node in graph.nodes():
  212. if node.kind() == 'prim::Constant' and \
  213. node in constants and constants[node] != len(node.output().uses()):
  214. create_node(node)
  215. if node.kind() == 'prim::ListConstruct' and \
  216. node in lists and lists[node] != len(node.output().uses()):
  217. create_node(node)
  218. return json_graph
  219. class Metadata: # pylint: disable=too-few-public-methods,missing-class-docstring
  220. def __init__(self, metadata):
  221. self.types = metadata
  222. self.cache = set()
  223. self._primitives = {
  224. 'int': 'int64', 'float': 'float32', 'bool': 'boolean', 'str': 'string'
  225. }
  226. def type(self, schema): # pylint: disable=missing-function-docstring
  227. key = schema.name if isinstance(schema, Schema) else schema.split('(', 1)[0].strip()
  228. if key not in self.cache:
  229. self.cache.add(key)
  230. schema = schema if isinstance(schema, Schema) else Schema(schema)
  231. arguments = list(filter(lambda _: \
  232. not(_.kwarg_only and hasattr(_, 'alias')), schema.arguments))
  233. returns = schema.returns
  234. value = self.types.setdefault(schema.name, { 'name': schema.name, })
  235. inputs = value.get('inputs', [])
  236. outputs = value.get('outputs', [])
  237. inputs = [ inputs[i] if i < len(inputs) else {} for i in range(len(arguments)) ]
  238. outputs = [ outputs[i] if i < len(outputs) else {} for i in range(len(returns)) ]
  239. value['inputs'] = inputs
  240. value['outputs'] = outputs
  241. for i, _ in enumerate(arguments):
  242. argument = inputs[i]
  243. argument['name'] = _.name
  244. self._argument(argument, getattr(_, 'type'))
  245. if hasattr(_, 'default'):
  246. argument['default'] = _.default
  247. for i, _ in enumerate(returns):
  248. argument = outputs[i]
  249. if hasattr(_, 'name'):
  250. argument['name'] = _.name
  251. self._argument(argument, getattr(_, 'type'))
  252. return self.types[key]
  253. def _argument(self, argument, value):
  254. optional = False
  255. argument_type = ''
  256. while not isinstance(value, str):
  257. if isinstance(value, Schema.OptionalType):
  258. value = value.element_type
  259. optional = True
  260. elif isinstance(value, Schema.ListType):
  261. size = str(value.size) if hasattr(value, 'size') else ''
  262. argument_type = '[' + size + ']' + argument_type
  263. value = value.element_type
  264. elif isinstance(value, Schema.DictType):
  265. value = str(value)
  266. else:
  267. name = value.name
  268. name = self._primitives[name] if name in self._primitives else name
  269. argument_type = name + argument_type
  270. break
  271. if argument_type:
  272. argument['type'] = argument_type
  273. else:
  274. argument.pop('type', None)
  275. if optional:
  276. argument['optional'] = True
  277. else:
  278. argument.pop('optional', False)
  279. class Schema: # pylint: disable=too-few-public-methods,missing-class-docstring
  280. def __init__(self, value):
  281. lexer = Schema.Lexer(value)
  282. lexer.whitespace(0)
  283. self._parse_name(lexer)
  284. lexer.whitespace(0)
  285. if lexer.kind == '(':
  286. self._parse_arguments(lexer)
  287. lexer.whitespace(0)
  288. lexer.expect('->')
  289. lexer.whitespace(0)
  290. self._parse_returns(lexer)
  291. def __str__(self):
  292. arguments = []
  293. kwarg_only = False
  294. for _ in self.arguments:
  295. if not kwarg_only and _.kwarg_only:
  296. kwarg_only = True
  297. arguments.append('*')
  298. arguments.append(_.__str__())
  299. if self.is_vararg:
  300. arguments.append('...')
  301. returns = ', '.join(map(lambda _: _.__str__(), self.returns))
  302. returns = returns if len(self.returns) == 1 else '(' + returns + ')'
  303. return self.name + '(' + ', '.join(arguments) + ') -> ' + returns
  304. def _parse_name(self, lexer):
  305. self.name = lexer.expect('id')
  306. if lexer.eat(':'):
  307. lexer.expect(':')
  308. self.name = self.name + '::' + lexer.expect('id')
  309. if lexer.eat('.'):
  310. self.name = self.name + '.' + lexer.expect('id')
  311. def _parse_arguments(self, lexer):
  312. self.arguments = []
  313. self.is_vararg = False
  314. self.kwarg_only = False
  315. lexer.expect('(')
  316. if not lexer.eat(')'):
  317. while True:
  318. lexer.whitespace(0)
  319. if self.is_vararg:
  320. raise NotImplementedError()
  321. if lexer.eat('*'):
  322. self.kwarg_only = True
  323. elif lexer.eat('...'):
  324. self.is_vararg = True
  325. else:
  326. self.arguments.append(Schema.Argument(lexer, False, self.kwarg_only))
  327. lexer.whitespace(0)
  328. if not lexer.eat(','):
  329. break
  330. lexer.expect(')')
  331. def _parse_returns(self, lexer):
  332. self.returns = []
  333. self.is_varret = False
  334. if lexer.eat('...'):
  335. self.is_varret = True
  336. elif lexer.eat('('):
  337. lexer.whitespace(0)
  338. if not lexer.eat(')'):
  339. while True:
  340. lexer.whitespace(0)
  341. if self.is_varret:
  342. raise NotImplementedError()
  343. if lexer.eat('...'):
  344. self.is_varret = True
  345. else:
  346. self.returns.append(Schema.Argument(lexer, True, False))
  347. lexer.whitespace(0)
  348. if not lexer.eat(','):
  349. break
  350. lexer.expect(')')
  351. lexer.whitespace(0)
  352. else:
  353. self.returns.append(Schema.Argument(lexer, True, False))
  354. class Argument: # pylint: disable=too-few-public-methods
  355. def __init__(self, lexer, is_return, kwarg_only):
  356. value = Schema.Type.parse(lexer)
  357. lexer.whitespace(0)
  358. while True:
  359. if lexer.eat('['):
  360. size = None
  361. if lexer.kind == '#':
  362. size = int(lexer.value)
  363. lexer.next()
  364. lexer.expect(']')
  365. value = Schema.ListType(value, size)
  366. elif lexer.eat('?'):
  367. value = Schema.OptionalType(value)
  368. elif lexer.kind == '(' and not hasattr(self, 'alias'):
  369. self.alias = self._parse_alias(lexer)
  370. else:
  371. break
  372. self.type = value
  373. if is_return:
  374. lexer.whitespace(0)
  375. self.kwarg_only = False
  376. if lexer.kind == 'id':
  377. self.name = lexer.expect('id')
  378. else:
  379. lexer.whitespace(1)
  380. self.kwarg_only = kwarg_only
  381. self.name = lexer.expect('id')
  382. lexer.whitespace(0)
  383. if lexer.eat('='):
  384. lexer.whitespace(0)
  385. self.default = self._parse_value(lexer)
  386. def __str__(self):
  387. alias = '(' + self.alias + ')' if hasattr(self, 'alias') else ''
  388. name = ' ' + self.name if hasattr(self, 'name') else ''
  389. default = '=' + self.default.__str__() if hasattr(self, 'default') else ''
  390. return self.type.__str__() + alias + name + default
  391. def _parse_value(self, lexer):
  392. if lexer.kind == 'id':
  393. if lexer.value in ('True', 'False'):
  394. value = bool(lexer.value == 'True')
  395. elif lexer.value == 'None':
  396. value = None
  397. elif lexer.value in ('Mean', 'contiguous_format', 'long'):
  398. value = lexer.value
  399. else:
  400. raise NotImplementedError()
  401. elif lexer.kind == '#':
  402. value = float(lexer.value) if \
  403. lexer.value.find('.') != -1 or lexer.value.find('e') != -1 else \
  404. int(lexer.value)
  405. elif lexer.kind == 'string':
  406. value = lexer.value[1:-1]
  407. elif lexer.eat('['):
  408. value = []
  409. if not lexer.eat(']'):
  410. while True:
  411. lexer.whitespace(0)
  412. value.append(self._parse_value(lexer))
  413. lexer.whitespace(0)
  414. if not lexer.eat(','):
  415. break
  416. lexer.expect(']')
  417. return value
  418. else:
  419. raise NotImplementedError()
  420. lexer.next()
  421. return value
  422. def _parse_alias(self, lexer):
  423. value = ''
  424. lexer.expect('(')
  425. while not lexer.eat(')'):
  426. value += lexer.value
  427. lexer.next()
  428. return value
  429. class Type: # pylint: disable=too-few-public-methods,missing-class-docstring
  430. def __init__(self, name):
  431. self.name = name
  432. def __str__(self):
  433. return self.name
  434. @staticmethod
  435. def parse(lexer): # pylint: disable=missing-function-docstring
  436. name = lexer.expect('id')
  437. while lexer.eat('.'):
  438. name = name + '.' + lexer.expect('id')
  439. if name == 'Dict':
  440. lexer.expect('(')
  441. lexer.whitespace(0)
  442. key_type = Schema.Type.parse(lexer)
  443. lexer.whitespace(0)
  444. lexer.expect(',')
  445. lexer.whitespace(0)
  446. value_type = Schema.Type.parse(lexer)
  447. lexer.whitespace(0)
  448. lexer.expect(')')
  449. return Schema.DictType(key_type, value_type)
  450. return Schema.Type(name)
  451. class OptionalType: # pylint: disable=too-few-public-methods,missing-class-docstring
  452. def __init__(self, element_type):
  453. self.element_type = element_type
  454. def __str__(self):
  455. return self.element_type.__str__() + '?'
  456. class ListType: # pylint: disable=too-few-public-methods,missing-class-docstring
  457. def __init__(self, element_type, size):
  458. self.element_type = element_type
  459. if size:
  460. self.size = size
  461. def __str__(self):
  462. size = self.size.__str__() if hasattr(self, 'size') else ''
  463. return self.element_type.__str__() + '[' + size + ']'
  464. class DictType:
  465. def __init__(self, key_type, value_type):
  466. self._key_type = key_type
  467. self._value_type = value_type
  468. def __str__(self):
  469. return 'Dict[' + str(self._key_type) + ', ' + str(self._value_type) + ']'
  470. def getKeyType(self): # pylint: disable=invalid-name,missing-function-docstring
  471. return self._key_type
  472. def getValueType(self): # pylint: disable=invalid-name,,missing-function-docstring
  473. return self._value_type
  474. class Lexer: # pylint: disable=too-few-public-methods,missing-class-docstring
  475. def __init__(self, buffer):
  476. self.buffer = buffer
  477. self.position = 0
  478. self.value = ''
  479. self.next()
  480. def eat(self, kind): # pylint: disable=missing-function-docstring
  481. if self.kind != kind:
  482. return None
  483. value = self.value
  484. self.next()
  485. return value
  486. def expect(self, kind): # pylint: disable=missing-function-docstring
  487. if self.kind != kind:
  488. raise SyntaxError("Unexpected '" + self.kind + "' instead of '" + kind + "'.")
  489. value = self.value
  490. self.next()
  491. return value
  492. def whitespace(self, count): # pylint: disable=missing-function-docstring
  493. if self.kind != ' ':
  494. if count > len(self.value):
  495. raise IndexError()
  496. return False
  497. self.next()
  498. return True
  499. def next(self): # pylint: disable=missing-function-docstring,too-many-branches
  500. self.position += len(self.value)
  501. i = self.position
  502. if i >= len(self.buffer):
  503. self.kind = '\0'
  504. self.value = ''
  505. elif self.buffer[i] == ' ':
  506. while self.buffer[i] == ' ':
  507. i += 1
  508. self.kind = ' '
  509. self.value = self.buffer[self.position:i]
  510. elif self.buffer[i] == '.' and self.buffer[i+1] == '.' and self.buffer[i+2] == '.':
  511. self.kind = '...'
  512. self.value = '...'
  513. elif self.buffer[i] in ('(', ')', ':', '.', '[', ']', ',', '=', '?', '!', '*', '|'):
  514. self.kind = self.buffer[i]
  515. self.value = self.buffer[i]
  516. elif (self.buffer[i] >= 'a' and self.buffer[i] <= 'z') or \
  517. (self.buffer[i] >= 'A' and self.buffer[i] <= 'Z') or self.buffer[i] == '_':
  518. i += 1
  519. while i < len(self.buffer) and \
  520. ((self.buffer[i] >= 'a' and self.buffer[i] <= 'z') or \
  521. (self.buffer[i] >= 'A' and self.buffer[i] <= 'Z') or \
  522. (self.buffer[i] >= '0' and self.buffer[i] <= '9') or self.buffer[i] == '_'):
  523. i += 1
  524. self.kind = 'id'
  525. self.value = self.buffer[self.position:i]
  526. elif self.buffer[i] == '-' and self.buffer[i+1] == '>':
  527. self.kind = '->'
  528. self.value = '->'
  529. elif (self.buffer[i] >= '0' and self.buffer[i] <= '9') or self.buffer[i] == '-':
  530. i += 1
  531. while i < len(self.buffer) and \
  532. ((self.buffer[i] >= '0' and self.buffer[i] <= '9') or \
  533. self.buffer[i] == '.' or self.buffer[i] == 'e' or self.buffer[i] == '-'):
  534. i += 1
  535. self.kind = '#'
  536. self.value = self.buffer[self.position:i]
  537. elif self.buffer[i] in ("'", '"'):
  538. quote = self.buffer[i]
  539. i += 1
  540. while i < len(self.buffer) and self.buffer[i] != quote:
  541. i += 2 if self.buffer[i] == '\\' and self.buffer[i+1] in ("'", '"', '\\') else 1
  542. i += 1
  543. self.kind = 'string'
  544. self.value = self.buffer[self.position:i]
  545. else:
  546. raise NotImplementedError("Unsupported token at " + self.position)