Răsfoiți Sursa

[Jasper/PyT] Update DALI Jasper pipeline to functional API

Signed-off-by: Joaquin Anton <[email protected]>
Joaquin Anton 4 ani în urmă
părinte
comite
41e4a07a26

+ 1 - 1
PyTorch/SpeechRecognition/Jasper/common/dali/data_loader.py

@@ -119,7 +119,7 @@ class DaliDataLoader:
                                             train_pipeline=train_pipeline)
 
         return DaliJasperIterator([pipeline], transcripts=transcripts, symbols=symbols, batch_size=self.batch_size,
-                                  shard_size=self._shard_size(), train_iterator=train_pipeline)
+                                  reader_name="file_reader", train_iterator=train_pipeline)
 
     def _init_synth_iterator(self, batch_size, nfeatures, iters_per_epoch, ngpus):
         self.dataset_size = ngpus * iters_per_epoch * batch_size

+ 3 - 4
PyTorch/SpeechRecognition/Jasper/common/dali/iterator.py

@@ -42,7 +42,7 @@ class DaliJasperIterator(object):
     Use DataLoader instead.
     """
 
-    def __init__(self, dali_pipelines, transcripts, symbols, batch_size, shard_size, train_iterator: bool):
+    def __init__(self, dali_pipelines, transcripts, symbols, batch_size, reader_name, train_iterator: bool):
         self.transcripts = transcripts
         self.symbols = symbols
         self.batch_size = batch_size
@@ -51,9 +51,8 @@ class DaliJasperIterator(object):
 
         # in train pipeline shard_size is set to divisable by batch_size, so PARTIAL policy is safe
         self.dali_it = DALIGenericIterator(
-            dali_pipelines, ["audio", "label", "audio_shape"], size=shard_size,
-            dynamic_shape=True, auto_reset=True, last_batch_padded=True,
-            last_batch_policy=LastBatchPolicy.PARTIAL)
+            dali_pipelines, ["audio", "label", "audio_shape"], reader_name=reader_name,
+            dynamic_shape=True, auto_reset=True, last_batch_policy=LastBatchPolicy.PARTIAL)
 
     @staticmethod
     def _str2list(s: str):

+ 70 - 102
PyTorch/SpeechRecognition/Jasper/common/dali/pipeline.py

@@ -12,9 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import nvidia.dali
-import nvidia.dali.ops as ops
-import nvidia.dali.ops.random as random
+import nvidia.dali as dali
+import nvidia.dali.fn as fn
 import nvidia.dali.types as types
 import multiprocessing
 import numpy as np
@@ -23,7 +22,7 @@ import math
 import itertools
 
 
-class DaliPipeline(nvidia.dali.pipeline.Pipeline):
+class DaliPipeline():
     def __init__(self, *,
                  train_pipeline: bool,  # True if train pipeline, False if validation pipeline
                  device_id,
@@ -55,9 +54,8 @@ class DaliPipeline(nvidia.dali.pipeline.Pipeline):
                  mask_both_max_time,
                  mask_both_min_freq,
                  mask_both_max_freq,
-                 preprocessing_device="gpu"):
-        super().__init__(batch_size, num_threads, device_id)
-
+                 preprocessing_device="gpu",
+                 is_triton_pipeline=False):
         self._dali_init_log(locals())
 
         if torch.distributed.is_initialized():
@@ -71,7 +69,9 @@ class DaliPipeline(nvidia.dali.pipeline.Pipeline):
         assert self.preprocessing_device == "cpu" or self.preprocessing_device == "gpu", \
             "Incorrect preprocessing device. Please choose either 'cpu' or 'gpu'"
         self.frame_splicing_factor = frame_splicing_factor
-        assert frame_splicing_factor == 1, "DALI doesn't support frame splicing operation"
+
+        # TODO(janton): Implement this
+        assert frame_splicing_factor == 1, "Frame splicing is not yet implemented"
 
         self.resample_range = resample_range
         self.discrete_resample_range = discrete_resample_range
@@ -96,50 +96,76 @@ class DaliPipeline(nvidia.dali.pipeline.Pipeline):
         }
         self.do_remove_silence = True if silence_threshold is not None else False
 
-        self.read = ops.FileReader(device="cpu", file_root=file_root, file_list=file_list, shard_id=shard_id,
-                                   num_shards=n_shards, shuffle_after_epoch=train_pipeline)
+        @dali.pipeline_def
+        def dali_jasper_pipe():
+            if is_triton_pipeline:
+                assert not self.train, "Pipeline for Triton shall be a validation pipeline"
+                if torch.distributed.is_initialized():
+                    raise RuntimeError(
+                        "You're creating Triton pipeline, using multi-process mode. Please use single-process mode.")
+                encoded, label = fn.external_source(device="cpu", name="DALI_INPUT_0", no_copy=True)
+            else:
+                encoded, label = fn.readers.file(device="cpu", name="file_reader",
+                                                 file_root=file_root, file_list=file_list, shard_id=shard_id,
+                                                 num_shards=n_shards, shuffle_after_epoch=train_pipeline)
 
-        # TODO change ExternalSource to Uniform for new DALI release
-        if discrete_resample_range and resample_range is not None:
-            self.speed_perturbation_coeffs = ops.ExternalSource(device="cpu", cycle=True,
-                                                                source=self._discrete_resample_coeffs_generator)
-        elif resample_range is not None:
-            self.speed_perturbation_coeffs = random.Uniform(device="cpu", range=resample_range)
-        else:
-            self.speed_perturbation_coeffs = None
+            speed_perturbation_coeffs = None
+            if resample_range is not None:
+                if discrete_resample_range:
+                    values = [self.resample_range[0], 1.0, self.resample_range[1]]
+                    speed_perturbation_coeffs = fn.random.uniform(device="cpu", values=values)
+                else:
+                    speed_perturbation_coeffs = fn.random.uniform(device="cpu", range=resample_range)
+
+            if self.train and speed_perturbation_coeffs is not None:
+                dec_sample_rate_arg = speed_perturbation_coeffs * self.sample_rate
+            elif resample_range is None:
+                dec_sample_rate_arg = self.sample_rate
+            else:
+                dec_sample_rate_arg = None
+
+            audio, _ = fn.decoders.audio(encoded, sample_rate=dec_sample_rate_arg, dtype=types.FLOAT, downmix=True)
+
+            if self.do_remove_silence:
+                begin, length = fn.nonsilent_region(audio, cutoff_db=silence_threshold)
+                audio = fn.slice(audio, begin, length, axes=[0])
+
+            # Max duration drop is performed at DataLayer stage
 
-        self.decode = ops.AudioDecoder(device="cpu", sample_rate=self.sample_rate if resample_range is None else None,
-                                       dtype=types.FLOAT, downmix=True)
+            if self.preprocessing_device == "gpu":
+                audio = audio.gpu()
 
-        self.normal_distribution = random.Normal(device=preprocessing_device)
+            if self.dither_coeff != 0.:
+                audio = audio + fn.random.normal(device=preprocessing_device) * self.dither_coeff
 
-        self.preemph = ops.PreemphasisFilter(device=preprocessing_device, preemph_coeff=preemph_coeff)
+            audio = fn.preemphasis_filter(audio, preemph_coeff=preemph_coeff)
 
-        self.spectrogram = ops.Spectrogram(device=preprocessing_device, nfft=nfft,
-                                           window_length=window_size * sample_rate,
-                                           window_step=window_stride * sample_rate)
+            spec = fn.spectrogram(audio, nfft=nfft,
+                                  window_length=window_size * sample_rate, window_step=window_stride * sample_rate)
 
-        self.mel_fbank = ops.MelFilterBank(device=preprocessing_device, sample_rate=sample_rate, nfilter=self.nfeatures,
-                                           normalize=True)
+            mel_spec = fn.mel_filter_bank(spec, sample_rate=sample_rate, nfilter=self.nfeatures, normalize=True)
 
-        self.log_features = ops.ToDecibels(device=preprocessing_device, multiplier=np.log(10), reference=1.0,
-                                           cutoff_db=math.log(1e-20))
+            log_features = fn.to_decibels(mel_spec, multiplier=np.log(10), reference=1.0, cutoff_db=math.log(1e-20))
 
-        self.get_shape = ops.Shapes(device=preprocessing_device)
+            log_features_len = fn.shapes(log_features)
+            if self.frame_splicing_factor != 1:
+                log_features_len = self._div_ceil(log_features_len, self.frame_splicing_factor)
 
-        self.normalize = ops.Normalize(device=preprocessing_device, axes=[1])
+            log_features = fn.normalize(log_features, axes=[1])
+            log_features = fn.pad(log_features, axes=[1], fill_value=0, align=pad_align)
 
-        self.pad = ops.Pad(device=preprocessing_device, axes=[1], fill_value=0, align=pad_align)
+            if self.train and self._do_spectrogram_masking():
+                anchors, shapes  = fn.external_source(source=self._cutouts_generator, num_outputs=2, cycle=True)
+                log_features = fn.erase(log_features, anchor=anchors, shape=shapes, axes=[0, 1], fill_value=0,
+                                        normalized_anchor=True)
 
-        # Silence trimming
-        self.get_nonsilent_region = ops.NonsilentRegion(device="cpu", cutoff_db=silence_threshold)
-        self.trim_silence = ops.Slice(device="cpu", normalized_anchor=False, normalized_shape=False, axes=[0])
-        self.to_float = ops.Cast(device="cpu", dtype=types.FLOAT)
+            # When modifying DALI pipeline returns, make sure you update `output_map` in DALIGenericIterator invocation
+            return log_features.gpu(), label.gpu(), log_features_len.gpu()
 
-        # Spectrogram masking
-        self.spectrogram_cutouts = ops.ExternalSource(source=self._cutouts_generator, num_outputs=2, cycle=True)
-        self.mask_spectrogram = ops.Erase(device=preprocessing_device, axes=[0, 1], fill_value=0,
-                                          normalized_anchor=True)
+        self.pipe_handle = dali_jasper_pipe(batch_size=batch_size, num_threads=num_threads, device_id=device_id)
+
+    def get_pipeline(self):
+        return self.pipe_handle
 
     @classmethod
     def from_config(cls, train_pipeline: bool, device_id, batch_size, file_root: str, file_list: str, config_data: dict,
@@ -202,7 +228,7 @@ class DaliPipeline(nvidia.dali.pipeline.Pipeline):
             mask_both_min_freq = 0
             mask_both_max_freq = 0
 
-        return cls(train_pipeline=train_pipeline,
+        inst = cls(train_pipeline=train_pipeline,
                    device_id=device_id,
                    preprocessing_device=device_type,
                    num_threads=num_cpu_threads,
@@ -233,6 +259,7 @@ class DaliPipeline(nvidia.dali.pipeline.Pipeline):
                    mask_both_max_time=mask_both_max_time,
                    mask_both_min_freq=mask_both_min_freq,
                    mask_both_max_freq=mask_both_max_freq)
+        return inst.get_pipeline()
 
     @staticmethod
     def _dali_init_log(args: dict):
@@ -248,15 +275,6 @@ class DaliPipeline(nvidia.dali.pipeline.Pipeline):
     def _div_ceil(dividend, divisor):
         return (dividend + (divisor - 1)) // divisor
 
-    def _get_audio_len(self, inp):
-        return self.get_shape(inp) if self.frame_splicing_factor == 1 else \
-            self._div_ceil(self.get_shape(inp), self.frame_splicing_factor)
-
-    def _remove_silence(self, inp):
-        begin, length = self.get_nonsilent_region(inp)
-        out = self.trim_silence(inp, self.to_float(begin), self.to_float(length))
-        return out
-
     def _do_spectrogram_masking(self):
         return self.mask_params['time_num_regions'] > 0 or self.mask_params['freq_num_regions'] > 0 or \
                self.mask_params['both_num_regions'] > 0
@@ -321,13 +339,6 @@ class DaliPipeline(nvidia.dali.pipeline.Pipeline):
         )
         return anchors, shapes
 
-    def _discrete_resample_coeffs_generator(self):
-        """
-        Generate resample coeffs from discrete set
-        """
-        yield np.random.choice([self.resample_range[0], 1.0, self.resample_range[1]],
-                               size=self.max_batch_size).astype('float32')
-
     def _cutouts_generator(self):
         """
         Generator, that wraps cutouts creation in order to randomize inputs
@@ -340,56 +351,13 @@ class DaliPipeline(nvidia.dali.pipeline.Pipeline):
             """
             return map(list, zip(*tuples))
 
-        [anchors, shapes] = tuples2list([self._generate_cutouts() for _ in range(self.max_batch_size)])
+        [anchors, shapes] = tuples2list([self._generate_cutouts() for _ in range(self.pipe_handle.max_batch_size)])
         yield np.array(anchors, dtype=np.float32), np.array(shapes, dtype=np.float32)
 
-    def define_graph(self):
-        audio, label = self.read()
-        if not self.train or self.speed_perturbation_coeffs is None:
-            audio, sr = self.decode(audio)
-        else:
-            resample_coeffs = self.speed_perturbation_coeffs() * self.sample_rate
-            audio, sr = self.decode(audio, sample_rate=resample_coeffs)
-
-        if self.do_remove_silence:
-            audio = self._remove_silence(audio)
-
-        # Max duration drop is performed at DataLayer stage
-
-        if self.preprocessing_device == "gpu":
-            audio = audio.gpu()
-
-        if self.dither_coeff != 0.:
-            audio = audio + self.normal_distribution(audio) * self.dither_coeff
-
-        audio = self.preemph(audio)
-
-        audio = self.spectrogram(audio)
-        audio = self.mel_fbank(audio)
-        audio = self.log_features(audio)
-
-        audio_len = self._get_audio_len(audio)
-
-        audio = self.normalize(audio)
-        audio = self.pad(audio)
-
-        if self.train and self._do_spectrogram_masking():
-            anchors, shapes = self.spectrogram_cutouts()
-            audio = self.mask_spectrogram(audio, anchor=anchors, shape=shapes)
-
-        # When modifying DALI pipeline returns, make sure you update `output_map` in DALIGenericIterator invocation
-        return audio.gpu(), label.gpu(), audio_len.gpu()
-
-
 class DaliTritonPipeline(DaliPipeline):
     def __init__(self, **kwargs):
+        kwargs['is_triton_pipeline'] = True
         super().__init__(**kwargs)
-        assert not kwargs['train_pipeline'], "Pipeline for Triton shall be a validation pipeline"
-        if torch.distributed.is_initialized():
-            raise RuntimeError(
-                "You're creating Triton pipeline, using multi-process mode. Please use single-process mode.")
-        self.read = ops.ExternalSource(name="DALI_INPUT_0", no_copy=True, device="cpu")
-
 
 def serialize_dali_triton_pipeline(output_path: str, config_data: dict, config_features: dict):
     pipe = DaliTritonPipeline.from_config(train_pipeline=False, device_id=-1, batch_size=-1, file_root=None,