| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889 |
- # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import argparse
- import torch
- import models
- def parse_args(parser):
- """
- Parse commandline arguments.
- """
- parser.add_argument('model_name', type=str,
- choices=['HiFi-GAN', 'FastPitch'],
- help='Name of the converted model')
- parser.add_argument('input_ckpt', type=str,
- help='Path to the input checkpoint')
- parser.add_argument('output_ckpt', default=None,
- help='Path to save the output checkpoint to')
- parser.add_argument('--cuda', action='store_true',
- help='Move model weights to GPU before export')
- parser.add_argument('--amp', action='store_true',
- help='Convert model to FP16 prior to saving')
- parser.add_argument('--load-from', type=str, default='pyt',
- choices=['pyt', 'ts'],
- help='Source checkpoint format')
- parser.add_argument('--convert-to', type=str, default='ts',
- choices=['ts', 'ttrt'],
- help='Output checkpoint format')
- return parser
- def main():
- """
- Exports PyT or TorchScript checkpoint to TorchScript or Torch-TensorRT.
- """
- parser = argparse.ArgumentParser(description='PyTorch model export',
- allow_abbrev=False)
- parser = parse_args(parser)
- args, unk_args = parser.parse_known_args()
- device = torch.device('cuda' if args.cuda else 'cpu')
- assert args.load_from != args.convert_to, \
- 'Load and convert formats must be different'
- print(f'Converting {args.model_name} from "{args.load_from}"'
- f' to "{args.convert_to}" ({device}).')
- if args.load_from == 'ts':
- ts_model, _ = models.load_and_setup_ts_model(args.model_name,
- args.input_ckpt, args.amp,
- device)
- else:
- assert args.load_from == 'pyt'
- pyt_model, _ = models.load_pyt_model_for_infer(
- args.model_name, parser, args.input_ckpt, args.amp, device,
- unk_args=unk_args, jitable=True)
- ts_model = torch.jit.script(pyt_model)
- if args.convert_to == 'ts':
- torch.jit.save(ts_model, args.output_ckpt)
- else:
- assert args.convert_to == 'ttrt'
- trt_model = models.convert_ts_to_trt('HiFi-GAN', ts_model, parser,
- args.amp, unk_args)
- torch.jit.save(trt_model, args.output_ckpt)
- print(f'{args.model_name}: checkpoint saved to {args.output_ckpt}.')
- if unk_args:
- print(f'Warning: encountered unknown program options: {unk_args}')
- if __name__ == '__main__':
- main()
|