export_model.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. # Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import logging
  16. import paddle
  17. import program
  18. from dali import build_dataloader
  19. from utils.mode import Mode
  20. from utils.save_load import init_ckpt
  21. from utils.logger import setup_dllogger
  22. from utils.config import parse_args, print_args
  23. def main(args):
  24. '''
  25. Export saved model params to paddle inference model
  26. '''
  27. setup_dllogger(args.trt_export_log_path)
  28. if args.show_config:
  29. print_args(args)
  30. eval_dataloader = build_dataloader(args, Mode.EVAL)
  31. startup_prog = paddle.static.Program()
  32. eval_prog = paddle.static.Program()
  33. eval_fetchs, _, eval_feeds, _ = program.build(
  34. args,
  35. eval_prog,
  36. startup_prog,
  37. step_each_epoch=len(eval_dataloader),
  38. is_train=False)
  39. eval_prog = eval_prog.clone(for_test=True)
  40. device = paddle.set_device('gpu')
  41. exe = paddle.static.Executor(device)
  42. exe.run(startup_prog)
  43. path_to_ckpt = args.from_checkpoint
  44. if path_to_ckpt is None:
  45. logging.warning(
  46. 'The --from-checkpoint is not set, model weights will not be initialize.'
  47. )
  48. else:
  49. init_ckpt(path_to_ckpt, eval_prog, exe)
  50. logging.info('Checkpoint path is %s', path_to_ckpt)
  51. save_inference_dir = args.trt_inference_dir
  52. paddle.static.save_inference_model(
  53. path_prefix=os.path.join(save_inference_dir, args.model_arch_name),
  54. feed_vars=[eval_feeds['data']],
  55. fetch_vars=[eval_fetchs['label'][0]],
  56. executor=exe,
  57. program=eval_prog)
  58. logging.info('Successully export inference model to %s',
  59. save_inference_dir)
  60. if __name__ == '__main__':
  61. paddle.enable_static()
  62. main(parse_args(including_trt=True))