prepare_dataset.sh 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. # Copyright (c) 2018, deepakn94, robieta. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #
  15. # -----------------------------------------------------------------------
  16. #
  17. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  18. #
  19. # Licensed under the Apache License, Version 2.0 (the "License");
  20. # you may not use this file except in compliance with the License.
  21. # You may obtain a copy of the License at
  22. #
  23. # http://www.apache.org/licenses/LICENSE-2.0
  24. #
  25. # Unless required by applicable law or agreed to in writing, software
  26. # distributed under the License is distributed on an "AS IS" BASIS,
  27. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  28. # See the License for the specific language governing permissions and
  29. # limitations under the License.
  30. #!/bin/bash
  31. set -e
  32. set -x
  33. DATASET_NAME=${1:-'ml-20m'}
  34. RAW_DATADIR=${2:-'/data'}
  35. CACHED_DATADIR=${3:-"${RAW_DATADIR}/cache/${DATASET_NAME}"}
  36. # you can add another option to this case in order to support other datasets
  37. case ${DATASET_NAME} in
  38. 'ml-20m')
  39. ZIP_PATH=${RAW_DATADIR}/'ml-20m.zip'
  40. RATINGS_PATH=${RAW_DATADIR}'/ml-20m/ratings.csv'
  41. ;;
  42. 'ml-1m')
  43. ZIP_PATH=${RAW_DATADIR}/'ml-1m.zip'
  44. RATINGS_PATH=${RAW_DATADIR}'/ml-1m/ratings.dat'
  45. ;;
  46. *)
  47. echo "Unsupported dataset name: $DATASET_NAME"
  48. exit 1
  49. esac
  50. if [ ! -d ${RAW_DATADIR} ]; then
  51. mkdir -p ${RAW_DATADIR}
  52. fi
  53. if [ ! -d ${CACHED_DATADIR} ]; then
  54. mkdir -p ${CACHED_DATADIR}
  55. fi
  56. if [ -f log ]; then
  57. rm -f log
  58. fi
  59. if [ ! -f ${ZIP_PATH} ]; then
  60. echo "Dataset not found. Please download it from: https://grouplens.org/datasets/movielens/20m/ and put it in ${ZIP_PATH}"
  61. exit 1
  62. fi
  63. if [ ! -f ${RATINGS_PATH} ]; then
  64. unzip -u ${ZIP_PATH} -d ${RAW_DATADIR}
  65. fi
  66. if [ ! -f ${CACHED_DATADIR}/train_ratings.pt ]; then
  67. echo "preprocessing ${RATINGS_PATH} and save to disk"
  68. t0=$(date +%s)
  69. python convert.py --path ${RATINGS_PATH} --output ${CACHED_DATADIR}
  70. t1=$(date +%s)
  71. delta=$(( $t1 - $t0 ))
  72. echo "Finish preprocessing in $delta seconds"
  73. else
  74. echo 'Using cached preprocessed data'
  75. fi
  76. echo "Dataset $DATASET_NAME successfully prepared at: $CACHED_DATADIR\n"
  77. echo "You can now run the training with: python -m torch.distributed.launch --nproc_per_node=<number_of_GPUs> --use_env ncf.py --data ${CACHED_DATADIR}"