From e69ec8ed89dbbb3698bdf1c0030231724f4313a5 Mon Sep 17 00:00:00 2001 From: taejinp Date: Wed, 13 Nov 2024 18:46:20 -0800 Subject: [PATCH 01/16] Adding the first pr files models and dataset Signed-off-by: taejinp --- .../sortformer_diar_HL_callhome_part1.yaml | 18 + .../sortformer_diar_HL_dihard.yaml | 17 + .../sortformer_diar_encoder_infer.py | 132 ++ .../sortformer_diar_encoder_train.py | 54 + .../asr/data/audio_to_diar_label.py | 490 ++++++- .../asr/data/audio_to_diar_label_lhotse.py | 76 + .../asr/models/sortformer_diar_models.py | 565 ++++++++ .../asr/parts/utils/asr_multispeaker_utils.py | 1231 +++++++++++++++++ .../common/parts/preprocessing/collections.py | 192 ++- 9 files changed, 2759 insertions(+), 16 deletions(-) create mode 100644 examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_HL_callhome_part1.yaml create mode 100644 examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_HL_dihard.yaml create mode 100644 examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_encoder_infer.py create mode 100644 examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_encoder_train.py create mode 100644 nemo/collections/asr/data/audio_to_diar_label_lhotse.py create mode 100644 nemo/collections/asr/models/sortformer_diar_models.py create mode 100644 nemo/collections/asr/parts/utils/asr_multispeaker_utils.py diff --git a/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_HL_callhome_part1.yaml b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_HL_callhome_part1.yaml new file mode 100644 index 000000000000..6b960e2d5950 --- /dev/null +++ b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_HL_callhome_part1.yaml @@ -0,0 +1,18 @@ +# Postprocessing parameters for timestamp outputs from speaker diarization models. +# This speaker diarization postprocessing scheme is inspired by the postprocessing procedure in the following paper: +# Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020). +# These parameters were optimized with with hybrid-loss trained Sortformer model introduced in https://arxiv.org/pdf/2409.06656. +# These parameters were optimized on the development split of DIHARD3 dataset. See https://arxiv.org/pdf/2012.01477. +# Trial 17903 finished with value: 0.10261257411949805 and parameters: {'onset': 0.53, 'offset': 0.49, 'pad_onset': 0.23, 'pad_offset': 0.0, 'min_duration_on': 0.39, 'min_duration_off': 0.39}. Best is trial 17903 with value: 0.10261257411949805. +# Trial 24682 finished with value: 0.10257785779242055 and parameters: {'onset': 0.53, 'offset': 0.49, 'pad_onset': 0.23, 'pad_offset': 0.01, 'min_duration_on': 0.42, 'min_duration_off': 0.34}. Best is trial 24682 with value: 0.10257785779242055. +parameters: + window_length_in_sec: 0.0 # Not used + shift_length_in_sec: 0.0 # Not used + smoothing: False # Not used + overlap: 0.5 # Not used + onset: 0.53 # Onset threshold for detecting the beginning and end of a speech + offset: 0.49 # Offset threshold for detecting the end of a speech + pad_onset: 0.23 # Adding durations before each speech segment + pad_offset: 0.01 # Adding durations after each speech segment + min_duration_on: 0.42 # Threshold for small non-speech deletion + min_duration_off: 0.34 # Threshold for short speech segment deletion \ No newline at end of file diff --git a/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_HL_dihard.yaml b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_HL_dihard.yaml new file mode 100644 index 000000000000..bb9f362ad619 --- /dev/null +++ b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_HL_dihard.yaml @@ -0,0 +1,17 @@ +# Postprocessing parameters for timestamp outputs from speaker diarization models. +# This speaker diarization postprocessing scheme is inspired by the postprocessing procedure in the following paper: +# Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020). +# These parameters were optimized with with hybrid-loss trained Sortformer model introduced in https://arxiv.org/pdf/2409.06656. +# These parameters were optimized on CallHome Dataset from the NIST SRE 2000 Disc8, especially from the split2 specified in: Kaldi, “Kaldi x-vector recipe v2,” https://github.com/kaldi-asr/kaldi/tree/master/egs/callhome_diarization/v2. +# Trial 180 finished with value: 0.12329626986650599 and parameters: {'onset': 0.56, 'offset': 0.81, 'pad_onset': 0.05, 'pad_offset': 0.0, 'min_duration_on': 0.1, 'min_duration_off': 0.16}. Best is trial 180 with value: 0.12329626986650599. +parameters: + window_length_in_sec: 0.0 # Not used + shift_length_in_sec: 0.0 # Not used + smoothing: False # Not used + overlap: 0.5 # Not used + onset: 0.64 # Onset threshold for detecting the beginning and end of a speech + offset: 0.74 # Offset threshold for detecting the end of a speech + pad_onset: 0.06 # Adding durations before each speech segment + pad_offset: 0.0 # Adding durations after each speech segment + min_duration_on: 0.1 # Threshold for small non-speech deletion + min_duration_off: 0.15 # Threshold for short speech segment deletion \ No newline at end of file diff --git a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_encoder_infer.py b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_encoder_infer.py new file mode 100644 index 000000000000..aafd2b2cb6ed --- /dev/null +++ b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_encoder_infer.py @@ -0,0 +1,132 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytorch_lightning as pl +from omegaconf import OmegaConf +from pytorch_lightning import seed_everything +import seaborn as sns +import numpy as np + +from nemo.collections.asr.models import SortformerEncLabelModel +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager +seed_everything(42) +import torch +import matplotlib.pyplot as plt +import seaborn as sns +from sklearn.manifold import TSNE +import pandas as pd +from nemo.collections.asr.data.audio_to_msdd_mock_label import generate_mock_embs + +def plot_enc_tsne(x, targets, memo): + # x = enc_states_list[-1].squeeze(0).cpu().detach().numpy() + tsne = TSNE(n_components=2, verbose=False, random_state=100) + zembs = tsne.fit_transform(x) + + # Step 1: Create a new column filled with 0.5 + new_column = torch.full((targets.size(0), 1), 0.5) + # Step 2: Concatenate the new column with the original tensor + updated_targets = torch.cat((new_column, targets), dim=1) + + df = pd.DataFrame() + df["y"] = updated_targets.argmax(dim=1).detach().cpu().numpy() + df["comp-1"] = zembs[:,0] + df["comp-2"] = zembs[:,1] + + # Plotting using seaborn + plt.figure(figsize=(10, 8)) + sns.scatterplot(x="comp-1", y="comp-2", hue=df.y.tolist(), + palette=sns.color_palette("hls", 10), + data=df).set(title="SortFormer HiddenState T-SNE projection") + + # Save the plot as a PNG file in the specified directory + plt.savefig(f'/home/taejinp/Downloads/tsne_plots/tsne_sortformer_plot_{memo}.png') + +def remove_speaker_models(ckpt_path): + ckpt_instance = torch.load(ckpt_path) + _state_dict = ckpt_instance['state_dict'] + + key_list = list(_state_dict.keys()) + for key in key_list: + if '_speaker_model.' in key or '_speaker_model_decoder.' in key: + # import ipdb; ipdb.set_trace() + del _state_dict[key] + + target_path = ckpt_path.replace('.ckpt', '.removed.ckpt') + torch.save(ckpt_instance, target_path) + return target_path + + +# @hydra_runner(config_path="../conf/neural_diarizer", config_name="msdd_5scl_15_05_50Povl_256x3x32x2.yaml") +def main(): + # logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') + # trainer = pl.Trainer(**cfg.trainer) + # exp_manager(trainer, cfg.get("exp_manager", None)) + # ckpt_path = "/disk_c/taejinp_backup/msdd_model_train/NVB_SFmr_MixMockEmbsTest/version_18_f0:84/checkpoints/e613.ckpt" + ckpt_path = "/disk_c/taejinp_backup/msdd_model_train/SFmr_MixMockEmbsTest/version_21/checkpoints/ep2255.ckpt" + target_path = remove_speaker_models(ckpt_path) + sortformer_model = SortformerEncLabelModel.load_from_checkpoint(checkpoint_path=target_path) + unit_len = 25 + targets = torch.eye(4,4).repeat_interleave(unit_len,1).t() + targets[:,2:] = 0 + # targets[:,3:] = 0 + targets = targets[:2*unit_len, :] + new_column = torch.full((targets.size(0), 1), 0.5) + updated_targets = torch.cat((new_column, targets), dim=1) + mock_embs, audio_signal_length, targets = generate_mock_embs(targets=targets, seed=315, + mock_emb_noise_std=0.03, + mock_emb_degree_of_freedom=4, + min_noise_std=0.01,) + mock_embs = mock_embs.unsqueeze(0) + audio_signal = mock_embs + + audio_signal, audio_signal_length, targets + + audio_signal = audio_signal.cuda() + ms_seg_counts = torch.tensor([]).cuda() + ms_seg_timestamps = torch.tensor([]).cuda() + scale_mapping = torch.tensor([]).cuda() + sortformer_model.alpha = 0.0 + + _preds_mean, preds_, attn_score_stack, enc_states_list, preds_list = sortformer_model.forward( + audio_signal=audio_signal, + audio_signal_length=audio_signal_length, + ms_seg_timestamps=ms_seg_timestamps, + ms_seg_counts=ms_seg_counts, + scale_mapping=scale_mapping, + temp_targets=targets, + ) + + audio_signal_np = audio_signal.squeeze(0).cpu().detach().numpy() + plot_enc_tsne(audio_signal_np, targets, memo=f'input', ) + for layer_c in range(len(enc_states_list)): + print(f"Plotting TSNE for layer {layer_c} ...") + x = enc_states_list[layer_c].squeeze(0).cpu().detach().numpy() + plot_enc_tsne(x, targets, memo=f'layer{layer_c}', ) + preds = preds_.squeeze(0).cpu().detach().numpy() + plot_enc_tsne(preds, targets, memo=f'preds', ) + _preds_mean = _preds_mean.squeeze(0).cpu().detach().numpy() + plot_enc_tsne(_preds_mean, targets, memo=f'preds_mean', ) + + # Optionally, you can also show the plot if desired + plt.show() + import ipdb; ipdb.set_trace() + + # msdd_model = SortformerEncLabelModel(cfg=cfg.model, trainer=trainer) + # trainer.fit(msdd_model) + + +if __name__ == '__main__': + main() diff --git a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_encoder_train.py b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_encoder_train.py new file mode 100644 index 000000000000..fb350113d596 --- /dev/null +++ b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_encoder_train.py @@ -0,0 +1,54 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytorch_lightning as pl +from omegaconf import OmegaConf +from pytorch_lightning import seed_everything + +from nemo.collections.asr.models import SortformerEncLabelModel +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + +""" +Example training session (single GPU training on telephonic datasets) + +python ./multiscale_diar_decoder.py --config-path='../conf/neural_diarizer' --config-name='msdd_5scl_15_05_50Povl_256x3x32x2.yaml' \ + trainer.devices=1 \ + model.base.diarizer.speaker_embeddings.model_path="titanet_large" \ + model.train_ds.manifest_filepath="" \ + model.validation_ds.manifest_filepath="" \ + model.train_ds.emb_dir="" \ + model.validation_ds.emb_dir="" \ + exp_manager.name='sample_train' \ + exp_manager.exp_dir='./msdd_exp' +""" + +seed_everything(42) + + +@hydra_runner(config_path="../conf/neural_diarizer", config_name="msdd_5scl_15_05_50Povl_256x3x32x2.yaml") +def main(cfg): + logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') + trainer = pl.Trainer(**cfg.trainer) + exp_manager(trainer, cfg.get("exp_manager", None)) + sortformer_model = SortformerEncLabelModel(cfg=cfg.model, trainer=trainer) + # Initialize the weights of the model from another model, if provided via config + sortformer_model.maybe_init_from_pretrained_checkpoint(cfg) + trainer.fit(sortformer_model) + + +if __name__ == '__main__': + + main() diff --git a/nemo/collections/asr/data/audio_to_diar_label.py b/nemo/collections/asr/data/audio_to_diar_label.py index a1cb6d0f1bdc..ffad8e4fd072 100644 --- a/nemo/collections/asr/data/audio_to_diar_label.py +++ b/nemo/collections/asr/data/audio_to_diar_label.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,16 +15,17 @@ import os from collections import OrderedDict from statistics import mode -from typing import Dict, Optional - +from typing import Dict, List, Tuple, Optional import torch +import numpy as np from nemo.collections.asr.parts.utils.offline_clustering import get_argmin_mat -from nemo.collections.asr.parts.utils.speaker_utils import convert_rttm_line, prepare_split_data -from nemo.collections.common.parts.preprocessing.collections import DiarizationSpeechLabel +from nemo.collections.asr.parts.utils.asr_multispeaker_utils import find_first_nonzero +from nemo.collections.asr.parts.utils.speaker_utils import convert_rttm_line, prepare_split_data, get_subsegments +from nemo.collections.common.parts.preprocessing.collections import DiarizationSpeechLabel, EndtoEndDiarizationSpeechLabel from nemo.core.classes import Dataset from nemo.core.neural_types import AudioSignal, EncodedRepresentation, LengthsType, NeuralType, ProbsType - +from nemo.utils import logging def get_scale_mapping_list(uniq_timestamps): """ @@ -62,7 +63,7 @@ def get_scale_mapping_list(uniq_timestamps): return scale_mapping_argmat -def extract_seg_info_from_rttm(uniq_id, rttm_lines, mapping_dict=None, target_spks=None): +def extract_seg_info_from_rttm(rttm_lines, mapping_dict=None, target_spks=None): """ Get RTTM lines containing speaker labels, start time and end time. target_spks contains two targeted speaker indices for creating groundtruth label files. Only speakers in target_spks variable will be @@ -139,6 +140,128 @@ def assign_frame_level_spk_vector(rttm_timestamps, round_digits, frame_per_sec, return fr_level_target +def get_subsegments_to_timestamps( + subsegments: List[Tuple[float, float]], + feat_per_sec: int = 100, + max_end_ts: float=None, + decimals=2 + ): + """ + Convert subsegment timestamps to scale timestamps by multiplying with the feature rate and rounding. + All `ts` related tensors are dimensioned as (N, 2), where N is the number of subsegments. + + Args: + subsegments (List[Tuple[float, float]]): + A list of tuples where each tuple contains the start and end times of a subsegment. + feat_per_sec (int, optional): + The number of feature frames per second. Defaults to 100. + max_end_ts (float, optional): + The maximum end timestamp to clip the results. If None, no clipping is applied. Defaults to None. + decimals (int, optional): + The number of decimal places to round the timestamps. Defaults to 2. + + Returns: + ts (torch.tensor): + A tensor containing the scaled and rounded timestamps for each subsegment. + """ + seg_ts = (torch.tensor(subsegments) * feat_per_sec).float() + ts_round = torch.round(seg_ts, decimals=decimals) + ts = ts_round.long() + ts[:, 1] = ts[:, 0] + ts[:, 1] + if max_end_ts is not None: + ts = np.clip(ts, 0, int(max_end_ts*feat_per_sec)) + return ts + +def extract_frame_info_from_rttm(uniq_id, offset, duration, rttm_lines, round_digits=3): + """ + Extracts RTTM lines containing speaker labels, start time, and end time for a given audio segment. + + Args: + uniq_id (str): Unique identifier for the audio file and corresponding RTTM file. + offset (float): The starting time offset for the segment of interest. + duration (float): The duration of the segment of interest. + rttm_lines (list): List of RTTM lines in string format. + round_digits (int, optional): Number of decimal places to round the start and end times. Defaults to 3. + + Returns: + rttm_mat (tuple): A tuple containing lists of start times, end times, and speaker labels. + sess_to_global_spkids (dict): A mapping from session-specific speaker indices to global speaker identifiers. + """ + rttm_stt, rttm_end = offset, offset + duration + stt_list, end_list, speaker_list, speaker_set = [], [], [], [] + sess_to_global_spkids = dict() + + for rttm_line in rttm_lines: + start, end, speaker = convert_rttm_line(rttm_line) + + # Skip invalid RTTM lines where the start time is greater than the end time. + if start > end: + continue + + # Check if the RTTM segment overlaps with the specified segment of interest. + if (end > rttm_stt and start < rttm_end) or (start < rttm_end and end > rttm_stt): + # Adjust the start and end times to fit within the segment of interest. + start, end = max(start, rttm_stt), min(end, rttm_end) + else: + continue + + # Round the start and end times to the specified number of decimal places. + end_list.append(round(end, round_digits)) + stt_list.append(round(start, round_digits)) + + # Assign a unique index to each speaker and maintain a mapping. + if speaker not in speaker_set: + speaker_set.append(speaker) + speaker_list.append(speaker_set.index(speaker)) + sess_to_global_spkids.update({speaker_set.index(speaker): speaker}) + + rttm_mat = (stt_list, end_list, speaker_list) + return rttm_mat, sess_to_global_spkids + +def get_frame_targets_from_rttm( + rttm_timestamps: list, + offset: float, + duration: float, + round_digits: int, + feat_per_sec: int, + max_spks: int, + ): + """ + Create a multi-dimensional vector sequence containing speaker timestamp information in RTTM. + The unit-length is the frame shift length of the acoustic feature. The feature-level annotations + `feat_level_target` will later be converted to base-segment level diarization label. + + Args: + rttm_timestamps (list): + List containing start and end time for each speaker segment label. + stt_list, end_list and speaker_list are contained. + feat_per_sec (int): + Number of feature frames per second. This quantity is determined by window_stride variable in preprocessing module. + target_spks (tuple): + Speaker indices that are generated from combinations. If there are only one or two speakers, + only a single target_spks variable is generated. + + Returns: + feat_level_target (torch.tensor): + Tensor containing label for each feature level frame. + """ + stt_list, end_list, speaker_list = rttm_timestamps + sorted_speakers = sorted(list(set(speaker_list))) + total_fr_len = int(duration * feat_per_sec) + if len(sorted_speakers) > max_spks: + logging.warning(f"Number of speakers in RTTM file {len(sorted_speakers)} exceeds the maximum number of speakers: {max_spks}! Only {max_spks} first speakers remain, and this will affect frame metrics!") + feat_level_target = torch.zeros(total_fr_len, max_spks) + for count, (stt, end, spk_rttm_key) in enumerate(zip(stt_list, end_list, speaker_list)): + if end < offset or stt > offset + duration: + continue + stt, end = max(offset, stt), min(offset + duration, end) + spk = spk_rttm_key + if spk < max_spks: + stt_fr, end_fr = int((stt - offset) * feat_per_sec), int((end - offset)* feat_per_sec) + feat_level_target[stt_fr:end_fr, spk] = 1 + return feat_level_target + + class _AudioMSDDTrainDataset(Dataset): """ Dataset class that loads a json file containing paths to audio files, @@ -338,7 +461,7 @@ def parse_rttm_for_ms_targets(self, sample): """ rttm_lines = open(sample.rttm_file).readlines() uniq_id = self.get_uniq_id_with_range(sample) - rttm_timestamps = extract_seg_info_from_rttm(uniq_id, rttm_lines) + rttm_timestamps = extract_seg_info_from_rttm(rttm_lines) fr_level_target = assign_frame_level_spk_vector( rttm_timestamps, self.round_digits, self.frame_per_sec, target_spks=sample.target_spks ) @@ -370,14 +493,14 @@ def get_uniq_id_with_range(self, sample, deci=3): def get_ms_seg_timestamps(self, sample): """ - Get start and end time of segments in each scale. + Get start and end time of each diarization frame. Args: sample: `DiarizationSpeechLabel` instance from preprocessing.collections Returns: ms_seg_timestamps (torch.tensor): - Tensor containing Multiscale segment timestamps. + Tensor containing timestamps for each frame. ms_seg_counts (torch.tensor): Number of segments for each scale. This information is used for reshaping embedding batch during forward propagation. @@ -529,7 +652,7 @@ def parse_rttm_multiscale(self, sample): rttm_lines = open(sample.rttm_file).readlines() uniq_id = os.path.splitext(os.path.basename(sample.rttm_file))[0] mapping_dict = self.emb_dict[max(self.emb_dict.keys())][uniq_id]['mapping'] - rttm_timestamps = extract_seg_info_from_rttm(uniq_id, rttm_lines, mapping_dict, sample.target_spks) + rttm_timestamps = extract_seg_info_from_rttm(rttm_lines, mapping_dict, sample.target_spks) fr_level_target = assign_frame_level_spk_vector( rttm_timestamps, self.round_digits, self.frame_per_sec, sample.target_spks ) @@ -851,3 +974,348 @@ def __init__( def msdd_infer_collate_fn(self, batch): return _msdd_infer_collate_fn(self, batch) + +class _AudioToSpeechE2ESpkDiarDataset(Dataset): + """ + Dataset class that loads a json file containing paths to audio files, + RTTM files and number of speakers. This Dataset class is designed for + training or fine-tuning speaker embedding extractor and diarization decoder + at the same time. + + Example: + {"audio_filepath": "/path/to/audio_0.wav", "num_speakers": 2, + "rttm_filepath": "/path/to/diar_label_0.rttm} + ... + {"audio_filepath": "/path/to/audio_n.wav", "num_speakers": 2, + "rttm_filepath": "/path/to/diar_label_n.rttm} + + Args: + manifest_filepath (str): + Path to input manifest json files. + multiargs_dict (dict): + Dictionary containing the parameters for multiscale segmentation and clustering. + soft_label_thres (float): + Threshold that determines the label of each segment based on RTTM file information. + featurizer: + Featurizer instance for generating audio_signal from the raw waveform. + window_stride (float): + Window stride for acoustic feature. This value is used for calculating the numbers of feature-level frames. + """ + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + """Returns definitions of module output ports.""" + output_types = { + "audio_signal": NeuralType(('B', 'T'), AudioSignal()), + "audio_length": NeuralType(('B'), LengthsType()), + "targets": NeuralType(('B', 'T', 'C'), ProbsType()), + "target_len": NeuralType(('B', 'C'), LengthsType()), + } + + return output_types + + def __init__( + self, + *, + manifest_filepath: str, + soft_label_thres: float, + session_len_sec: float, + num_spks: int, + featurizer, + window_stride: float, + min_subsegment_duration: float = 0.03, + global_rank: int = 0, + dtype=torch.float16, + round_digits: int = 2, + soft_targets: bool = False, + subsampling_factor: int = 8, + ): + super().__init__() + self.collection = EndtoEndDiarizationSpeechLabel( + manifests_files=manifest_filepath.split(','), + round_digits=round_digits, + ) + self.featurizer = featurizer + self.round_digits = round_digits + self.feat_per_sec = int(1 / window_stride) + self.diar_frame_length = round(subsampling_factor * window_stride, round_digits) + self.session_len_sec = session_len_sec + self.soft_label_thres = soft_label_thres + self.max_spks = num_spks + self.min_subsegment_duration = min_subsegment_duration + self.dtype = dtype + self.use_asr_style_frame_count = True + self.soft_targets = soft_targets + self.round_digits = 2 + self.floor_decimal = 10 ** self.round_digits + + def __len__(self): + return len(self.collection) + + def get_uniq_id_with_range(self, sample, deci=3): + """ + Generate unique training sample ID from unique file ID, offset and duration. The start-end time added + unique ID is required for identifying the sample since multiple short audio samples are generated from a single + audio file. The start time and end time of the audio stream uses millisecond units if `deci=3`. + + Args: + sample: + `DiarizationSpeechLabel` instance from collections. + + Returns: + uniq_id (str): + Unique sample ID which includes start and end time of the audio stream. + Example: abc1001_3122_6458 + """ + bare_uniq_id = os.path.splitext(os.path.basename(sample.rttm_file))[0] + offset = str(int(round(sample.offset, deci) * pow(10, deci))) + endtime = str(int(round(sample.offset + sample.duration, deci) * pow(10, deci))) + uniq_id = f"{bare_uniq_id}_{offset}_{endtime}" + return uniq_id + + def parse_rttm_for_targets_and_lens(self, uniq_id, rttm_file, offset, duration, target_len): + """ + Generate target tensor variable by extracting groundtruth diarization labels from an RTTM file. + This function converts (start, end, speaker_id) format into base-scale (the finest scale) segment level + diarization label in a matrix form. + + Example of seg_target: + [[0., 1.], [0., 1.], [1., 1.], [1., 0.], [1., 0.], ..., [0., 1.]] + """ + rttm_lines = open(rttm_file).readlines() + rttm_timestamps, sess_to_global_spkids = extract_frame_info_from_rttm(uniq_id, offset, duration, rttm_lines) + + fr_level_target = get_frame_targets_from_rttm(rttm_timestamps=rttm_timestamps, + offset=offset, + duration=duration, + round_digits=self.round_digits, + feat_per_sec=self.feat_per_sec, + max_spks=self.max_spks) + + soft_target_seg = self.get_soft_targets_seg(feat_level_target=fr_level_target, + target_len=target_len) + if self.soft_targets: + step_target = soft_target_seg + else: + step_target = (soft_target_seg >= self.soft_label_thres).float() + return step_target + + def get_soft_targets_seg(self, feat_level_target, target_len): + """ + Generate the final targets for the actual diarization step. + Here, frame level means step level which is also referred to as segments. + We follow the original paper and refer to the step level as "frames". + + Args: + feat_level_target (torch.tensor): + Tensor variable containing hard-labels of speaker activity in each feature-level segment. + target_len (torch.tensor): + Numbers of ms segments + + Returns: + soft_target_seg (torch.tensor): + Tensor variable containing soft-labels of speaker activity in each step-level segment. + """ + num_seg = torch.max(target_len) + targets = torch.zeros(num_seg, self.max_spks) + stride = int(self.feat_per_sec * self.diar_frame_length) + for index in range(num_seg): + if index == 0: + seg_stt_feat = 0 + else: + seg_stt_feat = stride * index - 1 - int(stride / 2) + if index == num_seg - 1: + seg_end_feat = feat_level_target.shape[0] + else: + seg_end_feat = stride * index - 1 + int(stride / 2) + targets[index] = torch.mean(feat_level_target[seg_stt_feat:seg_end_feat+1, :], axis=0) + return targets + + def get_segment_timestamps( + self, + duration: float, + offset: float = 0, + sample_rate: int = 16000, + ): + """ + Get start and end time of segments in each scale. + + Args: + sample: + `DiarizationSpeechLabel` instance from preprocessing.collections + Returns: + segment_timestamps (torch.tensor): + Tensor containing Multiscale segment timestamps. + target_len (torch.tensor): + Number of segments for each scale. This information is used for reshaping embedding batch + during forward propagation. + """ + subsegments = get_subsegments(offset=offset, + window=round(self.diar_frame_length * 2, self.round_digits), + shift=self.diar_frame_length, + duration=duration, + min_subsegment_duration=self.min_subsegment_duration, + use_asr_style_frame_count=self.use_asr_style_frame_count, + sample_rate=sample_rate, + feat_per_sec=self.feat_per_sec, + ) + if self.use_asr_style_frame_count: + effective_dur = np.ceil((1+duration*sample_rate)/int(sample_rate/self.feat_per_sec)).astype(int)/self.feat_per_sec + else: + effective_dur = duration + ts_tensor = get_subsegments_to_timestamps(subsegments, self.feat_per_sec, decimals=2, max_end_ts=(offset+effective_dur)) + target_len = torch.tensor([ts_tensor.shape[0]]) + return target_len + + def __getitem__(self, index): + sample = self.collection[index] + if sample.offset is None: + sample.offset = 0 + offset = sample.offset + if self.session_len_sec < 0: + session_len_sec = sample.duration + else: + session_len_sec = min(sample.duration, self.session_len_sec) + + uniq_id = self.get_uniq_id_with_range(sample) + audio_signal = self.featurizer.process(sample.audio_file, offset=offset, duration=session_len_sec) + + # We should resolve the length mis-match from the round-off errors: `session_len_sec` and `audio_signal.shape[0]` + session_len_sec = np.floor(audio_signal.shape[0] / self.featurizer.sample_rate * self.floor_decimal)/self.floor_decimal + audio_signal = audio_signal[:round(self.featurizer.sample_rate*session_len_sec)] + + audio_signal_length = torch.tensor(audio_signal.shape[0]).long() + audio_signal, audio_signal_length = audio_signal.to('cpu'), audio_signal_length.to('cpu') + target_len = self.get_segment_timestamps(duration=session_len_sec, sample_rate=self.featurizer.sample_rate) + targets = self.parse_rttm_for_targets_and_lens(uniq_id=uniq_id, + rttm_file=sample.rttm_file, + offset=offset, + duration=session_len_sec, + target_len=target_len) + return audio_signal, audio_signal_length, targets, target_len + +def _eesd_train_collate_fn(self, batch): + """ + Collate a batch of variables needed for training the end-to-end speaker diarization (EESD) model + from raw waveforms to diarization labels. The following variables are included in the training/validation batch: + + Args: + batch (tuple): + A tuple containing the variables for diarization training. + + Returns: + audio_signal (torch.Tensor): + A tensor containing the raw waveform samples (time series) loaded from the `audio_filepath` in the input manifest file. + feature_length (torch.Tensor): + A tensor containing the lengths of the raw waveform samples. + targets (torch.Tensor): + Groundtruth speaker labels for the given input embedding sequence. + target_lens (torch.Tensor): + A tensor containing the number of segments for each sample in the batch, necessary for reshaping inputs to the EESD model. + """ + packed_batch = list(zip(*batch)) + audio_signal, feature_length, targets, target_len = packed_batch + audio_signal_list, feature_length_list = [], [] + target_len_list, targets_list = [], [] + + max_raw_feat_len = max([x.shape[0] for x in audio_signal]) + max_target_len = max([x.shape[0] for x in targets]) + if max([len(feat.shape) for feat in audio_signal]) > 1: + max_ch = max([feat.shape[1] for feat in audio_signal]) + else: + max_ch = 1 + for feat, feat_len, tgt, segment_ct in batch: + seq_len = tgt.shape[0] + if len(feat.shape) > 1: + pad_feat = (0, 0, 0, max_raw_feat_len - feat.shape[0]) + else: + pad_feat = (0, max_raw_feat_len - feat.shape[0]) + if feat.shape[0] < feat_len: + feat_len_pad = feat_len - feat.shape[0] + feat = torch.nn.functional.pad(feat, (0, feat_len_pad)) + pad_tgt = (0, 0, 0, max_target_len - seq_len) + padded_feat = torch.nn.functional.pad(feat, pad_feat) + padded_tgt = torch.nn.functional.pad(tgt, pad_tgt) + if max_ch > 1 and padded_feat.shape[1] < max_ch: + feat_ch_pad = max_ch - padded_feat.shape[1] + padded_feat = torch.nn.functional.pad(padded_feat, (0, feat_ch_pad)) + audio_signal_list.append(padded_feat) + feature_length_list.append(feat_len.clone().detach()) + target_len_list.append(segment_ct.clone().detach()) + targets_list.append(padded_tgt) + audio_signal = torch.stack(audio_signal_list) + feature_length = torch.stack(feature_length_list) + target_lens = torch.stack(target_len_list) + targets = torch.stack(targets_list) + return audio_signal, feature_length, targets, target_lens + +class AudioToSpeechE2ESpkDiarDataset(_AudioToSpeechE2ESpkDiarDataset): + """ + Dataset class for loading a JSON file containing paths to audio files, + RTTM (Rich Transcription Time Marked) files, and the number of speakers. + This class is designed for training or fine-tuning a speaker embedding + extractor and diarization decoder simultaneously. + + The JSON manifest file should have entries in the following format: + + Example: + { + "audio_filepath": "/path/to/audio_0.wav", + "num_speakers": 2, + "rttm_filepath": "/path/to/diar_label_0.rttm" + } + ... + { + "audio_filepath": "/path/to/audio_n.wav", + "num_speakers": 2, + "rttm_filepath": "/path/to/diar_label_n.rttm" + } + + Args: + manifest_filepath (str): + Path to the input manifest JSON file containing paths to audio and RTTM files. + soft_label_thres (float): + Threshold for assigning soft labels to segments based on RTTM file information. + session_len_sec (float): + Duration of each session (in seconds) for training or fine-tuning. + num_spks (int): + Number of speakers in the audio files. + featurizer: + Instance of a featurizer for generating features from the raw waveform. + window_stride (float): + Window stride (in seconds) for extracting acoustic features, used to calculate + the number of feature frames. + global_rank (int): + Global rank of the current process (used for distributed training). + soft_targets (bool): + Whether or not to use soft targets during training. + + Methods: + eesd_train_collate_fn(batch): + Collates a batch of data for end-to-end speaker diarization training. + """ + def __init__( + self, + *, + manifest_filepath: str, + soft_label_thres: float, + session_len_sec: float, + num_spks: int, + featurizer, + window_stride, + global_rank: int, + soft_targets: bool, + ): + super().__init__( + manifest_filepath=manifest_filepath, + soft_label_thres=soft_label_thres, + session_len_sec=session_len_sec, + num_spks=num_spks, + featurizer=featurizer, + window_stride=window_stride, + global_rank=global_rank, + soft_targets=soft_targets, + ) + + def eesd_train_collate_fn(self, batch): + return _eesd_train_collate_fn(self, batch) \ No newline at end of file diff --git a/nemo/collections/asr/data/audio_to_diar_label_lhotse.py b/nemo/collections/asr/data/audio_to_diar_label_lhotse.py new file mode 100644 index 000000000000..e223e4ef2a56 --- /dev/null +++ b/nemo/collections/asr/data/audio_to_diar_label_lhotse.py @@ -0,0 +1,76 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional, Tuple + +import torch.utils.data +from lhotse.dataset import AudioSamples +from lhotse.dataset.collation import collate_matrices + +from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType +from nemo.collections.asr.parts.utils.asr_multispeaker_utils import ( + speaker_to_target, + get_hidden_length_from_sample_length, +) + +class LhotseAudioToSpeechE2ESpkDiarDataset(torch.utils.data.Dataset): + """ + This dataset is based on diarization datasets from audio_to_eesd_label.py. + Unlike native NeMo datasets, Lhotse dataset defines only the mapping from + a CutSet (meta-data) to a mini-batch with PyTorch tensors. + Specifically, it performs tokenization, I/O, augmentation, and feature extraction (if any). + Managing data, sampling, de-duplication across workers/nodes etc. is all handled + by Lhotse samplers instead. + """ + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + return { + 'audio_signal': NeuralType(('B', 'T'), AudioSignal()), + 'a_sig_length': NeuralType(tuple('B'), LengthsType()), + 'targets': NeuralType(('B', 'T', 'N'), LabelsType()), + 'target_length': NeuralType(tuple('B'), LengthsType()), + 'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True), + } + + def __init__(self, cfg): + super().__init__() + self.load_audio = AudioSamples(fault_tolerant=True) + self.cfg = cfg + self.num_speakers = self.cfg.get('num_speakers', 4) + self.num_sample_per_mel_frame = int(self.cfg.get('window_stride', 0.01) * self.cfg.get('sample_rate', 16000)) # 160 + self.num_mel_frame_per_target_frame = int(self.cfg.get('subsampling_factor', 8)) + self.spk_tar_all_zero = self.cfg.get('spk_tar_all_zero',False) + + def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]: + audio, audio_lens, cuts = self.load_audio(cuts) + speaker_activities = [] + for cut in cuts: + speaker_activity = speaker_to_target( + a_cut=cut, + num_speakers=self.num_speakers, + num_sample_per_mel_frame=self.num_sample_per_mel_frame, + num_mel_frame_per_asr_frame=self.num_mel_frame_per_target_frame, + spk_tar_all_zero=self.spk_tar_all_zero, + boundary_segments=True + ) + speaker_activities.append(speaker_activity) + targets = collate_matrices(speaker_activities).to(audio.dtype) + target_lens_list = [] + for audio_len in audio_lens: + target_fr_len = get_hidden_length_from_sample_length(audio_len, self.num_sample_per_mel_frame, self.num_mel_frame_per_target_frame) + target_lens_list.append([target_fr_len]) + target_lens = torch.tensor(target_lens_list) + + return audio, audio_lens, targets, target_lens diff --git a/nemo/collections/asr/models/sortformer_diar_models.py b/nemo/collections/asr/models/sortformer_diar_models.py new file mode 100644 index 000000000000..c389f0eb627f --- /dev/null +++ b/nemo/collections/asr/models/sortformer_diar_models.py @@ -0,0 +1,565 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import itertools +import random +import torch +from collections import OrderedDict +from typing import Dict, List, Optional, Union +from hydra.utils import instantiate +from omegaconf import DictConfig +from pytorch_lightning import Trainer +from tqdm import tqdm +from nemo.core.classes import ModelPT +from nemo.core.classes.common import PretrainedModelInfo +from nemo.core.neural_types import AudioSignal, LengthsType, NeuralType +from nemo.core.neural_types.elements import ProbsType +from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations +from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config +from nemo.collections.asr.data.audio_to_diar_label_lhotse import LhotseAudioToSpeechE2ESpkDiarDataset +from nemo.collections.asr.data.audio_to_diar_label import AudioToSpeechE2ESpkDiarDataset +from nemo.collections.asr.metrics.multi_binary_acc import MultiBinaryAccuracy +from nemo.collections.asr.models.asr_model import ExportableEncDecModel +from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer +from nemo.collections.asr.parts.utils.asr_multispeaker_utils import get_pil_targets, get_ats_targets +from nemo.utils import logging + +try: + from torch.cuda.amp import autocast +except ImportError: + from contextlib import contextmanager + + @contextmanager + def autocast(enabled=None): + yield + +# torch.backends.cudnn.enabled = False + +__all__ = ['SortformerEncLabelModel'] + +class SortformerEncLabelModel(ModelPT, ExportableEncDecModel): + """ + Encoder class for Sortformer diarization model. + Model class creates training, validation methods for setting up data performing model forward pass. + + This model class expects config dict for: + * preprocessor + * Transformer Encoder + * FastConformer Encoder + * Sortformer Modules + """ + + @classmethod + def list_available_models(cls) -> List[PretrainedModelInfo]: + """ + This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. + + Returns: + List of available pre-trained models. + """ + result = [] + return result + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + """ + Initialize an Sortformer Diarizer model and a pretrained NEST encoder. + In this init function, training and validation datasets are prepared. + """ + random.seed(42) + self._trainer = trainer if trainer else None + self._cfg = cfg + + if self._trainer: + self.world_size = trainer.num_nodes * trainer.num_devices + else: + self.world_size = 1 + + if self._trainer is not None and self._cfg.get('augmentor', None) is not None: + self.augmentor = process_augmentations(self._cfg.augmentor) + else: + self.augmentor = None + super().__init__(cfg=self._cfg, trainer=trainer) + self.preprocessor = SortformerEncLabelModel.from_config_dict(self._cfg.preprocessor) + + if hasattr(self._cfg, 'spec_augment') and self._cfg.spec_augment is not None: + self.spec_augmentation = SortformerEncLabelModel.from_config_dict(self._cfg.spec_augment) + else: + self.spec_augmentation = None + + self.encoder = SortformerEncLabelModel.from_config_dict(self._cfg.encoder) + self.sortformer_modules = SortformerEncLabelModel.from_config_dict(self._cfg.sortformer_modules) + self.transformer_encoder = SortformerEncLabelModel.from_config_dict(self._cfg.transformer_encoder) + self._init_loss_weights() + + self.eps = 1e-3 + self.loss = instantiate(self._cfg.loss) + + self.streaming_mode = self._cfg.get("streaming_mode", False) + self.save_hyperparameters("cfg") + self._init_eval_metrics() + + speaker_inds = list(range(self._cfg.max_num_of_spks)) + self.speaker_permutations = torch.tensor(list(itertools.permutations(speaker_inds))) # Get all permutations + + def _init_loss_weights(self): + pil_weight = self._cfg.get("pil_weight", 0.0) + ats_weight = self._cfg.get("ats_weight", 1.0) + if pil_weight + ats_weight == 0: + raise ValueError(f"weights for PIL {pil_weight} and ATS {ats_weight} cannot sum to 0") + self.pil_weight = pil_weight/(pil_weight + ats_weight) + self.ats_weight = ats_weight/(pil_weight + ats_weight) + logging.info(f"Normalized weights for PIL {self.pil_weight} and ATS {self.ats_weight}") + + def _init_eval_metrics(self): + """ + If there is no label, then the evaluation metrics will be based on Permutation Invariant Loss (PIL). + """ + self._accuracy_test = MultiBinaryAccuracy() + self._accuracy_train = MultiBinaryAccuracy() + self._accuracy_valid = MultiBinaryAccuracy() + + self._accuracy_test_ats = MultiBinaryAccuracy() + self._accuracy_train_ats = MultiBinaryAccuracy() + self._accuracy_valid_ats = MultiBinaryAccuracy() + + def _reset_train_metrics(self): + self._accuracy_train.reset() + self._accuracy_train_ats.reset() + + def _reset_valid_metrics(self): + self._accuracy_valid.reset() + self._accuracy_valid_ats.reset() + + def __setup_dataloader_from_config(self, config): + # Switch to lhotse dataloader if specified in the config + if config.get("use_lhotse"): + return get_lhotse_dataloader_from_config( + config, + global_rank=self.global_rank, + world_size=self.world_size, + dataset=LhotseAudioToSpeechE2ESpkDiarDataset(cfg=config), + ) + + featurizer = WaveformFeaturizer( + sample_rate=config['sample_rate'], int_values=config.get('int_values', False), augmentor=self.augmentor + ) + + if 'manifest_filepath' in config and config['manifest_filepath'] is None: + logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}") + return None + + logging.info(f"Loading dataset from {config.manifest_filepath}") + + if self._trainer is not None: + global_rank = self._trainer.global_rank + else: + global_rank = 0 + time_flag = time.time() + logging.info("AAB: Starting Dataloader Instance loading... Step A") + + dataset = AudioToSpeechE2ESpkDiarDataset( + manifest_filepath=config.manifest_filepath, + soft_label_thres=config.soft_label_thres, + session_len_sec=config.session_len_sec, + num_spks=config.num_spks, + featurizer=featurizer, + window_stride=self._cfg.preprocessor.window_stride, + global_rank=global_rank, + soft_targets=config.soft_targets if 'soft_targets' in config else False, + ) + logging.info(f"AAB: Dataloader dataset is created, starting torch.utils.data.Dataloader step B: {time.time() - time_flag}") + + self.data_collection = dataset.collection + self.collate_ds = dataset + + dataloader_instance = torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config.batch_size, + collate_fn=self.collate_ds.eesd_train_collate_fn, + drop_last=config.get('drop_last', False), + shuffle=False, + num_workers=config.get('num_workers', 1), + pin_memory=config.get('pin_memory', False), + ) + logging.info(f"AAC: Dataloader Instance loading is done ETA Step B done: {time.time() - time_flag}") + return dataloader_instance + + def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]): + self._train_dl = self.__setup_dataloader_from_config(config=train_data_config,) + + def setup_validation_data(self, val_data_layer_config: Optional[Union[DictConfig, Dict]]): + self._validation_dl = self.__setup_dataloader_from_config(config=val_data_layer_config,) + + def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): + self._test_dl = self.__setup_dataloader_from_config(config=test_data_config,) + + def test_dataloader(self): + if self._test_dl is not None: + return self._test_dl + + @property + def input_types(self) -> Optional[Dict[str, NeuralType]]: + if hasattr(self.preprocessor, '_sample_rate'): + audio_eltype = AudioSignal(freq=self.preprocessor._sample_rate) + else: + audio_eltype = AudioSignal() + return { + "audio_signal": NeuralType(('B', 'T'), audio_eltype), + "audio_signal_length": NeuralType(('B',), LengthsType()), + } + + @property + def output_types(self) -> Dict[str, NeuralType]: + return OrderedDict( + { + "preds": NeuralType(('B', 'T', 'C'), ProbsType()), + } + ) + + def frontend_encoder(self, processed_signal, processed_signal_length): + """ + Generate encoder outputs from frontend encoder. + + Args: + process_signal (torch.Tensor): tensor containing audio-feature (mel spectrogram, mfcc, etc.) + processed_signal_length (torch.Tensor): tensor containing lengths of audio signal in integers + + Returns: + emb_seq (torch.Tensor): tensor containing encoder outputs + emb_seq_length (torch.Tensor): tensor containing lengths of encoder outputs + """ + # Spec augment is not applied during evaluation/testing + if self.spec_augmentation is not None and self.training: + processed_signal = self.spec_augmentation(input_spec=processed_signal, length=processed_signal_length) + self.encoder = self.encoder.to(self.device) + emb_seq, emb_seq_length = self.encoder(audio_signal=processed_signal, length=processed_signal_length) + emb_seq = emb_seq.transpose(1, 2) + if self._cfg.encoder.d_model != self._cfg.tf_d_model: + self.sortformer_modules.encoder_proj = self.sortformer_modules.encoder_proj.to(self.device) + emb_seq = self.sortformer_modules.encoder_proj(emb_seq) + return emb_seq, emb_seq_length + + def forward_infer(self, emb_seq): + """ + The main forward pass for diarization for offline diarization inference. + + Args: + emb_seq (torch.Tensor): tensor containing FastConformer encoder states (embedding vectors). + Dimension: (batch_size, diar_frame_count, emb_dim) + + Returns: + preds (torch.Tensor): Sorted tensor containing Sigmoid values for predicted speaker labels. + Dimension: (batch_size, diar_frame_count, num_speakers) + encoder_states_list (list): List containing total speaker memory for each step for debugging purposes + Dimension: [(batch_size, diar_frame_count, inner dim), ... ] + """ + encoder_mask = self.sortformer_modules.length_to_mask(emb_seq) + trans_emb_seq = self.transformer_encoder(encoder_states=emb_seq, encoder_mask=encoder_mask) + preds = self.sortformer_modules.forward_speaker_sigmoids(trans_emb_seq) + return preds + + def process_signal(self, audio_signal, audio_signal_length): + """ + Extract audio features from time-series signal for further processing in the model. + + This function performs the following steps: + 1. Moves the audio signal to the correct device. + 2. Normalizes the time-series audio signal. + 3. Extrac audio feature from from the time-series audio signal using the model's preprocessor. + + Args: + audio_signal (torch.Tensor): The input audio signal. + Shape: (batch_size, num_samples) + audio_signal_length (torch.Tensor): The length of each audio signal in the batch. + Shape: (batch_size,) + + Returns: + tuple: A tuple containing: + - processed_signal (torch.Tensor): The preprocessed audio signal. + Shape: (batch_size, num_features, num_frames) + - processed_signal_length (torch.Tensor): The length of each processed signal. + Shape: (batch_size,) + """ + audio_signal = audio_signal.to(self.device) + audio_signal = (1/(audio_signal.max()+self.eps)) * audio_signal + processed_signal, processed_signal_length = self.preprocessor(input_signal=audio_signal, length=audio_signal_length) + return processed_signal, processed_signal_length + + def forward( + self, + audio_signal, + audio_signal_length, + ): + """ + Forward pass for training and inference. + + Args: + audio_signal (torch.Tensor): tensor containing audio waveform + Dimension: (batch_size, num_samples) + audio_signal_length (torch.Tensor): tensor containing lengths of audio waveforms + Dimension: (batch_size,) + + Returns: + preds (torch.Tensor): Sorted tensor containing predicted speaker labels + Dimension: (batch_size, diar_frame_count, num_speakers) + encoder_states_list (list): List containing total speaker memory for each step for debugging purposes + Dimension: [(batch_size, diar_frame_count, inner dim), ] + """ + processed_signal, processed_signal_length = self.process_signal(audio_signal=audio_signal, audio_signal_length=audio_signal_length) + processed_signal = processed_signal[:, :, :processed_signal_length.max()] + if self._cfg.get("streaming_mode", False): + raise NotImplementedError("Streaming mode is not implemented yet.") + else: + emb_seq, _ = self.frontend_encoder(processed_signal=processed_signal, processed_signal_length=processed_signal_length) + preds = self.forward_infer(emb_seq) + return preds + + def _get_aux_train_evaluations(self, preds, targets, target_lens): + """ + Compute auxiliary training evaluations including losses and metrics. + + This function calculates various losses and metrics for the training process, + including ATS (Anchored Temporal Segmentation) and PIL (Permutation Invariant Loss) + based evaluations. + + Args: + preds (torch.Tensor): Predicted speaker labels. + Shape: (batch_size, diar_frame_count, num_speakers) + targets (torch.Tensor): Ground truth speaker labels. + Shape: (batch_size, diar_frame_count, num_speakers) + target_lens (torch.Tensor): Lengths of target sequences. + Shape: (batch_size,) + + Returns: + (dict): A dictionary containing the following training metrics. + """ + targets_ats = get_ats_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) + targets_pil = get_pil_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) + ats_loss = self.loss(probs=preds, labels=targets_ats, target_lens=target_lens) + pil_loss = self.loss(probs=preds, labels=targets_pil, target_lens=target_lens) + loss = self.ats_weight * ats_loss + self.pil_weight * pil_loss + + self._accuracy_train(preds, targets_pil, target_lens) + train_f1_acc, train_precision, train_recall = self._accuracy_train.compute() + + self._accuracy_train_ats(preds, targets_ats, target_lens) + train_f1_acc_ats, _, _ = self._accuracy_train_ats.compute() + + train_metrics = { + 'loss': loss, + 'ats_loss': ats_loss, + 'pil_loss': pil_loss, + 'learning_rate': self._optimizer.param_groups[0]['lr'], + 'train_f1_acc': train_f1_acc, + 'train_precision': train_precision, + 'train_recall': train_recall, + 'train_f1_acc_ats': train_f1_acc_ats, + } + return train_metrics + + def training_step(self, batch: list) -> dict: + """ + Performs a single training step. + + Args: + batch (list): A list containing the following elements: + - audio_signal (torch.Tensor): The input audio signal in time-series format. + - audio_signal_length (torch.Tensor): The length of each audio signal in the batch. + - targets (torch.Tensor): The target labels for the batch. + - target_lens (torch.Tensor): The length of each target sequence in the batch. + batch_idx (int): The index of the current batch. + + Returns: + (dict): A dictionary containing the 'loss' key with the calculated loss value. + """ + audio_signal, audio_signal_length, targets, target_lens = batch + preds = self.forward(audio_signal=audio_signal, audio_signal_length=audio_signal_length) + train_metrics = self._get_aux_train_evaluations(preds, targets, target_lens) + self._reset_train_metrics() + self.log_dict(train_metrics, sync_dist=True, on_step=True, on_epoch=False, logger=True) + return {'loss': train_metrics['loss']} + + def _get_aux_validation_evaluations(self, preds, targets, target_lens): + """ + Compute auxiliary validation evaluations including losses and metrics. + This function calculates various losses and metrics for the validation process, + including ATS (Anchored Temporal Segmentation) and PIL (Permutation Invariant Loss) + based evaluations. + + Args: + preds (torch.Tensor): Predicted speaker labels. + Shape: (batch_size, diar_frame_count, num_speakers) + targets (torch.Tensor): Ground truth speaker labels. + Shape: (batch_size, diar_frame_count, num_speakers) + target_lens (torch.Tensor): Lengths of target sequences. + Shape: (batch_size,) + + Returns: + dict: A dictionary containing the following validation metrics + """ + targets_ats = get_ats_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) + targets_pil = get_pil_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) + + val_ats_loss = self.loss(probs=preds, labels=targets_ats, target_lens=target_lens) + val_pil_loss = self.loss(probs=preds, labels=targets_pil, target_lens=target_lens) + val_loss = self.ats_weight * val_ats_loss + self.pil_weight * val_pil_loss + + self._accuracy_valid(preds, targets_pil, target_lens) + val_f1_acc, val_precision, val_recall = self._accuracy_valid.compute() + + self._accuracy_valid_ats(preds, targets_ats, target_lens) + valid_f1_acc_ats, _, _ = self._accuracy_valid_ats.compute() + + self._accuracy_valid.reset() + self._accuracy_valid_ats.reset() + + val_metrics = { + 'val_loss': val_loss, + 'val_ats_loss': val_ats_loss, + 'val_pil_loss': val_pil_loss, + 'val_f1_acc': val_f1_acc, + 'val_precision': val_precision, + 'val_recall': val_recall, + 'val_f1_acc_ats': valid_f1_acc_ats, + } + return val_metrics + + def validation_step(self, batch: list, dataloader_idx: int = 0): + """ + Performs a single validation step. + + This method processes a batch of data during the validation phase. It forward passes + the audio signal through the model, computes various validation metrics, and stores + these metrics for later aggregation. + + Args: + batch (list): A list containing the following elements: + - audio_signal (torch.Tensor): The input audio signal. + - audio_signal_length (torch.Tensor): The length of each audio signal in the batch. + - targets (torch.Tensor): The target labels for the batch. + - target_lens (torch.Tensor): The length of each target sequence in the batch. + batch_idx (int): The index of the current batch. + dataloader_idx (int, optional): The index of the dataloader in case of multiple + validation dataloaders. Defaults to 0. + + Returns: + dict: A dictionary containing various validation metrics for this batch. + """ + audio_signal, audio_signal_length, targets, target_lens = batch + preds = self.forward( + audio_signal=audio_signal, + audio_signal_length=audio_signal_length, + ) + val_metrics = self._get_aux_validation_evaluations(preds, targets, target_lens) + if isinstance(self.trainer.val_dataloaders, list) and len(self.trainer.val_dataloaders) > 1: + self.validation_step_outputs[dataloader_idx].append(val_metrics) + else: + self.validation_step_outputs.append(val_metrics) + return val_metrics + + def multi_validation_epoch_end(self, outputs: list, dataloader_idx: int = 0): + if not outputs: + logging.warning(f"`outputs` is None; empty outputs for dataloader={dataloader_idx}") + return None + val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean() + val_ats_loss_mean = torch.stack([x['val_ats_loss'] for x in outputs]).mean() + val_pil_loss_mean = torch.stack([x['val_pil_loss'] for x in outputs]).mean() + val_f1_acc_mean = torch.stack([x['val_f1_acc'] for x in outputs]).mean() + val_precision_mean = torch.stack([x['val_precision'] for x in outputs]).mean() + val_recall_mean = torch.stack([x['val_recall'] for x in outputs]).mean() + val_f1_acc_ats_mean = torch.stack([x['val_f1_acc_ats'] for x in outputs]).mean() + + self._reset_valid_metrics() + + multi_val_metrics = { + 'val_loss': val_loss_mean, + 'val_ats_loss': val_ats_loss_mean, + 'val_pil_loss': val_pil_loss_mean, + 'val_f1_acc': val_f1_acc_mean, + 'val_precision': val_precision_mean, + 'val_recall': val_recall_mean, + 'val_f1_acc_ats': val_f1_acc_ats_mean, + } + return {'log': multi_val_metrics} + + def _get_aux_test_batch_evaluations(self, batch_idx: int, preds, targets, target_lens): + """ + Compute auxiliary validation evaluations including losses and metrics. + + This function calculates various losses and metrics for the validation process, + including ATS (Anchored Temporal Segmentation) and PIL (Permutation Invariant Loss) + based evaluations. + + Args: + preds (torch.Tensor): Predicted speaker labels. + Shape: (batch_size, diar_frame_count, num_speakers) + targets (torch.Tensor): Ground truth speaker labels. + Shape: (batch_size, diar_frame_count, num_speakers) + target_lens (torch.Tensor): Lengths of target sequences. + Shape: (batch_size,) + + Returns: + dict: A dictionary containing the following validation metrics + """ + targets_ats = get_ats_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) + targets_pil = get_pil_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) + self._accuracy_test(preds, targets_pil, target_lens) + f1_acc, precision, recall = self._accuracy_test.compute() + self.batch_f1_accs_list.append(f1_acc) + self.batch_precision_list.append(precision) + self.batch_recall_list.append(recall) + logging.info(f"batch {batch_idx}: f1_acc={f1_acc}, precision={precision}, recall={recall}") + + self._accuracy_test_ats(preds, targets_ats, target_lens) + f1_acc_ats, precision_ats, recall_ats = self._accuracy_test_ats.compute() + self.batch_f1_accs_ats_list.append(f1_acc_ats) + logging.info(f"batch {batch_idx}: f1_acc_ats={f1_acc_ats}, precision_ats={precision_ats}, recall_ats={recall_ats}") + + self._accuracy_test.reset() + self._accuracy_test_ats.reset() + + def test_batch(self,): + """ + Perform batch testing on the model. + + This method iterates through the test data loader, making predictions for each batch, + and calculates various evaluation metrics. It handles both single and multi-sample batches. + """ + self.preds_total_list, self.batch_f1_accs_list, self.batch_precision_list, self.batch_recall_list, self.batch_f1_accs_ats_list = [], [], [], [], [] + + with torch.no_grad(): + for batch_idx, batch in enumerate(tqdm(self._test_dl)): + audio_signal, audio_signal_length, targets, target_lens = batch + audio_signal = audio_signal.to(self.device) + audio_signal_length = audio_signal_length.to(self.device) + preds = self.forward( + audio_signal=audio_signal, + audio_signal_length=audio_signal_length, + ) + preds = preds.detach().to('cpu') + if preds.shape[0] == 1: # batch size = 1 + self.preds_total_list.append(preds) + else: + self.preds_total_list.extend(torch.split(preds, [1] * preds.shape[0])) + torch.cuda.empty_cache() + self._get_aux_test_batch_evaluations(batch_idx, preds, targets, target_lens) + + logging.info(f"Batch F1Acc. MEAN: {torch.mean(torch.tensor(self.batch_f1_accs_list))}") + logging.info(f"Batch Precision MEAN: {torch.mean(torch.tensor(self.batch_precision_list))}") + logging.info(f"Batch Recall MEAN: {torch.mean(torch.tensor(self.batch_recall_list))}") + logging.info(f"Batch ATS F1Acc. MEAN: {torch.mean(torch.tensor(self.batch_f1_accs_ats_list))}") + + def diarize(self,): + raise NotImplementedError diff --git a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py new file mode 100644 index 000000000000..a1d34e1f7480 --- /dev/null +++ b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py @@ -0,0 +1,1231 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import copy +import math +import random +import logging +import itertools +from copy import deepcopy +import concurrent.futures +from cytoolz import groupby +from collections import defaultdict +from typing import Dict, Optional, Tuple, List + +import numpy as np +import soundfile +from tqdm import tqdm +from scipy.stats import norm + +import torch.utils.data +from lhotse.cut.set import mix +from lhotse.cut import CutSet, MixedCut, MonoCut, MixTrack +from lhotse import SupervisionSet, SupervisionSegment, dill_enabled, AudioSource, Recording +from lhotse.utils import uuid4 + +def find_first_nonzero(mat: torch.Tensor, max_cap_val=-1, thres:float = 0.5) -> torch.Tensor: + """ + Finds the first nonzero value in the matrix, discretizing it to the specified maximum capacity. + + Args: + mat (Tensor): A torch tensor representing the matrix. + max_cap_val (int): The maximum capacity to which the matrix values will be discretized. + thres (float): The threshold value for discretizing the matrix values. + + Returns: + mask_max_indices (Tensor): A torch tensor representing the discretized matrix with the first nonzero value in each row. + """ + # Discretize the matrix to the specified maximum capacity + labels_discrete = mat.clone() + labels_discrete[labels_discrete < thres] = 0 + labels_discrete[labels_discrete >= thres] = 1 + + # non zero values mask + non_zero_mask = labels_discrete != 0 + # operations on the mask to find first nonzero values in the rows + mask_max_values, mask_max_indices = torch.max(non_zero_mask, dim=1) + # if the max-mask is zero, there is no nonzero value in the row + mask_max_indices[mask_max_values == 0] = max_cap_val + return mask_max_indices + +def find_best_permutation(match_score: torch.Tensor, speaker_permutations: torch.Tensor) -> torch.Tensor: + """ + Finds the best permutation indices based on the match score. + + Args: + match_score (torch.Tensor): A tensor containing the match scores for each permutation. + Shape: (batch_size, num_permutations) + speaker_permutations (torch.Tensor): A tensor containing all possible speaker permutations. + Shape: (num_permutations, num_speakers) + + Returns: + torch.Tensor: A tensor containing the best permutation indices for each batch. + Shape: (batch_size, num_speakers) + """ + batch_best_perm = torch.argmax(match_score, axis=1) + rep_speaker_permutations = speaker_permutations.repeat(batch_best_perm.shape[0], 1).to(match_score.device) + perm_size = speaker_permutations.shape[0] + global_inds_vec = torch.arange(0, perm_size * batch_best_perm.shape[0], perm_size).to(batch_best_perm.device) + batch_best_perm + return rep_speaker_permutations[global_inds_vec.to(rep_speaker_permutations.device), :] + +def reconstruct_labels(labels: torch.Tensor, batch_perm_inds: torch.Tensor) -> torch.Tensor: + """ + Reconstructs the labels using the best permutation indices with matrix operations. + + Args: + labels (torch.Tensor): A tensor containing the original labels. + Shape: (batch_size, num_frames, num_speakers) + batch_perm_inds (torch.Tensor): A tensor containing the best permutation indices for each batch. + Shape: (batch_size, num_speakers) + + Returns: + torch.Tensor: A tensor containing the reconstructed labels using the best permutation indices. + Shape: (batch_size, num_frames, num_speakers) + """ + # Expanding batch_perm_inds to align with labels dimensions + batch_size, num_frames, num_speakers = labels.shape + batch_perm_inds_exp = batch_perm_inds.unsqueeze(1).expand(-1, num_frames, -1) + + # Reconstructing the labels using advanced indexing + reconstructed_labels = torch.gather(labels, 2, batch_perm_inds_exp) + return reconstructed_labels + +def get_ats_targets( + labels: torch.Tensor, + preds: torch.Tensor, + speaker_permutations: torch.Tensor, + thres: float = 0.5, + tolerance: float = 0 +) -> torch.Tensor: + """ + Sorts labels and predictions to get the optimal of all arrival-time ordered permutations. + + Args: + labels (torch.Tensor): A tensor containing the original labels. + Shape: (batch_size, num_frames, num_speakers) + preds (torch.Tensor): A tensor containing the predictions. + Shape: (batch_size, num_frames, num_speakers) + speaker_permutations (torch.Tensor): A tensor containing all possible speaker permutations. + Shape: (num_permutations, num_speakers) + thres (float): The threshold value for discretizing the matrix values. Default is 0.5. + tolerance (float): The tolerance for comparing the first speech frame indices. Default is 0. + + Returns: + torch.Tensor: A tensor containing the reconstructed labels using the best permutation indices. + Shape: (batch_size, num_frames, num_speakers) + """ + # Find the first nonzero frame index for each speaker in each batch + nonzero_ind = find_first_nonzero(mat=labels, max_cap_val=labels.shape[1], thres=thres) # (batch_size, num_speakers) + + # Sort the first nonzero frame indices for arrival-time ordering + sorted_values = torch.sort(nonzero_ind)[0] # (batch_size, num_speakers) + perm_size = speaker_permutations.shape[0] # Scalar value (num_permutations) + permed_labels = labels[:, :, speaker_permutations] # (batch_size, num_frames, num_permutations, num_speakers) + permed_nonzero_ind = find_first_nonzero(mat=permed_labels, max_cap_val=labels.shape[1]) # (batch_size, num_permutations, num_speakers) + + # Compare the first frame indices of sorted labels with those of the permuted labels using tolerance + perm_compare = torch.abs(sorted_values.unsqueeze(1) - permed_nonzero_ind) <= tolerance # (batch_size, num_permutations, num_speakers) + perm_mask = torch.all(perm_compare, dim=2).float() # (batch_size, num_permutations) + preds_rep = torch.unsqueeze(preds, 2).repeat(1, 1, perm_size, 1) # Exapnd the preds: (batch_size, num_frames, num_permutations, num_speakers) + + # Compute the match score for each permutation by comparing permuted labels with preds + match_score = torch.sum(permed_labels * preds_rep, axis=1).sum(axis=2) * perm_mask # (batch_size, num_permutations) + batch_perm_inds = find_best_permutation(match_score, speaker_permutations) # (batch_size, num_speakers) + max_score_permed_labels = reconstruct_labels(labels, batch_perm_inds) # (batch_size, num_frames, num_speakers) + return max_score_permed_labels # (batch_size, num_frames, num_speakers) + +def get_pil_targets(labels: torch.Tensor, preds: torch.Tensor, speaker_permutations: torch.Tensor) -> torch.Tensor: + """ + Sorts labels and predictions to get the optimal permutation based on the match score. + + Args: + labels (torch.Tensor): A tensor containing the ground truth labels. + Shape: (batch_size, num_speakers, num_classes) + preds (torch.Tensor): A tensor containing the predicted values. + Shape: (batch_size, num_speakers, num_classes) + speaker_permutations (torch.Tensor): A tensor containing all possible speaker permutations. + Shape: (num_permutations, num_speakers) + + Returns: + torch.Tensor: A tensor of permuted labels that best match the predictions. + Shape: (batch_size, num_speakers, num_classes) + """ + perm_size = speaker_permutations.shape[0] # Scalar value (num_permutations) + permed_labels = labels[:, :, speaker_permutations] # (batch_size, num_classes, num_permutations, num_speakers) + # Repeat preds to match permutations for comparison + preds_rep = torch.unsqueeze(preds, 2).repeat(1, 1, speaker_permutations.shape[0], 1) # (batch_size, num_speakers, num_permutations, num_classes) + match_score = torch.sum(permed_labels * preds_rep, axis=1).sum(axis=2) # (batch_size, num_permutations) + batch_perm_inds = find_best_permutation(match_score, speaker_permutations) # (batch_size, num_speakers) + # Reconstruct labels based on the best permutation for each batch + max_score_permed_labels = reconstruct_labels(labels, batch_perm_inds) # (batch_size, num_speakers, num_classes) + return max_score_permed_labels # (batch_size, num_speakers, num_classes) + +def apply_spk_mapping(diar_preds: torch.Tensor, spk_mappings: torch.Tensor) -> torch.Tensor: + """ + Applies a speaker mapping to diar predictions. + + Args: + diar_preds (Tensor): The diar predictions tensor. + Dimension: (batch_size, num_frames, num_speakers) + spk_mappings (Tensor): The speaker mappings tensor. + Dimension: (batch_size, num_speakers) + + Returns: + permuted_diar_preds (Tensor): The permuted diar predictions tensor with the given speaker mappings. + """ + expanded_mappings = spk_mappings.unsqueeze(1).expand(-1, diar_preds.size(1), -1) + permuted_diar_preds = torch.gather(diar_preds, 2, expanded_mappings) + return permuted_diar_preds + +def shuffle_spk_mapping(cuts: list, num_speakers: int, shuffle_spk_mapping: bool = False, pattern= r'<\|spltoken\d+\|>') -> Tuple[CutSet, torch.Tensor]: + """ + Applies a shuffle mapping to speaker text labels in the cuts. + Example: + Original cut.text: + "<|spltoken0|> we do shuffle <|spltoken1|> and map speakers <|spltoken0|> yes <|spltoken2|> we keep dimensions" + Speaker Mapping: [3, 0, 1, 2] + Shuffled cut.text: + "<|spltoken3|> we do shuffle <|spltoken0|> and map speakers <|spltoken3|> yes <|spltoken1|> we keep dimensions" + + Args: + cuts (List[MonoCut, MixedCut]): A list of Cut instances. + num_speakers (int): The total number of speakers. + shuffle_spk_mapping (bool): Whether to shuffle the speaker mappings. + pattern (str): A regular expression pattern for speaker tokens. + + Returns: + cuts (list): The updated CutSet with shuffled speaker mappings. + spk_mappings (Tensor): + If shuffle_speaker_mapping is True, shuffled speaker mappings in batch. + If shuffle_speaker_mapping is False, speaker mappings in batch is not permuted and returns torch.arange() values. + """ + batch_size = len(cuts) + if shuffle_spk_mapping: + permuted_indices = torch.rand(batch_size, num_speakers).argsort(dim=1) + spk_mappings = torch.gather(torch.arange(num_speakers).repeat(batch_size, 1), 1, permuted_indices) + str_pattern = pattern.replace("\\", '') + left_str, right_str = str_pattern.split('d+')[0], str_pattern.split('d+')[1] + for idx, cut in enumerate(cuts): + word_list = [] + for word in deepcopy(cut.text).split(): + if len(re.findall(pattern, word)) > 0: + spk_token_int = int(word.replace(left_str,'').replace(right_str, '')) + new_spk = spk_mappings[idx][spk_token_int] + word_list.append(f'{left_str}{new_spk}{right_str}') + else: + word_list.append(word) + cuts[idx].supervisions[0].text = ' '.join(word_list) + else: + spk_mappings = torch.arange(num_speakers).unsqueeze(0).repeat(batch_size, 1) + return cuts, spk_mappings + +def find_segments_from_rttm( + recording_id: str, + rttms, + start_after: float, + end_before: float, + adjust_offset: bool=True, + tolerance: float=0.001): + """ + Finds segments from the given rttm file. + This function is designed to replace rttm + + Args: + recording_id (str): The recording ID in string format. + rttms (SupervisionSet): The SupervisionSet instance. + start_after (float): The start time after which segments are selected. + end_before (float): The end time before which segments are selected. + adjust_offset (bool): Whether to adjust the offset of the segments. + tolerance (float): The tolerance for time matching. 0.001 by default. + + Returns: + segments (List[SupervisionSegment]): A list of SupervisionSegment instances. + """ + segment_by_recording_id = rttms._segments_by_recording_id + if segment_by_recording_id is None: + from cytoolz import groupby + segment_by_recording_id = groupby(lambda seg: seg.recording_id, rttms) + + return [ + # We only modify the offset - the duration remains the same, as we're only shifting the segment + # relative to the Cut's start, and not truncating anything. + segment.with_offset(-start_after) if adjust_offset else segment + for segment in segment_by_recording_id.get(recording_id, []) + if segment.start < end_before + tolerance + and segment.end > start_after + tolerance + ] + +def speaker_to_target( + a_cut, + num_speakers: int = 4, + num_sample_per_mel_frame: int = 160, + num_mel_frame_per_asr_frame: int = 8, + spk_tar_all_zero: bool = False, + boundary_segments: bool = False, + soft_label: bool = False, + ignore_num_spk_mismatch: bool = True, + soft_thres: float = 0.5, + ): + ''' + Get rttm samples corresponding to one cut, generate speaker mask numpy.ndarray with shape (num_speaker, hidden_length) + This function is needed for speaker diarization with ASR model trainings. + + Args: + a_cut (MonoCut, MixedCut): Lhotse Cut instance which is MonoCut or MixedCut instance. + num_speakers (int): max number of speakers for all cuts ("mask" dim0), 4 by default + num_sample_per_mel_frame (int): number of sample per mel frame, sample_rate / 1000 * window_stride, 160 by default (10ms window stride) + num_mel_frame_per_asr_frame (int): encoder subsampling_factor, 8 by default + spk_tar_all_zero (Tensor): set to True gives all zero "mask" + boundary_segments (bool): set to True to include segments containing the boundary of the cut, False by default for multi-speaker ASR training + soft_label (bool): set to True to use soft label that enables values in [0, 1] range, False by default and leads to binary labels. + ignore_num_spk_mismatch (bool): This is a temporary solution to handle speaker mismatch. Will be removed in the future. + + Returns: + mask (Tensor): speaker mask with shape (num_speaker, hidden_lenght) + ''' + # get cut-related segments from rttms + # basename = os.path.basename(a_cut.rttm_filepath).replace('.rttm', '') + if isinstance(a_cut, MixedCut): + cut_list = [track.cut for track in a_cut.tracks if isinstance(track.cut, MonoCut)] + offsets = [track.offset for track in a_cut.tracks if isinstance(track.cut, MonoCut)] + elif isinstance(a_cut, MonoCut): + cut_list = [a_cut] + offsets = [0] + else: + raise ValueError(f"Unsupported cut type type{a_cut}: only MixedCut and MonoCut are supported") + + segments_total = [] + for i, cut in enumerate(cut_list): + rttms = SupervisionSet.from_rttm(cut.rttm_filepath) + if boundary_segments: # segments with seg_start < total_end and seg_end > total_start are included + segments_iterator = find_segments_from_rttm(recording_id=cut.recording_id, rttms=rttms, start_after=cut.start, end_before=cut.end, tolerance=0.0) + else: # segments with seg_start > total_start and seg_end < total_end are included + segments_iterator = rttms.find(recording_id=cut.recording_id, start_after=cut.start, end_before=cut.end, adjust_offset=True) + + for seg in segments_iterator: + if seg.start < 0: + seg.duration += seg.start + seg.start = 0 + if seg.end > cut.duration: + seg.duration -= seg.end - cut.duration + seg.start += offsets[i] + segments_total.append(seg) + + # apply arrival time sorting to the existing segments + segments_total.sort(key = lambda rttm_sup: rttm_sup.start) + + seen = set() + seen_add = seen.add + speaker_ats = [s.speaker for s in segments_total if not (s.speaker in seen or seen_add(s.speaker))] + + speaker_to_idx_map = { + spk: idx + for idx, spk in enumerate(speaker_ats) + } + if len(speaker_to_idx_map) > num_speakers and not ignore_num_spk_mismatch: # raise error if number of speakers + raise ValueError(f"Number of speakers {len(speaker_to_idx_map)} is larger than the maximum number of speakers {num_speakers}") + + # initialize mask matrices (num_speaker, encoder_hidden_len) + feat_per_sec = int(a_cut.sampling_rate / num_sample_per_mel_frame) # 100 by default + num_samples = get_hidden_length_from_sample_length(a_cut.num_samples, num_sample_per_mel_frame, num_mel_frame_per_asr_frame) + if spk_tar_all_zero: + frame_mask = torch.zeros((num_samples, num_speakers)) + else: + frame_mask = get_mask_from_segments(segments_total, a_cut, speaker_to_idx_map, num_speakers, feat_per_sec, ignore_num_spk_mismatch) + soft_mask = get_soft_mask(frame_mask, num_samples, num_mel_frame_per_asr_frame) + + if soft_label: + mask = soft_mask + else: + mask = (soft_mask > soft_thres).float() + + return mask + +def get_mask_from_segments(segments: list, a_cut, speaker_to_idx_map: torch.Tensor, num_speakers: int =4, feat_per_sec: int=100, ignore_num_spk_mismatch: bool = False): + """ + Generate mask matrix from segments list. + This function is needed for speaker diarization with ASR model trainings. + + Args: + segments: A list of Lhotse Supervision segments iterator. + cut (MonoCut, MixedCut): Lhotse MonoCut or MixedCut instance. + speaker_to_idx_map (dict): A dictionary mapping speaker names to indices. + num_speakers (int): max number of speakers for all cuts ("mask" dim0), 4 by default + feat_per_sec (int): number of frames per second, 100 by default, 0.01s frame rate + ignore_num_spk_mismatch (bool): This is a temporary solution to handle speaker mismatch. Will be removed in the future. + + Returns: + mask (Tensor): A numpy array of shape (num_speakers, encoder_hidden_len). + Dimension: (num_speakers, num_frames) + """ + # get targets with 0.01s frame rate + num_samples = round(a_cut.duration * feat_per_sec) + mask = torch.zeros((num_samples, num_speakers)) + for rttm_sup in segments: + speaker_idx = speaker_to_idx_map[rttm_sup.speaker] + if speaker_idx >= num_speakers: + if ignore_num_spk_mismatch: + continue + else: + raise ValueError(f"Speaker Index {speaker_idx} exceeds the max index: {num_speakers-1}") + stt = max(rttm_sup.start, 0) + ent = min(rttm_sup.end, a_cut.duration) + stf = int(stt * feat_per_sec) + enf = int(ent * feat_per_sec) + mask[stf:enf, speaker_idx] = 1.0 + return mask + +def get_soft_mask(feat_level_target, num_samples, stride): + """ + Get soft mask from feat_level_target with stride. + This function is needed for speaker diarization with ASR model trainings. + + Args: + feat_level_target (Tensor): A numpy array of shape (num_frames, num_speakers). + Dimension: (num_frames, num_speakers) + num_sample (int): The total number of samples. + stride (int): The stride for the mask. + """ + + num_speakers = feat_level_target.shape[1] + mask = torch.zeros(num_samples, num_speakers) + + for index in range(num_samples): + if index == 0: + seg_stt_feat = 0 + else: + seg_stt_feat = stride * index - 1 - int(stride / 2) + if index == num_samples - 1: + seg_end_feat = feat_level_target.shape[0] + else: + seg_end_feat = stride * index - 1 + int(stride / 2) + mask[index] = torch.mean(feat_level_target[seg_stt_feat:seg_end_feat+1, :], axis=0) + return mask + +def get_hidden_length_from_sample_length( + num_samples: int, + num_sample_per_mel_frame: int = 160, + num_mel_frame_per_asr_frame: int = 8 +) -> int: + """ + Calculate the hidden length from the given number of samples. + This function is needed for speaker diarization with ASR model trainings. + + This function computes the number of frames required for a given number of audio samples, + considering the number of samples per mel frame and the number of mel frames per ASR frame. + + Parameters: + num_samples (int): The total number of audio samples. + num_sample_per_mel_frame (int, optional): The number of samples per mel frame. Default is 160. + num_mel_frame_per_asr_frame (int, optional): The number of mel frames per ASR frame. Default is 8. + + Returns: + hidden_length (int): The calculated hidden length in terms of the number of frames. + """ + mel_frame_count = math.ceil((num_samples + 1) / num_sample_per_mel_frame) + hidden_length = math.ceil(mel_frame_count / num_mel_frame_per_asr_frame) + return int(hidden_length) + +class ConcatenationMeetingSimulator(): + """ + This simulator concatenates the segments from different/same sessions to create a + multi-speaker meeting. + """ + + def __init__( + self, + intra_session_concat_prob: float|List[float] = [0, 1.0, 0.5, 0.2], + data_type: str = "msasr", + min_duration: float = 30.0, + max_duration: float = 40.0, + max_num_speakers: int = 4, + speaker_count_distribution: List[float] = [0, 2, 3, 4], + skip_long_segments: bool = True, + valid_dataset_ids: List[str] = [], + ): + """ + :param intra_session_concat_prob: the probability of concatenating segments from the same + session. [Default: 1] + :param data_type: the type of data to simulate. Either 'msasr' or 'diar'. If 'msasr', + the transcripts are included in the simulation,and the boundary segments are + not included. [Default: 'msasr'] + :param max_duration: the maximum duration of the simulated meeting. [Default: 40.0] + """ + super().__init__() + if isinstance(intra_session_concat_prob, float): + self.intra_session_concat_prob = [intra_session_concat_prob] * (max_num_speakers) + elif len(intra_session_concat_prob) == max_num_speakers: + self.intra_session_concat_prob = intra_session_concat_prob + else: + raise ValueError(f"intra_session_concat_prob must be either a float or a list of floats, but got {intra_session_concat_prob}") + if data_type not in ["msasr", "diar"]: + raise ValueError("data_type must be either 'msasr' or 'diar', but got {data_type}") + self.data_type = data_type + self.min_duration = min_duration + self.max_duration = max_duration + self.max_num_speakers = max_num_speakers + self.speaker_count_distribution = speaker_count_distribution + assert len(speaker_count_distribution) == max_num_speakers, f"Length of speaker_count_distribution {len(speaker_count_distribution)} must be equal to max_num_speakers {max_num_speakers}" + + if skip_long_segments: + self.skip_duration = max_duration / 2 + else: + self.skip_duration = max_duration + + self.valid_dataset_ids = valid_dataset_ids + + def fit(self, cuts) -> CutSet: + """ + Read the manifest file and return a CutSet object. + Each line in the manifest file should be a JSON object representing a segment. + """ + + self.id2cut = {} + self.sess2cut_ids = defaultdict(list) + self.sess2spks = defaultdict(set) + self.data2sess_ids = defaultdict(list) + self.spk2cut_ids = defaultdict(list) + self.data2num_spk2cut_ids = {} + self.sess2num_spk2cut_ids = {} + self.num_spk2cut_ids = {i+1:[] for i in range(self.max_num_speakers)} + for i, cut in tqdm(enumerate(cuts), desc="Reading segments", ncols=100, total=len(cuts)): + if cut.duration > self.skip_duration: + continue + if not hasattr(cut, 'dataset_id') or cut.dataset_id is None: + continue + if self.valid_dataset_ids and cut.dataset_id not in self.valid_dataset_ids: + continue + if cut.dataset_id not in self.data2num_spk2cut_ids: + self.data2num_spk2cut_ids[cut.dataset_id] = defaultdict(list) + if cut.recording_id not in self.sess2num_spk2cut_ids: + self.sess2num_spk2cut_ids[cut.recording_id] = defaultdict(list) + + speakers = cut.global_speaker_ids + if self.data_type == "msasr": + speaker_tokens = set(re.findall(r'<\|spltoken\d+\|>', cut.text)) + if len(speakers) != len(speaker_tokens): + # Lhotse automatically fixes the max duration of the cut, + # resulting in the mismatch of the number of speakers + # and speaker tokens for the last segment + # TODO: need to fix the issue in Lhotse that automatically fixes the max duration + continue + for spk in speakers: + self.spk2cut_ids[spk].append(cut.id) + self.sess2spks[cut.recording_id] = self.sess2spks[cut.recording_id].union(speakers) + + self.id2cut[cut.id] = cut + self.sess2cut_ids[cut.recording_id].append(cut.id) + self.data2num_spk2cut_ids[cut.dataset_id][len(speakers)].append(cut.id) + self.sess2num_spk2cut_ids[cut.recording_id][len(speakers)].append(cut.id) + self.num_spk2cut_ids[len(speakers)].append(cut.id) + if cut.recording_id not in self.data2sess_ids[cut.dataset_id]: + self.data2sess_ids[cut.dataset_id].append(cut.recording_id) + + self.cut_ids = list(self.id2cut.keys()) + self.num_spk2sess_ids = groupby(lambda x: len(self.sess2spks[x]), self.sess2spks.keys()) + + self.data2global_speaker = { + dataset_id: True for dataset_id in self.data2sess_ids.keys() + } + + def _create_mixture(self, n_speakers: int, is_intra_session_concat=False) -> MixedCut: + + db_norm = norm.rvs(-32.05957708631966, 5.66648411405886) # mean and std from Fisher data + + if is_intra_session_concat: + # intra-dataset and intra-session concatenation + tracks, num_speakers = self.get_intra_session_tracks(n_speakers, db_norm=db_norm) + + else: + # intra-dataset but inter-session concatenation + tracks, num_speakers = self.get_inter_session_tracks(n_speakers, db_norm=db_norm) + + cut = MixedCut(id='concat_' + '_'.join([track.cut.id for track in tracks]), tracks=tracks) + if self.data_type == "msasr": + cut = self.reorder_spk_mapping(cut) + + assert self.min_duration <= cut.duration <= self.max_duration, f"Total duration {cut.duration} is not within the range of min {self.min_duration} and max {self.max_duration}" + assert n_speakers == num_speakers, f"Total number of speakers {cut.num_speakers} is not equal to the number of speakers {n_speakers}" + + return cut + + def get_intra_session_tracks(self, n_speakers: int=4, db_norm: float=-25) -> List[MixTrack]: + """ + Get the tracks for the MixedCut object. + """ + session_id = random.choice(self.num_spk2sess_ids[n_speakers]) + + total_duration = 0.0 + total_spk_set = set() + tracks = [] + while True: + cut = self.id2cut[random.choice(self.sess2cut_ids[session_id])] + tracks.append(MixTrack(cut=deepcopy(cut.normalize_loudness(target=db_norm, mix_first=False)), type=type(cut), offset=total_duration)) + total_spk_set = total_spk_set.union(cut.global_speaker_ids) + total_duration += cut.duration + + # break condition + if total_duration >= self.min_duration: + if total_duration > self.max_duration: # exceed the maximum duration, starting over + total_duration = 0.0 + total_spk_set = set() + tracks = [] + session_id = random.choice(self.num_spk2sess_ids[n_speakers]) + if len(total_spk_set) == n_speakers: # meet the number of speakers and duration, break + break + else: + total_duration = 0.0 + total_spk_set = set() + tracks = [] + session_id = random.choice(self.num_spk2sess_ids[n_speakers]) + + return tracks, len(total_spk_set) + + def get_inter_session_tracks(self, n_speakers: int=4, db_norm: float=-25) -> List[MixTrack]: + """ + Get the tracks for the MixedCut object. + """ + sample_cut = self.id2cut[random.choice(self.cut_ids)] + dataset_id = sample_cut.dataset_id + n_spk_list = [n_spk for n_spk, cut_ids in self.data2num_spk2cut_ids[dataset_id].items() if len(cut_ids) > 0] + sum_spk_list = set([i + j for i in n_spk_list for j in n_spk_list]) + + if min(sum_spk_list) > n_speakers: + raise ValueError(f"Cannot generate {n_speakers}-speaker inter session samples by concatenating two samples since the dataset {dataset_id} only have {','.join([str(i) for i in n_spk_list])} speakers.") + + n_spk_left = n_speakers + total_duration = 0.0 + total_spk_set = set() + tracks = [] + num_spk2cut_ids = self.data2num_spk2cut_ids[dataset_id] + while True: + #if n_spk_left == n_speakers: # for more speakers cases + # n_spk = random.choice([n_spk for n_spk in n_spk_list if n_spk < n_spk_left]) + if n_spk_left >= 2: + n_spk = 2 + else: + # n_spk = random.choice([n_spk for n_spk in n_spk_list if n_spk <= n_spk_left]) + n_spk = 1 + + while True: + cut = self.id2cut[random.choice(num_spk2cut_ids[n_spk])] + spks = set(cut.global_speaker_ids) + if not spks.intersection(total_spk_set): + break + + tracks.append(MixTrack(cut=deepcopy(cut.normalize_loudness(target=db_norm, mix_first=False)), type=type(cut), offset=total_duration)) + total_duration += cut.duration + n_spk_left -= n_spk + total_spk_set = total_spk_set.union(spks) + + # break condition + + if total_duration >= self.min_duration: + if total_duration > self.max_duration or len(total_spk_set) < n_speakers: # exceed the maximum duration, starting over + total_duration = 0.0 + n_spk_left = n_speakers + total_spk_set = set() + tracks = [] + if len(total_spk_set) == n_speakers: # meet the number of speakers and duration, break + break + else: + if len(total_spk_set) == n_speakers: # meet the number of speakers, but not the duration, starting over --- TODO: will try to find the segments that only contains those speakers + total_duration = 0.0 + n_spk_left = n_speakers + total_spk_set = set() + tracks = [] + + return tracks, len(total_spk_set) + + def reorder_spk_mapping(self, cut: MixedCut, pattern=r'<\|spltoken\d+\|>') -> str: + """ + Concatenate the texts of the input cuts. + + """ + global_spk_mapping = {} + str_pattern = pattern.replace("\\", '') + left_str, right_str = str_pattern.split('d+') + for i, track in enumerate(cut.tracks): + local_inverse_spk_mapping = {} + local_spk_mapping = {} + for speaker in track.cut.global_speaker_ids: + if speaker not in global_spk_mapping: + global_spk_mapping[speaker] = len(global_spk_mapping) + if speaker not in local_spk_mapping: + local_spk_mapping[speaker] = len(local_spk_mapping) + local_inverse_spk_mapping[len(local_inverse_spk_mapping)] = speaker + + if i != 0: + text = '' + for word in track.cut.text.split(): + if len(re.findall(pattern, word)) > 0: + local_spk_idx = int(word.replace(left_str,'').replace(right_str, '')) + spk = local_inverse_spk_mapping[local_spk_idx] + global_spk_idx = global_spk_mapping[spk] + text += f'{left_str}{global_spk_idx}{right_str}' + else: + text += ' ' + word + track.cut.supervisions[0].text = text + cut.supervisions[i].text = text + else: + cut.supervisions[0].text = track.cut.text + # TODO: need to check the last speaker of last track and the first speaker of the current track + # if they are the same, we need to remove the the speaker token from the current track for segment-level + # Do not need to remove the speaker token for word-level + + return cut + + def apply_speaker_distribution(self, num_meetings: int, speaker_count_distribution) -> Dict[int, int]: + """ + Balance the speaker distribution for the simulated meetings. + Args: + num_meetings: The total number of simulated meetings. + speaker_count_distribution: The speaker count distribution for the simulated meetings. + For each number of speakers, calculate the number of meetings needed to balance the distribution. + """ + + total_spk = sum(speaker_count_distribution) + num_speakers2num_meetings = {} + for i_spk in range(self.max_num_speakers): + num_speakers2num_meetings[i_spk+1] = round(num_meetings * speaker_count_distribution[i_spk] / total_spk) + + return num_speakers2num_meetings + + + @dill_enabled(True) + def simulate(self, + cuts: CutSet, + num_meetings: int = 10000, + seed: int = 0, + num_jobs: int = 1, + ) -> CutSet: + random.seed(seed) + + self.fit(cuts) + + + num_speakers2num_meetings = self.apply_speaker_distribution(num_meetings, self.speaker_count_distribution) + logging.warn(f"Will be generating {(','.join([str(i) for i in num_speakers2num_meetings.values()]))} samples for {(','.join([str(i) for i in num_speakers2num_meetings.keys()]))} speakers given speaker count distribution of {str(self.speaker_count_distribution)}.") + num_speakers2num_meetings[1] = 0 # skip 1-speaker samples + logging.warn(f'But 1-speaker samples will be skipped. Will be generating {sum(num_speakers2num_meetings.values()) - num_speakers2num_meetings[1]} samples in total.') + + # Step 0: Calculate the number of intra-session and inter-session concatentation samples + n_spks = [k for k, v in self.num_spk2cut_ids.items() if len(v) > 0] + valid_sim_n_spks = set([i+j for i in n_spks for j in n_spks]) # valid number of speakers for inter-session samples + n_spk2n_intra_mt, n_spk2n_inter_mt = {i+1:0 for i in range(self.max_num_speakers)}, {i+1:0 for i in range(self.max_num_speakers)} + for n_spk, n_mt in num_speakers2num_meetings.items(): + logging.warn(f"=="*16 + f"{n_spk}-speaker" + "=="*16) + if n_mt <= 0: + logging.warning(f"No concatentation samples for {n_spk} speakers. Will skip simulation for {n_spk} speakers.") + continue + n_intra_mt = int(n_mt * self.intra_session_concat_prob[n_spk-1]) + n_inter_mt = n_mt - n_intra_mt + if n_spk in self.num_spk2sess_ids: + logging.warn(f"Will be genrating {n_intra_mt} {n_spk}-speaker intra-session concatentation samples.") + n_spk2n_intra_mt[n_spk] = n_intra_mt + else: + logging.warning(f"Cannot generate {n_intra_mt} {n_spk}-speaker intra-session samples by concatenating two samples from the same session since we only have samples for {','.join([str(i) for i in n_spks])} speakers.") + n_spk2n_intra_mt[n_spk] = 0 + n_inter_mt = n_mt + if n_spk in valid_sim_n_spks: + logging.warn(f"Will be genrating {n_inter_mt} {n_spk}-speaker inter-session concatentation samples.") + n_spk2n_inter_mt[n_spk] = n_inter_mt + else: + logging.warning(f"Cannot generate {n_inter_mt} {n_spk}-speaker inter-session samples by concatenating two samples from different sessions since we only have samples for {','.join([str(i) for i in n_spks])} speakers.") + if n_spk2n_intra_mt[n_spk] != 0: + n_spk2n_intra_mt[n_spk] = n_mt + logging.warn(f"Will be genrating {n_spk2n_intra_mt[n_spk]} {n_spk}-speaker intra-session concatentation samples instead.") + else: + logging.warning(f"No samples for {n_spk} speakers. Will skip simulation for {n_spk} speakers.") + logging.warn(f"""Will be generating {','.join([str(i) for i in n_spk2n_intra_mt.values()])} intra-session concatentation samples and {','.join([str(i) for i in n_spk2n_inter_mt.values()])} inter-session concatentation samples for {','.join([str(i+1) for i in range(self.max_num_speakers)])} speakers.""") + # Step 1: intra-session + num_intra_meetings = 0 + intra_mixtures = [] + logging.info(f"Simulating intra-session concatentation samples.") + for n_spk, n_mt in n_spk2n_intra_mt.items(): + if n_mt <= 0: + continue + + for i in tqdm(range(n_mt), desc=f"Simulating {n_spk}-speaker intra-session mixtures", ncols=128): + intra_mixtures.append(self._create_mixture(n_speakers=n_spk, is_intra_session_concat=True)) + num_intra_meetings += n_mt + logging.info(f"Finished simulating intra-session concatentation samples. Total number of intra-session concatentation samples: {num_intra_meetings}") + + # Steo 2: inter-session + logging.info(f"Simulating inter-session concatentation samples.") + + num_inter_meetings = 0 + inter_mixtures = [] + for n_spk, n_mt in n_spk2n_inter_mt.items(): + if n_mt <= 0: + continue + + for i in tqdm(range(n_mt), desc=f"Simulating {n_spk}-speaker inter-session mixtures", ncols=128): + inter_mixtures.append(self._create_mixture(n_speakers=n_spk, is_intra_session_concat=False)) + num_inter_meetings += n_mt + logging.info(f"Finished simulating inter-session concatentation samples. Total number of inter-session concatentation samples: {num_inter_meetings}") + + if num_inter_meetings + num_intra_meetings == 0: + logging.warning(f"No samples are generated. Probably the duration of the segments is not within the range of min {self.min_duration//2} and max {self.max_duration//2}, or the speaker count distribution is not correctly set.") + + + # Multi-processing gets slower, TODO + # else: + # futures = [] + # for n_spk, n_mt in num_speakers2num_meetings.items(): + # tp = concurrent.futures.ProcessPoolExecutor(max_workers=num_jobs) + # futures.extend([tp.submit(self._create_mixture, n_spk) for _ in range(n_mt)]) + # pbar = tqdm(total=num_meetings, desc=f"Simulating mixtures", unit="line", ncols=128) + # count = 0 + # for f in concurrent.futures.as_completed(futures): + # count += 1 + # pbar.update() + # mixtures.append(f.result()) + # tp.shutdown() + # pbar.close() + + return CutSet.from_cuts(intra_mixtures + inter_mixtures) + + +class MixMeetingSimulator(): + """ + This simulator Mix the segments from different/same sessions to create a + multi-speaker meeting. + """ + + def __init__( + self, + intra_session_mix_prob: float|List[float] = [0, 0, 0, 0], + data_type: str = "msasr", + min_duration: float = 80.0, + max_duration: float = 100.0, + max_num_speakers: int = 4, + speaker_count_distribution: List[float] = [0, 0, 0.1, 4], + valid_dataset_ids: List[str] = [], + ): + """ + :param intra_session_mix_prob: the probability of concatenating segments from the same + session. [Default: 1] + :param data_type: the type of data to simulate. Either 'msasr' or 'diar'. If 'msasr', + the transcripts are included in the simulation,and the boundary segments are + not included. [Default: 'msasr'] + :param max_duration: the maximum duration of the simulated meeting. [Default: 40.0] + """ + super().__init__() + if isinstance(intra_session_mix_prob, float): + self.intra_session_mix_prob = [intra_session_mix_prob] * (max_num_speakers) + elif len(intra_session_mix_prob) == max_num_speakers: + self.intra_session_mix_prob = intra_session_mix_prob + else: + raise ValueError(f"intra_session_mix_prob must be either a float or a list of floats, but got {intra_session_mix_prob}") + if data_type not in ["msasr", "diar"]: + raise ValueError("data_type must be either 'msasr' or 'diar', but got {data_type}") + self.data_type = data_type + self.min_duration = min_duration + self.max_duration = max_duration + self.max_num_speakers = max_num_speakers + self.speaker_count_distribution = speaker_count_distribution + self.valid_dataset_ids = valid_dataset_ids + assert len(speaker_count_distribution) == max_num_speakers, f"Length of speaker_count_distribution {len(speaker_count_distribution)} must be equal to max_num_speakers {max_num_speakers}" + + def fit(self, cuts) -> CutSet: + """ + Read the manifest file and return a CutSet object. + Each line in the manifest file should be a JSON object representing a segment. + """ + + self.id2cut = {} + self.sess2cut_ids = defaultdict(list) + self.sess2spks = defaultdict(set) + self.data2sess_ids = defaultdict(list) + self.spk2cut_ids = defaultdict(list) + self.data2num_spk2cut_ids = {} + self.sess2num_spk2cut_ids = {} + self.num_spk2cut_ids = {i+1:[] for i in range(self.max_num_speakers)} + for i, cut in tqdm(enumerate(cuts), desc="Reading segments", ncols=100, total=len(cuts)): + if not self.min_duration <= cut.duration <= self.max_duration: + continue + if not hasattr(cut, 'dataset_id') or cut.dataset_id is None: + continue + if self.valid_dataset_ids and cut.dataset_id not in self.valid_dataset_ids: + continue + if cut.dataset_id not in self.data2num_spk2cut_ids: + self.data2num_spk2cut_ids[cut.dataset_id] = defaultdict(list) + if cut.recording_id not in self.sess2num_spk2cut_ids: + self.sess2num_spk2cut_ids[cut.recording_id] = defaultdict(list) + + speakers = cut.global_speaker_ids + if self.data_type == "msasr": + speaker_tokens = set(re.findall(r'<\|spltoken\d+\|>', cut.text)) + if len(speakers) != len(speaker_tokens): + # Lhotse automatically fixes the max duration of the cut, + # resulting in the mismatch of the number of speakers + # and speaker tokens for the last segment + # TODO: need to fix the issue in Lhotse that automatically fixes the max duration + continue + for spk in speakers: + self.spk2cut_ids[spk].append(cut.id) + self.sess2spks[cut.recording_id] = self.sess2spks[cut.recording_id].union(speakers) + + self.id2cut[cut.id] = cut + self.sess2cut_ids[cut.recording_id].append(cut.id) + self.data2num_spk2cut_ids[cut.dataset_id][len(speakers)].append(cut.id) + self.sess2num_spk2cut_ids[cut.recording_id][len(speakers)].append(cut.id) + self.num_spk2cut_ids[len(speakers)].append(cut.id) + if cut.recording_id not in self.data2sess_ids[cut.dataset_id]: + self.data2sess_ids[cut.dataset_id].append(cut.recording_id) + + self.cut_ids = list(self.id2cut.keys()) + self.num_spk2sess_ids = groupby(lambda x: len(self.sess2spks[x]), self.sess2spks.keys()) + + self.data2global_speaker = { + dataset_id: True for dataset_id in self.data2sess_ids.keys() + } + + def _create_mixture(self, n_speakers: int, is_intra_session_concat=False) -> MixedCut: + + db_norm = norm.rvs(-32.05957708631966, 5.66648411405886) # mean and std from Fisher data + + if is_intra_session_concat: + # intra-dataset and intra-session concatenation + tracks, num_speakers = self.get_intra_session_tracks(n_speakers, db_norm=db_norm) + + else: + # intra-dataset but inter-session concatenation + tracks, num_speakers = self.get_inter_session_tracks(n_speakers, db_norm=db_norm) + + cut = MixedCut(id='mix_' + '_'.join([track.cut.id for track in tracks]), tracks=tracks) + if self.data_type == "msasr": + cut = self.reorder_spk_mapping(cut) + + assert self.min_duration <= cut.duration <= self.max_duration, f"Total duration {cut.duration} is not within the range of min {self.min_duration} and max {self.max_duration}" + assert n_speakers == num_speakers, f"Total number of speakers {cut.num_speakers} is not equal to the number of speakers {n_speakers}" + + return cut + + def get_intra_session_tracks(self, n_speakers: int=4, db_norm: float=-25) -> List[MixTrack]: + """ + Get the tracks for the MixedCut object. + """ + session_id = random.choice(self.num_spk2sess_ids[n_speakers]) + + total_spk_set = set() + tracks = [] + while True: + cut = self.id2cut[random.choice(self.sess2cut_ids[session_id])] + tracks.append(MixTrack(cut=deepcopy(cut.normalize_loudness(target=db_norm, mix_first=False)), type=type(cut), offset=0)) + total_spk_set = total_spk_set.union(cut.global_speaker_ids) + total_duration = max(total_duration, cut.duration) + + # break condition + if total_duration >= self.min_duration: + if total_duration > self.max_duration: # exceed the maximum duration, starting over + total_duration = 0.0 + total_spk_set = set() + tracks = [] + session_id = random.choice(self.num_spk2sess_ids[n_speakers]) + if len(total_spk_set) == n_speakers: # meet the number of speakers and duration, break + break + else: + total_duration = 0.0 + total_spk_set = set() + tracks = [] + session_id = random.choice(self.num_spk2sess_ids[n_speakers]) + + return tracks, len(total_spk_set) + + def get_inter_session_tracks(self, n_speakers: int=4, db_norm: float=-25) -> List[MixTrack]: + """ + Get the tracks for the MixedCut object. + """ + sample_cut = self.id2cut[random.choice(self.cut_ids)] + dataset_id = sample_cut.dataset_id + n_spk_list = [n_spk for n_spk, cut_ids in self.data2num_spk2cut_ids[dataset_id].items() if len(cut_ids) > 0] + sum_spk_list = set([i + j for i in n_spk_list for j in n_spk_list]) + + if min(sum_spk_list) > n_speakers: + raise ValueError(f"Cannot generate {n_speakers}-speaker inter session samples by concatenating two samples since the dataset {dataset_id} only have {','.join([str(i) for i in n_spk_list])} speakers.") + + n_spk_left = n_speakers + total_duration = 0.0 + total_spk_set = set() + tracks = [] + num_spk2cut_ids = self.data2num_spk2cut_ids[dataset_id] + while True: + if n_spk_left >= 2: + n_spk = 2 + else: + # n_spk = random.choice([n_spk for n_spk in n_spk_list if n_spk <= n_spk_left]) + n_spk = 1 + + while True: + cut = self.id2cut[random.choice(num_spk2cut_ids[n_spk])] + spks = set(cut.global_speaker_ids) + if not spks.intersection(total_spk_set): + break + + tracks.append(MixTrack(cut=deepcopy(cut.normalize_loudness(target=db_norm, mix_first=False)), type=type(cut), offset=0)) + total_duration = max(total_duration, cut.duration) + n_spk_left -= n_spk + total_spk_set = total_spk_set.union(spks) + + # break condition + + if total_duration >= self.min_duration: + if total_duration > self.max_duration or len(tracks) > 2: # exceed the maximum duration, starting over + total_duration = 0.0 + n_spk_left = n_speakers + total_spk_set = set() + tracks = [] + if len(total_spk_set) == n_speakers: # meet the number of speakers and duration, break + break + else: + if len(total_spk_set) == n_speakers: # meet the number of speakers, but not the duration, starting over --- TODO: will try to find the segments that only contains those speakers + total_duration = 0.0 + n_spk_left = n_speakers + total_spk_set = set() + tracks = [] + + return tracks, len(total_spk_set) + + def reorder_spk_mapping(self, cut: MixedCut, pattern=r'<\|spltoken\d+\|>') -> str: + """ + Concatenate the texts of the input cuts. + + """ + global_spk_mapping = {} + str_pattern = pattern.replace("\\", '') + left_str, right_str = str_pattern.split('d+') + for i, track in enumerate(cut.tracks): + local_inverse_spk_mapping = {} + local_spk_mapping = {} + for speaker in track.cut.global_speaker_ids: + if speaker not in global_spk_mapping: + global_spk_mapping[speaker] = len(global_spk_mapping) + if speaker not in local_spk_mapping: + local_spk_mapping[speaker] = len(local_spk_mapping) + local_inverse_spk_mapping[len(local_inverse_spk_mapping)] = speaker + + if i != 0: + text = '' + for word in track.cut.text.split(): + if len(re.findall(pattern, word)) > 0: + local_spk_idx = int(word.replace(left_str,'').replace(right_str, '')) + spk = local_inverse_spk_mapping[local_spk_idx] + global_spk_idx = global_spk_mapping[spk] + text += f'{left_str}{global_spk_idx}{right_str}' + else: + text += ' ' + word + track.cut.supervisions[0].text = text + cut.supervisions[i].text = text + else: + cut.supervisions[0].text = track.cut.text + # TODO: need to check the last speaker of last track and the first speaker of the current track + # if they are the same, we need to remove the the speaker token from the current track for segment-level + # Do not need to remove the speaker token for word-level + + return cut + + def apply_speaker_distribution(self, num_meetings: int, speaker_count_distribution) -> Dict[int, int]: + """ + Balance the speaker distribution for the simulated meetings. + Args: + num_meetings: The total number of simulated meetings. + speaker_count_distribution: The speaker count distribution for the simulated meetings. + For each number of speakers, calculate the number of meetings needed to balance the distribution. + """ + + total_spk = sum(speaker_count_distribution) + num_speakers2num_meetings = {} + for i_spk in range(self.max_num_speakers): + num_speakers2num_meetings[i_spk+1] = round(num_meetings * speaker_count_distribution[i_spk] / total_spk) + + return num_speakers2num_meetings + + + @dill_enabled(True) + def simulate(self, + cuts: CutSet, + num_meetings: int = 10000, + seed: int = 0, + num_jobs: int = 1, + ) -> CutSet: + random.seed(seed) + + self.fit(cuts) + + num_speakers2num_meetings = self.apply_speaker_distribution(num_meetings, self.speaker_count_distribution) + logging.warn(f"Will be generating {(','.join([str(i) for i in num_speakers2num_meetings.values()]))} samples for {(','.join([str(i) for i in num_speakers2num_meetings.keys()]))} speakers given speaker count distribution of {str(self.speaker_count_distribution)}.") + num_speakers2num_meetings[1] = 0 # skip 1-speaker samples + logging.warn(f'But 1-speaker samples will be skipped. Will be generating {sum(num_speakers2num_meetings.values()) - num_speakers2num_meetings[1]} samples in total.') + + # Step 0: Calculate the number of intra-session and inter-session concatentation samples + n_spks = [k for k, v in self.num_spk2cut_ids.items() if len(v) > 0] + valid_sim_n_spks = set([i+j for i in n_spks for j in n_spks]) # valid number of speakers for inter-session samples + n_spk2n_intra_mt, n_spk2n_inter_mt = {i+1:0 for i in range(self.max_num_speakers)}, {i+1:0 for i in range(self.max_num_speakers)} + for n_spk, n_mt in num_speakers2num_meetings.items(): + logging.warn(f"=="*16 + f"{n_spk}-speaker" + "=="*16) + if n_mt <= 0: + logging.warning(f"No intra-session concatentation samples for {n_spk} speakers. Will skip simulation for {n_spk} speakers.") + continue + n_intra_mt = int(n_mt * self.intra_session_mix_prob[n_spk-1]) + n_inter_mt = n_mt - n_intra_mt + if n_spk in self.num_spk2sess_ids: + logging.warn(f"Will be genrating {n_intra_mt} {n_spk}-speaker intra-session concatentation samples.") + n_spk2n_intra_mt[n_spk] = n_intra_mt + else: + logging.warning(f"Cannot generate {n_intra_mt} {n_spk}-speaker intra-session samples by concatenating two samples from the same session since we only have samples for {','.join([str(i) for i in n_spks])} speakers.") + n_spk2n_intra_mt[n_spk] = 0 + n_inter_mt = n_mt + if n_spk in valid_sim_n_spks: + logging.warn(f"Will be genrating {n_inter_mt} {n_spk}-speaker inter-session concatentation samples.") + n_spk2n_inter_mt[n_spk] = n_inter_mt + else: + logging.warning(f"Cannot generate {n_inter_mt} {n_spk}-speaker inter-session samples by concatenating two samples from different sessions since we only have samples for {','.join([str(i) for i in n_spks])} speakers.") + if n_spk2n_intra_mt[n_spk] != 0: + n_spk2n_intra_mt[n_spk] = n_mt + logging.warn(f"Will be genrating {n_spk2n_intra_mt[n_spk]} {n_spk}-speaker intra-session concatentation samples instead.") + else: + logging.warning(f"No samples for {n_spk} speakers. Will skip simulation for {n_spk} speakers.") + logging.warn(f"""Will be generating {','.join([str(i) for i in n_spk2n_intra_mt.values()])} intra-session concatentation samples and {','.join([str(i) for i in n_spk2n_inter_mt.values()])} inter-session concatentation samples for {','.join([str(i+1) for i in range(self.max_num_speakers)])} speakers.""") + # Step 1: intra-session + num_intra_meetings = 0 + intra_mixtures = [] + logging.info(f"Simulating intra-session concatentation samples.") + for n_spk, n_mt in n_spk2n_intra_mt.items(): + if n_mt <= 0: + continue + + for i in tqdm(range(n_mt), desc=f"Simulating {n_spk}-speaker intra-session mixtures", ncols=128): + intra_mixtures.append(self._create_mixture(n_speakers=n_spk, is_intra_session_concat=True)) + num_intra_meetings += n_mt + logging.info(f"Finished simulating intra-session concatentation samples. Total number of intra-session concatentation samples: {num_intra_meetings}") + + # Steo 2: inter-session + logging.info(f"Simulating inter-session concatentation samples.") + + num_inter_meetings = 0 + inter_mixtures = [] + for n_spk, n_mt in n_spk2n_inter_mt.items(): + if n_mt <= 0: + continue + + for i in tqdm(range(n_mt), desc=f"Simulating {n_spk}-speaker inter-session mixtures", ncols=128): + inter_mixtures.append(self._create_mixture(n_speakers=n_spk, is_intra_session_concat=False)) + num_inter_meetings += n_mt + logging.info(f"Finished simulating inter-session concatentation samples. Total number of inter-session concatentation samples: {num_inter_meetings}") + + if num_inter_meetings + num_intra_meetings == 0: + logging.warning(f"No samples are generated. Probably the duration of the segments is not within the range of min {self.min_duration} and max {self.max_duration}, or the speaker count distribution is not correctly set.") + + return CutSet.from_cuts(intra_mixtures + inter_mixtures) + +class LibriSpeechMixSimulator(): + + def __init__( + self, + min_duration: float = 80.0, + max_duration: float = 100.0, + n_mix_speakers: List[int] = [1, 2, 3], + speaker_count_distribution: List[float] = [1, 1, 1], + ): + """ + :param min_duration: the minimum duration of the simulated meeting. [Default: 80.0] + :param max_duration: the maximum duration of the simulated meeting. [Default: 100.0] + """ + super().__init__() + self.min_duration = min_duration + self.max_duration = max_duration + self.n_mix_speakers = n_mix_speakers + self.speaker_count_distribution = speaker_count_distribution + assert len(speaker_count_distribution) == len(n_mix_speakers), f"Length of speaker_count_distribution {len(speaker_count_distribution)} must be equal to max_num_speakers {len(n_mix_speakers)}" + + def fit(self, cuts) -> CutSet: + pass + + def simulate(self, + cuts: CutSet, + num_meetings: int = 10000, + seed: int = 0, + num_jobs: int = 1, + ) -> CutSet: + random.seed(seed) + + cut_set = [] + for n_speakers, n_mt in zip(self.n_mix_speakers, self.speaker_count_distribution): + if n_mt <= 0: + continue + for i in tqdm(range(n_mt), desc=f"Simulating {n_speakers}-speaker mixtures", ncols=128): + cut_set.append(self._create_mixture(n_speakers=n_speakers)) + return CutSet.from_cuts(cut_set) + +class LibriSpeechMixGenerator(): + def __init__(self): + pass + + def generate(self, cuts): + cut_set = [] + for cut in tqdm(cuts): + offsets = cut.delays + durations = cut.durations + wavs = cut.wavs + texts = cut.texts + speakers = cut.speakers + + tracks = [] + for i, (offset, duration, wav, text, speaker) in enumerate(zip(offsets, durations, wavs, texts, speakers)): + wav_dur = soundfile.info(wav).duration + wav_samples = soundfile.info(wav).frames + custom = { + 'speaker': speaker, + 'text': text, + } + cut_1spk = MonoCut( + id=wav.split('/')[-1].replace('.wav', ''), + start=0, + duration=duration, + channel=0, + supervisions=[], + recording=Recording( + id=wav.split('/')[-1].replace('.wav', ''), + sources=[ + AudioSource( + type='file', + channels=[0], + source=wav + ) + ], + sampling_rate=16000, + num_samples=wav_samples, + duration=wav_dur + ), + custom=custom + ) + + tracks.append(MixTrack(cut=cut_1spk, type=type(cut_1spk), offset=offset)) + sup = SupervisionSegment( + id=cut.id, + recording_id=cut.recording_id, + start=0, + duration=offset+wav_dur, + text=cut.text, + ) + tracks[0].cut.supervisions.append(sup) + cut_multi_spk = MixedCut(id=cut.id, tracks=tracks) + + cut_set.append(cut_multi_spk) + + return CutSet.from_cuts(cut_set) \ No newline at end of file diff --git a/nemo/collections/common/parts/preprocessing/collections.py b/nemo/collections/common/parts/preprocessing/collections.py index b16ac50e4d56..144ae405de52 100644 --- a/nemo/collections/common/parts/preprocessing/collections.py +++ b/nemo/collections/common/parts/preprocessing/collections.py @@ -16,8 +16,7 @@ import json import os from itertools import combinations -from typing import Any, Callable, Dict, Iterable, List, Optional, Union - +from typing import Any, Dict, Iterable, List, Optional, Union import numpy as np import pandas as pd @@ -311,7 +310,7 @@ def __init__( class ASRAudioText(AudioText): """`AudioText` collector from asr structured json files.""" - def __init__(self, manifests_files: Union[str, List[str]], parse_func: Optional[Callable] = None, *args, **kwargs): + def __init__(self, manifests_files: Union[str, List[str]], *args, **kwargs): """Parse lists of audio files, durations and transcripts texts. Args: @@ -334,9 +333,8 @@ def __init__(self, manifests_files: Union[str, List[str]], parse_func: Optional[ [], [], ) - speakers, orig_srs, token_labels, langs = [], [], [], [] - for item in manifest.item_iter(manifests_files, parse_func=parse_func): + for item in manifest.item_iter(manifests_files): ids.append(item['id']) audio_files.append(item['audio_file']) durations.append(item['duration']) @@ -1244,6 +1242,190 @@ def __parse_item_rttm(self, line: str, manifest_file: str) -> Dict[str, Any]: ) return item +class EndtoEndDiarizationLabel(_Collection): + """List of diarization audio-label correspondence with preprocessing.""" + + OUTPUT_TYPE = collections.namedtuple( + typename='DiarizationLabelEntity', + field_names='audio_file uniq_id duration rttm_file offset', + ) + + def __init__( + self, + audio_files: List[str], + uniq_ids: List[str], + durations: List[float], + rttm_files: List[str], + offsets: List[float], + max_number: Optional[int] = None, + do_sort_by_duration: bool = False, + index_by_file_id: bool = False, + ): + """ + Instantiates audio-label manifest with filters and preprocessing. + + This method initializes the EndtoEndDiarizationLabel object by processing the input data + and applying optional filters and sorting. + + Args: + audio_files (List[str]): List of audio file paths. + uniq_ids (List[str]): List of unique identifiers for each audio file. + durations (List[float]): List of float durations for each audio file. + rttm_files (List[str]): List of RTTM path strings (Groundtruth diarization annotation file). + offsets (List[float]): List of offsets or None for each audio file. + max_number (Optional[int]): Maximum number of samples to collect. Defaults to None. + do_sort_by_duration (bool): If True, sort samples list by duration. Defaults to False. + index_by_file_id (bool): If True, saves a mapping from filename base (ID) to index in data. Defaults to False. + + """ + if index_by_file_id: + self.mapping = {} + output_type = self.OUTPUT_TYPE + data, duration_filtered = [], 0.0 + + zipped_items = zip( + audio_files, uniq_ids, durations, rttm_files, offsets + ) + for ( + audio_file, + uniq_id, + duration, + rttm_file, + offset, + ) in zipped_items: + + if duration is None: + duration = 0 + + data.append( + output_type( + audio_file, + uniq_id, + duration, + rttm_file, + offset, + ) + ) + + if index_by_file_id: + if isinstance(audio_file, list): + if len(audio_file) == 0: + raise ValueError(f"Empty audio file list: {audio_file}") + audio_file_name = sorted(audio_file)[0] + else: + audio_file_name = audio_file + file_id, _ = os.path.splitext(os.path.basename(audio_file)) + self.mapping[file_id] = len(data) - 1 + + # Max number of entities filter. + if len(data) == max_number: + break + + if do_sort_by_duration: + if index_by_file_id: + logging.warning("Tried to sort dataset by duration, but cannot since index_by_file_id is set.") + else: + data.sort(key=lambda entity: entity.duration) + + logging.info( + "Filtered duration for loading collection is %f.", duration_filtered, + ) + logging.info(f"Total {len(data)} session files loaded accounting to # {len(audio_files)} audio clips") + + super().__init__(data) + + +class EndtoEndDiarizationSpeechLabel(EndtoEndDiarizationLabel): + """`DiarizationLabel` diarization data sample collector from structured json files.""" + + def __init__( + self, + manifests_files: Union[str, List[str]], + round_digits=2, + *args, + **kwargs, + ): + """ + Parse lists of audio files, durations, RTTM (Diarization annotation) files. + Since diarization model infers only two speakers, speaker pairs are generated + from the total number of speakers in the session. + + Args: + manifest_filepath (str): + Path to input manifest json files. + round_digit (int): + Number of digits to be rounded. + *args: Args to pass to `SpeechLabel` constructor. + **kwargs: Kwargs to pass to `SpeechLabel` constructor. + """ + self.round_digits = round_digits + audio_files, uniq_ids, durations, rttm_files, offsets = ( + [], + [], + [], + [], + [], + ) + + for item in manifest.item_iter(manifests_files, parse_func=self.__parse_item_rttm): + # Training mode + rttm_labels = [] + with open(item['rttm_file'], 'r') as f: + for index, rttm_line in enumerate(f.readlines()): + rttm = rttm_line.strip().split() + start = round(float(rttm[3]), round_digits) + end = round(float(rttm[4]), round_digits) + round(float(rttm[3]), round_digits) + speaker = rttm[7] + rttm_labels.append('{} {} {}'.format(start, end, speaker)) + audio_files.append(item['audio_file']) + uniq_ids.append(item['uniq_id']) + durations.append(item['duration']) + rttm_files.append(item['rttm_file']) + offsets.append(item['offset']) + + super().__init__( + audio_files, + uniq_ids, + durations, + rttm_files, + offsets, + *args, + **kwargs, + ) + + def __parse_item_rttm(self, line: str, manifest_file: str) -> Dict[str, Any]: + """Parse each rttm file and save it to in Dict format""" + item = json.loads(line) + if 'audio_filename' in item: + item['audio_file'] = item.pop('audio_filename') + elif 'audio_filepath' in item: + item['audio_file'] = item.pop('audio_filepath') + else: + raise ValueError( + f"Manifest file has invalid json line " f"structure: {line} without proper audio file key." + ) + if isinstance(item['audio_file'], list): + item['audio_file'] = [os.path.expanduser(audio_file_path) for audio_file_path in item['audio_file']] + else: + item['audio_file'] = os.path.expanduser(item['audio_file']) + + if not isinstance(item['audio_file'], list): + if 'uniq_id' not in item: + item['uniq_id'] = os.path.splitext(os.path.basename(item['audio_file']))[0] + elif 'uniq_id' not in item: + raise ValueError(f"Manifest file has invalid json line " f"structure: {line} without proper uniq_id key.") + + if 'duration' not in item: + raise ValueError(f"Manifest file has invalid json line " f"structure: {line} without proper duration key.") + item = dict( + audio_file=item['audio_file'], + uniq_id=item['uniq_id'], + duration=item['duration'], + rttm_file=item['rttm_filepath'], + offset=item.get('offset', None), + ) + return item + class Audio(_Collection): """Prepare a list of all audio items, filtered by duration.""" From 29143251c67bde5ef38aafa321bbd11a840bc1f2 Mon Sep 17 00:00:00 2001 From: taejinp Date: Thu, 14 Nov 2024 00:01:08 -0800 Subject: [PATCH 02/16] Tested all unit-test files Signed-off-by: taejinp --- .../neural_diarizer/e2e_diarize_speech.py | 386 ++++++++++++++++++ nemo/collections/asr/metrics/der.py | 41 +- .../asr/metrics/multi_binary_acc.py | 51 ++- nemo/collections/asr/models/__init__.py | 7 +- .../asr/models/sortformer_diar_models.py | 3 +- .../asr/modules/sortformer_modules.py | 111 +++++ .../asr/parts/utils/speaker_utils.py | 163 +++++++- nemo/collections/asr/parts/utils/vad_utils.py | 136 +++--- 8 files changed, 780 insertions(+), 118 deletions(-) create mode 100644 examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py create mode 100644 nemo/collections/asr/modules/sortformer_modules.py diff --git a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py new file mode 100644 index 000000000000..98f2ee10e523 --- /dev/null +++ b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py @@ -0,0 +1,386 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +python $BASEPATH/neural_diarizer/sortformer_diarization.py \ + model_path=/path/to/sortformer_model.nemo \ + batch_size=4 \ + dataset_manifest=/path/to/diarization_path_to_manifest.json + +""" +import pytorch_lightning as pl +from omegaconf import OmegaConf +from pytorch_lightning import seed_everything + +from nemo.collections.asr.models import SortformerEncLabelModel +from nemo.core.config import hydra_runner +from nemo.collections.asr.metrics.der import score_labels +from hydra.core.config_store import ConfigStore + +import os +import yaml +from dataclasses import dataclass, is_dataclass +from typing import Optional, Union, List, Tuple, Dict + +from nemo.collections.asr.parts.utils.speaker_utils import audio_rttm_map, timestamps_to_pyannote_object +from nemo.collections.asr.parts.utils.vad_utils import ts_vad_post_processing + +from tqdm import tqdm +import torch +import logging +import optuna +import tempfile + +seed_everything(42) +torch.backends.cudnn.deterministic = True + +@dataclass +class PostProcessingParams: + window_length_in_sec: float = 0.15 + shift_length_in_sec: float = 0.01 + smoothing: bool = False + overlap: float = 0.5 + onset: float = 0.5 + offset: float = 0.5 + pad_onset: float = 0.0 + pad_offset: float = 0.0 + min_duration_on: float = 0.0 + min_duration_off: float = 0.0 + filter_speech_first: bool = True + +@dataclass +class DiarizationConfig: + # Required configs + model_path: Optional[str] = None # Path to a .nemo file + pretrained_name: Optional[str] = None # Name of a pretrained model + audio_dir: Optional[str] = None # Path to a directory which contains audio files + dataset_manifest: Optional[str] = None # Path to dataset's JSON manifest + + postprocessing_yaml: Optional[str] = None # Path to a yaml file for postprocessing configurations + no_der: bool = False + out_rttm_dir: Optional[str] = None + + # General configs + session_len_sec: float = -1 # End-to-end diarization session length in seconds + batch_size: int = 4 + num_workers: int = 0 + random_seed: Optional[int] = None # seed number going to be used in seed_everything() + bypass_postprocessing: bool = True # If True, postprocessing will be bypassed + + # Eval Settings: (0.25, False) should be default setting for sortformer eval. + collar: float = 0.25 # Collar in seconds for DER calculation + ignore_overlap: bool = False # If True, DER will be calculated only for non-overlapping segments + + # If `cuda` is a negative number, inference will be on CPU only. + cuda: Optional[int] = None + matmul_precision: str = "highest" # Literal["highest", "high", "medium"] + + # Optuna Config + launch_pp_optim: bool = False # If True, launch optimization process for postprocessing parameters + optuna_study_name: str = "optim_postprocessing" + optuna_temp_dir: str = "/tmp/optuna" + optuna_storage: str = f"sqlite:///{optuna_study_name}.db" + optuna_log_file: str = f"{optuna_study_name}.log" + optuna_n_trials: int = 100000 + +def load_postprocessing_from_yaml(postprocessing_yaml): + """ + Load postprocessing parameters from a YAML file. + + Args: + postprocessing_yaml (str): + Path to a YAML file for postprocessing configurations. + + Returns: + postprocessing_params (dataclass): + Postprocessing parameters loaded from the YAML file. + """ + # Add PostProcessingParams as a field + postprocessing_params = OmegaConf.structured(PostProcessingParams()) + if postprocessing_yaml is None: + logging.info(f"No postprocessing YAML file has been provided. Default postprocessing configurations will be applied.") + else: + # Load postprocessing params from the provided YAML file + with open(postprocessing_yaml, 'r') as file: + yaml_params = yaml.safe_load(file)['parameters'] + # Update the postprocessing_params with the loaded values + logging.info(f"Postprocessing YAML file '{postprocessing_yaml}' has been loaded.") + for key, value in yaml_params.items(): + if hasattr(postprocessing_params, key): + setattr(postprocessing_params, key, value) + return postprocessing_params + +def optuna_suggest_params(postprocessing_cfg: PostProcessingParams, trial: optuna.Trial) -> PostProcessingParams: + """ + Suggests hyperparameters for postprocessing using Optuna. + + Args: + postprocessing_cfg (PostProcessingParams): The current postprocessing configuration. + trial (optuna.Trial): The Optuna trial object used to suggest hyperparameters. + + Returns: + PostProcessingParams: The updated postprocessing configuration with suggested hyperparameters. + """ + postprocessing_cfg.onset = trial.suggest_float("onset", 0.4, 0.8, step=0.01) + postprocessing_cfg.offset = trial.suggest_float("offset", 0.4, 0.9, step=0.01) + postprocessing_cfg.pad_onset = trial.suggest_float("pad_onset", 0.1, 0.5, step=0.01) + postprocessing_cfg.pad_offset = trial.suggest_float("pad_offset", 0.0, 0.2, step=0.01) + postprocessing_cfg.min_duration_on = trial.suggest_float("min_duration_on", 0.0, 0.75, step=0.01) + postprocessing_cfg.min_duration_off = trial.suggest_float("min_duration_off", 0.0, 0.75, step=0.01) + return postprocessing_cfg + +def get_tensor_path(cfg: DiarizationConfig) -> str: + """ + Constructs the file path for saving or loading prediction tensors based on the configuration. + + Args: + cfg (DiarizationConfig): The configuration object containing model and dataset details. + + Returns: + str: The constructed file path for the prediction tensor. + """ + tensor_filename = os.path.basename(cfg.dataset_manifest).replace("manifest.", "").replace(".json", "") + model_base_path = os.path.dirname(cfg.model_path) + model_id = os.path.basename(cfg.model_path).replace(".ckpt", "").replace(".nemo", "") + bpath = f"{model_base_path}/pred_tensors" + if not os.path.exists(bpath): + os.makedirs(bpath) + tensor_path = f"{bpath}/__{model_id}__{tensor_filename}.pt" + return tensor_path + +def diarization_objective( + trial: optuna.Trial, + postprocessing_cfg: PostProcessingParams, + temp_out_dir: str, + infer_audio_rttm_dict: Dict[str, Dict[str, str]], + diar_model_preds_total_list: List[torch.Tensor], + collar: float = 0.25, + ignore_overlap: bool = False +) -> float: + """ + Objective function for Optuna hyperparameter optimization in speaker diarization. + + This function evaluates the diarization performance using a set of postprocessing parameters + suggested by Optuna. It converts prediction matrices to time-stamp segments, scores the + diarization results, and returns the Diarization Error Rate (DER) as the optimization metric. + + Args: + trial (optuna.Trial): The Optuna trial object used to suggest hyperparameters. + postprocessing_cfg (PostProcessingParams): The current postprocessing configuration. + temp_out_dir (str): Temporary directory for storing intermediate outputs. + infer_audio_rttm_dict (Dict[str, Dict[str, str]]): Dictionary containing audio file paths, + offsets, durations, and RTTM file paths. + diar_model_preds_total_list (List[torch.Tensor]): List of prediction matrices containing + sigmoid values for each speaker. Dimension: [(1, frames, num_speakers), ..., (1, frames, num_speakers)] + collar (float, optional): Collar in seconds for DER calculation. Defaults to 0.25. + ignore_overlap (bool, optional): If True, DER will be calculated only for non-overlapping segments. + Defaults to False. + + Returns: + float: The Diarization Error Rate (DER) for the given set of postprocessing parameters. + """ + with tempfile.TemporaryDirectory(dir=temp_out_dir, prefix="Diar_PostProcessing_") as local_temp_out_dir: + if trial is not None: + postprocessing_cfg = optuna_suggest_params(postprocessing_cfg, trial) + all_hyps, all_refs, all_uems = convert_pred_mat_to_segments(audio_rttm_map_dict=infer_audio_rttm_dict, + postprocessing_cfg=postprocessing_cfg, + batch_preds_list=diar_model_preds_total_list, + unit_10ms_frame_count=8, + bypass_postprocessing=False) + metric, mapping_dict, itemized_errors = score_labels(AUDIO_RTTM_MAP=infer_audio_rttm_dict, + all_reference=all_refs, + all_hypothesis=all_hyps, + all_uem=all_uems, + collar=collar, + ignore_overlap=ignore_overlap + ) + der = abs(metric) + return der + +def run_optuna_hyperparam_search( + cfg: DiarizationConfig, # type: DiarizationConfig + postprocessing_cfg: PostProcessingParams, + infer_audio_rttm_dict: Dict[str, Dict[str, str]], + preds_list: List[torch.Tensor], + temp_out_dir: str, + ): + worker_function = lambda trial: diarization_objective( + trial=trial, + postprocessing_cfg=postprocessing_cfg, + temp_out_dir=temp_out_dir, + infer_audio_rttm_dict=infer_audio_rttm_dict, + diar_model_preds_total_list=preds_list, + collar=cfg.collar, + ) + study = optuna.create_study( + direction="minimize", + study_name=cfg.optuna_study_name, + storage=cfg.optuna_storage, + load_if_exists=True + ) + logger = logging.getLogger() + logger.setLevel(logging.INFO) # Setup the root logger. + if cfg.optuna_log_file is not None: + logger.addHandler(logging.FileHandler(cfg.optuna_log_file, mode="a")) + logger.addHandler(logging.StreamHandler()) + optuna.logging.enable_propagation() # Propagate logs to the root logger. + study.optimize(worker_function, n_trials=cfg.optuna_n_trials) + + +def convert_pred_mat_to_segments( + audio_rttm_map_dict: Dict[str, Dict[str, str]], + postprocessing_cfg, + batch_preds_list: List[torch.Tensor], + unit_10ms_frame_count:int = 8, + bypass_postprocessing: bool = False, + out_rttm_dir: str | None = None, + ): + """ + Convert prediction matrix to time-stamp segments. + + Args: + audio_rttm_map_dict (dict): dictionary of audio file path, offset, duration and RTTM filepath. + batch_preds_list (List[torch.Tensor]): list of prediction matrices containing sigmoid values for each speaker. + Dimension: [(1, frames, num_speakers), ..., (1, frames, num_speakers)] + unit_10ms_frame_count (int, optional): number of 10ms segments in a frame. Defaults to 8. + bypass_postprocessing (bool, optional): if True, postprocessing will be bypassed. Defaults to False. + + Returns: + all_hypothesis (list): list of pyannote objects for each audio file. + all_reference (list): list of pyannote objects for each audio file. + all_uems (list): list of pyannote objects for each audio file. + """ + batch_pred_ts_segs, all_hypothesis, all_reference, all_uems = [], [], [], [] + cfg_vad_params = OmegaConf.structured(postprocessing_cfg) + for sample_idx, (uniq_id, audio_rttm_values) in tqdm(enumerate(audio_rttm_map_dict.items()), total=len(audio_rttm_map_dict), desc="Running post-processing"): + spk_ts = [] + offset, duration = audio_rttm_values['offset'], audio_rttm_values['duration'] + speaker_assign_mat = batch_preds_list[sample_idx].squeeze(dim=0) + speaker_timestamps = [[] for _ in range(speaker_assign_mat.shape[-1])] + for spk_id in range(speaker_assign_mat.shape[-1]): + ts_mat = ts_vad_post_processing(speaker_assign_mat[:, spk_id], + cfg_vad_params=cfg_vad_params, + unit_10ms_frame_count=unit_10ms_frame_count, + bypass_postprocessing=bypass_postprocessing) + ts_mat = ts_mat + offset + ts_mat = torch.clamp(ts_mat, min=offset, max=(offset + duration)) + ts_seg_list = ts_mat.tolist() + speaker_timestamps[spk_id].extend(ts_seg_list) + spk_ts.append(ts_seg_list) + all_hypothesis, all_reference, all_uems = timestamps_to_pyannote_object(speaker_timestamps, + uniq_id, + audio_rttm_values, + all_hypothesis, + all_reference, + all_uems, + out_rttm_dir, + ) + batch_pred_ts_segs.append(spk_ts) + return all_hypothesis, all_reference, all_uems + +@hydra_runner(config_name="DiarizationConfig", schema=DiarizationConfig) +def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: + for key in cfg: + cfg[key] = None if cfg[key] == 'None' else cfg[key] + + if is_dataclass(cfg): + cfg = OmegaConf.structured(cfg) + + if cfg.random_seed: + pl.seed_everything(cfg.random_seed) + + if cfg.model_path is None and cfg.pretrained_name is None: + raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None!") + if cfg.audio_dir is None and cfg.dataset_manifest is None: + raise ValueError("Both cfg.audio_dir and cfg.dataset_manifest cannot be None!") + + # setup GPU + torch.set_float32_matmul_precision(cfg.matmul_precision) + if cfg.cuda is None: + if torch.cuda.is_available(): + device = [0] # use 0th CUDA device + accelerator = 'gpu' + map_location = torch.device('cuda:0') + else: + device = 1 + accelerator = 'cpu' + map_location = torch.device('cpu') + else: + device = [cfg.cuda] + accelerator = 'gpu' + map_location = torch.device(f'cuda:{cfg.cuda}') + + if cfg.model_path.endswith(".ckpt"): + diar_model = SortformerEncLabelModel.load_from_checkpoint(checkpoint_path=cfg.model_path, map_location=map_location, strict=False) + elif cfg.model_path.endswith(".nemo"): + diar_model = SortformerEncLabelModel.restore_from(restore_path=cfg.model_path, map_location=map_location) + else: + raise ValueError("cfg.model_path must end with.ckpt or.nemo!") + + diar_model._cfg.test_ds.session_len_sec = cfg.session_len_sec + trainer = pl.Trainer(devices=device, accelerator=accelerator) + diar_model.set_trainer(trainer) + + diar_model = diar_model.eval() + diar_model._cfg.test_ds.manifest_filepath = cfg.dataset_manifest + infer_audio_rttm_dict = audio_rttm_map(cfg.dataset_manifest) + diar_model._cfg.test_ds.batch_size = cfg.batch_size + + # Model setup for inference + diar_model._cfg.test_ds.num_workers = cfg.num_workers + diar_model.setup_test_data(test_data_config=diar_model._cfg.test_ds) + + postprocessing_cfg = load_postprocessing_from_yaml(cfg.postprocessing_yaml) + tensor_path = get_tensor_path(cfg) + + if os.path.exists(tensor_path): + logging.info(f"A saved prediction tensor has been found. Loading the saved prediction tensors from {tensor_path}...") + diar_model_preds_total_list = torch.load(tensor_path) + else: + logging.info(f"No saved prediction tensors found. Running inference on the dataset...") + diar_model.test_batch() + diar_model_preds_total_list = diar_model.preds_total_list + torch.save(diar_model.preds_total_list, tensor_path) + + if cfg.launch_pp_optim: + # Launch a hyperparameter optimization process if launch_pp_optim is True + run_optuna_hyperparam_search(cfg=cfg, + postprocessing_cfg=postprocessing_cfg, + infer_audio_rttm_dict=infer_audio_rttm_dict, + preds_list=diar_model_preds_total_list, + temp_out_dir=cfg.optuna_temp_dir) + + # Evaluation + if not cfg.no_der: + if cfg.out_rttm_dir is not None and not os.path.exists(cfg.out_rttm_dir): + os.mkdir(cfg.out_rttm_dir) + all_hyps, all_refs, all_uems = convert_pred_mat_to_segments(infer_audio_rttm_dict, + postprocessing_cfg=postprocessing_cfg, + batch_preds_list=diar_model_preds_total_list, + unit_10ms_frame_count=8, + bypass_postprocessing=cfg.bypass_postprocessing, + out_rttm_dir=cfg.out_rttm_dir + ) + logging.info(f"Evaluating the model on the {len(diar_model_preds_total_list)} audio segments...") + metric, mapping_dict, itemized_errors = score_labels(AUDIO_RTTM_MAP=infer_audio_rttm_dict, + all_reference=all_refs, + all_hypothesis=all_hyps, + all_uem=all_uems, + collar=cfg.collar, + ignore_overlap=cfg.ignore_overlap + ) + logging.info(f"PostProcessingParams: {postprocessing_cfg}") + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/nemo/collections/asr/metrics/der.py b/nemo/collections/asr/metrics/der.py index fc5cded970d0..16f62bbe9e4c 100644 --- a/nemo/collections/asr/metrics/der.py +++ b/nemo/collections/asr/metrics/der.py @@ -130,7 +130,13 @@ def uem_timeline_from_file(uem_file, uniq_name=''): def score_labels( - AUDIO_RTTM_MAP, all_reference, all_hypothesis, collar=0.25, ignore_overlap=True, verbose: bool = True + AUDIO_RTTM_MAP, + all_reference, + all_hypothesis, + all_uem: List[List[float]]=None, + collar:float=0.25, + ignore_overlap: bool=True, + verbose: bool = True ) -> Optional[Tuple[DiarizationErrorRate, Dict]]: """ Calculate DER, CER, FA and MISS rate from hypotheses and references. Hypothesis results are @@ -157,26 +163,41 @@ def score_labels( if len(all_reference) == len(all_hypothesis): metric = DiarizationErrorRate(collar=2 * collar, skip_overlap=ignore_overlap) - mapping_dict = {} - for (reference, hypothesis) in zip(all_reference, all_hypothesis): + mapping_dict, correct_spk_count = {}, 0 + for idx, (reference, hypothesis) in enumerate(zip(all_reference, all_hypothesis)): ref_key, ref_labels = reference _, hyp_labels = hypothesis - uem = AUDIO_RTTM_MAP[ref_key].get('uem_filepath', None) - if uem is not None: - uem = uem_timeline_from_file(uem_file=uem, uniq_name=ref_key) - metric(ref_labels, hyp_labels, uem=uem, detailed=True) + if len(ref_labels.labels()) == len(hyp_labels.labels()): + correct_spk_count += 1 + if verbose and len(ref_labels.labels()) != len(hyp_labels.labels()): + logging.info(f"Wrong Spk. Count with uniq_id:...{ref_key[-10:]}, Ref: {len(ref_labels.labels())}, Hyp: {len(hyp_labels.labels())}") + uem_obj = None + if all_uem is not None: + metric(ref_labels, hyp_labels, uem=all_uem[idx], detailed=True) + elif AUDIO_RTTM_MAP[ref_key].get('uem_filepath', None) is not None: + uem_file = AUDIO_RTTM_MAP[ref_key].get('uem_filepath', None) + uem_obj = uem_timeline_from_file(uem_file=uem_file, uniq_name=ref_key) + metric(ref_labels, hyp_labels, uem=uem_obj, detailed=True) + else: + metric(ref_labels, hyp_labels, detailed=True) mapping_dict[ref_key] = metric.optimal_mapping(ref_labels, hyp_labels) + spk_count_acc = correct_spk_count / len(all_reference) DER = abs(metric) + if metric['total'] == 0: + raise ValueError(f"Total evaluation time is 0. Abort.") CER = metric['confusion'] / metric['total'] FA = metric['false alarm'] / metric['total'] MISS = metric['missed detection'] / metric['total'] + itemized_errors = (DER, CER, FA, MISS) + if verbose: + # logging.info(f"\n{metric.report()}") + pass logging.info( - "Cumulative Results for collar {} sec and ignore_overlap {}: \n FA: {:.4f}\t MISS {:.4f}\t \ - Diarization ER: {:.4f}\t, Confusion ER:{:.4f}".format( - collar, ignore_overlap, FA, MISS, DER, CER + "Cumulative Results for collar {} sec and ignore_overlap {}: \n| FA: {:.4f} | MISS: {:.4f} | CER: {:.4f} | DER: {:.4f} | Spk. Count Acc. {:.4f}\n".format( + collar, ignore_overlap, FA, MISS, CER, DER, spk_count_acc ) ) diff --git a/nemo/collections/asr/metrics/multi_binary_acc.py b/nemo/collections/asr/metrics/multi_binary_acc.py index 8cc21c53ad82..72781143208b 100644 --- a/nemo/collections/asr/metrics/multi_binary_acc.py +++ b/nemo/collections/asr/metrics/multi_binary_acc.py @@ -68,18 +68,19 @@ def on_validation_epoch_end(self): f1_score (torch.Tensor): F1 score calculated from the predicted value and binarized target values. """ - full_state_update = False def __init__(self, dist_sync_on_step=False): super().__init__(dist_sync_on_step=dist_sync_on_step) - self.total_correct_counts = 0 - self.total_sample_counts = 0 - self.true_positive_count = 0 - self.false_positive_count = 0 - self.false_negative_count = 0 - - def update(self, preds: torch.Tensor, targets: torch.Tensor, signal_lengths: torch.Tensor) -> torch.Tensor: + self.add_state("total_correct_counts", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.add_state("total_sample_counts", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.add_state("true_positive_count", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.add_state("false_positive_count", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.add_state("false_negative_count", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.add_state("positive_count", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.eps = 1e-6 + + def update(self, preds: torch.Tensor, targets: torch.Tensor, signal_lengths: torch.Tensor, cumulative=False) -> torch.Tensor: with torch.no_grad(): preds_list = [preds[k, : signal_lengths[k], :] for k in range(preds.shape[0])] targets_list = [targets[k, : signal_lengths[k], :] for k in range(targets.shape[0])] @@ -91,22 +92,30 @@ def update(self, preds: torch.Tensor, targets: torch.Tensor, signal_lengths: tor self.positive = self.preds.round().bool() == 1 self.negative = self.preds.round().bool() == 0 - self.positive_count = torch.sum(self.preds.round().bool() == True) - self.true_positive_count += torch.sum(torch.logical_and(self.true, self.positive)) - self.false_positive_count += torch.sum(torch.logical_and(self.false, self.positive)) - self.false_negative_count += torch.sum(torch.logical_and(self.false, self.negative)) - - self.total_correct_counts += torch.sum(self.preds.round().bool() == self.targets.round().bool()) - self.total_sample_counts += torch.prod(torch.tensor(self.targets.shape)) + if cumulative: + self.positive_count += torch.sum(self.preds.round().bool() == True) + self.true_positive_count += torch.sum(torch.logical_and(self.true, self.positive)) + self.false_positive_count += torch.sum(torch.logical_and(self.false, self.positive)) + self.false_negative_count += torch.sum(torch.logical_and(self.false, self.negative)) + self.total_correct_counts += torch.sum(self.preds.round().bool() == self.targets.round().bool()) + self.total_sample_counts += torch.prod(torch.tensor(self.targets.shape)) + else: + self.positive_count = torch.sum(self.preds.round().bool() == True) + self.true_positive_count = torch.sum(torch.logical_and(self.true, self.positive)) + self.false_positive_count = torch.sum(torch.logical_and(self.false, self.positive)) + self.false_negative_count = torch.sum(torch.logical_and(self.false, self.negative)) + self.total_correct_counts = torch.sum(self.preds.round().bool() == self.targets.round().bool()) + self.total_sample_counts = torch.prod(torch.tensor(self.targets.shape)) def compute(self): """ Compute F1 score from the accumulated values. Return -1 if the F1 score is NaN. """ - self.precision = self.true_positive_count / (self.true_positive_count + self.false_positive_count) - self.recall = self.true_positive_count / (self.true_positive_count + self.false_negative_count) - self.f1_score = 2 * self.precision * self.recall / (self.precision + self.recall) - if torch.isnan(self.f1_score): + precision = self.true_positive_count / (self.true_positive_count + self.false_positive_count + self.eps) + recall = self.true_positive_count / (self.true_positive_count + self.false_negative_count + self.eps) + f1_score = (2 * precision * recall / (precision + recall + self.eps)).detach().clone() + + if torch.isnan(f1_score): logging.warn("self.f1_score contains NaN value. Returning -1 instead of NaN value.") - self.f1_score = -1 - return self.f1_score + f1_score = -1 + return f1_score.float(), precision.float(), recall.float() diff --git a/nemo/collections/asr/models/__init__.py b/nemo/collections/asr/models/__init__.py index e4a1342b9c36..31194d8849f0 100644 --- a/nemo/collections/asr/models/__init__.py +++ b/nemo/collections/asr/models/__init__.py @@ -19,6 +19,7 @@ EncDecClassificationModel, EncDecFrameClassificationModel, ) +from nemo.collections.asr.models.sortformer_diar_models import SortformerEncLabelModel from nemo.collections.asr.models.clustering_diarizer import ClusteringDiarizer from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE from nemo.collections.asr.models.ctc_models import EncDecCTCModel @@ -35,9 +36,5 @@ from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel from nemo.collections.asr.models.slu_models import SLUIntentSlotBPEModel -from nemo.collections.asr.models.ssl_models import ( - EncDecDenoiseMaskedTokenPredModel, - EncDecMaskedTokenPredModel, - SpeechEncDecSelfSupervisedModel, -) +from nemo.collections.asr.models.ssl_models import SpeechEncDecSelfSupervisedModel from nemo.collections.asr.models.transformer_bpe_models import EncDecTransfModelBPE diff --git a/nemo/collections/asr/models/sortformer_diar_models.py b/nemo/collections/asr/models/sortformer_diar_models.py index c389f0eb627f..50cdf6214d5b 100644 --- a/nemo/collections/asr/models/sortformer_diar_models.py +++ b/nemo/collections/asr/models/sortformer_diar_models.py @@ -555,7 +555,8 @@ def test_batch(self,): self.preds_total_list.extend(torch.split(preds, [1] * preds.shape[0])) torch.cuda.empty_cache() self._get_aux_test_batch_evaluations(batch_idx, preds, targets, target_lens) - + # except: + # import ipdb; ipdb.set_trace() logging.info(f"Batch F1Acc. MEAN: {torch.mean(torch.tensor(self.batch_f1_accs_list))}") logging.info(f"Batch Precision MEAN: {torch.mean(torch.tensor(self.batch_precision_list))}") logging.info(f"Batch Recall MEAN: {torch.mean(torch.tensor(self.batch_recall_list))}") diff --git a/nemo/collections/asr/modules/sortformer_modules.py b/nemo/collections/asr/modules/sortformer_modules.py new file mode 100644 index 000000000000..823cf98590e7 --- /dev/null +++ b/nemo/collections/asr/modules/sortformer_modules.py @@ -0,0 +1,111 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from nemo.core.classes.exportable import Exportable +from nemo.core.classes.module import NeuralModule +from nemo.core.neural_types import EncodedRepresentation, LengthsType, NeuralType, SpectrogramType +from nemo.core.neural_types.elements import ProbsType + +__all__ = ['SortformerModules'] + + +class SortformerModules(NeuralModule, Exportable): + """ + Multi-scale Diarization Decoder (MSDD) for overlap-aware diarization and improved diarization accuracy from clustering diarizer. + Based on the paper: Taejin Park et. al, "Multi-scale Speaker Diarization with Dynamic Scale Weighting", Interspeech 2022. + Arxiv version: https://arxiv.org/pdf/2203.15974.pdf + + Args: + num_spks (int): + Max number of speakers that are processed by the model. In `MSDD_module`, `num_spks=2` for pairwise inference. + hidden_size (int): + Number of hidden units in sequence models and intermediate layers. + num_lstm_layers (int): + Number of the stacked LSTM layers. + dropout_rate (float): + Dropout rate for linear layers, CNN and LSTM. + tf_d_model (int): + Dimension of the embedding vectors. + scale_n (int): + Number of scales in multi-scale system. + clamp_max (float): + Maximum value for limiting the scale weight values. + conv_repeat (int): + Number of CNN layers after the first CNN layer. + weighting_scheme (str): + Name of the methods for estimating the scale weights. + context_vector_type (str): + If 'cos_sim', cosine similarity values are used for the input of the sequence models. + If 'elem_prod', element-wise product values are used for the input of the sequence models. + """ + def init_weights(self, m): + if type(m) == nn.Linear: + torch.nn.init.xavier_uniform_(m.weight) + m.bias.data.fill_(0.01) + + def __init__( + self, + num_spks: int = 4, + hidden_size: int = 192, + dropout_rate: float = 0.5, + fc_d_model: int = 512, + tf_d_model: int = 192, + ): + super().__init__() + self.fc_d_model = fc_d_model + self.tf_d_model = tf_d_model + self.hidden_size = tf_d_model + self.unit_n_spks: int = num_spks + self.hidden_to_spks = nn.Linear(2 * self.hidden_size, self.unit_n_spks) + self.first_hidden_to_hidden = nn.Linear(self.hidden_size, self.hidden_size) + self.single_hidden_to_spks = nn.Linear(self.hidden_size, self.unit_n_spks) + self.dropout = nn.Dropout(dropout_rate) + self.encoder_proj = nn.Linear(self.fc_d_model, self.tf_d_model) + + def length_to_mask(self, context_embs): + """ + Convert length values to encoder mask input tensor. + + Args: + lengths (torch.Tensor): tensor containing lengths of sequences + max_len (int): maximum sequence length + + Returns: + mask (torch.Tensor): tensor of shape (batch_size, max_len) containing 0's + in the padded region and 1's elsewhere + """ + lengths = torch.tensor([context_embs.shape[1]] * context_embs.shape[0]) + batch_size = context_embs.shape[0] + max_len=context_embs.shape[1] + # create a tensor with the shape (batch_size, 1) filled with ones + row_vector = torch.arange(max_len).unsqueeze(0).expand(batch_size, -1).to(lengths.device) + # create a tensor with the shape (batch_size, max_len) filled with lengths + length_matrix = lengths.unsqueeze(1).expand(-1, max_len).to(lengths.device) + # create a mask by comparing the row vector and length matrix + mask = row_vector < length_matrix + return mask.float().to(context_embs.device) + + def forward_speaker_sigmoids(self, hidden_out): + hidden_out = self.dropout(F.relu(hidden_out)) + hidden_out = self.first_hidden_to_hidden(hidden_out) + hidden_out = self.dropout(F.relu(hidden_out)) + spk_preds = self.single_hidden_to_spks(hidden_out) + preds = nn.Sigmoid()(spk_preds) + return preds diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index 5d3a0bf4274e..80b3e1f918b8 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -21,10 +21,11 @@ from typing import Dict, List, Tuple, Union import numpy as np -import omegaconf +from omegaconf import OmegaConf +from omegaconf.listconfig import ListConfig import soundfile as sf import torch -from pyannote.core import Annotation, Segment +from pyannote.core import Annotation, Segment, Timeline from tqdm import tqdm from nemo.collections.asr.data.audio_to_label import repeat_signal @@ -108,7 +109,10 @@ def audio_rttm_map(manifest, attach_dur=False): if attach_dur: uniqname = get_uniq_id_with_dur(meta) else: - uniqname = get_uniqname_from_filepath(filepath=meta['audio_filepath']) + if "uniq_id" in dic.keys(): + uniqname = dic['uniq_id'] + else: + uniqname = get_uniqname_from_filepath(filepath=meta['audio_filepath']) if uniqname not in AUDIO_RTTM_MAP: AUDIO_RTTM_MAP[uniqname] = meta @@ -144,7 +148,7 @@ def parse_scale_configs(window_lengths_in_sec, shift_lengths_in_sec, multiscale_ """ check_float_config = [isinstance(var, float) for var in (window_lengths_in_sec, shift_lengths_in_sec)] check_list_config = [ - isinstance(var, (omegaconf.listconfig.ListConfig, list, tuple)) + isinstance(var, (ListConfig, list, tuple)) for var in (window_lengths_in_sec, shift_lengths_in_sec, multiscale_weights) ] if all(check_list_config) or all(check_float_config): @@ -928,28 +932,61 @@ def segments_manifest_to_subsegments_manifest( return subsegments_manifest_file -def get_subsegments(offset: float, window: float, shift: float, duration: float) -> List[List[float]]: - """ - Return subsegments from a segment of audio file +def get_subsegments( + offset: float, + window: float, + shift: float, + duration: float, + min_subsegment_duration: float = 0.01, + decimals: int = 2, + use_asr_style_frame_count: bool = False, + sample_rate: int = 16000, + feat_per_sec: int = 100, + ) -> List[List[float]]: + """ + Return subsegments from a segment of audio file. + + Example: + (window, shift) = 1.5, 0.75 + Segment: [12.05, 14.45] + Subsegments: [[12.05, 13.55], [12.8, 14.3], [13.55, 14.45], [14.3, 14.45]] + Args: - offset (float): start time of audio segment - window (float): window length for segments to subsegments length - shift (float): hop length for subsegments shift - duration (float): duration of segment + offset (float): Start time of audio segment + window (float): Window length for segments to subsegments length + shift (float): Hop length for subsegments shift + duration (float): Duration of segment + min_subsegment_duration (float): Exclude subsegments smaller than this duration value + decimals (int): Number of decimal places to round to + use_asr_style_frame_count (bool): If True, use asr style frame count to generate subsegments. + For example, if duration is 10 secs and frame_shift is 0.08 secs, + it results in (10/0.08)+1 = 125 + 1 frames. + Returns: subsegments (List[tuple[float, float]]): subsegments generated for the segments as list of tuple of start and duration of each subsegment """ - subsegments: List[List[float]] = [] + subsegments: List[List[float]] = [] start = offset slice_end = start + duration - base = math.ceil((duration - window) / shift) - slices = 1 if base < 0 else base + 1 - for slice_id in range(slices): - end = start + window - if end > slice_end: - end = slice_end - subsegments.append([start, end - start]) - start = offset + (slice_id + 1) * shift + if min_subsegment_duration <= duration < shift: + slices = 1 + elif use_asr_style_frame_count is True: + num_feat_frames = np.ceil((1+duration*sample_rate)/int(sample_rate/feat_per_sec)).astype(int) + slices = np.ceil(num_feat_frames/int(feat_per_sec*shift)).astype(int) + slice_end = start + shift * slices + else: + slices = np.ceil(1+ (duration-window)/shift).astype(int) + if slices == 1: + if min(duration, window) >= min_subsegment_duration: + subsegments.append([start, min(duration, window)]) + elif slices > 0: # What if slcies = 0 ? + start_col = torch.arange(offset, slice_end, shift)[:slices] + dur_col = window * torch.ones(slices) + dur_col = torch.min(slice_end*torch.ones_like(start_col)- start_col, window * torch.ones_like(start_col)) + dur_col = torch.round(dur_col, decimals=decimals) + valid_mask = dur_col >= min_subsegment_duration + valid_subsegments = torch.stack([start_col[valid_mask], dur_col[valid_mask]], dim=1) + subsegments = valid_subsegments.tolist() return subsegments @@ -1000,6 +1037,15 @@ def tensor_to_list(range_tensor: torch.Tensor) -> List[List[float]]: return [[float(range_tensor[k][0]), float(range_tensor[k][1])] for k in range(range_tensor.shape[0])] +def generate_diarization_output_lines(speaker_timestamps, model_spk_num): + speaker_lines_total = [] + for spk_idx in range(model_spk_num): + ts_invervals = speaker_timestamps[spk_idx] + merged_ts_intervals = merge_float_intervals(ts_invervals) + for ts_interval in merged_ts_intervals: + speaker_lines_total.extend([f"{ts_interval[0]:.3f} {ts_interval[1]:.3f} speaker_{int(spk_idx)}"]) + return speaker_lines_total + def get_speech_labels_for_update( frame_start: float, buffer_end: float, @@ -1580,6 +1626,83 @@ def make_rttm_with_overlap( return all_reference, all_hypothesis +def timestamps_to_pyannote_object(speaker_timestamps: List[Tuple[float, float]], + uniq_id: str, + audio_rttm_values: Dict[str, str], + all_hypothesis: List[Tuple[str, Timeline]], + all_reference: List[Tuple[str, Timeline]], + all_uems: List[Tuple[str, Timeline]], + out_rttm_dir: str | None + ): + """ + Convert speaker timestamps to pyannote.core.Timeline object. + + Args: + speaker_timestamps (List[Tuple[float, float]]): + Timestamps of each speaker: start time and end time of each speaker. + uniq_id (str): + Unique ID of each speaker. + audio_rttm_values (Dict[str, str]): + Dictionary of manifest values. + all_hypothesis (List[Tuple[str, pyannote.core.Timeline]]): + List of hypothesis in pyannote.core.Timeline object. + all_reference (List[Tuple[str, pyannote.core.Timeline]]): + List of reference in pyannote.core.Timeline object. + all_uems (List[Tuple[str, pyannote.core.Timeline]]): + List of uems in pyannote.core.Timeline object. + out_rttm_dir (str | None): + Directory to save RTTMs + + Returns: + all_hypothesis (List[Tuple[str, pyannote.core.Timeline]]): + List of hypothesis in pyannote.core.Timeline object with an added Timeline object. + all_reference (List[Tuple[str, pyannote.core.Timeline]]): + List of reference in pyannote.core.Timeline object with an added Timeline object. + all_uems (List[Tuple[str, pyannote.core.Timeline]]): + List of uems in pyannote.core.Timeline object with an added Timeline object. + """ + offset, dur = float(audio_rttm_values.get('offset', None)), float(audio_rttm_values.get('duration', None)) + hyp_labels = generate_diarization_output_lines(speaker_timestamps=speaker_timestamps, model_spk_num=len(speaker_timestamps)) + hypothesis = labels_to_pyannote_object(hyp_labels, uniq_name=uniq_id) + if out_rttm_dir is not None and os.path.exists(out_rttm_dir): + with open(f'{out_rttm_dir}/{uniq_id}.rttm','w') as f: + hypothesis.write_rttm(f) + all_hypothesis.append([uniq_id, hypothesis]) + rttm_file = audio_rttm_values.get('rttm_filepath', None) + if rttm_file is not None and os.path.exists(rttm_file): + uem_lines = [[offset, dur+offset]] + org_ref_labels = rttm_to_labels(rttm_file) + ref_labels = org_ref_labels + reference = labels_to_pyannote_object(ref_labels, uniq_name=uniq_id) + uem_obj = get_uem_object(uem_lines, uniq_id=uniq_id) + all_uems.append(uem_obj) + all_reference.append([uniq_id, reference]) + return all_hypothesis, all_reference, all_uems + +def get_uem_object(uem_lines: List[List[float]], uniq_id: str): + """ + Generate pyannote timeline segments for uem file. + + file format + UNIQ_SPEAKER_ID CHANNEL START_TIME END_TIME + + Args: + uem_lines (list): list of session ID and start, end times. + Example: + [[0.0, 30.41], [60.04, 165.83]] + uniq_id (str): Unique session ID. + + Returns: + timeline (pyannote.core.Timeline): pyannote timeline object. + """ + timeline = Timeline(uri=uniq_id) + for uem_stt_end in uem_lines: + start_time, end_time = uem_stt_end + timeline.add(Segment(float(start_time), float(end_time))) + return timeline + + + def embedding_normalize(embs, use_std=False, eps=1e-10): """ Mean and l2 length normalize the input speaker embeddings diff --git a/nemo/collections/asr/parts/utils/vad_utils.py b/nemo/collections/asr/parts/utils/vad_utils.py index aea04b8cafcf..192c42375dca 100644 --- a/nemo/collections/asr/parts/utils/vad_utils.py +++ b/nemo/collections/asr/parts/utils/vad_utils.py @@ -23,31 +23,23 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple, Union +import IPython.display as ipd import librosa import matplotlib.pyplot as plt import numpy as np import pandas as pd import torch -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from pyannote.core import Annotation, Segment from pyannote.metrics import detection from sklearn.metrics import roc_auc_score from sklearn.model_selection import ParameterGrid from tqdm import tqdm - +from nemo.collections.asr.parts.utils.speaker_utils import timestamps_to_pyannote_object from nemo.collections.asr.models import EncDecClassificationModel, EncDecFrameClassificationModel from nemo.collections.common.parts.preprocessing.manifest import get_full_path from nemo.utils import logging -HAVE_IPYTHON = False -try: - import IPython.display as ipd - - HAVE_IPYTHON = True -except: - HAVE_IPYTHON = False - - """ This file contains all the utility functions required for voice activity detection. """ @@ -74,8 +66,7 @@ def prepare_manifest(config: dict) -> str: input_list = config['input'] else: raise ValueError( - "The input for manifest preparation would either be a string of the filepath to \ - manifest or a list of {'audio_filepath': i, 'offset': 0, 'duration': null} " + "The input for manifest preparation would either be a string of the filepath to manifest or a list of {'audio_filepath': i, 'offset': 0, 'duration': null} " ) args_func = { @@ -204,8 +195,7 @@ def write_vad_infer_manifest(file: dict, args_func: dict) -> list: def get_vad_stream_status(data: list) -> list: """ - Generate a list of status for each snippet in manifest. - A snippet should be in single, start, next or end status. + Generate a list of status for each snippet in manifest. A snippet should be in single, start, next or end status. Used for concatenating to full audio file. Args: data (list): list of filepath of audio snippet @@ -256,8 +246,7 @@ def generate_overlap_vad_seq( out_dir: str = None, ) -> str: """ - Generate predictions with overlapping input windows/segments. - Then a smoothing filter is applied to decide the label for a frame spanned by multiple windows. + Generate predictions with overlapping input windows/segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple windows. Two common smoothing filters are supported: majority vote (median) and average (mean). This function uses multiprocessing to speed up. Args: @@ -321,8 +310,7 @@ def generate_overlap_vad_seq_per_tensor( frame: torch.Tensor, per_args: Dict[str, float], smoothing_method: str ) -> torch.Tensor: """ - Use generated frame prediction (generated by shifting window of shift_length_in_sec (10ms)) - to generate prediction with overlapping input window/segments + Use generated frame prediction (generated by shifting window of shift_length_in_sec (10ms)) to generate prediction with overlapping input window/segments See description in generate_overlap_vad_seq. Use this for single instance pipeline. """ @@ -484,8 +472,7 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te Binarize predictions to speech and non-speech Reference - Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", \ - InterSpeech 2015. + Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", InterSpeech 2015. Implementation: https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py Args: @@ -498,8 +485,7 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te frame_length_in_sec (float): length of frame. Returns: - speech_segments(torch.Tensor): A tensor of speech segment in torch.Tensor([[start1, end1], [start2, end2]]) \ - format. + speech_segments(torch.Tensor): A tensor of speech segment in torch.Tensor([[start1, end1], [start2, end2]]) format. """ frame_length_in_sec = per_args.get('frame_length_in_sec', 0.01) @@ -550,8 +536,7 @@ def remove_segments(original_segments: torch.Tensor, to_be_removed_segments: tor """ Remove speech segments list in to_be_removed_segments from original_segments. For example, - remove torch.Tensor([[start2, end2],[start4, end4]]) from torch.Tensor([[start1, end1],[start2, end2],\ - [start3, end3], [start4, end4]]), + remove torch.Tensor([[start2, end2],[start4, end4]]) from torch.Tensor([[start1, end1],[start2, end2],[start3, end3], [start4, end4]]), -> torch.Tensor([[start1, end1],[start3, end3]]) """ @@ -577,25 +562,21 @@ def filtering(speech_segments: torch.Tensor, per_args: Dict[str, float]) -> torc Filter out short non_speech and speech segments. Reference - Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", \ - InterSpeech 2015. + Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", InterSpeech 2015. Implementation: https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py Args: - speech_segments (torch.Tensor): A tensor of speech segment in torch.Tensor([[start1, end1], \ - [start2, end2]]) format. + speech_segments (torch.Tensor): A tensor of speech segment in torch.Tensor([[start1, end1], [start2, end2]]) format. per_args: min_duration_on (float): threshold for small non_speech deletion min_duration_off (float): threshold for short speech segment deletion - filter_speech_first (float): Whether to perform short speech segment deletion first. \ - Use 1.0 to represent True. + filter_speech_first (float): Whether to perform short speech segment deletion first. Use 1.0 to represent True. Returns: - speech_segments(torch.Tensor): A tensor of filtered speech segment in \ - torch.Tensor([[start1, end1], [start2, end2]]) format. + speech_segments(torch.Tensor): A tensor of filtered speech segment in torch.Tensor([[start1, end1], [start2, end2]]) format. """ if speech_segments.shape == torch.Size([0]): return speech_segments - + min_duration_on = per_args.get('min_duration_on', 0.0) min_duration_off = per_args.get('min_duration_off', 0.0) filter_speech_first = per_args.get('filter_speech_first', 1.0) @@ -728,8 +709,7 @@ def generate_vad_segment_table( 17,18, speech Args: vad_pred_dir (str): directory of prediction files to be processed. - postprocessing_params (dict): dictionary of thresholds for prediction score. - See details in binarization and filtering. + postprocessing_params (dict): dictionary of thresholds for prediction score. See details in binarization and filtering. frame_length_in_sec (float): frame length. out_dir (str): output dir of generated table/csv file. num_workers(float): number of process for multiprocessing @@ -840,12 +820,10 @@ def vad_tune_threshold_on_dev( num_workers: int = 20, ) -> Tuple[dict, dict]: """ - Tune thresholds on dev set. Return best thresholds which gives the lowest detection error rate - (DetER) in thresholds. + Tune thresholds on dev set. Return best thresholds which gives the lowest detection error rate (DetER) in thresholds. Args: params (dict): dictionary of parameters to be tuned on. - vad_pred_method (str): suffix of prediction file. Use to locate file. - Should be either in "frame", "mean" or "median". + vad_pred_method (str): suffix of prediction file. Use to locate file. Should be either in "frame", "mean" or "median". groundtruth_RTTM_dir (str): directory of ground-truth rttm files or a file contains the paths of them. focus_metric (str): metrics we care most when tuning threshold. Should be either in "DetER", "FA", "MISS" frame_length_in_sec (float): frame length. @@ -936,8 +914,7 @@ def check_if_param_valid(params: dict) -> bool: for j in params[i]: if not j >= 0: raise ValueError( - "Invalid inputs! All float parameters except pad_onset and pad_offset should be \ - larger than 0!" + "Invalid inputs! All float parameters except pad_onset and pad_offset should be larger than 0!" ) if not (all(i <= 1 for i in params['onset']) and all(i <= 1 for i in params['offset'])): @@ -995,7 +972,7 @@ def plot( unit_frame_len: float = 0.01, label_repeat: int = 1, xticks_step: int = 5, -) -> "ipd.Audio": +) -> ipd.Audio: """ Plot Audio and/or VAD output and/or groundtruth labels for visualization Args: @@ -1009,13 +986,9 @@ def plot( threshold (float): threshold for prediction score (from 0 to 1). per_args(dict): a dict that stores the thresholds for postprocessing. unit_frame_len (float): unit frame length in seconds for VAD predictions. - label_repeat (int): repeat the label for this number of times to match different \ - frame lengths in preds and labels. + label_repeat (int): repeat the label for this number of times to match different frame lengths in preds and labels. xticks_step (int): step size for xticks. """ - if HAVE_IPYTHON is False: - raise ImportError("IPython is not installed. Please install IPython to use this function.") - plt.figure(figsize=[20, 2]) audio, sample_rate = librosa.load( @@ -1281,8 +1254,7 @@ def stitch_segmented_asr_output( fout.flush() logging.info( - f"Finish stitch segmented ASR output to {stitched_output_manifest}, \ - the speech segments info has been stored in directory {speech_segments_tensor_dir}" + f"Finish stitch segmented ASR output to {stitched_output_manifest}, the speech segments info has been stored in directory {speech_segments_tensor_dir}" ) return stitched_output_manifest @@ -1462,13 +1434,10 @@ def plot_sample_from_rttm( show: bool = True, offset: float = 0.0, unit_frame_len: float = 0.01, -) -> "ipd.Audio": +): """ Plot audio signal and frame-level labels from RTTM file """ - if HAVE_IPYTHON is False: - raise ImportError("IPython is not installed. Please install IPython to use this function.") - plt.figure(figsize=[20, 2]) audio, sample_rate = librosa.load(path=audio_file, sr=16000, mono=True, offset=offset, duration=max_duration) @@ -1503,9 +1472,8 @@ def plot_sample_from_rttm( def align_labels_to_frames(probs, labels, threshold=0.2): """ Aligns labels to frames when the frame length (e.g., 10ms) is different from the label length (e.g., 20ms). - The threshold 0.2 is not important, since the actual ratio will always be close to an integer - unless using frame/label. lengths that are not multiples of each other - (e.g., 15ms frame length and 20ms label length), which is not valid. + The threshold 0.2 is not important, since the actual ratio will always be close to an integer unless using frame/label + lengths that are not multiples of each other (e.g., 15ms frame length and 20ms label length), which is not valid. The value 0.2 here is just for easier unit testing. Args: probs (List[float]): list of probabilities @@ -1543,13 +1511,11 @@ def align_labels_to_frames(probs, labels, threshold=0.2): ratio = frames_len / labels_len res = frames_len % labels_len if ceil(ratio) - ratio < threshold: - # e.g., ratio is 1.83, ceil(ratio) = 2, then we repeat labels to make it a - # multiple of 2, and discard the redundant labels + # e.g., ratio is 1.83, ceil(ratio) = 2, then we repeat labels to make it a multiple of 2, and discard the redundant labels labels = labels.repeat_interleave(ceil(ratio), dim=0).long().tolist() labels = labels[:frames_len] else: - # e.g., ratio is 2.02, floor(ratio) = 2, then we repeat labels to make it a multiple of - # 2 and add additional labels + # e.g., ratio is 2.02, floor(ratio) = 2, then we repeat labels to make it a multiple of 2 and add additional labels labels = labels.repeat_interleave(floor(ratio), dim=0).long().tolist() if res > 0: labels += labels[-res:] @@ -1743,3 +1709,51 @@ def frame_vad_eval_detection_error( auroc = roc_auc_score(y_true=all_labels, y_score=all_probs) report = metric.report(display=False) return auroc, report + + +def ts_vad_post_processing( + ts_vad_binary_vec: torch.Tensor, + cfg_vad_params: OmegaConf, + unit_10ms_frame_count: int=8, + bypass_postprocessing: bool = False + ): + """ + Post-processing on diarization results using VAD style post-processing methods. + These post-processing methods are inspired by the following paper: + Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020). + + Args: + ts_vad_binary_vec (Tensor): + Sigmoid values of each frame and each speaker. + Dimension: (num_frames,) + cfg_vad_params (OmegaConf): + Configuration (omega config) of VAD parameters. + unit_10ms_frame_count (int, optional): + an integer indicating the number of 10ms frames in a unit. + For example, if unit_10ms_frame_count is 8, then each frame is 0.08 seconds. + bypass_postprocessing (bool, optional): + If True, diarization post-processing will be bypassed. + + Returns: + speech_segments (Tensor): + start and end of each speech segment. + Dimension: (num_segments, 2) + + Example: + tensor([[ 0.0000, 3.0400], + [ 6.0000, 6.0800], + ... + [587.3600, 591.0400], + [591.1200, 597.7600]]) + """ + ts_vad_binary_frames = torch.repeat_interleave(ts_vad_binary_vec, unit_10ms_frame_count) + if not bypass_postprocessing: + speech_segments = binarization(ts_vad_binary_frames, cfg_vad_params) + speech_segments = filtering(speech_segments, cfg_vad_params) + else: + cfg_vad_params.onset=0.5 + cfg_vad_params.offset=0.5 + cfg_vad_params.pad_onset=0.0 + cfg_vad_params.pad_offset=0.0 + speech_segments = binarization(ts_vad_binary_frames, cfg_vad_params) + return speech_segments \ No newline at end of file From 9a468ac82e68458ceb9c5882975d887362a00c07 Mon Sep 17 00:00:00 2001 From: taejinp Date: Thu, 14 Nov 2024 00:37:49 -0800 Subject: [PATCH 03/16] Name changes on yaml files and train example Signed-off-by: taejinp --- ...rtformer_diarizer_hybrid_loss_4spk-v1.yaml | 218 ++++++++++++++++++ ...rtformer_diar_4spk-v1_callhome-part1.yaml} | 1 - ...> sortformer_diar_4spk-v1_dihard-dev.yaml} | 2 +- .../sortformer_diar_encoder_infer.py | 132 ----------- ...oder_train.py => sortformer_diar_train.py} | 0 5 files changed, 219 insertions(+), 134 deletions(-) create mode 100644 examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml rename examples/speaker_tasks/diarization/conf/post_processing/{sortformer_diar_HL_callhome_part1.yaml => sortformer_diar_4spk-v1_callhome-part1.yaml} (85%) rename examples/speaker_tasks/diarization/conf/post_processing/{sortformer_diar_HL_dihard.yaml => sortformer_diar_4spk-v1_dihard-dev.yaml} (84%) delete mode 100644 examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_encoder_infer.py rename examples/speaker_tasks/diarization/neural_diarizer/{sortformer_diar_encoder_train.py => sortformer_diar_train.py} (100%) diff --git a/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml b/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml new file mode 100644 index 000000000000..e44bae976729 --- /dev/null +++ b/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml @@ -0,0 +1,218 @@ +# Sortformer Diarizer is an end-to-end speaker diarization model that is solely based on Transformer-encoder type of architecture. +# Model name convention for Sortformer Diarizer: sortformer_diarizer____loss.yaml +# (Example) `sortformer_diarizer_FC18_TF18_hybrid_loss.yaml` has 18 layers for FastConformer and 18 layers of Transformer. +# Sortformer Diarizer model checkpoint (.ckpt) and NeMo file (.nemo) contain Fast Conformer Encoder model (NEST Encoder) and the pre-trained NEST model is loaded along with the Transformer Encoder layers. +# Example: a manifest line for training +# {"audio_filepath": "/path/to/audio01.wav", "offset": 390.83, "duration": 90.00, "text": "-", "num_speakers": 2, "rttm_filepath": "/path/to/audio01.rttm"} +name: "SortFormerDiarizer" +sample_rate: 16000 +num_workers: 18 +batch_size: 8 + +model: + pil_weight: 0.5 + ats_weight: 0.5 + num_workers: ${num_workers} + fc_d_model: 512 + tf_d_model: 192 + max_num_of_spks: 4 # Number of speakers per model. This is currently fixed at 4. + session_len_sec: 90 + + train_ds: + manifest_filepath: ??? + sample_rate: ${sample_rate} + num_spks: ${model.max_num_of_spks} + session_len_sec: ${model.session_len_sec} + soft_label_thres: 0.5 + soft_targets: False + labels: null + batch_size: ${batch_size} + shuffle: True + num_workers: ${num_workers} + validation_mode: False + # lhotse config + use_lhotse: False + use_bucketing: True + num_buckets: 10 + bucket_duration_bins: [10, 20, 30, 40, 50, 60, 70, 80, 90] + pin_memory: True + min_duration: 80 + max_duration: 90 + batch_duration: 400 + quadratic_duration: 1200 + bucket_buffer_size: 20000 + shuffle_buffer_size: 10000 + window_stride: ${model.preprocessor.window_stride} + subsampling_factor: ${model.encoder.subsampling_factor} + + validation_ds: + manifest_filepath: ??? + is_tarred: False + tarred_audio_filepaths: null + sample_rate: ${sample_rate} + num_spks: ${model.max_num_of_spks} + session_len_sec: ${model.session_len_sec} + soft_label_thres: 0.5 + soft_targets: False + labels: null + batch_size: ${batch_size} + shuffle: False + num_workers: ${num_workers} + validation_mode: True + # lhotse config + use_lhotse: False + use_bucketing: False + drop_last: False + pin_memory: True + window_stride: ${model.preprocessor.window_stride} + subsampling_factor: ${model.encoder.subsampling_factor} + + test_ds: + manifest_filepath: null + is_tarred: False + tarred_audio_filepaths: null + sample_rate: 16000 + num_spks: ${model.max_num_of_spks} + session_len_sec: ${model.session_len_sec} + soft_label_thres: 0.5 + soft_targets: False + labels: null + batch_size: ${batch_size} + shuffle: False + seq_eval_mode: True + num_workers: ${num_workers} + validation_mode: True + # lhotse config + use_lhotse: False + use_bucketing: False + drop_last: False + pin_memory: True + window_stride: ${model.preprocessor.window_stride} + subsampling_factor: ${model.encoder.subsampling_factor} + + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + normalize: "per_feature" + window_size: 0.025 + sample_rate: ${sample_rate} + window_stride: 0.01 + window: "hann" + features: 80 + n_fft: 512 + frame_splicing: 1 + dither: 0.00001 + + sortformer_modules: + _target_: nemo.collections.asr.modules.sortformer_modules.SortformerModules + num_spks: ${model.max_num_of_spks} # Number of speakers per model. This is currently fixed at 4. + dropout_rate: 0.5 # Dropout rate + fc_d_model: ${model.fc_d_model} + tf_d_model: ${model.tf_d_model} # Hidden layer size for linear layers in Sortformer Diarizer module + + encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: ${model.preprocessor.features} + feat_out: -1 + n_layers: 18 + d_model: ${model.fc_d_model} + + # Sub-sampling parameters + subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 8 # must be power of 2 for striding and vggnet + subsampling_conv_channels: 256 # set to -1 to make it equal to the d_model + causal_downsampling: false + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [-1, -1] # -1 means unlimited context + att_context_style: regular # regular or chunked_limited + xscaling: true # scales up the input embeddings by sqrt(d_model) + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + + # Convolution module's params + conv_kernel_size: 9 + conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) + conv_context_size: null + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 # The dropout used before the encoder + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + # set to non-zero to enable stochastic depth + stochastic_depth_drop_prob: 0.0 + stochastic_depth_mode: linear # linear or uniform + stochastic_depth_start_layer: 1 + + transformer_encoder: + _target_: nemo.collections.asr.modules.transformer.transformer_encoders.TransformerEncoder + num_layers: 18 + hidden_size: ${model.tf_d_model} # Needs to be multiple of num_attention_heads + inner_size: 768 + num_attention_heads: 8 + attn_score_dropout: 0.5 + attn_layer_dropout: 0.5 + ffn_dropout: 0.5 + hidden_act: relu + pre_ln: False + pre_ln_final_layer_norm: True + + loss: + _target_: nemo.collections.asr.losses.bce_loss.BCELoss + weight: null # Weight for binary cross-entropy loss. Either `null` or list type input. (e.g. [0.5,0.5]) + reduction: mean + + lr: 0.0001 + optim: + name: adamw + lr: ${model.lr} + # optimizer arguments + betas: [0.9, 0.98] + weight_decay: 1e-3 + + sched: + name: InverseSquareRootAnnealing + warmup_steps: 2500 + warmup_ratio: null + min_lr: 1e-06 + +trainer: + devices: 1 # number of gpus (devices) + accelerator: gpu + max_epochs: 800 + max_steps: -1 # computed at runtime if not set + num_nodes: 1 + strategy: ddp_find_unused_parameters_true # Could be "ddp" + accumulate_grad_batches: 1 + deterministic: True + enable_checkpointing: False + logger: False + log_every_n_steps: 1 # Interval of logging. + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + +exp_manager: + use_datetime_version: False + exp_dir: null + name: ${name} + resume_if_exists: True + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + resume_ignore_no_checkpoint: True + create_tensorboard_logger: True + create_checkpoint_callback: True + create_wandb_logger: False + checkpoint_callback_params: + monitor: "val_f1_acc" + mode: "max" + save_top_k: 9 + every_n_epochs: 1 + wandb_logger_kwargs: + resume: True + name: null + project: null \ No newline at end of file diff --git a/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_HL_callhome_part1.yaml b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_callhome-part1.yaml similarity index 85% rename from examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_HL_callhome_part1.yaml rename to examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_callhome-part1.yaml index 6b960e2d5950..3733e1285b77 100644 --- a/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_HL_callhome_part1.yaml +++ b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_callhome-part1.yaml @@ -3,7 +3,6 @@ # Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020). # These parameters were optimized with with hybrid-loss trained Sortformer model introduced in https://arxiv.org/pdf/2409.06656. # These parameters were optimized on the development split of DIHARD3 dataset. See https://arxiv.org/pdf/2012.01477. -# Trial 17903 finished with value: 0.10261257411949805 and parameters: {'onset': 0.53, 'offset': 0.49, 'pad_onset': 0.23, 'pad_offset': 0.0, 'min_duration_on': 0.39, 'min_duration_off': 0.39}. Best is trial 17903 with value: 0.10261257411949805. # Trial 24682 finished with value: 0.10257785779242055 and parameters: {'onset': 0.53, 'offset': 0.49, 'pad_onset': 0.23, 'pad_offset': 0.01, 'min_duration_on': 0.42, 'min_duration_off': 0.34}. Best is trial 24682 with value: 0.10257785779242055. parameters: window_length_in_sec: 0.0 # Not used diff --git a/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_HL_dihard.yaml b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_dihard-dev.yaml similarity index 84% rename from examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_HL_dihard.yaml rename to examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_dihard-dev.yaml index bb9f362ad619..275bc86db4cd 100644 --- a/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_HL_dihard.yaml +++ b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_dihard-dev.yaml @@ -3,7 +3,7 @@ # Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020). # These parameters were optimized with with hybrid-loss trained Sortformer model introduced in https://arxiv.org/pdf/2409.06656. # These parameters were optimized on CallHome Dataset from the NIST SRE 2000 Disc8, especially from the split2 specified in: Kaldi, “Kaldi x-vector recipe v2,” https://github.com/kaldi-asr/kaldi/tree/master/egs/callhome_diarization/v2. -# Trial 180 finished with value: 0.12329626986650599 and parameters: {'onset': 0.56, 'offset': 0.81, 'pad_onset': 0.05, 'pad_offset': 0.0, 'min_duration_on': 0.1, 'min_duration_off': 0.16}. Best is trial 180 with value: 0.12329626986650599. +# Trial 732 finished with value: 0.12171946949255649 and parameters: {'onset': 0.64, 'offset': 0.74, 'pad_onset': 0.06, 'pad_offset': 0.0, 'min_duration_on': 0.1, 'min_duration_off': 0.15}. Best is trial 732 with value: 0.12171946949255649. parameters: window_length_in_sec: 0.0 # Not used shift_length_in_sec: 0.0 # Not used diff --git a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_encoder_infer.py b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_encoder_infer.py deleted file mode 100644 index aafd2b2cb6ed..000000000000 --- a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_encoder_infer.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytorch_lightning as pl -from omegaconf import OmegaConf -from pytorch_lightning import seed_everything -import seaborn as sns -import numpy as np - -from nemo.collections.asr.models import SortformerEncLabelModel -from nemo.core.config import hydra_runner -from nemo.utils import logging -from nemo.utils.exp_manager import exp_manager -seed_everything(42) -import torch -import matplotlib.pyplot as plt -import seaborn as sns -from sklearn.manifold import TSNE -import pandas as pd -from nemo.collections.asr.data.audio_to_msdd_mock_label import generate_mock_embs - -def plot_enc_tsne(x, targets, memo): - # x = enc_states_list[-1].squeeze(0).cpu().detach().numpy() - tsne = TSNE(n_components=2, verbose=False, random_state=100) - zembs = tsne.fit_transform(x) - - # Step 1: Create a new column filled with 0.5 - new_column = torch.full((targets.size(0), 1), 0.5) - # Step 2: Concatenate the new column with the original tensor - updated_targets = torch.cat((new_column, targets), dim=1) - - df = pd.DataFrame() - df["y"] = updated_targets.argmax(dim=1).detach().cpu().numpy() - df["comp-1"] = zembs[:,0] - df["comp-2"] = zembs[:,1] - - # Plotting using seaborn - plt.figure(figsize=(10, 8)) - sns.scatterplot(x="comp-1", y="comp-2", hue=df.y.tolist(), - palette=sns.color_palette("hls", 10), - data=df).set(title="SortFormer HiddenState T-SNE projection") - - # Save the plot as a PNG file in the specified directory - plt.savefig(f'/home/taejinp/Downloads/tsne_plots/tsne_sortformer_plot_{memo}.png') - -def remove_speaker_models(ckpt_path): - ckpt_instance = torch.load(ckpt_path) - _state_dict = ckpt_instance['state_dict'] - - key_list = list(_state_dict.keys()) - for key in key_list: - if '_speaker_model.' in key or '_speaker_model_decoder.' in key: - # import ipdb; ipdb.set_trace() - del _state_dict[key] - - target_path = ckpt_path.replace('.ckpt', '.removed.ckpt') - torch.save(ckpt_instance, target_path) - return target_path - - -# @hydra_runner(config_path="../conf/neural_diarizer", config_name="msdd_5scl_15_05_50Povl_256x3x32x2.yaml") -def main(): - # logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') - # trainer = pl.Trainer(**cfg.trainer) - # exp_manager(trainer, cfg.get("exp_manager", None)) - # ckpt_path = "/disk_c/taejinp_backup/msdd_model_train/NVB_SFmr_MixMockEmbsTest/version_18_f0:84/checkpoints/e613.ckpt" - ckpt_path = "/disk_c/taejinp_backup/msdd_model_train/SFmr_MixMockEmbsTest/version_21/checkpoints/ep2255.ckpt" - target_path = remove_speaker_models(ckpt_path) - sortformer_model = SortformerEncLabelModel.load_from_checkpoint(checkpoint_path=target_path) - unit_len = 25 - targets = torch.eye(4,4).repeat_interleave(unit_len,1).t() - targets[:,2:] = 0 - # targets[:,3:] = 0 - targets = targets[:2*unit_len, :] - new_column = torch.full((targets.size(0), 1), 0.5) - updated_targets = torch.cat((new_column, targets), dim=1) - mock_embs, audio_signal_length, targets = generate_mock_embs(targets=targets, seed=315, - mock_emb_noise_std=0.03, - mock_emb_degree_of_freedom=4, - min_noise_std=0.01,) - mock_embs = mock_embs.unsqueeze(0) - audio_signal = mock_embs - - audio_signal, audio_signal_length, targets - - audio_signal = audio_signal.cuda() - ms_seg_counts = torch.tensor([]).cuda() - ms_seg_timestamps = torch.tensor([]).cuda() - scale_mapping = torch.tensor([]).cuda() - sortformer_model.alpha = 0.0 - - _preds_mean, preds_, attn_score_stack, enc_states_list, preds_list = sortformer_model.forward( - audio_signal=audio_signal, - audio_signal_length=audio_signal_length, - ms_seg_timestamps=ms_seg_timestamps, - ms_seg_counts=ms_seg_counts, - scale_mapping=scale_mapping, - temp_targets=targets, - ) - - audio_signal_np = audio_signal.squeeze(0).cpu().detach().numpy() - plot_enc_tsne(audio_signal_np, targets, memo=f'input', ) - for layer_c in range(len(enc_states_list)): - print(f"Plotting TSNE for layer {layer_c} ...") - x = enc_states_list[layer_c].squeeze(0).cpu().detach().numpy() - plot_enc_tsne(x, targets, memo=f'layer{layer_c}', ) - preds = preds_.squeeze(0).cpu().detach().numpy() - plot_enc_tsne(preds, targets, memo=f'preds', ) - _preds_mean = _preds_mean.squeeze(0).cpu().detach().numpy() - plot_enc_tsne(_preds_mean, targets, memo=f'preds_mean', ) - - # Optionally, you can also show the plot if desired - plt.show() - import ipdb; ipdb.set_trace() - - # msdd_model = SortformerEncLabelModel(cfg=cfg.model, trainer=trainer) - # trainer.fit(msdd_model) - - -if __name__ == '__main__': - main() diff --git a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_encoder_train.py b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py similarity index 100% rename from examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_encoder_train.py rename to examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py From 2f44fe1fd6526ba1b691ccdd57b3dadc22cef4b0 Mon Sep 17 00:00:00 2001 From: tango4j Date: Thu, 14 Nov 2024 09:08:01 +0000 Subject: [PATCH 04/16] Apply isort and black reformatting Signed-off-by: tango4j --- .../neural_diarizer/e2e_diarize_speech.py | 238 ++++--- .../neural_diarizer/sortformer_diar_train.py | 2 +- .../asr/data/audio_to_diar_label.py | 182 ++--- .../asr/data/audio_to_diar_label_lhotse.py | 25 +- nemo/collections/asr/metrics/der.py | 32 +- .../asr/metrics/multi_binary_acc.py | 7 +- nemo/collections/asr/models/__init__.py | 2 +- .../asr/models/sortformer_diar_models.py | 156 +++-- .../asr/modules/sortformer_modules.py | 7 +- .../asr/parts/utils/asr_multispeaker_utils.py | 636 +++++++++++------- .../asr/parts/utils/speaker_utils.py | 128 ++-- nemo/collections/asr/parts/utils/vad_utils.py | 41 +- .../common/parts/preprocessing/collections.py | 16 +- 13 files changed, 850 insertions(+), 622 deletions(-) diff --git a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py index 98f2ee10e523..40ed9fab7a64 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py @@ -19,32 +19,31 @@ dataset_manifest=/path/to/diarization_path_to_manifest.json """ +import logging +import os +import tempfile +from dataclasses import dataclass, is_dataclass +from typing import Dict, List, Optional, Tuple, Union + +import optuna import pytorch_lightning as pl +import torch +import yaml +from hydra.core.config_store import ConfigStore from omegaconf import OmegaConf from pytorch_lightning import seed_everything +from tqdm import tqdm -from nemo.collections.asr.models import SortformerEncLabelModel -from nemo.core.config import hydra_runner from nemo.collections.asr.metrics.der import score_labels -from hydra.core.config_store import ConfigStore - -import os -import yaml -from dataclasses import dataclass, is_dataclass -from typing import Optional, Union, List, Tuple, Dict - +from nemo.collections.asr.models import SortformerEncLabelModel from nemo.collections.asr.parts.utils.speaker_utils import audio_rttm_map, timestamps_to_pyannote_object from nemo.collections.asr.parts.utils.vad_utils import ts_vad_post_processing - -from tqdm import tqdm -import torch -import logging -import optuna -import tempfile +from nemo.core.config import hydra_runner seed_everything(42) torch.backends.cudnn.deterministic = True + @dataclass class PostProcessingParams: window_length_in_sec: float = 0.15 @@ -59,6 +58,7 @@ class PostProcessingParams: min_duration_off: float = 0.0 filter_speech_first: bool = True + @dataclass class DiarizationConfig: # Required configs @@ -66,50 +66,53 @@ class DiarizationConfig: pretrained_name: Optional[str] = None # Name of a pretrained model audio_dir: Optional[str] = None # Path to a directory which contains audio files dataset_manifest: Optional[str] = None # Path to dataset's JSON manifest - + postprocessing_yaml: Optional[str] = None # Path to a yaml file for postprocessing configurations no_der: bool = False out_rttm_dir: Optional[str] = None - + # General configs - session_len_sec: float = -1 # End-to-end diarization session length in seconds + session_len_sec: float = -1 # End-to-end diarization session length in seconds batch_size: int = 4 num_workers: int = 0 random_seed: Optional[int] = None # seed number going to be used in seed_everything() - bypass_postprocessing: bool = True # If True, postprocessing will be bypassed - + bypass_postprocessing: bool = True # If True, postprocessing will be bypassed + # Eval Settings: (0.25, False) should be default setting for sortformer eval. - collar: float = 0.25 # Collar in seconds for DER calculation - ignore_overlap: bool = False # If True, DER will be calculated only for non-overlapping segments + collar: float = 0.25 # Collar in seconds for DER calculation + ignore_overlap: bool = False # If True, DER will be calculated only for non-overlapping segments # If `cuda` is a negative number, inference will be on CPU only. cuda: Optional[int] = None matmul_precision: str = "highest" # Literal["highest", "high", "medium"] # Optuna Config - launch_pp_optim: bool = False # If True, launch optimization process for postprocessing parameters + launch_pp_optim: bool = False # If True, launch optimization process for postprocessing parameters optuna_study_name: str = "optim_postprocessing" optuna_temp_dir: str = "/tmp/optuna" optuna_storage: str = f"sqlite:///{optuna_study_name}.db" optuna_log_file: str = f"{optuna_study_name}.log" optuna_n_trials: int = 100000 + def load_postprocessing_from_yaml(postprocessing_yaml): - """ + """ Load postprocessing parameters from a YAML file. Args: - postprocessing_yaml (str): + postprocessing_yaml (str): Path to a YAML file for postprocessing configurations. Returns: - postprocessing_params (dataclass): + postprocessing_params (dataclass): Postprocessing parameters loaded from the YAML file. """ # Add PostProcessingParams as a field postprocessing_params = OmegaConf.structured(PostProcessingParams()) if postprocessing_yaml is None: - logging.info(f"No postprocessing YAML file has been provided. Default postprocessing configurations will be applied.") + logging.info( + f"No postprocessing YAML file has been provided. Default postprocessing configurations will be applied." + ) else: # Load postprocessing params from the provided YAML file with open(postprocessing_yaml, 'r') as file: @@ -121,6 +124,7 @@ def load_postprocessing_from_yaml(postprocessing_yaml): setattr(postprocessing_params, key, value) return postprocessing_params + def optuna_suggest_params(postprocessing_cfg: PostProcessingParams, trial: optuna.Trial) -> PostProcessingParams: """ Suggests hyperparameters for postprocessing using Optuna. @@ -140,6 +144,7 @@ def optuna_suggest_params(postprocessing_cfg: PostProcessingParams, trial: optun postprocessing_cfg.min_duration_off = trial.suggest_float("min_duration_off", 0.0, 0.75, step=0.01) return postprocessing_cfg + def get_tensor_path(cfg: DiarizationConfig) -> str: """ Constructs the file path for saving or loading prediction tensors based on the configuration. @@ -159,20 +164,21 @@ def get_tensor_path(cfg: DiarizationConfig) -> str: tensor_path = f"{bpath}/__{model_id}__{tensor_filename}.pt" return tensor_path + def diarization_objective( - trial: optuna.Trial, - postprocessing_cfg: PostProcessingParams, - temp_out_dir: str, - infer_audio_rttm_dict: Dict[str, Dict[str, str]], - diar_model_preds_total_list: List[torch.Tensor], - collar: float = 0.25, - ignore_overlap: bool = False + trial: optuna.Trial, + postprocessing_cfg: PostProcessingParams, + temp_out_dir: str, + infer_audio_rttm_dict: Dict[str, Dict[str, str]], + diar_model_preds_total_list: List[torch.Tensor], + collar: float = 0.25, + ignore_overlap: bool = False, ) -> float: """ Objective function for Optuna hyperparameter optimization in speaker diarization. This function evaluates the diarization performance using a set of postprocessing parameters - suggested by Optuna. It converts prediction matrices to time-stamp segments, scores the + suggested by Optuna. It converts prediction matrices to time-stamp segments, scores the diarization results, and returns the Diarization Error Rate (DER) as the optimization metric. Args: @@ -192,42 +198,43 @@ def diarization_objective( """ with tempfile.TemporaryDirectory(dir=temp_out_dir, prefix="Diar_PostProcessing_") as local_temp_out_dir: if trial is not None: - postprocessing_cfg = optuna_suggest_params(postprocessing_cfg, trial) - all_hyps, all_refs, all_uems = convert_pred_mat_to_segments(audio_rttm_map_dict=infer_audio_rttm_dict, - postprocessing_cfg=postprocessing_cfg, - batch_preds_list=diar_model_preds_total_list, - unit_10ms_frame_count=8, - bypass_postprocessing=False) - metric, mapping_dict, itemized_errors = score_labels(AUDIO_RTTM_MAP=infer_audio_rttm_dict, - all_reference=all_refs, - all_hypothesis=all_hyps, - all_uem=all_uems, - collar=collar, - ignore_overlap=ignore_overlap - ) + postprocessing_cfg = optuna_suggest_params(postprocessing_cfg, trial) + all_hyps, all_refs, all_uems = convert_pred_mat_to_segments( + audio_rttm_map_dict=infer_audio_rttm_dict, + postprocessing_cfg=postprocessing_cfg, + batch_preds_list=diar_model_preds_total_list, + unit_10ms_frame_count=8, + bypass_postprocessing=False, + ) + metric, mapping_dict, itemized_errors = score_labels( + AUDIO_RTTM_MAP=infer_audio_rttm_dict, + all_reference=all_refs, + all_hypothesis=all_hyps, + all_uem=all_uems, + collar=collar, + ignore_overlap=ignore_overlap, + ) der = abs(metric) return der + def run_optuna_hyperparam_search( cfg: DiarizationConfig, # type: DiarizationConfig postprocessing_cfg: PostProcessingParams, - infer_audio_rttm_dict: Dict[str, Dict[str, str]], - preds_list: List[torch.Tensor], - temp_out_dir: str, - ): + infer_audio_rttm_dict: Dict[str, Dict[str, str]], + preds_list: List[torch.Tensor], + temp_out_dir: str, +): worker_function = lambda trial: diarization_objective( trial=trial, postprocessing_cfg=postprocessing_cfg, temp_out_dir=temp_out_dir, - infer_audio_rttm_dict=infer_audio_rttm_dict, + infer_audio_rttm_dict=infer_audio_rttm_dict, diar_model_preds_total_list=preds_list, collar=cfg.collar, ) study = optuna.create_study( - direction="minimize", - study_name=cfg.optuna_study_name, - storage=cfg.optuna_storage, - load_if_exists=True + direction="minimize", study_name=cfg.optuna_study_name, storage=cfg.optuna_storage, load_if_exists=True ) logger = logging.getLogger() logger.setLevel(logging.INFO) # Setup the root logger. @@ -235,17 +242,17 @@ def run_optuna_hyperparam_search( logger.addHandler(logging.FileHandler(cfg.optuna_log_file, mode="a")) logger.addHandler(logging.StreamHandler()) optuna.logging.enable_propagation() # Propagate logs to the root logger. - study.optimize(worker_function, n_trials=cfg.optuna_n_trials) + study.optimize(worker_function, n_trials=cfg.optuna_n_trials) def convert_pred_mat_to_segments( - audio_rttm_map_dict: Dict[str, Dict[str, str]], - postprocessing_cfg, - batch_preds_list: List[torch.Tensor], - unit_10ms_frame_count:int = 8, + audio_rttm_map_dict: Dict[str, Dict[str, str]], + postprocessing_cfg, + batch_preds_list: List[torch.Tensor], + unit_10ms_frame_count: int = 8, bypass_postprocessing: bool = False, out_rttm_dir: str | None = None, - ): +): """ Convert prediction matrix to time-stamp segments. @@ -263,32 +270,38 @@ def convert_pred_mat_to_segments( """ batch_pred_ts_segs, all_hypothesis, all_reference, all_uems = [], [], [], [] cfg_vad_params = OmegaConf.structured(postprocessing_cfg) - for sample_idx, (uniq_id, audio_rttm_values) in tqdm(enumerate(audio_rttm_map_dict.items()), total=len(audio_rttm_map_dict), desc="Running post-processing"): + for sample_idx, (uniq_id, audio_rttm_values) in tqdm( + enumerate(audio_rttm_map_dict.items()), total=len(audio_rttm_map_dict), desc="Running post-processing" + ): spk_ts = [] offset, duration = audio_rttm_values['offset'], audio_rttm_values['duration'] speaker_assign_mat = batch_preds_list[sample_idx].squeeze(dim=0) speaker_timestamps = [[] for _ in range(speaker_assign_mat.shape[-1])] for spk_id in range(speaker_assign_mat.shape[-1]): - ts_mat = ts_vad_post_processing(speaker_assign_mat[:, spk_id], - cfg_vad_params=cfg_vad_params, - unit_10ms_frame_count=unit_10ms_frame_count, - bypass_postprocessing=bypass_postprocessing) + ts_mat = ts_vad_post_processing( + speaker_assign_mat[:, spk_id], + cfg_vad_params=cfg_vad_params, + unit_10ms_frame_count=unit_10ms_frame_count, + bypass_postprocessing=bypass_postprocessing, + ) ts_mat = ts_mat + offset ts_mat = torch.clamp(ts_mat, min=offset, max=(offset + duration)) ts_seg_list = ts_mat.tolist() speaker_timestamps[spk_id].extend(ts_seg_list) spk_ts.append(ts_seg_list) - all_hypothesis, all_reference, all_uems = timestamps_to_pyannote_object(speaker_timestamps, - uniq_id, - audio_rttm_values, - all_hypothesis, - all_reference, - all_uems, - out_rttm_dir, - ) - batch_pred_ts_segs.append(spk_ts) + all_hypothesis, all_reference, all_uems = timestamps_to_pyannote_object( + speaker_timestamps, + uniq_id, + audio_rttm_values, + all_hypothesis, + all_reference, + all_uems, + out_rttm_dir, + ) + batch_pred_ts_segs.append(spk_ts) return all_hypothesis, all_reference, all_uems + @hydra_runner(config_name="DiarizationConfig", schema=DiarizationConfig) def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: for key in cfg: @@ -299,7 +312,7 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: if cfg.random_seed: pl.seed_everything(cfg.random_seed) - + if cfg.model_path is None and cfg.pretrained_name is None: raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None!") if cfg.audio_dir is None and cfg.dataset_manifest is None: @@ -322,65 +335,74 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: map_location = torch.device(f'cuda:{cfg.cuda}') if cfg.model_path.endswith(".ckpt"): - diar_model = SortformerEncLabelModel.load_from_checkpoint(checkpoint_path=cfg.model_path, map_location=map_location, strict=False) + diar_model = SortformerEncLabelModel.load_from_checkpoint( + checkpoint_path=cfg.model_path, map_location=map_location, strict=False + ) elif cfg.model_path.endswith(".nemo"): diar_model = SortformerEncLabelModel.restore_from(restore_path=cfg.model_path, map_location=map_location) else: raise ValueError("cfg.model_path must end with.ckpt or.nemo!") - + diar_model._cfg.test_ds.session_len_sec = cfg.session_len_sec trainer = pl.Trainer(devices=device, accelerator=accelerator) diar_model.set_trainer(trainer) - + diar_model = diar_model.eval() diar_model._cfg.test_ds.manifest_filepath = cfg.dataset_manifest infer_audio_rttm_dict = audio_rttm_map(cfg.dataset_manifest) diar_model._cfg.test_ds.batch_size = cfg.batch_size - - # Model setup for inference + + # Model setup for inference diar_model._cfg.test_ds.num_workers = cfg.num_workers - diar_model.setup_test_data(test_data_config=diar_model._cfg.test_ds) - + diar_model.setup_test_data(test_data_config=diar_model._cfg.test_ds) + postprocessing_cfg = load_postprocessing_from_yaml(cfg.postprocessing_yaml) tensor_path = get_tensor_path(cfg) - + if os.path.exists(tensor_path): - logging.info(f"A saved prediction tensor has been found. Loading the saved prediction tensors from {tensor_path}...") + logging.info( + f"A saved prediction tensor has been found. Loading the saved prediction tensors from {tensor_path}..." + ) diar_model_preds_total_list = torch.load(tensor_path) else: logging.info(f"No saved prediction tensors found. Running inference on the dataset...") diar_model.test_batch() diar_model_preds_total_list = diar_model.preds_total_list torch.save(diar_model.preds_total_list, tensor_path) - + if cfg.launch_pp_optim: # Launch a hyperparameter optimization process if launch_pp_optim is True - run_optuna_hyperparam_search(cfg=cfg, - postprocessing_cfg=postprocessing_cfg, - infer_audio_rttm_dict=infer_audio_rttm_dict, - preds_list=diar_model_preds_total_list, - temp_out_dir=cfg.optuna_temp_dir) + run_optuna_hyperparam_search( + cfg=cfg, + postprocessing_cfg=postprocessing_cfg, + infer_audio_rttm_dict=infer_audio_rttm_dict, + preds_list=diar_model_preds_total_list, + temp_out_dir=cfg.optuna_temp_dir, + ) # Evaluation if not cfg.no_der: if cfg.out_rttm_dir is not None and not os.path.exists(cfg.out_rttm_dir): os.mkdir(cfg.out_rttm_dir) - all_hyps, all_refs, all_uems = convert_pred_mat_to_segments(infer_audio_rttm_dict, - postprocessing_cfg=postprocessing_cfg, - batch_preds_list=diar_model_preds_total_list, - unit_10ms_frame_count=8, - bypass_postprocessing=cfg.bypass_postprocessing, - out_rttm_dir=cfg.out_rttm_dir - ) + all_hyps, all_refs, all_uems = convert_pred_mat_to_segments( + infer_audio_rttm_dict, + postprocessing_cfg=postprocessing_cfg, + batch_preds_list=diar_model_preds_total_list, + unit_10ms_frame_count=8, + bypass_postprocessing=cfg.bypass_postprocessing, + out_rttm_dir=cfg.out_rttm_dir, + ) logging.info(f"Evaluating the model on the {len(diar_model_preds_total_list)} audio segments...") - metric, mapping_dict, itemized_errors = score_labels(AUDIO_RTTM_MAP=infer_audio_rttm_dict, - all_reference=all_refs, - all_hypothesis=all_hyps, - all_uem=all_uems, - collar=cfg.collar, - ignore_overlap=cfg.ignore_overlap - ) + metric, mapping_dict, itemized_errors = score_labels( + AUDIO_RTTM_MAP=infer_audio_rttm_dict, + all_reference=all_refs, + all_hypothesis=all_hyps, + all_uem=all_uems, + collar=cfg.collar, + ignore_overlap=cfg.ignore_overlap, + ) logging.info(f"PostProcessingParams: {postprocessing_cfg}") + if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py index fb350113d596..3ba0dbc3ed19 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py @@ -50,5 +50,5 @@ def main(cfg): if __name__ == '__main__': - + main() diff --git a/nemo/collections/asr/data/audio_to_diar_label.py b/nemo/collections/asr/data/audio_to_diar_label.py index ffad8e4fd072..b00338743a43 100644 --- a/nemo/collections/asr/data/audio_to_diar_label.py +++ b/nemo/collections/asr/data/audio_to_diar_label.py @@ -15,18 +15,23 @@ import os from collections import OrderedDict from statistics import mode -from typing import Dict, List, Tuple, Optional -import torch +from typing import Dict, List, Optional, Tuple + import numpy as np +import torch -from nemo.collections.asr.parts.utils.offline_clustering import get_argmin_mat from nemo.collections.asr.parts.utils.asr_multispeaker_utils import find_first_nonzero -from nemo.collections.asr.parts.utils.speaker_utils import convert_rttm_line, prepare_split_data, get_subsegments -from nemo.collections.common.parts.preprocessing.collections import DiarizationSpeechLabel, EndtoEndDiarizationSpeechLabel +from nemo.collections.asr.parts.utils.offline_clustering import get_argmin_mat +from nemo.collections.asr.parts.utils.speaker_utils import convert_rttm_line, get_subsegments, prepare_split_data +from nemo.collections.common.parts.preprocessing.collections import ( + DiarizationSpeechLabel, + EndtoEndDiarizationSpeechLabel, +) from nemo.core.classes import Dataset from nemo.core.neural_types import AudioSignal, EncodedRepresentation, LengthsType, NeuralType, ProbsType from nemo.utils import logging + def get_scale_mapping_list(uniq_timestamps): """ Call get_argmin_mat function to find the index of the non-base-scale segment that is closest to the @@ -125,7 +130,7 @@ def assign_frame_level_spk_vector(rttm_timestamps, round_digits, frame_per_sec, return None else: sorted_speakers = sorted(list(set(speaker_list))) - total_fr_len = int(max(end_list) * (10 ** round_digits)) + total_fr_len = int(max(end_list) * (10**round_digits)) spk_num = max(len(sorted_speakers), min_spks) speaker_mapping_dict = {rttm_key: x_int for x_int, rttm_key in enumerate(sorted_speakers)} fr_level_target = torch.zeros(total_fr_len, spk_num) @@ -141,27 +146,24 @@ def assign_frame_level_spk_vector(rttm_timestamps, round_digits, frame_per_sec, def get_subsegments_to_timestamps( - subsegments: List[Tuple[float, float]], - feat_per_sec: int = 100, - max_end_ts: float=None, - decimals=2 - ): + subsegments: List[Tuple[float, float]], feat_per_sec: int = 100, max_end_ts: float = None, decimals=2 +): """ Convert subsegment timestamps to scale timestamps by multiplying with the feature rate and rounding. All `ts` related tensors are dimensioned as (N, 2), where N is the number of subsegments. Args: - subsegments (List[Tuple[float, float]]): + subsegments (List[Tuple[float, float]]): A list of tuples where each tuple contains the start and end times of a subsegment. - feat_per_sec (int, optional): + feat_per_sec (int, optional): The number of feature frames per second. Defaults to 100. - max_end_ts (float, optional): + max_end_ts (float, optional): The maximum end timestamp to clip the results. If None, no clipping is applied. Defaults to None. - decimals (int, optional): + decimals (int, optional): The number of decimal places to round the timestamps. Defaults to 2. Returns: - ts (torch.tensor): + ts (torch.tensor): A tensor containing the scaled and rounded timestamps for each subsegment. """ seg_ts = (torch.tensor(subsegments) * feat_per_sec).float() @@ -169,8 +171,9 @@ def get_subsegments_to_timestamps( ts = ts_round.long() ts[:, 1] = ts[:, 0] + ts[:, 1] if max_end_ts is not None: - ts = np.clip(ts, 0, int(max_end_ts*feat_per_sec)) - return ts + ts = np.clip(ts, 0, int(max_end_ts * feat_per_sec)) + return ts + def extract_frame_info_from_rttm(uniq_id, offset, duration, rttm_lines, round_digits=3): """ @@ -190,42 +193,43 @@ def extract_frame_info_from_rttm(uniq_id, offset, duration, rttm_lines, round_di rttm_stt, rttm_end = offset, offset + duration stt_list, end_list, speaker_list, speaker_set = [], [], [], [] sess_to_global_spkids = dict() - + for rttm_line in rttm_lines: start, end, speaker = convert_rttm_line(rttm_line) - + # Skip invalid RTTM lines where the start time is greater than the end time. if start > end: continue - + # Check if the RTTM segment overlaps with the specified segment of interest. if (end > rttm_stt and start < rttm_end) or (start < rttm_end and end > rttm_stt): # Adjust the start and end times to fit within the segment of interest. start, end = max(start, rttm_stt), min(end, rttm_end) else: continue - + # Round the start and end times to the specified number of decimal places. end_list.append(round(end, round_digits)) stt_list.append(round(start, round_digits)) - + # Assign a unique index to each speaker and maintain a mapping. if speaker not in speaker_set: speaker_set.append(speaker) speaker_list.append(speaker_set.index(speaker)) sess_to_global_spkids.update({speaker_set.index(speaker): speaker}) - + rttm_mat = (stt_list, end_list, speaker_list) return rttm_mat, sess_to_global_spkids + def get_frame_targets_from_rttm( - rttm_timestamps: list, - offset: float, - duration: float, - round_digits: int, - feat_per_sec: int, + rttm_timestamps: list, + offset: float, + duration: float, + round_digits: int, + feat_per_sec: int, max_spks: int, - ): +): """ Create a multi-dimensional vector sequence containing speaker timestamp information in RTTM. The unit-length is the frame shift length of the acoustic feature. The feature-level annotations @@ -249,15 +253,17 @@ def get_frame_targets_from_rttm( sorted_speakers = sorted(list(set(speaker_list))) total_fr_len = int(duration * feat_per_sec) if len(sorted_speakers) > max_spks: - logging.warning(f"Number of speakers in RTTM file {len(sorted_speakers)} exceeds the maximum number of speakers: {max_spks}! Only {max_spks} first speakers remain, and this will affect frame metrics!") - feat_level_target = torch.zeros(total_fr_len, max_spks) + logging.warning( + f"Number of speakers in RTTM file {len(sorted_speakers)} exceeds the maximum number of speakers: {max_spks}! Only {max_spks} first speakers remain, and this will affect frame metrics!" + ) + feat_level_target = torch.zeros(total_fr_len, max_spks) for count, (stt, end, spk_rttm_key) in enumerate(zip(stt_list, end_list, speaker_list)): if end < offset or stt > offset + duration: continue stt, end = max(offset, stt), min(offset + duration, end) spk = spk_rttm_key if spk < max_spks: - stt_fr, end_fr = int((stt - offset) * feat_per_sec), int((end - offset)* feat_per_sec) + stt_fr, end_fr = int((stt - offset) * feat_per_sec), int((end - offset) * feat_per_sec) feat_level_target[stt_fr:end_fr, spk] = 1 return feat_level_target @@ -337,7 +343,7 @@ def __init__( self.multiscale_args_dict = multiscale_args_dict self.emb_dir = emb_dir self.round_digits = 2 - self.decim = 10 ** self.round_digits + self.decim = 10**self.round_digits self.soft_label_thres = soft_label_thres self.pairwise_infer = pairwise_infer self.max_spks = 2 @@ -347,7 +353,10 @@ def __init__( self.global_rank = global_rank self.manifest_filepath = manifest_filepath self.multiscale_timestamp_dict = prepare_split_data( - self.manifest_filepath, self.emb_dir, self.multiscale_args_dict, self.global_rank, + self.manifest_filepath, + self.emb_dir, + self.multiscale_args_dict, + self.global_rank, ) def __len__(self): @@ -364,7 +373,7 @@ def assign_labels_to_longer_segs(self, uniq_id, base_scale_clus_label): Unique sample ID for training. base_scale_clus_label (torch.tensor): Tensor variable containing the speaker labels for the base-scale segments. - + Returns: per_scale_clus_label (torch.tensor): Tensor variable containing the speaker labels for each segment in each scale. @@ -415,7 +424,7 @@ def get_diar_target_labels(self, uniq_id, sample, fr_level_target): seg_target_list, base_clus_label = [], [] self.scale_n = len(self.multiscale_timestamp_dict[uniq_id]['scale_dict']) subseg_time_stamp_list = self.multiscale_timestamp_dict[uniq_id]["scale_dict"][self.scale_n - 1]["time_stamps"] - for (seg_stt, seg_end) in subseg_time_stamp_list: + for seg_stt, seg_end in subseg_time_stamp_list: seg_stt_fr, seg_end_fr = int(seg_stt * self.frame_per_sec), int(seg_end * self.frame_per_sec) soft_label_vec_sess = torch.sum(fr_level_target[seg_stt_fr:seg_end_fr, :], axis=0) / ( seg_end_fr - seg_stt_fr @@ -619,7 +628,7 @@ def __init__( self.emb_seq = emb_seq self.clus_label_dict = clus_label_dict self.round_digits = 2 - self.decim = 10 ** self.round_digits + self.decim = 10**self.round_digits self.frame_per_sec = int(1 / window_stride) self.soft_label_thres = soft_label_thres self.pairwise_infer = pairwise_infer @@ -685,7 +694,7 @@ def get_diar_target_labels_from_fr_target(self, uniq_id, fr_level_target): return None else: seg_target_list = [] - for (seg_stt, seg_end, label_int) in self.clus_label_dict[uniq_id]: + for seg_stt, seg_end, label_int in self.clus_label_dict[uniq_id]: seg_stt_fr, seg_end_fr = int(seg_stt * self.frame_per_sec), int(seg_end * self.frame_per_sec) soft_label_vec = torch.sum(fr_level_target[seg_stt_fr:seg_end_fr, :], axis=0) / ( seg_end_fr - seg_stt_fr @@ -975,6 +984,7 @@ def __init__( def msdd_infer_collate_fn(self, batch): return _msdd_infer_collate_fn(self, batch) + class _AudioToSpeechE2ESpkDiarDataset(Dataset): """ Dataset class that loads a json file containing paths to audio files, @@ -1047,8 +1057,8 @@ def __init__( self.use_asr_style_frame_count = True self.soft_targets = soft_targets self.round_digits = 2 - self.floor_decimal = 10 ** self.round_digits - + self.floor_decimal = 10**self.round_digits + def __len__(self): return len(self.collection) @@ -1085,15 +1095,16 @@ def parse_rttm_for_targets_and_lens(self, uniq_id, rttm_file, offset, duration, rttm_lines = open(rttm_file).readlines() rttm_timestamps, sess_to_global_spkids = extract_frame_info_from_rttm(uniq_id, offset, duration, rttm_lines) - fr_level_target = get_frame_targets_from_rttm(rttm_timestamps=rttm_timestamps, - offset=offset, - duration=duration, - round_digits=self.round_digits, - feat_per_sec=self.feat_per_sec, - max_spks=self.max_spks) + fr_level_target = get_frame_targets_from_rttm( + rttm_timestamps=rttm_timestamps, + offset=offset, + duration=duration, + round_digits=self.round_digits, + feat_per_sec=self.feat_per_sec, + max_spks=self.max_spks, + ) - soft_target_seg = self.get_soft_targets_seg(feat_level_target=fr_level_target, - target_len=target_len) + soft_target_seg = self.get_soft_targets_seg(feat_level_target=fr_level_target, target_len=target_len) if self.soft_targets: step_target = soft_target_seg else: @@ -1128,15 +1139,15 @@ def get_soft_targets_seg(self, feat_level_target, target_len): seg_end_feat = feat_level_target.shape[0] else: seg_end_feat = stride * index - 1 + int(stride / 2) - targets[index] = torch.mean(feat_level_target[seg_stt_feat:seg_end_feat+1, :], axis=0) + targets[index] = torch.mean(feat_level_target[seg_stt_feat : seg_end_feat + 1, :], axis=0) return targets def get_segment_timestamps( self, - duration: float, - offset: float = 0, + duration: float, + offset: float = 0, sample_rate: int = 16000, - ): + ): """ Get start and end time of segments in each scale. @@ -1150,22 +1161,28 @@ def get_segment_timestamps( Number of segments for each scale. This information is used for reshaping embedding batch during forward propagation. """ - subsegments = get_subsegments(offset=offset, - window=round(self.diar_frame_length * 2, self.round_digits), - shift=self.diar_frame_length, - duration=duration, - min_subsegment_duration=self.min_subsegment_duration, - use_asr_style_frame_count=self.use_asr_style_frame_count, - sample_rate=sample_rate, - feat_per_sec=self.feat_per_sec, + subsegments = get_subsegments( + offset=offset, + window=round(self.diar_frame_length * 2, self.round_digits), + shift=self.diar_frame_length, + duration=duration, + min_subsegment_duration=self.min_subsegment_duration, + use_asr_style_frame_count=self.use_asr_style_frame_count, + sample_rate=sample_rate, + feat_per_sec=self.feat_per_sec, ) if self.use_asr_style_frame_count: - effective_dur = np.ceil((1+duration*sample_rate)/int(sample_rate/self.feat_per_sec)).astype(int)/self.feat_per_sec + effective_dur = ( + np.ceil((1 + duration * sample_rate) / int(sample_rate / self.feat_per_sec)).astype(int) + / self.feat_per_sec + ) else: - effective_dur = duration - ts_tensor = get_subsegments_to_timestamps(subsegments, self.feat_per_sec, decimals=2, max_end_ts=(offset+effective_dur)) + effective_dur = duration + ts_tensor = get_subsegments_to_timestamps( + subsegments, self.feat_per_sec, decimals=2, max_end_ts=(offset + effective_dur) + ) target_len = torch.tensor([ts_tensor.shape[0]]) - return target_len + return target_len def __getitem__(self, index): sample = self.collection[index] @@ -1179,24 +1196,25 @@ def __getitem__(self, index): uniq_id = self.get_uniq_id_with_range(sample) audio_signal = self.featurizer.process(sample.audio_file, offset=offset, duration=session_len_sec) - + # We should resolve the length mis-match from the round-off errors: `session_len_sec` and `audio_signal.shape[0]` - session_len_sec = np.floor(audio_signal.shape[0] / self.featurizer.sample_rate * self.floor_decimal)/self.floor_decimal - audio_signal = audio_signal[:round(self.featurizer.sample_rate*session_len_sec)] - + session_len_sec = ( + np.floor(audio_signal.shape[0] / self.featurizer.sample_rate * self.floor_decimal) / self.floor_decimal + ) + audio_signal = audio_signal[: round(self.featurizer.sample_rate * session_len_sec)] + audio_signal_length = torch.tensor(audio_signal.shape[0]).long() audio_signal, audio_signal_length = audio_signal.to('cpu'), audio_signal_length.to('cpu') target_len = self.get_segment_timestamps(duration=session_len_sec, sample_rate=self.featurizer.sample_rate) - targets = self.parse_rttm_for_targets_and_lens(uniq_id=uniq_id, - rttm_file=sample.rttm_file, - offset=offset, - duration=session_len_sec, - target_len=target_len) + targets = self.parse_rttm_for_targets_and_lens( + uniq_id=uniq_id, rttm_file=sample.rttm_file, offset=offset, duration=session_len_sec, target_len=target_len + ) return audio_signal, audio_signal_length, targets, target_len + def _eesd_train_collate_fn(self, batch): """ - Collate a batch of variables needed for training the end-to-end speaker diarization (EESD) model + Collate a batch of variables needed for training the end-to-end speaker diarization (EESD) model from raw waveforms to diarization labels. The following variables are included in the training/validation batch: Args: @@ -1249,24 +1267,25 @@ def _eesd_train_collate_fn(self, batch): targets = torch.stack(targets_list) return audio_signal, feature_length, targets, target_lens + class AudioToSpeechE2ESpkDiarDataset(_AudioToSpeechE2ESpkDiarDataset): """ Dataset class for loading a JSON file containing paths to audio files, RTTM (Rich Transcription Time Marked) files, and the number of speakers. - This class is designed for training or fine-tuning a speaker embedding + This class is designed for training or fine-tuning a speaker embedding extractor and diarization decoder simultaneously. The JSON manifest file should have entries in the following format: - + Example: { - "audio_filepath": "/path/to/audio_0.wav", + "audio_filepath": "/path/to/audio_0.wav", "num_speakers": 2, "rttm_filepath": "/path/to/diar_label_0.rttm" } ... { - "audio_filepath": "/path/to/audio_n.wav", + "audio_filepath": "/path/to/audio_n.wav", "num_speakers": 2, "rttm_filepath": "/path/to/diar_label_n.rttm" } @@ -1283,7 +1302,7 @@ class AudioToSpeechE2ESpkDiarDataset(_AudioToSpeechE2ESpkDiarDataset): featurizer: Instance of a featurizer for generating features from the raw waveform. window_stride (float): - Window stride (in seconds) for extracting acoustic features, used to calculate + Window stride (in seconds) for extracting acoustic features, used to calculate the number of feature frames. global_rank (int): Global rank of the current process (used for distributed training). @@ -1294,6 +1313,7 @@ class AudioToSpeechE2ESpkDiarDataset(_AudioToSpeechE2ESpkDiarDataset): eesd_train_collate_fn(batch): Collates a batch of data for end-to-end speaker diarization training. """ + def __init__( self, *, @@ -1318,4 +1338,4 @@ def __init__( ) def eesd_train_collate_fn(self, batch): - return _eesd_train_collate_fn(self, batch) \ No newline at end of file + return _eesd_train_collate_fn(self, batch) diff --git a/nemo/collections/asr/data/audio_to_diar_label_lhotse.py b/nemo/collections/asr/data/audio_to_diar_label_lhotse.py index e223e4ef2a56..8d11c4c1167d 100644 --- a/nemo/collections/asr/data/audio_to_diar_label_lhotse.py +++ b/nemo/collections/asr/data/audio_to_diar_label_lhotse.py @@ -18,11 +18,12 @@ from lhotse.dataset import AudioSamples from lhotse.dataset.collation import collate_matrices -from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType from nemo.collections.asr.parts.utils.asr_multispeaker_utils import ( - speaker_to_target, - get_hidden_length_from_sample_length, + get_hidden_length_from_sample_length, + speaker_to_target, ) +from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType + class LhotseAudioToSpeechE2ESpkDiarDataset(torch.utils.data.Dataset): """ @@ -43,16 +44,18 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]: 'target_length': NeuralType(tuple('B'), LengthsType()), 'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True), } - + def __init__(self, cfg): super().__init__() self.load_audio = AudioSamples(fault_tolerant=True) self.cfg = cfg self.num_speakers = self.cfg.get('num_speakers', 4) - self.num_sample_per_mel_frame = int(self.cfg.get('window_stride', 0.01) * self.cfg.get('sample_rate', 16000)) # 160 + self.num_sample_per_mel_frame = int( + self.cfg.get('window_stride', 0.01) * self.cfg.get('sample_rate', 16000) + ) # 160 self.num_mel_frame_per_target_frame = int(self.cfg.get('subsampling_factor', 8)) - self.spk_tar_all_zero = self.cfg.get('spk_tar_all_zero',False) - + self.spk_tar_all_zero = self.cfg.get('spk_tar_all_zero', False) + def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]: audio, audio_lens, cuts = self.load_audio(cuts) speaker_activities = [] @@ -63,14 +66,16 @@ def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]: num_sample_per_mel_frame=self.num_sample_per_mel_frame, num_mel_frame_per_asr_frame=self.num_mel_frame_per_target_frame, spk_tar_all_zero=self.spk_tar_all_zero, - boundary_segments=True + boundary_segments=True, ) speaker_activities.append(speaker_activity) targets = collate_matrices(speaker_activities).to(audio.dtype) target_lens_list = [] for audio_len in audio_lens: - target_fr_len = get_hidden_length_from_sample_length(audio_len, self.num_sample_per_mel_frame, self.num_mel_frame_per_target_frame) + target_fr_len = get_hidden_length_from_sample_length( + audio_len, self.num_sample_per_mel_frame, self.num_mel_frame_per_target_frame + ) target_lens_list.append([target_fr_len]) target_lens = torch.tensor(target_lens_list) - + return audio, audio_lens, targets, target_lens diff --git a/nemo/collections/asr/metrics/der.py b/nemo/collections/asr/metrics/der.py index 16f62bbe9e4c..000b839ceb46 100644 --- a/nemo/collections/asr/metrics/der.py +++ b/nemo/collections/asr/metrics/der.py @@ -36,12 +36,12 @@ def get_partial_ref_labels(pred_labels: List[str], ref_labels: List[str]) -> List[str]: """ - For evaluation of online diarization performance, generate partial reference labels + For evaluation of online diarization performance, generate partial reference labels from the last prediction time. Args: pred_labels (list[str]): list of partial prediction labels - ref_labels (list[str]): list of full reference labels + ref_labels (list[str]): list of full reference labels Returns: ref_labels_out (list[str]): list of partial reference labels @@ -84,8 +84,8 @@ def get_online_DER_stats( For evaluation of online diarization performance, add cumulative, average, and maximum DER/CER. Args: - DER (float): Diarization Error Rate from the start to the current point - CER (float): Confusion Error Rate from the start to the current point + DER (float): Diarization Error Rate from the start to the current point + CER (float): Confusion Error Rate from the start to the current point FA (float): False Alarm from the start to the current point MISS (float): Miss rate from the start to the current point diar_eval_count (int): Number of evaluation sessions @@ -130,13 +130,13 @@ def uem_timeline_from_file(uem_file, uniq_name=''): def score_labels( - AUDIO_RTTM_MAP, - all_reference, - all_hypothesis, - all_uem: List[List[float]]=None, - collar:float=0.25, - ignore_overlap: bool=True, - verbose: bool = True + AUDIO_RTTM_MAP, + all_reference, + all_hypothesis, + all_uem: List[List[float]] = None, + collar: float = 0.25, + ignore_overlap: bool = True, + verbose: bool = True, ) -> Optional[Tuple[DiarizationErrorRate, Dict]]: """ Calculate DER, CER, FA and MISS rate from hypotheses and references. Hypothesis results are @@ -170,7 +170,9 @@ def score_labels( if len(ref_labels.labels()) == len(hyp_labels.labels()): correct_spk_count += 1 if verbose and len(ref_labels.labels()) != len(hyp_labels.labels()): - logging.info(f"Wrong Spk. Count with uniq_id:...{ref_key[-10:]}, Ref: {len(ref_labels.labels())}, Hyp: {len(hyp_labels.labels())}") + logging.info( + f"Wrong Spk. Count with uniq_id:...{ref_key[-10:]}, Ref: {len(ref_labels.labels())}, Hyp: {len(hyp_labels.labels())}" + ) uem_obj = None if all_uem is not None: metric(ref_labels, hyp_labels, uem=all_uem[idx], detailed=True) @@ -189,7 +191,7 @@ def score_labels( CER = metric['confusion'] / metric['total'] FA = metric['false alarm'] / metric['total'] MISS = metric['missed detection'] / metric['total'] - + itemized_errors = (DER, CER, FA, MISS) if verbose: @@ -386,7 +388,7 @@ def calculate_session_cpWER( # Calculate WER for each speaker in hypothesis with reference # There are (number of hyp speakers) x (number of ref speakers) combinations lsa_wer_list = [] - for (spk_hyp_trans, spk_ref_trans) in all_pairs: + for spk_hyp_trans, spk_ref_trans in all_pairs: spk_wer = word_error_rate(hypotheses=[spk_hyp_trans], references=[spk_ref_trans]) lsa_wer_list.append(spk_wer) @@ -440,7 +442,7 @@ def concat_perm_word_error_rate( f"{len(spk_hypotheses)} and {len(spk_references)} correspondingly" ) cpWER_values, hyps_spk, refs_spk = [], [], [] - for (spk_hypothesis, spk_reference) in zip(spk_hypotheses, spk_references): + for spk_hypothesis, spk_reference in zip(spk_hypotheses, spk_references): cpWER, min_hypothesis, concat_reference = calculate_session_cpWER(spk_hypothesis, spk_reference) cpWER_values.append(cpWER) hyps_spk.append(min_hypothesis) diff --git a/nemo/collections/asr/metrics/multi_binary_acc.py b/nemo/collections/asr/metrics/multi_binary_acc.py index 72781143208b..13e57b43bb0b 100644 --- a/nemo/collections/asr/metrics/multi_binary_acc.py +++ b/nemo/collections/asr/metrics/multi_binary_acc.py @@ -68,6 +68,7 @@ def on_validation_epoch_end(self): f1_score (torch.Tensor): F1 score calculated from the predicted value and binarized target values. """ + full_state_update = False def __init__(self, dist_sync_on_step=False): @@ -80,7 +81,9 @@ def __init__(self, dist_sync_on_step=False): self.add_state("positive_count", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) self.eps = 1e-6 - def update(self, preds: torch.Tensor, targets: torch.Tensor, signal_lengths: torch.Tensor, cumulative=False) -> torch.Tensor: + def update( + self, preds: torch.Tensor, targets: torch.Tensor, signal_lengths: torch.Tensor, cumulative=False + ) -> torch.Tensor: with torch.no_grad(): preds_list = [preds[k, : signal_lengths[k], :] for k in range(preds.shape[0])] targets_list = [targets[k, : signal_lengths[k], :] for k in range(targets.shape[0])] @@ -99,7 +102,7 @@ def update(self, preds: torch.Tensor, targets: torch.Tensor, signal_lengths: tor self.false_negative_count += torch.sum(torch.logical_and(self.false, self.negative)) self.total_correct_counts += torch.sum(self.preds.round().bool() == self.targets.round().bool()) self.total_sample_counts += torch.prod(torch.tensor(self.targets.shape)) - else: + else: self.positive_count = torch.sum(self.preds.round().bool() == True) self.true_positive_count = torch.sum(torch.logical_and(self.true, self.positive)) self.false_positive_count = torch.sum(torch.logical_and(self.false, self.positive)) diff --git a/nemo/collections/asr/models/__init__.py b/nemo/collections/asr/models/__init__.py index 31194d8849f0..2573a7ac84b4 100644 --- a/nemo/collections/asr/models/__init__.py +++ b/nemo/collections/asr/models/__init__.py @@ -19,7 +19,6 @@ EncDecClassificationModel, EncDecFrameClassificationModel, ) -from nemo.collections.asr.models.sortformer_diar_models import SortformerEncLabelModel from nemo.collections.asr.models.clustering_diarizer import ClusteringDiarizer from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE from nemo.collections.asr.models.ctc_models import EncDecCTCModel @@ -36,5 +35,6 @@ from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel from nemo.collections.asr.models.slu_models import SLUIntentSlotBPEModel +from nemo.collections.asr.models.sortformer_diar_models import SortformerEncLabelModel from nemo.collections.asr.models.ssl_models import SpeechEncDecSelfSupervisedModel from nemo.collections.asr.models.transformer_bpe_models import EncDecTransfModelBPE diff --git a/nemo/collections/asr/models/sortformer_diar_models.py b/nemo/collections/asr/models/sortformer_diar_models.py index 50cdf6214d5b..665f439b0ad0 100644 --- a/nemo/collections/asr/models/sortformer_diar_models.py +++ b/nemo/collections/asr/models/sortformer_diar_models.py @@ -12,28 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. -import time import itertools import random -import torch +import time from collections import OrderedDict from typing import Dict, List, Optional, Union + +import torch from hydra.utils import instantiate from omegaconf import DictConfig from pytorch_lightning import Trainer from tqdm import tqdm -from nemo.core.classes import ModelPT -from nemo.core.classes.common import PretrainedModelInfo -from nemo.core.neural_types import AudioSignal, LengthsType, NeuralType -from nemo.core.neural_types.elements import ProbsType -from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations -from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config -from nemo.collections.asr.data.audio_to_diar_label_lhotse import LhotseAudioToSpeechE2ESpkDiarDataset + from nemo.collections.asr.data.audio_to_diar_label import AudioToSpeechE2ESpkDiarDataset +from nemo.collections.asr.data.audio_to_diar_label_lhotse import LhotseAudioToSpeechE2ESpkDiarDataset from nemo.collections.asr.metrics.multi_binary_acc import MultiBinaryAccuracy from nemo.collections.asr.models.asr_model import ExportableEncDecModel from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer -from nemo.collections.asr.parts.utils.asr_multispeaker_utils import get_pil_targets, get_ats_targets +from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations +from nemo.collections.asr.parts.utils.asr_multispeaker_utils import get_ats_targets, get_pil_targets +from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config +from nemo.core.classes import ModelPT +from nemo.core.classes.common import PretrainedModelInfo +from nemo.core.neural_types import AudioSignal, LengthsType, NeuralType +from nemo.core.neural_types.elements import ProbsType from nemo.utils import logging try: @@ -45,10 +47,12 @@ def autocast(enabled=None): yield -# torch.backends.cudnn.enabled = False + +# torch.backends.cudnn.enabled = False __all__ = ['SortformerEncLabelModel'] + class SortformerEncLabelModel(ModelPT, ExportableEncDecModel): """ Encoder class for Sortformer diarization model. @@ -80,7 +84,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): random.seed(42) self._trainer = trainer if trainer else None self._cfg = cfg - + if self._trainer: self.world_size = trainer.num_nodes * trainer.num_devices else: @@ -109,27 +113,27 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.streaming_mode = self._cfg.get("streaming_mode", False) self.save_hyperparameters("cfg") self._init_eval_metrics() - + speaker_inds = list(range(self._cfg.max_num_of_spks)) - self.speaker_permutations = torch.tensor(list(itertools.permutations(speaker_inds))) # Get all permutations - + self.speaker_permutations = torch.tensor(list(itertools.permutations(speaker_inds))) # Get all permutations + def _init_loss_weights(self): pil_weight = self._cfg.get("pil_weight", 0.0) ats_weight = self._cfg.get("ats_weight", 1.0) if pil_weight + ats_weight == 0: raise ValueError(f"weights for PIL {pil_weight} and ATS {ats_weight} cannot sum to 0") - self.pil_weight = pil_weight/(pil_weight + ats_weight) - self.ats_weight = ats_weight/(pil_weight + ats_weight) + self.pil_weight = pil_weight / (pil_weight + ats_weight) + self.ats_weight = ats_weight / (pil_weight + ats_weight) logging.info(f"Normalized weights for PIL {self.pil_weight} and ATS {self.ats_weight}") - + def _init_eval_metrics(self): - """ + """ If there is no label, then the evaluation metrics will be based on Permutation Invariant Loss (PIL). """ self._accuracy_test = MultiBinaryAccuracy() self._accuracy_train = MultiBinaryAccuracy() self._accuracy_valid = MultiBinaryAccuracy() - + self._accuracy_test_ats = MultiBinaryAccuracy() self._accuracy_train_ats = MultiBinaryAccuracy() self._accuracy_valid_ats = MultiBinaryAccuracy() @@ -137,11 +141,11 @@ def _init_eval_metrics(self): def _reset_train_metrics(self): self._accuracy_train.reset() self._accuracy_train_ats.reset() - + def _reset_valid_metrics(self): self._accuracy_valid.reset() self._accuracy_valid_ats.reset() - + def __setup_dataloader_from_config(self, config): # Switch to lhotse dataloader if specified in the config if config.get("use_lhotse"): @@ -168,7 +172,7 @@ def __setup_dataloader_from_config(self, config): global_rank = 0 time_flag = time.time() logging.info("AAB: Starting Dataloader Instance loading... Step A") - + dataset = AudioToSpeechE2ESpkDiarDataset( manifest_filepath=config.manifest_filepath, soft_label_thres=config.soft_label_thres, @@ -179,11 +183,13 @@ def __setup_dataloader_from_config(self, config): global_rank=global_rank, soft_targets=config.soft_targets if 'soft_targets' in config else False, ) - logging.info(f"AAB: Dataloader dataset is created, starting torch.utils.data.Dataloader step B: {time.time() - time_flag}") + logging.info( + f"AAB: Dataloader dataset is created, starting torch.utils.data.Dataloader step B: {time.time() - time_flag}" + ) self.data_collection = dataset.collection self.collate_ds = dataset - + dataloader_instance = torch.utils.data.DataLoader( dataset=dataset, batch_size=config.batch_size, @@ -195,15 +201,21 @@ def __setup_dataloader_from_config(self, config): ) logging.info(f"AAC: Dataloader Instance loading is done ETA Step B done: {time.time() - time_flag}") return dataloader_instance - + def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]): - self._train_dl = self.__setup_dataloader_from_config(config=train_data_config,) + self._train_dl = self.__setup_dataloader_from_config( + config=train_data_config, + ) def setup_validation_data(self, val_data_layer_config: Optional[Union[DictConfig, Dict]]): - self._validation_dl = self.__setup_dataloader_from_config(config=val_data_layer_config,) - + self._validation_dl = self.__setup_dataloader_from_config( + config=val_data_layer_config, + ) + def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): - self._test_dl = self.__setup_dataloader_from_config(config=test_data_config,) + self._test_dl = self.__setup_dataloader_from_config( + config=test_data_config, + ) def test_dataloader(self): if self._test_dl is not None: @@ -227,11 +239,11 @@ def output_types(self) -> Dict[str, NeuralType]: "preds": NeuralType(('B', 'T', 'C'), ProbsType()), } ) - + def frontend_encoder(self, processed_signal, processed_signal_length): - """ + """ Generate encoder outputs from frontend encoder. - + Args: process_signal (torch.Tensor): tensor containing audio-feature (mel spectrogram, mfcc, etc.) processed_signal_length (torch.Tensor): tensor containing lengths of audio signal in integers @@ -248,7 +260,7 @@ def frontend_encoder(self, processed_signal, processed_signal_length): emb_seq = emb_seq.transpose(1, 2) if self._cfg.encoder.d_model != self._cfg.tf_d_model: self.sortformer_modules.encoder_proj = self.sortformer_modules.encoder_proj.to(self.device) - emb_seq = self.sortformer_modules.encoder_proj(emb_seq) + emb_seq = self.sortformer_modules.encoder_proj(emb_seq) return emb_seq, emb_seq_length def forward_infer(self, emb_seq): @@ -258,7 +270,7 @@ def forward_infer(self, emb_seq): Args: emb_seq (torch.Tensor): tensor containing FastConformer encoder states (embedding vectors). Dimension: (batch_size, diar_frame_count, emb_dim) - + Returns: preds (torch.Tensor): Sorted tensor containing Sigmoid values for predicted speaker labels. Dimension: (batch_size, diar_frame_count, num_speakers) @@ -269,9 +281,9 @@ def forward_infer(self, emb_seq): trans_emb_seq = self.transformer_encoder(encoder_states=emb_seq, encoder_mask=encoder_mask) preds = self.sortformer_modules.forward_speaker_sigmoids(trans_emb_seq) return preds - + def process_signal(self, audio_signal, audio_signal_length): - """ + """ Extract audio features from time-series signal for further processing in the model. This function performs the following steps: @@ -293,43 +305,49 @@ def process_signal(self, audio_signal, audio_signal_length): Shape: (batch_size,) """ audio_signal = audio_signal.to(self.device) - audio_signal = (1/(audio_signal.max()+self.eps)) * audio_signal - processed_signal, processed_signal_length = self.preprocessor(input_signal=audio_signal, length=audio_signal_length) + audio_signal = (1 / (audio_signal.max() + self.eps)) * audio_signal + processed_signal, processed_signal_length = self.preprocessor( + input_signal=audio_signal, length=audio_signal_length + ) return processed_signal, processed_signal_length - + def forward( - self, - audio_signal, - audio_signal_length, + self, + audio_signal, + audio_signal_length, ): """ Forward pass for training and inference. - + Args: audio_signal (torch.Tensor): tensor containing audio waveform Dimension: (batch_size, num_samples) audio_signal_length (torch.Tensor): tensor containing lengths of audio waveforms Dimension: (batch_size,) - + Returns: preds (torch.Tensor): Sorted tensor containing predicted speaker labels Dimension: (batch_size, diar_frame_count, num_speakers) encoder_states_list (list): List containing total speaker memory for each step for debugging purposes Dimension: [(batch_size, diar_frame_count, inner dim), ] """ - processed_signal, processed_signal_length = self.process_signal(audio_signal=audio_signal, audio_signal_length=audio_signal_length) - processed_signal = processed_signal[:, :, :processed_signal_length.max()] + processed_signal, processed_signal_length = self.process_signal( + audio_signal=audio_signal, audio_signal_length=audio_signal_length + ) + processed_signal = processed_signal[:, :, : processed_signal_length.max()] if self._cfg.get("streaming_mode", False): raise NotImplementedError("Streaming mode is not implemented yet.") else: - emb_seq, _ = self.frontend_encoder(processed_signal=processed_signal, processed_signal_length=processed_signal_length) + emb_seq, _ = self.frontend_encoder( + processed_signal=processed_signal, processed_signal_length=processed_signal_length + ) preds = self.forward_infer(emb_seq) return preds - + def _get_aux_train_evaluations(self, preds, targets, target_lens): - """ + """ Compute auxiliary training evaluations including losses and metrics. - + This function calculates various losses and metrics for the training process, including ATS (Anchored Temporal Segmentation) and PIL (Permutation Invariant Loss) based evaluations. @@ -366,7 +384,7 @@ def _get_aux_train_evaluations(self, preds, targets, target_lens): 'train_precision': train_precision, 'train_recall': train_recall, 'train_f1_acc_ats': train_f1_acc_ats, - } + } return train_metrics def training_step(self, batch: list) -> dict: @@ -392,7 +410,7 @@ def training_step(self, batch: list) -> dict: return {'loss': train_metrics['loss']} def _get_aux_validation_evaluations(self, preds, targets, target_lens): - """ + """ Compute auxiliary validation evaluations including losses and metrics. This function calculates various losses and metrics for the validation process, including ATS (Anchored Temporal Segmentation) and PIL (Permutation Invariant Loss) @@ -482,7 +500,7 @@ def multi_validation_epoch_end(self, outputs: list, dataloader_idx: int = 0): val_f1_acc_ats_mean = torch.stack([x['val_f1_acc_ats'] for x in outputs]).mean() self._reset_valid_metrics() - + multi_val_metrics = { 'val_loss': val_loss_mean, 'val_ats_loss': val_ats_loss_mean, @@ -495,9 +513,9 @@ def multi_validation_epoch_end(self, outputs: list, dataloader_idx: int = 0): return {'log': multi_val_metrics} def _get_aux_test_batch_evaluations(self, batch_idx: int, preds, targets, target_lens): - """ + """ Compute auxiliary validation evaluations including losses and metrics. - + This function calculates various losses and metrics for the validation process, including ATS (Anchored Temporal Segmentation) and PIL (Permutation Invariant Loss) based evaluations. @@ -525,19 +543,29 @@ def _get_aux_test_batch_evaluations(self, batch_idx: int, preds, targets, target self._accuracy_test_ats(preds, targets_ats, target_lens) f1_acc_ats, precision_ats, recall_ats = self._accuracy_test_ats.compute() self.batch_f1_accs_ats_list.append(f1_acc_ats) - logging.info(f"batch {batch_idx}: f1_acc_ats={f1_acc_ats}, precision_ats={precision_ats}, recall_ats={recall_ats}") + logging.info( + f"batch {batch_idx}: f1_acc_ats={f1_acc_ats}, precision_ats={precision_ats}, recall_ats={recall_ats}" + ) self._accuracy_test.reset() self._accuracy_test_ats.reset() - def test_batch(self,): - """ + def test_batch( + self, + ): + """ Perform batch testing on the model. - + This method iterates through the test data loader, making predictions for each batch, and calculates various evaluation metrics. It handles both single and multi-sample batches. """ - self.preds_total_list, self.batch_f1_accs_list, self.batch_precision_list, self.batch_recall_list, self.batch_f1_accs_ats_list = [], [], [], [], [] + ( + self.preds_total_list, + self.batch_f1_accs_list, + self.batch_precision_list, + self.batch_recall_list, + self.batch_f1_accs_ats_list, + ) = ([], [], [], [], []) with torch.no_grad(): for batch_idx, batch in enumerate(tqdm(self._test_dl)): @@ -549,7 +577,7 @@ def test_batch(self,): audio_signal_length=audio_signal_length, ) preds = preds.detach().to('cpu') - if preds.shape[0] == 1: # batch size = 1 + if preds.shape[0] == 1: # batch size = 1 self.preds_total_list.append(preds) else: self.preds_total_list.extend(torch.split(preds, [1] * preds.shape[0])) @@ -562,5 +590,7 @@ def test_batch(self,): logging.info(f"Batch Recall MEAN: {torch.mean(torch.tensor(self.batch_recall_list))}") logging.info(f"Batch ATS F1Acc. MEAN: {torch.mean(torch.tensor(self.batch_f1_accs_ats_list))}") - def diarize(self,): + def diarize( + self, + ): raise NotImplementedError diff --git a/nemo/collections/asr/modules/sortformer_modules.py b/nemo/collections/asr/modules/sortformer_modules.py index 823cf98590e7..6ed29d3e6a70 100644 --- a/nemo/collections/asr/modules/sortformer_modules.py +++ b/nemo/collections/asr/modules/sortformer_modules.py @@ -55,6 +55,7 @@ class SortformerModules(NeuralModule, Exportable): If 'cos_sim', cosine similarity values are used for the input of the sequence models. If 'elem_prod', element-wise product values are used for the input of the sequence models. """ + def init_weights(self, m): if type(m) == nn.Linear: torch.nn.init.xavier_uniform_(m.weight) @@ -91,9 +92,9 @@ def length_to_mask(self, context_embs): mask (torch.Tensor): tensor of shape (batch_size, max_len) containing 0's in the padded region and 1's elsewhere """ - lengths = torch.tensor([context_embs.shape[1]] * context_embs.shape[0]) + lengths = torch.tensor([context_embs.shape[1]] * context_embs.shape[0]) batch_size = context_embs.shape[0] - max_len=context_embs.shape[1] + max_len = context_embs.shape[1] # create a tensor with the shape (batch_size, 1) filled with ones row_vector = torch.arange(max_len).unsqueeze(0).expand(batch_size, -1).to(lengths.device) # create a tensor with the shape (batch_size, max_len) filled with lengths @@ -101,7 +102,7 @@ def length_to_mask(self, context_embs): # create a mask by comparing the row vector and length matrix mask = row_vector < length_matrix return mask.float().to(context_embs.device) - + def forward_speaker_sigmoids(self, hidden_out): hidden_out = self.dropout(F.relu(hidden_out)) hidden_out = self.first_hidden_to_hidden(hidden_out) diff --git a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py index a1d34e1f7480..a52271a5e83b 100644 --- a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py +++ b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py @@ -12,39 +12,39 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import re +import concurrent.futures import copy +import itertools +import logging import math +import os import random -import logging -import itertools -from copy import deepcopy -import concurrent.futures -from cytoolz import groupby +import re from collections import defaultdict -from typing import Dict, Optional, Tuple, List +from copy import deepcopy +from typing import Dict, List, Optional, Tuple import numpy as np import soundfile -from tqdm import tqdm -from scipy.stats import norm - import torch.utils.data +from cytoolz import groupby +from lhotse import AudioSource, Recording, SupervisionSegment, SupervisionSet, dill_enabled +from lhotse.cut import CutSet, MixedCut, MixTrack, MonoCut from lhotse.cut.set import mix -from lhotse.cut import CutSet, MixedCut, MonoCut, MixTrack -from lhotse import SupervisionSet, SupervisionSegment, dill_enabled, AudioSource, Recording from lhotse.utils import uuid4 +from scipy.stats import norm +from tqdm import tqdm -def find_first_nonzero(mat: torch.Tensor, max_cap_val=-1, thres:float = 0.5) -> torch.Tensor: - """ + +def find_first_nonzero(mat: torch.Tensor, max_cap_val=-1, thres: float = 0.5) -> torch.Tensor: + """ Finds the first nonzero value in the matrix, discretizing it to the specified maximum capacity. - + Args: mat (Tensor): A torch tensor representing the matrix. max_cap_val (int): The maximum capacity to which the matrix values will be discretized. thres (float): The threshold value for discretizing the matrix values. - + Returns: mask_max_indices (Tensor): A torch tensor representing the discretized matrix with the first nonzero value in each row. """ @@ -61,6 +61,7 @@ def find_first_nonzero(mat: torch.Tensor, max_cap_val=-1, thres:float = 0.5) -> mask_max_indices[mask_max_values == 0] = max_cap_val return mask_max_indices + def find_best_permutation(match_score: torch.Tensor, speaker_permutations: torch.Tensor) -> torch.Tensor: """ Finds the best permutation indices based on the match score. @@ -78,9 +79,12 @@ def find_best_permutation(match_score: torch.Tensor, speaker_permutations: torch batch_best_perm = torch.argmax(match_score, axis=1) rep_speaker_permutations = speaker_permutations.repeat(batch_best_perm.shape[0], 1).to(match_score.device) perm_size = speaker_permutations.shape[0] - global_inds_vec = torch.arange(0, perm_size * batch_best_perm.shape[0], perm_size).to(batch_best_perm.device) + batch_best_perm + global_inds_vec = ( + torch.arange(0, perm_size * batch_best_perm.shape[0], perm_size).to(batch_best_perm.device) + batch_best_perm + ) return rep_speaker_permutations[global_inds_vec.to(rep_speaker_permutations.device), :] + def reconstruct_labels(labels: torch.Tensor, batch_perm_inds: torch.Tensor) -> torch.Tensor: """ Reconstructs the labels using the best permutation indices with matrix operations. @@ -103,12 +107,13 @@ def reconstruct_labels(labels: torch.Tensor, batch_perm_inds: torch.Tensor) -> t reconstructed_labels = torch.gather(labels, 2, batch_perm_inds_exp) return reconstructed_labels + def get_ats_targets( - labels: torch.Tensor, - preds: torch.Tensor, - speaker_permutations: torch.Tensor, - thres: float = 0.5, - tolerance: float = 0 + labels: torch.Tensor, + preds: torch.Tensor, + speaker_permutations: torch.Tensor, + thres: float = 0.5, + tolerance: float = 0, ) -> torch.Tensor: """ Sorts labels and predictions to get the optimal of all arrival-time ordered permutations. @@ -128,25 +133,36 @@ def get_ats_targets( Shape: (batch_size, num_frames, num_speakers) """ # Find the first nonzero frame index for each speaker in each batch - nonzero_ind = find_first_nonzero(mat=labels, max_cap_val=labels.shape[1], thres=thres) # (batch_size, num_speakers) - + nonzero_ind = find_first_nonzero( + mat=labels, max_cap_val=labels.shape[1], thres=thres + ) # (batch_size, num_speakers) + # Sort the first nonzero frame indices for arrival-time ordering sorted_values = torch.sort(nonzero_ind)[0] # (batch_size, num_speakers) perm_size = speaker_permutations.shape[0] # Scalar value (num_permutations) permed_labels = labels[:, :, speaker_permutations] # (batch_size, num_frames, num_permutations, num_speakers) - permed_nonzero_ind = find_first_nonzero(mat=permed_labels, max_cap_val=labels.shape[1]) # (batch_size, num_permutations, num_speakers) + permed_nonzero_ind = find_first_nonzero( + mat=permed_labels, max_cap_val=labels.shape[1] + ) # (batch_size, num_permutations, num_speakers) # Compare the first frame indices of sorted labels with those of the permuted labels using tolerance - perm_compare = torch.abs(sorted_values.unsqueeze(1) - permed_nonzero_ind) <= tolerance # (batch_size, num_permutations, num_speakers) + perm_compare = ( + torch.abs(sorted_values.unsqueeze(1) - permed_nonzero_ind) <= tolerance + ) # (batch_size, num_permutations, num_speakers) perm_mask = torch.all(perm_compare, dim=2).float() # (batch_size, num_permutations) - preds_rep = torch.unsqueeze(preds, 2).repeat(1, 1, perm_size, 1) # Exapnd the preds: (batch_size, num_frames, num_permutations, num_speakers) + preds_rep = torch.unsqueeze(preds, 2).repeat( + 1, 1, perm_size, 1 + ) # Exapnd the preds: (batch_size, num_frames, num_permutations, num_speakers) # Compute the match score for each permutation by comparing permuted labels with preds - match_score = torch.sum(permed_labels * preds_rep, axis=1).sum(axis=2) * perm_mask # (batch_size, num_permutations) + match_score = ( + torch.sum(permed_labels * preds_rep, axis=1).sum(axis=2) * perm_mask + ) # (batch_size, num_permutations) batch_perm_inds = find_best_permutation(match_score, speaker_permutations) # (batch_size, num_speakers) max_score_permed_labels = reconstruct_labels(labels, batch_perm_inds) # (batch_size, num_frames, num_speakers) return max_score_permed_labels # (batch_size, num_frames, num_speakers) + def get_pil_targets(labels: torch.Tensor, preds: torch.Tensor, speaker_permutations: torch.Tensor) -> torch.Tensor: """ Sorts labels and predictions to get the optimal permutation based on the match score. @@ -166,23 +182,26 @@ def get_pil_targets(labels: torch.Tensor, preds: torch.Tensor, speaker_permutati perm_size = speaker_permutations.shape[0] # Scalar value (num_permutations) permed_labels = labels[:, :, speaker_permutations] # (batch_size, num_classes, num_permutations, num_speakers) # Repeat preds to match permutations for comparison - preds_rep = torch.unsqueeze(preds, 2).repeat(1, 1, speaker_permutations.shape[0], 1) # (batch_size, num_speakers, num_permutations, num_classes) + preds_rep = torch.unsqueeze(preds, 2).repeat( + 1, 1, speaker_permutations.shape[0], 1 + ) # (batch_size, num_speakers, num_permutations, num_classes) match_score = torch.sum(permed_labels * preds_rep, axis=1).sum(axis=2) # (batch_size, num_permutations) batch_perm_inds = find_best_permutation(match_score, speaker_permutations) # (batch_size, num_speakers) # Reconstruct labels based on the best permutation for each batch max_score_permed_labels = reconstruct_labels(labels, batch_perm_inds) # (batch_size, num_speakers, num_classes) return max_score_permed_labels # (batch_size, num_speakers, num_classes) + def apply_spk_mapping(diar_preds: torch.Tensor, spk_mappings: torch.Tensor) -> torch.Tensor: - """ + """ Applies a speaker mapping to diar predictions. Args: - diar_preds (Tensor): The diar predictions tensor. + diar_preds (Tensor): The diar predictions tensor. Dimension: (batch_size, num_frames, num_speakers) spk_mappings (Tensor): The speaker mappings tensor. Dimension: (batch_size, num_speakers) - + Returns: permuted_diar_preds (Tensor): The permuted diar predictions tensor with the given speaker mappings. """ @@ -190,15 +209,18 @@ def apply_spk_mapping(diar_preds: torch.Tensor, spk_mappings: torch.Tensor) -> t permuted_diar_preds = torch.gather(diar_preds, 2, expanded_mappings) return permuted_diar_preds -def shuffle_spk_mapping(cuts: list, num_speakers: int, shuffle_spk_mapping: bool = False, pattern= r'<\|spltoken\d+\|>') -> Tuple[CutSet, torch.Tensor]: - """ + +def shuffle_spk_mapping( + cuts: list, num_speakers: int, shuffle_spk_mapping: bool = False, pattern=r'<\|spltoken\d+\|>' +) -> Tuple[CutSet, torch.Tensor]: + """ Applies a shuffle mapping to speaker text labels in the cuts. Example: Original cut.text: - "<|spltoken0|> we do shuffle <|spltoken1|> and map speakers <|spltoken0|> yes <|spltoken2|> we keep dimensions" + "<|spltoken0|> we do shuffle <|spltoken1|> and map speakers <|spltoken0|> yes <|spltoken2|> we keep dimensions" Speaker Mapping: [3, 0, 1, 2] Shuffled cut.text: - "<|spltoken3|> we do shuffle <|spltoken0|> and map speakers <|spltoken3|> yes <|spltoken1|> we keep dimensions" + "<|spltoken3|> we do shuffle <|spltoken0|> and map speakers <|spltoken3|> yes <|spltoken1|> we keep dimensions" Args: cuts (List[MonoCut, MixedCut]): A list of Cut instances. @@ -208,11 +230,11 @@ def shuffle_spk_mapping(cuts: list, num_speakers: int, shuffle_spk_mapping: bool Returns: cuts (list): The updated CutSet with shuffled speaker mappings. - spk_mappings (Tensor): + spk_mappings (Tensor): If shuffle_speaker_mapping is True, shuffled speaker mappings in batch. If shuffle_speaker_mapping is False, speaker mappings in batch is not permuted and returns torch.arange() values. - """ - batch_size = len(cuts) + """ + batch_size = len(cuts) if shuffle_spk_mapping: permuted_indices = torch.rand(batch_size, num_speakers).argsort(dim=1) spk_mappings = torch.gather(torch.arange(num_speakers).repeat(batch_size, 1), 1, permuted_indices) @@ -220,9 +242,9 @@ def shuffle_spk_mapping(cuts: list, num_speakers: int, shuffle_spk_mapping: bool left_str, right_str = str_pattern.split('d+')[0], str_pattern.split('d+')[1] for idx, cut in enumerate(cuts): word_list = [] - for word in deepcopy(cut.text).split(): + for word in deepcopy(cut.text).split(): if len(re.findall(pattern, word)) > 0: - spk_token_int = int(word.replace(left_str,'').replace(right_str, '')) + spk_token_int = int(word.replace(left_str, '').replace(right_str, '')) new_spk = spk_mappings[idx][spk_token_int] word_list.append(f'{left_str}{new_spk}{right_str}') else: @@ -230,16 +252,18 @@ def shuffle_spk_mapping(cuts: list, num_speakers: int, shuffle_spk_mapping: bool cuts[idx].supervisions[0].text = ' '.join(word_list) else: spk_mappings = torch.arange(num_speakers).unsqueeze(0).repeat(batch_size, 1) - return cuts, spk_mappings + return cuts, spk_mappings + def find_segments_from_rttm( - recording_id: str, - rttms, - start_after: float, - end_before: float, - adjust_offset: bool=True, - tolerance: float=0.001): - """ + recording_id: str, + rttms, + start_after: float, + end_before: float, + adjust_offset: bool = True, + tolerance: float = 0.001, +): + """ Finds segments from the given rttm file. This function is designed to replace rttm @@ -250,35 +274,36 @@ def find_segments_from_rttm( end_before (float): The end time before which segments are selected. adjust_offset (bool): Whether to adjust the offset of the segments. tolerance (float): The tolerance for time matching. 0.001 by default. - + Returns: segments (List[SupervisionSegment]): A list of SupervisionSegment instances. """ segment_by_recording_id = rttms._segments_by_recording_id if segment_by_recording_id is None: from cytoolz import groupby + segment_by_recording_id = groupby(lambda seg: seg.recording_id, rttms) return [ - # We only modify the offset - the duration remains the same, as we're only shifting the segment - # relative to the Cut's start, and not truncating anything. - segment.with_offset(-start_after) if adjust_offset else segment - for segment in segment_by_recording_id.get(recording_id, []) - if segment.start < end_before + tolerance - and segment.end > start_after + tolerance - ] + # We only modify the offset - the duration remains the same, as we're only shifting the segment + # relative to the Cut's start, and not truncating anything. + segment.with_offset(-start_after) if adjust_offset else segment + for segment in segment_by_recording_id.get(recording_id, []) + if segment.start < end_before + tolerance and segment.end > start_after + tolerance + ] + def speaker_to_target( a_cut, - num_speakers: int = 4, - num_sample_per_mel_frame: int = 160, - num_mel_frame_per_asr_frame: int = 8, + num_speakers: int = 4, + num_sample_per_mel_frame: int = 160, + num_mel_frame_per_asr_frame: int = 8, spk_tar_all_zero: bool = False, boundary_segments: bool = False, soft_label: bool = False, ignore_num_spk_mismatch: bool = True, soft_thres: float = 0.5, - ): +): ''' Get rttm samples corresponding to one cut, generate speaker mask numpy.ndarray with shape (num_speaker, hidden_length) This function is needed for speaker diarization with ASR model trainings. @@ -292,7 +317,7 @@ def speaker_to_target( boundary_segments (bool): set to True to include segments containing the boundary of the cut, False by default for multi-speaker ASR training soft_label (bool): set to True to use soft label that enables values in [0, 1] range, False by default and leads to binary labels. ignore_num_spk_mismatch (bool): This is a temporary solution to handle speaker mismatch. Will be removed in the future. - + Returns: mask (Tensor): speaker mask with shape (num_speaker, hidden_lenght) ''' @@ -306,14 +331,18 @@ def speaker_to_target( offsets = [0] else: raise ValueError(f"Unsupported cut type type{a_cut}: only MixedCut and MonoCut are supported") - + segments_total = [] for i, cut in enumerate(cut_list): rttms = SupervisionSet.from_rttm(cut.rttm_filepath) - if boundary_segments: # segments with seg_start < total_end and seg_end > total_start are included - segments_iterator = find_segments_from_rttm(recording_id=cut.recording_id, rttms=rttms, start_after=cut.start, end_before=cut.end, tolerance=0.0) - else: # segments with seg_start > total_start and seg_end < total_end are included - segments_iterator = rttms.find(recording_id=cut.recording_id, start_after=cut.start, end_before=cut.end, adjust_offset=True) + if boundary_segments: # segments with seg_start < total_end and seg_end > total_start are included + segments_iterator = find_segments_from_rttm( + recording_id=cut.recording_id, rttms=rttms, start_after=cut.start, end_before=cut.end, tolerance=0.0 + ) + else: # segments with seg_start > total_start and seg_end < total_end are included + segments_iterator = rttms.find( + recording_id=cut.recording_id, start_after=cut.start, end_before=cut.end, adjust_offset=True + ) for seg in segments_iterator: if seg.start < 0: @@ -323,28 +352,31 @@ def speaker_to_target( seg.duration -= seg.end - cut.duration seg.start += offsets[i] segments_total.append(seg) - + # apply arrival time sorting to the existing segments - segments_total.sort(key = lambda rttm_sup: rttm_sup.start) + segments_total.sort(key=lambda rttm_sup: rttm_sup.start) seen = set() seen_add = seen.add speaker_ats = [s.speaker for s in segments_total if not (s.speaker in seen or seen_add(s.speaker))] - - speaker_to_idx_map = { - spk: idx - for idx, spk in enumerate(speaker_ats) - } + + speaker_to_idx_map = {spk: idx for idx, spk in enumerate(speaker_ats)} if len(speaker_to_idx_map) > num_speakers and not ignore_num_spk_mismatch: # raise error if number of speakers - raise ValueError(f"Number of speakers {len(speaker_to_idx_map)} is larger than the maximum number of speakers {num_speakers}") - + raise ValueError( + f"Number of speakers {len(speaker_to_idx_map)} is larger than the maximum number of speakers {num_speakers}" + ) + # initialize mask matrices (num_speaker, encoder_hidden_len) - feat_per_sec = int(a_cut.sampling_rate / num_sample_per_mel_frame) # 100 by default - num_samples = get_hidden_length_from_sample_length(a_cut.num_samples, num_sample_per_mel_frame, num_mel_frame_per_asr_frame) - if spk_tar_all_zero: + feat_per_sec = int(a_cut.sampling_rate / num_sample_per_mel_frame) # 100 by default + num_samples = get_hidden_length_from_sample_length( + a_cut.num_samples, num_sample_per_mel_frame, num_mel_frame_per_asr_frame + ) + if spk_tar_all_zero: frame_mask = torch.zeros((num_samples, num_speakers)) else: - frame_mask = get_mask_from_segments(segments_total, a_cut, speaker_to_idx_map, num_speakers, feat_per_sec, ignore_num_spk_mismatch) + frame_mask = get_mask_from_segments( + segments_total, a_cut, speaker_to_idx_map, num_speakers, feat_per_sec, ignore_num_spk_mismatch + ) soft_mask = get_soft_mask(frame_mask, num_samples, num_mel_frame_per_asr_frame) if soft_label: @@ -354,11 +386,19 @@ def speaker_to_target( return mask -def get_mask_from_segments(segments: list, a_cut, speaker_to_idx_map: torch.Tensor, num_speakers: int =4, feat_per_sec: int=100, ignore_num_spk_mismatch: bool = False): - """ + +def get_mask_from_segments( + segments: list, + a_cut, + speaker_to_idx_map: torch.Tensor, + num_speakers: int = 4, + feat_per_sec: int = 100, + ignore_num_spk_mismatch: bool = False, +): + """ Generate mask matrix from segments list. This function is needed for speaker diarization with ASR model trainings. - + Args: segments: A list of Lhotse Supervision segments iterator. cut (MonoCut, MixedCut): Lhotse MonoCut or MixedCut instance. @@ -366,13 +406,13 @@ def get_mask_from_segments(segments: list, a_cut, speaker_to_idx_map: torch.Tens num_speakers (int): max number of speakers for all cuts ("mask" dim0), 4 by default feat_per_sec (int): number of frames per second, 100 by default, 0.01s frame rate ignore_num_spk_mismatch (bool): This is a temporary solution to handle speaker mismatch. Will be removed in the future. - + Returns: mask (Tensor): A numpy array of shape (num_speakers, encoder_hidden_len). Dimension: (num_speakers, num_frames) """ # get targets with 0.01s frame rate - num_samples = round(a_cut.duration * feat_per_sec) + num_samples = round(a_cut.duration * feat_per_sec) mask = torch.zeros((num_samples, num_speakers)) for rttm_sup in segments: speaker_idx = speaker_to_idx_map[rttm_sup.speaker] @@ -388,17 +428,18 @@ def get_mask_from_segments(segments: list, a_cut, speaker_to_idx_map: torch.Tens mask[stf:enf, speaker_idx] = 1.0 return mask + def get_soft_mask(feat_level_target, num_samples, stride): """ Get soft mask from feat_level_target with stride. This function is needed for speaker diarization with ASR model trainings. - + Args: feat_level_target (Tensor): A numpy array of shape (num_frames, num_speakers). Dimension: (num_frames, num_speakers) num_sample (int): The total number of samples. stride (int): The stride for the mask. - """ + """ num_speakers = feat_level_target.shape[1] mask = torch.zeros(num_samples, num_speakers) @@ -412,15 +453,14 @@ def get_soft_mask(feat_level_target, num_samples, stride): seg_end_feat = feat_level_target.shape[0] else: seg_end_feat = stride * index - 1 + int(stride / 2) - mask[index] = torch.mean(feat_level_target[seg_stt_feat:seg_end_feat+1, :], axis=0) + mask[index] = torch.mean(feat_level_target[seg_stt_feat : seg_end_feat + 1, :], axis=0) return mask + def get_hidden_length_from_sample_length( - num_samples: int, - num_sample_per_mel_frame: int = 160, - num_mel_frame_per_asr_frame: int = 8 + num_samples: int, num_sample_per_mel_frame: int = 160, num_mel_frame_per_asr_frame: int = 8 ) -> int: - """ + """ Calculate the hidden length from the given number of samples. This function is needed for speaker diarization with ASR model trainings. @@ -439,15 +479,16 @@ def get_hidden_length_from_sample_length( hidden_length = math.ceil(mel_frame_count / num_mel_frame_per_asr_frame) return int(hidden_length) -class ConcatenationMeetingSimulator(): + +class ConcatenationMeetingSimulator: """ This simulator concatenates the segments from different/same sessions to create a - multi-speaker meeting. + multi-speaker meeting. """ def __init__( self, - intra_session_concat_prob: float|List[float] = [0, 1.0, 0.5, 0.2], + intra_session_concat_prob: float | List[float] = [0, 1.0, 0.5, 0.2], data_type: str = "msasr", min_duration: float = 30.0, max_duration: float = 40.0, @@ -460,7 +501,7 @@ def __init__( :param intra_session_concat_prob: the probability of concatenating segments from the same session. [Default: 1] :param data_type: the type of data to simulate. Either 'msasr' or 'diar'. If 'msasr', - the transcripts are included in the simulation,and the boundary segments are + the transcripts are included in the simulation,and the boundary segments are not included. [Default: 'msasr'] :param max_duration: the maximum duration of the simulated meeting. [Default: 40.0] """ @@ -470,7 +511,9 @@ def __init__( elif len(intra_session_concat_prob) == max_num_speakers: self.intra_session_concat_prob = intra_session_concat_prob else: - raise ValueError(f"intra_session_concat_prob must be either a float or a list of floats, but got {intra_session_concat_prob}") + raise ValueError( + f"intra_session_concat_prob must be either a float or a list of floats, but got {intra_session_concat_prob}" + ) if data_type not in ["msasr", "diar"]: raise ValueError("data_type must be either 'msasr' or 'diar', but got {data_type}") self.data_type = data_type @@ -478,7 +521,9 @@ def __init__( self.max_duration = max_duration self.max_num_speakers = max_num_speakers self.speaker_count_distribution = speaker_count_distribution - assert len(speaker_count_distribution) == max_num_speakers, f"Length of speaker_count_distribution {len(speaker_count_distribution)} must be equal to max_num_speakers {max_num_speakers}" + assert ( + len(speaker_count_distribution) == max_num_speakers + ), f"Length of speaker_count_distribution {len(speaker_count_distribution)} must be equal to max_num_speakers {max_num_speakers}" if skip_long_segments: self.skip_duration = max_duration / 2 @@ -489,7 +534,7 @@ def __init__( def fit(self, cuts) -> CutSet: """ - Read the manifest file and return a CutSet object. + Read the manifest file and return a CutSet object. Each line in the manifest file should be a JSON object representing a segment. """ @@ -500,7 +545,7 @@ def fit(self, cuts) -> CutSet: self.spk2cut_ids = defaultdict(list) self.data2num_spk2cut_ids = {} self.sess2num_spk2cut_ids = {} - self.num_spk2cut_ids = {i+1:[] for i in range(self.max_num_speakers)} + self.num_spk2cut_ids = {i + 1: [] for i in range(self.max_num_speakers)} for i, cut in tqdm(enumerate(cuts), desc="Reading segments", ncols=100, total=len(cuts)): if cut.duration > self.skip_duration: continue @@ -512,20 +557,20 @@ def fit(self, cuts) -> CutSet: self.data2num_spk2cut_ids[cut.dataset_id] = defaultdict(list) if cut.recording_id not in self.sess2num_spk2cut_ids: self.sess2num_spk2cut_ids[cut.recording_id] = defaultdict(list) - + speakers = cut.global_speaker_ids if self.data_type == "msasr": speaker_tokens = set(re.findall(r'<\|spltoken\d+\|>', cut.text)) - if len(speakers) != len(speaker_tokens): - # Lhotse automatically fixes the max duration of the cut, - # resulting in the mismatch of the number of speakers + if len(speakers) != len(speaker_tokens): + # Lhotse automatically fixes the max duration of the cut, + # resulting in the mismatch of the number of speakers # and speaker tokens for the last segment # TODO: need to fix the issue in Lhotse that automatically fixes the max duration continue for spk in speakers: self.spk2cut_ids[spk].append(cut.id) self.sess2spks[cut.recording_id] = self.sess2spks[cut.recording_id].union(speakers) - + self.id2cut[cut.id] = cut self.sess2cut_ids[cut.recording_id].append(cut.id) self.data2num_spk2cut_ids[cut.dataset_id][len(speakers)].append(cut.id) @@ -533,23 +578,21 @@ def fit(self, cuts) -> CutSet: self.num_spk2cut_ids[len(speakers)].append(cut.id) if cut.recording_id not in self.data2sess_ids[cut.dataset_id]: self.data2sess_ids[cut.dataset_id].append(cut.recording_id) - + self.cut_ids = list(self.id2cut.keys()) self.num_spk2sess_ids = groupby(lambda x: len(self.sess2spks[x]), self.sess2spks.keys()) - - self.data2global_speaker = { - dataset_id: True for dataset_id in self.data2sess_ids.keys() - } - + + self.data2global_speaker = {dataset_id: True for dataset_id in self.data2sess_ids.keys()} + def _create_mixture(self, n_speakers: int, is_intra_session_concat=False) -> MixedCut: - db_norm = norm.rvs(-32.05957708631966, 5.66648411405886) # mean and std from Fisher data - + db_norm = norm.rvs(-32.05957708631966, 5.66648411405886) # mean and std from Fisher data + if is_intra_session_concat: # intra-dataset and intra-session concatenation tracks, num_speakers = self.get_intra_session_tracks(n_speakers, db_norm=db_norm) - else: + else: # intra-dataset but inter-session concatenation tracks, num_speakers = self.get_inter_session_tracks(n_speakers, db_norm=db_norm) @@ -557,44 +600,54 @@ def _create_mixture(self, n_speakers: int, is_intra_session_concat=False) -> Mix if self.data_type == "msasr": cut = self.reorder_spk_mapping(cut) - assert self.min_duration <= cut.duration <= self.max_duration, f"Total duration {cut.duration} is not within the range of min {self.min_duration} and max {self.max_duration}" - assert n_speakers == num_speakers, f"Total number of speakers {cut.num_speakers} is not equal to the number of speakers {n_speakers}" + assert ( + self.min_duration <= cut.duration <= self.max_duration + ), f"Total duration {cut.duration} is not within the range of min {self.min_duration} and max {self.max_duration}" + assert ( + n_speakers == num_speakers + ), f"Total number of speakers {cut.num_speakers} is not equal to the number of speakers {n_speakers}" return cut - - def get_intra_session_tracks(self, n_speakers: int=4, db_norm: float=-25) -> List[MixTrack]: + + def get_intra_session_tracks(self, n_speakers: int = 4, db_norm: float = -25) -> List[MixTrack]: """ Get the tracks for the MixedCut object. """ session_id = random.choice(self.num_spk2sess_ids[n_speakers]) - + total_duration = 0.0 total_spk_set = set() tracks = [] while True: cut = self.id2cut[random.choice(self.sess2cut_ids[session_id])] - tracks.append(MixTrack(cut=deepcopy(cut.normalize_loudness(target=db_norm, mix_first=False)), type=type(cut), offset=total_duration)) + tracks.append( + MixTrack( + cut=deepcopy(cut.normalize_loudness(target=db_norm, mix_first=False)), + type=type(cut), + offset=total_duration, + ) + ) total_spk_set = total_spk_set.union(cut.global_speaker_ids) total_duration += cut.duration # break condition if total_duration >= self.min_duration: - if total_duration > self.max_duration: # exceed the maximum duration, starting over + if total_duration > self.max_duration: # exceed the maximum duration, starting over total_duration = 0.0 total_spk_set = set() tracks = [] session_id = random.choice(self.num_spk2sess_ids[n_speakers]) - if len(total_spk_set) == n_speakers: # meet the number of speakers and duration, break + if len(total_spk_set) == n_speakers: # meet the number of speakers and duration, break break else: total_duration = 0.0 total_spk_set = set() tracks = [] session_id = random.choice(self.num_spk2sess_ids[n_speakers]) - + return tracks, len(total_spk_set) - def get_inter_session_tracks(self, n_speakers: int=4, db_norm: float=-25) -> List[MixTrack]: + def get_inter_session_tracks(self, n_speakers: int = 4, db_norm: float = -25) -> List[MixTrack]: """ Get the tracks for the MixedCut object. """ @@ -604,7 +657,9 @@ def get_inter_session_tracks(self, n_speakers: int=4, db_norm: float=-25) -> Lis sum_spk_list = set([i + j for i in n_spk_list for j in n_spk_list]) if min(sum_spk_list) > n_speakers: - raise ValueError(f"Cannot generate {n_speakers}-speaker inter session samples by concatenating two samples since the dataset {dataset_id} only have {','.join([str(i) for i in n_spk_list])} speakers.") + raise ValueError( + f"Cannot generate {n_speakers}-speaker inter session samples by concatenating two samples since the dataset {dataset_id} only have {','.join([str(i) for i in n_spk_list])} speakers." + ) n_spk_left = n_speakers total_duration = 0.0 @@ -612,7 +667,7 @@ def get_inter_session_tracks(self, n_speakers: int=4, db_norm: float=-25) -> Lis tracks = [] num_spk2cut_ids = self.data2num_spk2cut_ids[dataset_id] while True: - #if n_spk_left == n_speakers: # for more speakers cases + # if n_spk_left == n_speakers: # for more speakers cases # n_spk = random.choice([n_spk for n_spk in n_spk_list if n_spk < n_spk_left]) if n_spk_left >= 2: n_spk = 2 @@ -626,34 +681,44 @@ def get_inter_session_tracks(self, n_speakers: int=4, db_norm: float=-25) -> Lis if not spks.intersection(total_spk_set): break - tracks.append(MixTrack(cut=deepcopy(cut.normalize_loudness(target=db_norm, mix_first=False)), type=type(cut), offset=total_duration)) + tracks.append( + MixTrack( + cut=deepcopy(cut.normalize_loudness(target=db_norm, mix_first=False)), + type=type(cut), + offset=total_duration, + ) + ) total_duration += cut.duration n_spk_left -= n_spk total_spk_set = total_spk_set.union(spks) # break condition - + if total_duration >= self.min_duration: - if total_duration > self.max_duration or len(total_spk_set) < n_speakers: # exceed the maximum duration, starting over + if ( + total_duration > self.max_duration or len(total_spk_set) < n_speakers + ): # exceed the maximum duration, starting over total_duration = 0.0 n_spk_left = n_speakers total_spk_set = set() tracks = [] - if len(total_spk_set) == n_speakers: # meet the number of speakers and duration, break + if len(total_spk_set) == n_speakers: # meet the number of speakers and duration, break break else: - if len(total_spk_set) == n_speakers: # meet the number of speakers, but not the duration, starting over --- TODO: will try to find the segments that only contains those speakers + if ( + len(total_spk_set) == n_speakers + ): # meet the number of speakers, but not the duration, starting over --- TODO: will try to find the segments that only contains those speakers total_duration = 0.0 n_spk_left = n_speakers total_spk_set = set() tracks = [] - + return tracks, len(total_spk_set) - + def reorder_spk_mapping(self, cut: MixedCut, pattern=r'<\|spltoken\d+\|>') -> str: """ Concatenate the texts of the input cuts. - + """ global_spk_mapping = {} str_pattern = pattern.replace("\\", '') @@ -667,12 +732,12 @@ def reorder_spk_mapping(self, cut: MixedCut, pattern=r'<\|spltoken\d+\|>') -> st if speaker not in local_spk_mapping: local_spk_mapping[speaker] = len(local_spk_mapping) local_inverse_spk_mapping[len(local_inverse_spk_mapping)] = speaker - + if i != 0: text = '' - for word in track.cut.text.split(): + for word in track.cut.text.split(): if len(re.findall(pattern, word)) > 0: - local_spk_idx = int(word.replace(left_str,'').replace(right_str, '')) + local_spk_idx = int(word.replace(left_str, '').replace(right_str, '')) spk = local_inverse_spk_mapping[local_spk_idx] global_spk_idx = global_spk_mapping[spk] text += f'{left_str}{global_spk_idx}{right_str}' @@ -682,12 +747,12 @@ def reorder_spk_mapping(self, cut: MixedCut, pattern=r'<\|spltoken\d+\|>') -> st cut.supervisions[i].text = text else: cut.supervisions[0].text = track.cut.text - # TODO: need to check the last speaker of last track and the first speaker of the current track + # TODO: need to check the last speaker of last track and the first speaker of the current track # if they are the same, we need to remove the the speaker token from the current track for segment-level # Do not need to remove the speaker token for word-level - + return cut - + def apply_speaker_distribution(self, num_meetings: int, speaker_count_distribution) -> Dict[int, int]: """ Balance the speaker distribution for the simulated meetings. @@ -700,13 +765,13 @@ def apply_speaker_distribution(self, num_meetings: int, speaker_count_distributi total_spk = sum(speaker_count_distribution) num_speakers2num_meetings = {} for i_spk in range(self.max_num_speakers): - num_speakers2num_meetings[i_spk+1] = round(num_meetings * speaker_count_distribution[i_spk] / total_spk) + num_speakers2num_meetings[i_spk + 1] = round(num_meetings * speaker_count_distribution[i_spk] / total_spk) return num_speakers2num_meetings - - + @dill_enabled(True) - def simulate(self, + def simulate( + self, cuts: CutSet, num_meetings: int = 10000, seed: int = 0, @@ -715,42 +780,59 @@ def simulate(self, random.seed(seed) self.fit(cuts) - num_speakers2num_meetings = self.apply_speaker_distribution(num_meetings, self.speaker_count_distribution) - logging.warn(f"Will be generating {(','.join([str(i) for i in num_speakers2num_meetings.values()]))} samples for {(','.join([str(i) for i in num_speakers2num_meetings.keys()]))} speakers given speaker count distribution of {str(self.speaker_count_distribution)}.") - num_speakers2num_meetings[1] = 0 # skip 1-speaker samples - logging.warn(f'But 1-speaker samples will be skipped. Will be generating {sum(num_speakers2num_meetings.values()) - num_speakers2num_meetings[1]} samples in total.') + logging.warn( + f"Will be generating {(','.join([str(i) for i in num_speakers2num_meetings.values()]))} samples for {(','.join([str(i) for i in num_speakers2num_meetings.keys()]))} speakers given speaker count distribution of {str(self.speaker_count_distribution)}." + ) + num_speakers2num_meetings[1] = 0 # skip 1-speaker samples + logging.warn( + f'But 1-speaker samples will be skipped. Will be generating {sum(num_speakers2num_meetings.values()) - num_speakers2num_meetings[1]} samples in total.' + ) # Step 0: Calculate the number of intra-session and inter-session concatentation samples n_spks = [k for k, v in self.num_spk2cut_ids.items() if len(v) > 0] - valid_sim_n_spks = set([i+j for i in n_spks for j in n_spks]) # valid number of speakers for inter-session samples - n_spk2n_intra_mt, n_spk2n_inter_mt = {i+1:0 for i in range(self.max_num_speakers)}, {i+1:0 for i in range(self.max_num_speakers)} + valid_sim_n_spks = set( + [i + j for i in n_spks for j in n_spks] + ) # valid number of speakers for inter-session samples + n_spk2n_intra_mt, n_spk2n_inter_mt = {i + 1: 0 for i in range(self.max_num_speakers)}, { + i + 1: 0 for i in range(self.max_num_speakers) + } for n_spk, n_mt in num_speakers2num_meetings.items(): - logging.warn(f"=="*16 + f"{n_spk}-speaker" + "=="*16) + logging.warn(f"==" * 16 + f"{n_spk}-speaker" + "==" * 16) if n_mt <= 0: - logging.warning(f"No concatentation samples for {n_spk} speakers. Will skip simulation for {n_spk} speakers.") + logging.warning( + f"No concatentation samples for {n_spk} speakers. Will skip simulation for {n_spk} speakers." + ) continue - n_intra_mt = int(n_mt * self.intra_session_concat_prob[n_spk-1]) + n_intra_mt = int(n_mt * self.intra_session_concat_prob[n_spk - 1]) n_inter_mt = n_mt - n_intra_mt if n_spk in self.num_spk2sess_ids: logging.warn(f"Will be genrating {n_intra_mt} {n_spk}-speaker intra-session concatentation samples.") n_spk2n_intra_mt[n_spk] = n_intra_mt else: - logging.warning(f"Cannot generate {n_intra_mt} {n_spk}-speaker intra-session samples by concatenating two samples from the same session since we only have samples for {','.join([str(i) for i in n_spks])} speakers.") + logging.warning( + f"Cannot generate {n_intra_mt} {n_spk}-speaker intra-session samples by concatenating two samples from the same session since we only have samples for {','.join([str(i) for i in n_spks])} speakers." + ) n_spk2n_intra_mt[n_spk] = 0 n_inter_mt = n_mt if n_spk in valid_sim_n_spks: logging.warn(f"Will be genrating {n_inter_mt} {n_spk}-speaker inter-session concatentation samples.") n_spk2n_inter_mt[n_spk] = n_inter_mt else: - logging.warning(f"Cannot generate {n_inter_mt} {n_spk}-speaker inter-session samples by concatenating two samples from different sessions since we only have samples for {','.join([str(i) for i in n_spks])} speakers.") + logging.warning( + f"Cannot generate {n_inter_mt} {n_spk}-speaker inter-session samples by concatenating two samples from different sessions since we only have samples for {','.join([str(i) for i in n_spks])} speakers." + ) if n_spk2n_intra_mt[n_spk] != 0: n_spk2n_intra_mt[n_spk] = n_mt - logging.warn(f"Will be genrating {n_spk2n_intra_mt[n_spk]} {n_spk}-speaker intra-session concatentation samples instead.") + logging.warn( + f"Will be genrating {n_spk2n_intra_mt[n_spk]} {n_spk}-speaker intra-session concatentation samples instead." + ) else: logging.warning(f"No samples for {n_spk} speakers. Will skip simulation for {n_spk} speakers.") - logging.warn(f"""Will be generating {','.join([str(i) for i in n_spk2n_intra_mt.values()])} intra-session concatentation samples and {','.join([str(i) for i in n_spk2n_inter_mt.values()])} inter-session concatentation samples for {','.join([str(i+1) for i in range(self.max_num_speakers)])} speakers.""") + logging.warn( + f"""Will be generating {','.join([str(i) for i in n_spk2n_intra_mt.values()])} intra-session concatentation samples and {','.join([str(i) for i in n_spk2n_inter_mt.values()])} inter-session concatentation samples for {','.join([str(i+1) for i in range(self.max_num_speakers)])} speakers.""" + ) # Step 1: intra-session num_intra_meetings = 0 intra_mixtures = [] @@ -762,25 +844,30 @@ def simulate(self, for i in tqdm(range(n_mt), desc=f"Simulating {n_spk}-speaker intra-session mixtures", ncols=128): intra_mixtures.append(self._create_mixture(n_speakers=n_spk, is_intra_session_concat=True)) num_intra_meetings += n_mt - logging.info(f"Finished simulating intra-session concatentation samples. Total number of intra-session concatentation samples: {num_intra_meetings}") - + logging.info( + f"Finished simulating intra-session concatentation samples. Total number of intra-session concatentation samples: {num_intra_meetings}" + ) + # Steo 2: inter-session logging.info(f"Simulating inter-session concatentation samples.") - + num_inter_meetings = 0 inter_mixtures = [] for n_spk, n_mt in n_spk2n_inter_mt.items(): if n_mt <= 0: continue - + for i in tqdm(range(n_mt), desc=f"Simulating {n_spk}-speaker inter-session mixtures", ncols=128): inter_mixtures.append(self._create_mixture(n_speakers=n_spk, is_intra_session_concat=False)) num_inter_meetings += n_mt - logging.info(f"Finished simulating inter-session concatentation samples. Total number of inter-session concatentation samples: {num_inter_meetings}") + logging.info( + f"Finished simulating inter-session concatentation samples. Total number of inter-session concatentation samples: {num_inter_meetings}" + ) if num_inter_meetings + num_intra_meetings == 0: - logging.warning(f"No samples are generated. Probably the duration of the segments is not within the range of min {self.min_duration//2} and max {self.max_duration//2}, or the speaker count distribution is not correctly set.") - + logging.warning( + f"No samples are generated. Probably the duration of the segments is not within the range of min {self.min_duration//2} and max {self.max_duration//2}, or the speaker count distribution is not correctly set." + ) # Multi-processing gets slower, TODO # else: @@ -788,7 +875,7 @@ def simulate(self, # for n_spk, n_mt in num_speakers2num_meetings.items(): # tp = concurrent.futures.ProcessPoolExecutor(max_workers=num_jobs) # futures.extend([tp.submit(self._create_mixture, n_spk) for _ in range(n_mt)]) - # pbar = tqdm(total=num_meetings, desc=f"Simulating mixtures", unit="line", ncols=128) + # pbar = tqdm(total=num_meetings, desc=f"Simulating mixtures", unit="line", ncols=128) # count = 0 # for f in concurrent.futures.as_completed(futures): # count += 1 @@ -798,17 +885,17 @@ def simulate(self, # pbar.close() return CutSet.from_cuts(intra_mixtures + inter_mixtures) - -class MixMeetingSimulator(): + +class MixMeetingSimulator: """ This simulator Mix the segments from different/same sessions to create a - multi-speaker meeting. + multi-speaker meeting. """ def __init__( self, - intra_session_mix_prob: float|List[float] = [0, 0, 0, 0], + intra_session_mix_prob: float | List[float] = [0, 0, 0, 0], data_type: str = "msasr", min_duration: float = 80.0, max_duration: float = 100.0, @@ -820,7 +907,7 @@ def __init__( :param intra_session_mix_prob: the probability of concatenating segments from the same session. [Default: 1] :param data_type: the type of data to simulate. Either 'msasr' or 'diar'. If 'msasr', - the transcripts are included in the simulation,and the boundary segments are + the transcripts are included in the simulation,and the boundary segments are not included. [Default: 'msasr'] :param max_duration: the maximum duration of the simulated meeting. [Default: 40.0] """ @@ -830,7 +917,9 @@ def __init__( elif len(intra_session_mix_prob) == max_num_speakers: self.intra_session_mix_prob = intra_session_mix_prob else: - raise ValueError(f"intra_session_mix_prob must be either a float or a list of floats, but got {intra_session_mix_prob}") + raise ValueError( + f"intra_session_mix_prob must be either a float or a list of floats, but got {intra_session_mix_prob}" + ) if data_type not in ["msasr", "diar"]: raise ValueError("data_type must be either 'msasr' or 'diar', but got {data_type}") self.data_type = data_type @@ -839,11 +928,13 @@ def __init__( self.max_num_speakers = max_num_speakers self.speaker_count_distribution = speaker_count_distribution self.valid_dataset_ids = valid_dataset_ids - assert len(speaker_count_distribution) == max_num_speakers, f"Length of speaker_count_distribution {len(speaker_count_distribution)} must be equal to max_num_speakers {max_num_speakers}" + assert ( + len(speaker_count_distribution) == max_num_speakers + ), f"Length of speaker_count_distribution {len(speaker_count_distribution)} must be equal to max_num_speakers {max_num_speakers}" def fit(self, cuts) -> CutSet: """ - Read the manifest file and return a CutSet object. + Read the manifest file and return a CutSet object. Each line in the manifest file should be a JSON object representing a segment. """ @@ -854,7 +945,7 @@ def fit(self, cuts) -> CutSet: self.spk2cut_ids = defaultdict(list) self.data2num_spk2cut_ids = {} self.sess2num_spk2cut_ids = {} - self.num_spk2cut_ids = {i+1:[] for i in range(self.max_num_speakers)} + self.num_spk2cut_ids = {i + 1: [] for i in range(self.max_num_speakers)} for i, cut in tqdm(enumerate(cuts), desc="Reading segments", ncols=100, total=len(cuts)): if not self.min_duration <= cut.duration <= self.max_duration: continue @@ -866,20 +957,20 @@ def fit(self, cuts) -> CutSet: self.data2num_spk2cut_ids[cut.dataset_id] = defaultdict(list) if cut.recording_id not in self.sess2num_spk2cut_ids: self.sess2num_spk2cut_ids[cut.recording_id] = defaultdict(list) - + speakers = cut.global_speaker_ids if self.data_type == "msasr": speaker_tokens = set(re.findall(r'<\|spltoken\d+\|>', cut.text)) - if len(speakers) != len(speaker_tokens): - # Lhotse automatically fixes the max duration of the cut, - # resulting in the mismatch of the number of speakers + if len(speakers) != len(speaker_tokens): + # Lhotse automatically fixes the max duration of the cut, + # resulting in the mismatch of the number of speakers # and speaker tokens for the last segment # TODO: need to fix the issue in Lhotse that automatically fixes the max duration continue for spk in speakers: self.spk2cut_ids[spk].append(cut.id) self.sess2spks[cut.recording_id] = self.sess2spks[cut.recording_id].union(speakers) - + self.id2cut[cut.id] = cut self.sess2cut_ids[cut.recording_id].append(cut.id) self.data2num_spk2cut_ids[cut.dataset_id][len(speakers)].append(cut.id) @@ -887,23 +978,21 @@ def fit(self, cuts) -> CutSet: self.num_spk2cut_ids[len(speakers)].append(cut.id) if cut.recording_id not in self.data2sess_ids[cut.dataset_id]: self.data2sess_ids[cut.dataset_id].append(cut.recording_id) - + self.cut_ids = list(self.id2cut.keys()) self.num_spk2sess_ids = groupby(lambda x: len(self.sess2spks[x]), self.sess2spks.keys()) - - self.data2global_speaker = { - dataset_id: True for dataset_id in self.data2sess_ids.keys() - } - + + self.data2global_speaker = {dataset_id: True for dataset_id in self.data2sess_ids.keys()} + def _create_mixture(self, n_speakers: int, is_intra_session_concat=False) -> MixedCut: - db_norm = norm.rvs(-32.05957708631966, 5.66648411405886) # mean and std from Fisher data - + db_norm = norm.rvs(-32.05957708631966, 5.66648411405886) # mean and std from Fisher data + if is_intra_session_concat: # intra-dataset and intra-session concatenation tracks, num_speakers = self.get_intra_session_tracks(n_speakers, db_norm=db_norm) - else: + else: # intra-dataset but inter-session concatenation tracks, num_speakers = self.get_inter_session_tracks(n_speakers, db_norm=db_norm) @@ -911,43 +1000,51 @@ def _create_mixture(self, n_speakers: int, is_intra_session_concat=False) -> Mix if self.data_type == "msasr": cut = self.reorder_spk_mapping(cut) - assert self.min_duration <= cut.duration <= self.max_duration, f"Total duration {cut.duration} is not within the range of min {self.min_duration} and max {self.max_duration}" - assert n_speakers == num_speakers, f"Total number of speakers {cut.num_speakers} is not equal to the number of speakers {n_speakers}" + assert ( + self.min_duration <= cut.duration <= self.max_duration + ), f"Total duration {cut.duration} is not within the range of min {self.min_duration} and max {self.max_duration}" + assert ( + n_speakers == num_speakers + ), f"Total number of speakers {cut.num_speakers} is not equal to the number of speakers {n_speakers}" return cut - - def get_intra_session_tracks(self, n_speakers: int=4, db_norm: float=-25) -> List[MixTrack]: + + def get_intra_session_tracks(self, n_speakers: int = 4, db_norm: float = -25) -> List[MixTrack]: """ Get the tracks for the MixedCut object. """ session_id = random.choice(self.num_spk2sess_ids[n_speakers]) - + total_spk_set = set() tracks = [] while True: cut = self.id2cut[random.choice(self.sess2cut_ids[session_id])] - tracks.append(MixTrack(cut=deepcopy(cut.normalize_loudness(target=db_norm, mix_first=False)), type=type(cut), offset=0)) + tracks.append( + MixTrack( + cut=deepcopy(cut.normalize_loudness(target=db_norm, mix_first=False)), type=type(cut), offset=0 + ) + ) total_spk_set = total_spk_set.union(cut.global_speaker_ids) total_duration = max(total_duration, cut.duration) # break condition if total_duration >= self.min_duration: - if total_duration > self.max_duration: # exceed the maximum duration, starting over + if total_duration > self.max_duration: # exceed the maximum duration, starting over total_duration = 0.0 total_spk_set = set() tracks = [] session_id = random.choice(self.num_spk2sess_ids[n_speakers]) - if len(total_spk_set) == n_speakers: # meet the number of speakers and duration, break + if len(total_spk_set) == n_speakers: # meet the number of speakers and duration, break break else: total_duration = 0.0 total_spk_set = set() tracks = [] session_id = random.choice(self.num_spk2sess_ids[n_speakers]) - + return tracks, len(total_spk_set) - def get_inter_session_tracks(self, n_speakers: int=4, db_norm: float=-25) -> List[MixTrack]: + def get_inter_session_tracks(self, n_speakers: int = 4, db_norm: float = -25) -> List[MixTrack]: """ Get the tracks for the MixedCut object. """ @@ -957,7 +1054,9 @@ def get_inter_session_tracks(self, n_speakers: int=4, db_norm: float=-25) -> Lis sum_spk_list = set([i + j for i in n_spk_list for j in n_spk_list]) if min(sum_spk_list) > n_speakers: - raise ValueError(f"Cannot generate {n_speakers}-speaker inter session samples by concatenating two samples since the dataset {dataset_id} only have {','.join([str(i) for i in n_spk_list])} speakers.") + raise ValueError( + f"Cannot generate {n_speakers}-speaker inter session samples by concatenating two samples since the dataset {dataset_id} only have {','.join([str(i) for i in n_spk_list])} speakers." + ) n_spk_left = n_speakers total_duration = 0.0 @@ -977,34 +1076,40 @@ def get_inter_session_tracks(self, n_speakers: int=4, db_norm: float=-25) -> Lis if not spks.intersection(total_spk_set): break - tracks.append(MixTrack(cut=deepcopy(cut.normalize_loudness(target=db_norm, mix_first=False)), type=type(cut), offset=0)) + tracks.append( + MixTrack( + cut=deepcopy(cut.normalize_loudness(target=db_norm, mix_first=False)), type=type(cut), offset=0 + ) + ) total_duration = max(total_duration, cut.duration) n_spk_left -= n_spk total_spk_set = total_spk_set.union(spks) # break condition - + if total_duration >= self.min_duration: - if total_duration > self.max_duration or len(tracks) > 2: # exceed the maximum duration, starting over + if total_duration > self.max_duration or len(tracks) > 2: # exceed the maximum duration, starting over total_duration = 0.0 n_spk_left = n_speakers total_spk_set = set() tracks = [] - if len(total_spk_set) == n_speakers: # meet the number of speakers and duration, break + if len(total_spk_set) == n_speakers: # meet the number of speakers and duration, break break else: - if len(total_spk_set) == n_speakers: # meet the number of speakers, but not the duration, starting over --- TODO: will try to find the segments that only contains those speakers + if ( + len(total_spk_set) == n_speakers + ): # meet the number of speakers, but not the duration, starting over --- TODO: will try to find the segments that only contains those speakers total_duration = 0.0 n_spk_left = n_speakers total_spk_set = set() tracks = [] - + return tracks, len(total_spk_set) - + def reorder_spk_mapping(self, cut: MixedCut, pattern=r'<\|spltoken\d+\|>') -> str: """ Concatenate the texts of the input cuts. - + """ global_spk_mapping = {} str_pattern = pattern.replace("\\", '') @@ -1018,12 +1123,12 @@ def reorder_spk_mapping(self, cut: MixedCut, pattern=r'<\|spltoken\d+\|>') -> st if speaker not in local_spk_mapping: local_spk_mapping[speaker] = len(local_spk_mapping) local_inverse_spk_mapping[len(local_inverse_spk_mapping)] = speaker - + if i != 0: text = '' - for word in track.cut.text.split(): + for word in track.cut.text.split(): if len(re.findall(pattern, word)) > 0: - local_spk_idx = int(word.replace(left_str,'').replace(right_str, '')) + local_spk_idx = int(word.replace(left_str, '').replace(right_str, '')) spk = local_inverse_spk_mapping[local_spk_idx] global_spk_idx = global_spk_mapping[spk] text += f'{left_str}{global_spk_idx}{right_str}' @@ -1033,12 +1138,12 @@ def reorder_spk_mapping(self, cut: MixedCut, pattern=r'<\|spltoken\d+\|>') -> st cut.supervisions[i].text = text else: cut.supervisions[0].text = track.cut.text - # TODO: need to check the last speaker of last track and the first speaker of the current track + # TODO: need to check the last speaker of last track and the first speaker of the current track # if they are the same, we need to remove the the speaker token from the current track for segment-level # Do not need to remove the speaker token for word-level - + return cut - + def apply_speaker_distribution(self, num_meetings: int, speaker_count_distribution) -> Dict[int, int]: """ Balance the speaker distribution for the simulated meetings. @@ -1051,13 +1156,13 @@ def apply_speaker_distribution(self, num_meetings: int, speaker_count_distributi total_spk = sum(speaker_count_distribution) num_speakers2num_meetings = {} for i_spk in range(self.max_num_speakers): - num_speakers2num_meetings[i_spk+1] = round(num_meetings * speaker_count_distribution[i_spk] / total_spk) + num_speakers2num_meetings[i_spk + 1] = round(num_meetings * speaker_count_distribution[i_spk] / total_spk) return num_speakers2num_meetings - - + @dill_enabled(True) - def simulate(self, + def simulate( + self, cuts: CutSet, num_meetings: int = 10000, seed: int = 0, @@ -1068,39 +1173,57 @@ def simulate(self, self.fit(cuts) num_speakers2num_meetings = self.apply_speaker_distribution(num_meetings, self.speaker_count_distribution) - logging.warn(f"Will be generating {(','.join([str(i) for i in num_speakers2num_meetings.values()]))} samples for {(','.join([str(i) for i in num_speakers2num_meetings.keys()]))} speakers given speaker count distribution of {str(self.speaker_count_distribution)}.") - num_speakers2num_meetings[1] = 0 # skip 1-speaker samples - logging.warn(f'But 1-speaker samples will be skipped. Will be generating {sum(num_speakers2num_meetings.values()) - num_speakers2num_meetings[1]} samples in total.') + logging.warn( + f"Will be generating {(','.join([str(i) for i in num_speakers2num_meetings.values()]))} samples for {(','.join([str(i) for i in num_speakers2num_meetings.keys()]))} speakers given speaker count distribution of {str(self.speaker_count_distribution)}." + ) + num_speakers2num_meetings[1] = 0 # skip 1-speaker samples + logging.warn( + f'But 1-speaker samples will be skipped. Will be generating {sum(num_speakers2num_meetings.values()) - num_speakers2num_meetings[1]} samples in total.' + ) # Step 0: Calculate the number of intra-session and inter-session concatentation samples n_spks = [k for k, v in self.num_spk2cut_ids.items() if len(v) > 0] - valid_sim_n_spks = set([i+j for i in n_spks for j in n_spks]) # valid number of speakers for inter-session samples - n_spk2n_intra_mt, n_spk2n_inter_mt = {i+1:0 for i in range(self.max_num_speakers)}, {i+1:0 for i in range(self.max_num_speakers)} + valid_sim_n_spks = set( + [i + j for i in n_spks for j in n_spks] + ) # valid number of speakers for inter-session samples + n_spk2n_intra_mt, n_spk2n_inter_mt = {i + 1: 0 for i in range(self.max_num_speakers)}, { + i + 1: 0 for i in range(self.max_num_speakers) + } for n_spk, n_mt in num_speakers2num_meetings.items(): - logging.warn(f"=="*16 + f"{n_spk}-speaker" + "=="*16) + logging.warn(f"==" * 16 + f"{n_spk}-speaker" + "==" * 16) if n_mt <= 0: - logging.warning(f"No intra-session concatentation samples for {n_spk} speakers. Will skip simulation for {n_spk} speakers.") + logging.warning( + f"No intra-session concatentation samples for {n_spk} speakers. Will skip simulation for {n_spk} speakers." + ) continue - n_intra_mt = int(n_mt * self.intra_session_mix_prob[n_spk-1]) + n_intra_mt = int(n_mt * self.intra_session_mix_prob[n_spk - 1]) n_inter_mt = n_mt - n_intra_mt if n_spk in self.num_spk2sess_ids: logging.warn(f"Will be genrating {n_intra_mt} {n_spk}-speaker intra-session concatentation samples.") n_spk2n_intra_mt[n_spk] = n_intra_mt else: - logging.warning(f"Cannot generate {n_intra_mt} {n_spk}-speaker intra-session samples by concatenating two samples from the same session since we only have samples for {','.join([str(i) for i in n_spks])} speakers.") + logging.warning( + f"Cannot generate {n_intra_mt} {n_spk}-speaker intra-session samples by concatenating two samples from the same session since we only have samples for {','.join([str(i) for i in n_spks])} speakers." + ) n_spk2n_intra_mt[n_spk] = 0 n_inter_mt = n_mt if n_spk in valid_sim_n_spks: logging.warn(f"Will be genrating {n_inter_mt} {n_spk}-speaker inter-session concatentation samples.") n_spk2n_inter_mt[n_spk] = n_inter_mt else: - logging.warning(f"Cannot generate {n_inter_mt} {n_spk}-speaker inter-session samples by concatenating two samples from different sessions since we only have samples for {','.join([str(i) for i in n_spks])} speakers.") + logging.warning( + f"Cannot generate {n_inter_mt} {n_spk}-speaker inter-session samples by concatenating two samples from different sessions since we only have samples for {','.join([str(i) for i in n_spks])} speakers." + ) if n_spk2n_intra_mt[n_spk] != 0: n_spk2n_intra_mt[n_spk] = n_mt - logging.warn(f"Will be genrating {n_spk2n_intra_mt[n_spk]} {n_spk}-speaker intra-session concatentation samples instead.") + logging.warn( + f"Will be genrating {n_spk2n_intra_mt[n_spk]} {n_spk}-speaker intra-session concatentation samples instead." + ) else: logging.warning(f"No samples for {n_spk} speakers. Will skip simulation for {n_spk} speakers.") - logging.warn(f"""Will be generating {','.join([str(i) for i in n_spk2n_intra_mt.values()])} intra-session concatentation samples and {','.join([str(i) for i in n_spk2n_inter_mt.values()])} inter-session concatentation samples for {','.join([str(i+1) for i in range(self.max_num_speakers)])} speakers.""") + logging.warn( + f"""Will be generating {','.join([str(i) for i in n_spk2n_intra_mt.values()])} intra-session concatentation samples and {','.join([str(i) for i in n_spk2n_inter_mt.values()])} inter-session concatentation samples for {','.join([str(i+1) for i in range(self.max_num_speakers)])} speakers.""" + ) # Step 1: intra-session num_intra_meetings = 0 intra_mixtures = [] @@ -1112,28 +1235,35 @@ def simulate(self, for i in tqdm(range(n_mt), desc=f"Simulating {n_spk}-speaker intra-session mixtures", ncols=128): intra_mixtures.append(self._create_mixture(n_speakers=n_spk, is_intra_session_concat=True)) num_intra_meetings += n_mt - logging.info(f"Finished simulating intra-session concatentation samples. Total number of intra-session concatentation samples: {num_intra_meetings}") - + logging.info( + f"Finished simulating intra-session concatentation samples. Total number of intra-session concatentation samples: {num_intra_meetings}" + ) + # Steo 2: inter-session logging.info(f"Simulating inter-session concatentation samples.") - + num_inter_meetings = 0 inter_mixtures = [] for n_spk, n_mt in n_spk2n_inter_mt.items(): if n_mt <= 0: continue - + for i in tqdm(range(n_mt), desc=f"Simulating {n_spk}-speaker inter-session mixtures", ncols=128): inter_mixtures.append(self._create_mixture(n_speakers=n_spk, is_intra_session_concat=False)) num_inter_meetings += n_mt - logging.info(f"Finished simulating inter-session concatentation samples. Total number of inter-session concatentation samples: {num_inter_meetings}") + logging.info( + f"Finished simulating inter-session concatentation samples. Total number of inter-session concatentation samples: {num_inter_meetings}" + ) if num_inter_meetings + num_intra_meetings == 0: - logging.warning(f"No samples are generated. Probably the duration of the segments is not within the range of min {self.min_duration} and max {self.max_duration}, or the speaker count distribution is not correctly set.") + logging.warning( + f"No samples are generated. Probably the duration of the segments is not within the range of min {self.min_duration} and max {self.max_duration}, or the speaker count distribution is not correctly set." + ) return CutSet.from_cuts(intra_mixtures + inter_mixtures) -class LibriSpeechMixSimulator(): + +class LibriSpeechMixSimulator: def __init__( self, @@ -1151,12 +1281,15 @@ def __init__( self.max_duration = max_duration self.n_mix_speakers = n_mix_speakers self.speaker_count_distribution = speaker_count_distribution - assert len(speaker_count_distribution) == len(n_mix_speakers), f"Length of speaker_count_distribution {len(speaker_count_distribution)} must be equal to max_num_speakers {len(n_mix_speakers)}" + assert len(speaker_count_distribution) == len( + n_mix_speakers + ), f"Length of speaker_count_distribution {len(speaker_count_distribution)} must be equal to max_num_speakers {len(n_mix_speakers)}" def fit(self, cuts) -> CutSet: pass - def simulate(self, + def simulate( + self, cuts: CutSet, num_meetings: int = 10000, seed: int = 0, @@ -1172,7 +1305,8 @@ def simulate(self, cut_set.append(self._create_mixture(n_speakers=n_speakers)) return CutSet.from_cuts(cut_set) -class LibriSpeechMixGenerator(): + +class LibriSpeechMixGenerator: def __init__(self): pass @@ -1201,18 +1335,12 @@ def generate(self, cuts): supervisions=[], recording=Recording( id=wav.split('/')[-1].replace('.wav', ''), - sources=[ - AudioSource( - type='file', - channels=[0], - source=wav - ) - ], - sampling_rate=16000, + sources=[AudioSource(type='file', channels=[0], source=wav)], + sampling_rate=16000, num_samples=wav_samples, - duration=wav_dur + duration=wav_dur, ), - custom=custom + custom=custom, ) tracks.append(MixTrack(cut=cut_1spk, type=type(cut_1spk), offset=offset)) @@ -1220,12 +1348,12 @@ def generate(self, cuts): id=cut.id, recording_id=cut.recording_id, start=0, - duration=offset+wav_dur, + duration=offset + wav_dur, text=cut.text, ) tracks[0].cut.supervisions.append(sup) cut_multi_spk = MixedCut(id=cut.id, tracks=tracks) - + cut_set.append(cut_multi_spk) - - return CutSet.from_cuts(cut_set) \ No newline at end of file + + return CutSet.from_cuts(cut_set) diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index 80b3e1f918b8..046f32c1d48f 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -21,10 +21,10 @@ from typing import Dict, List, Tuple, Union import numpy as np -from omegaconf import OmegaConf -from omegaconf.listconfig import ListConfig import soundfile as sf import torch +from omegaconf import OmegaConf +from omegaconf.listconfig import ListConfig from pyannote.core import Annotation, Segment, Timeline from tqdm import tqdm @@ -589,7 +589,7 @@ def write_overlap_segments(outfile, AUDIO_RTTM_MAP, uniq_id, overlap_range_list, Number of decimals to round the offset and duration values. """ audio_path = AUDIO_RTTM_MAP[uniq_id]['audio_filepath'] - for (stt, end) in overlap_range_list: + for stt, end in overlap_range_list: meta = { "audio_filepath": audio_path, "offset": round(stt, decimals), @@ -749,14 +749,14 @@ def fl2int(x: float, decimals: int = 3) -> int: """ Convert floating point number to integer. """ - return torch.round(torch.tensor([x * (10 ** decimals)]), decimals=0).int().item() + return torch.round(torch.tensor([x * (10**decimals)]), decimals=0).int().item() def int2fl(x: int, decimals: int = 3) -> float: """ Convert integer to floating point number. """ - return torch.round(torch.tensor([x / (10 ** decimals)]), decimals=decimals).item() + return torch.round(torch.tensor([x / (10**decimals)]), decimals=decimals).item() def merge_float_intervals(ranges: List[List[float]], decimals: int = 5, margin: int = 2) -> List[List[float]]: @@ -902,9 +902,10 @@ def segments_manifest_to_subsegments_manifest( pwd = os.getcwd() subsegments_manifest_file = os.path.join(pwd, 'subsegments.json') - with open(segments_manifest_file, 'r') as segments_manifest, open( - subsegments_manifest_file, 'w' - ) as subsegments_manifest: + with ( + open(segments_manifest_file, 'r') as segments_manifest, + open(subsegments_manifest_file, 'w') as subsegments_manifest, + ): segments = segments_manifest.readlines() for segment in segments: segment = segment.strip() @@ -933,22 +934,22 @@ def segments_manifest_to_subsegments_manifest( def get_subsegments( - offset: float, - window: float, - shift: float, - duration: float, + offset: float, + window: float, + shift: float, + duration: float, min_subsegment_duration: float = 0.01, decimals: int = 2, use_asr_style_frame_count: bool = False, sample_rate: int = 16000, feat_per_sec: int = 100, - ) -> List[List[float]]: +) -> List[List[float]]: """ Return subsegments from a segment of audio file. - + Example: (window, shift) = 1.5, 0.75 - Segment: [12.05, 14.45] + Segment: [12.05, 14.45] Subsegments: [[12.05, 13.55], [12.8, 14.3], [13.55, 14.45], [14.3, 14.45]] Args: @@ -959,30 +960,30 @@ def get_subsegments( min_subsegment_duration (float): Exclude subsegments smaller than this duration value decimals (int): Number of decimal places to round to use_asr_style_frame_count (bool): If True, use asr style frame count to generate subsegments. - For example, if duration is 10 secs and frame_shift is 0.08 secs, + For example, if duration is 10 secs and frame_shift is 0.08 secs, it results in (10/0.08)+1 = 125 + 1 frames. - + Returns: subsegments (List[tuple[float, float]]): subsegments generated for the segments as list of tuple of start and duration of each subsegment """ - subsegments: List[List[float]] = [] + subsegments: List[List[float]] = [] start = offset slice_end = start + duration if min_subsegment_duration <= duration < shift: slices = 1 - elif use_asr_style_frame_count is True: - num_feat_frames = np.ceil((1+duration*sample_rate)/int(sample_rate/feat_per_sec)).astype(int) - slices = np.ceil(num_feat_frames/int(feat_per_sec*shift)).astype(int) + elif use_asr_style_frame_count is True: + num_feat_frames = np.ceil((1 + duration * sample_rate) / int(sample_rate / feat_per_sec)).astype(int) + slices = np.ceil(num_feat_frames / int(feat_per_sec * shift)).astype(int) slice_end = start + shift * slices else: - slices = np.ceil(1+ (duration-window)/shift).astype(int) + slices = np.ceil(1 + (duration - window) / shift).astype(int) if slices == 1: if min(duration, window) >= min_subsegment_duration: subsegments.append([start, min(duration, window)]) - elif slices > 0: # What if slcies = 0 ? + elif slices > 0: # What if slcies = 0 ? start_col = torch.arange(offset, slice_end, shift)[:slices] dur_col = window * torch.ones(slices) - dur_col = torch.min(slice_end*torch.ones_like(start_col)- start_col, window * torch.ones_like(start_col)) + dur_col = torch.min(slice_end * torch.ones_like(start_col) - start_col, window * torch.ones_like(start_col)) dur_col = torch.round(dur_col, decimals=decimals) valid_mask = dur_col >= min_subsegment_duration valid_subsegments = torch.stack([start_col[valid_mask], dur_col[valid_mask]], dim=1) @@ -990,7 +991,13 @@ def get_subsegments( return subsegments -def get_target_sig(sig, start_sec: float, end_sec: float, slice_length: int, sample_rate: int,) -> torch.Tensor: +def get_target_sig( + sig, + start_sec: float, + end_sec: float, + slice_length: int, + sample_rate: int, +) -> torch.Tensor: """ Extract time-series signal from the given audio buffer based on the start and end timestamps. @@ -1037,15 +1044,16 @@ def tensor_to_list(range_tensor: torch.Tensor) -> List[List[float]]: return [[float(range_tensor[k][0]), float(range_tensor[k][1])] for k in range(range_tensor.shape[0])] -def generate_diarization_output_lines(speaker_timestamps, model_spk_num): - speaker_lines_total = [] +def generate_diarization_output_lines(speaker_timestamps, model_spk_num): + speaker_lines_total = [] for spk_idx in range(model_spk_num): ts_invervals = speaker_timestamps[spk_idx] merged_ts_intervals = merge_float_intervals(ts_invervals) for ts_interval in merged_ts_intervals: speaker_lines_total.extend([f"{ts_interval[0]:.3f} {ts_interval[1]:.3f} speaker_{int(spk_idx)}"]) return speaker_lines_total - + + def get_speech_labels_for_update( frame_start: float, buffer_end: float, @@ -1113,9 +1121,12 @@ def get_speech_labels_for_update( return speech_label_for_new_segments, cumulative_speech_labels -def get_new_cursor_for_update(frame_start: float, segment_range_ts: List[List[float]],) -> Tuple[float, int]: +def get_new_cursor_for_update( + frame_start: float, + segment_range_ts: List[List[float]], +) -> Tuple[float, int]: """ - Function for updating a cursor online speaker diarization. + Function for updating a cursor online speaker diarization. Remove the old segments that overlap with the new frame (self.frame_start) cursor_for_old_segments is set to the onset of the t_range popped lastly. @@ -1273,7 +1284,10 @@ def get_online_subsegments_from_buffer( range_t = [max(0, range_offs[0]), range_offs[1]] subsegments = get_subsegments( - offset=range_t[0], window=window, shift=shift, duration=(range_t[1] - range_t[0]), + offset=range_t[0], + window=window, + shift=shift, + duration=(range_t[1] - range_t[0]), ) ind_offset, sigs, ranges, inds = get_online_segments_from_slices( sig=audio_buffer, @@ -1444,8 +1458,7 @@ def generate_speaker_timestamps( def get_uniq_id_list_from_manifest(manifest_file: str): - """Retrieve `uniq_id` values from the given manifest_file and save the IDs to a list. - """ + """Retrieve `uniq_id` values from the given manifest_file and save the IDs to a list.""" uniq_id_list = [] with open(manifest_file, 'r', encoding='utf-8') as manifest: for i, line in enumerate(manifest.readlines()): @@ -1626,21 +1639,22 @@ def make_rttm_with_overlap( return all_reference, all_hypothesis -def timestamps_to_pyannote_object(speaker_timestamps: List[Tuple[float, float]], - uniq_id: str, - audio_rttm_values: Dict[str, str], - all_hypothesis: List[Tuple[str, Timeline]], - all_reference: List[Tuple[str, Timeline]], - all_uems: List[Tuple[str, Timeline]], - out_rttm_dir: str | None - ): - """ +def timestamps_to_pyannote_object( + speaker_timestamps: List[Tuple[float, float]], + uniq_id: str, + audio_rttm_values: Dict[str, str], + all_hypothesis: List[Tuple[str, Timeline]], + all_reference: List[Tuple[str, Timeline]], + all_uems: List[Tuple[str, Timeline]], + out_rttm_dir: str | None, +): + """ Convert speaker timestamps to pyannote.core.Timeline object. - + Args: - speaker_timestamps (List[Tuple[float, float]]): + speaker_timestamps (List[Tuple[float, float]]): Timestamps of each speaker: start time and end time of each speaker. - uniq_id (str): + uniq_id (str): Unique ID of each speaker. audio_rttm_values (Dict[str, str]): Dictionary of manifest values. @@ -1652,7 +1666,7 @@ def timestamps_to_pyannote_object(speaker_timestamps: List[Tuple[float, float]], List of uems in pyannote.core.Timeline object. out_rttm_dir (str | None): Directory to save RTTMs - + Returns: all_hypothesis (List[Tuple[str, pyannote.core.Timeline]]): List of hypothesis in pyannote.core.Timeline object with an added Timeline object. @@ -1662,47 +1676,49 @@ def timestamps_to_pyannote_object(speaker_timestamps: List[Tuple[float, float]], List of uems in pyannote.core.Timeline object with an added Timeline object. """ offset, dur = float(audio_rttm_values.get('offset', None)), float(audio_rttm_values.get('duration', None)) - hyp_labels = generate_diarization_output_lines(speaker_timestamps=speaker_timestamps, model_spk_num=len(speaker_timestamps)) + hyp_labels = generate_diarization_output_lines( + speaker_timestamps=speaker_timestamps, model_spk_num=len(speaker_timestamps) + ) hypothesis = labels_to_pyannote_object(hyp_labels, uniq_name=uniq_id) if out_rttm_dir is not None and os.path.exists(out_rttm_dir): - with open(f'{out_rttm_dir}/{uniq_id}.rttm','w') as f: + with open(f'{out_rttm_dir}/{uniq_id}.rttm', 'w') as f: hypothesis.write_rttm(f) all_hypothesis.append([uniq_id, hypothesis]) rttm_file = audio_rttm_values.get('rttm_filepath', None) if rttm_file is not None and os.path.exists(rttm_file): - uem_lines = [[offset, dur+offset]] + uem_lines = [[offset, dur + offset]] org_ref_labels = rttm_to_labels(rttm_file) ref_labels = org_ref_labels reference = labels_to_pyannote_object(ref_labels, uniq_name=uniq_id) uem_obj = get_uem_object(uem_lines, uniq_id=uniq_id) all_uems.append(uem_obj) all_reference.append([uniq_id, reference]) - return all_hypothesis, all_reference, all_uems - + return all_hypothesis, all_reference, all_uems + + def get_uem_object(uem_lines: List[List[float]], uniq_id: str): """ Generate pyannote timeline segments for uem file. - + file format UNIQ_SPEAKER_ID CHANNEL START_TIME END_TIME - + Args: uem_lines (list): list of session ID and start, end times. Example: [[0.0, 30.41], [60.04, 165.83]] uniq_id (str): Unique session ID. - + Returns: timeline (pyannote.core.Timeline): pyannote timeline object. """ timeline = Timeline(uri=uniq_id) for uem_stt_end in uem_lines: - start_time, end_time = uem_stt_end + start_time, end_time = uem_stt_end timeline.add(Segment(float(start_time), float(end_time))) return timeline - def embedding_normalize(embs, use_std=False, eps=1e-10): """ Mean and l2 length normalize the input speaker embeddings diff --git a/nemo/collections/asr/parts/utils/vad_utils.py b/nemo/collections/asr/parts/utils/vad_utils.py index 192c42375dca..f7374931bc45 100644 --- a/nemo/collections/asr/parts/utils/vad_utils.py +++ b/nemo/collections/asr/parts/utils/vad_utils.py @@ -35,8 +35,9 @@ from sklearn.metrics import roc_auc_score from sklearn.model_selection import ParameterGrid from tqdm import tqdm -from nemo.collections.asr.parts.utils.speaker_utils import timestamps_to_pyannote_object + from nemo.collections.asr.models import EncDecClassificationModel, EncDecFrameClassificationModel +from nemo.collections.asr.parts.utils.speaker_utils import timestamps_to_pyannote_object from nemo.collections.common.parts.preprocessing.manifest import get_full_path from nemo.utils import logging @@ -576,7 +577,7 @@ def filtering(speech_segments: torch.Tensor, per_args: Dict[str, float]) -> torc """ if speech_segments.shape == torch.Size([0]): return speech_segments - + min_duration_on = per_args.get('min_duration_on', 0.0) min_duration_off = per_args.get('min_duration_off', 0.0) filter_speech_first = per_args.get('filter_speech_first', 1.0) @@ -1712,34 +1713,34 @@ def frame_vad_eval_detection_error( def ts_vad_post_processing( - ts_vad_binary_vec: torch.Tensor, - cfg_vad_params: OmegaConf, - unit_10ms_frame_count: int=8, - bypass_postprocessing: bool = False - ): + ts_vad_binary_vec: torch.Tensor, + cfg_vad_params: OmegaConf, + unit_10ms_frame_count: int = 8, + bypass_postprocessing: bool = False, +): """ Post-processing on diarization results using VAD style post-processing methods. These post-processing methods are inspired by the following paper: - Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020). + Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020). Args: - ts_vad_binary_vec (Tensor): + ts_vad_binary_vec (Tensor): Sigmoid values of each frame and each speaker. Dimension: (num_frames,) - cfg_vad_params (OmegaConf): + cfg_vad_params (OmegaConf): Configuration (omega config) of VAD parameters. - unit_10ms_frame_count (int, optional): + unit_10ms_frame_count (int, optional): an integer indicating the number of 10ms frames in a unit. For example, if unit_10ms_frame_count is 8, then each frame is 0.08 seconds. - bypass_postprocessing (bool, optional): + bypass_postprocessing (bool, optional): If True, diarization post-processing will be bypassed. Returns: - speech_segments (Tensor): + speech_segments (Tensor): start and end of each speech segment. Dimension: (num_segments, 2) - - Example: + + Example: tensor([[ 0.0000, 3.0400], [ 6.0000, 6.0800], ... @@ -1751,9 +1752,9 @@ def ts_vad_post_processing( speech_segments = binarization(ts_vad_binary_frames, cfg_vad_params) speech_segments = filtering(speech_segments, cfg_vad_params) else: - cfg_vad_params.onset=0.5 - cfg_vad_params.offset=0.5 - cfg_vad_params.pad_onset=0.0 - cfg_vad_params.pad_offset=0.0 + cfg_vad_params.onset = 0.5 + cfg_vad_params.offset = 0.5 + cfg_vad_params.pad_onset = 0.0 + cfg_vad_params.pad_offset = 0.0 speech_segments = binarization(ts_vad_binary_frames, cfg_vad_params) - return speech_segments \ No newline at end of file + return speech_segments diff --git a/nemo/collections/common/parts/preprocessing/collections.py b/nemo/collections/common/parts/preprocessing/collections.py index 144ae405de52..632ec06bc647 100644 --- a/nemo/collections/common/parts/preprocessing/collections.py +++ b/nemo/collections/common/parts/preprocessing/collections.py @@ -1242,6 +1242,7 @@ def __parse_item_rttm(self, line: str, manifest_file: str) -> Dict[str, Any]: ) return item + class EndtoEndDiarizationLabel(_Collection): """List of diarization audio-label correspondence with preprocessing.""" @@ -1283,9 +1284,7 @@ def __init__( output_type = self.OUTPUT_TYPE data, duration_filtered = [], 0.0 - zipped_items = zip( - audio_files, uniq_ids, durations, rttm_files, offsets - ) + zipped_items = zip(audio_files, uniq_ids, durations, rttm_files, offsets) for ( audio_file, uniq_id, @@ -1328,7 +1327,8 @@ def __init__( data.sort(key=lambda entity: entity.duration) logging.info( - "Filtered duration for loading collection is %f.", duration_filtered, + "Filtered duration for loading collection is %f.", + duration_filtered, ) logging.info(f"Total {len(data)} session files loaded accounting to # {len(audio_files)} audio clips") @@ -1346,8 +1346,8 @@ def __init__( **kwargs, ): """ - Parse lists of audio files, durations, RTTM (Diarization annotation) files. - Since diarization model infers only two speakers, speaker pairs are generated + Parse lists of audio files, durations, RTTM (Diarization annotation) files. + Since diarization model infers only two speakers, speaker pairs are generated from the total number of speakers in the session. Args: @@ -1404,12 +1404,12 @@ def __parse_item_rttm(self, line: str, manifest_file: str) -> Dict[str, Any]: raise ValueError( f"Manifest file has invalid json line " f"structure: {line} without proper audio file key." ) - if isinstance(item['audio_file'], list): + if isinstance(item['audio_file'], list): item['audio_file'] = [os.path.expanduser(audio_file_path) for audio_file_path in item['audio_file']] else: item['audio_file'] = os.path.expanduser(item['audio_file']) - if not isinstance(item['audio_file'], list): + if not isinstance(item['audio_file'], list): if 'uniq_id' not in item: item['uniq_id'] = os.path.splitext(os.path.basename(item['audio_file']))[0] elif 'uniq_id' not in item: From 4ddc59bc0d8fc606d8452ddcfcc5e5a12a3ed9e0 Mon Sep 17 00:00:00 2001 From: taejinp Date: Thu, 14 Nov 2024 16:56:08 -0800 Subject: [PATCH 05/16] Reflecting comments and removing unnecessary parts for this PR Signed-off-by: taejinp --- ...rtformer_diarizer_hybrid_loss_4spk-v1.yaml | 35 +- ...ortformer_diar_4spk-v1_callhome-part1.yaml | 4 - .../sortformer_diar_4spk-v1_dihard-dev.yaml | 4 - .../neural_diarizer/e2e_diarize_speech.py | 35 +- nemo/collections/asr/models/__init__.py | 10 +- .../asr/models/sortformer_diar_models.py | 16 +- .../asr/modules/sortformer_modules.py | 16 +- .../asr/parts/utils/asr_multispeaker_utils.py | 1024 ++--------------- .../asr/parts/utils/speaker_utils.py | 6 +- nemo/collections/asr/parts/utils/vad_utils.py | 1 - 10 files changed, 122 insertions(+), 1029 deletions(-) diff --git a/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml b/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml index e44bae976729..04409a4cd60a 100644 --- a/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml +++ b/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml @@ -1,6 +1,6 @@ -# Sortformer Diarizer is an end-to-end speaker diarization model that is solely based on Transformer-encoder type of architecture. -# Model name convention for Sortformer Diarizer: sortformer_diarizer____loss.yaml -# (Example) `sortformer_diarizer_FC18_TF18_hybrid_loss.yaml` has 18 layers for FastConformer and 18 layers of Transformer. +sortformer_diarizer_hybrid_loss_4spk-v1.yaml# Sortformer Diarizer is an end-to-end speaker diarization model that is solely based on Transformer-encoder type of architecture. +# Model name convention for Sortformer Diarizer: sortformer_diarizer___.yaml +# (Example) `sortformer_diarizer_hybrid_loss_4spk-v1.yaml`. # Sortformer Diarizer model checkpoint (.ckpt) and NeMo file (.nemo) contain Fast Conformer Encoder model (NEST Encoder) and the pre-trained NEST model is loaded along with the Transformer Encoder layers. # Example: a manifest line for training # {"audio_filepath": "/path/to/audio01.wav", "offset": 390.83, "duration": 90.00, "text": "-", "num_speakers": 2, "rttm_filepath": "/path/to/audio01.rttm"} @@ -10,21 +10,21 @@ num_workers: 18 batch_size: 8 model: - pil_weight: 0.5 - ats_weight: 0.5 - num_workers: ${num_workers} - fc_d_model: 512 - tf_d_model: 192 - max_num_of_spks: 4 # Number of speakers per model. This is currently fixed at 4. - session_len_sec: 90 + pil_weight: 0.5 # Weight for Permutation Invariant Loss (PIL) used in training the Sortformer diarizer model + ats_weight: 0.5 # Weight for Arrival Time Sort (ATS) loss in training the Sortformer diarizer model + num_workers: ${num_workers} # Number of workers for data loading + fc_d_model: 512 # Hidden dimension size of the Fast-conformer Encoder + tf_d_model: 192 # Hidden dimension size of the Transformer Encoder + max_num_of_spks: 4 # Maximum number of speakers per model; currently set to 4 + session_len_sec: 90 # Maximum session length in seconds train_ds: manifest_filepath: ??? sample_rate: ${sample_rate} num_spks: ${model.max_num_of_spks} session_len_sec: ${model.session_len_sec} - soft_label_thres: 0.5 - soft_targets: False + soft_label_thres: 0.5 # Threshold for binarizing target values; higher values make the model more conservative in predicting speaker activity. + soft_targets: False # If True, use continuous values as target values when calculating cross-entropy loss labels: null batch_size: ${batch_size} shuffle: True @@ -52,7 +52,7 @@ model: sample_rate: ${sample_rate} num_spks: ${model.max_num_of_spks} session_len_sec: ${model.session_len_sec} - soft_label_thres: 0.5 + soft_label_thres: 0.5 # A threshold value for setting up the binarized labels. The higher the more conservative the model becomes. soft_targets: False labels: null batch_size: ${batch_size} @@ -121,10 +121,8 @@ model: subsampling_factor: 8 # must be power of 2 for striding and vggnet subsampling_conv_channels: 256 # set to -1 to make it equal to the d_model causal_downsampling: false - # Feed forward module's params ff_expansion_factor: 4 - # Multi-headed Attention Module's params self_attention_model: rel_pos # rel_pos or abs_pos n_heads: 8 # may need to be lower for smaller d_models @@ -134,19 +132,16 @@ model: xscaling: true # scales up the input embeddings by sqrt(d_model) untie_biases: true # unties the biases of the TransformerXL layers pos_emb_max_len: 5000 - # Convolution module's params conv_kernel_size: 9 conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) conv_context_size: null - - ### regularization + # Regularization dropout: 0.1 # The dropout used in most of the Conformer Modules dropout_pre_encoder: 0.1 # The dropout used before the encoder dropout_emb: 0.0 # The dropout used for embeddings dropout_att: 0.1 # The dropout for multi-headed attention modules - - # set to non-zero to enable stochastic depth + # Set to non-zero to enable stochastic depth stochastic_depth_drop_prob: 0.0 stochastic_depth_mode: linear # linear or uniform stochastic_depth_start_layer: 1 diff --git a/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_callhome-part1.yaml b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_callhome-part1.yaml index 3733e1285b77..ebed4a649730 100644 --- a/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_callhome-part1.yaml +++ b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_callhome-part1.yaml @@ -5,10 +5,6 @@ # These parameters were optimized on the development split of DIHARD3 dataset. See https://arxiv.org/pdf/2012.01477. # Trial 24682 finished with value: 0.10257785779242055 and parameters: {'onset': 0.53, 'offset': 0.49, 'pad_onset': 0.23, 'pad_offset': 0.01, 'min_duration_on': 0.42, 'min_duration_off': 0.34}. Best is trial 24682 with value: 0.10257785779242055. parameters: - window_length_in_sec: 0.0 # Not used - shift_length_in_sec: 0.0 # Not used - smoothing: False # Not used - overlap: 0.5 # Not used onset: 0.53 # Onset threshold for detecting the beginning and end of a speech offset: 0.49 # Offset threshold for detecting the end of a speech pad_onset: 0.23 # Adding durations before each speech segment diff --git a/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_dihard-dev.yaml b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_dihard-dev.yaml index 275bc86db4cd..9beaff6e3c7c 100644 --- a/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_dihard-dev.yaml +++ b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_dihard-dev.yaml @@ -5,10 +5,6 @@ # These parameters were optimized on CallHome Dataset from the NIST SRE 2000 Disc8, especially from the split2 specified in: Kaldi, “Kaldi x-vector recipe v2,” https://github.com/kaldi-asr/kaldi/tree/master/egs/callhome_diarization/v2. # Trial 732 finished with value: 0.12171946949255649 and parameters: {'onset': 0.64, 'offset': 0.74, 'pad_onset': 0.06, 'pad_offset': 0.0, 'min_duration_on': 0.1, 'min_duration_off': 0.15}. Best is trial 732 with value: 0.12171946949255649. parameters: - window_length_in_sec: 0.0 # Not used - shift_length_in_sec: 0.0 # Not used - smoothing: False # Not used - overlap: 0.5 # Not used onset: 0.64 # Onset threshold for detecting the beginning and end of a speech offset: 0.74 # Offset threshold for detecting the end of a speech pad_onset: 0.06 # Adding durations before each speech segment diff --git a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py index 98f2ee10e523..a2dcd15dbb71 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py @@ -47,17 +47,12 @@ @dataclass class PostProcessingParams: - window_length_in_sec: float = 0.15 - shift_length_in_sec: float = 0.01 - smoothing: bool = False - overlap: float = 0.5 - onset: float = 0.5 - offset: float = 0.5 - pad_onset: float = 0.0 - pad_offset: float = 0.0 - min_duration_on: float = 0.0 - min_duration_off: float = 0.0 - filter_speech_first: bool = True + onset: float = 0.5 # Onset threshold for detecting the beginning and end of a speech + offset: float = 0.5 # Offset threshold for detecting the end of a speech + pad_onset: float = 0.0 # Adding durations before each speech segment + pad_offset: float = 0.0 # Adding durations after each speech segment + min_duration_on: float = 0.0 # Threshold for small non-speech deletion + min_duration_off: float = 0.0 # Threshold for short speech segment deletion @dataclass class DiarizationConfig: @@ -124,7 +119,9 @@ def load_postprocessing_from_yaml(postprocessing_yaml): def optuna_suggest_params(postprocessing_cfg: PostProcessingParams, trial: optuna.Trial) -> PostProcessingParams: """ Suggests hyperparameters for postprocessing using Optuna. - + See the following link for `trial` instance in Optuna framework. + https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html#optuna.trial.Trial + Args: postprocessing_cfg (PostProcessingParams): The current postprocessing configuration. trial (optuna.Trial): The Optuna trial object used to suggest hyperparameters. @@ -373,13 +370,13 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: out_rttm_dir=cfg.out_rttm_dir ) logging.info(f"Evaluating the model on the {len(diar_model_preds_total_list)} audio segments...") - metric, mapping_dict, itemized_errors = score_labels(AUDIO_RTTM_MAP=infer_audio_rttm_dict, - all_reference=all_refs, - all_hypothesis=all_hyps, - all_uem=all_uems, - collar=cfg.collar, - ignore_overlap=cfg.ignore_overlap - ) + score_labels(AUDIO_RTTM_MAP=infer_audio_rttm_dict, + all_reference=all_refs, + all_hypothesis=all_hyps, + all_uem=all_uems, + collar=cfg.collar, + ignore_overlap=cfg.ignore_overlap + ) logging.info(f"PostProcessingParams: {postprocessing_cfg}") if __name__ == '__main__': diff --git a/nemo/collections/asr/models/__init__.py b/nemo/collections/asr/models/__init__.py index 31194d8849f0..e85500593656 100644 --- a/nemo/collections/asr/models/__init__.py +++ b/nemo/collections/asr/models/__init__.py @@ -19,8 +19,8 @@ EncDecClassificationModel, EncDecFrameClassificationModel, ) -from nemo.collections.asr.models.sortformer_diar_models import SortformerEncLabelModel from nemo.collections.asr.models.clustering_diarizer import ClusteringDiarizer +from nemo.collections.asr.models.sortformer_diar_models import SortformerEncLabelModel from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE from nemo.collections.asr.models.ctc_models import EncDecCTCModel from nemo.collections.asr.models.hybrid_rnnt_ctc_bpe_models import EncDecHybridRNNTCTCBPEModel @@ -36,5 +36,9 @@ from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel from nemo.collections.asr.models.slu_models import SLUIntentSlotBPEModel -from nemo.collections.asr.models.ssl_models import SpeechEncDecSelfSupervisedModel -from nemo.collections.asr.models.transformer_bpe_models import EncDecTransfModelBPE +from nemo.collections.asr.models.ssl_models import ( + EncDecDenoiseMaskedTokenPredModel, + EncDecMaskedTokenPredModel, + SpeechEncDecSelfSupervisedModel, +) +from nemo.collections.asr.models.transformer_bpe_models import EncDecTransfModelBPE \ No newline at end of file diff --git a/nemo/collections/asr/models/sortformer_diar_models.py b/nemo/collections/asr/models/sortformer_diar_models.py index 50cdf6214d5b..7b2b5cf17793 100644 --- a/nemo/collections/asr/models/sortformer_diar_models.py +++ b/nemo/collections/asr/models/sortformer_diar_models.py @@ -36,17 +36,6 @@ from nemo.collections.asr.parts.utils.asr_multispeaker_utils import get_pil_targets, get_ats_targets from nemo.utils import logging -try: - from torch.cuda.amp import autocast -except ImportError: - from contextlib import contextmanager - - @contextmanager - def autocast(enabled=None): - yield - -# torch.backends.cudnn.enabled = False - __all__ = ['SortformerEncLabelModel'] class SortformerEncLabelModel(ModelPT, ExportableEncDecModel): @@ -549,14 +538,13 @@ def test_batch(self,): audio_signal_length=audio_signal_length, ) preds = preds.detach().to('cpu') - if preds.shape[0] == 1: # batch size = 1 + if preds.shape[0] == 1: # If batch size is absolute 1 self.preds_total_list.append(preds) else: self.preds_total_list.extend(torch.split(preds, [1] * preds.shape[0])) torch.cuda.empty_cache() self._get_aux_test_batch_evaluations(batch_idx, preds, targets, target_lens) - # except: - # import ipdb; ipdb.set_trace() + logging.info(f"Batch F1Acc. MEAN: {torch.mean(torch.tensor(self.batch_f1_accs_list))}") logging.info(f"Batch Precision MEAN: {torch.mean(torch.tensor(self.batch_precision_list))}") logging.info(f"Batch Recall MEAN: {torch.mean(torch.tensor(self.batch_recall_list))}") diff --git a/nemo/collections/asr/modules/sortformer_modules.py b/nemo/collections/asr/modules/sortformer_modules.py index 823cf98590e7..1805327ab69b 100644 --- a/nemo/collections/asr/modules/sortformer_modules.py +++ b/nemo/collections/asr/modules/sortformer_modules.py @@ -20,7 +20,6 @@ from nemo.core.classes.exportable import Exportable from nemo.core.classes.module import NeuralModule -from nemo.core.neural_types import EncodedRepresentation, LengthsType, NeuralType, SpectrogramType from nemo.core.neural_types.elements import ProbsType __all__ = ['SortformerModules'] @@ -37,23 +36,12 @@ class SortformerModules(NeuralModule, Exportable): Max number of speakers that are processed by the model. In `MSDD_module`, `num_spks=2` for pairwise inference. hidden_size (int): Number of hidden units in sequence models and intermediate layers. - num_lstm_layers (int): - Number of the stacked LSTM layers. dropout_rate (float): Dropout rate for linear layers, CNN and LSTM. + fc_d_model (int): + Dimension of the embedding vectors. tf_d_model (int): Dimension of the embedding vectors. - scale_n (int): - Number of scales in multi-scale system. - clamp_max (float): - Maximum value for limiting the scale weight values. - conv_repeat (int): - Number of CNN layers after the first CNN layer. - weighting_scheme (str): - Name of the methods for estimating the scale weights. - context_vector_type (str): - If 'cos_sim', cosine similarity values are used for the input of the sequence models. - If 'elem_prod', element-wise product values are used for the input of the sequence models. """ def init_weights(self, m): if type(m) == nn.Linear: diff --git a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py index a1d34e1f7480..fed55730e7f1 100644 --- a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py +++ b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py @@ -12,29 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import re -import copy import math -import random -import logging -import itertools -from copy import deepcopy -import concurrent.futures -from cytoolz import groupby -from collections import defaultdict -from typing import Dict, Optional, Tuple, List - -import numpy as np -import soundfile +import torch from tqdm import tqdm -from scipy.stats import norm - -import torch.utils.data -from lhotse.cut.set import mix -from lhotse.cut import CutSet, MixedCut, MonoCut, MixTrack -from lhotse import SupervisionSet, SupervisionSegment, dill_enabled, AudioSource, Recording -from lhotse.utils import uuid4 +from lhotse import SupervisionSet +from lhotse.cut import MixedCut, MonoCut def find_first_nonzero(mat: torch.Tensor, max_cap_val=-1, thres:float = 0.5) -> torch.Tensor: """ @@ -163,7 +145,6 @@ def get_pil_targets(labels: torch.Tensor, preds: torch.Tensor, speaker_permutati torch.Tensor: A tensor of permuted labels that best match the predictions. Shape: (batch_size, num_speakers, num_classes) """ - perm_size = speaker_permutations.shape[0] # Scalar value (num_permutations) permed_labels = labels[:, :, speaker_permutations] # (batch_size, num_classes, num_permutations, num_speakers) # Repeat preds to match permutations for comparison preds_rep = torch.unsqueeze(preds, 2).repeat(1, 1, speaker_permutations.shape[0], 1) # (batch_size, num_speakers, num_permutations, num_classes) @@ -173,65 +154,6 @@ def get_pil_targets(labels: torch.Tensor, preds: torch.Tensor, speaker_permutati max_score_permed_labels = reconstruct_labels(labels, batch_perm_inds) # (batch_size, num_speakers, num_classes) return max_score_permed_labels # (batch_size, num_speakers, num_classes) -def apply_spk_mapping(diar_preds: torch.Tensor, spk_mappings: torch.Tensor) -> torch.Tensor: - """ - Applies a speaker mapping to diar predictions. - - Args: - diar_preds (Tensor): The diar predictions tensor. - Dimension: (batch_size, num_frames, num_speakers) - spk_mappings (Tensor): The speaker mappings tensor. - Dimension: (batch_size, num_speakers) - - Returns: - permuted_diar_preds (Tensor): The permuted diar predictions tensor with the given speaker mappings. - """ - expanded_mappings = spk_mappings.unsqueeze(1).expand(-1, diar_preds.size(1), -1) - permuted_diar_preds = torch.gather(diar_preds, 2, expanded_mappings) - return permuted_diar_preds - -def shuffle_spk_mapping(cuts: list, num_speakers: int, shuffle_spk_mapping: bool = False, pattern= r'<\|spltoken\d+\|>') -> Tuple[CutSet, torch.Tensor]: - """ - Applies a shuffle mapping to speaker text labels in the cuts. - Example: - Original cut.text: - "<|spltoken0|> we do shuffle <|spltoken1|> and map speakers <|spltoken0|> yes <|spltoken2|> we keep dimensions" - Speaker Mapping: [3, 0, 1, 2] - Shuffled cut.text: - "<|spltoken3|> we do shuffle <|spltoken0|> and map speakers <|spltoken3|> yes <|spltoken1|> we keep dimensions" - - Args: - cuts (List[MonoCut, MixedCut]): A list of Cut instances. - num_speakers (int): The total number of speakers. - shuffle_spk_mapping (bool): Whether to shuffle the speaker mappings. - pattern (str): A regular expression pattern for speaker tokens. - - Returns: - cuts (list): The updated CutSet with shuffled speaker mappings. - spk_mappings (Tensor): - If shuffle_speaker_mapping is True, shuffled speaker mappings in batch. - If shuffle_speaker_mapping is False, speaker mappings in batch is not permuted and returns torch.arange() values. - """ - batch_size = len(cuts) - if shuffle_spk_mapping: - permuted_indices = torch.rand(batch_size, num_speakers).argsort(dim=1) - spk_mappings = torch.gather(torch.arange(num_speakers).repeat(batch_size, 1), 1, permuted_indices) - str_pattern = pattern.replace("\\", '') - left_str, right_str = str_pattern.split('d+')[0], str_pattern.split('d+')[1] - for idx, cut in enumerate(cuts): - word_list = [] - for word in deepcopy(cut.text).split(): - if len(re.findall(pattern, word)) > 0: - spk_token_int = int(word.replace(left_str,'').replace(right_str, '')) - new_spk = spk_mappings[idx][spk_token_int] - word_list.append(f'{left_str}{new_spk}{right_str}') - else: - word_list.append(word) - cuts[idx].supervisions[0].text = ' '.join(word_list) - else: - spk_mappings = torch.arange(num_speakers).unsqueeze(0).repeat(batch_size, 1) - return cuts, spk_mappings - def find_segments_from_rttm( recording_id: str, rttms, @@ -268,91 +190,6 @@ def find_segments_from_rttm( and segment.end > start_after + tolerance ] -def speaker_to_target( - a_cut, - num_speakers: int = 4, - num_sample_per_mel_frame: int = 160, - num_mel_frame_per_asr_frame: int = 8, - spk_tar_all_zero: bool = False, - boundary_segments: bool = False, - soft_label: bool = False, - ignore_num_spk_mismatch: bool = True, - soft_thres: float = 0.5, - ): - ''' - Get rttm samples corresponding to one cut, generate speaker mask numpy.ndarray with shape (num_speaker, hidden_length) - This function is needed for speaker diarization with ASR model trainings. - - Args: - a_cut (MonoCut, MixedCut): Lhotse Cut instance which is MonoCut or MixedCut instance. - num_speakers (int): max number of speakers for all cuts ("mask" dim0), 4 by default - num_sample_per_mel_frame (int): number of sample per mel frame, sample_rate / 1000 * window_stride, 160 by default (10ms window stride) - num_mel_frame_per_asr_frame (int): encoder subsampling_factor, 8 by default - spk_tar_all_zero (Tensor): set to True gives all zero "mask" - boundary_segments (bool): set to True to include segments containing the boundary of the cut, False by default for multi-speaker ASR training - soft_label (bool): set to True to use soft label that enables values in [0, 1] range, False by default and leads to binary labels. - ignore_num_spk_mismatch (bool): This is a temporary solution to handle speaker mismatch. Will be removed in the future. - - Returns: - mask (Tensor): speaker mask with shape (num_speaker, hidden_lenght) - ''' - # get cut-related segments from rttms - # basename = os.path.basename(a_cut.rttm_filepath).replace('.rttm', '') - if isinstance(a_cut, MixedCut): - cut_list = [track.cut for track in a_cut.tracks if isinstance(track.cut, MonoCut)] - offsets = [track.offset for track in a_cut.tracks if isinstance(track.cut, MonoCut)] - elif isinstance(a_cut, MonoCut): - cut_list = [a_cut] - offsets = [0] - else: - raise ValueError(f"Unsupported cut type type{a_cut}: only MixedCut and MonoCut are supported") - - segments_total = [] - for i, cut in enumerate(cut_list): - rttms = SupervisionSet.from_rttm(cut.rttm_filepath) - if boundary_segments: # segments with seg_start < total_end and seg_end > total_start are included - segments_iterator = find_segments_from_rttm(recording_id=cut.recording_id, rttms=rttms, start_after=cut.start, end_before=cut.end, tolerance=0.0) - else: # segments with seg_start > total_start and seg_end < total_end are included - segments_iterator = rttms.find(recording_id=cut.recording_id, start_after=cut.start, end_before=cut.end, adjust_offset=True) - - for seg in segments_iterator: - if seg.start < 0: - seg.duration += seg.start - seg.start = 0 - if seg.end > cut.duration: - seg.duration -= seg.end - cut.duration - seg.start += offsets[i] - segments_total.append(seg) - - # apply arrival time sorting to the existing segments - segments_total.sort(key = lambda rttm_sup: rttm_sup.start) - - seen = set() - seen_add = seen.add - speaker_ats = [s.speaker for s in segments_total if not (s.speaker in seen or seen_add(s.speaker))] - - speaker_to_idx_map = { - spk: idx - for idx, spk in enumerate(speaker_ats) - } - if len(speaker_to_idx_map) > num_speakers and not ignore_num_spk_mismatch: # raise error if number of speakers - raise ValueError(f"Number of speakers {len(speaker_to_idx_map)} is larger than the maximum number of speakers {num_speakers}") - - # initialize mask matrices (num_speaker, encoder_hidden_len) - feat_per_sec = int(a_cut.sampling_rate / num_sample_per_mel_frame) # 100 by default - num_samples = get_hidden_length_from_sample_length(a_cut.num_samples, num_sample_per_mel_frame, num_mel_frame_per_asr_frame) - if spk_tar_all_zero: - frame_mask = torch.zeros((num_samples, num_speakers)) - else: - frame_mask = get_mask_from_segments(segments_total, a_cut, speaker_to_idx_map, num_speakers, feat_per_sec, ignore_num_spk_mismatch) - soft_mask = get_soft_mask(frame_mask, num_samples, num_mel_frame_per_asr_frame) - - if soft_label: - mask = soft_mask - else: - mask = (soft_mask > soft_thres).float() - - return mask def get_mask_from_segments(segments: list, a_cut, speaker_to_idx_map: torch.Tensor, num_speakers: int =4, feat_per_sec: int=100, ignore_num_spk_mismatch: bool = False): """ @@ -439,793 +276,88 @@ def get_hidden_length_from_sample_length( hidden_length = math.ceil(mel_frame_count / num_mel_frame_per_asr_frame) return int(hidden_length) -class ConcatenationMeetingSimulator(): - """ - This simulator concatenates the segments from different/same sessions to create a - multi-speaker meeting. - """ - - def __init__( - self, - intra_session_concat_prob: float|List[float] = [0, 1.0, 0.5, 0.2], - data_type: str = "msasr", - min_duration: float = 30.0, - max_duration: float = 40.0, - max_num_speakers: int = 4, - speaker_count_distribution: List[float] = [0, 2, 3, 4], - skip_long_segments: bool = True, - valid_dataset_ids: List[str] = [], - ): - """ - :param intra_session_concat_prob: the probability of concatenating segments from the same - session. [Default: 1] - :param data_type: the type of data to simulate. Either 'msasr' or 'diar'. If 'msasr', - the transcripts are included in the simulation,and the boundary segments are - not included. [Default: 'msasr'] - :param max_duration: the maximum duration of the simulated meeting. [Default: 40.0] - """ - super().__init__() - if isinstance(intra_session_concat_prob, float): - self.intra_session_concat_prob = [intra_session_concat_prob] * (max_num_speakers) - elif len(intra_session_concat_prob) == max_num_speakers: - self.intra_session_concat_prob = intra_session_concat_prob - else: - raise ValueError(f"intra_session_concat_prob must be either a float or a list of floats, but got {intra_session_concat_prob}") - if data_type not in ["msasr", "diar"]: - raise ValueError("data_type must be either 'msasr' or 'diar', but got {data_type}") - self.data_type = data_type - self.min_duration = min_duration - self.max_duration = max_duration - self.max_num_speakers = max_num_speakers - self.speaker_count_distribution = speaker_count_distribution - assert len(speaker_count_distribution) == max_num_speakers, f"Length of speaker_count_distribution {len(speaker_count_distribution)} must be equal to max_num_speakers {max_num_speakers}" - - if skip_long_segments: - self.skip_duration = max_duration / 2 - else: - self.skip_duration = max_duration - - self.valid_dataset_ids = valid_dataset_ids - - def fit(self, cuts) -> CutSet: - """ - Read the manifest file and return a CutSet object. - Each line in the manifest file should be a JSON object representing a segment. - """ - - self.id2cut = {} - self.sess2cut_ids = defaultdict(list) - self.sess2spks = defaultdict(set) - self.data2sess_ids = defaultdict(list) - self.spk2cut_ids = defaultdict(list) - self.data2num_spk2cut_ids = {} - self.sess2num_spk2cut_ids = {} - self.num_spk2cut_ids = {i+1:[] for i in range(self.max_num_speakers)} - for i, cut in tqdm(enumerate(cuts), desc="Reading segments", ncols=100, total=len(cuts)): - if cut.duration > self.skip_duration: - continue - if not hasattr(cut, 'dataset_id') or cut.dataset_id is None: - continue - if self.valid_dataset_ids and cut.dataset_id not in self.valid_dataset_ids: - continue - if cut.dataset_id not in self.data2num_spk2cut_ids: - self.data2num_spk2cut_ids[cut.dataset_id] = defaultdict(list) - if cut.recording_id not in self.sess2num_spk2cut_ids: - self.sess2num_spk2cut_ids[cut.recording_id] = defaultdict(list) - - speakers = cut.global_speaker_ids - if self.data_type == "msasr": - speaker_tokens = set(re.findall(r'<\|spltoken\d+\|>', cut.text)) - if len(speakers) != len(speaker_tokens): - # Lhotse automatically fixes the max duration of the cut, - # resulting in the mismatch of the number of speakers - # and speaker tokens for the last segment - # TODO: need to fix the issue in Lhotse that automatically fixes the max duration - continue - for spk in speakers: - self.spk2cut_ids[spk].append(cut.id) - self.sess2spks[cut.recording_id] = self.sess2spks[cut.recording_id].union(speakers) - - self.id2cut[cut.id] = cut - self.sess2cut_ids[cut.recording_id].append(cut.id) - self.data2num_spk2cut_ids[cut.dataset_id][len(speakers)].append(cut.id) - self.sess2num_spk2cut_ids[cut.recording_id][len(speakers)].append(cut.id) - self.num_spk2cut_ids[len(speakers)].append(cut.id) - if cut.recording_id not in self.data2sess_ids[cut.dataset_id]: - self.data2sess_ids[cut.dataset_id].append(cut.recording_id) - - self.cut_ids = list(self.id2cut.keys()) - self.num_spk2sess_ids = groupby(lambda x: len(self.sess2spks[x]), self.sess2spks.keys()) - - self.data2global_speaker = { - dataset_id: True for dataset_id in self.data2sess_ids.keys() - } - - def _create_mixture(self, n_speakers: int, is_intra_session_concat=False) -> MixedCut: - - db_norm = norm.rvs(-32.05957708631966, 5.66648411405886) # mean and std from Fisher data - - if is_intra_session_concat: - # intra-dataset and intra-session concatenation - tracks, num_speakers = self.get_intra_session_tracks(n_speakers, db_norm=db_norm) - - else: - # intra-dataset but inter-session concatenation - tracks, num_speakers = self.get_inter_session_tracks(n_speakers, db_norm=db_norm) - - cut = MixedCut(id='concat_' + '_'.join([track.cut.id for track in tracks]), tracks=tracks) - if self.data_type == "msasr": - cut = self.reorder_spk_mapping(cut) - - assert self.min_duration <= cut.duration <= self.max_duration, f"Total duration {cut.duration} is not within the range of min {self.min_duration} and max {self.max_duration}" - assert n_speakers == num_speakers, f"Total number of speakers {cut.num_speakers} is not equal to the number of speakers {n_speakers}" - - return cut - - def get_intra_session_tracks(self, n_speakers: int=4, db_norm: float=-25) -> List[MixTrack]: - """ - Get the tracks for the MixedCut object. - """ - session_id = random.choice(self.num_spk2sess_ids[n_speakers]) - - total_duration = 0.0 - total_spk_set = set() - tracks = [] - while True: - cut = self.id2cut[random.choice(self.sess2cut_ids[session_id])] - tracks.append(MixTrack(cut=deepcopy(cut.normalize_loudness(target=db_norm, mix_first=False)), type=type(cut), offset=total_duration)) - total_spk_set = total_spk_set.union(cut.global_speaker_ids) - total_duration += cut.duration - - # break condition - if total_duration >= self.min_duration: - if total_duration > self.max_duration: # exceed the maximum duration, starting over - total_duration = 0.0 - total_spk_set = set() - tracks = [] - session_id = random.choice(self.num_spk2sess_ids[n_speakers]) - if len(total_spk_set) == n_speakers: # meet the number of speakers and duration, break - break - else: - total_duration = 0.0 - total_spk_set = set() - tracks = [] - session_id = random.choice(self.num_spk2sess_ids[n_speakers]) - - return tracks, len(total_spk_set) - - def get_inter_session_tracks(self, n_speakers: int=4, db_norm: float=-25) -> List[MixTrack]: - """ - Get the tracks for the MixedCut object. - """ - sample_cut = self.id2cut[random.choice(self.cut_ids)] - dataset_id = sample_cut.dataset_id - n_spk_list = [n_spk for n_spk, cut_ids in self.data2num_spk2cut_ids[dataset_id].items() if len(cut_ids) > 0] - sum_spk_list = set([i + j for i in n_spk_list for j in n_spk_list]) - - if min(sum_spk_list) > n_speakers: - raise ValueError(f"Cannot generate {n_speakers}-speaker inter session samples by concatenating two samples since the dataset {dataset_id} only have {','.join([str(i) for i in n_spk_list])} speakers.") - - n_spk_left = n_speakers - total_duration = 0.0 - total_spk_set = set() - tracks = [] - num_spk2cut_ids = self.data2num_spk2cut_ids[dataset_id] - while True: - #if n_spk_left == n_speakers: # for more speakers cases - # n_spk = random.choice([n_spk for n_spk in n_spk_list if n_spk < n_spk_left]) - if n_spk_left >= 2: - n_spk = 2 - else: - # n_spk = random.choice([n_spk for n_spk in n_spk_list if n_spk <= n_spk_left]) - n_spk = 1 - - while True: - cut = self.id2cut[random.choice(num_spk2cut_ids[n_spk])] - spks = set(cut.global_speaker_ids) - if not spks.intersection(total_spk_set): - break - - tracks.append(MixTrack(cut=deepcopy(cut.normalize_loudness(target=db_norm, mix_first=False)), type=type(cut), offset=total_duration)) - total_duration += cut.duration - n_spk_left -= n_spk - total_spk_set = total_spk_set.union(spks) - - # break condition - - if total_duration >= self.min_duration: - if total_duration > self.max_duration or len(total_spk_set) < n_speakers: # exceed the maximum duration, starting over - total_duration = 0.0 - n_spk_left = n_speakers - total_spk_set = set() - tracks = [] - if len(total_spk_set) == n_speakers: # meet the number of speakers and duration, break - break - else: - if len(total_spk_set) == n_speakers: # meet the number of speakers, but not the duration, starting over --- TODO: will try to find the segments that only contains those speakers - total_duration = 0.0 - n_spk_left = n_speakers - total_spk_set = set() - tracks = [] - - return tracks, len(total_spk_set) - - def reorder_spk_mapping(self, cut: MixedCut, pattern=r'<\|spltoken\d+\|>') -> str: - """ - Concatenate the texts of the input cuts. - - """ - global_spk_mapping = {} - str_pattern = pattern.replace("\\", '') - left_str, right_str = str_pattern.split('d+') - for i, track in enumerate(cut.tracks): - local_inverse_spk_mapping = {} - local_spk_mapping = {} - for speaker in track.cut.global_speaker_ids: - if speaker not in global_spk_mapping: - global_spk_mapping[speaker] = len(global_spk_mapping) - if speaker not in local_spk_mapping: - local_spk_mapping[speaker] = len(local_spk_mapping) - local_inverse_spk_mapping[len(local_inverse_spk_mapping)] = speaker - - if i != 0: - text = '' - for word in track.cut.text.split(): - if len(re.findall(pattern, word)) > 0: - local_spk_idx = int(word.replace(left_str,'').replace(right_str, '')) - spk = local_inverse_spk_mapping[local_spk_idx] - global_spk_idx = global_spk_mapping[spk] - text += f'{left_str}{global_spk_idx}{right_str}' - else: - text += ' ' + word - track.cut.supervisions[0].text = text - cut.supervisions[i].text = text - else: - cut.supervisions[0].text = track.cut.text - # TODO: need to check the last speaker of last track and the first speaker of the current track - # if they are the same, we need to remove the the speaker token from the current track for segment-level - # Do not need to remove the speaker token for word-level - - return cut - - def apply_speaker_distribution(self, num_meetings: int, speaker_count_distribution) -> Dict[int, int]: - """ - Balance the speaker distribution for the simulated meetings. - Args: - num_meetings: The total number of simulated meetings. - speaker_count_distribution: The speaker count distribution for the simulated meetings. - For each number of speakers, calculate the number of meetings needed to balance the distribution. - """ - - total_spk = sum(speaker_count_distribution) - num_speakers2num_meetings = {} - for i_spk in range(self.max_num_speakers): - num_speakers2num_meetings[i_spk+1] = round(num_meetings * speaker_count_distribution[i_spk] / total_spk) - - return num_speakers2num_meetings - - - @dill_enabled(True) - def simulate(self, - cuts: CutSet, - num_meetings: int = 10000, - seed: int = 0, - num_jobs: int = 1, - ) -> CutSet: - random.seed(seed) - - self.fit(cuts) - - - num_speakers2num_meetings = self.apply_speaker_distribution(num_meetings, self.speaker_count_distribution) - logging.warn(f"Will be generating {(','.join([str(i) for i in num_speakers2num_meetings.values()]))} samples for {(','.join([str(i) for i in num_speakers2num_meetings.keys()]))} speakers given speaker count distribution of {str(self.speaker_count_distribution)}.") - num_speakers2num_meetings[1] = 0 # skip 1-speaker samples - logging.warn(f'But 1-speaker samples will be skipped. Will be generating {sum(num_speakers2num_meetings.values()) - num_speakers2num_meetings[1]} samples in total.') - - # Step 0: Calculate the number of intra-session and inter-session concatentation samples - n_spks = [k for k, v in self.num_spk2cut_ids.items() if len(v) > 0] - valid_sim_n_spks = set([i+j for i in n_spks for j in n_spks]) # valid number of speakers for inter-session samples - n_spk2n_intra_mt, n_spk2n_inter_mt = {i+1:0 for i in range(self.max_num_speakers)}, {i+1:0 for i in range(self.max_num_speakers)} - for n_spk, n_mt in num_speakers2num_meetings.items(): - logging.warn(f"=="*16 + f"{n_spk}-speaker" + "=="*16) - if n_mt <= 0: - logging.warning(f"No concatentation samples for {n_spk} speakers. Will skip simulation for {n_spk} speakers.") - continue - n_intra_mt = int(n_mt * self.intra_session_concat_prob[n_spk-1]) - n_inter_mt = n_mt - n_intra_mt - if n_spk in self.num_spk2sess_ids: - logging.warn(f"Will be genrating {n_intra_mt} {n_spk}-speaker intra-session concatentation samples.") - n_spk2n_intra_mt[n_spk] = n_intra_mt - else: - logging.warning(f"Cannot generate {n_intra_mt} {n_spk}-speaker intra-session samples by concatenating two samples from the same session since we only have samples for {','.join([str(i) for i in n_spks])} speakers.") - n_spk2n_intra_mt[n_spk] = 0 - n_inter_mt = n_mt - if n_spk in valid_sim_n_spks: - logging.warn(f"Will be genrating {n_inter_mt} {n_spk}-speaker inter-session concatentation samples.") - n_spk2n_inter_mt[n_spk] = n_inter_mt - else: - logging.warning(f"Cannot generate {n_inter_mt} {n_spk}-speaker inter-session samples by concatenating two samples from different sessions since we only have samples for {','.join([str(i) for i in n_spks])} speakers.") - if n_spk2n_intra_mt[n_spk] != 0: - n_spk2n_intra_mt[n_spk] = n_mt - logging.warn(f"Will be genrating {n_spk2n_intra_mt[n_spk]} {n_spk}-speaker intra-session concatentation samples instead.") - else: - logging.warning(f"No samples for {n_spk} speakers. Will skip simulation for {n_spk} speakers.") - logging.warn(f"""Will be generating {','.join([str(i) for i in n_spk2n_intra_mt.values()])} intra-session concatentation samples and {','.join([str(i) for i in n_spk2n_inter_mt.values()])} inter-session concatentation samples for {','.join([str(i+1) for i in range(self.max_num_speakers)])} speakers.""") - # Step 1: intra-session - num_intra_meetings = 0 - intra_mixtures = [] - logging.info(f"Simulating intra-session concatentation samples.") - for n_spk, n_mt in n_spk2n_intra_mt.items(): - if n_mt <= 0: - continue - - for i in tqdm(range(n_mt), desc=f"Simulating {n_spk}-speaker intra-session mixtures", ncols=128): - intra_mixtures.append(self._create_mixture(n_speakers=n_spk, is_intra_session_concat=True)) - num_intra_meetings += n_mt - logging.info(f"Finished simulating intra-session concatentation samples. Total number of intra-session concatentation samples: {num_intra_meetings}") - - # Steo 2: inter-session - logging.info(f"Simulating inter-session concatentation samples.") - - num_inter_meetings = 0 - inter_mixtures = [] - for n_spk, n_mt in n_spk2n_inter_mt.items(): - if n_mt <= 0: - continue - - for i in tqdm(range(n_mt), desc=f"Simulating {n_spk}-speaker inter-session mixtures", ncols=128): - inter_mixtures.append(self._create_mixture(n_speakers=n_spk, is_intra_session_concat=False)) - num_inter_meetings += n_mt - logging.info(f"Finished simulating inter-session concatentation samples. Total number of inter-session concatentation samples: {num_inter_meetings}") - - if num_inter_meetings + num_intra_meetings == 0: - logging.warning(f"No samples are generated. Probably the duration of the segments is not within the range of min {self.min_duration//2} and max {self.max_duration//2}, or the speaker count distribution is not correctly set.") - - - # Multi-processing gets slower, TODO - # else: - # futures = [] - # for n_spk, n_mt in num_speakers2num_meetings.items(): - # tp = concurrent.futures.ProcessPoolExecutor(max_workers=num_jobs) - # futures.extend([tp.submit(self._create_mixture, n_spk) for _ in range(n_mt)]) - # pbar = tqdm(total=num_meetings, desc=f"Simulating mixtures", unit="line", ncols=128) - # count = 0 - # for f in concurrent.futures.as_completed(futures): - # count += 1 - # pbar.update() - # mixtures.append(f.result()) - # tp.shutdown() - # pbar.close() - - return CutSet.from_cuts(intra_mixtures + inter_mixtures) - - -class MixMeetingSimulator(): - """ - This simulator Mix the segments from different/same sessions to create a - multi-speaker meeting. - """ - - def __init__( - self, - intra_session_mix_prob: float|List[float] = [0, 0, 0, 0], - data_type: str = "msasr", - min_duration: float = 80.0, - max_duration: float = 100.0, - max_num_speakers: int = 4, - speaker_count_distribution: List[float] = [0, 0, 0.1, 4], - valid_dataset_ids: List[str] = [], +def speaker_to_target( + a_cut, + num_speakers: int = 4, + num_sample_per_mel_frame: int = 160, + num_mel_frame_per_asr_frame: int = 8, + spk_tar_all_zero: bool = False, + boundary_segments: bool = False, + soft_label: bool = False, + ignore_num_spk_mismatch: bool = True, + soft_thres: float = 0.5, ): - """ - :param intra_session_mix_prob: the probability of concatenating segments from the same - session. [Default: 1] - :param data_type: the type of data to simulate. Either 'msasr' or 'diar'. If 'msasr', - the transcripts are included in the simulation,and the boundary segments are - not included. [Default: 'msasr'] - :param max_duration: the maximum duration of the simulated meeting. [Default: 40.0] - """ - super().__init__() - if isinstance(intra_session_mix_prob, float): - self.intra_session_mix_prob = [intra_session_mix_prob] * (max_num_speakers) - elif len(intra_session_mix_prob) == max_num_speakers: - self.intra_session_mix_prob = intra_session_mix_prob - else: - raise ValueError(f"intra_session_mix_prob must be either a float or a list of floats, but got {intra_session_mix_prob}") - if data_type not in ["msasr", "diar"]: - raise ValueError("data_type must be either 'msasr' or 'diar', but got {data_type}") - self.data_type = data_type - self.min_duration = min_duration - self.max_duration = max_duration - self.max_num_speakers = max_num_speakers - self.speaker_count_distribution = speaker_count_distribution - self.valid_dataset_ids = valid_dataset_ids - assert len(speaker_count_distribution) == max_num_speakers, f"Length of speaker_count_distribution {len(speaker_count_distribution)} must be equal to max_num_speakers {max_num_speakers}" - - def fit(self, cuts) -> CutSet: - """ - Read the manifest file and return a CutSet object. - Each line in the manifest file should be a JSON object representing a segment. - """ - - self.id2cut = {} - self.sess2cut_ids = defaultdict(list) - self.sess2spks = defaultdict(set) - self.data2sess_ids = defaultdict(list) - self.spk2cut_ids = defaultdict(list) - self.data2num_spk2cut_ids = {} - self.sess2num_spk2cut_ids = {} - self.num_spk2cut_ids = {i+1:[] for i in range(self.max_num_speakers)} - for i, cut in tqdm(enumerate(cuts), desc="Reading segments", ncols=100, total=len(cuts)): - if not self.min_duration <= cut.duration <= self.max_duration: - continue - if not hasattr(cut, 'dataset_id') or cut.dataset_id is None: - continue - if self.valid_dataset_ids and cut.dataset_id not in self.valid_dataset_ids: - continue - if cut.dataset_id not in self.data2num_spk2cut_ids: - self.data2num_spk2cut_ids[cut.dataset_id] = defaultdict(list) - if cut.recording_id not in self.sess2num_spk2cut_ids: - self.sess2num_spk2cut_ids[cut.recording_id] = defaultdict(list) - - speakers = cut.global_speaker_ids - if self.data_type == "msasr": - speaker_tokens = set(re.findall(r'<\|spltoken\d+\|>', cut.text)) - if len(speakers) != len(speaker_tokens): - # Lhotse automatically fixes the max duration of the cut, - # resulting in the mismatch of the number of speakers - # and speaker tokens for the last segment - # TODO: need to fix the issue in Lhotse that automatically fixes the max duration - continue - for spk in speakers: - self.spk2cut_ids[spk].append(cut.id) - self.sess2spks[cut.recording_id] = self.sess2spks[cut.recording_id].union(speakers) - - self.id2cut[cut.id] = cut - self.sess2cut_ids[cut.recording_id].append(cut.id) - self.data2num_spk2cut_ids[cut.dataset_id][len(speakers)].append(cut.id) - self.sess2num_spk2cut_ids[cut.recording_id][len(speakers)].append(cut.id) - self.num_spk2cut_ids[len(speakers)].append(cut.id) - if cut.recording_id not in self.data2sess_ids[cut.dataset_id]: - self.data2sess_ids[cut.dataset_id].append(cut.recording_id) - - self.cut_ids = list(self.id2cut.keys()) - self.num_spk2sess_ids = groupby(lambda x: len(self.sess2spks[x]), self.sess2spks.keys()) - - self.data2global_speaker = { - dataset_id: True for dataset_id in self.data2sess_ids.keys() - } - - def _create_mixture(self, n_speakers: int, is_intra_session_concat=False) -> MixedCut: - - db_norm = norm.rvs(-32.05957708631966, 5.66648411405886) # mean and std from Fisher data - - if is_intra_session_concat: - # intra-dataset and intra-session concatenation - tracks, num_speakers = self.get_intra_session_tracks(n_speakers, db_norm=db_norm) - - else: - # intra-dataset but inter-session concatenation - tracks, num_speakers = self.get_inter_session_tracks(n_speakers, db_norm=db_norm) - - cut = MixedCut(id='mix_' + '_'.join([track.cut.id for track in tracks]), tracks=tracks) - if self.data_type == "msasr": - cut = self.reorder_spk_mapping(cut) - - assert self.min_duration <= cut.duration <= self.max_duration, f"Total duration {cut.duration} is not within the range of min {self.min_duration} and max {self.max_duration}" - assert n_speakers == num_speakers, f"Total number of speakers {cut.num_speakers} is not equal to the number of speakers {n_speakers}" - - return cut - - def get_intra_session_tracks(self, n_speakers: int=4, db_norm: float=-25) -> List[MixTrack]: - """ - Get the tracks for the MixedCut object. - """ - session_id = random.choice(self.num_spk2sess_ids[n_speakers]) - - total_spk_set = set() - tracks = [] - while True: - cut = self.id2cut[random.choice(self.sess2cut_ids[session_id])] - tracks.append(MixTrack(cut=deepcopy(cut.normalize_loudness(target=db_norm, mix_first=False)), type=type(cut), offset=0)) - total_spk_set = total_spk_set.union(cut.global_speaker_ids) - total_duration = max(total_duration, cut.duration) - - # break condition - if total_duration >= self.min_duration: - if total_duration > self.max_duration: # exceed the maximum duration, starting over - total_duration = 0.0 - total_spk_set = set() - tracks = [] - session_id = random.choice(self.num_spk2sess_ids[n_speakers]) - if len(total_spk_set) == n_speakers: # meet the number of speakers and duration, break - break - else: - total_duration = 0.0 - total_spk_set = set() - tracks = [] - session_id = random.choice(self.num_spk2sess_ids[n_speakers]) - - return tracks, len(total_spk_set) - - def get_inter_session_tracks(self, n_speakers: int=4, db_norm: float=-25) -> List[MixTrack]: - """ - Get the tracks for the MixedCut object. - """ - sample_cut = self.id2cut[random.choice(self.cut_ids)] - dataset_id = sample_cut.dataset_id - n_spk_list = [n_spk for n_spk, cut_ids in self.data2num_spk2cut_ids[dataset_id].items() if len(cut_ids) > 0] - sum_spk_list = set([i + j for i in n_spk_list for j in n_spk_list]) - - if min(sum_spk_list) > n_speakers: - raise ValueError(f"Cannot generate {n_speakers}-speaker inter session samples by concatenating two samples since the dataset {dataset_id} only have {','.join([str(i) for i in n_spk_list])} speakers.") - - n_spk_left = n_speakers - total_duration = 0.0 - total_spk_set = set() - tracks = [] - num_spk2cut_ids = self.data2num_spk2cut_ids[dataset_id] - while True: - if n_spk_left >= 2: - n_spk = 2 - else: - # n_spk = random.choice([n_spk for n_spk in n_spk_list if n_spk <= n_spk_left]) - n_spk = 1 - - while True: - cut = self.id2cut[random.choice(num_spk2cut_ids[n_spk])] - spks = set(cut.global_speaker_ids) - if not spks.intersection(total_spk_set): - break - - tracks.append(MixTrack(cut=deepcopy(cut.normalize_loudness(target=db_norm, mix_first=False)), type=type(cut), offset=0)) - total_duration = max(total_duration, cut.duration) - n_spk_left -= n_spk - total_spk_set = total_spk_set.union(spks) + ''' + Get rttm samples corresponding to one cut, generate speaker mask numpy.ndarray with shape (num_speaker, hidden_length) + This function is needed for speaker diarization with ASR model trainings. - # break condition - - if total_duration >= self.min_duration: - if total_duration > self.max_duration or len(tracks) > 2: # exceed the maximum duration, starting over - total_duration = 0.0 - n_spk_left = n_speakers - total_spk_set = set() - tracks = [] - if len(total_spk_set) == n_speakers: # meet the number of speakers and duration, break - break - else: - if len(total_spk_set) == n_speakers: # meet the number of speakers, but not the duration, starting over --- TODO: will try to find the segments that only contains those speakers - total_duration = 0.0 - n_spk_left = n_speakers - total_spk_set = set() - tracks = [] - - return tracks, len(total_spk_set) + Args: + a_cut (MonoCut, MixedCut): Lhotse Cut instance which is MonoCut or MixedCut instance. + num_speakers (int): max number of speakers for all cuts ("mask" dim0), 4 by default + num_sample_per_mel_frame (int): number of sample per mel frame, sample_rate / 1000 * window_stride, 160 by default (10ms window stride) + num_mel_frame_per_asr_frame (int): encoder subsampling_factor, 8 by default + spk_tar_all_zero (Tensor): set to True gives all zero "mask" + boundary_segments (bool): set to True to include segments containing the boundary of the cut, False by default for multi-speaker ASR training + soft_label (bool): set to True to use soft label that enables values in [0, 1] range, False by default and leads to binary labels. + ignore_num_spk_mismatch (bool): This is a temporary solution to handle speaker mismatch. Will be removed in the future. - def reorder_spk_mapping(self, cut: MixedCut, pattern=r'<\|spltoken\d+\|>') -> str: - """ - Concatenate the texts of the input cuts. - - """ - global_spk_mapping = {} - str_pattern = pattern.replace("\\", '') - left_str, right_str = str_pattern.split('d+') - for i, track in enumerate(cut.tracks): - local_inverse_spk_mapping = {} - local_spk_mapping = {} - for speaker in track.cut.global_speaker_ids: - if speaker not in global_spk_mapping: - global_spk_mapping[speaker] = len(global_spk_mapping) - if speaker not in local_spk_mapping: - local_spk_mapping[speaker] = len(local_spk_mapping) - local_inverse_spk_mapping[len(local_inverse_spk_mapping)] = speaker - - if i != 0: - text = '' - for word in track.cut.text.split(): - if len(re.findall(pattern, word)) > 0: - local_spk_idx = int(word.replace(left_str,'').replace(right_str, '')) - spk = local_inverse_spk_mapping[local_spk_idx] - global_spk_idx = global_spk_mapping[spk] - text += f'{left_str}{global_spk_idx}{right_str}' - else: - text += ' ' + word - track.cut.supervisions[0].text = text - cut.supervisions[i].text = text - else: - cut.supervisions[0].text = track.cut.text - # TODO: need to check the last speaker of last track and the first speaker of the current track - # if they are the same, we need to remove the the speaker token from the current track for segment-level - # Do not need to remove the speaker token for word-level - - return cut + Returns: + mask (Tensor): speaker mask with shape (num_speaker, hidden_lenght) + ''' + # get cut-related segments from rttms + # basename = os.path.basename(a_cut.rttm_filepath).replace('.rttm', '') + if isinstance(a_cut, MixedCut): + cut_list = [track.cut for track in a_cut.tracks if isinstance(track.cut, MonoCut)] + offsets = [track.offset for track in a_cut.tracks if isinstance(track.cut, MonoCut)] + elif isinstance(a_cut, MonoCut): + cut_list = [a_cut] + offsets = [0] + else: + raise ValueError(f"Unsupported cut type type{a_cut}: only MixedCut and MonoCut are supported") - def apply_speaker_distribution(self, num_meetings: int, speaker_count_distribution) -> Dict[int, int]: - """ - Balance the speaker distribution for the simulated meetings. - Args: - num_meetings: The total number of simulated meetings. - speaker_count_distribution: The speaker count distribution for the simulated meetings. - For each number of speakers, calculate the number of meetings needed to balance the distribution. - """ - - total_spk = sum(speaker_count_distribution) - num_speakers2num_meetings = {} - for i_spk in range(self.max_num_speakers): - num_speakers2num_meetings[i_spk+1] = round(num_meetings * speaker_count_distribution[i_spk] / total_spk) + segments_total = [] + for i, cut in enumerate(cut_list): + rttms = SupervisionSet.from_rttm(cut.rttm_filepath) + if boundary_segments: # segments with seg_start < total_end and seg_end > total_start are included + segments_iterator = find_segments_from_rttm(recording_id=cut.recording_id, rttms=rttms, start_after=cut.start, end_before=cut.end, tolerance=0.0) + else: # segments with seg_start > total_start and seg_end < total_end are included + segments_iterator = rttms.find(recording_id=cut.recording_id, start_after=cut.start, end_before=cut.end, adjust_offset=True) - return num_speakers2num_meetings - + for seg in segments_iterator: + if seg.start < 0: + seg.duration += seg.start + seg.start = 0 + if seg.end > cut.duration: + seg.duration -= seg.end - cut.duration + seg.start += offsets[i] + segments_total.append(seg) - @dill_enabled(True) - def simulate(self, - cuts: CutSet, - num_meetings: int = 10000, - seed: int = 0, - num_jobs: int = 1, - ) -> CutSet: - random.seed(seed) - - self.fit(cuts) - - num_speakers2num_meetings = self.apply_speaker_distribution(num_meetings, self.speaker_count_distribution) - logging.warn(f"Will be generating {(','.join([str(i) for i in num_speakers2num_meetings.values()]))} samples for {(','.join([str(i) for i in num_speakers2num_meetings.keys()]))} speakers given speaker count distribution of {str(self.speaker_count_distribution)}.") - num_speakers2num_meetings[1] = 0 # skip 1-speaker samples - logging.warn(f'But 1-speaker samples will be skipped. Will be generating {sum(num_speakers2num_meetings.values()) - num_speakers2num_meetings[1]} samples in total.') - - # Step 0: Calculate the number of intra-session and inter-session concatentation samples - n_spks = [k for k, v in self.num_spk2cut_ids.items() if len(v) > 0] - valid_sim_n_spks = set([i+j for i in n_spks for j in n_spks]) # valid number of speakers for inter-session samples - n_spk2n_intra_mt, n_spk2n_inter_mt = {i+1:0 for i in range(self.max_num_speakers)}, {i+1:0 for i in range(self.max_num_speakers)} - for n_spk, n_mt in num_speakers2num_meetings.items(): - logging.warn(f"=="*16 + f"{n_spk}-speaker" + "=="*16) - if n_mt <= 0: - logging.warning(f"No intra-session concatentation samples for {n_spk} speakers. Will skip simulation for {n_spk} speakers.") - continue - n_intra_mt = int(n_mt * self.intra_session_mix_prob[n_spk-1]) - n_inter_mt = n_mt - n_intra_mt - if n_spk in self.num_spk2sess_ids: - logging.warn(f"Will be genrating {n_intra_mt} {n_spk}-speaker intra-session concatentation samples.") - n_spk2n_intra_mt[n_spk] = n_intra_mt - else: - logging.warning(f"Cannot generate {n_intra_mt} {n_spk}-speaker intra-session samples by concatenating two samples from the same session since we only have samples for {','.join([str(i) for i in n_spks])} speakers.") - n_spk2n_intra_mt[n_spk] = 0 - n_inter_mt = n_mt - if n_spk in valid_sim_n_spks: - logging.warn(f"Will be genrating {n_inter_mt} {n_spk}-speaker inter-session concatentation samples.") - n_spk2n_inter_mt[n_spk] = n_inter_mt - else: - logging.warning(f"Cannot generate {n_inter_mt} {n_spk}-speaker inter-session samples by concatenating two samples from different sessions since we only have samples for {','.join([str(i) for i in n_spks])} speakers.") - if n_spk2n_intra_mt[n_spk] != 0: - n_spk2n_intra_mt[n_spk] = n_mt - logging.warn(f"Will be genrating {n_spk2n_intra_mt[n_spk]} {n_spk}-speaker intra-session concatentation samples instead.") - else: - logging.warning(f"No samples for {n_spk} speakers. Will skip simulation for {n_spk} speakers.") - logging.warn(f"""Will be generating {','.join([str(i) for i in n_spk2n_intra_mt.values()])} intra-session concatentation samples and {','.join([str(i) for i in n_spk2n_inter_mt.values()])} inter-session concatentation samples for {','.join([str(i+1) for i in range(self.max_num_speakers)])} speakers.""") - # Step 1: intra-session - num_intra_meetings = 0 - intra_mixtures = [] - logging.info(f"Simulating intra-session concatentation samples.") - for n_spk, n_mt in n_spk2n_intra_mt.items(): - if n_mt <= 0: - continue + # apply arrival time sorting to the existing segments + segments_total.sort(key = lambda rttm_sup: rttm_sup.start) - for i in tqdm(range(n_mt), desc=f"Simulating {n_spk}-speaker intra-session mixtures", ncols=128): - intra_mixtures.append(self._create_mixture(n_speakers=n_spk, is_intra_session_concat=True)) - num_intra_meetings += n_mt - logging.info(f"Finished simulating intra-session concatentation samples. Total number of intra-session concatentation samples: {num_intra_meetings}") - - # Steo 2: inter-session - logging.info(f"Simulating inter-session concatentation samples.") + seen = set() + seen_add = seen.add + speaker_ats = [s.speaker for s in segments_total if not (s.speaker in seen or seen_add(s.speaker))] + + speaker_to_idx_map = { + spk: idx + for idx, spk in enumerate(speaker_ats) + } + if len(speaker_to_idx_map) > num_speakers and not ignore_num_spk_mismatch: # raise error if number of speakers + raise ValueError(f"Number of speakers {len(speaker_to_idx_map)} is larger than the maximum number of speakers {num_speakers}") - num_inter_meetings = 0 - inter_mixtures = [] - for n_spk, n_mt in n_spk2n_inter_mt.items(): - if n_mt <= 0: - continue - - for i in tqdm(range(n_mt), desc=f"Simulating {n_spk}-speaker inter-session mixtures", ncols=128): - inter_mixtures.append(self._create_mixture(n_speakers=n_spk, is_intra_session_concat=False)) - num_inter_meetings += n_mt - logging.info(f"Finished simulating inter-session concatentation samples. Total number of inter-session concatentation samples: {num_inter_meetings}") - - if num_inter_meetings + num_intra_meetings == 0: - logging.warning(f"No samples are generated. Probably the duration of the segments is not within the range of min {self.min_duration} and max {self.max_duration}, or the speaker count distribution is not correctly set.") - - return CutSet.from_cuts(intra_mixtures + inter_mixtures) - -class LibriSpeechMixSimulator(): - - def __init__( - self, - min_duration: float = 80.0, - max_duration: float = 100.0, - n_mix_speakers: List[int] = [1, 2, 3], - speaker_count_distribution: List[float] = [1, 1, 1], - ): - """ - :param min_duration: the minimum duration of the simulated meeting. [Default: 80.0] - :param max_duration: the maximum duration of the simulated meeting. [Default: 100.0] - """ - super().__init__() - self.min_duration = min_duration - self.max_duration = max_duration - self.n_mix_speakers = n_mix_speakers - self.speaker_count_distribution = speaker_count_distribution - assert len(speaker_count_distribution) == len(n_mix_speakers), f"Length of speaker_count_distribution {len(speaker_count_distribution)} must be equal to max_num_speakers {len(n_mix_speakers)}" - - def fit(self, cuts) -> CutSet: - pass - - def simulate(self, - cuts: CutSet, - num_meetings: int = 10000, - seed: int = 0, - num_jobs: int = 1, - ) -> CutSet: - random.seed(seed) - - cut_set = [] - for n_speakers, n_mt in zip(self.n_mix_speakers, self.speaker_count_distribution): - if n_mt <= 0: - continue - for i in tqdm(range(n_mt), desc=f"Simulating {n_speakers}-speaker mixtures", ncols=128): - cut_set.append(self._create_mixture(n_speakers=n_speakers)) - return CutSet.from_cuts(cut_set) - -class LibriSpeechMixGenerator(): - def __init__(self): - pass - - def generate(self, cuts): - cut_set = [] - for cut in tqdm(cuts): - offsets = cut.delays - durations = cut.durations - wavs = cut.wavs - texts = cut.texts - speakers = cut.speakers + # initialize mask matrices (num_speaker, encoder_hidden_len) + feat_per_sec = int(a_cut.sampling_rate / num_sample_per_mel_frame) # 100 by default + num_samples = get_hidden_length_from_sample_length(a_cut.num_samples, num_sample_per_mel_frame, num_mel_frame_per_asr_frame) + if spk_tar_all_zero: + frame_mask = torch.zeros((num_samples, num_speakers)) + else: + frame_mask = get_mask_from_segments(segments_total, a_cut, speaker_to_idx_map, num_speakers, feat_per_sec, ignore_num_spk_mismatch) + soft_mask = get_soft_mask(frame_mask, num_samples, num_mel_frame_per_asr_frame) - tracks = [] - for i, (offset, duration, wav, text, speaker) in enumerate(zip(offsets, durations, wavs, texts, speakers)): - wav_dur = soundfile.info(wav).duration - wav_samples = soundfile.info(wav).frames - custom = { - 'speaker': speaker, - 'text': text, - } - cut_1spk = MonoCut( - id=wav.split('/')[-1].replace('.wav', ''), - start=0, - duration=duration, - channel=0, - supervisions=[], - recording=Recording( - id=wav.split('/')[-1].replace('.wav', ''), - sources=[ - AudioSource( - type='file', - channels=[0], - source=wav - ) - ], - sampling_rate=16000, - num_samples=wav_samples, - duration=wav_dur - ), - custom=custom - ) + if soft_label: + mask = soft_mask + else: + mask = (soft_mask > soft_thres).float() - tracks.append(MixTrack(cut=cut_1spk, type=type(cut_1spk), offset=offset)) - sup = SupervisionSegment( - id=cut.id, - recording_id=cut.recording_id, - start=0, - duration=offset+wav_dur, - text=cut.text, - ) - tracks[0].cut.supervisions.append(sup) - cut_multi_spk = MixedCut(id=cut.id, tracks=tracks) - - cut_set.append(cut_multi_spk) - - return CutSet.from_cuts(cut_set) \ No newline at end of file + return mask \ No newline at end of file diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index 80b3e1f918b8..492041162cff 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -21,7 +21,6 @@ from typing import Dict, List, Tuple, Union import numpy as np -from omegaconf import OmegaConf from omegaconf.listconfig import ListConfig import soundfile as sf import torch @@ -981,9 +980,8 @@ def get_subsegments( subsegments.append([start, min(duration, window)]) elif slices > 0: # What if slcies = 0 ? start_col = torch.arange(offset, slice_end, shift)[:slices] - dur_col = window * torch.ones(slices) - dur_col = torch.min(slice_end*torch.ones_like(start_col)- start_col, window * torch.ones_like(start_col)) - dur_col = torch.round(dur_col, decimals=decimals) + dur_col_raw = torch.min(slice_end*torch.ones_like(start_col)- start_col, window * torch.ones_like(start_col)) + dur_col = torch.round(dur_col_raw, decimals=decimals) valid_mask = dur_col >= min_subsegment_duration valid_subsegments = torch.stack([start_col[valid_mask], dur_col[valid_mask]], dim=1) subsegments = valid_subsegments.tolist() diff --git a/nemo/collections/asr/parts/utils/vad_utils.py b/nemo/collections/asr/parts/utils/vad_utils.py index 192c42375dca..0ccfef9b9e8b 100644 --- a/nemo/collections/asr/parts/utils/vad_utils.py +++ b/nemo/collections/asr/parts/utils/vad_utils.py @@ -35,7 +35,6 @@ from sklearn.metrics import roc_auc_score from sklearn.model_selection import ParameterGrid from tqdm import tqdm -from nemo.collections.asr.parts.utils.speaker_utils import timestamps_to_pyannote_object from nemo.collections.asr.models import EncDecClassificationModel, EncDecFrameClassificationModel from nemo.collections.common.parts.preprocessing.manifest import get_full_path from nemo.utils import logging From 40e9f95d879ba9bc1fed25d04e56806e6d24fdec Mon Sep 17 00:00:00 2001 From: tango4j Date: Fri, 15 Nov 2024 01:15:55 +0000 Subject: [PATCH 06/16] Apply isort and black reformatting Signed-off-by: tango4j --- .../neural_diarizer/e2e_diarize_speech.py | 17 +++--- nemo/collections/asr/models/__init__.py | 2 +- .../asr/models/sortformer_diar_models.py | 2 +- .../asr/parts/utils/asr_multispeaker_utils.py | 60 +++++++++++-------- .../asr/parts/utils/speaker_utils.py | 4 +- 5 files changed, 48 insertions(+), 37 deletions(-) diff --git a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py index 5237b5c3c67b..72d7977840ce 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py @@ -125,7 +125,7 @@ def optuna_suggest_params(postprocessing_cfg: PostProcessingParams, trial: optun Suggests hyperparameters for postprocessing using Optuna. See the following link for `trial` instance in Optuna framework. https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html#optuna.trial.Trial - + Args: postprocessing_cfg (PostProcessingParams): The current postprocessing configuration. trial (optuna.Trial): The Optuna trial object used to suggest hyperparameters. @@ -390,13 +390,14 @@ def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: out_rttm_dir=cfg.out_rttm_dir, ) logging.info(f"Evaluating the model on the {len(diar_model_preds_total_list)} audio segments...") - score_labels(AUDIO_RTTM_MAP=infer_audio_rttm_dict, - all_reference=all_refs, - all_hypothesis=all_hyps, - all_uem=all_uems, - collar=cfg.collar, - ignore_overlap=cfg.ignore_overlap - ) + score_labels( + AUDIO_RTTM_MAP=infer_audio_rttm_dict, + all_reference=all_refs, + all_hypothesis=all_hyps, + all_uem=all_uems, + collar=cfg.collar, + ignore_overlap=cfg.ignore_overlap, + ) logging.info(f"PostProcessingParams: {postprocessing_cfg}") diff --git a/nemo/collections/asr/models/__init__.py b/nemo/collections/asr/models/__init__.py index f27828a6b11e..34dead15b33d 100644 --- a/nemo/collections/asr/models/__init__.py +++ b/nemo/collections/asr/models/__init__.py @@ -20,7 +20,6 @@ EncDecFrameClassificationModel, ) from nemo.collections.asr.models.clustering_diarizer import ClusteringDiarizer -from nemo.collections.asr.models.sortformer_diar_models import SortformerEncLabelModel from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE from nemo.collections.asr.models.ctc_models import EncDecCTCModel from nemo.collections.asr.models.hybrid_rnnt_ctc_bpe_models import EncDecHybridRNNTCTCBPEModel @@ -36,6 +35,7 @@ from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel from nemo.collections.asr.models.slu_models import SLUIntentSlotBPEModel +from nemo.collections.asr.models.sortformer_diar_models import SortformerEncLabelModel from nemo.collections.asr.models.ssl_models import ( EncDecDenoiseMaskedTokenPredModel, EncDecMaskedTokenPredModel, diff --git a/nemo/collections/asr/models/sortformer_diar_models.py b/nemo/collections/asr/models/sortformer_diar_models.py index fd9d01f33f2b..939b03e7a5ac 100644 --- a/nemo/collections/asr/models/sortformer_diar_models.py +++ b/nemo/collections/asr/models/sortformer_diar_models.py @@ -571,7 +571,7 @@ def test_batch( self.preds_total_list.extend(torch.split(preds, [1] * preds.shape[0])) torch.cuda.empty_cache() self._get_aux_test_batch_evaluations(batch_idx, preds, targets, target_lens) - + logging.info(f"Batch F1Acc. MEAN: {torch.mean(torch.tensor(self.batch_f1_accs_list))}") logging.info(f"Batch Precision MEAN: {torch.mean(torch.tensor(self.batch_precision_list))}") logging.info(f"Batch Recall MEAN: {torch.mean(torch.tensor(self.batch_recall_list))}") diff --git a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py index 5e19b7abeb38..3f40f5cd3e39 100644 --- a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py +++ b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py @@ -13,10 +13,11 @@ # limitations under the License. import math + import torch -from tqdm import tqdm from lhotse import SupervisionSet from lhotse.cut import MixedCut, MonoCut +from tqdm import tqdm def find_first_nonzero(mat: torch.Tensor, max_cap_val=-1, thres: float = 0.5) -> torch.Tensor: @@ -173,6 +174,7 @@ def get_pil_targets(labels: torch.Tensor, preds: torch.Tensor, speaker_permutati max_score_permed_labels = reconstruct_labels(labels, batch_perm_inds) # (batch_size, num_speakers, num_classes) return max_score_permed_labels # (batch_size, num_speakers, num_classes) + def find_segments_from_rttm( recording_id: str, rttms, @@ -211,8 +213,6 @@ def find_segments_from_rttm( ] - - def get_mask_from_segments( segments: list, a_cut, @@ -305,17 +305,18 @@ def get_hidden_length_from_sample_length( hidden_length = math.ceil(mel_frame_count / num_mel_frame_per_asr_frame) return int(hidden_length) + def speaker_to_target( a_cut, - num_speakers: int = 4, - num_sample_per_mel_frame: int = 160, - num_mel_frame_per_asr_frame: int = 8, + num_speakers: int = 4, + num_sample_per_mel_frame: int = 160, + num_mel_frame_per_asr_frame: int = 8, spk_tar_all_zero: bool = False, boundary_segments: bool = False, soft_label: bool = False, ignore_num_spk_mismatch: bool = True, soft_thres: float = 0.5, - ): +): ''' Get rttm samples corresponding to one cut, generate speaker mask numpy.ndarray with shape (num_speaker, hidden_length) This function is needed for speaker diarization with ASR model trainings. @@ -329,7 +330,7 @@ def speaker_to_target( boundary_segments (bool): set to True to include segments containing the boundary of the cut, False by default for multi-speaker ASR training soft_label (bool): set to True to use soft label that enables values in [0, 1] range, False by default and leads to binary labels. ignore_num_spk_mismatch (bool): This is a temporary solution to handle speaker mismatch. Will be removed in the future. - + Returns: mask (Tensor): speaker mask with shape (num_speaker, hidden_lenght) ''' @@ -343,14 +344,18 @@ def speaker_to_target( offsets = [0] else: raise ValueError(f"Unsupported cut type type{a_cut}: only MixedCut and MonoCut are supported") - + segments_total = [] for i, cut in enumerate(cut_list): rttms = SupervisionSet.from_rttm(cut.rttm_filepath) - if boundary_segments: # segments with seg_start < total_end and seg_end > total_start are included - segments_iterator = find_segments_from_rttm(recording_id=cut.recording_id, rttms=rttms, start_after=cut.start, end_before=cut.end, tolerance=0.0) - else: # segments with seg_start > total_start and seg_end < total_end are included - segments_iterator = rttms.find(recording_id=cut.recording_id, start_after=cut.start, end_before=cut.end, adjust_offset=True) + if boundary_segments: # segments with seg_start < total_end and seg_end > total_start are included + segments_iterator = find_segments_from_rttm( + recording_id=cut.recording_id, rttms=rttms, start_after=cut.start, end_before=cut.end, tolerance=0.0 + ) + else: # segments with seg_start > total_start and seg_end < total_end are included + segments_iterator = rttms.find( + recording_id=cut.recording_id, start_after=cut.start, end_before=cut.end, adjust_offset=True + ) for seg in segments_iterator: if seg.start < 0: @@ -360,28 +365,31 @@ def speaker_to_target( seg.duration -= seg.end - cut.duration seg.start += offsets[i] segments_total.append(seg) - + # apply arrival time sorting to the existing segments - segments_total.sort(key = lambda rttm_sup: rttm_sup.start) + segments_total.sort(key=lambda rttm_sup: rttm_sup.start) seen = set() seen_add = seen.add speaker_ats = [s.speaker for s in segments_total if not (s.speaker in seen or seen_add(s.speaker))] - - speaker_to_idx_map = { - spk: idx - for idx, spk in enumerate(speaker_ats) - } + + speaker_to_idx_map = {spk: idx for idx, spk in enumerate(speaker_ats)} if len(speaker_to_idx_map) > num_speakers and not ignore_num_spk_mismatch: # raise error if number of speakers - raise ValueError(f"Number of speakers {len(speaker_to_idx_map)} is larger than the maximum number of speakers {num_speakers}") - + raise ValueError( + f"Number of speakers {len(speaker_to_idx_map)} is larger than the maximum number of speakers {num_speakers}" + ) + # initialize mask matrices (num_speaker, encoder_hidden_len) - feat_per_sec = int(a_cut.sampling_rate / num_sample_per_mel_frame) # 100 by default - num_samples = get_hidden_length_from_sample_length(a_cut.num_samples, num_sample_per_mel_frame, num_mel_frame_per_asr_frame) - if spk_tar_all_zero: + feat_per_sec = int(a_cut.sampling_rate / num_sample_per_mel_frame) # 100 by default + num_samples = get_hidden_length_from_sample_length( + a_cut.num_samples, num_sample_per_mel_frame, num_mel_frame_per_asr_frame + ) + if spk_tar_all_zero: frame_mask = torch.zeros((num_samples, num_speakers)) else: - frame_mask = get_mask_from_segments(segments_total, a_cut, speaker_to_idx_map, num_speakers, feat_per_sec, ignore_num_spk_mismatch) + frame_mask = get_mask_from_segments( + segments_total, a_cut, speaker_to_idx_map, num_speakers, feat_per_sec, ignore_num_spk_mismatch + ) soft_mask = get_soft_mask(frame_mask, num_samples, num_mel_frame_per_asr_frame) if soft_label: diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index 09b395a28f8d..87ad7eda59d9 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -982,7 +982,9 @@ def get_subsegments( subsegments.append([start, min(duration, window)]) elif slices > 0: # What if slcies = 0 ? start_col = torch.arange(offset, slice_end, shift)[:slices] - dur_col_raw = torch.min(slice_end*torch.ones_like(start_col)- start_col, window * torch.ones_like(start_col)) + dur_col_raw = torch.min( + slice_end * torch.ones_like(start_col) - start_col, window * torch.ones_like(start_col) + ) dur_col = torch.round(dur_col_raw, decimals=decimals) valid_mask = dur_col >= min_subsegment_duration valid_subsegments = torch.stack([start_col[valid_mask], dur_col[valid_mask]], dim=1) From f7f84bb386fc0d8c11012e6e246a5d058caa0a72 Mon Sep 17 00:00:00 2001 From: taejinp Date: Thu, 14 Nov 2024 17:53:31 -0800 Subject: [PATCH 07/16] Adding docstrings to reflect the PR comments Signed-off-by: taejinp --- .../asr/data/audio_to_diar_label.py | 19 +++++++++++++------ .../asr/parts/utils/asr_multispeaker_utils.py | 3 --- .../asr/parts/utils/speaker_utils.py | 5 ----- .../common/parts/preprocessing/collections.py | 3 --- 4 files changed, 13 insertions(+), 17 deletions(-) diff --git a/nemo/collections/asr/data/audio_to_diar_label.py b/nemo/collections/asr/data/audio_to_diar_label.py index b00338743a43..f08788d2d231 100644 --- a/nemo/collections/asr/data/audio_to_diar_label.py +++ b/nemo/collections/asr/data/audio_to_diar_label.py @@ -149,18 +149,25 @@ def get_subsegments_to_timestamps( subsegments: List[Tuple[float, float]], feat_per_sec: int = 100, max_end_ts: float = None, decimals=2 ): """ - Convert subsegment timestamps to scale timestamps by multiplying with the feature rate and rounding. - All `ts` related tensors are dimensioned as (N, 2), where N is the number of subsegments. - + Convert subsegment timestamps to scale timestamps by multiplying with the feature rate (`feat_per_sec`) and rounding. + Segment is consisted of many subsegments and sugsegments are equivalent to `frames` in end-to-end speaker diarization models. + Args: subsegments (List[Tuple[float, float]]): - A list of tuples where each tuple contains the start and end times of a subsegment. + A list of tuples where each tuple contains the start and end times of a subsegment (frames in end-to-end models). + >>> subsegments = [[t0_start, t0_duration], [t1_start, t1_duration],..., [tN_start, tN_duration]] feat_per_sec (int, optional): The number of feature frames per second. Defaults to 100. max_end_ts (float, optional): The maximum end timestamp to clip the results. If None, no clipping is applied. Defaults to None. decimals (int, optional): The number of decimal places to round the timestamps. Defaults to 2. + + Example: + Segments starting from 0.0 and ending at 69.2 seconds. + If hop-length is 0.08 and the subsegment (frame) length is 0.16 seconds, + there are 864 = (69.2 - 0.16)/0.08 + 1 subsegments (frames in end-to-end models) in this segment. + >>> subsegments = [[[0.0, 0.16], [0.08, 0.16], ..., [69.04, 0.16], [69.12, 0.08]] Returns: ts (torch.tensor): @@ -175,7 +182,7 @@ def get_subsegments_to_timestamps( return ts -def extract_frame_info_from_rttm(uniq_id, offset, duration, rttm_lines, round_digits=3): +def extract_frame_info_from_rttm(offset, duration, rttm_lines, round_digits=3): """ Extracts RTTM lines containing speaker labels, start time, and end time for a given audio segment. @@ -1093,7 +1100,7 @@ def parse_rttm_for_targets_and_lens(self, uniq_id, rttm_file, offset, duration, [[0., 1.], [0., 1.], [1., 1.], [1., 0.], [1., 0.], ..., [0., 1.]] """ rttm_lines = open(rttm_file).readlines() - rttm_timestamps, sess_to_global_spkids = extract_frame_info_from_rttm(uniq_id, offset, duration, rttm_lines) + rttm_timestamps, sess_to_global_spkids = extract_frame_info_from_rttm(offset, duration, rttm_lines) fr_level_target = get_frame_targets_from_rttm( rttm_timestamps=rttm_timestamps, diff --git a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py index 5e19b7abeb38..785a7c41d32c 100644 --- a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py +++ b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py @@ -210,9 +210,6 @@ def find_segments_from_rttm( if segment.start < end_before + tolerance and segment.end > start_after + tolerance ] - - - def get_mask_from_segments( segments: list, a_cut, diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index 09b395a28f8d..cd831f054106 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -33,11 +33,6 @@ from nemo.collections.asr.parts.utils.offline_clustering import SpeakerClustering, get_argmin_mat, split_input_data from nemo.utils import logging -""" -This file contains all the utility functions required for speaker embeddings part in diarization scripts -""" - - def get_uniqname_from_filepath(filepath): """ Return base name from provided filepath diff --git a/nemo/collections/common/parts/preprocessing/collections.py b/nemo/collections/common/parts/preprocessing/collections.py index 632ec06bc647..8acd9fc08743 100644 --- a/nemo/collections/common/parts/preprocessing/collections.py +++ b/nemo/collections/common/parts/preprocessing/collections.py @@ -1310,9 +1310,6 @@ def __init__( if isinstance(audio_file, list): if len(audio_file) == 0: raise ValueError(f"Empty audio file list: {audio_file}") - audio_file_name = sorted(audio_file)[0] - else: - audio_file_name = audio_file file_id, _ = os.path.splitext(os.path.basename(audio_file)) self.mapping[file_id] = len(data) - 1 From 4134e2533c80c76fab190a7ac62436a24a0016a2 Mon Sep 17 00:00:00 2001 From: taejinp Date: Thu, 14 Nov 2024 17:56:03 -0800 Subject: [PATCH 08/16] removed the unused find_first_nonzero Signed-off-by: taejinp --- nemo/collections/asr/data/audio_to_diar_label.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nemo/collections/asr/data/audio_to_diar_label.py b/nemo/collections/asr/data/audio_to_diar_label.py index f08788d2d231..0a40d832eaf0 100644 --- a/nemo/collections/asr/data/audio_to_diar_label.py +++ b/nemo/collections/asr/data/audio_to_diar_label.py @@ -20,7 +20,6 @@ import numpy as np import torch -from nemo.collections.asr.parts.utils.asr_multispeaker_utils import find_first_nonzero from nemo.collections.asr.parts.utils.offline_clustering import get_argmin_mat from nemo.collections.asr.parts.utils.speaker_utils import convert_rttm_line, get_subsegments, prepare_split_data from nemo.collections.common.parts.preprocessing.collections import ( From 5dd4d4c6d2d4eee96f40a0b34b9139f59157d208 Mon Sep 17 00:00:00 2001 From: tango4j Date: Fri, 15 Nov 2024 01:56:16 +0000 Subject: [PATCH 09/16] Apply isort and black reformatting Signed-off-by: tango4j --- nemo/collections/asr/data/audio_to_diar_label.py | 8 ++++---- .../collections/asr/parts/utils/asr_multispeaker_utils.py | 1 + nemo/collections/asr/parts/utils/speaker_utils.py | 1 + 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/nemo/collections/asr/data/audio_to_diar_label.py b/nemo/collections/asr/data/audio_to_diar_label.py index f08788d2d231..535ae3301173 100644 --- a/nemo/collections/asr/data/audio_to_diar_label.py +++ b/nemo/collections/asr/data/audio_to_diar_label.py @@ -151,7 +151,7 @@ def get_subsegments_to_timestamps( """ Convert subsegment timestamps to scale timestamps by multiplying with the feature rate (`feat_per_sec`) and rounding. Segment is consisted of many subsegments and sugsegments are equivalent to `frames` in end-to-end speaker diarization models. - + Args: subsegments (List[Tuple[float, float]]): A list of tuples where each tuple contains the start and end times of a subsegment (frames in end-to-end models). @@ -162,10 +162,10 @@ def get_subsegments_to_timestamps( The maximum end timestamp to clip the results. If None, no clipping is applied. Defaults to None. decimals (int, optional): The number of decimal places to round the timestamps. Defaults to 2. - - Example: + + Example: Segments starting from 0.0 and ending at 69.2 seconds. - If hop-length is 0.08 and the subsegment (frame) length is 0.16 seconds, + If hop-length is 0.08 and the subsegment (frame) length is 0.16 seconds, there are 864 = (69.2 - 0.16)/0.08 + 1 subsegments (frames in end-to-end models) in this segment. >>> subsegments = [[[0.0, 0.16], [0.08, 0.16], ..., [69.04, 0.16], [69.12, 0.08]] diff --git a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py index 6412d88f4c0f..e945439bf8fa 100644 --- a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py +++ b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py @@ -210,6 +210,7 @@ def find_segments_from_rttm( if segment.start < end_before + tolerance and segment.end > start_after + tolerance ] + def get_mask_from_segments( segments: list, a_cut, diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index cb1244eef660..fd6f71dc0502 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -33,6 +33,7 @@ from nemo.collections.asr.parts.utils.offline_clustering import SpeakerClustering, get_argmin_mat, split_input_data from nemo.utils import logging + def get_uniqname_from_filepath(filepath): """ Return base name from provided filepath From 037f61e85a9739ae8cadcf63e40b0122f218de6c Mon Sep 17 00:00:00 2001 From: taejinp Date: Fri, 15 Nov 2024 12:33:06 -0800 Subject: [PATCH 10/16] Fixed all pylint issues Signed-off-by: taejinp --- ...rtformer_diarizer_hybrid_loss_4spk-v1.yaml | 2 +- .../neural_diarizer/e2e_diarize_speech.py | 3 +- .../neural_diarizer/sortformer_diar_train.py | 19 +-- .../asr/data/audio_to_diar_label.py | 96 +++++++---- nemo/collections/asr/metrics/der.py | 40 +++-- .../asr/metrics/multi_binary_acc.py | 12 ++ .../asr/models/sortformer_diar_models.py | 3 +- .../asr/modules/sortformer_modules.py | 1 - .../asr/parts/utils/speaker_utils.py | 152 +++++++++++------- nemo/collections/asr/parts/utils/vad_utils.py | 110 ++++++++----- .../common/parts/preprocessing/collections.py | 49 +++--- 11 files changed, 303 insertions(+), 184 deletions(-) diff --git a/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml b/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml index 04409a4cd60a..4a6d8f242d36 100644 --- a/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml +++ b/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml @@ -1,4 +1,4 @@ -sortformer_diarizer_hybrid_loss_4spk-v1.yaml# Sortformer Diarizer is an end-to-end speaker diarization model that is solely based on Transformer-encoder type of architecture. +# Sortformer Diarizer is an end-to-end speaker diarization model that is solely based on Transformer-encoder type of architecture. # Model name convention for Sortformer Diarizer: sortformer_diarizer___.yaml # (Example) `sortformer_diarizer_hybrid_loss_4spk-v1.yaml`. # Sortformer Diarizer model checkpoint (.ckpt) and NeMo file (.nemo) contain Fast Conformer Encoder model (NEST Encoder) and the pre-trained NEST model is loaded along with the Transformer Encoder layers. diff --git a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py index 72d7977840ce..0f90e70eff80 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py @@ -23,13 +23,12 @@ import os import tempfile from dataclasses import dataclass, is_dataclass -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Union import optuna import pytorch_lightning as pl import torch import yaml -from hydra.core.config_store import ConfigStore from omegaconf import OmegaConf from pytorch_lightning import seed_everything from tqdm import tqdm diff --git a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py index 3ba0dbc3ed19..75980d342c65 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,33 +22,26 @@ from nemo.utils.exp_manager import exp_manager """ -Example training session (single GPU training on telephonic datasets) +Example training session (single node training) -python ./multiscale_diar_decoder.py --config-path='../conf/neural_diarizer' --config-name='msdd_5scl_15_05_50Povl_256x3x32x2.yaml' \ +python ./sortformer_diar_train.py --config-path='../conf/neural_diarizer' --config-name='' \ trainer.devices=1 \ - model.base.diarizer.speaker_embeddings.model_path="titanet_large" \ model.train_ds.manifest_filepath="" \ model.validation_ds.manifest_filepath="" \ - model.train_ds.emb_dir="" \ - model.validation_ds.emb_dir="" \ exp_manager.name='sample_train' \ - exp_manager.exp_dir='./msdd_exp' + exp_manager.exp_dir='./sortformer_diar_train' """ seed_everything(42) - -@hydra_runner(config_path="../conf/neural_diarizer", config_name="msdd_5scl_15_05_50Povl_256x3x32x2.yaml") +@hydra_runner(config_path="../conf/neural_diarizer", config_name="sortformer_diarizer_hybrid_loss_4spk-v1.yaml") def main(cfg): logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') trainer = pl.Trainer(**cfg.trainer) exp_manager(trainer, cfg.get("exp_manager", None)) sortformer_model = SortformerEncLabelModel(cfg=cfg.model, trainer=trainer) - # Initialize the weights of the model from another model, if provided via config sortformer_model.maybe_init_from_pretrained_checkpoint(cfg) trainer.fit(sortformer_model) - if __name__ == '__main__': - - main() + main() \ No newline at end of file diff --git a/nemo/collections/asr/data/audio_to_diar_label.py b/nemo/collections/asr/data/audio_to_diar_label.py index b6b398743198..f47b5ca11f43 100644 --- a/nemo/collections/asr/data/audio_to_diar_label.py +++ b/nemo/collections/asr/data/audio_to_diar_label.py @@ -81,7 +81,8 @@ def extract_seg_info_from_rttm(rttm_lines, mapping_dict=None, target_spks=None): mapping_dict (dict): Mapping between the estimated speakers and the speakers in the ground-truth annotation. `mapping_dict` variable is only provided when the inference mode is running in sequence-eval mode. - Sequence eval mode uses the mapping between the estimated speakers and the speakers in ground-truth annotation. + Sequence eval mode uses the mapping between the estimated speakers and the speakers + in ground-truth annotation. Returns: rttm_tup (tuple): Tuple containing lists of start time, end time and speaker labels. @@ -113,12 +114,14 @@ def assign_frame_level_spk_vector(rttm_timestamps, round_digits, frame_per_sec, Args: rttm_timestamps (list): List containing start and end time for each speaker segment label. - stt_list, end_list and speaker_list are contained. + `stt_list`, `end_list` and `speaker_list` are contained. frame_per_sec (int): - Number of feature frames per second. This quantity is determined by window_stride variable in preprocessing module. + Number of feature frames per second. This quantity is determined by + `window_stride` variable in preprocessing module. target_spks (tuple): - Speaker indices that are generated from combinations. If there are only one or two speakers, - only a single target_spks variable is generated. + Speaker indices that are generated from combinations. + If there are only one or two speakers, + only a single `target_spks` variable is generated. Returns: fr_level_target (torch.tensor): @@ -148,12 +151,14 @@ def get_subsegments_to_timestamps( subsegments: List[Tuple[float, float]], feat_per_sec: int = 100, max_end_ts: float = None, decimals=2 ): """ - Convert subsegment timestamps to scale timestamps by multiplying with the feature rate (`feat_per_sec`) and rounding. - Segment is consisted of many subsegments and sugsegments are equivalent to `frames` in end-to-end speaker diarization models. + Convert subsegment timestamps to scale timestamps by multiplying with the feature rate (`feat_per_sec`) + and rounding. Segment is consisted of many subsegments and sugsegments are equivalent to `frames` + in end-to-end speaker diarization models. Args: subsegments (List[Tuple[float, float]]): - A list of tuples where each tuple contains the start and end times of a subsegment (frames in end-to-end models). + A list of tuples where each tuple contains the start and end times of a subsegment + (frames in end-to-end models). >>> subsegments = [[t0_start, t0_duration], [t1_start, t1_duration],..., [tN_start, tN_duration]] feat_per_sec (int, optional): The number of feature frames per second. Defaults to 100. @@ -246,7 +251,8 @@ def get_frame_targets_from_rttm( List containing start and end time for each speaker segment label. stt_list, end_list and speaker_list are contained. feat_per_sec (int): - Number of feature frames per second. This quantity is determined by window_stride variable in preprocessing module. + Number of feature frames per second. + This quantity is determined by window_stride variable in preprocessing module. target_spks (tuple): Speaker indices that are generated from combinations. If there are only one or two speakers, only a single target_spks variable is generated. @@ -260,7 +266,8 @@ def get_frame_targets_from_rttm( total_fr_len = int(duration * feat_per_sec) if len(sorted_speakers) > max_spks: logging.warning( - f"Number of speakers in RTTM file {len(sorted_speakers)} exceeds the maximum number of speakers: {max_spks}! Only {max_spks} first speakers remain, and this will affect frame metrics!" + f"Number of speakers in RTTM file {len(sorted_speakers)} exceeds the maximum number of speakers: " + f"{max_spks}! Only {max_spks} first speakers remain, and this will affect frame metrics!" ) feat_level_target = torch.zeros(total_fr_len, max_spks) for count, (stt, end, spk_rttm_key) in enumerate(zip(stt_list, end_list, speaker_list)): @@ -408,15 +415,17 @@ def assign_labels_to_longer_segs(self, uniq_id, base_scale_clus_label): def get_diar_target_labels(self, uniq_id, sample, fr_level_target): """ - Convert frame-level diarization target variable into segment-level target variable. Since the granularity is reduced - from frame level (10ms) to segment level (100ms~500ms), we need a threshold value, `soft_label_thres`, which determines - the label of each segment based on the overlap between a segment range (start and end time) and the frame-level target variable. + Convert frame-level diarization target variable into segment-level target variable. + Since the granularity is reduced from frame level (10ms) to segment level (100ms~500ms), + we need a threshold value, `soft_label_thres`, which determines the label of each segment + based on the overlap between a segment range (start and end time) and the frame-level target variable. Args: uniq_id (str): Unique file ID that refers to an input audio file and corresponding RTTM (Annotation) file. sample: - `DiarizationSpeechLabel` instance containing sample information such as audio filepath and RTTM filepath. + `DiarizationSpeechLabel` instance containing sample information such as + audio filepath and RTTM filepath. fr_level_target (torch.tensor): Tensor containing label for each feature-level frame. @@ -424,7 +433,8 @@ def get_diar_target_labels(self, uniq_id, sample, fr_level_target): seg_target (torch.tensor): Tensor containing binary speaker labels for base-scale segments. base_clus_label (torch.tensor): - Representative speaker label for each segment. This variable only has one speaker label for each base-scale segment. + Representative speaker label for each segment. This variable only has one speaker label + for each base-scale segment. -1 means that there is no corresponding speaker in the target_spks tuple. """ seg_target_list, base_clus_label = [], [] @@ -459,7 +469,8 @@ def parse_rttm_for_ms_targets(self, sample): Args: sample: - `DiarizationSpeechLabel` instance containing sample information such as audio filepath and RTTM filepath. + `DiarizationSpeechLabel` instance containing sample information such as + audio filepath and RTTM filepath. target_spks (tuple): Speaker indices that are generated from combinations. If there are only one or two speakers, only a single target_spks tuple is generated. @@ -474,7 +485,8 @@ def parse_rttm_for_ms_targets(self, sample): multiscale embeddings to form an input matrix for the MSDD model. """ - rttm_lines = open(sample.rttm_file).readlines() + with open(sample.rttm_file, 'r') as file: + rttm_lines = file.readlines() uniq_id = self.get_uniq_id_with_range(sample) rttm_timestamps = extract_seg_info_from_rttm(rttm_lines) fr_level_target = assign_frame_level_spk_vector( @@ -579,7 +591,8 @@ class _AudioMSDDInferDataset(Dataset): emb_dict (dict): Dictionary containing cluster-average embeddings and speaker mapping information. emb_seq (dict): - Dictionary containing multiscale speaker embedding sequence, scale mapping and corresponding segment timestamps. + Dictionary containing multiscale speaker embedding sequence, + scale mapping and corresponding segment timestamps. clus_label_dict (dict): Subsegment-level (from base-scale) speaker labels from clustering results. soft_label_thres (float): @@ -678,9 +691,9 @@ def get_diar_target_labels_from_fr_target(self, uniq_id, fr_level_target): """ Generate base-scale level binary diarization label from frame-level target matrix. For the given frame-level speaker target matrix fr_level_target, we count the number of frames that belong to each speaker and calculate - ratios for each speaker into the `soft_label_vec` variable. Finally, `soft_label_vec` variable is compared with `soft_label_thres` - to determine whether a label vector should contain 0 or 1 for each speaker bin. Note that seg_target variable has - dimension of (number of base-scale segments x 2) dimension. + ratios for each speaker into the `soft_label_vec` variable. Finally, `soft_label_vec` variable is compared + with `soft_label_thres` to determine whether a label vector should contain 0 or 1 for each speaker bin. + Note that seg_target variable has dimension of (number of base-scale segments x 2) dimension. Example of seg_target: [[0., 1.], [0., 1.], [1., 1.], [1., 0.], [1., 0.], ..., [0., 1.]] @@ -726,7 +739,8 @@ def __getitem__(self, index): if avg_embs.shape[2] > self.max_spks: raise ValueError( - f" avg_embs.shape[2] {avg_embs.shape[2]} should be less than or equal to self.max_num_speakers {self.max_spks}" + f" avg_embs.shape[2] {avg_embs.shape[2]} should be less than or equal to " + f"self.max_num_speakers {self.max_spks}" ) feats = [] @@ -820,7 +834,8 @@ def _msdd_train_collate_fn(self, batch): def _msdd_infer_collate_fn(self, batch): """ - Collate batch of feats (speaker embeddings), feature lengths, target label sequences and cluster-average embeddings. + Collate batch of feats (speaker embeddings), feature lengths, target label sequences + and cluster-average embeddings. Args: batch (tuple): @@ -922,6 +937,7 @@ def __init__( ) def msdd_train_collate_fn(self, batch): + """Collate batch of audio features, feature lengths, target label sequences for training.""" return _msdd_train_collate_fn(self, batch) @@ -943,11 +959,13 @@ class AudioToSpeechMSDDInferDataset(_AudioMSDDInferDataset): emb_dict (dict): Dictionary containing cluster-average embeddings and speaker mapping information. emb_seq (dict): - Dictionary containing multiscale speaker embedding sequence, scale mapping and corresponding segment timestamps. + Dictionary containing multiscale speaker embedding sequence, scale mapping + and corresponding segment timestamps. clus_label_dict (dict): Subsegment-level (from base-scale) speaker labels from clustering results. soft_label_thres (float): - Threshold that determines speaker labels of segments depending on the overlap with groundtruth speaker timestamps. + Threshold that determines speaker labels of segments depending on the overlap + with groundtruth speaker timestamps. featurizer: Featurizer instance for generating features from raw waveform. use_single_scale_clus (bool): @@ -955,11 +973,12 @@ class AudioToSpeechMSDDInferDataset(_AudioMSDDInferDataset): seq_eval_mode (bool): If True, F1 score will be calculated for each speaker pair during inference mode. window_stride (float): - Window stride for acoustic feature. This value is used for calculating the numbers of feature-level frames. + Window stride for acoustic feature. This value is used for calculating the numbers of + feature-level frames. pairwise_infer (bool): - If True, this Dataset class operates in inference mode. In inference mode, a set of speakers in the input audio - is split into multiple pairs of speakers and speaker tuples (e.g. 3 speakers: [(0,1), (1,2), (0,2)]) and then - fed into the MSDD to merge the individual results. + If True, this Dataset class operates in inference mode. In inference mode, a set of speakers + in the input audio is split into multiple pairs of speakers and speaker tuples + (e.g. 3 speakers: [(0,1), (1,2), (0,2)]) and then fed into the MSDD to merge the individual results. """ def __init__( @@ -988,6 +1007,7 @@ def __init__( ) def msdd_infer_collate_fn(self, batch): + """Collate batch of audio features, feature lengths, target label sequences for inference.""" return _msdd_infer_collate_fn(self, batch) @@ -1089,7 +1109,7 @@ def get_uniq_id_with_range(self, sample, deci=3): uniq_id = f"{bare_uniq_id}_{offset}_{endtime}" return uniq_id - def parse_rttm_for_targets_and_lens(self, uniq_id, rttm_file, offset, duration, target_len): + def parse_rttm_for_targets_and_lens(self, rttm_file, offset, duration, target_len): """ Generate target tensor variable by extracting groundtruth diarization labels from an RTTM file. This function converts (start, end, speaker_id) format into base-scale (the finest scale) segment level @@ -1098,7 +1118,9 @@ def parse_rttm_for_targets_and_lens(self, uniq_id, rttm_file, offset, duration, Example of seg_target: [[0., 1.], [0., 1.], [1., 1.], [1., 0.], [1., 0.], ..., [0., 1.]] """ - rttm_lines = open(rttm_file).readlines() + with open(rttm_file, 'r') as f: + rttm_lines = f.readlines() + rttm_timestamps, sess_to_global_spkids = extract_frame_info_from_rttm(offset, duration, rttm_lines) fr_level_target = get_frame_targets_from_rttm( @@ -1203,7 +1225,8 @@ def __getitem__(self, index): uniq_id = self.get_uniq_id_with_range(sample) audio_signal = self.featurizer.process(sample.audio_file, offset=offset, duration=session_len_sec) - # We should resolve the length mis-match from the round-off errors: `session_len_sec` and `audio_signal.shape[0]` + # We should resolve the length mis-match from the round-off errors between these two variables: + # `session_len_sec` and `audio_signal.shape[0]` session_len_sec = ( np.floor(audio_signal.shape[0] / self.featurizer.sample_rate * self.floor_decimal) / self.floor_decimal ) @@ -1213,7 +1236,7 @@ def __getitem__(self, index): audio_signal, audio_signal_length = audio_signal.to('cpu'), audio_signal_length.to('cpu') target_len = self.get_segment_timestamps(duration=session_len_sec, sample_rate=self.featurizer.sample_rate) targets = self.parse_rttm_for_targets_and_lens( - uniq_id=uniq_id, rttm_file=sample.rttm_file, offset=offset, duration=session_len_sec, target_len=target_len + rttm_file=sample.rttm_file, offset=offset, duration=session_len_sec, target_len=target_len ) return audio_signal, audio_signal_length, targets, target_len @@ -1229,13 +1252,15 @@ def _eesd_train_collate_fn(self, batch): Returns: audio_signal (torch.Tensor): - A tensor containing the raw waveform samples (time series) loaded from the `audio_filepath` in the input manifest file. + A tensor containing the raw waveform samples (time series) loaded from the `audio_filepath` + in the input manifest file. feature_length (torch.Tensor): A tensor containing the lengths of the raw waveform samples. targets (torch.Tensor): Groundtruth speaker labels for the given input embedding sequence. target_lens (torch.Tensor): - A tensor containing the number of segments for each sample in the batch, necessary for reshaping inputs to the EESD model. + A tensor containing the number of segments for each sample in the batch, necessary for + reshaping inputs to the EESD model. """ packed_batch = list(zip(*batch)) audio_signal, feature_length, targets, target_len = packed_batch @@ -1344,4 +1369,5 @@ def __init__( ) def eesd_train_collate_fn(self, batch): + """Collate a batch of data for end-to-end speaker diarization training.""" return _eesd_train_collate_fn(self, batch) diff --git a/nemo/collections/asr/metrics/der.py b/nemo/collections/asr/metrics/der.py index 000b839ceb46..22c9a76b7fc9 100644 --- a/nemo/collections/asr/metrics/der.py +++ b/nemo/collections/asr/metrics/der.py @@ -123,7 +123,7 @@ def uem_timeline_from_file(uem_file, uniq_name=''): lines = f.readlines() for line in lines: line = line.strip() - speaker_id, channel, start_time, end_time = line.split() + _, _, start_time, end_time = line.split() timeline.add(Segment(float(start_time), float(end_time))) return timeline @@ -145,14 +145,21 @@ def score_labels( Args: - AUDIO_RTTM_MAP (dict): Dictionary containing information provided from manifestpath - all_reference (list[uniq_name,Annotation]): reference annotations for score calculation - all_hypothesis (list[uniq_name,Annotation]): hypothesis annotations for score calculation - verbose (bool): Warns if RTTM file is not found. + AUDIO_RTTM_MAP (dict): + Dictionary containing information provided from manifestpath + all_reference (list[uniq_name,Annotation]): + Reference annotations for score calculation + all_hypothesis (list[uniq_name,Annotation]): + Hypothesis annotations for score calculation + verbose (bool): + Warns if RTTM file is not found. Returns: - metric (pyannote.DiarizationErrorRate): Pyannote Diarization Error Rate metric object. This object contains detailed scores of each audiofile. - mapping (dict): Mapping dict containing the mapping speaker label for each audio input + metric (pyannote.DiarizationErrorRate): + Pyannote Diarization Error Rate metric object. + This object contains detailed scores of each audiofile. + mapping (dict): + Mapping dict containing the mapping speaker label for each audio input < Caveat > Unlike md-eval.pl, "no score" collar in pyannote.metrics is the maximum length of @@ -171,7 +178,8 @@ def score_labels( correct_spk_count += 1 if verbose and len(ref_labels.labels()) != len(hyp_labels.labels()): logging.info( - f"Wrong Spk. Count with uniq_id:...{ref_key[-10:]}, Ref: {len(ref_labels.labels())}, Hyp: {len(hyp_labels.labels())}" + f"Wrong Spk. Count with uniq_id:...{ref_key[-10:]}, " + f"Ref: {len(ref_labels.labels())}, Hyp: {len(hyp_labels.labels())}" ) uem_obj = None if all_uem is not None: @@ -187,7 +195,7 @@ def score_labels( spk_count_acc = correct_spk_count / len(all_reference) DER = abs(metric) if metric['total'] == 0: - raise ValueError(f"Total evaluation time is 0. Abort.") + raise ValueError("Total evaluation time is 0. Abort.") CER = metric['confusion'] / metric['total'] FA = metric['false alarm'] / metric['total'] MISS = metric['missed detection'] / metric['total'] @@ -195,18 +203,18 @@ def score_labels( itemized_errors = (DER, CER, FA, MISS) if verbose: - # logging.info(f"\n{metric.report()}") - pass + logging.info(f"\n{metric.report()}") logging.info( - "Cumulative Results for collar {} sec and ignore_overlap {}: \n| FA: {:.4f} | MISS: {:.4f} | CER: {:.4f} | DER: {:.4f} | Spk. Count Acc. {:.4f}\n".format( - collar, ignore_overlap, FA, MISS, CER, DER, spk_count_acc - ) + f"Cumulative Results for collar {collar} sec and ignore_overlap {ignore_overlap}: \n" + f"| FA: {FA:.4f} | MISS: {MISS:.4f} | CER: {CER:.4f} | DER: {DER:.4f} | " + f"Spk. Count Acc. {spk_count_acc:.4f}\n" ) return metric, mapping_dict, itemized_errors elif verbose: logging.warning( - "Check if each ground truth RTTMs were present in the provided manifest file. Skipping calculation of Diariazation Error Rate" + "Check if each ground truth RTTMs were present in the provided manifest file. " + "Skipping calculation of Diariazation Error Rate" ) return None @@ -447,4 +455,4 @@ def concat_perm_word_error_rate( cpWER_values.append(cpWER) hyps_spk.append(min_hypothesis) refs_spk.append(concat_reference) - return cpWER_values, hyps_spk, refs_spk + return cpWER_values, hyps_spk, refs_spk \ No newline at end of file diff --git a/nemo/collections/asr/metrics/multi_binary_acc.py b/nemo/collections/asr/metrics/multi_binary_acc.py index 13e57b43bb0b..8ad09c842636 100644 --- a/nemo/collections/asr/metrics/multi_binary_acc.py +++ b/nemo/collections/asr/metrics/multi_binary_acc.py @@ -84,6 +84,18 @@ def __init__(self, dist_sync_on_step=False): def update( self, preds: torch.Tensor, targets: torch.Tensor, signal_lengths: torch.Tensor, cumulative=False ) -> torch.Tensor: + """ + Update the metric with the given predictions, targets, and signal lengths to the metric instance. + + Args: + preds (torch.Tensor): Predicted values. + targets (torch.Tensor): Target values. + signal_lengths (torch.Tensor): Length of each sequence in the batch input. + cumulative (bool): Whether to accumulate the values over time. + + Returns: + f1_score (torch.Tensor): F1 score calculated from the predicted value and binarized target values. + """ with torch.no_grad(): preds_list = [preds[k, : signal_lengths[k], :] for k in range(preds.shape[0])] targets_list = [targets[k, : signal_lengths[k], :] for k in range(targets.shape[0])] diff --git a/nemo/collections/asr/models/sortformer_diar_models.py b/nemo/collections/asr/models/sortformer_diar_models.py index 939b03e7a5ac..e3c14dd77c65 100644 --- a/nemo/collections/asr/models/sortformer_diar_models.py +++ b/nemo/collections/asr/models/sortformer_diar_models.py @@ -208,7 +208,8 @@ def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): def test_dataloader(self): if self._test_dl is not None: return self._test_dl - + return None + @property def input_types(self) -> Optional[Dict[str, NeuralType]]: if hasattr(self.preprocessor, '_sample_rate'): diff --git a/nemo/collections/asr/modules/sortformer_modules.py b/nemo/collections/asr/modules/sortformer_modules.py index fdbeee5235ea..e0b5b15094b6 100644 --- a/nemo/collections/asr/modules/sortformer_modules.py +++ b/nemo/collections/asr/modules/sortformer_modules.py @@ -20,7 +20,6 @@ from nemo.core.classes.exportable import Exportable from nemo.core.classes.module import NeuralModule -from nemo.core.neural_types.elements import ProbsType __all__ = ['SortformerModules'] diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index fd6f71dc0502..1e7dda91c9e7 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -14,7 +14,6 @@ import gc import json -import math import os import shutil from copy import deepcopy @@ -23,14 +22,13 @@ import numpy as np import soundfile as sf import torch -from omegaconf import OmegaConf from omegaconf.listconfig import ListConfig from pyannote.core import Annotation, Segment, Timeline from tqdm import tqdm from nemo.collections.asr.data.audio_to_label import repeat_signal from nemo.collections.asr.parts.utils.longform_clustering import LongFormSpeakerClustering -from nemo.collections.asr.parts.utils.offline_clustering import SpeakerClustering, get_argmin_mat, split_input_data +from nemo.collections.asr.parts.utils.offline_clustering import get_argmin_mat, split_input_data from nemo.utils import logging @@ -78,10 +76,13 @@ def audio_rttm_map(manifest, attach_dur=False): """ This function creates AUDIO_RTTM_MAP which is used by all diarization components to extract embeddings, cluster and unify time stamps - Args: manifest file that contains keys audio_filepath, rttm_filepath if exists, text, num_speakers if known and uem_filepath if exists - - returns: - AUDIO_RTTM_MAP (dict) : A dictionary with keys of uniq id, which is being used to map audio files and corresponding rttm files + + Args: + manifest (str): Path to the manifest file + attach_dur (bool, optional): If True, attach duration information to the unique name. Defaults to False. + + Returns: + AUDIO_RTTM_MAP (dict) : Dictionary with unique names as keys and corresponding metadata as values. """ AUDIO_RTTM_MAP = {} @@ -114,10 +115,9 @@ def audio_rttm_map(manifest, attach_dur=False): AUDIO_RTTM_MAP[uniqname] = meta else: raise KeyError( - "file {} is already part of AUDIO_RTTM_MAP, it might be duplicated, Note: file basename must be unique".format( - meta['audio_filepath'] + f"file {meta['audio_filepath']} is already part of AUDIO_RTTM_MAP, it might be duplicated, " + "Note: file basename must be unique" ) - ) return AUDIO_RTTM_MAP @@ -247,7 +247,8 @@ def get_embs_and_timestamps(multiscale_embeddings_and_timestamps, multiscale_arg def get_timestamps(multiscale_timestamps, multiscale_args_dict): """ The timestamps in `multiscale_timestamps` dictionary are indexed by scale index. - This function rearranges the extracted speaker embedding and timestamps by unique ID to make the further processing more convenient. + This function rearranges the extracted speaker embedding and timestamps by unique ID + to make the further processing more convenient. Args: multiscale_timestamps (dict): @@ -441,13 +442,20 @@ def perform_clustering( 'embeddings' : Tensor containing embeddings. Dimensions:(# of embs) x (emb. dimension) 'timestamps' : Tensor containing ime stamps list for each audio recording 'multiscale_segment_counts' : Tensor containing the number of segments for each scale - AUDIO_RTTM_MAP (dict): AUDIO_RTTM_MAP for mapping unique id with audio file path and rttm path - out_rttm_dir (str): Path to write predicted rttms - clustering_params (dict): clustering parameters provided through config that contains max_num_speakers (int), - oracle_num_speakers (bool), max_rp_threshold(float), sparse_search_volume(int) and enhance_count_threshold (int) - use_torch_script (bool): Boolean that determines whether to use torch.jit.script for speaker clustering - device (torch.device): Device we are running on ('cpu', 'cuda'). - verbose (bool): Enable TQDM progress bar. + AUDIO_RTTM_MAP (dict): + AUDIO_RTTM_MAP for mapping unique id with audio file path and rttm path + out_rttm_dir (str): + Path to write predicted rttms + clustering_params (dict): + Clustering parameters provided through config that contains max_num_speakers (int), + oracle_num_speakers (bool), max_rp_threshold(float), sparse_search_volume(int) + and enhance_count_threshold (int). + use_torch_script (bool): + Boolean that determines whether to use torch.jit.script for speaker clustering + device (torch.device): + Device we are running on ('cpu', 'cuda'). + verbose (bool): + Enable TQDM progress bar. Returns: all_reference (list[uniq_name,Annotation]): reference annotations for score calculation @@ -614,10 +622,9 @@ def read_rttm_lines(rttm_file_path): lines = f.readlines() else: raise FileNotFoundError( - "Requested to construct manifest from rttm with oracle VAD option or from NeMo VAD but received filename as {}".format( - rttm_file_path + "Requested to construct manifest from rttm with oracle VAD option " + f"or from NeMo VAD but received filename as {rttm_file_path}" ) - ) return lines @@ -886,7 +893,8 @@ def segments_manifest_to_subsegments_manifest( Generate subsegments manifest from segments manifest file Args: segments_manifest file (str): path to segments manifest file, typically from VAD output - subsegments_manifest_file (str): path to output subsegments manifest file (default (None) : writes to current working directory) + subsegments_manifest_file (str): path to output subsegments manifest file + (default (None) : writes to current working directory) window (float): window length for segments to subsegments length shift (float): hop length for subsegments shift min_subsegments_duration (float): exclude subsegments smaller than this duration value @@ -960,7 +968,8 @@ def get_subsegments( it results in (10/0.08)+1 = 125 + 1 frames. Returns: - subsegments (List[tuple[float, float]]): subsegments generated for the segments as list of tuple of start and duration of each subsegment + subsegments (List[tuple[float, float]]): subsegments generated for the segments as + list of tuple of start and duration of each subsegment """ subsegments: List[List[float]] = [] start = offset @@ -1041,7 +1050,25 @@ def tensor_to_list(range_tensor: torch.Tensor) -> List[List[float]]: return [[float(range_tensor[k][0]), float(range_tensor[k][1])] for k in range(range_tensor.shape[0])] -def generate_diarization_output_lines(speaker_timestamps, model_spk_num): +def generate_diarization_output_lines(speaker_timestamps: List[List[float]], model_spk_num: int) -> List[str]: + """ + Generate diarization output lines list from the speaker timestamps list by merging overlapping intervals. + + Args: + speaker_timestamps (list): + List containing the start and end time of the speech intervals for each speaker. + Example: + >>> speaker_timestamps = [[0.5, 3.12], [3.51, 7.26],... ] + model_spk_num (int): + Number of speakers in the model. + + Returns: + speaker_lines_total (list): + List containing the diarization output lines in the format: + "start_time end_time speaker_id" + Example: + >>> speaker_lines_total = ["0.5 3.12 speaker_0", "3.51 7.26 speaker_1",...] + """ speaker_lines_total = [] for spk_idx in range(model_spk_num): ts_invervals = speaker_timestamps[spk_idx] @@ -1334,20 +1361,22 @@ def get_scale_mapping_argmat(uniq_embs_and_timestamps: Dict[str, dict]) -> Dict[ def get_overlap_stamps(cont_stamps: List[str], ovl_spk_idx: List[str]): """ - Generate timestamps that include overlap speech. Overlap-including timestamps are created based on the segments that are - created for clustering diarizer. Overlap speech is assigned to the existing speech segments in `cont_stamps`. + Generate timestamps that include overlap speech. Overlap-including timestamps are created based on + the segments that are created for clustering diarizer. Overlap speech is assigned to the existing + speech segments in `cont_stamps`. Args: cont_stamps (list): - Non-overlapping (single speaker per segment) diarization output in string format. - Each line contains the start and end time of segments and corresponding speaker labels. + Non-overlapping (single speaker per segment) diarization output in string format. Each line + contains the start and end time of segments and corresponding speaker labels. ovl_spk_idx (list): - List containing segment index of the estimated overlapped speech. The start and end of segments are based on the - single-speaker (i.e., non-overlap-aware) RTTM generation. + List containing segment index of the estimated overlapped speech. The start and end of + segments are based on the single-speaker (i.e., non-overlap-aware) RTTM generation. + Returns: total_ovl_cont_list (list): - Rendered diarization output in string format. Each line contains the start and end time of segments and - corresponding speaker labels. This format is identical to `cont_stamps`. + Rendered diarization output in string format. Each line contains the start and end time of + segments and corresponding speaker labels. This format is identical to `cont_stamps`. """ ovl_spk_cont_list = [[] for _ in range(len(ovl_spk_idx))] for spk_idx in range(len(ovl_spk_idx)): @@ -1364,18 +1393,21 @@ def get_overlap_stamps(cont_stamps: List[str], ovl_spk_idx: List[str]): def get_adaptive_threshold(estimated_num_of_spks: int, min_threshold: float, overlap_infer_spk_limit: int): """ - This function controls the magnitude of the sigmoid threshold based on the estimated number of speakers. As the number of - speakers becomes larger, diarization error rate is very sensitive on overlap speech detection. This function linearly increases - the threshold in proportion to the estimated number of speakers so more confident overlap speech results are reflected when - the number of estimated speakers are relatively high. + This function controls the magnitude of the sigmoid threshold based on the estimated number of + speakers. As the number of speakers becomes larger, diarization error rate is very sensitive + to overlap speech detection. This function linearly increases the threshold in proportion to + the estimated number of speakers so more confident overlap speech results are reflected when + the number of estimated speakers is relatively high. Args: estimated_num_of_spks (int): Estimated number of speakers from the clustering result. min_threshold (float): - Sigmoid threshold value from the config file. This threshold value is minimum threshold value when `estimated_num_of_spks=2` + Sigmoid threshold value from the config file. This threshold value is the minimum + threshold when `estimated_num_of_spks=2`. overlap_infer_spk_limit (int): - If the `estimated_num_of_spks` is less then `overlap_infer_spk_limit`, overlap speech estimation is skipped. + If the `estimated_num_of_spks` is less than `overlap_infer_spk_limit`, overlap speech + estimation is skipped. Returns: adaptive_threshold (float): @@ -1390,37 +1422,41 @@ def get_adaptive_threshold(estimated_num_of_spks: int, min_threshold: float, ove def generate_speaker_timestamps( clus_labels: List[Union[float, int]], msdd_preds: List[torch.Tensor], **params ) -> Tuple[List[str], List[str]]: - ''' - Generate speaker timestamps from the segmentation information. If `use_clus_as_main=True`, use clustering result for main speaker - labels and use timestamps from the predicted sigmoid values. In this function, the main speaker labels in `maj_labels` exist for - every subsegment steps while overlap speaker labels in `ovl_labels` only exist for segments where overlap-speech is occuring. + """ + Generate speaker timestamps from the segmentation information. If `use_clus_as_main=True`, use + clustering result for main speaker labels and use timestamps from the predicted sigmoid values. + In this function, the main speaker labels in `maj_labels` exist for every subsegment step, while + overlap speaker labels in `ovl_labels` only exist for segments where overlap speech occurs. Args: clus_labels (list): List containing integer-valued speaker clustering results. msdd_preds (list): - List containing tensors of the predicted sigmoid values. - Each tensor has shape of: (Session length, estimated number of speakers). + List containing tensors of the predicted sigmoid values. Each tensor has shape of: + (Session length, estimated number of speakers). params: Parameters for generating RTTM output and evaluation. Parameters include: - infer_overlap (bool): If False, overlap-speech will not be detected. - use_clus_as_main (bool): Add overlap-speech detection from MSDD to clustering results. If False, only MSDD output - is used for constructing output RTTM files. + infer_overlap (bool): If False, overlap speech will not be detected. + use_clus_as_main (bool): Add overlap-speech detection from MSDD to clustering results. + If False, only MSDD output is used for constructing output + RTTM files. overlap_infer_spk_limit (int): Above this limit, overlap-speech detection is bypassed. - use_adaptive_thres (bool): Boolean that determines whehther to use adaptive_threshold depending on the estimated - number of speakers. + use_adaptive_thres (bool): Boolean that determines whether to use adaptive thresholds + depending on the estimated number of speakers. max_overlap_spks (int): Maximum number of overlap speakers detected. Default is 2. threshold (float): Sigmoid threshold for MSDD output. Returns: maj_labels (list): - List containing string-formated single-speaker speech segment timestamps and corresponding speaker labels. + List containing string-formatted single-speaker speech segment timestamps and corresponding + speaker labels. Example: [..., '551.685 552.77 speaker_1', '552.99 554.43 speaker_0', '554.97 558.19 speaker_0', ...] ovl_labels (list): - List containing string-formated additional overlapping speech segment timestamps and corresponding speaker labels. - Note that `ovl_labels` includes only overlapping speech that is not included in `maj_labels`. + List containing string-formatted additional overlapping speech segment timestamps and + corresponding speaker labels. Note that `ovl_labels` includes only overlapping speech that + is not included in `maj_labels`. Example: [..., '152.495 152.745 speaker_1', '372.71 373.085 speaker_0', '554.97 555.885 speaker_1', ...] - ''' + """ msdd_preds.squeeze(0) estimated_num_of_spks = msdd_preds.shape[-1] overlap_speaker_list = [[] for _ in range(estimated_num_of_spks)] @@ -1474,7 +1510,8 @@ def get_id_tup_dict(uniq_id_list: List[str], test_data_collection, preds_list: L uniq_id_list (list): List containing the `uniq_id` values. test_data_collection (collections.DiarizationLabelEntity): - Class instance that is containing session information such as targeted speaker indices, audio filepath and RTTM filepath. + Class instance that is containing session information such as targeted speaker indices, + audio filepath and RTTM filepath. preds_list (list): List containing tensors of predicted sigmoid values. @@ -1503,11 +1540,14 @@ def prepare_split_data(manifest_filepath, _out_dir, multiscale_args_dict, global Returns: multiscale_args_dict (dict): - - Dictionary containing two types of arguments: multi-scale weights and subsegment timestamps for each data sample. + - Dictionary containing two types of arguments: multi-scale weights and subsegment timestamps + for each data sample. - Each data sample has two keys: `multiscale_weights` and `scale_dict`. - `multiscale_weights` key contains a list containing multiscale weights. - `scale_dict` is indexed by integer keys which are scale index. - - Each data sample is indexed by using the following naming convention: `__` + - Each data sample is indexed by using the following naming convention: + `__` + Example: `fe_03_00106_mixed_626310_642300` """ speaker_dir = os.path.join(_out_dir, 'speaker_outputs') diff --git a/nemo/collections/asr/parts/utils/vad_utils.py b/nemo/collections/asr/parts/utils/vad_utils.py index cffcfd1ae5a1..0fbda543ca11 100644 --- a/nemo/collections/asr/parts/utils/vad_utils.py +++ b/nemo/collections/asr/parts/utils/vad_utils.py @@ -36,7 +36,6 @@ from sklearn.model_selection import ParameterGrid from tqdm import tqdm from nemo.collections.asr.models import EncDecClassificationModel, EncDecFrameClassificationModel -from nemo.collections.asr.parts.utils.speaker_utils import timestamps_to_pyannote_object from nemo.collections.common.parts.preprocessing.manifest import get_full_path from nemo.utils import logging @@ -66,7 +65,8 @@ def prepare_manifest(config: dict) -> str: input_list = config['input'] else: raise ValueError( - "The input for manifest preparation would either be a string of the filepath to manifest or a list of {'audio_filepath': i, 'offset': 0, 'duration': null} " + "The input for manifest preparation would either be a string of the filepath to manifest " + "or a list of {'audio_filepath': i, 'offset': 0, 'duration': null}." ) args_func = { @@ -246,7 +246,8 @@ def generate_overlap_vad_seq( out_dir: str = None, ) -> str: """ - Generate predictions with overlapping input windows/segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple windows. + Generate predictions with overlapping input windows/segments. + Then a smoothing filter is applied to decide the label for a frame spanned by multiple windows. Two common smoothing filters are supported: majority vote (median) and average (mean). This function uses multiprocessing to speed up. Args: @@ -310,8 +311,8 @@ def generate_overlap_vad_seq_per_tensor( frame: torch.Tensor, per_args: Dict[str, float], smoothing_method: str ) -> torch.Tensor: """ - Use generated frame prediction (generated by shifting window of shift_length_in_sec (10ms)) to generate prediction with overlapping input window/segments - See description in generate_overlap_vad_seq. + Use generated frame prediction (generated by shifting window of shift_length_in_sec (10ms)) to generate + prediction with overlapping input window/segments. See description in generate_overlap_vad_seq. Use this for single instance pipeline. """ # This function will be refactor for vectorization but this is okay for now @@ -472,7 +473,8 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te Binarize predictions to speech and non-speech Reference - Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", InterSpeech 2015. + Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice + Activity Detection", InterSpeech 2015. Implementation: https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py Args: @@ -485,7 +487,8 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te frame_length_in_sec (float): length of frame. Returns: - speech_segments(torch.Tensor): A tensor of speech segment in torch.Tensor([[start1, end1], [start2, end2]]) format. + speech_segments(torch.Tensor): A tensor of speech segment in the form of: + `torch.Tensor([[start1, end1], [start2, end2]])`. """ frame_length_in_sec = per_args.get('frame_length_in_sec', 0.01) @@ -535,10 +538,10 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te def remove_segments(original_segments: torch.Tensor, to_be_removed_segments: torch.Tensor) -> torch.Tensor: """ Remove speech segments list in to_be_removed_segments from original_segments. - For example, - remove torch.Tensor([[start2, end2],[start4, end4]]) from torch.Tensor([[start1, end1],[start2, end2],[start3, end3], [start4, end4]]), - -> - torch.Tensor([[start1, end1],[start3, end3]]) + (Example) Remove torch.Tensor([[start2, end2],[start4, end4]]) + from torch.Tensor([[start1, end1],[start2, end2],[start3, end3], [start4, end4]]), + -> + torch.Tensor([[start1, end1],[start3, end3]]) """ for y in to_be_removed_segments: original_segments = original_segments[original_segments.eq(y).all(dim=1).logical_not()] @@ -559,20 +562,30 @@ def get_gap_segments(segments: torch.Tensor) -> torch.Tensor: @torch.jit.script def filtering(speech_segments: torch.Tensor, per_args: Dict[str, float]) -> torch.Tensor: """ - Filter out short non_speech and speech segments. + Filter out short non-speech and speech segments. + + Reference: + Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice + Activity Detection", InterSpeech 2015. + Implementation: + https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py - Reference - Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", InterSpeech 2015. - Implementation: https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py Args: - speech_segments (torch.Tensor): A tensor of speech segment in torch.Tensor([[start1, end1], [start2, end2]]) format. + speech_segments (torch.Tensor): + A tensor of speech segments in the format + torch.Tensor([[start1, end1], [start2, end2]]). per_args: - min_duration_on (float): threshold for small non_speech deletion - min_duration_off (float): threshold for short speech segment deletion - filter_speech_first (float): Whether to perform short speech segment deletion first. Use 1.0 to represent True. + min_duration_on (float): + Threshold for small non-speech deletion. + min_duration_off (float): + Threshold for short speech segment deletion. + filter_speech_first (float): + Whether to perform short speech segment deletion first. Use 1.0 to represent True. Returns: - speech_segments(torch.Tensor): A tensor of filtered speech segment in torch.Tensor([[start1, end1], [start2, end2]]) format. + speech_segments (torch.Tensor): + A tensor of filtered speech segments in the format + torch.Tensor([[start1, end1], [start2, end2]]). """ if speech_segments.shape == torch.Size([0]): return speech_segments @@ -709,7 +722,8 @@ def generate_vad_segment_table( 17,18, speech Args: vad_pred_dir (str): directory of prediction files to be processed. - postprocessing_params (dict): dictionary of thresholds for prediction score. See details in binarization and filtering. + postprocessing_params (dict): dictionary of thresholds for prediction score. + See details in binarization and filtering. frame_length_in_sec (float): frame length. out_dir (str): output dir of generated table/csv file. num_workers(float): number of process for multiprocessing @@ -820,16 +834,19 @@ def vad_tune_threshold_on_dev( num_workers: int = 20, ) -> Tuple[dict, dict]: """ - Tune thresholds on dev set. Return best thresholds which gives the lowest detection error rate (DetER) in thresholds. + Tune thresholds on dev set. Return best thresholds which gives the lowest + detection error rate (DetER) in thresholds. + Args: params (dict): dictionary of parameters to be tuned on. - vad_pred_method (str): suffix of prediction file. Use to locate file. Should be either in "frame", "mean" or "median". - groundtruth_RTTM_dir (str): directory of ground-truth rttm files or a file contains the paths of them. - focus_metric (str): metrics we care most when tuning threshold. Should be either in "DetER", "FA", "MISS" - frame_length_in_sec (float): frame length. - num_workers (int): number of workers. + vad_pred_method (str): suffix of prediction file. Use to locate file. + Should be either in "frame", "mean" or "median". + groundtruth_RTTM_dir (str): Directory of ground-truth rttm files or a file contains the paths of them. + focus_metric (str): Metrics we care most when tuning threshold. Should be either in "DetER", "FA", "MISS" + frame_length_in_sec (float): Frame length. + num_workers (int): Number of workers. Returns: - best_threshold (float): threshold that gives lowest DetER. + best_threshold (float): Threshold that gives lowest DetER. """ min_score = 100 all_perf = {} @@ -986,7 +1003,8 @@ def plot( threshold (float): threshold for prediction score (from 0 to 1). per_args(dict): a dict that stores the thresholds for postprocessing. unit_frame_len (float): unit frame length in seconds for VAD predictions. - label_repeat (int): repeat the label for this number of times to match different frame lengths in preds and labels. + label_repeat (int): repeat the label for this number of times to match different + frame lengths in preds and labels. xticks_step (int): step size for xticks. """ plt.figure(figsize=[20, 2]) @@ -1254,7 +1272,8 @@ def stitch_segmented_asr_output( fout.flush() logging.info( - f"Finish stitch segmented ASR output to {stitched_output_manifest}, the speech segments info has been stored in directory {speech_segments_tensor_dir}" + f"Finish stitch segmented ASR output to {stitched_output_manifest}, " + f"the speech segments info has been stored in directory {speech_segments_tensor_dir}" ) return stitched_output_manifest @@ -1471,16 +1490,22 @@ def plot_sample_from_rttm( def align_labels_to_frames(probs, labels, threshold=0.2): """ - Aligns labels to frames when the frame length (e.g., 10ms) is different from the label length (e.g., 20ms). - The threshold 0.2 is not important, since the actual ratio will always be close to an integer unless using frame/label - lengths that are not multiples of each other (e.g., 15ms frame length and 20ms label length), which is not valid. - The value 0.2 here is just for easier unit testing. + Aligns labels to frames when the frame length (e.g., 10ms) is different from the label length + (e.g., 20ms). The threshold 0.2 is not critical, as the actual ratio will always be close to an + integer unless using frame/label lengths that are not multiples of each other (e.g., 15ms frame + length and 20ms label length), which is not valid. The value 0.2 is chosen for easier unit testing. + Args: - probs (List[float]): list of probabilities - labels (List[int]): list of labels - threshold (float): threshold for rounding ratio to integer + probs (List[float]): + List of probabilities. + labels (List[int]): + List of labels. + threshold (float): + Threshold for rounding the ratio to an integer. + Returns: - labels (List[int]): list of labels aligned to frames + labels (List[int]): + List of labels aligned to frames. """ frames_len = len(probs) labels_len = len(labels) @@ -1511,11 +1536,13 @@ def align_labels_to_frames(probs, labels, threshold=0.2): ratio = frames_len / labels_len res = frames_len % labels_len if ceil(ratio) - ratio < threshold: - # e.g., ratio is 1.83, ceil(ratio) = 2, then we repeat labels to make it a multiple of 2, and discard the redundant labels + # e.g., ratio is 1.83, ceil(ratio) = 2, then we repeat labels + # to make it a multiple of 2, and discard the redundant labels labels = labels.repeat_interleave(ceil(ratio), dim=0).long().tolist() labels = labels[:frames_len] else: - # e.g., ratio is 2.02, floor(ratio) = 2, then we repeat labels to make it a multiple of 2 and add additional labels + # e.g., ratio is 2.02, floor(ratio) = 2, then we repeat labels + # to make it a multiple of 2 and add additional labels labels = labels.repeat_interleave(floor(ratio), dim=0).long().tolist() if res > 0: labels += labels[-res:] @@ -1720,7 +1747,8 @@ def ts_vad_post_processing( """ Post-processing on diarization results using VAD style post-processing methods. These post-processing methods are inspired by the following paper: - Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020). + Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: + a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020). Args: ts_vad_binary_vec (Tensor): diff --git a/nemo/collections/common/parts/preprocessing/collections.py b/nemo/collections/common/parts/preprocessing/collections.py index 8acd9fc08743..5edf1724dc2f 100644 --- a/nemo/collections/common/parts/preprocessing/collections.py +++ b/nemo/collections/common/parts/preprocessing/collections.py @@ -350,7 +350,9 @@ def __init__(self, manifests_files: Union[str, List[str]], *args, **kwargs): class SpeechLLMAudioTextEntity(object): + """Class for SpeechLLM dataloader instance.""" def __init__(self, sid, audio_file, duration, context, answer, offset, speaker, orig_sr, lang) -> None: + """Initialize the AudioTextEntity for a SpeechLLM dataloader instance.""" self.id = sid self.audio_file = audio_file self.duration = duration @@ -642,7 +644,8 @@ def __parse_item(self, line: str, manifest_file: str) -> Dict[str, Any]: elif 'question' in item: # compatability with old manifests that uses 'question' as context key logging.warning( - f"Neither `{self.context_key}` is found nor `context_file` is set, but found `question` in item: {item}", + f"Neither `{self.context_key}` is found nor" + f"`context_file` is set, but found `question` in item: {item}", mode=logging_mode.ONCE, ) item['context'] = item.pop('question') @@ -739,7 +742,8 @@ def __init__( else: logging.info(f"Filtered duration for loading collection is {duration_filtered / 3600: .2f} hours.") logging.info( - f"Dataset successfully loaded with {len(data)} items and total duration provided from manifest is {total_duration / 3600: .2f} hours." + f"Dataset successfully loaded with {len(data)} items " + f"and total duration provided from manifest is {total_duration / 3600: .2f} hours." ) self.uniq_labels = sorted(set(map(lambda x: x.label, data))) @@ -880,13 +884,15 @@ def __init__( if len(data) == max_number: break - logging.info("# {} files loaded including # {} unique labels".format(len(data), len(self.uniq_labels))) + logging.info(f"# {len(data)} files loaded including # {len(self.uniq_labels)} unique labels") super().__init__(data) def relative_speaker_parser(self, seq_label): """Convert sequence of speaker labels to relative labels. Convert sequence of absolute speaker to sequence of relative speaker [E A C A E E C] -> [0 1 2 1 0 0 2] - In this seq of label , if label do not appear before, assign new relative labels len(pos); else reuse previous assigned relative labels. + In this seq of label , if label do not appear before, assign new relative labels len(pos); + else reuse previous assigned relative labels. + Args: seq_label (str): A string of a sequence of labels. @@ -923,10 +929,13 @@ def __init__( """Parse lists of feature files and sequences of labels. Args: - manifests_files: Either single string file or list of such - - manifests to yield items from. - max_number: Maximum number of samples to collect; pass to `FeatureSequenceLabel` constructor. - index_by_file_id: If True, saves a mapping from filename base (ID) to index in data; pass to `FeatureSequenceLabel` constructor. + manifests_files: + Either single string file or list of such manifests to yield items from. + max_number: + Maximum number of samples to collect; pass to `FeatureSequenceLabel` constructor. + index_by_file_id: + If True, saves a mapping from filename base (ID) to index in data; + pass to `FeatureSequenceLabel` constructor. """ feature_files, seq_labels = [], [] @@ -1088,24 +1097,26 @@ def __init__( **kwargs, ): """ - Parse lists of audio files, durations, RTTM (Diarization annotation) files. Since diarization model infers only - two speakers, speaker pairs are generated from the total number of speakers in the session. + Parse lists of audio files, durations, RTTM (Diarization annotation) files. Since the diarization + model infers only two speakers, speaker pairs are generated from the total number of speakers in + the session. Args: manifest_filepath (str): - Path to input manifest json files. + Path to input manifest JSON files. emb_dict (Dict): Dictionary containing cluster-average embeddings and speaker mapping information. clus_label_dict (Dict): Segment-level speaker labels from clustering results. round_digit (int): - Number of digits to be rounded. + Number of digits to round. seq_eval_mode (bool): If True, F1 score will be calculated for each speaker pair during inference mode. pairwise_infer (bool): - If True, this dataset class operates in inference mode. In inference mode, a set of speakers in the input audio - is split into multiple pairs of speakers and speaker tuples (e.g. 3 speakers: [(0,1), (1,2), (0,2)]) and then - fed into the diarization system to merge the individual results. + If True, this dataset class operates in inference mode. In inference mode, a set of + speakers in the input audio is split into multiple pairs of speakers and speaker tuples + (e.g., 3 speakers: [(0,1), (1,2), (0,2)]) and then fed into the diarization system to + merge the individual results. *args: Args to pass to `SpeechLabel` constructor. **kwargs: Kwargs to pass to `SpeechLabel` constructor. """ @@ -1244,7 +1255,7 @@ def __parse_item_rttm(self, line: str, manifest_file: str) -> Dict[str, Any]: class EndtoEndDiarizationLabel(_Collection): - """List of diarization audio-label correspondence with preprocessing.""" + """List of end-to-end diarization audio-label correspondence with preprocessing.""" OUTPUT_TYPE = collections.namedtuple( typename='DiarizationLabelEntity', @@ -1276,7 +1287,8 @@ def __init__( offsets (List[float]): List of offsets or None for each audio file. max_number (Optional[int]): Maximum number of samples to collect. Defaults to None. do_sort_by_duration (bool): If True, sort samples list by duration. Defaults to False. - index_by_file_id (bool): If True, saves a mapping from filename base (ID) to index in data. Defaults to False. + index_by_file_id (bool): If True, saves a mapping from filename base (ID) to index in data. + Defaults to False. """ if index_by_file_id: @@ -1694,7 +1706,8 @@ def __init__( manifests_files: Either single string file or list of such - manifests to yield items from. max_number: Maximum number of samples to collect; pass to `FeatureSequenceLabel` constructor. - index_by_file_id: If True, saves a mapping from filename base (ID) to index in data; pass to `FeatureSequenceLabel` constructor. + index_by_file_id: If True, saves a mapping from filename base (ID) to index in data; + pass to `FeatureSequenceLabel` constructor. """ feature_files, labels, durations = [], [], [] From cb232682b9d142b086524d78d41cca10f6d55249 Mon Sep 17 00:00:00 2001 From: tango4j Date: Fri, 15 Nov 2024 20:34:41 +0000 Subject: [PATCH 11/16] Apply isort and black reformatting Signed-off-by: tango4j --- .../neural_diarizer/sortformer_diar_train.py | 4 +- .../asr/data/audio_to_diar_label.py | 52 ++++++------- nemo/collections/asr/metrics/der.py | 18 ++--- .../asr/metrics/multi_binary_acc.py | 6 +- .../asr/models/sortformer_diar_models.py | 4 +- .../asr/parts/utils/speaker_utils.py | 76 +++++++++---------- nemo/collections/asr/parts/utils/vad_utils.py | 58 +++++++------- .../common/parts/preprocessing/collections.py | 27 +++---- 8 files changed, 124 insertions(+), 121 deletions(-) diff --git a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py index 75980d342c65..5231f4822886 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py @@ -34,6 +34,7 @@ seed_everything(42) + @hydra_runner(config_path="../conf/neural_diarizer", config_name="sortformer_diarizer_hybrid_loss_4spk-v1.yaml") def main(cfg): logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') @@ -43,5 +44,6 @@ def main(cfg): sortformer_model.maybe_init_from_pretrained_checkpoint(cfg) trainer.fit(sortformer_model) + if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/nemo/collections/asr/data/audio_to_diar_label.py b/nemo/collections/asr/data/audio_to_diar_label.py index f47b5ca11f43..34aa0989e564 100644 --- a/nemo/collections/asr/data/audio_to_diar_label.py +++ b/nemo/collections/asr/data/audio_to_diar_label.py @@ -81,7 +81,7 @@ def extract_seg_info_from_rttm(rttm_lines, mapping_dict=None, target_spks=None): mapping_dict (dict): Mapping between the estimated speakers and the speakers in the ground-truth annotation. `mapping_dict` variable is only provided when the inference mode is running in sequence-eval mode. - Sequence eval mode uses the mapping between the estimated speakers and the speakers + Sequence eval mode uses the mapping between the estimated speakers and the speakers in ground-truth annotation. Returns: rttm_tup (tuple): @@ -116,10 +116,10 @@ def assign_frame_level_spk_vector(rttm_timestamps, round_digits, frame_per_sec, List containing start and end time for each speaker segment label. `stt_list`, `end_list` and `speaker_list` are contained. frame_per_sec (int): - Number of feature frames per second. This quantity is determined by + Number of feature frames per second. This quantity is determined by `window_stride` variable in preprocessing module. target_spks (tuple): - Speaker indices that are generated from combinations. + Speaker indices that are generated from combinations. If there are only one or two speakers, only a single `target_spks` variable is generated. @@ -151,13 +151,13 @@ def get_subsegments_to_timestamps( subsegments: List[Tuple[float, float]], feat_per_sec: int = 100, max_end_ts: float = None, decimals=2 ): """ - Convert subsegment timestamps to scale timestamps by multiplying with the feature rate (`feat_per_sec`) - and rounding. Segment is consisted of many subsegments and sugsegments are equivalent to `frames` + Convert subsegment timestamps to scale timestamps by multiplying with the feature rate (`feat_per_sec`) + and rounding. Segment is consisted of many subsegments and sugsegments are equivalent to `frames` in end-to-end speaker diarization models. Args: subsegments (List[Tuple[float, float]]): - A list of tuples where each tuple contains the start and end times of a subsegment + A list of tuples where each tuple contains the start and end times of a subsegment (frames in end-to-end models). >>> subsegments = [[t0_start, t0_duration], [t1_start, t1_duration],..., [tN_start, tN_duration]] feat_per_sec (int, optional): @@ -251,7 +251,7 @@ def get_frame_targets_from_rttm( List containing start and end time for each speaker segment label. stt_list, end_list and speaker_list are contained. feat_per_sec (int): - Number of feature frames per second. + Number of feature frames per second. This quantity is determined by window_stride variable in preprocessing module. target_spks (tuple): Speaker indices that are generated from combinations. If there are only one or two speakers, @@ -415,16 +415,16 @@ def assign_labels_to_longer_segs(self, uniq_id, base_scale_clus_label): def get_diar_target_labels(self, uniq_id, sample, fr_level_target): """ - Convert frame-level diarization target variable into segment-level target variable. - Since the granularity is reduced from frame level (10ms) to segment level (100ms~500ms), - we need a threshold value, `soft_label_thres`, which determines the label of each segment + Convert frame-level diarization target variable into segment-level target variable. + Since the granularity is reduced from frame level (10ms) to segment level (100ms~500ms), + we need a threshold value, `soft_label_thres`, which determines the label of each segment based on the overlap between a segment range (start and end time) and the frame-level target variable. Args: uniq_id (str): Unique file ID that refers to an input audio file and corresponding RTTM (Annotation) file. sample: - `DiarizationSpeechLabel` instance containing sample information such as + `DiarizationSpeechLabel` instance containing sample information such as audio filepath and RTTM filepath. fr_level_target (torch.tensor): Tensor containing label for each feature-level frame. @@ -433,7 +433,7 @@ def get_diar_target_labels(self, uniq_id, sample, fr_level_target): seg_target (torch.tensor): Tensor containing binary speaker labels for base-scale segments. base_clus_label (torch.tensor): - Representative speaker label for each segment. This variable only has one speaker label + Representative speaker label for each segment. This variable only has one speaker label for each base-scale segment. -1 means that there is no corresponding speaker in the target_spks tuple. """ @@ -469,7 +469,7 @@ def parse_rttm_for_ms_targets(self, sample): Args: sample: - `DiarizationSpeechLabel` instance containing sample information such as + `DiarizationSpeechLabel` instance containing sample information such as audio filepath and RTTM filepath. target_spks (tuple): Speaker indices that are generated from combinations. If there are only one or two speakers, @@ -591,7 +591,7 @@ class _AudioMSDDInferDataset(Dataset): emb_dict (dict): Dictionary containing cluster-average embeddings and speaker mapping information. emb_seq (dict): - Dictionary containing multiscale speaker embedding sequence, + Dictionary containing multiscale speaker embedding sequence, scale mapping and corresponding segment timestamps. clus_label_dict (dict): Subsegment-level (from base-scale) speaker labels from clustering results. @@ -691,8 +691,8 @@ def get_diar_target_labels_from_fr_target(self, uniq_id, fr_level_target): """ Generate base-scale level binary diarization label from frame-level target matrix. For the given frame-level speaker target matrix fr_level_target, we count the number of frames that belong to each speaker and calculate - ratios for each speaker into the `soft_label_vec` variable. Finally, `soft_label_vec` variable is compared - with `soft_label_thres` to determine whether a label vector should contain 0 or 1 for each speaker bin. + ratios for each speaker into the `soft_label_vec` variable. Finally, `soft_label_vec` variable is compared + with `soft_label_thres` to determine whether a label vector should contain 0 or 1 for each speaker bin. Note that seg_target variable has dimension of (number of base-scale segments x 2) dimension. Example of seg_target: @@ -739,7 +739,7 @@ def __getitem__(self, index): if avg_embs.shape[2] > self.max_spks: raise ValueError( - f" avg_embs.shape[2] {avg_embs.shape[2]} should be less than or equal to " + f" avg_embs.shape[2] {avg_embs.shape[2]} should be less than or equal to " f"self.max_num_speakers {self.max_spks}" ) @@ -834,7 +834,7 @@ def _msdd_train_collate_fn(self, batch): def _msdd_infer_collate_fn(self, batch): """ - Collate batch of feats (speaker embeddings), feature lengths, target label sequences + Collate batch of feats (speaker embeddings), feature lengths, target label sequences and cluster-average embeddings. Args: @@ -959,12 +959,12 @@ class AudioToSpeechMSDDInferDataset(_AudioMSDDInferDataset): emb_dict (dict): Dictionary containing cluster-average embeddings and speaker mapping information. emb_seq (dict): - Dictionary containing multiscale speaker embedding sequence, scale mapping + Dictionary containing multiscale speaker embedding sequence, scale mapping and corresponding segment timestamps. clus_label_dict (dict): Subsegment-level (from base-scale) speaker labels from clustering results. soft_label_thres (float): - Threshold that determines speaker labels of segments depending on the overlap + Threshold that determines speaker labels of segments depending on the overlap with groundtruth speaker timestamps. featurizer: Featurizer instance for generating features from raw waveform. @@ -973,11 +973,11 @@ class AudioToSpeechMSDDInferDataset(_AudioMSDDInferDataset): seq_eval_mode (bool): If True, F1 score will be calculated for each speaker pair during inference mode. window_stride (float): - Window stride for acoustic feature. This value is used for calculating the numbers of + Window stride for acoustic feature. This value is used for calculating the numbers of feature-level frames. pairwise_infer (bool): - If True, this Dataset class operates in inference mode. In inference mode, a set of speakers - in the input audio is split into multiple pairs of speakers and speaker tuples + If True, this Dataset class operates in inference mode. In inference mode, a set of speakers + in the input audio is split into multiple pairs of speakers and speaker tuples (e.g. 3 speakers: [(0,1), (1,2), (0,2)]) and then fed into the MSDD to merge the individual results. """ @@ -1225,7 +1225,7 @@ def __getitem__(self, index): uniq_id = self.get_uniq_id_with_range(sample) audio_signal = self.featurizer.process(sample.audio_file, offset=offset, duration=session_len_sec) - # We should resolve the length mis-match from the round-off errors between these two variables: + # We should resolve the length mis-match from the round-off errors between these two variables: # `session_len_sec` and `audio_signal.shape[0]` session_len_sec = ( np.floor(audio_signal.shape[0] / self.featurizer.sample_rate * self.floor_decimal) / self.floor_decimal @@ -1252,14 +1252,14 @@ def _eesd_train_collate_fn(self, batch): Returns: audio_signal (torch.Tensor): - A tensor containing the raw waveform samples (time series) loaded from the `audio_filepath` + A tensor containing the raw waveform samples (time series) loaded from the `audio_filepath` in the input manifest file. feature_length (torch.Tensor): A tensor containing the lengths of the raw waveform samples. targets (torch.Tensor): Groundtruth speaker labels for the given input embedding sequence. target_lens (torch.Tensor): - A tensor containing the number of segments for each sample in the batch, necessary for + A tensor containing the number of segments for each sample in the batch, necessary for reshaping inputs to the EESD model. """ packed_batch = list(zip(*batch)) diff --git a/nemo/collections/asr/metrics/der.py b/nemo/collections/asr/metrics/der.py index 22c9a76b7fc9..7496f700341f 100644 --- a/nemo/collections/asr/metrics/der.py +++ b/nemo/collections/asr/metrics/der.py @@ -145,20 +145,20 @@ def score_labels( Args: - AUDIO_RTTM_MAP (dict): + AUDIO_RTTM_MAP (dict): Dictionary containing information provided from manifestpath - all_reference (list[uniq_name,Annotation]): + all_reference (list[uniq_name,Annotation]): Reference annotations for score calculation - all_hypothesis (list[uniq_name,Annotation]): + all_hypothesis (list[uniq_name,Annotation]): Hypothesis annotations for score calculation - verbose (bool): + verbose (bool): Warns if RTTM file is not found. Returns: - metric (pyannote.DiarizationErrorRate): - Pyannote Diarization Error Rate metric object. + metric (pyannote.DiarizationErrorRate): + Pyannote Diarization Error Rate metric object. This object contains detailed scores of each audiofile. - mapping (dict): + mapping (dict): Mapping dict containing the mapping speaker label for each audio input < Caveat > @@ -178,7 +178,7 @@ def score_labels( correct_spk_count += 1 if verbose and len(ref_labels.labels()) != len(hyp_labels.labels()): logging.info( - f"Wrong Spk. Count with uniq_id:...{ref_key[-10:]}, " + f"Wrong Spk. Count with uniq_id:...{ref_key[-10:]}, " f"Ref: {len(ref_labels.labels())}, Hyp: {len(hyp_labels.labels())}" ) uem_obj = None @@ -455,4 +455,4 @@ def concat_perm_word_error_rate( cpWER_values.append(cpWER) hyps_spk.append(min_hypothesis) refs_spk.append(concat_reference) - return cpWER_values, hyps_spk, refs_spk \ No newline at end of file + return cpWER_values, hyps_spk, refs_spk diff --git a/nemo/collections/asr/metrics/multi_binary_acc.py b/nemo/collections/asr/metrics/multi_binary_acc.py index 8ad09c842636..3a99769ebd25 100644 --- a/nemo/collections/asr/metrics/multi_binary_acc.py +++ b/nemo/collections/asr/metrics/multi_binary_acc.py @@ -84,15 +84,15 @@ def __init__(self, dist_sync_on_step=False): def update( self, preds: torch.Tensor, targets: torch.Tensor, signal_lengths: torch.Tensor, cumulative=False ) -> torch.Tensor: - """ + """ Update the metric with the given predictions, targets, and signal lengths to the metric instance. - + Args: preds (torch.Tensor): Predicted values. targets (torch.Tensor): Target values. signal_lengths (torch.Tensor): Length of each sequence in the batch input. cumulative (bool): Whether to accumulate the values over time. - + Returns: f1_score (torch.Tensor): F1 score calculated from the predicted value and binarized target values. """ diff --git a/nemo/collections/asr/models/sortformer_diar_models.py b/nemo/collections/asr/models/sortformer_diar_models.py index e3c14dd77c65..2e15e095b77a 100644 --- a/nemo/collections/asr/models/sortformer_diar_models.py +++ b/nemo/collections/asr/models/sortformer_diar_models.py @@ -208,8 +208,8 @@ def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): def test_dataloader(self): if self._test_dl is not None: return self._test_dl - return None - + return None + @property def input_types(self) -> Optional[Dict[str, NeuralType]]: if hasattr(self.preprocessor, '_sample_rate'): diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index 1e7dda91c9e7..15ec8a24a3bd 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -76,11 +76,11 @@ def audio_rttm_map(manifest, attach_dur=False): """ This function creates AUDIO_RTTM_MAP which is used by all diarization components to extract embeddings, cluster and unify time stamps - + Args: manifest (str): Path to the manifest file attach_dur (bool, optional): If True, attach duration information to the unique name. Defaults to False. - + Returns: AUDIO_RTTM_MAP (dict) : Dictionary with unique names as keys and corresponding metadata as values. """ @@ -117,7 +117,7 @@ def audio_rttm_map(manifest, attach_dur=False): raise KeyError( f"file {meta['audio_filepath']} is already part of AUDIO_RTTM_MAP, it might be duplicated, " "Note: file basename must be unique" - ) + ) return AUDIO_RTTM_MAP @@ -247,7 +247,7 @@ def get_embs_and_timestamps(multiscale_embeddings_and_timestamps, multiscale_arg def get_timestamps(multiscale_timestamps, multiscale_args_dict): """ The timestamps in `multiscale_timestamps` dictionary are indexed by scale index. - This function rearranges the extracted speaker embedding and timestamps by unique ID + This function rearranges the extracted speaker embedding and timestamps by unique ID to make the further processing more convenient. Args: @@ -442,19 +442,19 @@ def perform_clustering( 'embeddings' : Tensor containing embeddings. Dimensions:(# of embs) x (emb. dimension) 'timestamps' : Tensor containing ime stamps list for each audio recording 'multiscale_segment_counts' : Tensor containing the number of segments for each scale - AUDIO_RTTM_MAP (dict): + AUDIO_RTTM_MAP (dict): AUDIO_RTTM_MAP for mapping unique id with audio file path and rttm path - out_rttm_dir (str): + out_rttm_dir (str): Path to write predicted rttms - clustering_params (dict): - Clustering parameters provided through config that contains max_num_speakers (int), - oracle_num_speakers (bool), max_rp_threshold(float), sparse_search_volume(int) + clustering_params (dict): + Clustering parameters provided through config that contains max_num_speakers (int), + oracle_num_speakers (bool), max_rp_threshold(float), sparse_search_volume(int) and enhance_count_threshold (int). - use_torch_script (bool): + use_torch_script (bool): Boolean that determines whether to use torch.jit.script for speaker clustering - device (torch.device): + device (torch.device): Device we are running on ('cpu', 'cuda'). - verbose (bool): + verbose (bool): Enable TQDM progress bar. Returns: @@ -624,7 +624,7 @@ def read_rttm_lines(rttm_file_path): raise FileNotFoundError( "Requested to construct manifest from rttm with oracle VAD option " f"or from NeMo VAD but received filename as {rttm_file_path}" - ) + ) return lines @@ -893,7 +893,7 @@ def segments_manifest_to_subsegments_manifest( Generate subsegments manifest from segments manifest file Args: segments_manifest file (str): path to segments manifest file, typically from VAD output - subsegments_manifest_file (str): path to output subsegments manifest file + subsegments_manifest_file (str): path to output subsegments manifest file (default (None) : writes to current working directory) window (float): window length for segments to subsegments length shift (float): hop length for subsegments shift @@ -968,7 +968,7 @@ def get_subsegments( it results in (10/0.08)+1 = 125 + 1 frames. Returns: - subsegments (List[tuple[float, float]]): subsegments generated for the segments as + subsegments (List[tuple[float, float]]): subsegments generated for the segments as list of tuple of start and duration of each subsegment """ subsegments: List[List[float]] = [] @@ -1051,9 +1051,9 @@ def tensor_to_list(range_tensor: torch.Tensor) -> List[List[float]]: def generate_diarization_output_lines(speaker_timestamps: List[List[float]], model_spk_num: int) -> List[str]: - """ + """ Generate diarization output lines list from the speaker timestamps list by merging overlapping intervals. - + Args: speaker_timestamps (list): List containing the start and end time of the speech intervals for each speaker. @@ -1061,7 +1061,7 @@ def generate_diarization_output_lines(speaker_timestamps: List[List[float]], mod >>> speaker_timestamps = [[0.5, 3.12], [3.51, 7.26],... ] model_spk_num (int): Number of speakers in the model. - + Returns: speaker_lines_total (list): List containing the diarization output lines in the format: @@ -1393,20 +1393,20 @@ def get_overlap_stamps(cont_stamps: List[str], ovl_spk_idx: List[str]): def get_adaptive_threshold(estimated_num_of_spks: int, min_threshold: float, overlap_infer_spk_limit: int): """ - This function controls the magnitude of the sigmoid threshold based on the estimated number of - speakers. As the number of speakers becomes larger, diarization error rate is very sensitive - to overlap speech detection. This function linearly increases the threshold in proportion to - the estimated number of speakers so more confident overlap speech results are reflected when + This function controls the magnitude of the sigmoid threshold based on the estimated number of + speakers. As the number of speakers becomes larger, diarization error rate is very sensitive + to overlap speech detection. This function linearly increases the threshold in proportion to + the estimated number of speakers so more confident overlap speech results are reflected when the number of estimated speakers is relatively high. Args: estimated_num_of_spks (int): Estimated number of speakers from the clustering result. min_threshold (float): - Sigmoid threshold value from the config file. This threshold value is the minimum + Sigmoid threshold value from the config file. This threshold value is the minimum threshold when `estimated_num_of_spks=2`. overlap_infer_spk_limit (int): - If the `estimated_num_of_spks` is less than `overlap_infer_spk_limit`, overlap speech + If the `estimated_num_of_spks` is less than `overlap_infer_spk_limit`, overlap speech estimation is skipped. Returns: @@ -1423,37 +1423,37 @@ def generate_speaker_timestamps( clus_labels: List[Union[float, int]], msdd_preds: List[torch.Tensor], **params ) -> Tuple[List[str], List[str]]: """ - Generate speaker timestamps from the segmentation information. If `use_clus_as_main=True`, use - clustering result for main speaker labels and use timestamps from the predicted sigmoid values. - In this function, the main speaker labels in `maj_labels` exist for every subsegment step, while + Generate speaker timestamps from the segmentation information. If `use_clus_as_main=True`, use + clustering result for main speaker labels and use timestamps from the predicted sigmoid values. + In this function, the main speaker labels in `maj_labels` exist for every subsegment step, while overlap speaker labels in `ovl_labels` only exist for segments where overlap speech occurs. Args: clus_labels (list): List containing integer-valued speaker clustering results. msdd_preds (list): - List containing tensors of the predicted sigmoid values. Each tensor has shape of: + List containing tensors of the predicted sigmoid values. Each tensor has shape of: (Session length, estimated number of speakers). params: Parameters for generating RTTM output and evaluation. Parameters include: infer_overlap (bool): If False, overlap speech will not be detected. - use_clus_as_main (bool): Add overlap-speech detection from MSDD to clustering results. - If False, only MSDD output is used for constructing output + use_clus_as_main (bool): Add overlap-speech detection from MSDD to clustering results. + If False, only MSDD output is used for constructing output RTTM files. overlap_infer_spk_limit (int): Above this limit, overlap-speech detection is bypassed. - use_adaptive_thres (bool): Boolean that determines whether to use adaptive thresholds + use_adaptive_thres (bool): Boolean that determines whether to use adaptive thresholds depending on the estimated number of speakers. max_overlap_spks (int): Maximum number of overlap speakers detected. Default is 2. threshold (float): Sigmoid threshold for MSDD output. Returns: maj_labels (list): - List containing string-formatted single-speaker speech segment timestamps and corresponding + List containing string-formatted single-speaker speech segment timestamps and corresponding speaker labels. Example: [..., '551.685 552.77 speaker_1', '552.99 554.43 speaker_0', '554.97 558.19 speaker_0', ...] ovl_labels (list): - List containing string-formatted additional overlapping speech segment timestamps and - corresponding speaker labels. Note that `ovl_labels` includes only overlapping speech that + List containing string-formatted additional overlapping speech segment timestamps and + corresponding speaker labels. Note that `ovl_labels` includes only overlapping speech that is not included in `maj_labels`. Example: [..., '152.495 152.745 speaker_1', '372.71 373.085 speaker_0', '554.97 555.885 speaker_1', ...] """ @@ -1510,7 +1510,7 @@ def get_id_tup_dict(uniq_id_list: List[str], test_data_collection, preds_list: L uniq_id_list (list): List containing the `uniq_id` values. test_data_collection (collections.DiarizationLabelEntity): - Class instance that is containing session information such as targeted speaker indices, + Class instance that is containing session information such as targeted speaker indices, audio filepath and RTTM filepath. preds_list (list): List containing tensors of predicted sigmoid values. @@ -1540,14 +1540,14 @@ def prepare_split_data(manifest_filepath, _out_dir, multiscale_args_dict, global Returns: multiscale_args_dict (dict): - - Dictionary containing two types of arguments: multi-scale weights and subsegment timestamps + - Dictionary containing two types of arguments: multi-scale weights and subsegment timestamps for each data sample. - Each data sample has two keys: `multiscale_weights` and `scale_dict`. - `multiscale_weights` key contains a list containing multiscale weights. - `scale_dict` is indexed by integer keys which are scale index. - - Each data sample is indexed by using the following naming convention: + - Each data sample is indexed by using the following naming convention: `__` - + Example: `fe_03_00106_mixed_626310_642300` """ speaker_dir = os.path.join(_out_dir, 'speaker_outputs') diff --git a/nemo/collections/asr/parts/utils/vad_utils.py b/nemo/collections/asr/parts/utils/vad_utils.py index 0fbda543ca11..83a811ee4adb 100644 --- a/nemo/collections/asr/parts/utils/vad_utils.py +++ b/nemo/collections/asr/parts/utils/vad_utils.py @@ -246,7 +246,7 @@ def generate_overlap_vad_seq( out_dir: str = None, ) -> str: """ - Generate predictions with overlapping input windows/segments. + Generate predictions with overlapping input windows/segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple windows. Two common smoothing filters are supported: majority vote (median) and average (mean). This function uses multiprocessing to speed up. @@ -311,7 +311,7 @@ def generate_overlap_vad_seq_per_tensor( frame: torch.Tensor, per_args: Dict[str, float], smoothing_method: str ) -> torch.Tensor: """ - Use generated frame prediction (generated by shifting window of shift_length_in_sec (10ms)) to generate + Use generated frame prediction (generated by shifting window of shift_length_in_sec (10ms)) to generate prediction with overlapping input window/segments. See description in generate_overlap_vad_seq. Use this for single instance pipeline. """ @@ -473,7 +473,7 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te Binarize predictions to speech and non-speech Reference - Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice + Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", InterSpeech 2015. Implementation: https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py @@ -488,7 +488,7 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te Returns: speech_segments(torch.Tensor): A tensor of speech segment in the form of: - `torch.Tensor([[start1, end1], [start2, end2]])`. + `torch.Tensor([[start1, end1], [start2, end2]])`. """ frame_length_in_sec = per_args.get('frame_length_in_sec', 0.01) @@ -538,7 +538,7 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te def remove_segments(original_segments: torch.Tensor, to_be_removed_segments: torch.Tensor) -> torch.Tensor: """ Remove speech segments list in to_be_removed_segments from original_segments. - (Example) Remove torch.Tensor([[start2, end2],[start4, end4]]) + (Example) Remove torch.Tensor([[start2, end2],[start4, end4]]) from torch.Tensor([[start1, end1],[start2, end2],[start3, end3], [start4, end4]]), -> torch.Tensor([[start1, end1],[start3, end3]]) @@ -565,26 +565,26 @@ def filtering(speech_segments: torch.Tensor, per_args: Dict[str, float]) -> torc Filter out short non-speech and speech segments. Reference: - Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice + Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", InterSpeech 2015. - Implementation: + Implementation: https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py Args: - speech_segments (torch.Tensor): - A tensor of speech segments in the format + speech_segments (torch.Tensor): + A tensor of speech segments in the format torch.Tensor([[start1, end1], [start2, end2]]). per_args: - min_duration_on (float): + min_duration_on (float): Threshold for small non-speech deletion. - min_duration_off (float): + min_duration_off (float): Threshold for short speech segment deletion. - filter_speech_first (float): + filter_speech_first (float): Whether to perform short speech segment deletion first. Use 1.0 to represent True. Returns: - speech_segments (torch.Tensor): - A tensor of filtered speech segments in the format + speech_segments (torch.Tensor): + A tensor of filtered speech segments in the format torch.Tensor([[start1, end1], [start2, end2]]). """ if speech_segments.shape == torch.Size([0]): @@ -722,7 +722,7 @@ def generate_vad_segment_table( 17,18, speech Args: vad_pred_dir (str): directory of prediction files to be processed. - postprocessing_params (dict): dictionary of thresholds for prediction score. + postprocessing_params (dict): dictionary of thresholds for prediction score. See details in binarization and filtering. frame_length_in_sec (float): frame length. out_dir (str): output dir of generated table/csv file. @@ -834,12 +834,12 @@ def vad_tune_threshold_on_dev( num_workers: int = 20, ) -> Tuple[dict, dict]: """ - Tune thresholds on dev set. Return best thresholds which gives the lowest + Tune thresholds on dev set. Return best thresholds which gives the lowest detection error rate (DetER) in thresholds. - + Args: params (dict): dictionary of parameters to be tuned on. - vad_pred_method (str): suffix of prediction file. Use to locate file. + vad_pred_method (str): suffix of prediction file. Use to locate file. Should be either in "frame", "mean" or "median". groundtruth_RTTM_dir (str): Directory of ground-truth rttm files or a file contains the paths of them. focus_metric (str): Metrics we care most when tuning threshold. Should be either in "DetER", "FA", "MISS" @@ -1003,7 +1003,7 @@ def plot( threshold (float): threshold for prediction score (from 0 to 1). per_args(dict): a dict that stores the thresholds for postprocessing. unit_frame_len (float): unit frame length in seconds for VAD predictions. - label_repeat (int): repeat the label for this number of times to match different + label_repeat (int): repeat the label for this number of times to match different frame lengths in preds and labels. xticks_step (int): step size for xticks. """ @@ -1490,21 +1490,21 @@ def plot_sample_from_rttm( def align_labels_to_frames(probs, labels, threshold=0.2): """ - Aligns labels to frames when the frame length (e.g., 10ms) is different from the label length - (e.g., 20ms). The threshold 0.2 is not critical, as the actual ratio will always be close to an - integer unless using frame/label lengths that are not multiples of each other (e.g., 15ms frame + Aligns labels to frames when the frame length (e.g., 10ms) is different from the label length + (e.g., 20ms). The threshold 0.2 is not critical, as the actual ratio will always be close to an + integer unless using frame/label lengths that are not multiples of each other (e.g., 15ms frame length and 20ms label length), which is not valid. The value 0.2 is chosen for easier unit testing. Args: - probs (List[float]): + probs (List[float]): List of probabilities. - labels (List[int]): + labels (List[int]): List of labels. - threshold (float): + threshold (float): Threshold for rounding the ratio to an integer. Returns: - labels (List[int]): + labels (List[int]): List of labels aligned to frames. """ frames_len = len(probs) @@ -1536,12 +1536,12 @@ def align_labels_to_frames(probs, labels, threshold=0.2): ratio = frames_len / labels_len res = frames_len % labels_len if ceil(ratio) - ratio < threshold: - # e.g., ratio is 1.83, ceil(ratio) = 2, then we repeat labels + # e.g., ratio is 1.83, ceil(ratio) = 2, then we repeat labels # to make it a multiple of 2, and discard the redundant labels labels = labels.repeat_interleave(ceil(ratio), dim=0).long().tolist() labels = labels[:frames_len] else: - # e.g., ratio is 2.02, floor(ratio) = 2, then we repeat labels + # e.g., ratio is 2.02, floor(ratio) = 2, then we repeat labels # to make it a multiple of 2 and add additional labels labels = labels.repeat_interleave(floor(ratio), dim=0).long().tolist() if res > 0: @@ -1747,7 +1747,7 @@ def ts_vad_post_processing( """ Post-processing on diarization results using VAD style post-processing methods. These post-processing methods are inspired by the following paper: - Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: + Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020). Args: diff --git a/nemo/collections/common/parts/preprocessing/collections.py b/nemo/collections/common/parts/preprocessing/collections.py index a4728c29ff06..b6db109afa58 100644 --- a/nemo/collections/common/parts/preprocessing/collections.py +++ b/nemo/collections/common/parts/preprocessing/collections.py @@ -477,6 +477,7 @@ def __init__(self, manifests_files: Union[str, List[str]], *args, **kwargs): class SpeechLLMAudioTextEntity(object): """Class for SpeechLLM dataloader instance.""" + def __init__(self, sid, audio_file, duration, context, answer, offset, speaker, orig_sr, lang) -> None: """Initialize the AudioTextEntity for a SpeechLLM dataloader instance.""" self.id = sid @@ -1016,9 +1017,9 @@ def __init__( def relative_speaker_parser(self, seq_label): """Convert sequence of speaker labels to relative labels. Convert sequence of absolute speaker to sequence of relative speaker [E A C A E E C] -> [0 1 2 1 0 0 2] - In this seq of label , if label do not appear before, assign new relative labels len(pos); + In this seq of label , if label do not appear before, assign new relative labels len(pos); else reuse previous assigned relative labels. - + Args: seq_label (str): A string of a sequence of labels. @@ -1055,12 +1056,12 @@ def __init__( """Parse lists of feature files and sequences of labels. Args: - manifests_files: + manifests_files: Either single string file or list of such manifests to yield items from. - max_number: + max_number: Maximum number of samples to collect; pass to `FeatureSequenceLabel` constructor. - index_by_file_id: - If True, saves a mapping from filename base (ID) to index in data; + index_by_file_id: + If True, saves a mapping from filename base (ID) to index in data; pass to `FeatureSequenceLabel` constructor. """ @@ -1223,8 +1224,8 @@ def __init__( **kwargs, ): """ - Parse lists of audio files, durations, RTTM (Diarization annotation) files. Since the diarization - model infers only two speakers, speaker pairs are generated from the total number of speakers in + Parse lists of audio files, durations, RTTM (Diarization annotation) files. Since the diarization + model infers only two speakers, speaker pairs are generated from the total number of speakers in the session. Args: @@ -1239,9 +1240,9 @@ def __init__( seq_eval_mode (bool): If True, F1 score will be calculated for each speaker pair during inference mode. pairwise_infer (bool): - If True, this dataset class operates in inference mode. In inference mode, a set of - speakers in the input audio is split into multiple pairs of speakers and speaker tuples - (e.g., 3 speakers: [(0,1), (1,2), (0,2)]) and then fed into the diarization system to + If True, this dataset class operates in inference mode. In inference mode, a set of + speakers in the input audio is split into multiple pairs of speakers and speaker tuples + (e.g., 3 speakers: [(0,1), (1,2), (0,2)]) and then fed into the diarization system to merge the individual results. *args: Args to pass to `SpeechLabel` constructor. **kwargs: Kwargs to pass to `SpeechLabel` constructor. @@ -1413,7 +1414,7 @@ def __init__( offsets (List[float]): List of offsets or None for each audio file. max_number (Optional[int]): Maximum number of samples to collect. Defaults to None. do_sort_by_duration (bool): If True, sort samples list by duration. Defaults to False. - index_by_file_id (bool): If True, saves a mapping from filename base (ID) to index in data. + index_by_file_id (bool): If True, saves a mapping from filename base (ID) to index in data. Defaults to False. """ @@ -1832,7 +1833,7 @@ def __init__( manifests_files: Either single string file or list of such - manifests to yield items from. max_number: Maximum number of samples to collect; pass to `FeatureSequenceLabel` constructor. - index_by_file_id: If True, saves a mapping from filename base (ID) to index in data; + index_by_file_id: If True, saves a mapping from filename base (ID) to index in data; pass to `FeatureSequenceLabel` constructor. """ From 4a266b93a38384236f41a58c642170dd2c4fac03 Mon Sep 17 00:00:00 2001 From: taejinp Date: Fri, 15 Nov 2024 14:49:28 -0800 Subject: [PATCH 12/16] Resolving pylint issues Signed-off-by: taejinp --- .../neural_diarizer/e2e_diarize_speech.py | 22 ++++++++- .../neural_diarizer/sortformer_diar_train.py | 1 + .../asr/data/audio_to_diar_label.py | 2 +- .../asr/data/audio_to_diar_label_lhotse.py | 2 + .../asr/models/sortformer_diar_models.py | 6 ++- .../asr/modules/sortformer_modules.py | 45 ++++++++++-------- .../asr/parts/utils/asr_multispeaker_utils.py | 46 ++++++++++++------- .../common/parts/preprocessing/collections.py | 10 ++-- 8 files changed, 88 insertions(+), 46 deletions(-) diff --git a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py index 0f90e70eff80..cb09b7df3100 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -45,6 +45,12 @@ @dataclass class PostProcessingParams: + """ + Postprocessing parameters for end-to-end speaker diarization models. + These parameters can significantly affect DER performance depending on the evaluation style and the dataset. + It is recommended to tune these parameters based on the evaluation style and the dataset + to achieve the desired DER performance. + """ onset: float = 0.5 # Onset threshold for detecting the beginning and end of a speech offset: float = 0.5 # Offset threshold for detecting the end of a speech pad_onset: float = 0.0 # Adding durations before each speech segment @@ -55,7 +61,7 @@ class PostProcessingParams: @dataclass class DiarizationConfig: - # Required configs + """Diarization configuration parameters for inference.""" model_path: Optional[str] = None # Path to a .nemo file pretrained_name: Optional[str] = None # Name of a pretrained model audio_dir: Optional[str] = None # Path to a directory which contains audio files @@ -221,6 +227,17 @@ def run_optuna_hyperparam_search( preds_list: List[torch.Tensor], temp_out_dir: str, ): + """ + Run Optuna hyperparameter optimization for speaker diarization. + + Args: + cfg (DiarizationConfig): The configuration object containing model and dataset details. + postprocessing_cfg (PostProcessingParams): The current postprocessing configuration. + infer_audio_rttm_dict (dict): dictionary of audio file path, offset, duration and RTTM filepath. + preds_list (List[torch.Tensor]): list of prediction matrices containing sigmoid values for each speaker. + Dimension: [(1, frames, num_speakers), ..., (1, frames, num_speakers)] + temp_out_dir (str): temporary directory for storing intermediate outputs. + """ worker_function = lambda trial: diarization_objective( trial=trial, postprocessing_cfg=postprocessing_cfg, @@ -300,6 +317,7 @@ def convert_pred_mat_to_segments( @hydra_runner(config_name="DiarizationConfig", schema=DiarizationConfig) def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: + """Main function for end-to-end speaker diarization inference.""" for key in cfg: cfg[key] = None if cfg[key] == 'None' else cfg[key] diff --git a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py index 75980d342c65..bff2218e361f 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py @@ -36,6 +36,7 @@ @hydra_runner(config_path="../conf/neural_diarizer", config_name="sortformer_diarizer_hybrid_loss_4spk-v1.yaml") def main(cfg): + """Main function for training the sortformer diarizer model.""" logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') trainer = pl.Trainer(**cfg.trainer) exp_manager(trainer, cfg.get("exp_manager", None)) diff --git a/nemo/collections/asr/data/audio_to_diar_label.py b/nemo/collections/asr/data/audio_to_diar_label.py index f47b5ca11f43..48454e310070 100644 --- a/nemo/collections/asr/data/audio_to_diar_label.py +++ b/nemo/collections/asr/data/audio_to_diar_label.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/nemo/collections/asr/data/audio_to_diar_label_lhotse.py b/nemo/collections/asr/data/audio_to_diar_label_lhotse.py index 8d11c4c1167d..14723a398fe8 100644 --- a/nemo/collections/asr/data/audio_to_diar_label_lhotse.py +++ b/nemo/collections/asr/data/audio_to_diar_label_lhotse.py @@ -37,6 +37,8 @@ class LhotseAudioToSpeechE2ESpkDiarDataset(torch.utils.data.Dataset): @property def output_types(self) -> Optional[Dict[str, NeuralType]]: + """Define the output types of the dataset. + """ return { 'audio_signal': NeuralType(('B', 'T'), AudioSignal()), 'a_sig_length': NeuralType(tuple('B'), LengthsType()), diff --git a/nemo/collections/asr/models/sortformer_diar_models.py b/nemo/collections/asr/models/sortformer_diar_models.py index e3c14dd77c65..eadabe642779 100644 --- a/nemo/collections/asr/models/sortformer_diar_models.py +++ b/nemo/collections/asr/models/sortformer_diar_models.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -172,7 +172,8 @@ def __setup_dataloader_from_config(self, config): soft_targets=config.soft_targets if 'soft_targets' in config else False, ) logging.info( - f"AAB: Dataloader dataset is created, starting torch.utils.data.Dataloader step B: {time.time() - time_flag}" + f"AAB: Dataloader dataset is created, starting torch.utils.data.Dataloader" + f"step B: {time.time() - time_flag}" ) self.data_collection = dataset.collection @@ -581,4 +582,5 @@ def test_batch( def diarize( self, ): + """One-clieck runner function for diarization.""" raise NotImplementedError diff --git a/nemo/collections/asr/modules/sortformer_modules.py b/nemo/collections/asr/modules/sortformer_modules.py index e0b5b15094b6..36b7438c9a92 100644 --- a/nemo/collections/asr/modules/sortformer_modules.py +++ b/nemo/collections/asr/modules/sortformer_modules.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import OrderedDict - import torch import torch.nn as nn import torch.nn.functional as F @@ -26,24 +24,13 @@ class SortformerModules(NeuralModule, Exportable): """ - Multi-scale Diarization Decoder (MSDD) for overlap-aware diarization and improved diarization accuracy from clustering diarizer. - Based on the paper: Taejin Park et. al, "Multi-scale Speaker Diarization with Dynamic Scale Weighting", Interspeech 2022. - Arxiv version: https://arxiv.org/pdf/2203.15974.pdf - - Args: - num_spks (int): - Max number of speakers that are processed by the model. In `MSDD_module`, `num_spks=2` for pairwise inference. - hidden_size (int): - Number of hidden units in sequence models and intermediate layers. - dropout_rate (float): - Dropout rate for linear layers, CNN and LSTM. - fc_d_model (int): - Dimension of the embedding vectors. - tf_d_model (int): - Dimension of the embedding vectors. + A class including auxiliary functions for Sortformer models. + This class contains and will contain the following functions that performs streaming features, + and any neural layers that are not included in the NeMo neural modules (e.g. Transformer, Fast-Conformer). """ def init_weights(self, m): + """Init weights for linear layers.""" if type(m) == nn.Linear: torch.nn.init.xavier_uniform_(m.weight) m.bias.data.fill_(0.01) @@ -56,6 +43,19 @@ def __init__( fc_d_model: int = 512, tf_d_model: int = 192, ): + """ + Args: + num_spks (int): + Max number of speakers that are processed by the model. + hidden_size (int): + Number of hidden units in sequence models and intermediate layers. + dropout_rate (float): + Dropout rate for linear layers, CNN and LSTM. + fc_d_model (int): + Dimension of the embedding vectors. + tf_d_model (int): + Dimension of the embedding vectors. + """ super().__init__() self.fc_d_model = fc_d_model self.tf_d_model = tf_d_model @@ -91,6 +91,15 @@ def length_to_mask(self, context_embs): return mask.float().to(context_embs.device) def forward_speaker_sigmoids(self, hidden_out): + """ + A set of layers for predicting speaker probabilities with a sigmoid activation function. + + Args: + hidden_out (torch.Tensor): tensor of shape (batch_size, seq_len, hidden_size) + + Returns: + preds (torch.Tensor): tensor of shape (batch_size, num_spks) containing speaker probabilities + """ hidden_out = self.dropout(F.relu(hidden_out)) hidden_out = self.first_hidden_to_hidden(hidden_out) hidden_out = self.dropout(F.relu(hidden_out)) diff --git a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py index e945439bf8fa..46bcf2f1a8c6 100644 --- a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py +++ b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py @@ -28,7 +28,8 @@ def find_first_nonzero(mat: torch.Tensor, max_cap_val=-1, thres: float = 0.5) -> thres (float): The threshold value for discretizing the matrix values. Returns: - mask_max_indices (Tensor): A torch tensor representing the discretized matrix with the first nonzero value in each row. + mask_max_indices (Tensor): A torch tensor representing the discretized matrix with the first + nonzero value in each row. """ # Discretize the matrix to the specified maximum capacity labels_discrete = mat.clone() @@ -229,7 +230,8 @@ def get_mask_from_segments( speaker_to_idx_map (dict): A dictionary mapping speaker names to indices. num_speakers (int): max number of speakers for all cuts ("mask" dim0), 4 by default feat_per_sec (int): number of frames per second, 100 by default, 0.01s frame rate - ignore_num_spk_mismatch (bool): This is a temporary solution to handle speaker mismatch. Will be removed in the future. + ignore_num_spk_mismatch (bool): This is a temporary solution to handle speaker mismatch. + Will be removed in the future. Returns: mask (Tensor): A numpy array of shape (num_speakers, encoder_hidden_len). @@ -315,25 +317,34 @@ def speaker_to_target( ignore_num_spk_mismatch: bool = True, soft_thres: float = 0.5, ): - ''' - Get rttm samples corresponding to one cut, generate speaker mask numpy.ndarray with shape (num_speaker, hidden_length) - This function is needed for speaker diarization with ASR model trainings. + """ + Get rttm samples corresponding to one cut, generate speaker mask numpy.ndarray with shape + (num_speaker, hidden_length). This function is needed for speaker diarization with ASR model trainings. Args: - a_cut (MonoCut, MixedCut): Lhotse Cut instance which is MonoCut or MixedCut instance. - num_speakers (int): max number of speakers for all cuts ("mask" dim0), 4 by default - num_sample_per_mel_frame (int): number of sample per mel frame, sample_rate / 1000 * window_stride, 160 by default (10ms window stride) - num_mel_frame_per_asr_frame (int): encoder subsampling_factor, 8 by default - spk_tar_all_zero (Tensor): set to True gives all zero "mask" - boundary_segments (bool): set to True to include segments containing the boundary of the cut, False by default for multi-speaker ASR training - soft_label (bool): set to True to use soft label that enables values in [0, 1] range, False by default and leads to binary labels. - ignore_num_spk_mismatch (bool): This is a temporary solution to handle speaker mismatch. Will be removed in the future. + a_cut (MonoCut, MixedCut): + Lhotse Cut instance which is MonoCut or MixedCut instance. + num_speakers (int): + Max number of speakers for all cuts ("mask" dim0), 4 by default + num_sample_per_mel_frame (int): + Number of sample per mel frame, sample_rate / 1000 * window_stride, 160 by default (10ms window stride) + num_mel_frame_per_asr_frame (int): + Encoder subsampling_factor, 8 by default + spk_tar_all_zero (Tensor): + Set to True gives all zero "mask" + boundary_segments (bool): + Set to True to include segments containing the boundary of the cut, + False by default for multi-speaker ASR training + soft_label (bool): + Set to True to use soft label that enables values in [0, 1] range, + False by default and leads to binary labels. + ignore_num_spk_mismatch (bool): + This is a temporary solution to handle speaker mismatch. Will be removed in the future. Returns: - mask (Tensor): speaker mask with shape (num_speaker, hidden_lenght) - ''' + mask (Tensor): Speaker mask with shape (num_speaker, hidden_lenght) + """ # get cut-related segments from rttms - # basename = os.path.basename(a_cut.rttm_filepath).replace('.rttm', '') if isinstance(a_cut, MixedCut): cut_list = [track.cut for track in a_cut.tracks if isinstance(track.cut, MonoCut)] offsets = [track.offset for track in a_cut.tracks if isinstance(track.cut, MonoCut)] @@ -374,7 +385,8 @@ def speaker_to_target( speaker_to_idx_map = {spk: idx for idx, spk in enumerate(speaker_ats)} if len(speaker_to_idx_map) > num_speakers and not ignore_num_spk_mismatch: # raise error if number of speakers raise ValueError( - f"Number of speakers {len(speaker_to_idx_map)} is larger than the maximum number of speakers {num_speakers}" + f"Number of speakers {len(speaker_to_idx_map)} is larger than " + f"the maximum number of speakers {num_speakers}" ) # initialize mask matrices (num_speaker, encoder_hidden_len) diff --git a/nemo/collections/common/parts/preprocessing/collections.py b/nemo/collections/common/parts/preprocessing/collections.py index a4728c29ff06..13f9efe48a06 100644 --- a/nemo/collections/common/parts/preprocessing/collections.py +++ b/nemo/collections/common/parts/preprocessing/collections.py @@ -310,10 +310,9 @@ def __init__( class InstructionTuningAudioText(_Collection): """`AudioText` collector from asr structured json files.""" - OUTPUT_TYPE = collections.namedtuple( - typename='InstructionTuningText', - field_names='id context context_type context_duration question question_type answer answer_type answer_duration speaker', - ) + OUTPUT_TYPE = collections.namedtuple(typename='InstructionTuningText', + field_names=('id context context_type context_duration question ' + 'question_type answer answer_type answer_duration speaker'),) def __init__( self, @@ -559,7 +558,6 @@ def __init__( ): """Instantiates audio-context-answer manifest with filters and preprocessing. - Args: ids: List of examples positions. audio_files: List of audio files. @@ -1471,7 +1469,7 @@ def __init__( class EndtoEndDiarizationSpeechLabel(EndtoEndDiarizationLabel): - """`DiarizationLabel` diarization data sample collector from structured json files.""" + """End-to-end speaker diarization data sample collector from structured json files.""" def __init__( self, From 6e2225ef65552b6c6c557779bf3ee6b7b3f1f340 Mon Sep 17 00:00:00 2001 From: tango4j Date: Fri, 15 Nov 2024 22:50:57 +0000 Subject: [PATCH 13/16] Apply isort and black reformatting Signed-off-by: tango4j --- .../neural_diarizer/e2e_diarize_speech.py | 4 ++- .../asr/data/audio_to_diar_label_lhotse.py | 3 +-- .../asr/models/sortformer_diar_models.py | 2 +- .../asr/modules/sortformer_modules.py | 8 +++--- .../asr/parts/utils/asr_multispeaker_utils.py | 26 +++++++++---------- .../common/parts/preprocessing/collections.py | 10 ++++--- 6 files changed, 29 insertions(+), 24 deletions(-) diff --git a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py index cb09b7df3100..65ba0226988a 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py @@ -48,9 +48,10 @@ class PostProcessingParams: """ Postprocessing parameters for end-to-end speaker diarization models. These parameters can significantly affect DER performance depending on the evaluation style and the dataset. - It is recommended to tune these parameters based on the evaluation style and the dataset + It is recommended to tune these parameters based on the evaluation style and the dataset to achieve the desired DER performance. """ + onset: float = 0.5 # Onset threshold for detecting the beginning and end of a speech offset: float = 0.5 # Offset threshold for detecting the end of a speech pad_onset: float = 0.0 # Adding durations before each speech segment @@ -62,6 +63,7 @@ class PostProcessingParams: @dataclass class DiarizationConfig: """Diarization configuration parameters for inference.""" + model_path: Optional[str] = None # Path to a .nemo file pretrained_name: Optional[str] = None # Name of a pretrained model audio_dir: Optional[str] = None # Path to a directory which contains audio files diff --git a/nemo/collections/asr/data/audio_to_diar_label_lhotse.py b/nemo/collections/asr/data/audio_to_diar_label_lhotse.py index 14723a398fe8..0839b63954f0 100644 --- a/nemo/collections/asr/data/audio_to_diar_label_lhotse.py +++ b/nemo/collections/asr/data/audio_to_diar_label_lhotse.py @@ -37,8 +37,7 @@ class LhotseAudioToSpeechE2ESpkDiarDataset(torch.utils.data.Dataset): @property def output_types(self) -> Optional[Dict[str, NeuralType]]: - """Define the output types of the dataset. - """ + """Define the output types of the dataset.""" return { 'audio_signal': NeuralType(('B', 'T'), AudioSignal()), 'a_sig_length': NeuralType(tuple('B'), LengthsType()), diff --git a/nemo/collections/asr/models/sortformer_diar_models.py b/nemo/collections/asr/models/sortformer_diar_models.py index b54bfcc0d05c..5a3c8e354f1b 100644 --- a/nemo/collections/asr/models/sortformer_diar_models.py +++ b/nemo/collections/asr/models/sortformer_diar_models.py @@ -172,7 +172,7 @@ def __setup_dataloader_from_config(self, config): soft_targets=config.soft_targets if 'soft_targets' in config else False, ) logging.info( - f"AAB: Dataloader dataset is created, starting torch.utils.data.Dataloader" + f"AAB: Dataloader dataset is created, starting torch.utils.data.Dataloader" f"step B: {time.time() - time_flag}" ) diff --git a/nemo/collections/asr/modules/sortformer_modules.py b/nemo/collections/asr/modules/sortformer_modules.py index 36b7438c9a92..193dae29c304 100644 --- a/nemo/collections/asr/modules/sortformer_modules.py +++ b/nemo/collections/asr/modules/sortformer_modules.py @@ -43,10 +43,10 @@ def __init__( fc_d_model: int = 512, tf_d_model: int = 192, ): - """ + """ Args: num_spks (int): - Max number of speakers that are processed by the model. + Max number of speakers that are processed by the model. hidden_size (int): Number of hidden units in sequence models and intermediate layers. dropout_rate (float): @@ -54,7 +54,7 @@ def __init__( fc_d_model (int): Dimension of the embedding vectors. tf_d_model (int): - Dimension of the embedding vectors. + Dimension of the embedding vectors. """ super().__init__() self.fc_d_model = fc_d_model @@ -93,7 +93,7 @@ def length_to_mask(self, context_embs): def forward_speaker_sigmoids(self, hidden_out): """ A set of layers for predicting speaker probabilities with a sigmoid activation function. - + Args: hidden_out (torch.Tensor): tensor of shape (batch_size, seq_len, hidden_size) diff --git a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py index 46bcf2f1a8c6..eddfd3254adc 100644 --- a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py +++ b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py @@ -28,7 +28,7 @@ def find_first_nonzero(mat: torch.Tensor, max_cap_val=-1, thres: float = 0.5) -> thres (float): The threshold value for discretizing the matrix values. Returns: - mask_max_indices (Tensor): A torch tensor representing the discretized matrix with the first + mask_max_indices (Tensor): A torch tensor representing the discretized matrix with the first nonzero value in each row. """ # Discretize the matrix to the specified maximum capacity @@ -230,7 +230,7 @@ def get_mask_from_segments( speaker_to_idx_map (dict): A dictionary mapping speaker names to indices. num_speakers (int): max number of speakers for all cuts ("mask" dim0), 4 by default feat_per_sec (int): number of frames per second, 100 by default, 0.01s frame rate - ignore_num_spk_mismatch (bool): This is a temporary solution to handle speaker mismatch. + ignore_num_spk_mismatch (bool): This is a temporary solution to handle speaker mismatch. Will be removed in the future. Returns: @@ -318,27 +318,27 @@ def speaker_to_target( soft_thres: float = 0.5, ): """ - Get rttm samples corresponding to one cut, generate speaker mask numpy.ndarray with shape + Get rttm samples corresponding to one cut, generate speaker mask numpy.ndarray with shape (num_speaker, hidden_length). This function is needed for speaker diarization with ASR model trainings. Args: - a_cut (MonoCut, MixedCut): + a_cut (MonoCut, MixedCut): Lhotse Cut instance which is MonoCut or MixedCut instance. - num_speakers (int): + num_speakers (int): Max number of speakers for all cuts ("mask" dim0), 4 by default - num_sample_per_mel_frame (int): + num_sample_per_mel_frame (int): Number of sample per mel frame, sample_rate / 1000 * window_stride, 160 by default (10ms window stride) - num_mel_frame_per_asr_frame (int): + num_mel_frame_per_asr_frame (int): Encoder subsampling_factor, 8 by default - spk_tar_all_zero (Tensor): + spk_tar_all_zero (Tensor): Set to True gives all zero "mask" - boundary_segments (bool): - Set to True to include segments containing the boundary of the cut, + boundary_segments (bool): + Set to True to include segments containing the boundary of the cut, False by default for multi-speaker ASR training - soft_label (bool): - Set to True to use soft label that enables values in [0, 1] range, + soft_label (bool): + Set to True to use soft label that enables values in [0, 1] range, False by default and leads to binary labels. - ignore_num_spk_mismatch (bool): + ignore_num_spk_mismatch (bool): This is a temporary solution to handle speaker mismatch. Will be removed in the future. Returns: diff --git a/nemo/collections/common/parts/preprocessing/collections.py b/nemo/collections/common/parts/preprocessing/collections.py index b3d96f17ce8f..5773ddf4b79b 100644 --- a/nemo/collections/common/parts/preprocessing/collections.py +++ b/nemo/collections/common/parts/preprocessing/collections.py @@ -310,9 +310,13 @@ def __init__( class InstructionTuningAudioText(_Collection): """`AudioText` collector from asr structured json files.""" - OUTPUT_TYPE = collections.namedtuple(typename='InstructionTuningText', - field_names=('id context context_type context_duration question ' - 'question_type answer answer_type answer_duration speaker'),) + OUTPUT_TYPE = collections.namedtuple( + typename='InstructionTuningText', + field_names=( + 'id context context_type context_duration question ' + 'question_type answer answer_type answer_duration speaker' + ), + ) def __init__( self, From ab93b176a3a21ded11dacbd9da10f720c170e063 Mon Sep 17 00:00:00 2001 From: taejinp Date: Fri, 15 Nov 2024 14:57:15 -0800 Subject: [PATCH 14/16] Removing unused varialbe in audio_to_diar_label.py Signed-off-by: taejinp --- nemo/collections/asr/data/audio_to_diar_label.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/nemo/collections/asr/data/audio_to_diar_label.py b/nemo/collections/asr/data/audio_to_diar_label.py index 7217686d3168..568708dc8c7a 100644 --- a/nemo/collections/asr/data/audio_to_diar_label.py +++ b/nemo/collections/asr/data/audio_to_diar_label.py @@ -687,7 +687,7 @@ def parse_rttm_multiscale(self, sample): seg_target = self.get_diar_target_labels_from_fr_target(uniq_id, fr_level_target) return seg_target - def get_diar_target_labels_from_fr_target(self, uniq_id, fr_level_target): + def get_diar_target_labels_from_fr_target(self, uniq_id: str, fr_level_target: torch.Tensor) -> torch.Tensor: """ Generate base-scale level binary diarization label from frame-level target matrix. For the given frame-level speaker target matrix fr_level_target, we count the number of frames that belong to each speaker and calculate @@ -1222,7 +1222,6 @@ def __getitem__(self, index): else: session_len_sec = min(sample.duration, self.session_len_sec) - uniq_id = self.get_uniq_id_with_range(sample) audio_signal = self.featurizer.process(sample.audio_file, offset=offset, duration=session_len_sec) # We should resolve the length mis-match from the round-off errors between these two variables: From 7dea01b4b14f83934717cb5504942d82ea35b169 Mon Sep 17 00:00:00 2001 From: taejinp Date: Fri, 15 Nov 2024 18:13:58 -0800 Subject: [PATCH 15/16] Fixed docstrings in training script Signed-off-by: taejinp --- .../diarization/neural_diarizer/sortformer_diar_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py index 1b4376c4f2c4..78c7acbaa6c2 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py @@ -24,7 +24,7 @@ """ Example training session (single node training) -python ./sortformer_diar_train.py --config-path='../conf/neural_diarizer' --config-name='' \ +python ./sortformer_diar_train.py --config-path='../conf/neural_diarizer' --config-name='sortformer_diarizer_hybrid_loss_4spk-v1.yaml' \ trainer.devices=1 \ model.train_ds.manifest_filepath="" \ model.validation_ds.manifest_filepath="" \ From 71d515faa608870f7da8be2297a4ae7b78d99f4d Mon Sep 17 00:00:00 2001 From: taejinp Date: Fri, 15 Nov 2024 18:18:21 -0800 Subject: [PATCH 16/16] Line-too-long issue from Pylint fixed Signed-off-by: taejinp --- .../diarization/neural_diarizer/sortformer_diar_train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py index 78c7acbaa6c2..8719b6463f70 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py @@ -24,7 +24,8 @@ """ Example training session (single node training) -python ./sortformer_diar_train.py --config-path='../conf/neural_diarizer' --config-name='sortformer_diarizer_hybrid_loss_4spk-v1.yaml' \ +python ./sortformer_diar_train.py --config-path='../conf/neural_diarizer' \ + --config-name='sortformer_diarizer_hybrid_loss_4spk-v1.yaml' \ trainer.devices=1 \ model.train_ds.manifest_filepath="" \ model.validation_ds.manifest_filepath="" \