models.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522
  1. # Copyright 2017-2018 The Apache Software Foundation
  2. #
  3. # Licensed to the Apache Software Foundation (ASF) under one
  4. # or more contributor license agreements. See the NOTICE file
  5. # distributed with this work for additional information
  6. # regarding copyright ownership. The ASF licenses this file
  7. # to you under the Apache License, Version 2.0 (the
  8. # "License"); you may not use this file except in compliance
  9. # with the License. You may obtain a copy of the License at
  10. #
  11. # http://www.apache.org/licenses/LICENSE-2.0
  12. #
  13. # Unless required by applicable law or agreed to in writing,
  14. # software distributed under the License is distributed on an
  15. # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  16. # KIND, either express or implied. See the License for the
  17. # specific language governing permissions and limitations
  18. # under the License.
  19. #
  20. # -----------------------------------------------------------------------
  21. #
  22. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  23. #
  24. # Licensed under the Apache License, Version 2.0 (the "License");
  25. # you may not use this file except in compliance with the License.
  26. # You may obtain a copy of the License at
  27. #
  28. # http://www.apache.org/licenses/LICENSE-2.0
  29. #
  30. # Unless required by applicable law or agreed to in writing, software
  31. # distributed under the License is distributed on an "AS IS" BASIS,
  32. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  33. # See the License for the specific language governing permissions and
  34. # limitations under the License.
  35. import copy
  36. import mxnet as mx
  37. from mxnet.gluon.block import HybridBlock
  38. from mxnet.gluon import nn
  39. def add_model_args(parser):
  40. model = parser.add_argument_group('Model')
  41. model.add_argument('--arch', default='resnetv15',
  42. choices=['resnetv1', 'resnetv15',
  43. 'resnextv1', 'resnextv15',
  44. 'xception'],
  45. help='model architecture')
  46. model.add_argument('--num-layers', type=int, default=50,
  47. help='number of layers in the neural network, \
  48. required by some networks such as resnet')
  49. model.add_argument('--num-groups', type=int, default=32,
  50. help='number of groups for grouped convolutions, \
  51. required by some networks such as resnext')
  52. model.add_argument('--num-classes', type=int, default=1000,
  53. help='the number of classes')
  54. model.add_argument('--batchnorm-eps', type=float, default=1e-5,
  55. help='the amount added to the batchnorm variance to prevent output explosion.')
  56. model.add_argument('--batchnorm-mom', type=float, default=0.9,
  57. help='the leaky-integrator factor controling the batchnorm mean and variance.')
  58. model.add_argument('--fuse-bn-relu', type=int, default=0,
  59. help='have batchnorm kernel perform activation relu')
  60. model.add_argument('--fuse-bn-add-relu', type=int, default=0,
  61. help='have batchnorm kernel perform add followed by activation relu')
  62. return model
  63. class Builder:
  64. def __init__(self, dtype, input_layout, conv_layout, bn_layout,
  65. pooling_layout, bn_eps, bn_mom, fuse_bn_relu, fuse_bn_add_relu):
  66. self.dtype = dtype
  67. self.input_layout = input_layout
  68. self.conv_layout = conv_layout
  69. self.bn_layout = bn_layout
  70. self.pooling_layout = pooling_layout
  71. self.bn_eps = bn_eps
  72. self.bn_mom = bn_mom
  73. self.fuse_bn_relu = fuse_bn_relu
  74. self.fuse_bn_add_relu = fuse_bn_add_relu
  75. self.act_type = 'relu'
  76. self.bn_gamma_initializer = lambda last: 'zeros' if last else 'ones'
  77. self.linear_initializer = lambda groups=1: mx.init.Xavier(rnd_type='gaussian', factor_type="in",
  78. magnitude=2 * (groups ** 0.5))
  79. self.last_layout = self.input_layout
  80. def copy(self):
  81. return copy.copy(self)
  82. def batchnorm(self, last=False):
  83. gamma_initializer = self.bn_gamma_initializer(last)
  84. bn_axis = 3 if self.bn_layout == 'NHWC' else 1
  85. return self.sequence(
  86. self.transpose(self.bn_layout),
  87. nn.BatchNorm(axis=bn_axis, momentum=self.bn_mom, epsilon=self.bn_eps,
  88. gamma_initializer=gamma_initializer,
  89. running_variance_initializer=gamma_initializer)
  90. )
  91. def batchnorm_add_relu(self, last=False):
  92. gamma_initializer = self.bn_gamma_initializer(last)
  93. if self.fuse_bn_add_relu:
  94. bn_axis = 3 if self.bn_layout == 'NHWC' else 1
  95. return self.sequence(
  96. self.transpose(self.bn_layout),
  97. BatchNormAddRelu(axis=bn_axis, momentum=self.bn_mom,
  98. epsilon=self.bn_eps, act_type=self.act_type,
  99. gamma_initializer=gamma_initializer,
  100. running_variance_initializer=gamma_initializer)
  101. )
  102. return NonFusedBatchNormAddRelu(self, last=last)
  103. def batchnorm_relu(self, last=False):
  104. gamma_initializer = self.bn_gamma_initializer(last)
  105. if self.fuse_bn_relu:
  106. bn_axis = 3 if self.bn_layout == 'NHWC' else 1
  107. return self.sequence(
  108. self.transpose(self.bn_layout),
  109. nn.BatchNorm(axis=bn_axis, momentum=self.bn_mom,
  110. epsilon=self.bn_eps, act_type=self.act_type,
  111. gamma_initializer=gamma_initializer,
  112. running_variance_initializer=gamma_initializer)
  113. )
  114. return self.sequence(self.batchnorm(last=last), self.activation())
  115. def activation(self):
  116. return nn.Activation(self.act_type)
  117. def global_avg_pool(self):
  118. return self.sequence(
  119. self.transpose(self.pooling_layout),
  120. nn.GlobalAvgPool2D(layout=self.pooling_layout)
  121. )
  122. def max_pool(self, pool_size, strides=1, padding=True):
  123. padding = pool_size // 2 if padding is True else int(padding)
  124. return self.sequence(
  125. self.transpose(self.pooling_layout),
  126. nn.MaxPool2D(pool_size, strides=strides, padding=padding,
  127. layout=self.pooling_layout)
  128. )
  129. def conv(self, channels, kernel_size, padding=True, strides=1, groups=1, in_channels=0):
  130. padding = kernel_size // 2 if padding is True else int(padding)
  131. initializer = self.linear_initializer(groups=groups)
  132. return self.sequence(
  133. self.transpose(self.conv_layout),
  134. nn.Conv2D(channels, kernel_size=kernel_size, strides=strides,
  135. padding=padding, use_bias=False, groups=groups,
  136. in_channels=in_channels, layout=self.conv_layout,
  137. weight_initializer=initializer)
  138. )
  139. def separable_conv(self, channels, kernel_size, in_channels, padding=True, strides=1):
  140. return self.sequence(
  141. self.conv(in_channels, kernel_size, padding=padding,
  142. strides=strides, groups=in_channels, in_channels=in_channels),
  143. self.conv(channels, 1, in_channels=in_channels)
  144. )
  145. def dense(self, units, in_units=0):
  146. return nn.Dense(units, in_units=in_units,
  147. weight_initializer=self.linear_initializer())
  148. def transpose(self, to_layout):
  149. if self.last_layout == to_layout:
  150. return None
  151. ret = Transpose(self.last_layout, to_layout)
  152. self.last_layout = to_layout
  153. return ret
  154. def sequence(self, *seq):
  155. seq = list(filter(lambda x: x is not None, seq))
  156. if len(seq) == 1:
  157. return seq[0]
  158. ret = nn.HybridSequential()
  159. ret.add(*seq)
  160. return ret
  161. class Transpose(HybridBlock):
  162. def __init__(self, from_layout, to_layout):
  163. super().__init__()
  164. supported_layouts = ['NCHW', 'NHWC']
  165. if from_layout not in supported_layouts:
  166. raise ValueError('Not prepared to handle layout: {}'.format(from_layout))
  167. if to_layout not in supported_layouts:
  168. raise ValueError('Not prepared to handle layout: {}'.format(to_layout))
  169. self.from_layout = from_layout
  170. self.to_layout = to_layout
  171. def hybrid_forward(self, F, x):
  172. # Insert transpose if from_layout and to_layout don't match
  173. if self.from_layout == 'NCHW' and self.to_layout == 'NHWC':
  174. return F.transpose(x, axes=(0, 2, 3, 1))
  175. elif self.from_layout == 'NHWC' and self.to_layout == 'NCHW':
  176. return F.transpose(x, axes=(0, 3, 1, 2))
  177. else:
  178. return x
  179. def __repr__(self):
  180. s = '{name}({content})'
  181. if self.from_layout == self.to_layout:
  182. content = 'passthrough ' + self.from_layout
  183. else:
  184. content = self.from_layout + ' -> ' + self.to_layout
  185. return s.format(name=self.__class__.__name__,
  186. content=content)
  187. class LayoutWrapper(HybridBlock):
  188. def __init__(self, op, io_layout, op_layout, **kwargs):
  189. super(LayoutWrapper, self).__init__(**kwargs)
  190. with self.name_scope():
  191. self.layout1 = Transpose(io_layout, op_layout)
  192. self.op = op
  193. self.layout2 = Transpose(op_layout, io_layout)
  194. def hybrid_forward(self, F, *x):
  195. return self.layout2(self.op(*(self.layout1(y) for y in x)))
  196. class BatchNormAddRelu(nn.BatchNorm):
  197. def __init__(self, *args, **kwargs):
  198. super().__init__(*args, **kwargs)
  199. if self._kwargs.pop('act_type') != 'relu':
  200. raise ValueError('BatchNormAddRelu can be used only with ReLU as activation')
  201. def hybrid_forward(self, F, x, y, gamma, beta, running_mean, running_var):
  202. return F.BatchNormAddRelu(data=x, addend=y, gamma=gamma, beta=beta,
  203. moving_mean=running_mean, moving_var=running_var, name='fwd', **self._kwargs)
  204. class NonFusedBatchNormAddRelu(HybridBlock):
  205. def __init__(self, builder, **kwargs):
  206. super().__init__()
  207. self.bn = builder.batchnorm(**kwargs)
  208. self.act = builder.activation()
  209. def hybrid_forward(self, F, x, y):
  210. return self.act(self.bn(x) + y)
  211. # Blocks
  212. class ResNetBasicBlock(HybridBlock):
  213. def __init__(self, builder, channels, stride, downsample=False, in_channels=0,
  214. version='1', resnext_groups=None, **kwargs):
  215. super().__init__()
  216. assert not resnext_groups
  217. self.transpose = builder.transpose(builder.conv_layout)
  218. builder_copy = builder.copy()
  219. body = [
  220. builder.conv(channels, 3, strides=stride, in_channels=in_channels),
  221. builder.batchnorm_relu(),
  222. builder.conv(channels, 3),
  223. ]
  224. self.body = builder.sequence(*body)
  225. self.bn_add_relu = builder.batchnorm_add_relu(last=True)
  226. builder = builder_copy
  227. if downsample:
  228. self.downsample = builder.sequence(
  229. builder.conv(channels, 1, strides=stride, in_channels=in_channels),
  230. builder.batchnorm()
  231. )
  232. else:
  233. self.downsample = None
  234. def hybrid_forward(self, F, x):
  235. if self.transpose is not None:
  236. x = self.transpose(x)
  237. residual = x
  238. x = self.body(x)
  239. if self.downsample:
  240. residual = self.downsample(residual)
  241. x = self.bn_add_relu(x, residual)
  242. return x
  243. class ResNetBottleNeck(HybridBlock):
  244. def __init__(self, builder, channels, stride, downsample=False, in_channels=0,
  245. version='1', resnext_groups=None):
  246. super().__init__()
  247. stride1 = stride if version == '1' else 1
  248. stride2 = 1 if version == '1' else stride
  249. mult = 2 if resnext_groups else 1
  250. groups = resnext_groups or 1
  251. self.transpose = builder.transpose(builder.conv_layout)
  252. builder_copy = builder.copy()
  253. body = [
  254. builder.conv(channels * mult // 4, 1, strides=stride1, in_channels=in_channels),
  255. builder.batchnorm_relu(),
  256. builder.conv(channels * mult // 4, 3, strides=stride2),
  257. builder.batchnorm_relu(),
  258. builder.conv(channels, 1)
  259. ]
  260. self.body = builder.sequence(*body)
  261. self.bn_add_relu = builder.batchnorm_add_relu(last=True)
  262. builder = builder_copy
  263. if downsample:
  264. self.downsample = builder.sequence(
  265. builder.conv(channels, 1, strides=stride, in_channels=in_channels),
  266. builder.batchnorm()
  267. )
  268. else:
  269. self.downsample = None
  270. def hybrid_forward(self, F, x):
  271. if self.transpose is not None:
  272. x = self.transpose(x)
  273. residual = x
  274. x = self.body(x)
  275. if self.downsample:
  276. residual = self.downsample(residual)
  277. x = self.bn_add_relu(x, residual)
  278. return x
  279. class XceptionBlock(HybridBlock):
  280. def __init__(self, builder, definition, in_channels, relu_at_beginning=True):
  281. super().__init__()
  282. self.transpose = builder.transpose(builder.conv_layout)
  283. builder_copy = builder.copy()
  284. body = []
  285. if relu_at_beginning:
  286. body.append(builder.activation())
  287. last_channels = in_channels
  288. for channels1, channels2 in zip(definition, definition[1:] + [0]):
  289. if channels1 > 0:
  290. body.append(builder.separable_conv(channels1, 3, in_channels=last_channels))
  291. if channels2 > 0:
  292. body.append(builder.batchnorm_relu())
  293. else:
  294. body.append(builder.batchnorm(last=True))
  295. last_channels = channels1
  296. else:
  297. body.append(builder.max_pool(3, 2))
  298. self.body = builder.sequence(*body)
  299. builder = builder_copy
  300. if any(map(lambda x: x <= 0, definition)):
  301. self.shortcut = builder.sequence(
  302. builder.conv(last_channels, 1, strides=2, in_channels=in_channels),
  303. builder.batchnorm(),
  304. )
  305. else:
  306. self.shortcut = builder.sequence()
  307. def hybrid_forward(self, F, x):
  308. return self.shortcut(x) + self.body(x)
  309. # Nets
  310. class ResNet(HybridBlock):
  311. def __init__(self, builder, block, layers, channels, classes=1000,
  312. version='1', resnext_groups=None):
  313. super().__init__()
  314. assert len(layers) == len(channels) - 1
  315. self.version = version
  316. with self.name_scope():
  317. features = [
  318. builder.conv(channels[0], 7, strides=2),
  319. builder.batchnorm_relu(),
  320. builder.max_pool(3, 2),
  321. ]
  322. for i, num_layer in enumerate(layers):
  323. stride = 1 if i == 0 else 2
  324. features.append(self.make_layer(builder, block, num_layer, channels[i+1],
  325. stride, in_channels=channels[i],
  326. resnext_groups=resnext_groups))
  327. features.append(builder.global_avg_pool())
  328. self.features = builder.sequence(*features)
  329. self.output = builder.dense(classes, in_units=channels[-1])
  330. def make_layer(self, builder, block, layers, channels, stride,
  331. in_channels=0, resnext_groups=None):
  332. layer = []
  333. layer.append(block(builder, channels, stride, channels != in_channels,
  334. in_channels=in_channels, version=self.version,
  335. resnext_groups=resnext_groups))
  336. for _ in range(layers-1):
  337. layer.append(block(builder, channels, 1, False, in_channels=channels,
  338. version=self.version, resnext_groups=resnext_groups))
  339. return builder.sequence(*layer)
  340. def hybrid_forward(self, F, x):
  341. x = self.features(x)
  342. x = self.output(x)
  343. return x
  344. class Xception(HybridBlock):
  345. def __init__(self, builder,
  346. definition=([32, 64],
  347. [[128, 128, 0], [256, 256, 0], [728, 728, 0],
  348. *([[728, 728, 728]] * 8), [728, 1024, 0]],
  349. [1536, 2048]),
  350. classes=1000):
  351. super().__init__()
  352. definition1, definition2, definition3 = definition
  353. with self.name_scope():
  354. features = []
  355. last_channels = 0
  356. for i, channels in enumerate(definition1):
  357. features += [
  358. builder.conv(channels, 3, strides=(2 if i == 0 else 1), in_channels=last_channels),
  359. builder.batchnorm_relu(),
  360. ]
  361. last_channels = channels
  362. for i, block_definition in enumerate(definition2):
  363. features.append(XceptionBlock(builder, block_definition, in_channels=last_channels,
  364. relu_at_beginning=False if i == 0 else True))
  365. last_channels = list(filter(lambda x: x > 0, block_definition))[-1]
  366. for i, channels in enumerate(definition3):
  367. features += [
  368. builder.separable_conv(channels, 3, in_channels=last_channels),
  369. builder.batchnorm_relu(),
  370. ]
  371. last_channels = channels
  372. features.append(builder.global_avg_pool())
  373. self.features = builder.sequence(*features)
  374. self.output = builder.dense(classes, in_units=last_channels)
  375. def hybrid_forward(self, F, x):
  376. x = self.features(x)
  377. x = self.output(x)
  378. return x
  379. resnet_spec = {18: (ResNetBasicBlock, [2, 2, 2, 2], [64, 64, 128, 256, 512]),
  380. 34: (ResNetBasicBlock, [3, 4, 6, 3], [64, 64, 128, 256, 512]),
  381. 50: (ResNetBottleNeck, [3, 4, 6, 3], [64, 256, 512, 1024, 2048]),
  382. 101: (ResNetBottleNeck, [3, 4, 23, 3], [64, 256, 512, 1024, 2048]),
  383. 152: (ResNetBottleNeck, [3, 8, 36, 3], [64, 256, 512, 1024, 2048])}
  384. def create_resnet(builder, version, num_layers=50, resnext=False, classes=1000):
  385. assert num_layers in resnet_spec, \
  386. "Invalid number of layers: {}. Options are {}".format(
  387. num_layers, str(resnet_spec.keys()))
  388. block_class, layers, channels = resnet_spec[num_layers]
  389. assert not resnext or num_layers >= 50, \
  390. "Cannot create resnext with less then 50 layers"
  391. net = ResNet(builder, block_class, layers, channels, version=version,
  392. resnext_groups=args.num_groups if resnext else None)
  393. return net
  394. class fp16_model(mx.gluon.block.HybridBlock):
  395. def __init__(self, net, **kwargs):
  396. super(fp16_model, self).__init__(**kwargs)
  397. with self.name_scope():
  398. self._net = net
  399. def hybrid_forward(self, F, x):
  400. y = self._net(x)
  401. y = F.cast(y, dtype='float32')
  402. return y
  403. def get_model(arch, num_classes, num_layers, image_shape, dtype, amp,
  404. input_layout, conv_layout, batchnorm_layout, pooling_layout,
  405. batchnorm_eps, batchnorm_mom, fuse_bn_relu, fuse_bn_add_relu, **kwargs):
  406. builder = Builder(
  407. dtype = dtype,
  408. input_layout = input_layout,
  409. conv_layout = conv_layout,
  410. bn_layout = batchnorm_layout,
  411. pooling_layout = pooling_layout,
  412. bn_eps = batchnorm_eps,
  413. bn_mom = batchnorm_mom,
  414. fuse_bn_relu = fuse_bn_relu,
  415. fuse_bn_add_relu = fuse_bn_add_relu,
  416. )
  417. if arch.startswith('resnet') or arch.startswith('resnext'):
  418. version = '1' if arch in {'resnetv1', 'resnextv1'} else '1.5'
  419. net = create_resnet(
  420. builder = builder,
  421. version = version,
  422. resnext = arch.startswith('resnext'),
  423. num_layers = num_layers,
  424. classes = num_classes,
  425. )
  426. elif arch == 'xception':
  427. net = Xception(builder, classes=num_classes)
  428. else:
  429. raise ValueError('Wrong model architecture')
  430. net.hybridize(static_shape=True, static_alloc=True)
  431. if not amp:
  432. net.cast(dtype)
  433. if dtype == 'float16':
  434. net = fp16_model(net)
  435. return net