pytorch-script.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. from __future__ import unicode_literals
  2. from __future__ import print_function
  3. import io
  4. import json
  5. import pydoc
  6. import os
  7. import re
  8. import sys
  9. def metadata():
  10. json_file = os.path.join(os.path.dirname(__file__), '../src/pytorch-metadata.json')
  11. json_data = open(json_file).read()
  12. json_root = json.loads(json_data)
  13. schema_map = {}
  14. for entry in json_root:
  15. name = entry['name']
  16. schema = entry['schema']
  17. schema_map[name] = schema
  18. for entry in json_root:
  19. name = entry['name']
  20. schema = entry['schema']
  21. if 'package' in schema:
  22. class_name = schema['package'] + '.' + name
  23. # print(class_name)
  24. class_definition = pydoc.locate(class_name)
  25. if not class_definition:
  26. raise Exception('\'' + class_name + '\' not found.')
  27. docstring = class_definition.__doc__
  28. if not docstring:
  29. raise Exception('\'' + class_name + '\' missing __doc__.')
  30. # print(docstring)
  31. with io.open(json_file, 'w', newline='') as fout:
  32. json_data = json.dumps(json_root, sort_keys=True, indent=2)
  33. for line in json_data.splitlines():
  34. line = line.rstrip()
  35. if sys.version_info[0] < 3:
  36. line = unicode(line)
  37. fout.write(line)
  38. fout.write('\n')
  39. def download_torchvision_model(name, input):
  40. folder = os.path.expandvars('${test}/data/pytorch')
  41. if not os.path.exists(folder):
  42. os.makedirs(folder)
  43. base = folder + '/' + name.split('.')[-1]
  44. model = pydoc.locate(name)(pretrained=True)
  45. import torch
  46. torch.save(model, base + '.pkl.pth', _use_new_zipfile_serialization=False);
  47. torch.save(model, base + '.zip.pth', _use_new_zipfile_serialization=True);
  48. model.eval()
  49. torch.jit.script(model).save(base + '.pt')
  50. traced_model = torch.jit.trace(model, torch.rand(input))
  51. torch.jit.save(traced_model, base + '_traced.pt')
  52. def zoo():
  53. if not os.environ.get('test'):
  54. os.environ['test'] = os.path.normpath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../test'))
  55. download_torchvision_model('torchvision.models.alexnet', [ 1, 3, 299, 299 ])
  56. download_torchvision_model('torchvision.models.densenet161', [ 1, 3, 224, 224 ])
  57. download_torchvision_model('torchvision.models.inception_v3', [ 1, 3, 299, 299 ])
  58. download_torchvision_model('torchvision.models.mobilenet_v2', [ 1, 3, 224, 224 ])
  59. download_torchvision_model('torchvision.models.resnet18', [ 1, 3, 224, 224 ])
  60. download_torchvision_model('torchvision.models.resnet101', [ 1, 3, 224, 224 ])
  61. download_torchvision_model('torchvision.models.shufflenet_v2_x1_0', [ 1, 3, 224, 224 ])
  62. download_torchvision_model('torchvision.models.squeezenet1_1', [ 1, 3, 224, 224 ])
  63. download_torchvision_model('torchvision.models.video.r3d_18', [ 1, 3, 4, 112, 112 ])
  64. download_torchvision_model('torchvision.models.vgg11_bn', [ 1, 3, 224, 224 ])
  65. download_torchvision_model('torchvision.models.vgg16', [ 1, 3, 224, 224 ])
  66. if __name__ == '__main__':
  67. command_table = { 'metadata': metadata, 'zoo': zoo }
  68. command = sys.argv[1];
  69. command_table[command]()