nemo_dle_model_converter.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. import argparse
  2. import io
  3. import sys
  4. from copy import deepcopy
  5. from functools import reduce
  6. from pathlib import Path
  7. from subprocess import CalledProcessError, check_output
  8. import torch
  9. import yaml
  10. import quartznet.config
  11. from common import helpers
  12. from common.features import FilterbankFeatures
  13. from quartznet.config import load as load_yaml
  14. from quartznet.model import QuartzNet, MaskedConv1d
  15. # Corresponding DLE <-> NeMo config keys
  16. cfg_key_map = {
  17. ("input_val", "audio_dataset", "sample_rate"): ("preprocessor", "sample_rate"),
  18. ("input_val", "filterbank_features", "dither"): ("preprocessor", "dither"),
  19. ("input_val", "filterbank_features", "frame_splicing"): ("preprocessor", "frame_splicing"),
  20. ("input_val", "filterbank_features", "n_fft"): ("preprocessor", "n_fft"),
  21. ("input_val", "filterbank_features", "n_filt"): ("preprocessor", "features"),
  22. ("input_val", "filterbank_features", "normalize"): ("preprocessor", "normalize"),
  23. ("input_val", "filterbank_features", "sample_rate"): ("preprocessor", "sample_rate"),
  24. ("input_val", "filterbank_features", "window"): ("preprocessor", "window"),
  25. ("input_val", "filterbank_features", "window_size"): ("preprocessor", "window_size"),
  26. ("input_val", "filterbank_features", "window_stride"): ("preprocessor", "window_stride"),
  27. ("labels",): ("decoder", "vocabulary"),
  28. ("quartznet", "decoder", "in_feats"): ("decoder", "feat_in"),
  29. ("quartznet", "encoder", "activation"): ("encoder", "activation"),
  30. ("quartznet", "encoder", "blocks"): ("encoder", "jasper"),
  31. ("quartznet", "encoder", "frame_splicing"): ("preprocessor", "frame_splicing"),
  32. ("quartznet", "encoder", "in_feats"): ("encoder", "feat_in"),
  33. ("quartznet", "encoder", "use_conv_masks"): ("encoder", "conv_mask"),
  34. }
  35. def load_nemo_ckpt(fpath):
  36. """Make a DeepLearningExamples state_dict and config from a .nemo file."""
  37. try:
  38. cmd = ['tar', 'Oxzf', fpath, './model_config.yaml']
  39. nemo_cfg = yaml.safe_load(io.BytesIO(check_output(cmd)))
  40. cmd = ['tar', 'Oxzf', fpath, './model_weights.ckpt']
  41. ckpt = torch.load(io.BytesIO(check_output(cmd)), map_location="cpu")
  42. except (FileNotFoundError, CalledProcessError):
  43. print('WARNING: Could not uncompress with tar. '
  44. 'Falling back to the tarfile module (might take a few minutes).')
  45. import tarfile
  46. with tarfile.open(fpath, "r:gz") as tar:
  47. f = tar.extractfile(tar.getmember("./model_config.yaml"))
  48. nemo_cfg = yaml.safe_load(f)
  49. f = tar.extractfile(tar.getmember("./model_weights.ckpt"))
  50. ckpt = torch.load(f, map_location="cpu")
  51. remap = lambda k: (k.replace("encoder.encoder", "encoder.layers")
  52. .replace("decoder.decoder_layers", "decoder.layers")
  53. .replace("conv.weight", "weight"))
  54. dle_ckpt = {'state_dict': {remap(k): v for k, v in ckpt.items()
  55. if "preproc" not in k}}
  56. dle_cfg = config_from_nemo(nemo_cfg)
  57. return dle_ckpt, dle_cfg
  58. def save_nemo_ckpt(dle_ckpt, dle_cfg, dest_path):
  59. """Save a DeepLearningExamples model as a .nemo file."""
  60. cfg = deepcopy(dle_cfg)
  61. dle_ckpt = torch.load(dle_ckpt, map_location="cpu")["ema_state_dict"]
  62. # Build a DLE model instance and fill with weights
  63. symbols = helpers.add_ctc_blank(cfg['labels'])
  64. enc_kw = quartznet.config.encoder(cfg)
  65. dec_kw = quartznet.config.decoder(cfg, n_classes=len(symbols))
  66. model = QuartzNet(enc_kw, dec_kw)
  67. model.load_state_dict(dle_ckpt, strict=True)
  68. # Reaname core modules, e.g., encoder.layers -> encoder.encoder
  69. model.encoder._modules['encoder'] = model.encoder._modules.pop('layers')
  70. model.decoder._modules['decoder_layers'] = model.decoder._modules.pop('layers')
  71. # MaskedConv1d is made via composition in NeMo, and via inheritance in DLE
  72. # Params for MaskedConv1d in NeMo have an additional '.conv.' infix
  73. def rename_convs(module):
  74. for name in list(module._modules.keys()):
  75. submod = module._modules[name]
  76. if isinstance(submod, MaskedConv1d):
  77. module._modules[f'{name}.conv'] = module._modules.pop(name)
  78. else:
  79. rename_convs(submod)
  80. rename_convs(model.encoder.encoder)
  81. # Use FilterbankFeatures to calculate fbanks and store with model weights
  82. feature_processor = FilterbankFeatures(
  83. **dle_cfg['input_val']['filterbank_features'])
  84. nemo_ckpt = model.state_dict()
  85. nemo_ckpt["preprocessor.featurizer.fb"] = feature_processor.fb
  86. nemo_ckpt["preprocessor.featurizer.window"] = feature_processor.window
  87. nemo_cfg = config_to_nemo(dle_cfg)
  88. # Prepare the directory for zipping
  89. ckpt_files = dest_path / "ckpt_files"
  90. ckpt_files.mkdir(exist_ok=True, parents=False)
  91. with open(ckpt_files / "model_config.yaml", "w") as f:
  92. yaml.dump(nemo_cfg, f)
  93. torch.save(nemo_ckpt, ckpt_files / "model_weights.ckpt")
  94. with tarfile.open(dest_path / "quartznet.nemo", "w:gz") as tar:
  95. tar.add(ckpt_files, arcname="./")
  96. def save_dle_ckpt(ckpt, cfg, dest_dir):
  97. torch.save(ckpt, dest_dir / "model.pt")
  98. with open(dest_dir / "model_config.yaml", "w") as f:
  99. yaml.dump(cfg, f)
  100. def set_nested_item(tgt, src, tgt_keys, src_keys):
  101. """Assigns nested dict keys, e.g., d1[a][b][c] = d2[e][f][g][h]."""
  102. tgt_nested = reduce(lambda d, k: d[k], tgt_keys[:-1], tgt)
  103. tgt_nested[tgt_keys[-1]] = reduce(lambda d, k: d[k], src_keys, src)
  104. def config_from_nemo(nemo_cfg):
  105. """Convert a DeepLearningExamples config to a NeMo format."""
  106. dle_cfg = {
  107. 'name': 'QuartzNet',
  108. 'input_val': {
  109. 'audio_dataset': {
  110. 'normalize_transcripts': True,
  111. },
  112. 'filterbank_features': {
  113. 'pad_align': 16,
  114. },
  115. },
  116. 'quartznet': {
  117. 'decoder': {},
  118. 'encoder': {},
  119. },
  120. }
  121. for dle_keys, nemo_keys in cfg_key_map.items():
  122. try:
  123. set_nested_item(dle_cfg, nemo_cfg, dle_keys, nemo_keys)
  124. except KeyError:
  125. print(f'WARNING: Could not load config {nemo_keys} as {dle_keys}.')
  126. # mapping kernel_size is not expressable with cfg_map
  127. for block in dle_cfg["quartznet"]["encoder"]["blocks"]:
  128. block["kernel_size"] = block.pop("kernel")
  129. return dle_cfg
  130. def config_to_nemo(dle_cfg):
  131. """Convert a DeepLearningExamples config to a NeMo format."""
  132. nemo_cfg = {
  133. "target": "nemo.collections.asr.models.ctc_models.EncDecCTCModel",
  134. "dropout": 0.0,
  135. "preprocessor": {
  136. "_target_": "nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor",
  137. "stft_conv": False,
  138. },
  139. "encoder": {
  140. "_target_": "nemo.collections.asr.modules.ConvASREncoder",
  141. "jasper": {}
  142. },
  143. "decoder": {
  144. "_target_": "nemo.collections.asr.modules.ConvASRDecoder",
  145. },
  146. }
  147. for dle_keys, nemo_keys in cfg_key_map.items():
  148. try:
  149. set_nested_item(nemo_cfg, dle_cfg, nemo_keys, dle_keys)
  150. except KeyError:
  151. print(f"WARNING: Could not load config {dle_keys} as {nemo_keys}.")
  152. nemo_cfg["sample_rate"] = nemo_cfg["preprocessor"]["sample_rate"]
  153. nemo_cfg["repeat"] = nemo_cfg["encoder"]["jasper"][1]["repeat"]
  154. nemo_cfg["separable"] = nemo_cfg["encoder"]["jasper"][1]["separable"]
  155. nemo_cfg["labels"] = nemo_cfg["decoder"]["vocabulary"]
  156. nemo_cfg["decoder"]["num_classes"] = len(nemo_cfg["decoder"]["vocabulary"])
  157. # mapping kernel_size is not expressable with cfg_map
  158. for block in nemo_cfg["encoder"]["jasper"]:
  159. if "kernel_size" in block:
  160. block["kernel"] = block.pop("kernel_size")
  161. return nemo_cfg
  162. if __name__ == "__main__":
  163. parser = argparse.ArgumentParser(description="QuartzNet DLE <-> NeMo model converter.")
  164. parser.add_argument("source_model", type=Path,
  165. help="A DLE or NeMo QuartzNet model to be converted (.pt or .nemo, respectively)")
  166. parser.add_argument("dest_dir", type=Path, help="Destination directory")
  167. parser.add_argument("--dle_config_yaml", type=Path,
  168. help="A DLE config .yaml file, required only to convert DLE -> NeMo")
  169. args = parser.parse_args()
  170. ext = args.source_model.suffix.lower()
  171. if ext == ".nemo":
  172. ckpt, cfg = load_nemo_ckpt(args.source_model)
  173. save_dle_ckpt(ckpt, cfg, args.dest_dir)
  174. elif ext == ".pt":
  175. dle_cfg = load_yaml(args.dle_config_yaml)
  176. save_nemo_ckpt(args.source_model, dle_cfg, args.dest_dir)
  177. else:
  178. raise ValueError(f"Unknown extension {ext}.")
  179. print('Converted succesfully.')