arg_parser.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
  6. #
  7. # Licensed under the Apache License, Version 2.0 (the "License");
  8. # you may not use this file except in compliance with the License.
  9. # You may obtain a copy of the License at
  10. #
  11. # http://www.apache.org/licenses/LICENSE-2.0
  12. #
  13. # Unless required by applicable law or agreed to in writing, software
  14. # distributed under the License is distributed on an "AS IS" BASIS,
  15. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. # See the License for the specific language governing permissions and
  17. # limitations under the License.
  18. import os
  19. import sys
  20. def populate(parser):
  21. choices = ["pretrain", "finetune"]
  22. parser.add_argument("mode", help="Training mode", choices=choices)
  23. mode = parser.parse_args([a for a in sys.argv[1:] if a in choices]).mode
  24. if mode == "pretrain":
  25. populate_pretraining(parser)
  26. else:
  27. populate_finetuning(parser)
  28. populate_common(parser)
  29. return parser
  30. def populate_infer(parser):
  31. populate_finetuning(parser)
  32. populate_common(parser)
  33. _populate_infer(parser)
  34. return parser
  35. def populate_common(parser):
  36. train = parser.add_argument_group("training setup")
  37. train.add_argument("--epochs_this_job", default=0, type=int,
  38. help="Run for a number of epochs and exit")
  39. train.add_argument("--cudnn_benchmark", action="store_true",
  40. help="Enable cudnn benchmark")
  41. train.add_argument("--local_rank", "--local-rank", default=os.getenv("LOCAL_RANK", 0),
  42. type=int, help="GPU id used for distributed training")
  43. optim = parser.add_argument_group("optimization setup")
  44. optim.add_argument("--optimizer", default="adam", type=str,
  45. help="Optimization algorithm")
  46. optim.add_argument("--ema", type=float, default=0.0,
  47. help="Discount factor for EMA of model weights")
  48. io = parser.add_argument_group("feature and checkpointing setup")
  49. io.add_argument("--log_frequency", default=1, type=int,
  50. help="Number of steps between printing training stats")
  51. io.add_argument("--output_dir", type=str, required=True,
  52. help="Directory for logs and checkpoints")
  53. io.add_argument("--log_file", type=str, default=None,
  54. help="Path to save the training logfile.")
  55. io.add_argument("--benchmark_epochs_num", type=int, default=3,
  56. help="Number of last epochs to calculate throughput stats")
  57. ckpt = parser.add_argument_group("checkpoint")
  58. ckpt.add_argument("--no_save", action="store_true",
  59. help="Don't save models or checkpoints")
  60. ckpt.add_argument("--resume", action="store_true",
  61. help="Try to resume from last saved checkpoint")
  62. ckpt.add_argument("--ckpt", default=None, type=str,
  63. help="Path to a checkpoint for resuming training")
  64. ckpt.add_argument("--save_frequency", default=10, type=int,
  65. help="Checkpoint saving frequency in epochs")
  66. ckpt.add_argument("--keep_milestones", default=[100, 200, 300, 400],
  67. type=int, nargs="+",
  68. help="Milestone checkpoints to keep from removing")
  69. # io.add_argument("--save_best_from", default=380, type=int,
  70. # help="Epoch on which to begin tracking best checkpoint (dev WER)")
  71. common = parser.add_argument_group("common")
  72. common.add_argument("--seed", type=int, default=1,
  73. help="Pseudo random number generator seed")
  74. common.add_argument("--cpu", action="store_true",
  75. help="Use CPU instead of CUDA")
  76. common.add_argument("--amp", action="store_true",
  77. help="Use automatic mixed precision")
  78. common.add_argument("--fp16", action="store_true",
  79. help="If fp16 is being used")
  80. common.add_argument("--bf16", action="store_true",
  81. help="Train in bfloat16 precision")
  82. common.add_argument("--min_loss_scale", type=float, default=0.0001,
  83. help="Minimum FP16/AMP loss scale, after which "
  84. "training is stopped")
  85. common.add_argument("--fp16_init_scale", type=int, default=128,
  86. help="Default FP16 loss scale")
  87. common.add_argument("--fp32_transformer_layernorm", action="store_true",
  88. help="Calculate MHA LayerNorms in full precision")
  89. common.add_argument("--fp32_mha_softmax", action="store_true",
  90. help="Calculate multi-head attention to FP32")
  91. common.add_argument("--fp32_cosine_sim", action="store_true",
  92. help="Calculate cosine similarity in FP32")
  93. common.add_argument("--fp32_pos_conv", action="store_true",
  94. help="Calculate positional conv in FP32")
  95. common.add_argument("--fp32_conv_norms", action="store_true",
  96. help="Calculate normalization in conv layers in FP32")
  97. common.add_argument("--mha", type=str, default="fairseq",
  98. choices=["fairseq", "pyt"], help="MHA implementation")
  99. common.add_argument("--num_concat_batches", type=int, default=1)
  100. dataset = parser.add_argument_group("dataset")
  101. dataset.add_argument("--num_workers", type=int, default=6,
  102. help="How many subprocesses to use for data loading")
  103. dataset.add_argument("--skip_invalid_size_inputs_valid_test",
  104. action="store_true",
  105. help="Ignore too long or too short lines in valid and"
  106. " test set")
  107. dataset.add_argument("--max_tokens", type=int, default=1400000,
  108. help="Maximum number of tokens in a batch")
  109. dataset.add_argument("--max_tokens_valid", type=int, default=1400000,
  110. help="Maximum number of tokens in a validation batch "
  111. "(defaults to --max-tokens)")
  112. dataset.add_argument("--required_batch_size_multiple", type=int, default=8,
  113. help="Batch size will be a multiplier of this value")
  114. dataset.add_argument("--required_seq_len_multiple", type=int, default=2,
  115. help="Pad the input to encoder such that the sequence"
  116. " length is divisible by multiple")
  117. dataset.add_argument("--train_subset", type=str, default="train",
  118. help="Data subset to use for training (e.g. train, "
  119. "valid, test)")
  120. dataset.add_argument("--valid_subset", type=str, default="valid",
  121. help="Comma separated list of data subsets to use for"
  122. " validation (e.g. train, valid, test)")
  123. dataset.add_argument("--batch_size", type=int, default=None,
  124. help="Number of examples in a batch")
  125. dataset.add_argument("--batch_size_valid", type=int, default=None,
  126. help="Batch size of the validation batch (defaults "
  127. "to --batch-size)")
  128. task = parser.add_argument_group("task")
  129. task.add_argument("--data", type=str,
  130. default="/workspace/fairseq/librispeech",
  131. help="Path to data directory")
  132. task.add_argument("--sample_rate", type=int, default=16000,
  133. help="Target sample rate. audio files will be up/down "
  134. "sampled to this rate")
  135. task.add_argument("--enable_padding", action="store_true",
  136. help="Pad shorter samples instead of cropping")
  137. task.add_argument("--min_sample_size", type=int, default=None,
  138. help="Min sample size to crop to for batching")
  139. task.add_argument("--max_sample_size", type=int, default=None,
  140. help="Max sample size to crop to for batching")
  141. task.add_argument("--num_batch_buckets", type=int, default=0,
  142. help="If >0, then bucket source and target lengths into "
  143. "N buckets and pad accordingly; this is useful on "
  144. "TPUs to minimize the number of compilations")
  145. opt = parser.add_argument_group("optimization & optimizer")
  146. opt.add_argument("--max_update", type=int, default=400000,
  147. help="Force stop training at specified update")
  148. opt.add_argument("--update_freq", type=int, nargs="+", default=[64],
  149. help="Accumulate grads and update params every N batches")
  150. opt.add_argument("--lr", type=float, nargs="+", default=[0.0005],
  151. help="Max learning rate, must be more than cfg.min_lr")
  152. opt.add_argument("--adam_betas", type=float, nargs="+", default=[0.9, 0.98],
  153. help="Betas for Adam optimizer")
  154. opt.add_argument("--adam_eps", type=float, default=1e-06,
  155. help="Epsilon for Adam optimizer")
  156. opt.add_argument("--weight_decay", type=float, default=0.01,
  157. help="Weight decay")
  158. opt.add_argument("--clip_norm", type=float, default=0.0,
  159. help="Clip threshold of gradients")
  160. sched = parser.add_argument_group("lr_scheduler")
  161. sched.add_argument("--lr_policy", type=str, default="poly",
  162. choices=["poly", "exp"], help="LR decay policy")
  163. sched.add_argument("--warmup_updates", type=int, default=32000,
  164. help="Warmup the learning rate linearly for the first "
  165. "N updates")
  166. sched.add_argument("--hold_updates", type=int, default=0,
  167. help="The number of updates with const learning rate")
  168. sched.add_argument("--initial_lr_scale", type=float, default=0.0,
  169. help="Initial learning rate scale")
  170. sched.add_argument("--final_lr_scale", type=float, default=0.0,
  171. help="Final learning rate scale")
  172. sched.add_argument("--lr_poly_power", type=float, default=1.0,
  173. help="Poly lr policy policy power")
  174. sched.add_argument("--lr_exp_decay", type=float, default=None,
  175. help="Exp lr policy decay factor")
  176. drop = parser.add_argument_group("dropout")
  177. drop.add_argument("--dropout", type=float, default=0.1,
  178. help="Dropout probability for the transformer")
  179. drop.add_argument("--attention_dropout", type=float, default=0.0,
  180. help="Dropout probability for attention weights")
  181. drop.add_argument("--activation_dropout", type=float, default=0.0,
  182. help="Dropout probability after activation in FFN")
  183. drop.add_argument("--dropout_input", type=float, default=0.1,
  184. help="Dropout to apply to the input (after feat extr)")
  185. drop.add_argument("--dropout_features", type=float, default=0.1,
  186. help="Dropout to apply to the features (after feat extr)")
  187. mask = parser.add_argument_group("input masking")
  188. mask.add_argument("--apply_mask", action="store_true",
  189. help="Apply masking during fine-tuning")
  190. mask.add_argument("--mask_length", type=int, default=10,
  191. help="Repeat the mask indices multiple times")
  192. mask.add_argument("--mask_prob", type=float, default=0.5,
  193. help="Probability of replacing a token with mask "
  194. "(normalized by length)")
  195. mask.add_argument("--require_same_masks", type=bool, default=True,
  196. help="Whether to number of masked timesteps must be the"
  197. " same across all examples in a batch")
  198. mask.add_argument("--mask_selection", default="static",
  199. choices=["static", "uniform", "normal", "poisson"],
  200. help="How to choose masks")
  201. mask.add_argument("--mask_other", type=float, default=0,
  202. help="Secondary mask argument (used for more complex "
  203. "distributions), see help in compute_mask_indices")
  204. mask.add_argument("--no_mask_overlap", type=bool, default=False,
  205. help="Whether to allow masks to overlap")
  206. mask.add_argument("--mask_min_space", type=int, default=1,
  207. help="Min space between spans (if no overlap is enabled)")
  208. mask.add_argument("--mask_channel_length", type=int, default=10,
  209. help="Length of the mask for features (channels)")
  210. mask.add_argument("--mask_channel_prob", type=float, default=0.0,
  211. help="Probability of replacing a feature with 0")
  212. mask.add_argument("--mask_channel_before", type=bool, default=False,
  213. help="Apply channel-masking before frequency-masking")
  214. mask.add_argument("--mask_channel_selection", default="static",
  215. choices=["static", "uniform", "normal", "poisson"],
  216. help="How to choose mask length for channel masking")
  217. mask.add_argument("--mask_channel_other", type=float, default=0,
  218. help="Secondary mask argument (used for more complex "
  219. "distributions), see help in compute_mask_indicesh")
  220. mask.add_argument("--no_mask_channel_overlap", type=bool, default=False,
  221. help="Whether to allow channel masks to overlap")
  222. mask.add_argument("--mask_channel_min_space", type=int, default=1,
  223. help="Min space between spans (if no overlap is enabled)")
  224. parser.add_argument("--feature_grad_mult", type=float, default=0.1,
  225. help="Reset feature grad mult in wav2vec 2.0 to this")
  226. # NOTE In Fairseq this is called `--layerdrop` in fine-tuning yamls
  227. parser.add_argument("--encoder_layerdrop", type=float, default=0.05,
  228. help="Probability of dropping a layer in wav2vec 2.0")
  229. mask.add_argument("--mask_dropout", type=float, default=0.0,
  230. help="Percent of masks to unmask for each sample")
  231. def populate_finetuning(parser):
  232. """Args for fine-tuning, absent from pre-trained ckpts."""
  233. ft = parser.add_argument_group("supervised fine-tuning")
  234. ft.add_argument("--final_dropout", type=float, default=0.0,
  235. help="Dropout after transformer and before final proj")
  236. ft.add_argument("--w2v_path", type=str, default=None,
  237. help="Path to wav2vec 2.0 model")
  238. ft.add_argument("--blank_weight", type=float, default=0)
  239. ft.add_argument("--blank_mode", type=str, default="add")
  240. ft.add_argument("--labels", type=str, default="ltr",
  241. help="Extension of the label file to load for fine-tuning")
  242. ft.add_argument("--freeze_finetune_updates", type=int, default=0,
  243. help="Don't finetune wav2vec for this many updates")
  244. def populate_pretraining(parser):
  245. """During fine-tuning these parameters will be loaded from a ckpt."""
  246. model = parser.add_argument_group("model")
  247. model.add_argument("--extractor_mode", type=str, default="default",
  248. help="Mode for feature extractor. default has a single "
  249. "group norm with d groups in the first conv block,"
  250. " whereas layer_norm has layer norms in every "
  251. "block (meant to use with normalize=True)")
  252. model.add_argument("--encoder_layers", type=int, default=12,
  253. help="Num encoder layers in the transformer")
  254. model.add_argument("--encoder_embed_dim", type=int, default=768,
  255. help="Encoder embedding dimension")
  256. model.add_argument("--encoder_ffn_embed_dim", type=int, default=3072,
  257. help="Encoder embedding dimension for FFN")
  258. model.add_argument("--encoder_attention_heads", type=int, default=12,
  259. help="Num encoder attention heads")
  260. model.add_argument("--activation_fn", type=str, default="gelu",
  261. help="Activation function to use")
  262. model.add_argument("--final_dim", type=int, default=256,
  263. help="Project final representations and targets to this"
  264. " many dimensions. set to encoder_embed_dim "
  265. "is <= 0")
  266. model.add_argument("--layer_norm_first", action="store_true",
  267. help="Apply layernorm first in the transformer")
  268. model.add_argument("--conv_feature_layers", type=str,
  269. default="[(512,10,5)]+[(512,3,2)]*4+[(512,2,2)]+[(512,2,2)]",
  270. help="String describing convolutional feature "
  271. "extraction layers in form of a python list that "
  272. "contains [(dim, kernel_size, stride), ...]")
  273. model.add_argument("--conv_bias", action="store_true",
  274. help="Include bias in conv encoder")
  275. model.add_argument("--logit_temp", type=float, default=0.1,
  276. help="Temperature to divide logits by")
  277. model.add_argument("--quantize_targets", action="store_true",
  278. help="Use quantized targets")
  279. model.add_argument("--quantize_input", action="store_true",
  280. help="Use quantized inputs")
  281. model.add_argument("--target_glu", action="store_true",
  282. help="Adds projection + glu to targets")
  283. model.add_argument("--quantizer_depth", type=int, default=1,
  284. help="Number of quantizer layers")
  285. model.add_argument("--quantizer_factor", type=int, default=3,
  286. help="Dimensionality increase for inner quantizer "
  287. "layers (if depth > 1)")
  288. model.add_argument("--latent_vars", type=int, default=320,
  289. help="Number of latent variables V in each group of the"
  290. " codebook")
  291. model.add_argument("--latent_groups", type=int, default=2,
  292. help="Number of groups G of latent variables in the "
  293. "codebook")
  294. model.add_argument("--latent_dim", type=int, default=0,
  295. help="If > 0, uses this dimensionality for latent var"
  296. "iables. otherwise uses final_dim / latent_groups")
  297. model.add_argument("--num_negatives", type=int, default=100,
  298. help="Num of sampled negatives")
  299. model.add_argument("--negatives_from_everywhere", action="store_true",
  300. help="Sample negatives from everywhere, not just masked"
  301. " states")
  302. model.add_argument("--cross_sample_negatives", type=int, default=0,
  303. help="Num of cross sampled negatives")
  304. model.add_argument("--codebook_negatives", type=int, default=0,
  305. help="Number of negative examples codebook")
  306. model.add_argument("--conv_pos", type=int, default=128,
  307. help="Number of filters for convolutional positional "
  308. "embeddings")
  309. model.add_argument("--conv_pos_groups", type=int, default=16,
  310. help="Number of groups for convolutional positional "
  311. "embedding")
  312. model.add_argument("--latent_temp", type=float, nargs="+",
  313. default=[2.0, 0.5, 0.999995],
  314. help="Legacy (to be removed)")
  315. model.add_argument("--normalize", action="store_true",
  316. help="If set, normalizes input to have 0 mean and unit "
  317. "variance")
  318. parser.add_argument("--log_keys", type=str, nargs="*",
  319. default=["prob_perplexity", "code_perplexity", "temp"],
  320. help="Additional output keys to log")
  321. crit = parser.add_argument_group("criterion")
  322. crit.add_argument("--infonce", action="store_true",
  323. help="If set, uses cross entropy instead of binary cross"
  324. " entropy (i.e. InfoNCE loss)")
  325. crit.add_argument("--loss_weights", type=float, nargs="*",
  326. default=[0.1, 10.0], help="Weights for the loss terms")
  327. joc = parser.add_argument_group("joc experimental")
  328. joc.add_argument("--use_spectrogram_features", action="store_true",
  329. help="Train on input spectrograms")
  330. joc.add_argument("--rotary_embeddings", action="store_true",
  331. help="Use rotarty embeddings for Transformer layers")
  332. joc.add_argument("--hourglass_transformer", type=str, default=None,
  333. help="Specify the number of layers and shorteining, e.g.,"
  334. " [n_pre,(n_hourglass, shorten_factor),n_post]")
  335. joc.add_argument("--hourglass_resample", type=str, default="naive",
  336. help="Method of up/downsampling in the hourglass model")
  337. joc.add_argument("--spectrogram_feature_stacking", type=int, default=1)
  338. joc.add_argument("--spectrogram_feature_subsampling", type=int, default=1)
  339. joc.add_argument("--spectrogram_window_size", type=float, default=0.02)
  340. joc.add_argument("--spectrogram_window_stride", type=float, default=0.01)
  341. joc.add_argument("--spectrogram_n_filt", type=int, default=80)
  342. return parser
  343. def _populate_infer(parser):
  344. # Fine-tuning only
  345. infer = parser.add_argument_group("inference")
  346. infer.add_argument("--steps", default=0, type=int,
  347. help="Eval this many steps for every worker")
  348. infer.add_argument("--warmup_steps", default=0, type=int,
  349. help="Burn-in period before measuring latencies")
  350. infer.add_argument("--labels_path", type=str, default=None,
  351. help="Path to output labels file, e.g., dict.ltr.txt")
  352. infer.add_argument("--save_predictions", type=str, default=None,
  353. help="Save predictions in text form at this location")
  354. infer.add_argument("--save_logits", default=None, type=str,
  355. help="Save output logits under specified path")
  356. infer.add_argument("--transcribe_wav", type=str,
  357. help="Path to a single .wav file (16KHz)")
  358. infer.add_argument("--transcribe_filelist", type=str,
  359. help="Path to a filelist with one .wav path per line")
  360. infer.add_argument("--torchscript", action="store_true",
  361. help="Evaluate with a TorchScripted model")
  362. infer.add_argument("--w2v_path_for_args", type=str, default=None,
  363. help="Args to build model for inference (weights will "
  364. "be loaded from --w2v_path)")