preprocess_dagm2007.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. ##############################################################################
  4. # Copyright (c) Jonathan Dekhtiar - [email protected]
  5. # All Rights Reserved.
  6. #
  7. # This source code is licensed under the MIT license found in the
  8. # LICENSE file in the root directory of this source tree.
  9. ##############################################################################
  10. import os
  11. import glob
  12. import ntpath
  13. import argparse
  14. from collections import defaultdict
  15. parser = argparse.ArgumentParser(description="DAGM2007_preprocessing")
  16. parser.add_argument('--data_dir', required=True, type=str, help="Path to DAGM 2007 private dataset")
  17. DEFECTIVE_COUNT = defaultdict(lambda: defaultdict(int))
  18. EXPECTED_DEFECTIVE_SAMPLES_PER_CLASS = {
  19. "Train": {
  20. 1: 79,
  21. 2: 66,
  22. 3: 66,
  23. 4: 82,
  24. 5: 70,
  25. 6: 83,
  26. 7: 150,
  27. 8: 150,
  28. 9: 150,
  29. 10: 150,
  30. },
  31. "Test": {
  32. 1: 71,
  33. 2: 84,
  34. 3: 84,
  35. 4: 68,
  36. 5: 80,
  37. 6: 67,
  38. 7: 150,
  39. 8: 150,
  40. 9: 150,
  41. 10: 150,
  42. }
  43. }
  44. if __name__ == "__main__":
  45. FLAGS, unknown_args = parser.parse_known_args()
  46. if len(unknown_args) > 0:
  47. for bad_arg in unknown_args:
  48. print("ERROR: Unknown command line arg: %s" % bad_arg)
  49. raise ValueError("Invalid command line arg(s)")
  50. if not os.path.exists(FLAGS.data_dir):
  51. raise ValueError('The dataset directory received `%s` does not exists' % FLAGS.data_dir)
  52. for challenge_id in range(10):
  53. challenge_name = "Class%d" % (challenge_id + 1)
  54. challenge_folder_path = os.path.join(FLAGS.data_dir, challenge_name)
  55. print("[DAGM Preprocessing] Parsing Class ID: %02d ..." % (challenge_id + 1))
  56. if not os.path.exists(challenge_folder_path):
  57. raise ValueError('The folder `%s` does not exists' % challenge_folder_path)
  58. for data_set in ["Train", "Test"]:
  59. challenge_set_folder_path = os.path.join(challenge_folder_path, data_set)
  60. if not os.path.exists(challenge_set_folder_path):
  61. raise ValueError('The folder `%s` does not exists' % challenge_set_folder_path)
  62. with open(os.path.join(challenge_folder_path, "%s_list.csv" % data_set.lower()), 'w') as data_list_file:
  63. data_list_file.write('image_filepath,lbl_image_filepath,is_defective\n')
  64. files = glob.glob(os.path.join(challenge_set_folder_path, "*.PNG"))
  65. for file in files:
  66. filepath, fullname = ntpath.split(file)
  67. filename, extension = os.path.splitext(os.path.basename(fullname))
  68. lbl_filename = "%s_label.PNG" % filename
  69. lbl_filepath = os.path.join(filepath, "Label", lbl_filename)
  70. if os.path.exists(lbl_filepath):
  71. defective = True
  72. else:
  73. defective = False
  74. lbl_filename = ""
  75. if defective:
  76. DEFECTIVE_COUNT[data_set][challenge_id + 1] += 1
  77. data_list_file.write('%s,%s,%d\n' % (fullname, lbl_filename, defective))
  78. if DEFECTIVE_COUNT[data_set][challenge_id +
  79. 1] != EXPECTED_DEFECTIVE_SAMPLES_PER_CLASS[data_set][challenge_id + 1]:
  80. raise RuntimeError(
  81. "There should be `%d` defective samples instead of `%d` in challenge (%s): %d" % (
  82. DEFECTIVE_COUNT[data_set][challenge_id + 1],
  83. EXPECTED_DEFECTIVE_SAMPLES_PER_CLASS[data_set][challenge_id + 1], data_set, challenge_id + 1
  84. )
  85. )