prepare_dataset.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435
  1. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from argparse import ArgumentParser
  15. from vae.load.preprocessing import load_and_parse_ML_20M
  16. import numpy as np
  17. parser = ArgumentParser(description="Prepare data for VAE training")
  18. parser.add_argument('--data_dir', default='/data', type=str,
  19. help='Directory for storing the training data')
  20. parser.add_argument('--seed', default=0, type=int,
  21. help='Random seed')
  22. args = parser.parse_args()
  23. print('Preprocessing seed: ', args.seed)
  24. np.random.seed(args.seed)
  25. # load dataset
  26. (train_data,
  27. validation_data_input,
  28. validation_data_true,
  29. test_data_input,
  30. test_data_true) = load_and_parse_ML_20M(args.data_dir)