data.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. # Copyright 2017-2018 The Apache Software Foundation
  2. #
  3. # Licensed to the Apache Software Foundation (ASF) under one
  4. # or more contributor license agreements. See the NOTICE file
  5. # distributed with this work for additional information
  6. # regarding copyright ownership. The ASF licenses this file
  7. # to you under the Apache License, Version 2.0 (the
  8. # "License"); you may not use this file except in compliance
  9. # with the License. You may obtain a copy of the License at
  10. #
  11. # http://www.apache.org/licenses/LICENSE-2.0
  12. #
  13. # Unless required by applicable law or agreed to in writing,
  14. # software distributed under the License is distributed on an
  15. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  16. # KIND, either express or implied. See the License for the
  17. # specific language governing permissions and limitations
  18. # under the License.
  19. #
  20. # -----------------------------------------------------------------------
  21. #
  22. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  23. #
  24. # Licensed under the Apache License, Version 2.0 (the "License");
  25. # you may not use this file except in compliance with the License.
  26. # You may obtain a copy of the License at
  27. #
  28. # http://www.apache.org/licenses/LICENSE-2.0
  29. #
  30. # Unless required by applicable law or agreed to in writing, software
  31. # distributed under the License is distributed on an "AS IS" BASIS,
  32. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  33. # See the License for the specific language governing permissions and
  34. # limitations under the License.
  35. import mxnet as mx
  36. import mxnet.ndarray as nd
  37. import random
  38. import argparse
  39. from mxnet.io import DataBatch, DataIter
  40. import numpy as np
  41. import horovod.mxnet as hvd
  42. import dali
  43. def add_data_args(parser):
  44. def float_list(x):
  45. return list(map(float, x.split(',')))
  46. def int_list(x):
  47. return list(map(int, x.split(',')))
  48. data = parser.add_argument_group('Data')
  49. data.add_argument('--data-train', type=str, help='the training data')
  50. data.add_argument('--data-train-idx', type=str, default='', help='the index of training data')
  51. data.add_argument('--data-val', type=str, help='the validation data')
  52. data.add_argument('--data-val-idx', type=str, default='', help='the index of validation data')
  53. data.add_argument('--data-pred', type=str, help='the image on which run inference (only for pred mode)')
  54. data.add_argument('--data-backend', choices=('dali-gpu', 'dali-cpu', 'mxnet', 'synthetic'), default='dali-gpu',
  55. help='set data loading & augmentation backend')
  56. data.add_argument('--image-shape', type=int_list, default=[3, 224, 224],
  57. help='the image shape feed into the network')
  58. data.add_argument('--rgb-mean', type=float_list, default=[123.68, 116.779, 103.939],
  59. help='a tuple of size 3 for the mean rgb')
  60. data.add_argument('--rgb-std', type=float_list, default=[58.393, 57.12, 57.375],
  61. help='a tuple of size 3 for the std rgb')
  62. data.add_argument('--input-layout', type=str, default='NCHW', choices=('NCHW', 'NHWC'),
  63. help='the layout of the input data')
  64. data.add_argument('--conv-layout', type=str, default='NCHW', choices=('NCHW', 'NHWC'),
  65. help='the layout of the data assumed by the conv operation')
  66. data.add_argument('--batchnorm-layout', type=str, default='NCHW', choices=('NCHW', 'NHWC'),
  67. help='the layout of the data assumed by the batchnorm operation')
  68. data.add_argument('--pooling-layout', type=str, default='NCHW', choices=('NCHW', 'NHWC'),
  69. help='the layout of the data assumed by the pooling operation')
  70. data.add_argument('--num-examples', type=int, default=1281167,
  71. help="the number of training examples (doesn't work with mxnet data backend)")
  72. data.add_argument('--data-val-resize', type=int, default=256,
  73. help='base length of shorter edge for validation dataset')
  74. return data
  75. def add_data_aug_args(parser):
  76. aug = parser.add_argument_group(
  77. 'MXNet data backend', 'entire group applies only to mxnet data backend')
  78. aug.add_argument('--data-mxnet-threads', type=int, default=40,
  79. help='number of threads for data decoding for mxnet data backend')
  80. aug.add_argument('--random-crop', type=int, default=0,
  81. help='if or not randomly crop the image')
  82. aug.add_argument('--random-mirror', type=int, default=1,
  83. help='if or not randomly flip horizontally')
  84. aug.add_argument('--max-random-h', type=int, default=0,
  85. help='max change of hue, whose range is [0, 180]')
  86. aug.add_argument('--max-random-s', type=int, default=0,
  87. help='max change of saturation, whose range is [0, 255]')
  88. aug.add_argument('--max-random-l', type=int, default=0,
  89. help='max change of intensity, whose range is [0, 255]')
  90. aug.add_argument('--min-random-aspect-ratio', type=float, default=0.75,
  91. help='min value of aspect ratio, whose value is either None or a positive value.')
  92. aug.add_argument('--max-random-aspect-ratio', type=float, default=1.33,
  93. help='max value of aspect ratio. If min_random_aspect_ratio is None, '
  94. 'the aspect ratio range is [1-max_random_aspect_ratio, '
  95. '1+max_random_aspect_ratio], otherwise it is '
  96. '[min_random_aspect_ratio, max_random_aspect_ratio].')
  97. aug.add_argument('--max-random-rotate-angle', type=int, default=0,
  98. help='max angle to rotate, whose range is [0, 360]')
  99. aug.add_argument('--max-random-shear-ratio', type=float, default=0,
  100. help='max ratio to shear, whose range is [0, 1]')
  101. aug.add_argument('--max-random-scale', type=float, default=1,
  102. help='max ratio to scale')
  103. aug.add_argument('--min-random-scale', type=float, default=1,
  104. help='min ratio to scale, should >= img_size/input_shape. '
  105. 'otherwise use --pad-size')
  106. aug.add_argument('--max-random-area', type=float, default=1,
  107. help='max area to crop in random resized crop, whose range is [0, 1]')
  108. aug.add_argument('--min-random-area', type=float, default=0.05,
  109. help='min area to crop in random resized crop, whose range is [0, 1]')
  110. aug.add_argument('--min-crop-size', type=int, default=-1,
  111. help='Crop both width and height into a random size in '
  112. '[min_crop_size, max_crop_size]')
  113. aug.add_argument('--max-crop-size', type=int, default=-1,
  114. help='Crop both width and height into a random size in '
  115. '[min_crop_size, max_crop_size]')
  116. aug.add_argument('--brightness', type=float, default=0,
  117. help='brightness jittering, whose range is [0, 1]')
  118. aug.add_argument('--contrast', type=float, default=0,
  119. help='contrast jittering, whose range is [0, 1]')
  120. aug.add_argument('--saturation', type=float, default=0,
  121. help='saturation jittering, whose range is [0, 1]')
  122. aug.add_argument('--pca-noise', type=float, default=0,
  123. help='pca noise, whose range is [0, 1]')
  124. aug.add_argument('--random-resized-crop', type=int, default=1,
  125. help='whether to use random resized crop')
  126. return aug
  127. def get_data_loader(args):
  128. if args.data_backend == 'dali-gpu':
  129. return (lambda *args, **kwargs: dali.get_rec_iter(*args, **kwargs, dali_cpu=False))
  130. if args.data_backend == 'dali-cpu':
  131. return (lambda *args, **kwargs: dali.get_rec_iter(*args, **kwargs, dali_cpu=True))
  132. if args.data_backend == 'synthetic':
  133. return get_synthetic_rec_iter
  134. if args.data_backend == 'mxnet':
  135. return get_rec_iter
  136. raise ValueError('Wrong data backend')
  137. class DataGPUSplit:
  138. def __init__(self, dataloader, ctx, dtype):
  139. self.dataloader = dataloader
  140. self.ctx = ctx
  141. self.dtype = dtype
  142. self.batch_size = dataloader.batch_size // len(ctx)
  143. self._num_gpus = len(ctx)
  144. def __iter__(self):
  145. return DataGPUSplit(iter(self.dataloader), self.ctx, self.dtype)
  146. def __next__(self):
  147. data = next(self.dataloader)
  148. ret = []
  149. for i in range(len(self.ctx)):
  150. start = i * len(data.data[0]) // len(self.ctx)
  151. end = (i + 1) * len(data.data[0]) // len(self.ctx)
  152. pad = max(0, min(data.pad - (len(self.ctx) - i - 1) * self.batch_size, self.batch_size))
  153. ret.append(mx.io.DataBatch(
  154. [data.data[0][start:end].as_in_context(self.ctx[i]).astype(self.dtype)],
  155. [data.label[0][start:end].as_in_context(self.ctx[i])],
  156. pad=pad))
  157. return ret
  158. def next(self):
  159. return next(self)
  160. def reset(self):
  161. self.dataloader.reset()
  162. def get_rec_iter(args, kv=None):
  163. gpus = args.gpus
  164. if 'horovod' in args.kv_store:
  165. rank = hvd.rank()
  166. nworker = hvd.size()
  167. gpus = [gpus[0]]
  168. batch_size = args.batch_size // hvd.size()
  169. else:
  170. rank = kv.rank if kv else 0
  171. nworker = kv.num_workers if kv else 1
  172. batch_size = args.batch_size
  173. if args.input_layout == 'NHWC':
  174. raise ValueError('ImageRecordIter cannot handle layout {}'.format(args.input_layout))
  175. train = DataGPUSplit(mx.io.ImageRecordIter(
  176. path_imgrec = args.data_train,
  177. path_imgidx = args.data_train_idx,
  178. label_width = 1,
  179. mean_r = args.rgb_mean[0],
  180. mean_g = args.rgb_mean[1],
  181. mean_b = args.rgb_mean[2],
  182. std_r = args.rgb_std[0],
  183. std_g = args.rgb_std[1],
  184. std_b = args.rgb_std[2],
  185. data_name = 'data',
  186. label_name = 'softmax_label',
  187. data_shape = args.image_shape,
  188. batch_size = batch_size,
  189. rand_crop = args.random_crop,
  190. max_random_scale = args.max_random_scale,
  191. random_resized_crop = args.random_resized_crop,
  192. min_random_scale = args.min_random_scale,
  193. max_aspect_ratio = args.max_random_aspect_ratio,
  194. min_aspect_ratio = args.min_random_aspect_ratio,
  195. max_random_area = args.max_random_area,
  196. min_random_area = args.min_random_area,
  197. min_crop_size = args.min_crop_size,
  198. max_crop_size = args.max_crop_size,
  199. brightness = args.brightness,
  200. contrast = args.contrast,
  201. saturation = args.saturation,
  202. pca_noise = args.pca_noise,
  203. random_h = args.max_random_h,
  204. random_s = args.max_random_s,
  205. random_l = args.max_random_l,
  206. max_rotate_angle = args.max_random_rotate_angle,
  207. max_shear_ratio = args.max_random_shear_ratio,
  208. rand_mirror = args.random_mirror,
  209. preprocess_threads = args.data_mxnet_threads,
  210. shuffle = True,
  211. num_parts = nworker,
  212. part_index = rank,
  213. seed = args.seed or '0',
  214. ), [mx.gpu(gpu) for gpu in gpus], args.dtype)
  215. if args.data_val is None:
  216. return (train, None)
  217. val = DataGPUSplit(mx.io.ImageRecordIter(
  218. path_imgrec = args.data_val,
  219. path_imgidx = args.data_val_idx,
  220. label_width = 1,
  221. mean_r = args.rgb_mean[0],
  222. mean_g = args.rgb_mean[1],
  223. mean_b = args.rgb_mean[2],
  224. std_r = args.rgb_std[0],
  225. std_g = args.rgb_std[1],
  226. std_b = args.rgb_std[2],
  227. data_name = 'data',
  228. label_name = 'softmax_label',
  229. batch_size = batch_size,
  230. round_batch = False,
  231. data_shape = args.image_shape,
  232. preprocess_threads = args.data_mxnet_threads,
  233. rand_crop = False,
  234. rand_mirror = False,
  235. num_parts = nworker,
  236. part_index = rank,
  237. resize = args.data_val_resize,
  238. ), [mx.gpu(gpu) for gpu in gpus], args.dtype)
  239. return (train, val)
  240. class SyntheticDataIter(DataIter):
  241. def __init__(self, num_classes, data_shape, max_iter, ctx, dtype):
  242. self.batch_size = data_shape[0]
  243. self.cur_iter = 0
  244. self.max_iter = max_iter
  245. self.dtype = dtype
  246. label = np.random.randint(0, num_classes, [self.batch_size,])
  247. data = np.random.uniform(-1, 1, data_shape)
  248. self.data = []
  249. self.label = []
  250. self._num_gpus = len(ctx)
  251. for dev in ctx:
  252. self.data.append(mx.nd.array(data, dtype=self.dtype, ctx=dev))
  253. self.label.append(mx.nd.array(label, dtype=self.dtype, ctx=dev))
  254. def __iter__(self):
  255. return self
  256. def next(self):
  257. self.cur_iter += 1
  258. if self.cur_iter <= self.max_iter:
  259. return [DataBatch(data=(data,), label=(label,), pad=0) for data, label in zip(self.data, self.label)]
  260. else:
  261. raise StopIteration
  262. def __next__(self):
  263. return self.next()
  264. def reset(self):
  265. self.cur_iter = 0
  266. def get_synthetic_rec_iter(args, kv=None):
  267. gpus = args.gpus
  268. if 'horovod' in args.kv_store:
  269. gpus = [gpus[0]]
  270. batch_size = args.batch_size // hvd.size()
  271. else:
  272. batch_size = args.batch_size
  273. if args.input_layout == 'NCHW':
  274. data_shape = (batch_size, *args.image_shape)
  275. elif args.input_layout == 'NHWC':
  276. data_shape = (batch_size, *args.image_shape[1:], args.image_shape[0])
  277. else:
  278. raise ValueError('Wrong input layout')
  279. train = SyntheticDataIter(args.num_classes, data_shape,
  280. args.num_examples // args.batch_size,
  281. [mx.gpu(gpu) for gpu in gpus], args.dtype)
  282. if args.data_val is None:
  283. return (train, None)
  284. val = SyntheticDataIter(args.num_classes, data_shape,
  285. args.num_examples // args.batch_size,
  286. [mx.gpu(gpu) for gpu in gpus], args.dtype)
  287. return (train, val)
  288. def load_image(args, path, ctx=mx.cpu()):
  289. image = mx.image.imread(path).astype('float32')
  290. image = mx.image.imresize(image, *args.image_shape[1:])
  291. image = (image - nd.array(args.rgb_mean)) / nd.array(args.rgb_std)
  292. image = image.as_in_context(ctx)
  293. if args.input_layout == 'NCHW':
  294. image = image.transpose((2, 0, 1))
  295. image = image.astype(args.dtype)
  296. if args.image_shape[0] == 4:
  297. dim = 0 if args.input_layout == 'NCHW' else 2
  298. image = nd.concat(image, nd.zeros((1, *image.shape[1:]), dtype=image.dtype, ctx=image.context), dim=dim)
  299. return image