data.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. # Copyright 2017-2018 The Apache Software Foundation
  2. #
  3. # Licensed to the Apache Software Foundation (ASF) under one
  4. # or more contributor license agreements. See the NOTICE file
  5. # distributed with this work for additional information
  6. # regarding copyright ownership. The ASF licenses this file
  7. # to you under the Apache License, Version 2.0 (the
  8. # "License"); you may not use this file except in compliance
  9. # with the License. You may obtain a copy of the License at
  10. #
  11. # http://www.apache.org/licenses/LICENSE-2.0
  12. #
  13. # Unless required by applicable law or agreed to in writing,
  14. # software distributed under the License is distributed on an
  15. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  16. # KIND, either express or implied. See the License for the
  17. # specific language governing permissions and limitations
  18. # under the License.
  19. #
  20. # -----------------------------------------------------------------------
  21. #
  22. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  23. #
  24. # Licensed under the Apache License, Version 2.0 (the "License");
  25. # you may not use this file except in compliance with the License.
  26. # You may obtain a copy of the License at
  27. #
  28. # http://www.apache.org/licenses/LICENSE-2.0
  29. #
  30. # Unless required by applicable law or agreed to in writing, software
  31. # distributed under the License is distributed on an "AS IS" BASIS,
  32. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  33. # See the License for the specific language governing permissions and
  34. # limitations under the License.
  35. import mxnet as mx
  36. import random
  37. import argparse
  38. from mxnet.io import DataBatch, DataIter
  39. import numpy as np
  40. def add_data_args(parser):
  41. data = parser.add_argument_group('Data', 'the input images')
  42. data.add_argument('--data-train', type=str, help='the training data')
  43. data.add_argument('--data-train-idx', type=str, default='', help='the index of training data')
  44. data.add_argument('--data-val', type=str, help='the validation data')
  45. data.add_argument('--data-val-idx', type=str, default='', help='the index of validation data')
  46. data.add_argument('--rgb-mean', type=str, default='123.68,116.779,103.939',
  47. help='a tuple of size 3 for the mean rgb')
  48. data.add_argument('--rgb-std', type=str, default='1,1,1',
  49. help='a tuple of size 3 for the std rgb')
  50. data.add_argument('--pad-size', type=int, default=0,
  51. help='padding the input image')
  52. data.add_argument('--fill-value', type=int, default=127,
  53. help='Set the padding pixels value to fill_value')
  54. data.add_argument('--image-shape', type=str,
  55. help='the image shape feed into the network, e.g. (3,224,224)')
  56. data.add_argument('--num-classes', type=int, help='the number of classes')
  57. data.add_argument('--num-examples', type=int, help='the number of training examples')
  58. data.add_argument('--data-nthreads', type=int, default=4,
  59. help='number of threads for data decoding')
  60. data.add_argument('--benchmark-iters', type=int, default=None,
  61. help='run only benchmark-iters iterations from each epoch')
  62. data.add_argument('--input-layout', type=str, default='NCHW',
  63. help='the layout of the input data (e.g. NCHW)')
  64. data.add_argument('--conv-layout', type=str, default='NCHW',
  65. help='the layout of the data assumed by the conv operation (e.g. NCHW)')
  66. data.add_argument('--conv-algo', type=int, default=-1,
  67. help='set the convolution algos (fwd, dgrad, wgrad)')
  68. data.add_argument('--batchnorm-layout', type=str, default='NCHW',
  69. help='the layout of the data assumed by the batchnorm operation (e.g. NCHW)')
  70. data.add_argument('--batchnorm-eps', type=float, default=2e-5,
  71. help='the amount added to the batchnorm variance to prevent output explosion.')
  72. data.add_argument('--batchnorm-mom', type=float, default=0.9,
  73. help='the leaky-integrator factor controling the batchnorm mean and variance.')
  74. data.add_argument('--pooling-layout', type=str, default='NCHW',
  75. help='the layout of the data assumed by the pooling operation (e.g. NCHW)')
  76. data.add_argument('--verbose', type=int, default=0,
  77. help='turn on reporting of chosen algos for convolution, etc.')
  78. data.add_argument('--seed', type=int, default=None,
  79. help='set the seed for python, nd and mxnet rngs')
  80. data.add_argument('--custom-bn-off', type=int, default=0,
  81. help='disable use of custom batchnorm kernel')
  82. data.add_argument('--fuse-bn-relu', type=int, default=0,
  83. help='have batchnorm kernel perform activation relu')
  84. data.add_argument('--fuse-bn-add-relu', type=int, default=0,
  85. help='have batchnorm kernel perform add followed by activation relu')
  86. data.add_argument('--force-tensor-core', type=int, default=0,
  87. help='require conv algos to be tensor core')
  88. return data
  89. # Action to translate --set-resnet-aug flag to its component settings.
  90. class SetResnetAugAction(argparse.Action):
  91. def __init__(self, nargs=0, **kwargs):
  92. if nargs != 0:
  93. raise ValueError('nargs for SetResnetAug must be 0.')
  94. super(SetResnetAugAction, self).__init__(nargs=nargs, **kwargs)
  95. def __call__(self, parser, namespace, values, option_string=None):
  96. # standard data augmentation setting for resnet training
  97. setattr(namespace, 'random_crop', 1)
  98. setattr(namespace, 'random_resized_crop', 1)
  99. setattr(namespace, 'random_mirror', 1)
  100. setattr(namespace, 'min_random_area', 0.08)
  101. setattr(namespace, 'max_random_aspect_ratio', 4./3.)
  102. setattr(namespace, 'min_random_aspect_ratio', 3./4.)
  103. setattr(namespace, 'brightness', 0.4)
  104. setattr(namespace, 'contrast', 0.4)
  105. setattr(namespace, 'saturation', 0.4)
  106. setattr(namespace, 'pca_noise', 0.1)
  107. # record that this --set-resnet-aug 'macro arg' has been invoked
  108. setattr(namespace, self.dest, 1)
  109. # Similar to the above, but suitable for calling within a training script to set the defaults.
  110. def set_resnet_aug(aug):
  111. # standard data augmentation setting for resnet training
  112. aug.set_defaults(random_crop=0, random_resized_crop=1)
  113. aug.set_defaults(random_mirror=1)
  114. aug.set_defaults(min_random_area=0.08)
  115. aug.set_defaults(max_random_aspect_ratio=4./3., min_random_aspect_ratio=3./4.)
  116. aug.set_defaults(brightness=0.4, contrast=0.4, saturation=0.4, pca_noise=0.1)
  117. # Action to translate --set-data-aug-level <N> arg to its component settings.
  118. class SetDataAugLevelAction(argparse.Action):
  119. def __init__(self, option_strings, dest, nargs=None, **kwargs):
  120. if nargs is not None:
  121. raise ValueError("nargs not allowed")
  122. super(SetDataAugLevelAction, self).__init__(option_strings, dest, **kwargs)
  123. def __call__(self, parser, namespace, values, option_string=None):
  124. level = values
  125. # record that this --set-data-aug-level <N> 'macro arg' has been invoked
  126. setattr(namespace, self.dest, level)
  127. if level >= 1:
  128. setattr(namespace, 'random_crop', 1)
  129. setattr(namespace, 'random_mirror', 1)
  130. if level >= 2:
  131. setattr(namespace, 'max_random_h', 36)
  132. setattr(namespace, 'max_random_s', 50)
  133. setattr(namespace, 'max_random_l', 50)
  134. if level >= 3:
  135. setattr(namespace, 'max_random_rotate_angle', 10)
  136. setattr(namespace, 'max_random_shear_ratio', 0.1)
  137. setattr(namespace, 'max_random_aspect_ratio', 0.25)
  138. # Similar to the above, but suitable for calling within a training script to set the defaults.
  139. def set_data_aug_level(aug, level):
  140. if level >= 1:
  141. aug.set_defaults(random_crop=1, random_mirror=1)
  142. if level >= 2:
  143. aug.set_defaults(max_random_h=36, max_random_s=50, max_random_l=50)
  144. if level >= 3:
  145. aug.set_defaults(max_random_rotate_angle=10, max_random_shear_ratio=0.1, max_random_aspect_ratio=0.25)
  146. def add_data_aug_args(parser):
  147. aug = parser.add_argument_group(
  148. 'Image augmentations', 'implemented in src/io/image_aug_default.cc')
  149. aug.add_argument('--random-crop', type=int, default=0,
  150. help='if or not randomly crop the image')
  151. aug.add_argument('--random-mirror', type=int, default=0,
  152. help='if or not randomly flip horizontally')
  153. aug.add_argument('--max-random-h', type=int, default=0,
  154. help='max change of hue, whose range is [0, 180]')
  155. aug.add_argument('--max-random-s', type=int, default=0,
  156. help='max change of saturation, whose range is [0, 255]')
  157. aug.add_argument('--max-random-l', type=int, default=0,
  158. help='max change of intensity, whose range is [0, 255]')
  159. aug.add_argument('--min-random-aspect-ratio', type=float, default=None,
  160. help='min value of aspect ratio, whose value is either None or a positive value.')
  161. aug.add_argument('--max-random-aspect-ratio', type=float, default=0,
  162. help='max value of aspect ratio. If min_random_aspect_ratio is None, '
  163. 'the aspect ratio range is [1-max_random_aspect_ratio, '
  164. '1+max_random_aspect_ratio], otherwise it is '
  165. '[min_random_aspect_ratio, max_random_aspect_ratio].')
  166. aug.add_argument('--max-random-rotate-angle', type=int, default=0,
  167. help='max angle to rotate, whose range is [0, 360]')
  168. aug.add_argument('--max-random-shear-ratio', type=float, default=0,
  169. help='max ratio to shear, whose range is [0, 1]')
  170. aug.add_argument('--max-random-scale', type=float, default=1,
  171. help='max ratio to scale')
  172. aug.add_argument('--min-random-scale', type=float, default=1,
  173. help='min ratio to scale, should >= img_size/input_shape. '
  174. 'otherwise use --pad-size')
  175. aug.add_argument('--max-random-area', type=float, default=1,
  176. help='max area to crop in random resized crop, whose range is [0, 1]')
  177. aug.add_argument('--min-random-area', type=float, default=1,
  178. help='min area to crop in random resized crop, whose range is [0, 1]')
  179. aug.add_argument('--min-crop-size', type=int, default=-1,
  180. help='Crop both width and height into a random size in '
  181. '[min_crop_size, max_crop_size]')
  182. aug.add_argument('--max-crop-size', type=int, default=-1,
  183. help='Crop both width and height into a random size in '
  184. '[min_crop_size, max_crop_size]')
  185. aug.add_argument('--brightness', type=float, default=0,
  186. help='brightness jittering, whose range is [0, 1]')
  187. aug.add_argument('--contrast', type=float, default=0,
  188. help='contrast jittering, whose range is [0, 1]')
  189. aug.add_argument('--saturation', type=float, default=0,
  190. help='saturation jittering, whose range is [0, 1]')
  191. aug.add_argument('--pca-noise', type=float, default=0,
  192. help='pca noise, whose range is [0, 1]')
  193. aug.add_argument('--random-resized-crop', type=int, default=0,
  194. help='whether to use random resized crop')
  195. aug.add_argument('--set-resnet-aug', action=SetResnetAugAction,
  196. help='whether to employ standard resnet augmentations (see data.py)')
  197. aug.add_argument('--set-data-aug-level', type=int, default=None, action=SetDataAugLevelAction,
  198. help='set multiple data augmentations based on a `level` (see data.py)')
  199. return aug
  200. def get_rec_iter(args, kv=None):
  201. image_shape = tuple([int(l) for l in args.image_shape.split(',')])
  202. if args.input_layout == 'NHWC':
  203. image_shape = image_shape[1:] + (image_shape[0],)
  204. if kv:
  205. (rank, nworker) = (kv.rank, kv.num_workers)
  206. else:
  207. (rank, nworker) = (0, 1)
  208. rgb_mean = [float(i) for i in args.rgb_mean.split(',')]
  209. rgb_std = [float(i) for i in args.rgb_std.split(',')]
  210. if args.input_layout == 'NHWC':
  211. raise ValueError('ImageRecordIter cannot handle layout {}'.format(args.input_layout))
  212. train = mx.io.ImageRecordIter(
  213. path_imgrec = args.data_train,
  214. path_imgidx = args.data_train_idx,
  215. label_width = 1,
  216. mean_r = rgb_mean[0],
  217. mean_g = rgb_mean[1],
  218. mean_b = rgb_mean[2],
  219. std_r = rgb_std[0],
  220. std_g = rgb_std[1],
  221. std_b = rgb_std[2],
  222. data_name = 'data',
  223. label_name = 'softmax_label',
  224. data_shape = image_shape,
  225. batch_size = args.batch_size,
  226. rand_crop = args.random_crop,
  227. max_random_scale = args.max_random_scale,
  228. pad = args.pad_size,
  229. fill_value = args.fill_value,
  230. random_resized_crop = args.random_resized_crop,
  231. min_random_scale = args.min_random_scale,
  232. max_aspect_ratio = args.max_random_aspect_ratio,
  233. min_aspect_ratio = args.min_random_aspect_ratio,
  234. max_random_area = args.max_random_area,
  235. min_random_area = args.min_random_area,
  236. min_crop_size = args.min_crop_size,
  237. max_crop_size = args.max_crop_size,
  238. brightness = args.brightness,
  239. contrast = args.contrast,
  240. saturation = args.saturation,
  241. pca_noise = args.pca_noise,
  242. random_h = args.max_random_h,
  243. random_s = args.max_random_s,
  244. random_l = args.max_random_l,
  245. max_rotate_angle = args.max_random_rotate_angle,
  246. max_shear_ratio = args.max_random_shear_ratio,
  247. rand_mirror = args.random_mirror,
  248. preprocess_threads = args.data_nthreads,
  249. shuffle = True,
  250. num_parts = nworker,
  251. part_index = rank)
  252. if args.data_val is None:
  253. return (train, None)
  254. val = mx.io.ImageRecordIter(
  255. path_imgrec = args.data_val,
  256. path_imgidx = args.data_val_idx,
  257. label_width = 1,
  258. mean_r = rgb_mean[0],
  259. mean_g = rgb_mean[1],
  260. mean_b = rgb_mean[2],
  261. std_r = rgb_std[0],
  262. std_g = rgb_std[1],
  263. std_b = rgb_std[2],
  264. data_name = 'data',
  265. label_name = 'softmax_label',
  266. batch_size = args.batch_size,
  267. round_batch = False,
  268. data_shape = image_shape,
  269. preprocess_threads = args.data_nthreads,
  270. rand_crop = False,
  271. rand_mirror = False,
  272. num_parts = nworker,
  273. part_index = rank)
  274. return (train, val)