dataset.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. """ COCO dataset (quick and dirty)
  2. Hacked together by Ross Wightman
  3. """
  4. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. from __future__ import absolute_import
  18. from __future__ import division
  19. from __future__ import print_function
  20. import torch.utils.data as data
  21. import os
  22. import torch
  23. import numpy as np
  24. from PIL import Image
  25. from pycocotools.coco import COCO
  26. from effdet.anchors import Anchors, AnchorLabeler
  27. class CocoDetection(data.Dataset):
  28. """`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.
  29. Args:
  30. root (string): Root directory where images are downloaded to.
  31. ann_file (string): Path to json annotation file.
  32. transform (callable, optional): A function/transform that takes in an PIL image
  33. and returns a transformed version. E.g, ``transforms.ToTensor``
  34. """
  35. def __init__(self, root, ann_file, config, transform=None):
  36. super(CocoDetection, self).__init__()
  37. if isinstance(root, (str, bytes)):
  38. root = os.path.expanduser(root)
  39. self.root = root
  40. self.transform = transform
  41. self.yxyx = True # expected for TF model, most PT are xyxy
  42. self.include_masks = False
  43. self.include_bboxes_ignore = False
  44. self.has_annotations = 'image_info' not in ann_file
  45. self.coco = None
  46. self.cat_ids = []
  47. self.cat_to_label = dict()
  48. self.img_ids = []
  49. self.img_ids_invalid = []
  50. self.img_infos = []
  51. self._load_annotations(ann_file)
  52. self.anchors = Anchors(
  53. config.min_level, config.max_level,
  54. config.num_scales, config.aspect_ratios,
  55. config.anchor_scale, config.image_size)
  56. self.anchor_labeler = AnchorLabeler(self.anchors, config.num_classes, match_threshold=0.5)
  57. def _load_annotations(self, ann_file):
  58. assert self.coco is None
  59. self.coco = COCO(ann_file)
  60. self.cat_ids = self.coco.getCatIds()
  61. img_ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values())
  62. for img_id in sorted(self.coco.imgs.keys()):
  63. info = self.coco.loadImgs([img_id])[0]
  64. valid_annotation = not self.has_annotations or img_id in img_ids_with_ann
  65. if valid_annotation and min(info['width'], info['height']) >= 32:
  66. self.img_ids.append(img_id)
  67. self.img_infos.append(info)
  68. else:
  69. self.img_ids_invalid.append(img_id)
  70. def _parse_img_ann(self, img_id, img_info):
  71. ann_ids = self.coco.getAnnIds(imgIds=[img_id])
  72. ann_info = self.coco.loadAnns(ann_ids)
  73. bboxes = []
  74. bboxes_ignore = []
  75. cls = []
  76. for i, ann in enumerate(ann_info):
  77. if ann.get('ignore', False):
  78. continue
  79. x1, y1, w, h = ann['bbox']
  80. if self.include_masks and ann['area'] <= 0:
  81. continue
  82. if w < 1 or h < 1:
  83. continue
  84. # To subtract 1 or not, TF doesn't appear to do this so will keep it out for now.
  85. if self.yxyx:
  86. #bbox = [y1, x1, y1 + h - 1, x1 + w - 1]
  87. bbox = [y1, x1, y1 + h, x1 + w]
  88. else:
  89. #bbox = [x1, y1, x1 + w - 1, y1 + h - 1]
  90. bbox = [x1, y1, x1 + w, y1 + h]
  91. if ann.get('iscrowd', False):
  92. if self.include_bboxes_ignore:
  93. bboxes_ignore.append(bbox)
  94. else:
  95. bboxes.append(bbox)
  96. cls.append(self.cat_to_label[ann['category_id']] if self.cat_to_label else ann['category_id'])
  97. if bboxes:
  98. bboxes = np.array(bboxes, dtype=np.float32)
  99. cls = np.array(cls, dtype=np.int64)
  100. else:
  101. bboxes = np.zeros((0, 4), dtype=np.float32)
  102. cls = np.array([], dtype=np.int64)
  103. if self.include_bboxes_ignore:
  104. if bboxes_ignore:
  105. bboxes_ignore = np.array(bboxes_ignore, dtype=np.float32)
  106. else:
  107. bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
  108. ann = dict(img_id=img_id, bbox=bboxes, cls=cls, img_size=(img_info['width'], img_info['height']))
  109. if self.include_bboxes_ignore:
  110. ann['bbox_ignore'] = bboxes_ignore
  111. return ann
  112. def __getitem__(self, index):
  113. """
  114. Args:
  115. index (int): Index
  116. Returns:
  117. tuple: Tuple (image, annotations (target)).
  118. """
  119. img_id = self.img_ids[index]
  120. img_info = self.img_infos[index]
  121. if self.has_annotations:
  122. ann = self._parse_img_ann(img_id, img_info)
  123. else:
  124. ann = dict(img_id=img_id, img_size=(img_info['width'], img_info['height']))
  125. path = img_info['file_name']
  126. img = Image.open(os.path.join(self.root, path)).convert('RGB')
  127. if self.transform is not None:
  128. img, ann = self.transform(img, ann)
  129. cls_targets, box_targets, num_positives = self.anchor_labeler.label_anchors(
  130. ann['bbox'], ann['cls'])
  131. ann.pop('bbox')
  132. ann.pop('cls')
  133. ann['num_positives'] = num_positives
  134. ann.update(cls_targets)
  135. ann.update(box_targets)
  136. return img, ann
  137. def __len__(self):
  138. return len(self.img_ids)