download_dataset.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import os
  16. import tarfile
  17. from google_drive_downloader import GoogleDriveDownloader as gdd
  18. PARSER = argparse.ArgumentParser(description="V-Net medical")
  19. PARSER.add_argument('--data_dir',
  20. type=str,
  21. default='./data',
  22. help="""Directory where to download the dataset""")
  23. PARSER.add_argument('--dataset',
  24. type=str,
  25. default='hippocampus',
  26. help="""Dataset to download""")
  27. def main():
  28. FLAGS = PARSER.parse_args()
  29. if not os.path.exists(FLAGS.data_dir):
  30. os.makedirs(FLAGS.data_dir)
  31. filename = ''
  32. if FLAGS.dataset == 'hippocampus':
  33. filename = 'Task04_Hippocampus.tar'
  34. gdd.download_file_from_google_drive(file_id='1RzPB1_bqzQhlWvU-YGvZzhx2omcDh38C',
  35. dest_path=os.path.join(FLAGS.data_dir, filename),
  36. unzip=False)
  37. print('Unpacking...')
  38. tf = tarfile.open(os.path.join(FLAGS.data_dir, filename))
  39. tf.extractall(path=FLAGS.data_dir)
  40. print('Cleaning up...')
  41. os.remove(os.path.join(FLAGS.data_dir, filename))
  42. print("Finished downloading files for V-Net medical to {}".format(FLAGS.data_dir))
  43. if __name__ == '__main__':
  44. main()