export.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import torch
  16. import models
  17. def parse_args(parser):
  18. """
  19. Parse commandline arguments.
  20. """
  21. parser.add_argument('model_name', type=str,
  22. choices=['HiFi-GAN', 'FastPitch'],
  23. help='Name of the converted model')
  24. parser.add_argument('input_ckpt', type=str,
  25. help='Path to the input checkpoint')
  26. parser.add_argument('output_ckpt', default=None,
  27. help='Path to save the output checkpoint to')
  28. parser.add_argument('--cuda', action='store_true',
  29. help='Move model weights to GPU before export')
  30. parser.add_argument('--amp', action='store_true',
  31. help='Convert model to FP16 prior to saving')
  32. parser.add_argument('--load-from', type=str, default='pyt',
  33. choices=['pyt', 'ts'],
  34. help='Source checkpoint format')
  35. parser.add_argument('--convert-to', type=str, default='ts',
  36. choices=['ts', 'ttrt'],
  37. help='Output checkpoint format')
  38. return parser
  39. def main():
  40. """
  41. Exports PyT or TorchScript checkpoint to TorchScript or Torch-TensorRT.
  42. """
  43. parser = argparse.ArgumentParser(description='PyTorch model export',
  44. allow_abbrev=False)
  45. parser = parse_args(parser)
  46. args, unk_args = parser.parse_known_args()
  47. device = torch.device('cuda' if args.cuda else 'cpu')
  48. assert args.load_from != args.convert_to, \
  49. 'Load and convert formats must be different'
  50. print(f'Converting {args.model_name} from "{args.load_from}"'
  51. f' to "{args.convert_to}" ({device}).')
  52. if args.load_from == 'ts':
  53. ts_model, _ = models.load_and_setup_ts_model(args.model_name,
  54. args.input_ckpt, args.amp,
  55. device)
  56. else:
  57. assert args.load_from == 'pyt'
  58. pyt_model, _ = models.load_pyt_model_for_infer(
  59. args.model_name, parser, args.input_ckpt, args.amp, device,
  60. unk_args=unk_args, jitable=True)
  61. ts_model = torch.jit.script(pyt_model)
  62. if args.convert_to == 'ts':
  63. torch.jit.save(ts_model, args.output_ckpt)
  64. else:
  65. assert args.convert_to == 'ttrt'
  66. trt_model = models.convert_ts_to_trt('HiFi-GAN', ts_model, parser,
  67. args.amp, unk_args)
  68. torch.jit.save(trt_model, args.output_ckpt)
  69. print(f'{args.model_name}: checkpoint saved to {args.output_ckpt}.')
  70. if unk_args:
  71. print(f'Warning: encountered unknown program options: {unk_args}')
  72. if __name__ == '__main__':
  73. main()