| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879 |
- from __future__ import unicode_literals
- from __future__ import print_function
- import io
- import json
- import pydoc
- import os
- import re
- import sys
- def metadata():
- json_file = os.path.join(os.path.dirname(__file__), '../src/pytorch-metadata.json')
- json_data = open(json_file).read()
- json_root = json.loads(json_data)
- schema_map = {}
- for entry in json_root:
- name = entry['name']
- schema = entry['schema']
- schema_map[name] = schema
- for entry in json_root:
- name = entry['name']
- schema = entry['schema']
- if 'package' in schema:
- class_name = schema['package'] + '.' + name
- # print(class_name)
- class_definition = pydoc.locate(class_name)
- if not class_definition:
- raise Exception('\'' + class_name + '\' not found.')
- docstring = class_definition.__doc__
- if not docstring:
- raise Exception('\'' + class_name + '\' missing __doc__.')
- # print(docstring)
- with io.open(json_file, 'w', newline='') as fout:
- json_data = json.dumps(json_root, sort_keys=True, indent=2)
- for line in json_data.splitlines():
- line = line.rstrip()
- if sys.version_info[0] < 3:
- line = unicode(line)
- fout.write(line)
- fout.write('\n')
- def download_torchvision_model(name, input):
- folder = os.path.expandvars('${test}/data/pytorch')
- if not os.path.exists(folder):
- os.makedirs(folder)
- base = folder + '/' + name.split('.')[-1]
- model = pydoc.locate(name)(pretrained=True)
- import torch
- torch.save(model, base + '.pkl.pth', _use_new_zipfile_serialization=False);
- torch.save(model, base + '.zip.pth', _use_new_zipfile_serialization=True);
- model.eval()
- torch.jit.script(model).save(base + '.pt')
- traced_model = torch.jit.trace(model, torch.rand(input))
- torch.jit.save(traced_model, base + '_traced.pt')
- def zoo():
- if not os.environ.get('test'):
- os.environ['test'] = os.path.normpath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../test'))
- download_torchvision_model('torchvision.models.alexnet', [ 1, 3, 299, 299 ])
- download_torchvision_model('torchvision.models.densenet161', [ 1, 3, 224, 224 ])
- download_torchvision_model('torchvision.models.inception_v3', [ 1, 3, 299, 299 ])
- download_torchvision_model('torchvision.models.mobilenet_v2', [ 1, 3, 224, 224 ])
- download_torchvision_model('torchvision.models.resnet18', [ 1, 3, 224, 224 ])
- download_torchvision_model('torchvision.models.resnet101', [ 1, 3, 224, 224 ])
- download_torchvision_model('torchvision.models.shufflenet_v2_x1_0', [ 1, 3, 224, 224 ])
- download_torchvision_model('torchvision.models.squeezenet1_1', [ 1, 3, 224, 224 ])
- download_torchvision_model('torchvision.models.video.r3d_18', [ 1, 3, 4, 112, 112 ])
- download_torchvision_model('torchvision.models.vgg11_bn', [ 1, 3, 224, 224 ])
- download_torchvision_model('torchvision.models.vgg16', [ 1, 3, 224, 224 ])
- if __name__ == '__main__':
- command_table = { 'metadata': metadata, 'zoo': zoo }
- command = sys.argv[1];
- command_table[command]()
|