transforms.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. """ COCO transforms (quick and dirty)
  2. Hacked together by Ross Wightman
  3. """
  4. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import torch
  18. from PIL import Image
  19. import numpy as np
  20. import random
  21. import math
  22. IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
  23. IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
  24. IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
  25. IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
  26. class ImageToNumpy:
  27. def __call__(self, pil_img, annotations: dict):
  28. np_img = np.array(pil_img, dtype=np.uint8)
  29. if np_img.ndim < 3:
  30. np_img = np.expand_dims(np_img, axis=-1)
  31. np_img = np.moveaxis(np_img, 2, 0) # HWC to CHW
  32. return np_img, annotations
  33. class ImageToTensor:
  34. def __init__(self, dtype=torch.float32):
  35. self.dtype = dtype
  36. def __call__(self, pil_img, annotations: dict):
  37. np_img = np.array(pil_img, dtype=np.uint8)
  38. if np_img.ndim < 3:
  39. np_img = np.expand_dims(np_img, axis=-1)
  40. np_img = np.moveaxis(np_img, 2, 0) # HWC to CHW
  41. return torch.from_numpy(np_img).to(dtype=self.dtype), annotations
  42. class TargetToTensor:
  43. def __init__(self, dtype=torch.float32):
  44. self.dtype = dtype
  45. def __call__(self, pil_img, annotations: dict):
  46. annotations['bbox'] = torch.from_numpy(annotations['bbox']).to(dtype=self.dtype)
  47. annotations['cls'] = torch.from_numpy(annotations['cls']).to(dtype=torch.int64)
  48. return pil_img, annotations
  49. def _pil_interp(method):
  50. if method == 'bicubic':
  51. return Image.BICUBIC
  52. elif method == 'lanczos':
  53. return Image.LANCZOS
  54. elif method == 'hamming':
  55. return Image.HAMMING
  56. else:
  57. # default bilinear, do we want to allow nearest?
  58. return Image.BILINEAR
  59. def clip_boxes_(boxes, img_size):
  60. height, width = img_size
  61. clip_upper = np.array([height, width] * 2, dtype=boxes.dtype)
  62. np.clip(boxes, 0, clip_upper, out=boxes)
  63. def clip_boxes(boxes, img_size):
  64. clipped_boxes = boxes.copy()
  65. clip_boxes_(clipped_boxes, img_size)
  66. return clipped_boxes
  67. def _size_tuple(size):
  68. if isinstance(size, int):
  69. return size, size
  70. else:
  71. assert len(size) == 2
  72. return size
  73. class ResizePad:
  74. def __init__(self, target_size: int, interpolation: str = 'bilinear', fill_color: tuple = (0, 0, 0)):
  75. self.target_size = _size_tuple(target_size)
  76. self.interpolation = interpolation
  77. self.fill_color = fill_color
  78. def __call__(self, img, anno: dict):
  79. width, height = img.size
  80. img_scale_y = self.target_size[0] / height
  81. img_scale_x = self.target_size[1] / width
  82. img_scale = min(img_scale_y, img_scale_x)
  83. scaled_h = int(height * img_scale)
  84. scaled_w = int(width * img_scale)
  85. new_img = Image.new("RGB", (self.target_size[1], self.target_size[0]), color=self.fill_color)
  86. interp_method = _pil_interp(self.interpolation)
  87. img = img.resize((scaled_w, scaled_h), interp_method)
  88. new_img.paste(img)
  89. if 'bbox' in anno:
  90. # FIXME haven't tested this path since not currently using dataset annotations for train/eval
  91. bbox = anno['bbox']
  92. bbox[:, :4] *= img_scale
  93. clip_boxes_(bbox, (scaled_h, scaled_w))
  94. valid_indices = (bbox[:, :2] < bbox[:, 2:4]).all(axis=1)
  95. anno['bbox'] = bbox[valid_indices, :]
  96. anno['cls'] = anno['cls'][valid_indices]
  97. anno['img_scale'] = 1. / img_scale # back to original
  98. return new_img, anno
  99. class RandomResizePad:
  100. def __init__(self, target_size: int, scale: tuple = (0.1, 2.0), interpolation: str = 'bilinear',
  101. fill_color: tuple = (0, 0, 0)):
  102. self.target_size = _size_tuple(target_size)
  103. self.scale = scale
  104. self.interpolation = interpolation
  105. self.fill_color = fill_color
  106. def _get_params(self, img):
  107. # Select a random scale factor.
  108. scale_factor = random.uniform(*self.scale)
  109. scaled_target_height = scale_factor * self.target_size[0]
  110. scaled_target_width = scale_factor * self.target_size[1]
  111. # Recompute the accurate scale_factor using rounded scaled image size.
  112. width, height = img.size
  113. img_scale_y = scaled_target_height / height
  114. img_scale_x = scaled_target_width / width
  115. img_scale = min(img_scale_y, img_scale_x)
  116. # Select non-zero random offset (x, y) if scaled image is larger than target size
  117. scaled_h = int(height * img_scale)
  118. scaled_w = int(width * img_scale)
  119. offset_y = scaled_h - self.target_size[0]
  120. offset_x = scaled_w - self.target_size[1]
  121. offset_y = int(max(0.0, float(offset_y)) * random.uniform(0, 1))
  122. offset_x = int(max(0.0, float(offset_x)) * random.uniform(0, 1))
  123. return scaled_h, scaled_w, offset_y, offset_x, img_scale
  124. def __call__(self, img, anno: dict):
  125. scaled_h, scaled_w, offset_y, offset_x, img_scale = self._get_params(img)
  126. interp_method = _pil_interp(self.interpolation)
  127. img = img.resize((scaled_w, scaled_h), interp_method)
  128. right, lower = min(scaled_w, offset_x + self.target_size[1]), min(scaled_h, offset_y + self.target_size[0])
  129. img = img.crop((offset_x, offset_y, right, lower))
  130. new_img = Image.new("RGB", (self.target_size[1], self.target_size[0]), color=self.fill_color)
  131. new_img.paste(img)
  132. if 'bbox' in anno:
  133. # FIXME not fully tested
  134. bbox = anno['bbox'].copy() # FIXME copy for debugger inspection, back to inplace
  135. bbox[:, :4] *= img_scale
  136. box_offset = np.stack([offset_y, offset_x] * 2)
  137. bbox -= box_offset
  138. clip_boxes_(bbox, (scaled_h, scaled_w))
  139. valid_indices = (bbox[:, :2] < bbox[:, 2:4]).all(axis=1)
  140. anno['bbox'] = bbox[valid_indices, :]
  141. anno['cls'] = anno['cls'][valid_indices]
  142. anno['img_scale'] = 1. / img_scale # back to original
  143. return new_img, anno
  144. class RandomFlip:
  145. def __init__(self, horizontal=True, vertical=False, prob=0.5):
  146. self.horizontal = horizontal
  147. self.vertical = vertical
  148. self.prob = prob
  149. def _get_params(self):
  150. do_horizontal = random.random() < self.prob if self.horizontal else False
  151. do_vertical = random.random() < self.prob if self.vertical else False
  152. return do_horizontal, do_vertical
  153. def __call__(self, img, annotations: dict):
  154. do_horizontal, do_vertical = self._get_params()
  155. width, height = img.size
  156. def _fliph(bbox):
  157. x_max = width - bbox[:, 1]
  158. x_min = width - bbox[:, 3]
  159. bbox[:, 1] = x_min
  160. bbox[:, 3] = x_max
  161. def _flipv(bbox):
  162. y_max = height - bbox[:, 0]
  163. y_min = height - bbox[:, 2]
  164. bbox[:, 0] = y_min
  165. bbox[:, 2] = y_max
  166. if do_horizontal and do_vertical:
  167. img = img.transpose(Image.ROTATE_180)
  168. if 'bbox' in annotations:
  169. _fliph(annotations['bbox'])
  170. _flipv(annotations['bbox'])
  171. elif do_horizontal:
  172. img = img.transpose(Image.FLIP_LEFT_RIGHT)
  173. if 'bbox' in annotations:
  174. _fliph(annotations['bbox'])
  175. elif do_vertical:
  176. img = img.transpose(Image.FLIP_TOP_BOTTOM)
  177. if 'bbox' in annotations:
  178. _flipv(annotations['bbox'])
  179. return img, annotations
  180. def resolve_fill_color(fill_color, img_mean=IMAGENET_DEFAULT_MEAN):
  181. if isinstance(fill_color, tuple):
  182. assert len(fill_color) == 3
  183. fill_color = fill_color
  184. else:
  185. try:
  186. int_color = int(fill_color)
  187. fill_color = (int_color,) * 3
  188. except ValueError:
  189. assert fill_color == 'mean'
  190. fill_color = tuple([int(round(255 * x)) for x in img_mean])
  191. return fill_color
  192. class Compose:
  193. def __init__(self, transforms: list):
  194. self.transforms = transforms
  195. def __call__(self, img, annotations: dict):
  196. for t in self.transforms:
  197. img, annotations = t(img, annotations)
  198. return img, annotations
  199. def transforms_coco_eval(
  200. img_size=224,
  201. interpolation='bilinear',
  202. use_prefetcher=False,
  203. fill_color='mean',
  204. mean=IMAGENET_DEFAULT_MEAN,
  205. std=IMAGENET_DEFAULT_STD):
  206. fill_color = resolve_fill_color(fill_color, mean)
  207. image_tfl = [
  208. ResizePad(
  209. target_size=img_size, interpolation=interpolation, fill_color=fill_color),
  210. TargetToTensor(),
  211. ImageToNumpy(),
  212. ]
  213. assert use_prefetcher, "Only supporting prefetcher usage right now"
  214. image_tf = Compose(image_tfl)
  215. return image_tf
  216. def transforms_coco_train(
  217. img_size=224,
  218. interpolation='random',
  219. use_prefetcher=False,
  220. fill_color='mean',
  221. mean=IMAGENET_DEFAULT_MEAN,
  222. std=IMAGENET_DEFAULT_STD):
  223. fill_color = resolve_fill_color(fill_color, mean)
  224. image_tfl = [
  225. RandomFlip(horizontal=True, prob=0.5),
  226. RandomResizePad(
  227. target_size=img_size, interpolation=interpolation, fill_color=fill_color),
  228. TargetToTensor(),
  229. ImageToNumpy(),
  230. ]
  231. assert use_prefetcher, "Only supporting prefetcher usage right now"
  232. image_tf = Compose(image_tfl)
  233. return image_tf