Просмотр исходного кода

[FastPitch/PyT] Fix handling heteronyms when training with a lexicone

Adrian Lancucki 4 лет назад
Родитель
Сommit
37a5e77ccc

+ 8 - 7
PyTorch/SpeechSynthesis/FastPitch/common/text/cmudict.py

@@ -29,14 +29,9 @@ class CMUDict:
     if file_or_path is None:
       self._entries = {}
     else:
-      self.initialize(file_or_path, keep_ambiguous)
+      self.initialize(file_or_path, heteronyms_path, keep_ambiguous)
 
-    if heteronyms_path is None:
-      self.heteronyms = []
-    else:
-      self.heteronyms = set(lines_to_list(heteronyms_path))
-
-  def initialize(self, file_or_path, keep_ambiguous=True):
+  def initialize(self, file_or_path, heteronyms_path, keep_ambiguous=True):
     if isinstance(file_or_path, str):
       try:
         with open(file_or_path, encoding='latin-1') as f:
@@ -55,6 +50,12 @@ class CMUDict:
       entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
     self._entries = entries
 
+    if heteronyms_path is None:
+      self.heteronyms = []
+    else:
+      self.heteronyms = set(lines_to_list(heteronyms_path))
+
+
   def __len__(self):
     if len(self._entries) == 0:
       raise ValueError("CMUDict not initialized")

+ 1 - 1
PyTorch/SpeechSynthesis/FastPitch/inference.py

@@ -296,7 +296,7 @@ def main():
     args, unk_args = parser.parse_known_args()
 
     if args.p_arpabet > 0.0:
-        cmudict.initialize(args.cmudict_path, keep_ambiguous=True)
+        cmudict.initialize(args.cmudict_path, args.heteronyms_path)
 
     torch.backends.cudnn.benchmark = args.cudnn_benchmark
 

+ 1 - 1
PyTorch/SpeechSynthesis/FastPitch/train.py

@@ -351,7 +351,7 @@ def main():
     args, _ = parser.parse_known_args()
 
     if args.p_arpabet > 0.0:
-        cmudict.initialize(args.cmudict_path, keep_ambiguous=True)
+        cmudict.initialize(args.cmudict_path, args.heteronyms_path)
 
     distributed_run = args.world_size > 1
 

+ 1 - 1
PyTorch/SpeechSynthesis/FastPitch/triton/convert_model.py

@@ -148,7 +148,7 @@ def main():
 
             if args.p_arpabet > 0.0:
                 from common.text import cmudict
-                cmudict.initialize(args.cmudict_path, keep_ambiguous=True)
+                cmudict.initialize(args.cmudict_path, args.heteronyms_path)
 
             get_dataloader_fn = load_from_file(args.dataloader, label="dataloader", target=DATALOADER_FN_NAME)
             dataloader_fn = ArgParserGenerator(get_dataloader_fn).from_args(args)

+ 1 - 1
PyTorch/SpeechSynthesis/FastPitch/triton/dataloader.py

@@ -48,7 +48,7 @@ def get_dataloader_fn(batch_size: int = 8,
                       mel_fmax: float = 8000.0):
 
     if p_arpabet > 0.0:
-        cmudict.initialize(cmudict_path, keep_ambiguous=True)
+        cmudict.initialize(cmudict_path, heteronyms_path)
 
     dataset = TTSDataset(dataset_path=dataset_path,
                          audiopaths_and_text=filelist,