postprocess_pretrained_ckpt.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. # Copyright (c) 2020 NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import collections
  16. import json
  17. import os
  18. import tensorflow as tf
  19. from utils import log, heading
  20. from run_pretraining import PretrainingConfig
  21. from modeling import PretrainingModel
  22. def from_pretrained_ckpt(args):
  23. config = PretrainingConfig(
  24. model_name='postprocessing',
  25. data_dir='postprocessing',
  26. generator_hidden_size=0.3333333,
  27. )
  28. # Padding for divisibility by 8
  29. if config.vocab_size % 8 != 0:
  30. config.vocab_size += 8 - (config.vocab_size % 8)
  31. if args.amp:
  32. policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16", loss_scale="dynamic")
  33. tf.keras.mixed_precision.experimental.set_policy(policy)
  34. print('Compute dtype: %s' % policy.compute_dtype) # Compute dtype: float16
  35. print('Variable dtype: %s' % policy.variable_dtype) # Variable dtype: float32
  36. # Set up model
  37. model = PretrainingModel(config)
  38. # Load checkpoint
  39. checkpoint = tf.train.Checkpoint(step=tf.Variable(1), model=model)
  40. checkpoint.restore(args.pretrained_checkpoint).expect_partial()
  41. log(" ** Restored from {} at step {}".format(args.pretrained_checkpoint, int(checkpoint.step) - 1))
  42. disc_dir = os.path.join(args.output_dir, 'discriminator')
  43. gen_dir = os.path.join(args.output_dir, 'generator')
  44. heading(" ** Saving discriminator")
  45. model.discriminator(model.discriminator.dummy_inputs)
  46. model.discriminator.save_pretrained(disc_dir)
  47. heading(" ** Saving generator")
  48. model.generator(model.generator.dummy_inputs)
  49. model.generator.save_pretrained(gen_dir)
  50. if __name__ == '__main__':
  51. # Parse essential args
  52. parser = argparse.ArgumentParser()
  53. parser.add_argument('--pretrained_checkpoint')
  54. parser.add_argument('--output_dir')
  55. parser.add_argument('--amp', action='store_true', default=False)
  56. args = parser.parse_args()
  57. from_pretrained_ckpt(args)