diff --git a/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml b/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml new file mode 100644 index 000000000000..4a6d8f242d36 --- /dev/null +++ b/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml @@ -0,0 +1,213 @@ +# Sortformer Diarizer is an end-to-end speaker diarization model that is solely based on Transformer-encoder type of architecture. +# Model name convention for Sortformer Diarizer: sortformer_diarizer___.yaml +# (Example) `sortformer_diarizer_hybrid_loss_4spk-v1.yaml`. +# Sortformer Diarizer model checkpoint (.ckpt) and NeMo file (.nemo) contain Fast Conformer Encoder model (NEST Encoder) and the pre-trained NEST model is loaded along with the Transformer Encoder layers. +# Example: a manifest line for training +# {"audio_filepath": "/path/to/audio01.wav", "offset": 390.83, "duration": 90.00, "text": "-", "num_speakers": 2, "rttm_filepath": "/path/to/audio01.rttm"} +name: "SortFormerDiarizer" +sample_rate: 16000 +num_workers: 18 +batch_size: 8 + +model: + pil_weight: 0.5 # Weight for Permutation Invariant Loss (PIL) used in training the Sortformer diarizer model + ats_weight: 0.5 # Weight for Arrival Time Sort (ATS) loss in training the Sortformer diarizer model + num_workers: ${num_workers} # Number of workers for data loading + fc_d_model: 512 # Hidden dimension size of the Fast-conformer Encoder + tf_d_model: 192 # Hidden dimension size of the Transformer Encoder + max_num_of_spks: 4 # Maximum number of speakers per model; currently set to 4 + session_len_sec: 90 # Maximum session length in seconds + + train_ds: + manifest_filepath: ??? + sample_rate: ${sample_rate} + num_spks: ${model.max_num_of_spks} + session_len_sec: ${model.session_len_sec} + soft_label_thres: 0.5 # Threshold for binarizing target values; higher values make the model more conservative in predicting speaker activity. + soft_targets: False # If True, use continuous values as target values when calculating cross-entropy loss + labels: null + batch_size: ${batch_size} + shuffle: True + num_workers: ${num_workers} + validation_mode: False + # lhotse config + use_lhotse: False + use_bucketing: True + num_buckets: 10 + bucket_duration_bins: [10, 20, 30, 40, 50, 60, 70, 80, 90] + pin_memory: True + min_duration: 80 + max_duration: 90 + batch_duration: 400 + quadratic_duration: 1200 + bucket_buffer_size: 20000 + shuffle_buffer_size: 10000 + window_stride: ${model.preprocessor.window_stride} + subsampling_factor: ${model.encoder.subsampling_factor} + + validation_ds: + manifest_filepath: ??? + is_tarred: False + tarred_audio_filepaths: null + sample_rate: ${sample_rate} + num_spks: ${model.max_num_of_spks} + session_len_sec: ${model.session_len_sec} + soft_label_thres: 0.5 # A threshold value for setting up the binarized labels. The higher the more conservative the model becomes. + soft_targets: False + labels: null + batch_size: ${batch_size} + shuffle: False + num_workers: ${num_workers} + validation_mode: True + # lhotse config + use_lhotse: False + use_bucketing: False + drop_last: False + pin_memory: True + window_stride: ${model.preprocessor.window_stride} + subsampling_factor: ${model.encoder.subsampling_factor} + + test_ds: + manifest_filepath: null + is_tarred: False + tarred_audio_filepaths: null + sample_rate: 16000 + num_spks: ${model.max_num_of_spks} + session_len_sec: ${model.session_len_sec} + soft_label_thres: 0.5 + soft_targets: False + labels: null + batch_size: ${batch_size} + shuffle: False + seq_eval_mode: True + num_workers: ${num_workers} + validation_mode: True + # lhotse config + use_lhotse: False + use_bucketing: False + drop_last: False + pin_memory: True + window_stride: ${model.preprocessor.window_stride} + subsampling_factor: ${model.encoder.subsampling_factor} + + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + normalize: "per_feature" + window_size: 0.025 + sample_rate: ${sample_rate} + window_stride: 0.01 + window: "hann" + features: 80 + n_fft: 512 + frame_splicing: 1 + dither: 0.00001 + + sortformer_modules: + _target_: nemo.collections.asr.modules.sortformer_modules.SortformerModules + num_spks: ${model.max_num_of_spks} # Number of speakers per model. This is currently fixed at 4. + dropout_rate: 0.5 # Dropout rate + fc_d_model: ${model.fc_d_model} + tf_d_model: ${model.tf_d_model} # Hidden layer size for linear layers in Sortformer Diarizer module + + encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: ${model.preprocessor.features} + feat_out: -1 + n_layers: 18 + d_model: ${model.fc_d_model} + + # Sub-sampling parameters + subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 8 # must be power of 2 for striding and vggnet + subsampling_conv_channels: 256 # set to -1 to make it equal to the d_model + causal_downsampling: false + # Feed forward module's params + ff_expansion_factor: 4 + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [-1, -1] # -1 means unlimited context + att_context_style: regular # regular or chunked_limited + xscaling: true # scales up the input embeddings by sqrt(d_model) + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + # Convolution module's params + conv_kernel_size: 9 + conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) + conv_context_size: null + # Regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 # The dropout used before the encoder + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + # Set to non-zero to enable stochastic depth + stochastic_depth_drop_prob: 0.0 + stochastic_depth_mode: linear # linear or uniform + stochastic_depth_start_layer: 1 + + transformer_encoder: + _target_: nemo.collections.asr.modules.transformer.transformer_encoders.TransformerEncoder + num_layers: 18 + hidden_size: ${model.tf_d_model} # Needs to be multiple of num_attention_heads + inner_size: 768 + num_attention_heads: 8 + attn_score_dropout: 0.5 + attn_layer_dropout: 0.5 + ffn_dropout: 0.5 + hidden_act: relu + pre_ln: False + pre_ln_final_layer_norm: True + + loss: + _target_: nemo.collections.asr.losses.bce_loss.BCELoss + weight: null # Weight for binary cross-entropy loss. Either `null` or list type input. (e.g. [0.5,0.5]) + reduction: mean + + lr: 0.0001 + optim: + name: adamw + lr: ${model.lr} + # optimizer arguments + betas: [0.9, 0.98] + weight_decay: 1e-3 + + sched: + name: InverseSquareRootAnnealing + warmup_steps: 2500 + warmup_ratio: null + min_lr: 1e-06 + +trainer: + devices: 1 # number of gpus (devices) + accelerator: gpu + max_epochs: 800 + max_steps: -1 # computed at runtime if not set + num_nodes: 1 + strategy: ddp_find_unused_parameters_true # Could be "ddp" + accumulate_grad_batches: 1 + deterministic: True + enable_checkpointing: False + logger: False + log_every_n_steps: 1 # Interval of logging. + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + +exp_manager: + use_datetime_version: False + exp_dir: null + name: ${name} + resume_if_exists: True + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + resume_ignore_no_checkpoint: True + create_tensorboard_logger: True + create_checkpoint_callback: True + create_wandb_logger: False + checkpoint_callback_params: + monitor: "val_f1_acc" + mode: "max" + save_top_k: 9 + every_n_epochs: 1 + wandb_logger_kwargs: + resume: True + name: null + project: null \ No newline at end of file diff --git a/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_callhome-part1.yaml b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_callhome-part1.yaml new file mode 100644 index 000000000000..ebed4a649730 --- /dev/null +++ b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_callhome-part1.yaml @@ -0,0 +1,13 @@ +# Postprocessing parameters for timestamp outputs from speaker diarization models. +# This speaker diarization postprocessing scheme is inspired by the postprocessing procedure in the following paper: +# Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020). +# These parameters were optimized with with hybrid-loss trained Sortformer model introduced in https://arxiv.org/pdf/2409.06656. +# These parameters were optimized on the development split of DIHARD3 dataset. See https://arxiv.org/pdf/2012.01477. +# Trial 24682 finished with value: 0.10257785779242055 and parameters: {'onset': 0.53, 'offset': 0.49, 'pad_onset': 0.23, 'pad_offset': 0.01, 'min_duration_on': 0.42, 'min_duration_off': 0.34}. Best is trial 24682 with value: 0.10257785779242055. +parameters: + onset: 0.53 # Onset threshold for detecting the beginning and end of a speech + offset: 0.49 # Offset threshold for detecting the end of a speech + pad_onset: 0.23 # Adding durations before each speech segment + pad_offset: 0.01 # Adding durations after each speech segment + min_duration_on: 0.42 # Threshold for small non-speech deletion + min_duration_off: 0.34 # Threshold for short speech segment deletion \ No newline at end of file diff --git a/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_dihard-dev.yaml b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_dihard-dev.yaml new file mode 100644 index 000000000000..9beaff6e3c7c --- /dev/null +++ b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_dihard-dev.yaml @@ -0,0 +1,13 @@ +# Postprocessing parameters for timestamp outputs from speaker diarization models. +# This speaker diarization postprocessing scheme is inspired by the postprocessing procedure in the following paper: +# Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020). +# These parameters were optimized with with hybrid-loss trained Sortformer model introduced in https://arxiv.org/pdf/2409.06656. +# These parameters were optimized on CallHome Dataset from the NIST SRE 2000 Disc8, especially from the split2 specified in: Kaldi, “Kaldi x-vector recipe v2,” https://github.com/kaldi-asr/kaldi/tree/master/egs/callhome_diarization/v2. +# Trial 732 finished with value: 0.12171946949255649 and parameters: {'onset': 0.64, 'offset': 0.74, 'pad_onset': 0.06, 'pad_offset': 0.0, 'min_duration_on': 0.1, 'min_duration_off': 0.15}. Best is trial 732 with value: 0.12171946949255649. +parameters: + onset: 0.64 # Onset threshold for detecting the beginning and end of a speech + offset: 0.74 # Offset threshold for detecting the end of a speech + pad_onset: 0.06 # Adding durations before each speech segment + pad_offset: 0.0 # Adding durations after each speech segment + min_duration_on: 0.1 # Threshold for small non-speech deletion + min_duration_off: 0.15 # Threshold for short speech segment deletion \ No newline at end of file diff --git a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py new file mode 100644 index 000000000000..65ba0226988a --- /dev/null +++ b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py @@ -0,0 +1,424 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +python $BASEPATH/neural_diarizer/sortformer_diarization.py \ + model_path=/path/to/sortformer_model.nemo \ + batch_size=4 \ + dataset_manifest=/path/to/diarization_path_to_manifest.json + +""" +import logging +import os +import tempfile +from dataclasses import dataclass, is_dataclass +from typing import Dict, List, Optional, Union + +import optuna +import pytorch_lightning as pl +import torch +import yaml +from omegaconf import OmegaConf +from pytorch_lightning import seed_everything +from tqdm import tqdm + +from nemo.collections.asr.metrics.der import score_labels +from nemo.collections.asr.models import SortformerEncLabelModel +from nemo.collections.asr.parts.utils.speaker_utils import audio_rttm_map, timestamps_to_pyannote_object +from nemo.collections.asr.parts.utils.vad_utils import ts_vad_post_processing +from nemo.core.config import hydra_runner + +seed_everything(42) +torch.backends.cudnn.deterministic = True + + +@dataclass +class PostProcessingParams: + """ + Postprocessing parameters for end-to-end speaker diarization models. + These parameters can significantly affect DER performance depending on the evaluation style and the dataset. + It is recommended to tune these parameters based on the evaluation style and the dataset + to achieve the desired DER performance. + """ + + onset: float = 0.5 # Onset threshold for detecting the beginning and end of a speech + offset: float = 0.5 # Offset threshold for detecting the end of a speech + pad_onset: float = 0.0 # Adding durations before each speech segment + pad_offset: float = 0.0 # Adding durations after each speech segment + min_duration_on: float = 0.0 # Threshold for small non-speech deletion + min_duration_off: float = 0.0 # Threshold for short speech segment deletion + + +@dataclass +class DiarizationConfig: + """Diarization configuration parameters for inference.""" + + model_path: Optional[str] = None # Path to a .nemo file + pretrained_name: Optional[str] = None # Name of a pretrained model + audio_dir: Optional[str] = None # Path to a directory which contains audio files + dataset_manifest: Optional[str] = None # Path to dataset's JSON manifest + + postprocessing_yaml: Optional[str] = None # Path to a yaml file for postprocessing configurations + no_der: bool = False + out_rttm_dir: Optional[str] = None + + # General configs + session_len_sec: float = -1 # End-to-end diarization session length in seconds + batch_size: int = 4 + num_workers: int = 0 + random_seed: Optional[int] = None # seed number going to be used in seed_everything() + bypass_postprocessing: bool = True # If True, postprocessing will be bypassed + + # Eval Settings: (0.25, False) should be default setting for sortformer eval. + collar: float = 0.25 # Collar in seconds for DER calculation + ignore_overlap: bool = False # If True, DER will be calculated only for non-overlapping segments + + # If `cuda` is a negative number, inference will be on CPU only. + cuda: Optional[int] = None + matmul_precision: str = "highest" # Literal["highest", "high", "medium"] + + # Optuna Config + launch_pp_optim: bool = False # If True, launch optimization process for postprocessing parameters + optuna_study_name: str = "optim_postprocessing" + optuna_temp_dir: str = "/tmp/optuna" + optuna_storage: str = f"sqlite:///{optuna_study_name}.db" + optuna_log_file: str = f"{optuna_study_name}.log" + optuna_n_trials: int = 100000 + + +def load_postprocessing_from_yaml(postprocessing_yaml): + """ + Load postprocessing parameters from a YAML file. + + Args: + postprocessing_yaml (str): + Path to a YAML file for postprocessing configurations. + + Returns: + postprocessing_params (dataclass): + Postprocessing parameters loaded from the YAML file. + """ + # Add PostProcessingParams as a field + postprocessing_params = OmegaConf.structured(PostProcessingParams()) + if postprocessing_yaml is None: + logging.info( + f"No postprocessing YAML file has been provided. Default postprocessing configurations will be applied." + ) + else: + # Load postprocessing params from the provided YAML file + with open(postprocessing_yaml, 'r') as file: + yaml_params = yaml.safe_load(file)['parameters'] + # Update the postprocessing_params with the loaded values + logging.info(f"Postprocessing YAML file '{postprocessing_yaml}' has been loaded.") + for key, value in yaml_params.items(): + if hasattr(postprocessing_params, key): + setattr(postprocessing_params, key, value) + return postprocessing_params + + +def optuna_suggest_params(postprocessing_cfg: PostProcessingParams, trial: optuna.Trial) -> PostProcessingParams: + """ + Suggests hyperparameters for postprocessing using Optuna. + See the following link for `trial` instance in Optuna framework. + https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html#optuna.trial.Trial + + Args: + postprocessing_cfg (PostProcessingParams): The current postprocessing configuration. + trial (optuna.Trial): The Optuna trial object used to suggest hyperparameters. + + Returns: + PostProcessingParams: The updated postprocessing configuration with suggested hyperparameters. + """ + postprocessing_cfg.onset = trial.suggest_float("onset", 0.4, 0.8, step=0.01) + postprocessing_cfg.offset = trial.suggest_float("offset", 0.4, 0.9, step=0.01) + postprocessing_cfg.pad_onset = trial.suggest_float("pad_onset", 0.1, 0.5, step=0.01) + postprocessing_cfg.pad_offset = trial.suggest_float("pad_offset", 0.0, 0.2, step=0.01) + postprocessing_cfg.min_duration_on = trial.suggest_float("min_duration_on", 0.0, 0.75, step=0.01) + postprocessing_cfg.min_duration_off = trial.suggest_float("min_duration_off", 0.0, 0.75, step=0.01) + return postprocessing_cfg + + +def get_tensor_path(cfg: DiarizationConfig) -> str: + """ + Constructs the file path for saving or loading prediction tensors based on the configuration. + + Args: + cfg (DiarizationConfig): The configuration object containing model and dataset details. + + Returns: + str: The constructed file path for the prediction tensor. + """ + tensor_filename = os.path.basename(cfg.dataset_manifest).replace("manifest.", "").replace(".json", "") + model_base_path = os.path.dirname(cfg.model_path) + model_id = os.path.basename(cfg.model_path).replace(".ckpt", "").replace(".nemo", "") + bpath = f"{model_base_path}/pred_tensors" + if not os.path.exists(bpath): + os.makedirs(bpath) + tensor_path = f"{bpath}/__{model_id}__{tensor_filename}.pt" + return tensor_path + + +def diarization_objective( + trial: optuna.Trial, + postprocessing_cfg: PostProcessingParams, + temp_out_dir: str, + infer_audio_rttm_dict: Dict[str, Dict[str, str]], + diar_model_preds_total_list: List[torch.Tensor], + collar: float = 0.25, + ignore_overlap: bool = False, +) -> float: + """ + Objective function for Optuna hyperparameter optimization in speaker diarization. + + This function evaluates the diarization performance using a set of postprocessing parameters + suggested by Optuna. It converts prediction matrices to time-stamp segments, scores the + diarization results, and returns the Diarization Error Rate (DER) as the optimization metric. + + Args: + trial (optuna.Trial): The Optuna trial object used to suggest hyperparameters. + postprocessing_cfg (PostProcessingParams): The current postprocessing configuration. + temp_out_dir (str): Temporary directory for storing intermediate outputs. + infer_audio_rttm_dict (Dict[str, Dict[str, str]]): Dictionary containing audio file paths, + offsets, durations, and RTTM file paths. + diar_model_preds_total_list (List[torch.Tensor]): List of prediction matrices containing + sigmoid values for each speaker. Dimension: [(1, frames, num_speakers), ..., (1, frames, num_speakers)] + collar (float, optional): Collar in seconds for DER calculation. Defaults to 0.25. + ignore_overlap (bool, optional): If True, DER will be calculated only for non-overlapping segments. + Defaults to False. + + Returns: + float: The Diarization Error Rate (DER) for the given set of postprocessing parameters. + """ + with tempfile.TemporaryDirectory(dir=temp_out_dir, prefix="Diar_PostProcessing_") as local_temp_out_dir: + if trial is not None: + postprocessing_cfg = optuna_suggest_params(postprocessing_cfg, trial) + all_hyps, all_refs, all_uems = convert_pred_mat_to_segments( + audio_rttm_map_dict=infer_audio_rttm_dict, + postprocessing_cfg=postprocessing_cfg, + batch_preds_list=diar_model_preds_total_list, + unit_10ms_frame_count=8, + bypass_postprocessing=False, + ) + metric, mapping_dict, itemized_errors = score_labels( + AUDIO_RTTM_MAP=infer_audio_rttm_dict, + all_reference=all_refs, + all_hypothesis=all_hyps, + all_uem=all_uems, + collar=collar, + ignore_overlap=ignore_overlap, + ) + der = abs(metric) + return der + + +def run_optuna_hyperparam_search( + cfg: DiarizationConfig, # type: DiarizationConfig + postprocessing_cfg: PostProcessingParams, + infer_audio_rttm_dict: Dict[str, Dict[str, str]], + preds_list: List[torch.Tensor], + temp_out_dir: str, +): + """ + Run Optuna hyperparameter optimization for speaker diarization. + + Args: + cfg (DiarizationConfig): The configuration object containing model and dataset details. + postprocessing_cfg (PostProcessingParams): The current postprocessing configuration. + infer_audio_rttm_dict (dict): dictionary of audio file path, offset, duration and RTTM filepath. + preds_list (List[torch.Tensor]): list of prediction matrices containing sigmoid values for each speaker. + Dimension: [(1, frames, num_speakers), ..., (1, frames, num_speakers)] + temp_out_dir (str): temporary directory for storing intermediate outputs. + """ + worker_function = lambda trial: diarization_objective( + trial=trial, + postprocessing_cfg=postprocessing_cfg, + temp_out_dir=temp_out_dir, + infer_audio_rttm_dict=infer_audio_rttm_dict, + diar_model_preds_total_list=preds_list, + collar=cfg.collar, + ) + study = optuna.create_study( + direction="minimize", study_name=cfg.optuna_study_name, storage=cfg.optuna_storage, load_if_exists=True + ) + logger = logging.getLogger() + logger.setLevel(logging.INFO) # Setup the root logger. + if cfg.optuna_log_file is not None: + logger.addHandler(logging.FileHandler(cfg.optuna_log_file, mode="a")) + logger.addHandler(logging.StreamHandler()) + optuna.logging.enable_propagation() # Propagate logs to the root logger. + study.optimize(worker_function, n_trials=cfg.optuna_n_trials) + + +def convert_pred_mat_to_segments( + audio_rttm_map_dict: Dict[str, Dict[str, str]], + postprocessing_cfg, + batch_preds_list: List[torch.Tensor], + unit_10ms_frame_count: int = 8, + bypass_postprocessing: bool = False, + out_rttm_dir: str | None = None, +): + """ + Convert prediction matrix to time-stamp segments. + + Args: + audio_rttm_map_dict (dict): dictionary of audio file path, offset, duration and RTTM filepath. + batch_preds_list (List[torch.Tensor]): list of prediction matrices containing sigmoid values for each speaker. + Dimension: [(1, frames, num_speakers), ..., (1, frames, num_speakers)] + unit_10ms_frame_count (int, optional): number of 10ms segments in a frame. Defaults to 8. + bypass_postprocessing (bool, optional): if True, postprocessing will be bypassed. Defaults to False. + + Returns: + all_hypothesis (list): list of pyannote objects for each audio file. + all_reference (list): list of pyannote objects for each audio file. + all_uems (list): list of pyannote objects for each audio file. + """ + batch_pred_ts_segs, all_hypothesis, all_reference, all_uems = [], [], [], [] + cfg_vad_params = OmegaConf.structured(postprocessing_cfg) + for sample_idx, (uniq_id, audio_rttm_values) in tqdm( + enumerate(audio_rttm_map_dict.items()), total=len(audio_rttm_map_dict), desc="Running post-processing" + ): + spk_ts = [] + offset, duration = audio_rttm_values['offset'], audio_rttm_values['duration'] + speaker_assign_mat = batch_preds_list[sample_idx].squeeze(dim=0) + speaker_timestamps = [[] for _ in range(speaker_assign_mat.shape[-1])] + for spk_id in range(speaker_assign_mat.shape[-1]): + ts_mat = ts_vad_post_processing( + speaker_assign_mat[:, spk_id], + cfg_vad_params=cfg_vad_params, + unit_10ms_frame_count=unit_10ms_frame_count, + bypass_postprocessing=bypass_postprocessing, + ) + ts_mat = ts_mat + offset + ts_mat = torch.clamp(ts_mat, min=offset, max=(offset + duration)) + ts_seg_list = ts_mat.tolist() + speaker_timestamps[spk_id].extend(ts_seg_list) + spk_ts.append(ts_seg_list) + all_hypothesis, all_reference, all_uems = timestamps_to_pyannote_object( + speaker_timestamps, + uniq_id, + audio_rttm_values, + all_hypothesis, + all_reference, + all_uems, + out_rttm_dir, + ) + batch_pred_ts_segs.append(spk_ts) + return all_hypothesis, all_reference, all_uems + + +@hydra_runner(config_name="DiarizationConfig", schema=DiarizationConfig) +def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: + """Main function for end-to-end speaker diarization inference.""" + for key in cfg: + cfg[key] = None if cfg[key] == 'None' else cfg[key] + + if is_dataclass(cfg): + cfg = OmegaConf.structured(cfg) + + if cfg.random_seed: + pl.seed_everything(cfg.random_seed) + + if cfg.model_path is None and cfg.pretrained_name is None: + raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None!") + if cfg.audio_dir is None and cfg.dataset_manifest is None: + raise ValueError("Both cfg.audio_dir and cfg.dataset_manifest cannot be None!") + + # setup GPU + torch.set_float32_matmul_precision(cfg.matmul_precision) + if cfg.cuda is None: + if torch.cuda.is_available(): + device = [0] # use 0th CUDA device + accelerator = 'gpu' + map_location = torch.device('cuda:0') + else: + device = 1 + accelerator = 'cpu' + map_location = torch.device('cpu') + else: + device = [cfg.cuda] + accelerator = 'gpu' + map_location = torch.device(f'cuda:{cfg.cuda}') + + if cfg.model_path.endswith(".ckpt"): + diar_model = SortformerEncLabelModel.load_from_checkpoint( + checkpoint_path=cfg.model_path, map_location=map_location, strict=False + ) + elif cfg.model_path.endswith(".nemo"): + diar_model = SortformerEncLabelModel.restore_from(restore_path=cfg.model_path, map_location=map_location) + else: + raise ValueError("cfg.model_path must end with.ckpt or.nemo!") + + diar_model._cfg.test_ds.session_len_sec = cfg.session_len_sec + trainer = pl.Trainer(devices=device, accelerator=accelerator) + diar_model.set_trainer(trainer) + + diar_model = diar_model.eval() + diar_model._cfg.test_ds.manifest_filepath = cfg.dataset_manifest + infer_audio_rttm_dict = audio_rttm_map(cfg.dataset_manifest) + diar_model._cfg.test_ds.batch_size = cfg.batch_size + + # Model setup for inference + diar_model._cfg.test_ds.num_workers = cfg.num_workers + diar_model.setup_test_data(test_data_config=diar_model._cfg.test_ds) + + postprocessing_cfg = load_postprocessing_from_yaml(cfg.postprocessing_yaml) + tensor_path = get_tensor_path(cfg) + + if os.path.exists(tensor_path): + logging.info( + f"A saved prediction tensor has been found. Loading the saved prediction tensors from {tensor_path}..." + ) + diar_model_preds_total_list = torch.load(tensor_path) + else: + logging.info(f"No saved prediction tensors found. Running inference on the dataset...") + diar_model.test_batch() + diar_model_preds_total_list = diar_model.preds_total_list + torch.save(diar_model.preds_total_list, tensor_path) + + if cfg.launch_pp_optim: + # Launch a hyperparameter optimization process if launch_pp_optim is True + run_optuna_hyperparam_search( + cfg=cfg, + postprocessing_cfg=postprocessing_cfg, + infer_audio_rttm_dict=infer_audio_rttm_dict, + preds_list=diar_model_preds_total_list, + temp_out_dir=cfg.optuna_temp_dir, + ) + + # Evaluation + if not cfg.no_der: + if cfg.out_rttm_dir is not None and not os.path.exists(cfg.out_rttm_dir): + os.mkdir(cfg.out_rttm_dir) + all_hyps, all_refs, all_uems = convert_pred_mat_to_segments( + infer_audio_rttm_dict, + postprocessing_cfg=postprocessing_cfg, + batch_preds_list=diar_model_preds_total_list, + unit_10ms_frame_count=8, + bypass_postprocessing=cfg.bypass_postprocessing, + out_rttm_dir=cfg.out_rttm_dir, + ) + logging.info(f"Evaluating the model on the {len(diar_model_preds_total_list)} audio segments...") + score_labels( + AUDIO_RTTM_MAP=infer_audio_rttm_dict, + all_reference=all_refs, + all_hypothesis=all_hyps, + all_uem=all_uems, + collar=cfg.collar, + ignore_overlap=cfg.ignore_overlap, + ) + logging.info(f"PostProcessingParams: {postprocessing_cfg}") + + +if __name__ == '__main__': + main() diff --git a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py new file mode 100644 index 000000000000..8719b6463f70 --- /dev/null +++ b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py @@ -0,0 +1,51 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytorch_lightning as pl +from omegaconf import OmegaConf +from pytorch_lightning import seed_everything + +from nemo.collections.asr.models import SortformerEncLabelModel +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + +""" +Example training session (single node training) + +python ./sortformer_diar_train.py --config-path='../conf/neural_diarizer' \ + --config-name='sortformer_diarizer_hybrid_loss_4spk-v1.yaml' \ + trainer.devices=1 \ + model.train_ds.manifest_filepath="" \ + model.validation_ds.manifest_filepath="" \ + exp_manager.name='sample_train' \ + exp_manager.exp_dir='./sortformer_diar_train' +""" + +seed_everything(42) + + +@hydra_runner(config_path="../conf/neural_diarizer", config_name="sortformer_diarizer_hybrid_loss_4spk-v1.yaml") +def main(cfg): + """Main function for training the sortformer diarizer model.""" + logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') + trainer = pl.Trainer(**cfg.trainer) + exp_manager(trainer, cfg.get("exp_manager", None)) + sortformer_model = SortformerEncLabelModel(cfg=cfg.model, trainer=trainer) + sortformer_model.maybe_init_from_pretrained_checkpoint(cfg) + trainer.fit(sortformer_model) + + +if __name__ == '__main__': + main() diff --git a/nemo/collections/asr/data/audio_to_diar_label.py b/nemo/collections/asr/data/audio_to_diar_label.py index a1cb6d0f1bdc..568708dc8c7a 100644 --- a/nemo/collections/asr/data/audio_to_diar_label.py +++ b/nemo/collections/asr/data/audio_to_diar_label.py @@ -15,15 +15,20 @@ import os from collections import OrderedDict from statistics import mode -from typing import Dict, Optional +from typing import Dict, List, Optional, Tuple +import numpy as np import torch from nemo.collections.asr.parts.utils.offline_clustering import get_argmin_mat -from nemo.collections.asr.parts.utils.speaker_utils import convert_rttm_line, prepare_split_data -from nemo.collections.common.parts.preprocessing.collections import DiarizationSpeechLabel +from nemo.collections.asr.parts.utils.speaker_utils import convert_rttm_line, get_subsegments, prepare_split_data +from nemo.collections.common.parts.preprocessing.collections import ( + DiarizationSpeechLabel, + EndtoEndDiarizationSpeechLabel, +) from nemo.core.classes import Dataset from nemo.core.neural_types import AudioSignal, EncodedRepresentation, LengthsType, NeuralType, ProbsType +from nemo.utils import logging def get_scale_mapping_list(uniq_timestamps): @@ -62,7 +67,7 @@ def get_scale_mapping_list(uniq_timestamps): return scale_mapping_argmat -def extract_seg_info_from_rttm(uniq_id, rttm_lines, mapping_dict=None, target_spks=None): +def extract_seg_info_from_rttm(rttm_lines, mapping_dict=None, target_spks=None): """ Get RTTM lines containing speaker labels, start time and end time. target_spks contains two targeted speaker indices for creating groundtruth label files. Only speakers in target_spks variable will be @@ -76,7 +81,8 @@ def extract_seg_info_from_rttm(uniq_id, rttm_lines, mapping_dict=None, target_sp mapping_dict (dict): Mapping between the estimated speakers and the speakers in the ground-truth annotation. `mapping_dict` variable is only provided when the inference mode is running in sequence-eval mode. - Sequence eval mode uses the mapping between the estimated speakers and the speakers in ground-truth annotation. + Sequence eval mode uses the mapping between the estimated speakers and the speakers + in ground-truth annotation. Returns: rttm_tup (tuple): Tuple containing lists of start time, end time and speaker labels. @@ -108,12 +114,14 @@ def assign_frame_level_spk_vector(rttm_timestamps, round_digits, frame_per_sec, Args: rttm_timestamps (list): List containing start and end time for each speaker segment label. - stt_list, end_list and speaker_list are contained. + `stt_list`, `end_list` and `speaker_list` are contained. frame_per_sec (int): - Number of feature frames per second. This quantity is determined by window_stride variable in preprocessing module. + Number of feature frames per second. This quantity is determined by + `window_stride` variable in preprocessing module. target_spks (tuple): - Speaker indices that are generated from combinations. If there are only one or two speakers, - only a single target_spks variable is generated. + Speaker indices that are generated from combinations. + If there are only one or two speakers, + only a single `target_spks` variable is generated. Returns: fr_level_target (torch.tensor): @@ -124,7 +132,7 @@ def assign_frame_level_spk_vector(rttm_timestamps, round_digits, frame_per_sec, return None else: sorted_speakers = sorted(list(set(speaker_list))) - total_fr_len = int(max(end_list) * (10 ** round_digits)) + total_fr_len = int(max(end_list) * (10**round_digits)) spk_num = max(len(sorted_speakers), min_spks) speaker_mapping_dict = {rttm_key: x_int for x_int, rttm_key in enumerate(sorted_speakers)} fr_level_target = torch.zeros(total_fr_len, spk_num) @@ -139,6 +147,140 @@ def assign_frame_level_spk_vector(rttm_timestamps, round_digits, frame_per_sec, return fr_level_target +def get_subsegments_to_timestamps( + subsegments: List[Tuple[float, float]], feat_per_sec: int = 100, max_end_ts: float = None, decimals=2 +): + """ + Convert subsegment timestamps to scale timestamps by multiplying with the feature rate (`feat_per_sec`) + and rounding. Segment is consisted of many subsegments and sugsegments are equivalent to `frames` + in end-to-end speaker diarization models. + + Args: + subsegments (List[Tuple[float, float]]): + A list of tuples where each tuple contains the start and end times of a subsegment + (frames in end-to-end models). + >>> subsegments = [[t0_start, t0_duration], [t1_start, t1_duration],..., [tN_start, tN_duration]] + feat_per_sec (int, optional): + The number of feature frames per second. Defaults to 100. + max_end_ts (float, optional): + The maximum end timestamp to clip the results. If None, no clipping is applied. Defaults to None. + decimals (int, optional): + The number of decimal places to round the timestamps. Defaults to 2. + + Example: + Segments starting from 0.0 and ending at 69.2 seconds. + If hop-length is 0.08 and the subsegment (frame) length is 0.16 seconds, + there are 864 = (69.2 - 0.16)/0.08 + 1 subsegments (frames in end-to-end models) in this segment. + >>> subsegments = [[[0.0, 0.16], [0.08, 0.16], ..., [69.04, 0.16], [69.12, 0.08]] + + Returns: + ts (torch.tensor): + A tensor containing the scaled and rounded timestamps for each subsegment. + """ + seg_ts = (torch.tensor(subsegments) * feat_per_sec).float() + ts_round = torch.round(seg_ts, decimals=decimals) + ts = ts_round.long() + ts[:, 1] = ts[:, 0] + ts[:, 1] + if max_end_ts is not None: + ts = np.clip(ts, 0, int(max_end_ts * feat_per_sec)) + return ts + + +def extract_frame_info_from_rttm(offset, duration, rttm_lines, round_digits=3): + """ + Extracts RTTM lines containing speaker labels, start time, and end time for a given audio segment. + + Args: + uniq_id (str): Unique identifier for the audio file and corresponding RTTM file. + offset (float): The starting time offset for the segment of interest. + duration (float): The duration of the segment of interest. + rttm_lines (list): List of RTTM lines in string format. + round_digits (int, optional): Number of decimal places to round the start and end times. Defaults to 3. + + Returns: + rttm_mat (tuple): A tuple containing lists of start times, end times, and speaker labels. + sess_to_global_spkids (dict): A mapping from session-specific speaker indices to global speaker identifiers. + """ + rttm_stt, rttm_end = offset, offset + duration + stt_list, end_list, speaker_list, speaker_set = [], [], [], [] + sess_to_global_spkids = dict() + + for rttm_line in rttm_lines: + start, end, speaker = convert_rttm_line(rttm_line) + + # Skip invalid RTTM lines where the start time is greater than the end time. + if start > end: + continue + + # Check if the RTTM segment overlaps with the specified segment of interest. + if (end > rttm_stt and start < rttm_end) or (start < rttm_end and end > rttm_stt): + # Adjust the start and end times to fit within the segment of interest. + start, end = max(start, rttm_stt), min(end, rttm_end) + else: + continue + + # Round the start and end times to the specified number of decimal places. + end_list.append(round(end, round_digits)) + stt_list.append(round(start, round_digits)) + + # Assign a unique index to each speaker and maintain a mapping. + if speaker not in speaker_set: + speaker_set.append(speaker) + speaker_list.append(speaker_set.index(speaker)) + sess_to_global_spkids.update({speaker_set.index(speaker): speaker}) + + rttm_mat = (stt_list, end_list, speaker_list) + return rttm_mat, sess_to_global_spkids + + +def get_frame_targets_from_rttm( + rttm_timestamps: list, + offset: float, + duration: float, + round_digits: int, + feat_per_sec: int, + max_spks: int, +): + """ + Create a multi-dimensional vector sequence containing speaker timestamp information in RTTM. + The unit-length is the frame shift length of the acoustic feature. The feature-level annotations + `feat_level_target` will later be converted to base-segment level diarization label. + + Args: + rttm_timestamps (list): + List containing start and end time for each speaker segment label. + stt_list, end_list and speaker_list are contained. + feat_per_sec (int): + Number of feature frames per second. + This quantity is determined by window_stride variable in preprocessing module. + target_spks (tuple): + Speaker indices that are generated from combinations. If there are only one or two speakers, + only a single target_spks variable is generated. + + Returns: + feat_level_target (torch.tensor): + Tensor containing label for each feature level frame. + """ + stt_list, end_list, speaker_list = rttm_timestamps + sorted_speakers = sorted(list(set(speaker_list))) + total_fr_len = int(duration * feat_per_sec) + if len(sorted_speakers) > max_spks: + logging.warning( + f"Number of speakers in RTTM file {len(sorted_speakers)} exceeds the maximum number of speakers: " + f"{max_spks}! Only {max_spks} first speakers remain, and this will affect frame metrics!" + ) + feat_level_target = torch.zeros(total_fr_len, max_spks) + for count, (stt, end, spk_rttm_key) in enumerate(zip(stt_list, end_list, speaker_list)): + if end < offset or stt > offset + duration: + continue + stt, end = max(offset, stt), min(offset + duration, end) + spk = spk_rttm_key + if spk < max_spks: + stt_fr, end_fr = int((stt - offset) * feat_per_sec), int((end - offset) * feat_per_sec) + feat_level_target[stt_fr:end_fr, spk] = 1 + return feat_level_target + + class _AudioMSDDTrainDataset(Dataset): """ Dataset class that loads a json file containing paths to audio files, @@ -214,7 +356,7 @@ def __init__( self.multiscale_args_dict = multiscale_args_dict self.emb_dir = emb_dir self.round_digits = 2 - self.decim = 10 ** self.round_digits + self.decim = 10**self.round_digits self.soft_label_thres = soft_label_thres self.pairwise_infer = pairwise_infer self.max_spks = 2 @@ -224,7 +366,10 @@ def __init__( self.global_rank = global_rank self.manifest_filepath = manifest_filepath self.multiscale_timestamp_dict = prepare_split_data( - self.manifest_filepath, self.emb_dir, self.multiscale_args_dict, self.global_rank, + self.manifest_filepath, + self.emb_dir, + self.multiscale_args_dict, + self.global_rank, ) def __len__(self): @@ -241,7 +386,7 @@ def assign_labels_to_longer_segs(self, uniq_id, base_scale_clus_label): Unique sample ID for training. base_scale_clus_label (torch.tensor): Tensor variable containing the speaker labels for the base-scale segments. - + Returns: per_scale_clus_label (torch.tensor): Tensor variable containing the speaker labels for each segment in each scale. @@ -270,15 +415,17 @@ def assign_labels_to_longer_segs(self, uniq_id, base_scale_clus_label): def get_diar_target_labels(self, uniq_id, sample, fr_level_target): """ - Convert frame-level diarization target variable into segment-level target variable. Since the granularity is reduced - from frame level (10ms) to segment level (100ms~500ms), we need a threshold value, `soft_label_thres`, which determines - the label of each segment based on the overlap between a segment range (start and end time) and the frame-level target variable. + Convert frame-level diarization target variable into segment-level target variable. + Since the granularity is reduced from frame level (10ms) to segment level (100ms~500ms), + we need a threshold value, `soft_label_thres`, which determines the label of each segment + based on the overlap between a segment range (start and end time) and the frame-level target variable. Args: uniq_id (str): Unique file ID that refers to an input audio file and corresponding RTTM (Annotation) file. sample: - `DiarizationSpeechLabel` instance containing sample information such as audio filepath and RTTM filepath. + `DiarizationSpeechLabel` instance containing sample information such as + audio filepath and RTTM filepath. fr_level_target (torch.tensor): Tensor containing label for each feature-level frame. @@ -286,13 +433,14 @@ def get_diar_target_labels(self, uniq_id, sample, fr_level_target): seg_target (torch.tensor): Tensor containing binary speaker labels for base-scale segments. base_clus_label (torch.tensor): - Representative speaker label for each segment. This variable only has one speaker label for each base-scale segment. + Representative speaker label for each segment. This variable only has one speaker label + for each base-scale segment. -1 means that there is no corresponding speaker in the target_spks tuple. """ seg_target_list, base_clus_label = [], [] self.scale_n = len(self.multiscale_timestamp_dict[uniq_id]['scale_dict']) subseg_time_stamp_list = self.multiscale_timestamp_dict[uniq_id]["scale_dict"][self.scale_n - 1]["time_stamps"] - for (seg_stt, seg_end) in subseg_time_stamp_list: + for seg_stt, seg_end in subseg_time_stamp_list: seg_stt_fr, seg_end_fr = int(seg_stt * self.frame_per_sec), int(seg_end * self.frame_per_sec) soft_label_vec_sess = torch.sum(fr_level_target[seg_stt_fr:seg_end_fr, :], axis=0) / ( seg_end_fr - seg_stt_fr @@ -321,7 +469,8 @@ def parse_rttm_for_ms_targets(self, sample): Args: sample: - `DiarizationSpeechLabel` instance containing sample information such as audio filepath and RTTM filepath. + `DiarizationSpeechLabel` instance containing sample information such as + audio filepath and RTTM filepath. target_spks (tuple): Speaker indices that are generated from combinations. If there are only one or two speakers, only a single target_spks tuple is generated. @@ -336,9 +485,10 @@ def parse_rttm_for_ms_targets(self, sample): multiscale embeddings to form an input matrix for the MSDD model. """ - rttm_lines = open(sample.rttm_file).readlines() + with open(sample.rttm_file, 'r') as file: + rttm_lines = file.readlines() uniq_id = self.get_uniq_id_with_range(sample) - rttm_timestamps = extract_seg_info_from_rttm(uniq_id, rttm_lines) + rttm_timestamps = extract_seg_info_from_rttm(rttm_lines) fr_level_target = assign_frame_level_spk_vector( rttm_timestamps, self.round_digits, self.frame_per_sec, target_spks=sample.target_spks ) @@ -370,14 +520,14 @@ def get_uniq_id_with_range(self, sample, deci=3): def get_ms_seg_timestamps(self, sample): """ - Get start and end time of segments in each scale. + Get start and end time of each diarization frame. Args: sample: `DiarizationSpeechLabel` instance from preprocessing.collections Returns: ms_seg_timestamps (torch.tensor): - Tensor containing Multiscale segment timestamps. + Tensor containing timestamps for each frame. ms_seg_counts (torch.tensor): Number of segments for each scale. This information is used for reshaping embedding batch during forward propagation. @@ -441,7 +591,8 @@ class _AudioMSDDInferDataset(Dataset): emb_dict (dict): Dictionary containing cluster-average embeddings and speaker mapping information. emb_seq (dict): - Dictionary containing multiscale speaker embedding sequence, scale mapping and corresponding segment timestamps. + Dictionary containing multiscale speaker embedding sequence, + scale mapping and corresponding segment timestamps. clus_label_dict (dict): Subsegment-level (from base-scale) speaker labels from clustering results. soft_label_thres (float): @@ -496,7 +647,7 @@ def __init__( self.emb_seq = emb_seq self.clus_label_dict = clus_label_dict self.round_digits = 2 - self.decim = 10 ** self.round_digits + self.decim = 10**self.round_digits self.frame_per_sec = int(1 / window_stride) self.soft_label_thres = soft_label_thres self.pairwise_infer = pairwise_infer @@ -529,20 +680,20 @@ def parse_rttm_multiscale(self, sample): rttm_lines = open(sample.rttm_file).readlines() uniq_id = os.path.splitext(os.path.basename(sample.rttm_file))[0] mapping_dict = self.emb_dict[max(self.emb_dict.keys())][uniq_id]['mapping'] - rttm_timestamps = extract_seg_info_from_rttm(uniq_id, rttm_lines, mapping_dict, sample.target_spks) + rttm_timestamps = extract_seg_info_from_rttm(rttm_lines, mapping_dict, sample.target_spks) fr_level_target = assign_frame_level_spk_vector( rttm_timestamps, self.round_digits, self.frame_per_sec, sample.target_spks ) seg_target = self.get_diar_target_labels_from_fr_target(uniq_id, fr_level_target) return seg_target - def get_diar_target_labels_from_fr_target(self, uniq_id, fr_level_target): + def get_diar_target_labels_from_fr_target(self, uniq_id: str, fr_level_target: torch.Tensor) -> torch.Tensor: """ Generate base-scale level binary diarization label from frame-level target matrix. For the given frame-level speaker target matrix fr_level_target, we count the number of frames that belong to each speaker and calculate - ratios for each speaker into the `soft_label_vec` variable. Finally, `soft_label_vec` variable is compared with `soft_label_thres` - to determine whether a label vector should contain 0 or 1 for each speaker bin. Note that seg_target variable has - dimension of (number of base-scale segments x 2) dimension. + ratios for each speaker into the `soft_label_vec` variable. Finally, `soft_label_vec` variable is compared + with `soft_label_thres` to determine whether a label vector should contain 0 or 1 for each speaker bin. + Note that seg_target variable has dimension of (number of base-scale segments x 2) dimension. Example of seg_target: [[0., 1.], [0., 1.], [1., 1.], [1., 0.], [1., 0.], ..., [0., 1.]] @@ -562,7 +713,7 @@ def get_diar_target_labels_from_fr_target(self, uniq_id, fr_level_target): return None else: seg_target_list = [] - for (seg_stt, seg_end, label_int) in self.clus_label_dict[uniq_id]: + for seg_stt, seg_end, label_int in self.clus_label_dict[uniq_id]: seg_stt_fr, seg_end_fr = int(seg_stt * self.frame_per_sec), int(seg_end * self.frame_per_sec) soft_label_vec = torch.sum(fr_level_target[seg_stt_fr:seg_end_fr, :], axis=0) / ( seg_end_fr - seg_stt_fr @@ -588,7 +739,8 @@ def __getitem__(self, index): if avg_embs.shape[2] > self.max_spks: raise ValueError( - f" avg_embs.shape[2] {avg_embs.shape[2]} should be less than or equal to self.max_num_speakers {self.max_spks}" + f" avg_embs.shape[2] {avg_embs.shape[2]} should be less than or equal to " + f"self.max_num_speakers {self.max_spks}" ) feats = [] @@ -682,7 +834,8 @@ def _msdd_train_collate_fn(self, batch): def _msdd_infer_collate_fn(self, batch): """ - Collate batch of feats (speaker embeddings), feature lengths, target label sequences and cluster-average embeddings. + Collate batch of feats (speaker embeddings), feature lengths, target label sequences + and cluster-average embeddings. Args: batch (tuple): @@ -784,6 +937,7 @@ def __init__( ) def msdd_train_collate_fn(self, batch): + """Collate batch of audio features, feature lengths, target label sequences for training.""" return _msdd_train_collate_fn(self, batch) @@ -805,11 +959,13 @@ class AudioToSpeechMSDDInferDataset(_AudioMSDDInferDataset): emb_dict (dict): Dictionary containing cluster-average embeddings and speaker mapping information. emb_seq (dict): - Dictionary containing multiscale speaker embedding sequence, scale mapping and corresponding segment timestamps. + Dictionary containing multiscale speaker embedding sequence, scale mapping + and corresponding segment timestamps. clus_label_dict (dict): Subsegment-level (from base-scale) speaker labels from clustering results. soft_label_thres (float): - Threshold that determines speaker labels of segments depending on the overlap with groundtruth speaker timestamps. + Threshold that determines speaker labels of segments depending on the overlap + with groundtruth speaker timestamps. featurizer: Featurizer instance for generating features from raw waveform. use_single_scale_clus (bool): @@ -817,11 +973,12 @@ class AudioToSpeechMSDDInferDataset(_AudioMSDDInferDataset): seq_eval_mode (bool): If True, F1 score will be calculated for each speaker pair during inference mode. window_stride (float): - Window stride for acoustic feature. This value is used for calculating the numbers of feature-level frames. + Window stride for acoustic feature. This value is used for calculating the numbers of + feature-level frames. pairwise_infer (bool): - If True, this Dataset class operates in inference mode. In inference mode, a set of speakers in the input audio - is split into multiple pairs of speakers and speaker tuples (e.g. 3 speakers: [(0,1), (1,2), (0,2)]) and then - fed into the MSDD to merge the individual results. + If True, this Dataset class operates in inference mode. In inference mode, a set of speakers + in the input audio is split into multiple pairs of speakers and speaker tuples + (e.g. 3 speakers: [(0,1), (1,2), (0,2)]) and then fed into the MSDD to merge the individual results. """ def __init__( @@ -850,4 +1007,366 @@ def __init__( ) def msdd_infer_collate_fn(self, batch): + """Collate batch of audio features, feature lengths, target label sequences for inference.""" return _msdd_infer_collate_fn(self, batch) + + +class _AudioToSpeechE2ESpkDiarDataset(Dataset): + """ + Dataset class that loads a json file containing paths to audio files, + RTTM files and number of speakers. This Dataset class is designed for + training or fine-tuning speaker embedding extractor and diarization decoder + at the same time. + + Example: + {"audio_filepath": "/path/to/audio_0.wav", "num_speakers": 2, + "rttm_filepath": "/path/to/diar_label_0.rttm} + ... + {"audio_filepath": "/path/to/audio_n.wav", "num_speakers": 2, + "rttm_filepath": "/path/to/diar_label_n.rttm} + + Args: + manifest_filepath (str): + Path to input manifest json files. + multiargs_dict (dict): + Dictionary containing the parameters for multiscale segmentation and clustering. + soft_label_thres (float): + Threshold that determines the label of each segment based on RTTM file information. + featurizer: + Featurizer instance for generating audio_signal from the raw waveform. + window_stride (float): + Window stride for acoustic feature. This value is used for calculating the numbers of feature-level frames. + """ + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + """Returns definitions of module output ports.""" + output_types = { + "audio_signal": NeuralType(('B', 'T'), AudioSignal()), + "audio_length": NeuralType(('B'), LengthsType()), + "targets": NeuralType(('B', 'T', 'C'), ProbsType()), + "target_len": NeuralType(('B', 'C'), LengthsType()), + } + + return output_types + + def __init__( + self, + *, + manifest_filepath: str, + soft_label_thres: float, + session_len_sec: float, + num_spks: int, + featurizer, + window_stride: float, + min_subsegment_duration: float = 0.03, + global_rank: int = 0, + dtype=torch.float16, + round_digits: int = 2, + soft_targets: bool = False, + subsampling_factor: int = 8, + ): + super().__init__() + self.collection = EndtoEndDiarizationSpeechLabel( + manifests_files=manifest_filepath.split(','), + round_digits=round_digits, + ) + self.featurizer = featurizer + self.round_digits = round_digits + self.feat_per_sec = int(1 / window_stride) + self.diar_frame_length = round(subsampling_factor * window_stride, round_digits) + self.session_len_sec = session_len_sec + self.soft_label_thres = soft_label_thres + self.max_spks = num_spks + self.min_subsegment_duration = min_subsegment_duration + self.dtype = dtype + self.use_asr_style_frame_count = True + self.soft_targets = soft_targets + self.round_digits = 2 + self.floor_decimal = 10**self.round_digits + + def __len__(self): + return len(self.collection) + + def get_uniq_id_with_range(self, sample, deci=3): + """ + Generate unique training sample ID from unique file ID, offset and duration. The start-end time added + unique ID is required for identifying the sample since multiple short audio samples are generated from a single + audio file. The start time and end time of the audio stream uses millisecond units if `deci=3`. + + Args: + sample: + `DiarizationSpeechLabel` instance from collections. + + Returns: + uniq_id (str): + Unique sample ID which includes start and end time of the audio stream. + Example: abc1001_3122_6458 + """ + bare_uniq_id = os.path.splitext(os.path.basename(sample.rttm_file))[0] + offset = str(int(round(sample.offset, deci) * pow(10, deci))) + endtime = str(int(round(sample.offset + sample.duration, deci) * pow(10, deci))) + uniq_id = f"{bare_uniq_id}_{offset}_{endtime}" + return uniq_id + + def parse_rttm_for_targets_and_lens(self, rttm_file, offset, duration, target_len): + """ + Generate target tensor variable by extracting groundtruth diarization labels from an RTTM file. + This function converts (start, end, speaker_id) format into base-scale (the finest scale) segment level + diarization label in a matrix form. + + Example of seg_target: + [[0., 1.], [0., 1.], [1., 1.], [1., 0.], [1., 0.], ..., [0., 1.]] + """ + with open(rttm_file, 'r') as f: + rttm_lines = f.readlines() + + rttm_timestamps, sess_to_global_spkids = extract_frame_info_from_rttm(offset, duration, rttm_lines) + + fr_level_target = get_frame_targets_from_rttm( + rttm_timestamps=rttm_timestamps, + offset=offset, + duration=duration, + round_digits=self.round_digits, + feat_per_sec=self.feat_per_sec, + max_spks=self.max_spks, + ) + + soft_target_seg = self.get_soft_targets_seg(feat_level_target=fr_level_target, target_len=target_len) + if self.soft_targets: + step_target = soft_target_seg + else: + step_target = (soft_target_seg >= self.soft_label_thres).float() + return step_target + + def get_soft_targets_seg(self, feat_level_target, target_len): + """ + Generate the final targets for the actual diarization step. + Here, frame level means step level which is also referred to as segments. + We follow the original paper and refer to the step level as "frames". + + Args: + feat_level_target (torch.tensor): + Tensor variable containing hard-labels of speaker activity in each feature-level segment. + target_len (torch.tensor): + Numbers of ms segments + + Returns: + soft_target_seg (torch.tensor): + Tensor variable containing soft-labels of speaker activity in each step-level segment. + """ + num_seg = torch.max(target_len) + targets = torch.zeros(num_seg, self.max_spks) + stride = int(self.feat_per_sec * self.diar_frame_length) + for index in range(num_seg): + if index == 0: + seg_stt_feat = 0 + else: + seg_stt_feat = stride * index - 1 - int(stride / 2) + if index == num_seg - 1: + seg_end_feat = feat_level_target.shape[0] + else: + seg_end_feat = stride * index - 1 + int(stride / 2) + targets[index] = torch.mean(feat_level_target[seg_stt_feat : seg_end_feat + 1, :], axis=0) + return targets + + def get_segment_timestamps( + self, + duration: float, + offset: float = 0, + sample_rate: int = 16000, + ): + """ + Get start and end time of segments in each scale. + + Args: + sample: + `DiarizationSpeechLabel` instance from preprocessing.collections + Returns: + segment_timestamps (torch.tensor): + Tensor containing Multiscale segment timestamps. + target_len (torch.tensor): + Number of segments for each scale. This information is used for reshaping embedding batch + during forward propagation. + """ + subsegments = get_subsegments( + offset=offset, + window=round(self.diar_frame_length * 2, self.round_digits), + shift=self.diar_frame_length, + duration=duration, + min_subsegment_duration=self.min_subsegment_duration, + use_asr_style_frame_count=self.use_asr_style_frame_count, + sample_rate=sample_rate, + feat_per_sec=self.feat_per_sec, + ) + if self.use_asr_style_frame_count: + effective_dur = ( + np.ceil((1 + duration * sample_rate) / int(sample_rate / self.feat_per_sec)).astype(int) + / self.feat_per_sec + ) + else: + effective_dur = duration + ts_tensor = get_subsegments_to_timestamps( + subsegments, self.feat_per_sec, decimals=2, max_end_ts=(offset + effective_dur) + ) + target_len = torch.tensor([ts_tensor.shape[0]]) + return target_len + + def __getitem__(self, index): + sample = self.collection[index] + if sample.offset is None: + sample.offset = 0 + offset = sample.offset + if self.session_len_sec < 0: + session_len_sec = sample.duration + else: + session_len_sec = min(sample.duration, self.session_len_sec) + + audio_signal = self.featurizer.process(sample.audio_file, offset=offset, duration=session_len_sec) + + # We should resolve the length mis-match from the round-off errors between these two variables: + # `session_len_sec` and `audio_signal.shape[0]` + session_len_sec = ( + np.floor(audio_signal.shape[0] / self.featurizer.sample_rate * self.floor_decimal) / self.floor_decimal + ) + audio_signal = audio_signal[: round(self.featurizer.sample_rate * session_len_sec)] + + audio_signal_length = torch.tensor(audio_signal.shape[0]).long() + audio_signal, audio_signal_length = audio_signal.to('cpu'), audio_signal_length.to('cpu') + target_len = self.get_segment_timestamps(duration=session_len_sec, sample_rate=self.featurizer.sample_rate) + targets = self.parse_rttm_for_targets_and_lens( + rttm_file=sample.rttm_file, offset=offset, duration=session_len_sec, target_len=target_len + ) + return audio_signal, audio_signal_length, targets, target_len + + +def _eesd_train_collate_fn(self, batch): + """ + Collate a batch of variables needed for training the end-to-end speaker diarization (EESD) model + from raw waveforms to diarization labels. The following variables are included in the training/validation batch: + + Args: + batch (tuple): + A tuple containing the variables for diarization training. + + Returns: + audio_signal (torch.Tensor): + A tensor containing the raw waveform samples (time series) loaded from the `audio_filepath` + in the input manifest file. + feature_length (torch.Tensor): + A tensor containing the lengths of the raw waveform samples. + targets (torch.Tensor): + Groundtruth speaker labels for the given input embedding sequence. + target_lens (torch.Tensor): + A tensor containing the number of segments for each sample in the batch, necessary for + reshaping inputs to the EESD model. + """ + packed_batch = list(zip(*batch)) + audio_signal, feature_length, targets, target_len = packed_batch + audio_signal_list, feature_length_list = [], [] + target_len_list, targets_list = [], [] + + max_raw_feat_len = max([x.shape[0] for x in audio_signal]) + max_target_len = max([x.shape[0] for x in targets]) + if max([len(feat.shape) for feat in audio_signal]) > 1: + max_ch = max([feat.shape[1] for feat in audio_signal]) + else: + max_ch = 1 + for feat, feat_len, tgt, segment_ct in batch: + seq_len = tgt.shape[0] + if len(feat.shape) > 1: + pad_feat = (0, 0, 0, max_raw_feat_len - feat.shape[0]) + else: + pad_feat = (0, max_raw_feat_len - feat.shape[0]) + if feat.shape[0] < feat_len: + feat_len_pad = feat_len - feat.shape[0] + feat = torch.nn.functional.pad(feat, (0, feat_len_pad)) + pad_tgt = (0, 0, 0, max_target_len - seq_len) + padded_feat = torch.nn.functional.pad(feat, pad_feat) + padded_tgt = torch.nn.functional.pad(tgt, pad_tgt) + if max_ch > 1 and padded_feat.shape[1] < max_ch: + feat_ch_pad = max_ch - padded_feat.shape[1] + padded_feat = torch.nn.functional.pad(padded_feat, (0, feat_ch_pad)) + audio_signal_list.append(padded_feat) + feature_length_list.append(feat_len.clone().detach()) + target_len_list.append(segment_ct.clone().detach()) + targets_list.append(padded_tgt) + audio_signal = torch.stack(audio_signal_list) + feature_length = torch.stack(feature_length_list) + target_lens = torch.stack(target_len_list) + targets = torch.stack(targets_list) + return audio_signal, feature_length, targets, target_lens + + +class AudioToSpeechE2ESpkDiarDataset(_AudioToSpeechE2ESpkDiarDataset): + """ + Dataset class for loading a JSON file containing paths to audio files, + RTTM (Rich Transcription Time Marked) files, and the number of speakers. + This class is designed for training or fine-tuning a speaker embedding + extractor and diarization decoder simultaneously. + + The JSON manifest file should have entries in the following format: + + Example: + { + "audio_filepath": "/path/to/audio_0.wav", + "num_speakers": 2, + "rttm_filepath": "/path/to/diar_label_0.rttm" + } + ... + { + "audio_filepath": "/path/to/audio_n.wav", + "num_speakers": 2, + "rttm_filepath": "/path/to/diar_label_n.rttm" + } + + Args: + manifest_filepath (str): + Path to the input manifest JSON file containing paths to audio and RTTM files. + soft_label_thres (float): + Threshold for assigning soft labels to segments based on RTTM file information. + session_len_sec (float): + Duration of each session (in seconds) for training or fine-tuning. + num_spks (int): + Number of speakers in the audio files. + featurizer: + Instance of a featurizer for generating features from the raw waveform. + window_stride (float): + Window stride (in seconds) for extracting acoustic features, used to calculate + the number of feature frames. + global_rank (int): + Global rank of the current process (used for distributed training). + soft_targets (bool): + Whether or not to use soft targets during training. + + Methods: + eesd_train_collate_fn(batch): + Collates a batch of data for end-to-end speaker diarization training. + """ + + def __init__( + self, + *, + manifest_filepath: str, + soft_label_thres: float, + session_len_sec: float, + num_spks: int, + featurizer, + window_stride, + global_rank: int, + soft_targets: bool, + ): + super().__init__( + manifest_filepath=manifest_filepath, + soft_label_thres=soft_label_thres, + session_len_sec=session_len_sec, + num_spks=num_spks, + featurizer=featurizer, + window_stride=window_stride, + global_rank=global_rank, + soft_targets=soft_targets, + ) + + def eesd_train_collate_fn(self, batch): + """Collate a batch of data for end-to-end speaker diarization training.""" + return _eesd_train_collate_fn(self, batch) diff --git a/nemo/collections/asr/data/audio_to_diar_label_lhotse.py b/nemo/collections/asr/data/audio_to_diar_label_lhotse.py new file mode 100644 index 000000000000..0839b63954f0 --- /dev/null +++ b/nemo/collections/asr/data/audio_to_diar_label_lhotse.py @@ -0,0 +1,82 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional, Tuple + +import torch.utils.data +from lhotse.dataset import AudioSamples +from lhotse.dataset.collation import collate_matrices + +from nemo.collections.asr.parts.utils.asr_multispeaker_utils import ( + get_hidden_length_from_sample_length, + speaker_to_target, +) +from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType + + +class LhotseAudioToSpeechE2ESpkDiarDataset(torch.utils.data.Dataset): + """ + This dataset is based on diarization datasets from audio_to_eesd_label.py. + Unlike native NeMo datasets, Lhotse dataset defines only the mapping from + a CutSet (meta-data) to a mini-batch with PyTorch tensors. + Specifically, it performs tokenization, I/O, augmentation, and feature extraction (if any). + Managing data, sampling, de-duplication across workers/nodes etc. is all handled + by Lhotse samplers instead. + """ + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + """Define the output types of the dataset.""" + return { + 'audio_signal': NeuralType(('B', 'T'), AudioSignal()), + 'a_sig_length': NeuralType(tuple('B'), LengthsType()), + 'targets': NeuralType(('B', 'T', 'N'), LabelsType()), + 'target_length': NeuralType(tuple('B'), LengthsType()), + 'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True), + } + + def __init__(self, cfg): + super().__init__() + self.load_audio = AudioSamples(fault_tolerant=True) + self.cfg = cfg + self.num_speakers = self.cfg.get('num_speakers', 4) + self.num_sample_per_mel_frame = int( + self.cfg.get('window_stride', 0.01) * self.cfg.get('sample_rate', 16000) + ) # 160 + self.num_mel_frame_per_target_frame = int(self.cfg.get('subsampling_factor', 8)) + self.spk_tar_all_zero = self.cfg.get('spk_tar_all_zero', False) + + def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]: + audio, audio_lens, cuts = self.load_audio(cuts) + speaker_activities = [] + for cut in cuts: + speaker_activity = speaker_to_target( + a_cut=cut, + num_speakers=self.num_speakers, + num_sample_per_mel_frame=self.num_sample_per_mel_frame, + num_mel_frame_per_asr_frame=self.num_mel_frame_per_target_frame, + spk_tar_all_zero=self.spk_tar_all_zero, + boundary_segments=True, + ) + speaker_activities.append(speaker_activity) + targets = collate_matrices(speaker_activities).to(audio.dtype) + target_lens_list = [] + for audio_len in audio_lens: + target_fr_len = get_hidden_length_from_sample_length( + audio_len, self.num_sample_per_mel_frame, self.num_mel_frame_per_target_frame + ) + target_lens_list.append([target_fr_len]) + target_lens = torch.tensor(target_lens_list) + + return audio, audio_lens, targets, target_lens diff --git a/nemo/collections/asr/metrics/der.py b/nemo/collections/asr/metrics/der.py index fc5cded970d0..7496f700341f 100644 --- a/nemo/collections/asr/metrics/der.py +++ b/nemo/collections/asr/metrics/der.py @@ -36,12 +36,12 @@ def get_partial_ref_labels(pred_labels: List[str], ref_labels: List[str]) -> List[str]: """ - For evaluation of online diarization performance, generate partial reference labels + For evaluation of online diarization performance, generate partial reference labels from the last prediction time. Args: pred_labels (list[str]): list of partial prediction labels - ref_labels (list[str]): list of full reference labels + ref_labels (list[str]): list of full reference labels Returns: ref_labels_out (list[str]): list of partial reference labels @@ -84,8 +84,8 @@ def get_online_DER_stats( For evaluation of online diarization performance, add cumulative, average, and maximum DER/CER. Args: - DER (float): Diarization Error Rate from the start to the current point - CER (float): Confusion Error Rate from the start to the current point + DER (float): Diarization Error Rate from the start to the current point + CER (float): Confusion Error Rate from the start to the current point FA (float): False Alarm from the start to the current point MISS (float): Miss rate from the start to the current point diar_eval_count (int): Number of evaluation sessions @@ -123,14 +123,20 @@ def uem_timeline_from_file(uem_file, uniq_name=''): lines = f.readlines() for line in lines: line = line.strip() - speaker_id, channel, start_time, end_time = line.split() + _, _, start_time, end_time = line.split() timeline.add(Segment(float(start_time), float(end_time))) return timeline def score_labels( - AUDIO_RTTM_MAP, all_reference, all_hypothesis, collar=0.25, ignore_overlap=True, verbose: bool = True + AUDIO_RTTM_MAP, + all_reference, + all_hypothesis, + all_uem: List[List[float]] = None, + collar: float = 0.25, + ignore_overlap: bool = True, + verbose: bool = True, ) -> Optional[Tuple[DiarizationErrorRate, Dict]]: """ Calculate DER, CER, FA and MISS rate from hypotheses and references. Hypothesis results are @@ -139,14 +145,21 @@ def score_labels( Args: - AUDIO_RTTM_MAP (dict): Dictionary containing information provided from manifestpath - all_reference (list[uniq_name,Annotation]): reference annotations for score calculation - all_hypothesis (list[uniq_name,Annotation]): hypothesis annotations for score calculation - verbose (bool): Warns if RTTM file is not found. + AUDIO_RTTM_MAP (dict): + Dictionary containing information provided from manifestpath + all_reference (list[uniq_name,Annotation]): + Reference annotations for score calculation + all_hypothesis (list[uniq_name,Annotation]): + Hypothesis annotations for score calculation + verbose (bool): + Warns if RTTM file is not found. Returns: - metric (pyannote.DiarizationErrorRate): Pyannote Diarization Error Rate metric object. This object contains detailed scores of each audiofile. - mapping (dict): Mapping dict containing the mapping speaker label for each audio input + metric (pyannote.DiarizationErrorRate): + Pyannote Diarization Error Rate metric object. + This object contains detailed scores of each audiofile. + mapping (dict): + Mapping dict containing the mapping speaker label for each audio input < Caveat > Unlike md-eval.pl, "no score" collar in pyannote.metrics is the maximum length of @@ -157,33 +170,51 @@ def score_labels( if len(all_reference) == len(all_hypothesis): metric = DiarizationErrorRate(collar=2 * collar, skip_overlap=ignore_overlap) - mapping_dict = {} - for (reference, hypothesis) in zip(all_reference, all_hypothesis): + mapping_dict, correct_spk_count = {}, 0 + for idx, (reference, hypothesis) in enumerate(zip(all_reference, all_hypothesis)): ref_key, ref_labels = reference _, hyp_labels = hypothesis - uem = AUDIO_RTTM_MAP[ref_key].get('uem_filepath', None) - if uem is not None: - uem = uem_timeline_from_file(uem_file=uem, uniq_name=ref_key) - metric(ref_labels, hyp_labels, uem=uem, detailed=True) + if len(ref_labels.labels()) == len(hyp_labels.labels()): + correct_spk_count += 1 + if verbose and len(ref_labels.labels()) != len(hyp_labels.labels()): + logging.info( + f"Wrong Spk. Count with uniq_id:...{ref_key[-10:]}, " + f"Ref: {len(ref_labels.labels())}, Hyp: {len(hyp_labels.labels())}" + ) + uem_obj = None + if all_uem is not None: + metric(ref_labels, hyp_labels, uem=all_uem[idx], detailed=True) + elif AUDIO_RTTM_MAP[ref_key].get('uem_filepath', None) is not None: + uem_file = AUDIO_RTTM_MAP[ref_key].get('uem_filepath', None) + uem_obj = uem_timeline_from_file(uem_file=uem_file, uniq_name=ref_key) + metric(ref_labels, hyp_labels, uem=uem_obj, detailed=True) + else: + metric(ref_labels, hyp_labels, detailed=True) mapping_dict[ref_key] = metric.optimal_mapping(ref_labels, hyp_labels) + spk_count_acc = correct_spk_count / len(all_reference) DER = abs(metric) + if metric['total'] == 0: + raise ValueError("Total evaluation time is 0. Abort.") CER = metric['confusion'] / metric['total'] FA = metric['false alarm'] / metric['total'] MISS = metric['missed detection'] / metric['total'] + itemized_errors = (DER, CER, FA, MISS) + if verbose: + logging.info(f"\n{metric.report()}") logging.info( - "Cumulative Results for collar {} sec and ignore_overlap {}: \n FA: {:.4f}\t MISS {:.4f}\t \ - Diarization ER: {:.4f}\t, Confusion ER:{:.4f}".format( - collar, ignore_overlap, FA, MISS, DER, CER - ) + f"Cumulative Results for collar {collar} sec and ignore_overlap {ignore_overlap}: \n" + f"| FA: {FA:.4f} | MISS: {MISS:.4f} | CER: {CER:.4f} | DER: {DER:.4f} | " + f"Spk. Count Acc. {spk_count_acc:.4f}\n" ) return metric, mapping_dict, itemized_errors elif verbose: logging.warning( - "Check if each ground truth RTTMs were present in the provided manifest file. Skipping calculation of Diariazation Error Rate" + "Check if each ground truth RTTMs were present in the provided manifest file. " + "Skipping calculation of Diariazation Error Rate" ) return None @@ -365,7 +396,7 @@ def calculate_session_cpWER( # Calculate WER for each speaker in hypothesis with reference # There are (number of hyp speakers) x (number of ref speakers) combinations lsa_wer_list = [] - for (spk_hyp_trans, spk_ref_trans) in all_pairs: + for spk_hyp_trans, spk_ref_trans in all_pairs: spk_wer = word_error_rate(hypotheses=[spk_hyp_trans], references=[spk_ref_trans]) lsa_wer_list.append(spk_wer) @@ -419,7 +450,7 @@ def concat_perm_word_error_rate( f"{len(spk_hypotheses)} and {len(spk_references)} correspondingly" ) cpWER_values, hyps_spk, refs_spk = [], [], [] - for (spk_hypothesis, spk_reference) in zip(spk_hypotheses, spk_references): + for spk_hypothesis, spk_reference in zip(spk_hypotheses, spk_references): cpWER, min_hypothesis, concat_reference = calculate_session_cpWER(spk_hypothesis, spk_reference) cpWER_values.append(cpWER) hyps_spk.append(min_hypothesis) diff --git a/nemo/collections/asr/metrics/multi_binary_acc.py b/nemo/collections/asr/metrics/multi_binary_acc.py index 8cc21c53ad82..3a99769ebd25 100644 --- a/nemo/collections/asr/metrics/multi_binary_acc.py +++ b/nemo/collections/asr/metrics/multi_binary_acc.py @@ -73,13 +73,29 @@ def on_validation_epoch_end(self): def __init__(self, dist_sync_on_step=False): super().__init__(dist_sync_on_step=dist_sync_on_step) - self.total_correct_counts = 0 - self.total_sample_counts = 0 - self.true_positive_count = 0 - self.false_positive_count = 0 - self.false_negative_count = 0 + self.add_state("total_correct_counts", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.add_state("total_sample_counts", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.add_state("true_positive_count", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.add_state("false_positive_count", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.add_state("false_negative_count", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.add_state("positive_count", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.eps = 1e-6 + + def update( + self, preds: torch.Tensor, targets: torch.Tensor, signal_lengths: torch.Tensor, cumulative=False + ) -> torch.Tensor: + """ + Update the metric with the given predictions, targets, and signal lengths to the metric instance. + + Args: + preds (torch.Tensor): Predicted values. + targets (torch.Tensor): Target values. + signal_lengths (torch.Tensor): Length of each sequence in the batch input. + cumulative (bool): Whether to accumulate the values over time. - def update(self, preds: torch.Tensor, targets: torch.Tensor, signal_lengths: torch.Tensor) -> torch.Tensor: + Returns: + f1_score (torch.Tensor): F1 score calculated from the predicted value and binarized target values. + """ with torch.no_grad(): preds_list = [preds[k, : signal_lengths[k], :] for k in range(preds.shape[0])] targets_list = [targets[k, : signal_lengths[k], :] for k in range(targets.shape[0])] @@ -91,22 +107,30 @@ def update(self, preds: torch.Tensor, targets: torch.Tensor, signal_lengths: tor self.positive = self.preds.round().bool() == 1 self.negative = self.preds.round().bool() == 0 - self.positive_count = torch.sum(self.preds.round().bool() == True) - self.true_positive_count += torch.sum(torch.logical_and(self.true, self.positive)) - self.false_positive_count += torch.sum(torch.logical_and(self.false, self.positive)) - self.false_negative_count += torch.sum(torch.logical_and(self.false, self.negative)) - - self.total_correct_counts += torch.sum(self.preds.round().bool() == self.targets.round().bool()) - self.total_sample_counts += torch.prod(torch.tensor(self.targets.shape)) + if cumulative: + self.positive_count += torch.sum(self.preds.round().bool() == True) + self.true_positive_count += torch.sum(torch.logical_and(self.true, self.positive)) + self.false_positive_count += torch.sum(torch.logical_and(self.false, self.positive)) + self.false_negative_count += torch.sum(torch.logical_and(self.false, self.negative)) + self.total_correct_counts += torch.sum(self.preds.round().bool() == self.targets.round().bool()) + self.total_sample_counts += torch.prod(torch.tensor(self.targets.shape)) + else: + self.positive_count = torch.sum(self.preds.round().bool() == True) + self.true_positive_count = torch.sum(torch.logical_and(self.true, self.positive)) + self.false_positive_count = torch.sum(torch.logical_and(self.false, self.positive)) + self.false_negative_count = torch.sum(torch.logical_and(self.false, self.negative)) + self.total_correct_counts = torch.sum(self.preds.round().bool() == self.targets.round().bool()) + self.total_sample_counts = torch.prod(torch.tensor(self.targets.shape)) def compute(self): """ Compute F1 score from the accumulated values. Return -1 if the F1 score is NaN. """ - self.precision = self.true_positive_count / (self.true_positive_count + self.false_positive_count) - self.recall = self.true_positive_count / (self.true_positive_count + self.false_negative_count) - self.f1_score = 2 * self.precision * self.recall / (self.precision + self.recall) - if torch.isnan(self.f1_score): + precision = self.true_positive_count / (self.true_positive_count + self.false_positive_count + self.eps) + recall = self.true_positive_count / (self.true_positive_count + self.false_negative_count + self.eps) + f1_score = (2 * precision * recall / (precision + recall + self.eps)).detach().clone() + + if torch.isnan(f1_score): logging.warn("self.f1_score contains NaN value. Returning -1 instead of NaN value.") - self.f1_score = -1 - return self.f1_score + f1_score = -1 + return f1_score.float(), precision.float(), recall.float() diff --git a/nemo/collections/asr/models/__init__.py b/nemo/collections/asr/models/__init__.py index e4a1342b9c36..34dead15b33d 100644 --- a/nemo/collections/asr/models/__init__.py +++ b/nemo/collections/asr/models/__init__.py @@ -35,6 +35,7 @@ from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel from nemo.collections.asr.models.slu_models import SLUIntentSlotBPEModel +from nemo.collections.asr.models.sortformer_diar_models import SortformerEncLabelModel from nemo.collections.asr.models.ssl_models import ( EncDecDenoiseMaskedTokenPredModel, EncDecMaskedTokenPredModel, diff --git a/nemo/collections/asr/models/sortformer_diar_models.py b/nemo/collections/asr/models/sortformer_diar_models.py new file mode 100644 index 000000000000..5a3c8e354f1b --- /dev/null +++ b/nemo/collections/asr/models/sortformer_diar_models.py @@ -0,0 +1,586 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import random +import time +from collections import OrderedDict +from typing import Dict, List, Optional, Union + +import torch +from hydra.utils import instantiate +from omegaconf import DictConfig +from pytorch_lightning import Trainer +from tqdm import tqdm + +from nemo.collections.asr.data.audio_to_diar_label import AudioToSpeechE2ESpkDiarDataset +from nemo.collections.asr.data.audio_to_diar_label_lhotse import LhotseAudioToSpeechE2ESpkDiarDataset +from nemo.collections.asr.metrics.multi_binary_acc import MultiBinaryAccuracy +from nemo.collections.asr.models.asr_model import ExportableEncDecModel +from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer +from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations +from nemo.collections.asr.parts.utils.asr_multispeaker_utils import get_ats_targets, get_pil_targets +from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config +from nemo.core.classes import ModelPT +from nemo.core.classes.common import PretrainedModelInfo +from nemo.core.neural_types import AudioSignal, LengthsType, NeuralType +from nemo.core.neural_types.elements import ProbsType +from nemo.utils import logging + +__all__ = ['SortformerEncLabelModel'] + + +class SortformerEncLabelModel(ModelPT, ExportableEncDecModel): + """ + Encoder class for Sortformer diarization model. + Model class creates training, validation methods for setting up data performing model forward pass. + + This model class expects config dict for: + * preprocessor + * Transformer Encoder + * FastConformer Encoder + * Sortformer Modules + """ + + @classmethod + def list_available_models(cls) -> List[PretrainedModelInfo]: + """ + This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. + + Returns: + List of available pre-trained models. + """ + result = [] + return result + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + """ + Initialize an Sortformer Diarizer model and a pretrained NEST encoder. + In this init function, training and validation datasets are prepared. + """ + random.seed(42) + self._trainer = trainer if trainer else None + self._cfg = cfg + + if self._trainer: + self.world_size = trainer.num_nodes * trainer.num_devices + else: + self.world_size = 1 + + if self._trainer is not None and self._cfg.get('augmentor', None) is not None: + self.augmentor = process_augmentations(self._cfg.augmentor) + else: + self.augmentor = None + super().__init__(cfg=self._cfg, trainer=trainer) + self.preprocessor = SortformerEncLabelModel.from_config_dict(self._cfg.preprocessor) + + if hasattr(self._cfg, 'spec_augment') and self._cfg.spec_augment is not None: + self.spec_augmentation = SortformerEncLabelModel.from_config_dict(self._cfg.spec_augment) + else: + self.spec_augmentation = None + + self.encoder = SortformerEncLabelModel.from_config_dict(self._cfg.encoder) + self.sortformer_modules = SortformerEncLabelModel.from_config_dict(self._cfg.sortformer_modules) + self.transformer_encoder = SortformerEncLabelModel.from_config_dict(self._cfg.transformer_encoder) + self._init_loss_weights() + + self.eps = 1e-3 + self.loss = instantiate(self._cfg.loss) + + self.streaming_mode = self._cfg.get("streaming_mode", False) + self.save_hyperparameters("cfg") + self._init_eval_metrics() + + speaker_inds = list(range(self._cfg.max_num_of_spks)) + self.speaker_permutations = torch.tensor(list(itertools.permutations(speaker_inds))) # Get all permutations + + def _init_loss_weights(self): + pil_weight = self._cfg.get("pil_weight", 0.0) + ats_weight = self._cfg.get("ats_weight", 1.0) + if pil_weight + ats_weight == 0: + raise ValueError(f"weights for PIL {pil_weight} and ATS {ats_weight} cannot sum to 0") + self.pil_weight = pil_weight / (pil_weight + ats_weight) + self.ats_weight = ats_weight / (pil_weight + ats_weight) + logging.info(f"Normalized weights for PIL {self.pil_weight} and ATS {self.ats_weight}") + + def _init_eval_metrics(self): + """ + If there is no label, then the evaluation metrics will be based on Permutation Invariant Loss (PIL). + """ + self._accuracy_test = MultiBinaryAccuracy() + self._accuracy_train = MultiBinaryAccuracy() + self._accuracy_valid = MultiBinaryAccuracy() + + self._accuracy_test_ats = MultiBinaryAccuracy() + self._accuracy_train_ats = MultiBinaryAccuracy() + self._accuracy_valid_ats = MultiBinaryAccuracy() + + def _reset_train_metrics(self): + self._accuracy_train.reset() + self._accuracy_train_ats.reset() + + def _reset_valid_metrics(self): + self._accuracy_valid.reset() + self._accuracy_valid_ats.reset() + + def __setup_dataloader_from_config(self, config): + # Switch to lhotse dataloader if specified in the config + if config.get("use_lhotse"): + return get_lhotse_dataloader_from_config( + config, + global_rank=self.global_rank, + world_size=self.world_size, + dataset=LhotseAudioToSpeechE2ESpkDiarDataset(cfg=config), + ) + + featurizer = WaveformFeaturizer( + sample_rate=config['sample_rate'], int_values=config.get('int_values', False), augmentor=self.augmentor + ) + + if 'manifest_filepath' in config and config['manifest_filepath'] is None: + logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}") + return None + + logging.info(f"Loading dataset from {config.manifest_filepath}") + + if self._trainer is not None: + global_rank = self._trainer.global_rank + else: + global_rank = 0 + time_flag = time.time() + logging.info("AAB: Starting Dataloader Instance loading... Step A") + + dataset = AudioToSpeechE2ESpkDiarDataset( + manifest_filepath=config.manifest_filepath, + soft_label_thres=config.soft_label_thres, + session_len_sec=config.session_len_sec, + num_spks=config.num_spks, + featurizer=featurizer, + window_stride=self._cfg.preprocessor.window_stride, + global_rank=global_rank, + soft_targets=config.soft_targets if 'soft_targets' in config else False, + ) + logging.info( + f"AAB: Dataloader dataset is created, starting torch.utils.data.Dataloader" + f"step B: {time.time() - time_flag}" + ) + + self.data_collection = dataset.collection + self.collate_ds = dataset + + dataloader_instance = torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config.batch_size, + collate_fn=self.collate_ds.eesd_train_collate_fn, + drop_last=config.get('drop_last', False), + shuffle=False, + num_workers=config.get('num_workers', 1), + pin_memory=config.get('pin_memory', False), + ) + logging.info(f"AAC: Dataloader Instance loading is done ETA Step B done: {time.time() - time_flag}") + return dataloader_instance + + def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]): + self._train_dl = self.__setup_dataloader_from_config( + config=train_data_config, + ) + + def setup_validation_data(self, val_data_layer_config: Optional[Union[DictConfig, Dict]]): + self._validation_dl = self.__setup_dataloader_from_config( + config=val_data_layer_config, + ) + + def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): + self._test_dl = self.__setup_dataloader_from_config( + config=test_data_config, + ) + + def test_dataloader(self): + if self._test_dl is not None: + return self._test_dl + return None + + @property + def input_types(self) -> Optional[Dict[str, NeuralType]]: + if hasattr(self.preprocessor, '_sample_rate'): + audio_eltype = AudioSignal(freq=self.preprocessor._sample_rate) + else: + audio_eltype = AudioSignal() + return { + "audio_signal": NeuralType(('B', 'T'), audio_eltype), + "audio_signal_length": NeuralType(('B',), LengthsType()), + } + + @property + def output_types(self) -> Dict[str, NeuralType]: + return OrderedDict( + { + "preds": NeuralType(('B', 'T', 'C'), ProbsType()), + } + ) + + def frontend_encoder(self, processed_signal, processed_signal_length): + """ + Generate encoder outputs from frontend encoder. + + Args: + process_signal (torch.Tensor): tensor containing audio-feature (mel spectrogram, mfcc, etc.) + processed_signal_length (torch.Tensor): tensor containing lengths of audio signal in integers + + Returns: + emb_seq (torch.Tensor): tensor containing encoder outputs + emb_seq_length (torch.Tensor): tensor containing lengths of encoder outputs + """ + # Spec augment is not applied during evaluation/testing + if self.spec_augmentation is not None and self.training: + processed_signal = self.spec_augmentation(input_spec=processed_signal, length=processed_signal_length) + self.encoder = self.encoder.to(self.device) + emb_seq, emb_seq_length = self.encoder(audio_signal=processed_signal, length=processed_signal_length) + emb_seq = emb_seq.transpose(1, 2) + if self._cfg.encoder.d_model != self._cfg.tf_d_model: + self.sortformer_modules.encoder_proj = self.sortformer_modules.encoder_proj.to(self.device) + emb_seq = self.sortformer_modules.encoder_proj(emb_seq) + return emb_seq, emb_seq_length + + def forward_infer(self, emb_seq): + """ + The main forward pass for diarization for offline diarization inference. + + Args: + emb_seq (torch.Tensor): tensor containing FastConformer encoder states (embedding vectors). + Dimension: (batch_size, diar_frame_count, emb_dim) + + Returns: + preds (torch.Tensor): Sorted tensor containing Sigmoid values for predicted speaker labels. + Dimension: (batch_size, diar_frame_count, num_speakers) + encoder_states_list (list): List containing total speaker memory for each step for debugging purposes + Dimension: [(batch_size, diar_frame_count, inner dim), ... ] + """ + encoder_mask = self.sortformer_modules.length_to_mask(emb_seq) + trans_emb_seq = self.transformer_encoder(encoder_states=emb_seq, encoder_mask=encoder_mask) + preds = self.sortformer_modules.forward_speaker_sigmoids(trans_emb_seq) + return preds + + def process_signal(self, audio_signal, audio_signal_length): + """ + Extract audio features from time-series signal for further processing in the model. + + This function performs the following steps: + 1. Moves the audio signal to the correct device. + 2. Normalizes the time-series audio signal. + 3. Extrac audio feature from from the time-series audio signal using the model's preprocessor. + + Args: + audio_signal (torch.Tensor): The input audio signal. + Shape: (batch_size, num_samples) + audio_signal_length (torch.Tensor): The length of each audio signal in the batch. + Shape: (batch_size,) + + Returns: + tuple: A tuple containing: + - processed_signal (torch.Tensor): The preprocessed audio signal. + Shape: (batch_size, num_features, num_frames) + - processed_signal_length (torch.Tensor): The length of each processed signal. + Shape: (batch_size,) + """ + audio_signal = audio_signal.to(self.device) + audio_signal = (1 / (audio_signal.max() + self.eps)) * audio_signal + processed_signal, processed_signal_length = self.preprocessor( + input_signal=audio_signal, length=audio_signal_length + ) + return processed_signal, processed_signal_length + + def forward( + self, + audio_signal, + audio_signal_length, + ): + """ + Forward pass for training and inference. + + Args: + audio_signal (torch.Tensor): tensor containing audio waveform + Dimension: (batch_size, num_samples) + audio_signal_length (torch.Tensor): tensor containing lengths of audio waveforms + Dimension: (batch_size,) + + Returns: + preds (torch.Tensor): Sorted tensor containing predicted speaker labels + Dimension: (batch_size, diar_frame_count, num_speakers) + encoder_states_list (list): List containing total speaker memory for each step for debugging purposes + Dimension: [(batch_size, diar_frame_count, inner dim), ] + """ + processed_signal, processed_signal_length = self.process_signal( + audio_signal=audio_signal, audio_signal_length=audio_signal_length + ) + processed_signal = processed_signal[:, :, : processed_signal_length.max()] + if self._cfg.get("streaming_mode", False): + raise NotImplementedError("Streaming mode is not implemented yet.") + else: + emb_seq, _ = self.frontend_encoder( + processed_signal=processed_signal, processed_signal_length=processed_signal_length + ) + preds = self.forward_infer(emb_seq) + return preds + + def _get_aux_train_evaluations(self, preds, targets, target_lens): + """ + Compute auxiliary training evaluations including losses and metrics. + + This function calculates various losses and metrics for the training process, + including ATS (Anchored Temporal Segmentation) and PIL (Permutation Invariant Loss) + based evaluations. + + Args: + preds (torch.Tensor): Predicted speaker labels. + Shape: (batch_size, diar_frame_count, num_speakers) + targets (torch.Tensor): Ground truth speaker labels. + Shape: (batch_size, diar_frame_count, num_speakers) + target_lens (torch.Tensor): Lengths of target sequences. + Shape: (batch_size,) + + Returns: + (dict): A dictionary containing the following training metrics. + """ + targets_ats = get_ats_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) + targets_pil = get_pil_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) + ats_loss = self.loss(probs=preds, labels=targets_ats, target_lens=target_lens) + pil_loss = self.loss(probs=preds, labels=targets_pil, target_lens=target_lens) + loss = self.ats_weight * ats_loss + self.pil_weight * pil_loss + + self._accuracy_train(preds, targets_pil, target_lens) + train_f1_acc, train_precision, train_recall = self._accuracy_train.compute() + + self._accuracy_train_ats(preds, targets_ats, target_lens) + train_f1_acc_ats, _, _ = self._accuracy_train_ats.compute() + + train_metrics = { + 'loss': loss, + 'ats_loss': ats_loss, + 'pil_loss': pil_loss, + 'learning_rate': self._optimizer.param_groups[0]['lr'], + 'train_f1_acc': train_f1_acc, + 'train_precision': train_precision, + 'train_recall': train_recall, + 'train_f1_acc_ats': train_f1_acc_ats, + } + return train_metrics + + def training_step(self, batch: list) -> dict: + """ + Performs a single training step. + + Args: + batch (list): A list containing the following elements: + - audio_signal (torch.Tensor): The input audio signal in time-series format. + - audio_signal_length (torch.Tensor): The length of each audio signal in the batch. + - targets (torch.Tensor): The target labels for the batch. + - target_lens (torch.Tensor): The length of each target sequence in the batch. + batch_idx (int): The index of the current batch. + + Returns: + (dict): A dictionary containing the 'loss' key with the calculated loss value. + """ + audio_signal, audio_signal_length, targets, target_lens = batch + preds = self.forward(audio_signal=audio_signal, audio_signal_length=audio_signal_length) + train_metrics = self._get_aux_train_evaluations(preds, targets, target_lens) + self._reset_train_metrics() + self.log_dict(train_metrics, sync_dist=True, on_step=True, on_epoch=False, logger=True) + return {'loss': train_metrics['loss']} + + def _get_aux_validation_evaluations(self, preds, targets, target_lens): + """ + Compute auxiliary validation evaluations including losses and metrics. + This function calculates various losses and metrics for the validation process, + including ATS (Anchored Temporal Segmentation) and PIL (Permutation Invariant Loss) + based evaluations. + + Args: + preds (torch.Tensor): Predicted speaker labels. + Shape: (batch_size, diar_frame_count, num_speakers) + targets (torch.Tensor): Ground truth speaker labels. + Shape: (batch_size, diar_frame_count, num_speakers) + target_lens (torch.Tensor): Lengths of target sequences. + Shape: (batch_size,) + + Returns: + dict: A dictionary containing the following validation metrics + """ + targets_ats = get_ats_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) + targets_pil = get_pil_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) + + val_ats_loss = self.loss(probs=preds, labels=targets_ats, target_lens=target_lens) + val_pil_loss = self.loss(probs=preds, labels=targets_pil, target_lens=target_lens) + val_loss = self.ats_weight * val_ats_loss + self.pil_weight * val_pil_loss + + self._accuracy_valid(preds, targets_pil, target_lens) + val_f1_acc, val_precision, val_recall = self._accuracy_valid.compute() + + self._accuracy_valid_ats(preds, targets_ats, target_lens) + valid_f1_acc_ats, _, _ = self._accuracy_valid_ats.compute() + + self._accuracy_valid.reset() + self._accuracy_valid_ats.reset() + + val_metrics = { + 'val_loss': val_loss, + 'val_ats_loss': val_ats_loss, + 'val_pil_loss': val_pil_loss, + 'val_f1_acc': val_f1_acc, + 'val_precision': val_precision, + 'val_recall': val_recall, + 'val_f1_acc_ats': valid_f1_acc_ats, + } + return val_metrics + + def validation_step(self, batch: list, dataloader_idx: int = 0): + """ + Performs a single validation step. + + This method processes a batch of data during the validation phase. It forward passes + the audio signal through the model, computes various validation metrics, and stores + these metrics for later aggregation. + + Args: + batch (list): A list containing the following elements: + - audio_signal (torch.Tensor): The input audio signal. + - audio_signal_length (torch.Tensor): The length of each audio signal in the batch. + - targets (torch.Tensor): The target labels for the batch. + - target_lens (torch.Tensor): The length of each target sequence in the batch. + batch_idx (int): The index of the current batch. + dataloader_idx (int, optional): The index of the dataloader in case of multiple + validation dataloaders. Defaults to 0. + + Returns: + dict: A dictionary containing various validation metrics for this batch. + """ + audio_signal, audio_signal_length, targets, target_lens = batch + preds = self.forward( + audio_signal=audio_signal, + audio_signal_length=audio_signal_length, + ) + val_metrics = self._get_aux_validation_evaluations(preds, targets, target_lens) + if isinstance(self.trainer.val_dataloaders, list) and len(self.trainer.val_dataloaders) > 1: + self.validation_step_outputs[dataloader_idx].append(val_metrics) + else: + self.validation_step_outputs.append(val_metrics) + return val_metrics + + def multi_validation_epoch_end(self, outputs: list, dataloader_idx: int = 0): + if not outputs: + logging.warning(f"`outputs` is None; empty outputs for dataloader={dataloader_idx}") + return None + val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean() + val_ats_loss_mean = torch.stack([x['val_ats_loss'] for x in outputs]).mean() + val_pil_loss_mean = torch.stack([x['val_pil_loss'] for x in outputs]).mean() + val_f1_acc_mean = torch.stack([x['val_f1_acc'] for x in outputs]).mean() + val_precision_mean = torch.stack([x['val_precision'] for x in outputs]).mean() + val_recall_mean = torch.stack([x['val_recall'] for x in outputs]).mean() + val_f1_acc_ats_mean = torch.stack([x['val_f1_acc_ats'] for x in outputs]).mean() + + self._reset_valid_metrics() + + multi_val_metrics = { + 'val_loss': val_loss_mean, + 'val_ats_loss': val_ats_loss_mean, + 'val_pil_loss': val_pil_loss_mean, + 'val_f1_acc': val_f1_acc_mean, + 'val_precision': val_precision_mean, + 'val_recall': val_recall_mean, + 'val_f1_acc_ats': val_f1_acc_ats_mean, + } + return {'log': multi_val_metrics} + + def _get_aux_test_batch_evaluations(self, batch_idx: int, preds, targets, target_lens): + """ + Compute auxiliary validation evaluations including losses and metrics. + + This function calculates various losses and metrics for the validation process, + including ATS (Anchored Temporal Segmentation) and PIL (Permutation Invariant Loss) + based evaluations. + + Args: + preds (torch.Tensor): Predicted speaker labels. + Shape: (batch_size, diar_frame_count, num_speakers) + targets (torch.Tensor): Ground truth speaker labels. + Shape: (batch_size, diar_frame_count, num_speakers) + target_lens (torch.Tensor): Lengths of target sequences. + Shape: (batch_size,) + + Returns: + dict: A dictionary containing the following validation metrics + """ + targets_ats = get_ats_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) + targets_pil = get_pil_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) + self._accuracy_test(preds, targets_pil, target_lens) + f1_acc, precision, recall = self._accuracy_test.compute() + self.batch_f1_accs_list.append(f1_acc) + self.batch_precision_list.append(precision) + self.batch_recall_list.append(recall) + logging.info(f"batch {batch_idx}: f1_acc={f1_acc}, precision={precision}, recall={recall}") + + self._accuracy_test_ats(preds, targets_ats, target_lens) + f1_acc_ats, precision_ats, recall_ats = self._accuracy_test_ats.compute() + self.batch_f1_accs_ats_list.append(f1_acc_ats) + logging.info( + f"batch {batch_idx}: f1_acc_ats={f1_acc_ats}, precision_ats={precision_ats}, recall_ats={recall_ats}" + ) + + self._accuracy_test.reset() + self._accuracy_test_ats.reset() + + def test_batch( + self, + ): + """ + Perform batch testing on the model. + + This method iterates through the test data loader, making predictions for each batch, + and calculates various evaluation metrics. It handles both single and multi-sample batches. + """ + ( + self.preds_total_list, + self.batch_f1_accs_list, + self.batch_precision_list, + self.batch_recall_list, + self.batch_f1_accs_ats_list, + ) = ([], [], [], [], []) + + with torch.no_grad(): + for batch_idx, batch in enumerate(tqdm(self._test_dl)): + audio_signal, audio_signal_length, targets, target_lens = batch + audio_signal = audio_signal.to(self.device) + audio_signal_length = audio_signal_length.to(self.device) + preds = self.forward( + audio_signal=audio_signal, + audio_signal_length=audio_signal_length, + ) + preds = preds.detach().to('cpu') + if preds.shape[0] == 1: # batch size = 1 + self.preds_total_list.append(preds) + else: + self.preds_total_list.extend(torch.split(preds, [1] * preds.shape[0])) + torch.cuda.empty_cache() + self._get_aux_test_batch_evaluations(batch_idx, preds, targets, target_lens) + + logging.info(f"Batch F1Acc. MEAN: {torch.mean(torch.tensor(self.batch_f1_accs_list))}") + logging.info(f"Batch Precision MEAN: {torch.mean(torch.tensor(self.batch_precision_list))}") + logging.info(f"Batch Recall MEAN: {torch.mean(torch.tensor(self.batch_recall_list))}") + logging.info(f"Batch ATS F1Acc. MEAN: {torch.mean(torch.tensor(self.batch_f1_accs_ats_list))}") + + def diarize( + self, + ): + """One-clieck runner function for diarization.""" + raise NotImplementedError diff --git a/nemo/collections/asr/modules/sortformer_modules.py b/nemo/collections/asr/modules/sortformer_modules.py new file mode 100644 index 000000000000..193dae29c304 --- /dev/null +++ b/nemo/collections/asr/modules/sortformer_modules.py @@ -0,0 +1,108 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from nemo.core.classes.exportable import Exportable +from nemo.core.classes.module import NeuralModule + +__all__ = ['SortformerModules'] + + +class SortformerModules(NeuralModule, Exportable): + """ + A class including auxiliary functions for Sortformer models. + This class contains and will contain the following functions that performs streaming features, + and any neural layers that are not included in the NeMo neural modules (e.g. Transformer, Fast-Conformer). + """ + + def init_weights(self, m): + """Init weights for linear layers.""" + if type(m) == nn.Linear: + torch.nn.init.xavier_uniform_(m.weight) + m.bias.data.fill_(0.01) + + def __init__( + self, + num_spks: int = 4, + hidden_size: int = 192, + dropout_rate: float = 0.5, + fc_d_model: int = 512, + tf_d_model: int = 192, + ): + """ + Args: + num_spks (int): + Max number of speakers that are processed by the model. + hidden_size (int): + Number of hidden units in sequence models and intermediate layers. + dropout_rate (float): + Dropout rate for linear layers, CNN and LSTM. + fc_d_model (int): + Dimension of the embedding vectors. + tf_d_model (int): + Dimension of the embedding vectors. + """ + super().__init__() + self.fc_d_model = fc_d_model + self.tf_d_model = tf_d_model + self.hidden_size = tf_d_model + self.unit_n_spks: int = num_spks + self.hidden_to_spks = nn.Linear(2 * self.hidden_size, self.unit_n_spks) + self.first_hidden_to_hidden = nn.Linear(self.hidden_size, self.hidden_size) + self.single_hidden_to_spks = nn.Linear(self.hidden_size, self.unit_n_spks) + self.dropout = nn.Dropout(dropout_rate) + self.encoder_proj = nn.Linear(self.fc_d_model, self.tf_d_model) + + def length_to_mask(self, context_embs): + """ + Convert length values to encoder mask input tensor. + + Args: + lengths (torch.Tensor): tensor containing lengths of sequences + max_len (int): maximum sequence length + + Returns: + mask (torch.Tensor): tensor of shape (batch_size, max_len) containing 0's + in the padded region and 1's elsewhere + """ + lengths = torch.tensor([context_embs.shape[1]] * context_embs.shape[0]) + batch_size = context_embs.shape[0] + max_len = context_embs.shape[1] + # create a tensor with the shape (batch_size, 1) filled with ones + row_vector = torch.arange(max_len).unsqueeze(0).expand(batch_size, -1).to(lengths.device) + # create a tensor with the shape (batch_size, max_len) filled with lengths + length_matrix = lengths.unsqueeze(1).expand(-1, max_len).to(lengths.device) + # create a mask by comparing the row vector and length matrix + mask = row_vector < length_matrix + return mask.float().to(context_embs.device) + + def forward_speaker_sigmoids(self, hidden_out): + """ + A set of layers for predicting speaker probabilities with a sigmoid activation function. + + Args: + hidden_out (torch.Tensor): tensor of shape (batch_size, seq_len, hidden_size) + + Returns: + preds (torch.Tensor): tensor of shape (batch_size, num_spks) containing speaker probabilities + """ + hidden_out = self.dropout(F.relu(hidden_out)) + hidden_out = self.first_hidden_to_hidden(hidden_out) + hidden_out = self.dropout(F.relu(hidden_out)) + spk_preds = self.single_hidden_to_spks(hidden_out) + preds = nn.Sigmoid()(spk_preds) + return preds diff --git a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py new file mode 100644 index 000000000000..eddfd3254adc --- /dev/null +++ b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py @@ -0,0 +1,410 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import torch +from lhotse import SupervisionSet +from lhotse.cut import MixedCut, MonoCut + + +def find_first_nonzero(mat: torch.Tensor, max_cap_val=-1, thres: float = 0.5) -> torch.Tensor: + """ + Finds the first nonzero value in the matrix, discretizing it to the specified maximum capacity. + + Args: + mat (Tensor): A torch tensor representing the matrix. + max_cap_val (int): The maximum capacity to which the matrix values will be discretized. + thres (float): The threshold value for discretizing the matrix values. + + Returns: + mask_max_indices (Tensor): A torch tensor representing the discretized matrix with the first + nonzero value in each row. + """ + # Discretize the matrix to the specified maximum capacity + labels_discrete = mat.clone() + labels_discrete[labels_discrete < thres] = 0 + labels_discrete[labels_discrete >= thres] = 1 + + # non zero values mask + non_zero_mask = labels_discrete != 0 + # operations on the mask to find first nonzero values in the rows + mask_max_values, mask_max_indices = torch.max(non_zero_mask, dim=1) + # if the max-mask is zero, there is no nonzero value in the row + mask_max_indices[mask_max_values == 0] = max_cap_val + return mask_max_indices + + +def find_best_permutation(match_score: torch.Tensor, speaker_permutations: torch.Tensor) -> torch.Tensor: + """ + Finds the best permutation indices based on the match score. + + Args: + match_score (torch.Tensor): A tensor containing the match scores for each permutation. + Shape: (batch_size, num_permutations) + speaker_permutations (torch.Tensor): A tensor containing all possible speaker permutations. + Shape: (num_permutations, num_speakers) + + Returns: + torch.Tensor: A tensor containing the best permutation indices for each batch. + Shape: (batch_size, num_speakers) + """ + batch_best_perm = torch.argmax(match_score, axis=1) + rep_speaker_permutations = speaker_permutations.repeat(batch_best_perm.shape[0], 1).to(match_score.device) + perm_size = speaker_permutations.shape[0] + global_inds_vec = ( + torch.arange(0, perm_size * batch_best_perm.shape[0], perm_size).to(batch_best_perm.device) + batch_best_perm + ) + return rep_speaker_permutations[global_inds_vec.to(rep_speaker_permutations.device), :] + + +def reconstruct_labels(labels: torch.Tensor, batch_perm_inds: torch.Tensor) -> torch.Tensor: + """ + Reconstructs the labels using the best permutation indices with matrix operations. + + Args: + labels (torch.Tensor): A tensor containing the original labels. + Shape: (batch_size, num_frames, num_speakers) + batch_perm_inds (torch.Tensor): A tensor containing the best permutation indices for each batch. + Shape: (batch_size, num_speakers) + + Returns: + torch.Tensor: A tensor containing the reconstructed labels using the best permutation indices. + Shape: (batch_size, num_frames, num_speakers) + """ + # Expanding batch_perm_inds to align with labels dimensions + batch_size, num_frames, num_speakers = labels.shape + batch_perm_inds_exp = batch_perm_inds.unsqueeze(1).expand(-1, num_frames, -1) + + # Reconstructing the labels using advanced indexing + reconstructed_labels = torch.gather(labels, 2, batch_perm_inds_exp) + return reconstructed_labels + + +def get_ats_targets( + labels: torch.Tensor, + preds: torch.Tensor, + speaker_permutations: torch.Tensor, + thres: float = 0.5, + tolerance: float = 0, +) -> torch.Tensor: + """ + Sorts labels and predictions to get the optimal of all arrival-time ordered permutations. + + Args: + labels (torch.Tensor): A tensor containing the original labels. + Shape: (batch_size, num_frames, num_speakers) + preds (torch.Tensor): A tensor containing the predictions. + Shape: (batch_size, num_frames, num_speakers) + speaker_permutations (torch.Tensor): A tensor containing all possible speaker permutations. + Shape: (num_permutations, num_speakers) + thres (float): The threshold value for discretizing the matrix values. Default is 0.5. + tolerance (float): The tolerance for comparing the first speech frame indices. Default is 0. + + Returns: + torch.Tensor: A tensor containing the reconstructed labels using the best permutation indices. + Shape: (batch_size, num_frames, num_speakers) + """ + # Find the first nonzero frame index for each speaker in each batch + nonzero_ind = find_first_nonzero( + mat=labels, max_cap_val=labels.shape[1], thres=thres + ) # (batch_size, num_speakers) + + # Sort the first nonzero frame indices for arrival-time ordering + sorted_values = torch.sort(nonzero_ind)[0] # (batch_size, num_speakers) + perm_size = speaker_permutations.shape[0] # Scalar value (num_permutations) + permed_labels = labels[:, :, speaker_permutations] # (batch_size, num_frames, num_permutations, num_speakers) + permed_nonzero_ind = find_first_nonzero( + mat=permed_labels, max_cap_val=labels.shape[1] + ) # (batch_size, num_permutations, num_speakers) + + # Compare the first frame indices of sorted labels with those of the permuted labels using tolerance + perm_compare = ( + torch.abs(sorted_values.unsqueeze(1) - permed_nonzero_ind) <= tolerance + ) # (batch_size, num_permutations, num_speakers) + perm_mask = torch.all(perm_compare, dim=2).float() # (batch_size, num_permutations) + preds_rep = torch.unsqueeze(preds, 2).repeat( + 1, 1, perm_size, 1 + ) # Exapnd the preds: (batch_size, num_frames, num_permutations, num_speakers) + + # Compute the match score for each permutation by comparing permuted labels with preds + match_score = ( + torch.sum(permed_labels * preds_rep, axis=1).sum(axis=2) * perm_mask + ) # (batch_size, num_permutations) + batch_perm_inds = find_best_permutation(match_score, speaker_permutations) # (batch_size, num_speakers) + max_score_permed_labels = reconstruct_labels(labels, batch_perm_inds) # (batch_size, num_frames, num_speakers) + return max_score_permed_labels # (batch_size, num_frames, num_speakers) + + +def get_pil_targets(labels: torch.Tensor, preds: torch.Tensor, speaker_permutations: torch.Tensor) -> torch.Tensor: + """ + Sorts labels and predictions to get the optimal permutation based on the match score. + + Args: + labels (torch.Tensor): A tensor containing the ground truth labels. + Shape: (batch_size, num_speakers, num_classes) + preds (torch.Tensor): A tensor containing the predicted values. + Shape: (batch_size, num_speakers, num_classes) + speaker_permutations (torch.Tensor): A tensor containing all possible speaker permutations. + Shape: (num_permutations, num_speakers) + + Returns: + torch.Tensor: A tensor of permuted labels that best match the predictions. + Shape: (batch_size, num_speakers, num_classes) + """ + permed_labels = labels[:, :, speaker_permutations] # (batch_size, num_classes, num_permutations, num_speakers) + # Repeat preds to match permutations for comparison + preds_rep = torch.unsqueeze(preds, 2).repeat( + 1, 1, speaker_permutations.shape[0], 1 + ) # (batch_size, num_speakers, num_permutations, num_classes) + match_score = torch.sum(permed_labels * preds_rep, axis=1).sum(axis=2) # (batch_size, num_permutations) + batch_perm_inds = find_best_permutation(match_score, speaker_permutations) # (batch_size, num_speakers) + # Reconstruct labels based on the best permutation for each batch + max_score_permed_labels = reconstruct_labels(labels, batch_perm_inds) # (batch_size, num_speakers, num_classes) + return max_score_permed_labels # (batch_size, num_speakers, num_classes) + + +def find_segments_from_rttm( + recording_id: str, + rttms, + start_after: float, + end_before: float, + adjust_offset: bool = True, + tolerance: float = 0.001, +): + """ + Finds segments from the given rttm file. + This function is designed to replace rttm + + Args: + recording_id (str): The recording ID in string format. + rttms (SupervisionSet): The SupervisionSet instance. + start_after (float): The start time after which segments are selected. + end_before (float): The end time before which segments are selected. + adjust_offset (bool): Whether to adjust the offset of the segments. + tolerance (float): The tolerance for time matching. 0.001 by default. + + Returns: + segments (List[SupervisionSegment]): A list of SupervisionSegment instances. + """ + segment_by_recording_id = rttms._segments_by_recording_id + if segment_by_recording_id is None: + from cytoolz import groupby + + segment_by_recording_id = groupby(lambda seg: seg.recording_id, rttms) + + return [ + # We only modify the offset - the duration remains the same, as we're only shifting the segment + # relative to the Cut's start, and not truncating anything. + segment.with_offset(-start_after) if adjust_offset else segment + for segment in segment_by_recording_id.get(recording_id, []) + if segment.start < end_before + tolerance and segment.end > start_after + tolerance + ] + + +def get_mask_from_segments( + segments: list, + a_cut, + speaker_to_idx_map: torch.Tensor, + num_speakers: int = 4, + feat_per_sec: int = 100, + ignore_num_spk_mismatch: bool = False, +): + """ + Generate mask matrix from segments list. + This function is needed for speaker diarization with ASR model trainings. + + Args: + segments: A list of Lhotse Supervision segments iterator. + cut (MonoCut, MixedCut): Lhotse MonoCut or MixedCut instance. + speaker_to_idx_map (dict): A dictionary mapping speaker names to indices. + num_speakers (int): max number of speakers for all cuts ("mask" dim0), 4 by default + feat_per_sec (int): number of frames per second, 100 by default, 0.01s frame rate + ignore_num_spk_mismatch (bool): This is a temporary solution to handle speaker mismatch. + Will be removed in the future. + + Returns: + mask (Tensor): A numpy array of shape (num_speakers, encoder_hidden_len). + Dimension: (num_speakers, num_frames) + """ + # get targets with 0.01s frame rate + num_samples = round(a_cut.duration * feat_per_sec) + mask = torch.zeros((num_samples, num_speakers)) + for rttm_sup in segments: + speaker_idx = speaker_to_idx_map[rttm_sup.speaker] + if speaker_idx >= num_speakers: + if ignore_num_spk_mismatch: + continue + else: + raise ValueError(f"Speaker Index {speaker_idx} exceeds the max index: {num_speakers-1}") + stt = max(rttm_sup.start, 0) + ent = min(rttm_sup.end, a_cut.duration) + stf = int(stt * feat_per_sec) + enf = int(ent * feat_per_sec) + mask[stf:enf, speaker_idx] = 1.0 + return mask + + +def get_soft_mask(feat_level_target, num_samples, stride): + """ + Get soft mask from feat_level_target with stride. + This function is needed for speaker diarization with ASR model trainings. + + Args: + feat_level_target (Tensor): A numpy array of shape (num_frames, num_speakers). + Dimension: (num_frames, num_speakers) + num_sample (int): The total number of samples. + stride (int): The stride for the mask. + """ + + num_speakers = feat_level_target.shape[1] + mask = torch.zeros(num_samples, num_speakers) + + for index in range(num_samples): + if index == 0: + seg_stt_feat = 0 + else: + seg_stt_feat = stride * index - 1 - int(stride / 2) + if index == num_samples - 1: + seg_end_feat = feat_level_target.shape[0] + else: + seg_end_feat = stride * index - 1 + int(stride / 2) + mask[index] = torch.mean(feat_level_target[seg_stt_feat : seg_end_feat + 1, :], axis=0) + return mask + + +def get_hidden_length_from_sample_length( + num_samples: int, num_sample_per_mel_frame: int = 160, num_mel_frame_per_asr_frame: int = 8 +) -> int: + """ + Calculate the hidden length from the given number of samples. + This function is needed for speaker diarization with ASR model trainings. + + This function computes the number of frames required for a given number of audio samples, + considering the number of samples per mel frame and the number of mel frames per ASR frame. + + Parameters: + num_samples (int): The total number of audio samples. + num_sample_per_mel_frame (int, optional): The number of samples per mel frame. Default is 160. + num_mel_frame_per_asr_frame (int, optional): The number of mel frames per ASR frame. Default is 8. + + Returns: + hidden_length (int): The calculated hidden length in terms of the number of frames. + """ + mel_frame_count = math.ceil((num_samples + 1) / num_sample_per_mel_frame) + hidden_length = math.ceil(mel_frame_count / num_mel_frame_per_asr_frame) + return int(hidden_length) + + +def speaker_to_target( + a_cut, + num_speakers: int = 4, + num_sample_per_mel_frame: int = 160, + num_mel_frame_per_asr_frame: int = 8, + spk_tar_all_zero: bool = False, + boundary_segments: bool = False, + soft_label: bool = False, + ignore_num_spk_mismatch: bool = True, + soft_thres: float = 0.5, +): + """ + Get rttm samples corresponding to one cut, generate speaker mask numpy.ndarray with shape + (num_speaker, hidden_length). This function is needed for speaker diarization with ASR model trainings. + + Args: + a_cut (MonoCut, MixedCut): + Lhotse Cut instance which is MonoCut or MixedCut instance. + num_speakers (int): + Max number of speakers for all cuts ("mask" dim0), 4 by default + num_sample_per_mel_frame (int): + Number of sample per mel frame, sample_rate / 1000 * window_stride, 160 by default (10ms window stride) + num_mel_frame_per_asr_frame (int): + Encoder subsampling_factor, 8 by default + spk_tar_all_zero (Tensor): + Set to True gives all zero "mask" + boundary_segments (bool): + Set to True to include segments containing the boundary of the cut, + False by default for multi-speaker ASR training + soft_label (bool): + Set to True to use soft label that enables values in [0, 1] range, + False by default and leads to binary labels. + ignore_num_spk_mismatch (bool): + This is a temporary solution to handle speaker mismatch. Will be removed in the future. + + Returns: + mask (Tensor): Speaker mask with shape (num_speaker, hidden_lenght) + """ + # get cut-related segments from rttms + if isinstance(a_cut, MixedCut): + cut_list = [track.cut for track in a_cut.tracks if isinstance(track.cut, MonoCut)] + offsets = [track.offset for track in a_cut.tracks if isinstance(track.cut, MonoCut)] + elif isinstance(a_cut, MonoCut): + cut_list = [a_cut] + offsets = [0] + else: + raise ValueError(f"Unsupported cut type type{a_cut}: only MixedCut and MonoCut are supported") + + segments_total = [] + for i, cut in enumerate(cut_list): + rttms = SupervisionSet.from_rttm(cut.rttm_filepath) + if boundary_segments: # segments with seg_start < total_end and seg_end > total_start are included + segments_iterator = find_segments_from_rttm( + recording_id=cut.recording_id, rttms=rttms, start_after=cut.start, end_before=cut.end, tolerance=0.0 + ) + else: # segments with seg_start > total_start and seg_end < total_end are included + segments_iterator = rttms.find( + recording_id=cut.recording_id, start_after=cut.start, end_before=cut.end, adjust_offset=True + ) + + for seg in segments_iterator: + if seg.start < 0: + seg.duration += seg.start + seg.start = 0 + if seg.end > cut.duration: + seg.duration -= seg.end - cut.duration + seg.start += offsets[i] + segments_total.append(seg) + + # apply arrival time sorting to the existing segments + segments_total.sort(key=lambda rttm_sup: rttm_sup.start) + + seen = set() + seen_add = seen.add + speaker_ats = [s.speaker for s in segments_total if not (s.speaker in seen or seen_add(s.speaker))] + + speaker_to_idx_map = {spk: idx for idx, spk in enumerate(speaker_ats)} + if len(speaker_to_idx_map) > num_speakers and not ignore_num_spk_mismatch: # raise error if number of speakers + raise ValueError( + f"Number of speakers {len(speaker_to_idx_map)} is larger than " + f"the maximum number of speakers {num_speakers}" + ) + + # initialize mask matrices (num_speaker, encoder_hidden_len) + feat_per_sec = int(a_cut.sampling_rate / num_sample_per_mel_frame) # 100 by default + num_samples = get_hidden_length_from_sample_length( + a_cut.num_samples, num_sample_per_mel_frame, num_mel_frame_per_asr_frame + ) + if spk_tar_all_zero: + frame_mask = torch.zeros((num_samples, num_speakers)) + else: + frame_mask = get_mask_from_segments( + segments_total, a_cut, speaker_to_idx_map, num_speakers, feat_per_sec, ignore_num_spk_mismatch + ) + soft_mask = get_soft_mask(frame_mask, num_samples, num_mel_frame_per_asr_frame) + + if soft_label: + mask = soft_mask + else: + mask = (soft_mask > soft_thres).float() + + return mask diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index 5d3a0bf4274e..15ec8a24a3bd 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -14,28 +14,23 @@ import gc import json -import math import os import shutil from copy import deepcopy from typing import Dict, List, Tuple, Union import numpy as np -import omegaconf import soundfile as sf import torch -from pyannote.core import Annotation, Segment +from omegaconf.listconfig import ListConfig +from pyannote.core import Annotation, Segment, Timeline from tqdm import tqdm from nemo.collections.asr.data.audio_to_label import repeat_signal from nemo.collections.asr.parts.utils.longform_clustering import LongFormSpeakerClustering -from nemo.collections.asr.parts.utils.offline_clustering import SpeakerClustering, get_argmin_mat, split_input_data +from nemo.collections.asr.parts.utils.offline_clustering import get_argmin_mat, split_input_data from nemo.utils import logging -""" -This file contains all the utility functions required for speaker embeddings part in diarization scripts -""" - def get_uniqname_from_filepath(filepath): """ @@ -81,10 +76,13 @@ def audio_rttm_map(manifest, attach_dur=False): """ This function creates AUDIO_RTTM_MAP which is used by all diarization components to extract embeddings, cluster and unify time stamps - Args: manifest file that contains keys audio_filepath, rttm_filepath if exists, text, num_speakers if known and uem_filepath if exists - returns: - AUDIO_RTTM_MAP (dict) : A dictionary with keys of uniq id, which is being used to map audio files and corresponding rttm files + Args: + manifest (str): Path to the manifest file + attach_dur (bool, optional): If True, attach duration information to the unique name. Defaults to False. + + Returns: + AUDIO_RTTM_MAP (dict) : Dictionary with unique names as keys and corresponding metadata as values. """ AUDIO_RTTM_MAP = {} @@ -108,15 +106,17 @@ def audio_rttm_map(manifest, attach_dur=False): if attach_dur: uniqname = get_uniq_id_with_dur(meta) else: - uniqname = get_uniqname_from_filepath(filepath=meta['audio_filepath']) + if "uniq_id" in dic.keys(): + uniqname = dic['uniq_id'] + else: + uniqname = get_uniqname_from_filepath(filepath=meta['audio_filepath']) if uniqname not in AUDIO_RTTM_MAP: AUDIO_RTTM_MAP[uniqname] = meta else: raise KeyError( - "file {} is already part of AUDIO_RTTM_MAP, it might be duplicated, Note: file basename must be unique".format( - meta['audio_filepath'] - ) + f"file {meta['audio_filepath']} is already part of AUDIO_RTTM_MAP, it might be duplicated, " + "Note: file basename must be unique" ) return AUDIO_RTTM_MAP @@ -144,7 +144,7 @@ def parse_scale_configs(window_lengths_in_sec, shift_lengths_in_sec, multiscale_ """ check_float_config = [isinstance(var, float) for var in (window_lengths_in_sec, shift_lengths_in_sec)] check_list_config = [ - isinstance(var, (omegaconf.listconfig.ListConfig, list, tuple)) + isinstance(var, (ListConfig, list, tuple)) for var in (window_lengths_in_sec, shift_lengths_in_sec, multiscale_weights) ] if all(check_list_config) or all(check_float_config): @@ -247,7 +247,8 @@ def get_embs_and_timestamps(multiscale_embeddings_and_timestamps, multiscale_arg def get_timestamps(multiscale_timestamps, multiscale_args_dict): """ The timestamps in `multiscale_timestamps` dictionary are indexed by scale index. - This function rearranges the extracted speaker embedding and timestamps by unique ID to make the further processing more convenient. + This function rearranges the extracted speaker embedding and timestamps by unique ID + to make the further processing more convenient. Args: multiscale_timestamps (dict): @@ -441,13 +442,20 @@ def perform_clustering( 'embeddings' : Tensor containing embeddings. Dimensions:(# of embs) x (emb. dimension) 'timestamps' : Tensor containing ime stamps list for each audio recording 'multiscale_segment_counts' : Tensor containing the number of segments for each scale - AUDIO_RTTM_MAP (dict): AUDIO_RTTM_MAP for mapping unique id with audio file path and rttm path - out_rttm_dir (str): Path to write predicted rttms - clustering_params (dict): clustering parameters provided through config that contains max_num_speakers (int), - oracle_num_speakers (bool), max_rp_threshold(float), sparse_search_volume(int) and enhance_count_threshold (int) - use_torch_script (bool): Boolean that determines whether to use torch.jit.script for speaker clustering - device (torch.device): Device we are running on ('cpu', 'cuda'). - verbose (bool): Enable TQDM progress bar. + AUDIO_RTTM_MAP (dict): + AUDIO_RTTM_MAP for mapping unique id with audio file path and rttm path + out_rttm_dir (str): + Path to write predicted rttms + clustering_params (dict): + Clustering parameters provided through config that contains max_num_speakers (int), + oracle_num_speakers (bool), max_rp_threshold(float), sparse_search_volume(int) + and enhance_count_threshold (int). + use_torch_script (bool): + Boolean that determines whether to use torch.jit.script for speaker clustering + device (torch.device): + Device we are running on ('cpu', 'cuda'). + verbose (bool): + Enable TQDM progress bar. Returns: all_reference (list[uniq_name,Annotation]): reference annotations for score calculation @@ -585,7 +593,7 @@ def write_overlap_segments(outfile, AUDIO_RTTM_MAP, uniq_id, overlap_range_list, Number of decimals to round the offset and duration values. """ audio_path = AUDIO_RTTM_MAP[uniq_id]['audio_filepath'] - for (stt, end) in overlap_range_list: + for stt, end in overlap_range_list: meta = { "audio_filepath": audio_path, "offset": round(stt, decimals), @@ -614,9 +622,8 @@ def read_rttm_lines(rttm_file_path): lines = f.readlines() else: raise FileNotFoundError( - "Requested to construct manifest from rttm with oracle VAD option or from NeMo VAD but received filename as {}".format( - rttm_file_path - ) + "Requested to construct manifest from rttm with oracle VAD option " + f"or from NeMo VAD but received filename as {rttm_file_path}" ) return lines @@ -745,14 +752,14 @@ def fl2int(x: float, decimals: int = 3) -> int: """ Convert floating point number to integer. """ - return torch.round(torch.tensor([x * (10 ** decimals)]), decimals=0).int().item() + return torch.round(torch.tensor([x * (10**decimals)]), decimals=0).int().item() def int2fl(x: int, decimals: int = 3) -> float: """ Convert integer to floating point number. """ - return torch.round(torch.tensor([x / (10 ** decimals)]), decimals=decimals).item() + return torch.round(torch.tensor([x / (10**decimals)]), decimals=decimals).item() def merge_float_intervals(ranges: List[List[float]], decimals: int = 5, margin: int = 2) -> List[List[float]]: @@ -886,7 +893,8 @@ def segments_manifest_to_subsegments_manifest( Generate subsegments manifest from segments manifest file Args: segments_manifest file (str): path to segments manifest file, typically from VAD output - subsegments_manifest_file (str): path to output subsegments manifest file (default (None) : writes to current working directory) + subsegments_manifest_file (str): path to output subsegments manifest file + (default (None) : writes to current working directory) window (float): window length for segments to subsegments length shift (float): hop length for subsegments shift min_subsegments_duration (float): exclude subsegments smaller than this duration value @@ -898,9 +906,10 @@ def segments_manifest_to_subsegments_manifest( pwd = os.getcwd() subsegments_manifest_file = os.path.join(pwd, 'subsegments.json') - with open(segments_manifest_file, 'r') as segments_manifest, open( - subsegments_manifest_file, 'w' - ) as subsegments_manifest: + with ( + open(segments_manifest_file, 'r') as segments_manifest, + open(subsegments_manifest_file, 'w') as subsegments_manifest, + ): segments = segments_manifest.readlines() for segment in segments: segment = segment.strip() @@ -928,32 +937,73 @@ def segments_manifest_to_subsegments_manifest( return subsegments_manifest_file -def get_subsegments(offset: float, window: float, shift: float, duration: float) -> List[List[float]]: +def get_subsegments( + offset: float, + window: float, + shift: float, + duration: float, + min_subsegment_duration: float = 0.01, + decimals: int = 2, + use_asr_style_frame_count: bool = False, + sample_rate: int = 16000, + feat_per_sec: int = 100, +) -> List[List[float]]: """ - Return subsegments from a segment of audio file + Return subsegments from a segment of audio file. + + Example: + (window, shift) = 1.5, 0.75 + Segment: [12.05, 14.45] + Subsegments: [[12.05, 13.55], [12.8, 14.3], [13.55, 14.45], [14.3, 14.45]] + Args: - offset (float): start time of audio segment - window (float): window length for segments to subsegments length - shift (float): hop length for subsegments shift - duration (float): duration of segment + offset (float): Start time of audio segment + window (float): Window length for segments to subsegments length + shift (float): Hop length for subsegments shift + duration (float): Duration of segment + min_subsegment_duration (float): Exclude subsegments smaller than this duration value + decimals (int): Number of decimal places to round to + use_asr_style_frame_count (bool): If True, use asr style frame count to generate subsegments. + For example, if duration is 10 secs and frame_shift is 0.08 secs, + it results in (10/0.08)+1 = 125 + 1 frames. + Returns: - subsegments (List[tuple[float, float]]): subsegments generated for the segments as list of tuple of start and duration of each subsegment + subsegments (List[tuple[float, float]]): subsegments generated for the segments as + list of tuple of start and duration of each subsegment """ subsegments: List[List[float]] = [] start = offset slice_end = start + duration - base = math.ceil((duration - window) / shift) - slices = 1 if base < 0 else base + 1 - for slice_id in range(slices): - end = start + window - if end > slice_end: - end = slice_end - subsegments.append([start, end - start]) - start = offset + (slice_id + 1) * shift + if min_subsegment_duration <= duration < shift: + slices = 1 + elif use_asr_style_frame_count is True: + num_feat_frames = np.ceil((1 + duration * sample_rate) / int(sample_rate / feat_per_sec)).astype(int) + slices = np.ceil(num_feat_frames / int(feat_per_sec * shift)).astype(int) + slice_end = start + shift * slices + else: + slices = np.ceil(1 + (duration - window) / shift).astype(int) + if slices == 1: + if min(duration, window) >= min_subsegment_duration: + subsegments.append([start, min(duration, window)]) + elif slices > 0: # What if slcies = 0 ? + start_col = torch.arange(offset, slice_end, shift)[:slices] + dur_col_raw = torch.min( + slice_end * torch.ones_like(start_col) - start_col, window * torch.ones_like(start_col) + ) + dur_col = torch.round(dur_col_raw, decimals=decimals) + valid_mask = dur_col >= min_subsegment_duration + valid_subsegments = torch.stack([start_col[valid_mask], dur_col[valid_mask]], dim=1) + subsegments = valid_subsegments.tolist() return subsegments -def get_target_sig(sig, start_sec: float, end_sec: float, slice_length: int, sample_rate: int,) -> torch.Tensor: +def get_target_sig( + sig, + start_sec: float, + end_sec: float, + slice_length: int, + sample_rate: int, +) -> torch.Tensor: """ Extract time-series signal from the given audio buffer based on the start and end timestamps. @@ -1000,6 +1050,34 @@ def tensor_to_list(range_tensor: torch.Tensor) -> List[List[float]]: return [[float(range_tensor[k][0]), float(range_tensor[k][1])] for k in range(range_tensor.shape[0])] +def generate_diarization_output_lines(speaker_timestamps: List[List[float]], model_spk_num: int) -> List[str]: + """ + Generate diarization output lines list from the speaker timestamps list by merging overlapping intervals. + + Args: + speaker_timestamps (list): + List containing the start and end time of the speech intervals for each speaker. + Example: + >>> speaker_timestamps = [[0.5, 3.12], [3.51, 7.26],... ] + model_spk_num (int): + Number of speakers in the model. + + Returns: + speaker_lines_total (list): + List containing the diarization output lines in the format: + "start_time end_time speaker_id" + Example: + >>> speaker_lines_total = ["0.5 3.12 speaker_0", "3.51 7.26 speaker_1",...] + """ + speaker_lines_total = [] + for spk_idx in range(model_spk_num): + ts_invervals = speaker_timestamps[spk_idx] + merged_ts_intervals = merge_float_intervals(ts_invervals) + for ts_interval in merged_ts_intervals: + speaker_lines_total.extend([f"{ts_interval[0]:.3f} {ts_interval[1]:.3f} speaker_{int(spk_idx)}"]) + return speaker_lines_total + + def get_speech_labels_for_update( frame_start: float, buffer_end: float, @@ -1067,9 +1145,12 @@ def get_speech_labels_for_update( return speech_label_for_new_segments, cumulative_speech_labels -def get_new_cursor_for_update(frame_start: float, segment_range_ts: List[List[float]],) -> Tuple[float, int]: +def get_new_cursor_for_update( + frame_start: float, + segment_range_ts: List[List[float]], +) -> Tuple[float, int]: """ - Function for updating a cursor online speaker diarization. + Function for updating a cursor online speaker diarization. Remove the old segments that overlap with the new frame (self.frame_start) cursor_for_old_segments is set to the onset of the t_range popped lastly. @@ -1227,7 +1308,10 @@ def get_online_subsegments_from_buffer( range_t = [max(0, range_offs[0]), range_offs[1]] subsegments = get_subsegments( - offset=range_t[0], window=window, shift=shift, duration=(range_t[1] - range_t[0]), + offset=range_t[0], + window=window, + shift=shift, + duration=(range_t[1] - range_t[0]), ) ind_offset, sigs, ranges, inds = get_online_segments_from_slices( sig=audio_buffer, @@ -1277,20 +1361,22 @@ def get_scale_mapping_argmat(uniq_embs_and_timestamps: Dict[str, dict]) -> Dict[ def get_overlap_stamps(cont_stamps: List[str], ovl_spk_idx: List[str]): """ - Generate timestamps that include overlap speech. Overlap-including timestamps are created based on the segments that are - created for clustering diarizer. Overlap speech is assigned to the existing speech segments in `cont_stamps`. + Generate timestamps that include overlap speech. Overlap-including timestamps are created based on + the segments that are created for clustering diarizer. Overlap speech is assigned to the existing + speech segments in `cont_stamps`. Args: cont_stamps (list): - Non-overlapping (single speaker per segment) diarization output in string format. - Each line contains the start and end time of segments and corresponding speaker labels. + Non-overlapping (single speaker per segment) diarization output in string format. Each line + contains the start and end time of segments and corresponding speaker labels. ovl_spk_idx (list): - List containing segment index of the estimated overlapped speech. The start and end of segments are based on the - single-speaker (i.e., non-overlap-aware) RTTM generation. + List containing segment index of the estimated overlapped speech. The start and end of + segments are based on the single-speaker (i.e., non-overlap-aware) RTTM generation. + Returns: total_ovl_cont_list (list): - Rendered diarization output in string format. Each line contains the start and end time of segments and - corresponding speaker labels. This format is identical to `cont_stamps`. + Rendered diarization output in string format. Each line contains the start and end time of + segments and corresponding speaker labels. This format is identical to `cont_stamps`. """ ovl_spk_cont_list = [[] for _ in range(len(ovl_spk_idx))] for spk_idx in range(len(ovl_spk_idx)): @@ -1307,18 +1393,21 @@ def get_overlap_stamps(cont_stamps: List[str], ovl_spk_idx: List[str]): def get_adaptive_threshold(estimated_num_of_spks: int, min_threshold: float, overlap_infer_spk_limit: int): """ - This function controls the magnitude of the sigmoid threshold based on the estimated number of speakers. As the number of - speakers becomes larger, diarization error rate is very sensitive on overlap speech detection. This function linearly increases - the threshold in proportion to the estimated number of speakers so more confident overlap speech results are reflected when - the number of estimated speakers are relatively high. + This function controls the magnitude of the sigmoid threshold based on the estimated number of + speakers. As the number of speakers becomes larger, diarization error rate is very sensitive + to overlap speech detection. This function linearly increases the threshold in proportion to + the estimated number of speakers so more confident overlap speech results are reflected when + the number of estimated speakers is relatively high. Args: estimated_num_of_spks (int): Estimated number of speakers from the clustering result. min_threshold (float): - Sigmoid threshold value from the config file. This threshold value is minimum threshold value when `estimated_num_of_spks=2` + Sigmoid threshold value from the config file. This threshold value is the minimum + threshold when `estimated_num_of_spks=2`. overlap_infer_spk_limit (int): - If the `estimated_num_of_spks` is less then `overlap_infer_spk_limit`, overlap speech estimation is skipped. + If the `estimated_num_of_spks` is less than `overlap_infer_spk_limit`, overlap speech + estimation is skipped. Returns: adaptive_threshold (float): @@ -1333,37 +1422,41 @@ def get_adaptive_threshold(estimated_num_of_spks: int, min_threshold: float, ove def generate_speaker_timestamps( clus_labels: List[Union[float, int]], msdd_preds: List[torch.Tensor], **params ) -> Tuple[List[str], List[str]]: - ''' - Generate speaker timestamps from the segmentation information. If `use_clus_as_main=True`, use clustering result for main speaker - labels and use timestamps from the predicted sigmoid values. In this function, the main speaker labels in `maj_labels` exist for - every subsegment steps while overlap speaker labels in `ovl_labels` only exist for segments where overlap-speech is occuring. + """ + Generate speaker timestamps from the segmentation information. If `use_clus_as_main=True`, use + clustering result for main speaker labels and use timestamps from the predicted sigmoid values. + In this function, the main speaker labels in `maj_labels` exist for every subsegment step, while + overlap speaker labels in `ovl_labels` only exist for segments where overlap speech occurs. Args: clus_labels (list): List containing integer-valued speaker clustering results. msdd_preds (list): - List containing tensors of the predicted sigmoid values. - Each tensor has shape of: (Session length, estimated number of speakers). + List containing tensors of the predicted sigmoid values. Each tensor has shape of: + (Session length, estimated number of speakers). params: Parameters for generating RTTM output and evaluation. Parameters include: - infer_overlap (bool): If False, overlap-speech will not be detected. - use_clus_as_main (bool): Add overlap-speech detection from MSDD to clustering results. If False, only MSDD output - is used for constructing output RTTM files. + infer_overlap (bool): If False, overlap speech will not be detected. + use_clus_as_main (bool): Add overlap-speech detection from MSDD to clustering results. + If False, only MSDD output is used for constructing output + RTTM files. overlap_infer_spk_limit (int): Above this limit, overlap-speech detection is bypassed. - use_adaptive_thres (bool): Boolean that determines whehther to use adaptive_threshold depending on the estimated - number of speakers. + use_adaptive_thres (bool): Boolean that determines whether to use adaptive thresholds + depending on the estimated number of speakers. max_overlap_spks (int): Maximum number of overlap speakers detected. Default is 2. threshold (float): Sigmoid threshold for MSDD output. Returns: maj_labels (list): - List containing string-formated single-speaker speech segment timestamps and corresponding speaker labels. + List containing string-formatted single-speaker speech segment timestamps and corresponding + speaker labels. Example: [..., '551.685 552.77 speaker_1', '552.99 554.43 speaker_0', '554.97 558.19 speaker_0', ...] ovl_labels (list): - List containing string-formated additional overlapping speech segment timestamps and corresponding speaker labels. - Note that `ovl_labels` includes only overlapping speech that is not included in `maj_labels`. + List containing string-formatted additional overlapping speech segment timestamps and + corresponding speaker labels. Note that `ovl_labels` includes only overlapping speech that + is not included in `maj_labels`. Example: [..., '152.495 152.745 speaker_1', '372.71 373.085 speaker_0', '554.97 555.885 speaker_1', ...] - ''' + """ msdd_preds.squeeze(0) estimated_num_of_spks = msdd_preds.shape[-1] overlap_speaker_list = [[] for _ in range(estimated_num_of_spks)] @@ -1398,8 +1491,7 @@ def generate_speaker_timestamps( def get_uniq_id_list_from_manifest(manifest_file: str): - """Retrieve `uniq_id` values from the given manifest_file and save the IDs to a list. - """ + """Retrieve `uniq_id` values from the given manifest_file and save the IDs to a list.""" uniq_id_list = [] with open(manifest_file, 'r', encoding='utf-8') as manifest: for i, line in enumerate(manifest.readlines()): @@ -1418,7 +1510,8 @@ def get_id_tup_dict(uniq_id_list: List[str], test_data_collection, preds_list: L uniq_id_list (list): List containing the `uniq_id` values. test_data_collection (collections.DiarizationLabelEntity): - Class instance that is containing session information such as targeted speaker indices, audio filepath and RTTM filepath. + Class instance that is containing session information such as targeted speaker indices, + audio filepath and RTTM filepath. preds_list (list): List containing tensors of predicted sigmoid values. @@ -1447,11 +1540,14 @@ def prepare_split_data(manifest_filepath, _out_dir, multiscale_args_dict, global Returns: multiscale_args_dict (dict): - - Dictionary containing two types of arguments: multi-scale weights and subsegment timestamps for each data sample. + - Dictionary containing two types of arguments: multi-scale weights and subsegment timestamps + for each data sample. - Each data sample has two keys: `multiscale_weights` and `scale_dict`. - `multiscale_weights` key contains a list containing multiscale weights. - `scale_dict` is indexed by integer keys which are scale index. - - Each data sample is indexed by using the following naming convention: `__` + - Each data sample is indexed by using the following naming convention: + `__` + Example: `fe_03_00106_mixed_626310_642300` """ speaker_dir = os.path.join(_out_dir, 'speaker_outputs') @@ -1580,6 +1676,86 @@ def make_rttm_with_overlap( return all_reference, all_hypothesis +def timestamps_to_pyannote_object( + speaker_timestamps: List[Tuple[float, float]], + uniq_id: str, + audio_rttm_values: Dict[str, str], + all_hypothesis: List[Tuple[str, Timeline]], + all_reference: List[Tuple[str, Timeline]], + all_uems: List[Tuple[str, Timeline]], + out_rttm_dir: str | None, +): + """ + Convert speaker timestamps to pyannote.core.Timeline object. + + Args: + speaker_timestamps (List[Tuple[float, float]]): + Timestamps of each speaker: start time and end time of each speaker. + uniq_id (str): + Unique ID of each speaker. + audio_rttm_values (Dict[str, str]): + Dictionary of manifest values. + all_hypothesis (List[Tuple[str, pyannote.core.Timeline]]): + List of hypothesis in pyannote.core.Timeline object. + all_reference (List[Tuple[str, pyannote.core.Timeline]]): + List of reference in pyannote.core.Timeline object. + all_uems (List[Tuple[str, pyannote.core.Timeline]]): + List of uems in pyannote.core.Timeline object. + out_rttm_dir (str | None): + Directory to save RTTMs + + Returns: + all_hypothesis (List[Tuple[str, pyannote.core.Timeline]]): + List of hypothesis in pyannote.core.Timeline object with an added Timeline object. + all_reference (List[Tuple[str, pyannote.core.Timeline]]): + List of reference in pyannote.core.Timeline object with an added Timeline object. + all_uems (List[Tuple[str, pyannote.core.Timeline]]): + List of uems in pyannote.core.Timeline object with an added Timeline object. + """ + offset, dur = float(audio_rttm_values.get('offset', None)), float(audio_rttm_values.get('duration', None)) + hyp_labels = generate_diarization_output_lines( + speaker_timestamps=speaker_timestamps, model_spk_num=len(speaker_timestamps) + ) + hypothesis = labels_to_pyannote_object(hyp_labels, uniq_name=uniq_id) + if out_rttm_dir is not None and os.path.exists(out_rttm_dir): + with open(f'{out_rttm_dir}/{uniq_id}.rttm', 'w') as f: + hypothesis.write_rttm(f) + all_hypothesis.append([uniq_id, hypothesis]) + rttm_file = audio_rttm_values.get('rttm_filepath', None) + if rttm_file is not None and os.path.exists(rttm_file): + uem_lines = [[offset, dur + offset]] + org_ref_labels = rttm_to_labels(rttm_file) + ref_labels = org_ref_labels + reference = labels_to_pyannote_object(ref_labels, uniq_name=uniq_id) + uem_obj = get_uem_object(uem_lines, uniq_id=uniq_id) + all_uems.append(uem_obj) + all_reference.append([uniq_id, reference]) + return all_hypothesis, all_reference, all_uems + + +def get_uem_object(uem_lines: List[List[float]], uniq_id: str): + """ + Generate pyannote timeline segments for uem file. + + file format + UNIQ_SPEAKER_ID CHANNEL START_TIME END_TIME + + Args: + uem_lines (list): list of session ID and start, end times. + Example: + [[0.0, 30.41], [60.04, 165.83]] + uniq_id (str): Unique session ID. + + Returns: + timeline (pyannote.core.Timeline): pyannote timeline object. + """ + timeline = Timeline(uri=uniq_id) + for uem_stt_end in uem_lines: + start_time, end_time = uem_stt_end + timeline.add(Segment(float(start_time), float(end_time))) + return timeline + + def embedding_normalize(embs, use_std=False, eps=1e-10): """ Mean and l2 length normalize the input speaker embeddings diff --git a/nemo/collections/asr/parts/utils/vad_utils.py b/nemo/collections/asr/parts/utils/vad_utils.py index aea04b8cafcf..83a811ee4adb 100644 --- a/nemo/collections/asr/parts/utils/vad_utils.py +++ b/nemo/collections/asr/parts/utils/vad_utils.py @@ -23,31 +23,22 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple, Union +import IPython.display as ipd import librosa import matplotlib.pyplot as plt import numpy as np import pandas as pd import torch -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from pyannote.core import Annotation, Segment from pyannote.metrics import detection from sklearn.metrics import roc_auc_score from sklearn.model_selection import ParameterGrid from tqdm import tqdm - from nemo.collections.asr.models import EncDecClassificationModel, EncDecFrameClassificationModel from nemo.collections.common.parts.preprocessing.manifest import get_full_path from nemo.utils import logging -HAVE_IPYTHON = False -try: - import IPython.display as ipd - - HAVE_IPYTHON = True -except: - HAVE_IPYTHON = False - - """ This file contains all the utility functions required for voice activity detection. """ @@ -74,8 +65,8 @@ def prepare_manifest(config: dict) -> str: input_list = config['input'] else: raise ValueError( - "The input for manifest preparation would either be a string of the filepath to \ - manifest or a list of {'audio_filepath': i, 'offset': 0, 'duration': null} " + "The input for manifest preparation would either be a string of the filepath to manifest " + "or a list of {'audio_filepath': i, 'offset': 0, 'duration': null}." ) args_func = { @@ -204,8 +195,7 @@ def write_vad_infer_manifest(file: dict, args_func: dict) -> list: def get_vad_stream_status(data: list) -> list: """ - Generate a list of status for each snippet in manifest. - A snippet should be in single, start, next or end status. + Generate a list of status for each snippet in manifest. A snippet should be in single, start, next or end status. Used for concatenating to full audio file. Args: data (list): list of filepath of audio snippet @@ -321,9 +311,8 @@ def generate_overlap_vad_seq_per_tensor( frame: torch.Tensor, per_args: Dict[str, float], smoothing_method: str ) -> torch.Tensor: """ - Use generated frame prediction (generated by shifting window of shift_length_in_sec (10ms)) - to generate prediction with overlapping input window/segments - See description in generate_overlap_vad_seq. + Use generated frame prediction (generated by shifting window of shift_length_in_sec (10ms)) to generate + prediction with overlapping input window/segments. See description in generate_overlap_vad_seq. Use this for single instance pipeline. """ # This function will be refactor for vectorization but this is okay for now @@ -484,8 +473,8 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te Binarize predictions to speech and non-speech Reference - Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", \ - InterSpeech 2015. + Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice + Activity Detection", InterSpeech 2015. Implementation: https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py Args: @@ -498,8 +487,8 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te frame_length_in_sec (float): length of frame. Returns: - speech_segments(torch.Tensor): A tensor of speech segment in torch.Tensor([[start1, end1], [start2, end2]]) \ - format. + speech_segments(torch.Tensor): A tensor of speech segment in the form of: + `torch.Tensor([[start1, end1], [start2, end2]])`. """ frame_length_in_sec = per_args.get('frame_length_in_sec', 0.01) @@ -549,11 +538,10 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te def remove_segments(original_segments: torch.Tensor, to_be_removed_segments: torch.Tensor) -> torch.Tensor: """ Remove speech segments list in to_be_removed_segments from original_segments. - For example, - remove torch.Tensor([[start2, end2],[start4, end4]]) from torch.Tensor([[start1, end1],[start2, end2],\ - [start3, end3], [start4, end4]]), - -> - torch.Tensor([[start1, end1],[start3, end3]]) + (Example) Remove torch.Tensor([[start2, end2],[start4, end4]]) + from torch.Tensor([[start1, end1],[start2, end2],[start3, end3], [start4, end4]]), + -> + torch.Tensor([[start1, end1],[start3, end3]]) """ for y in to_be_removed_segments: original_segments = original_segments[original_segments.eq(y).all(dim=1).logical_not()] @@ -574,24 +562,30 @@ def get_gap_segments(segments: torch.Tensor) -> torch.Tensor: @torch.jit.script def filtering(speech_segments: torch.Tensor, per_args: Dict[str, float]) -> torch.Tensor: """ - Filter out short non_speech and speech segments. + Filter out short non-speech and speech segments. + + Reference: + Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice + Activity Detection", InterSpeech 2015. + Implementation: + https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py - Reference - Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", \ - InterSpeech 2015. - Implementation: https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py Args: - speech_segments (torch.Tensor): A tensor of speech segment in torch.Tensor([[start1, end1], \ - [start2, end2]]) format. + speech_segments (torch.Tensor): + A tensor of speech segments in the format + torch.Tensor([[start1, end1], [start2, end2]]). per_args: - min_duration_on (float): threshold for small non_speech deletion - min_duration_off (float): threshold for short speech segment deletion - filter_speech_first (float): Whether to perform short speech segment deletion first. \ - Use 1.0 to represent True. + min_duration_on (float): + Threshold for small non-speech deletion. + min_duration_off (float): + Threshold for short speech segment deletion. + filter_speech_first (float): + Whether to perform short speech segment deletion first. Use 1.0 to represent True. Returns: - speech_segments(torch.Tensor): A tensor of filtered speech segment in \ - torch.Tensor([[start1, end1], [start2, end2]]) format. + speech_segments (torch.Tensor): + A tensor of filtered speech segments in the format + torch.Tensor([[start1, end1], [start2, end2]]). """ if speech_segments.shape == torch.Size([0]): return speech_segments @@ -840,18 +834,19 @@ def vad_tune_threshold_on_dev( num_workers: int = 20, ) -> Tuple[dict, dict]: """ - Tune thresholds on dev set. Return best thresholds which gives the lowest detection error rate - (DetER) in thresholds. + Tune thresholds on dev set. Return best thresholds which gives the lowest + detection error rate (DetER) in thresholds. + Args: params (dict): dictionary of parameters to be tuned on. vad_pred_method (str): suffix of prediction file. Use to locate file. - Should be either in "frame", "mean" or "median". - groundtruth_RTTM_dir (str): directory of ground-truth rttm files or a file contains the paths of them. - focus_metric (str): metrics we care most when tuning threshold. Should be either in "DetER", "FA", "MISS" - frame_length_in_sec (float): frame length. - num_workers (int): number of workers. + Should be either in "frame", "mean" or "median". + groundtruth_RTTM_dir (str): Directory of ground-truth rttm files or a file contains the paths of them. + focus_metric (str): Metrics we care most when tuning threshold. Should be either in "DetER", "FA", "MISS" + frame_length_in_sec (float): Frame length. + num_workers (int): Number of workers. Returns: - best_threshold (float): threshold that gives lowest DetER. + best_threshold (float): Threshold that gives lowest DetER. """ min_score = 100 all_perf = {} @@ -936,8 +931,7 @@ def check_if_param_valid(params: dict) -> bool: for j in params[i]: if not j >= 0: raise ValueError( - "Invalid inputs! All float parameters except pad_onset and pad_offset should be \ - larger than 0!" + "Invalid inputs! All float parameters except pad_onset and pad_offset should be larger than 0!" ) if not (all(i <= 1 for i in params['onset']) and all(i <= 1 for i in params['offset'])): @@ -995,7 +989,7 @@ def plot( unit_frame_len: float = 0.01, label_repeat: int = 1, xticks_step: int = 5, -) -> "ipd.Audio": +) -> ipd.Audio: """ Plot Audio and/or VAD output and/or groundtruth labels for visualization Args: @@ -1009,13 +1003,10 @@ def plot( threshold (float): threshold for prediction score (from 0 to 1). per_args(dict): a dict that stores the thresholds for postprocessing. unit_frame_len (float): unit frame length in seconds for VAD predictions. - label_repeat (int): repeat the label for this number of times to match different \ - frame lengths in preds and labels. + label_repeat (int): repeat the label for this number of times to match different + frame lengths in preds and labels. xticks_step (int): step size for xticks. """ - if HAVE_IPYTHON is False: - raise ImportError("IPython is not installed. Please install IPython to use this function.") - plt.figure(figsize=[20, 2]) audio, sample_rate = librosa.load( @@ -1281,8 +1272,8 @@ def stitch_segmented_asr_output( fout.flush() logging.info( - f"Finish stitch segmented ASR output to {stitched_output_manifest}, \ - the speech segments info has been stored in directory {speech_segments_tensor_dir}" + f"Finish stitch segmented ASR output to {stitched_output_manifest}, " + f"the speech segments info has been stored in directory {speech_segments_tensor_dir}" ) return stitched_output_manifest @@ -1462,13 +1453,10 @@ def plot_sample_from_rttm( show: bool = True, offset: float = 0.0, unit_frame_len: float = 0.01, -) -> "ipd.Audio": +): """ Plot audio signal and frame-level labels from RTTM file """ - if HAVE_IPYTHON is False: - raise ImportError("IPython is not installed. Please install IPython to use this function.") - plt.figure(figsize=[20, 2]) audio, sample_rate = librosa.load(path=audio_file, sr=16000, mono=True, offset=offset, duration=max_duration) @@ -1502,17 +1490,22 @@ def plot_sample_from_rttm( def align_labels_to_frames(probs, labels, threshold=0.2): """ - Aligns labels to frames when the frame length (e.g., 10ms) is different from the label length (e.g., 20ms). - The threshold 0.2 is not important, since the actual ratio will always be close to an integer - unless using frame/label. lengths that are not multiples of each other - (e.g., 15ms frame length and 20ms label length), which is not valid. - The value 0.2 here is just for easier unit testing. + Aligns labels to frames when the frame length (e.g., 10ms) is different from the label length + (e.g., 20ms). The threshold 0.2 is not critical, as the actual ratio will always be close to an + integer unless using frame/label lengths that are not multiples of each other (e.g., 15ms frame + length and 20ms label length), which is not valid. The value 0.2 is chosen for easier unit testing. + Args: - probs (List[float]): list of probabilities - labels (List[int]): list of labels - threshold (float): threshold for rounding ratio to integer + probs (List[float]): + List of probabilities. + labels (List[int]): + List of labels. + threshold (float): + Threshold for rounding the ratio to an integer. + Returns: - labels (List[int]): list of labels aligned to frames + labels (List[int]): + List of labels aligned to frames. """ frames_len = len(probs) labels_len = len(labels) @@ -1543,13 +1536,13 @@ def align_labels_to_frames(probs, labels, threshold=0.2): ratio = frames_len / labels_len res = frames_len % labels_len if ceil(ratio) - ratio < threshold: - # e.g., ratio is 1.83, ceil(ratio) = 2, then we repeat labels to make it a - # multiple of 2, and discard the redundant labels + # e.g., ratio is 1.83, ceil(ratio) = 2, then we repeat labels + # to make it a multiple of 2, and discard the redundant labels labels = labels.repeat_interleave(ceil(ratio), dim=0).long().tolist() labels = labels[:frames_len] else: - # e.g., ratio is 2.02, floor(ratio) = 2, then we repeat labels to make it a multiple of - # 2 and add additional labels + # e.g., ratio is 2.02, floor(ratio) = 2, then we repeat labels + # to make it a multiple of 2 and add additional labels labels = labels.repeat_interleave(floor(ratio), dim=0).long().tolist() if res > 0: labels += labels[-res:] @@ -1743,3 +1736,52 @@ def frame_vad_eval_detection_error( auroc = roc_auc_score(y_true=all_labels, y_score=all_probs) report = metric.report(display=False) return auroc, report + + +def ts_vad_post_processing( + ts_vad_binary_vec: torch.Tensor, + cfg_vad_params: OmegaConf, + unit_10ms_frame_count: int = 8, + bypass_postprocessing: bool = False, +): + """ + Post-processing on diarization results using VAD style post-processing methods. + These post-processing methods are inspired by the following paper: + Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: + a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020). + + Args: + ts_vad_binary_vec (Tensor): + Sigmoid values of each frame and each speaker. + Dimension: (num_frames,) + cfg_vad_params (OmegaConf): + Configuration (omega config) of VAD parameters. + unit_10ms_frame_count (int, optional): + an integer indicating the number of 10ms frames in a unit. + For example, if unit_10ms_frame_count is 8, then each frame is 0.08 seconds. + bypass_postprocessing (bool, optional): + If True, diarization post-processing will be bypassed. + + Returns: + speech_segments (Tensor): + start and end of each speech segment. + Dimension: (num_segments, 2) + + Example: + tensor([[ 0.0000, 3.0400], + [ 6.0000, 6.0800], + ... + [587.3600, 591.0400], + [591.1200, 597.7600]]) + """ + ts_vad_binary_frames = torch.repeat_interleave(ts_vad_binary_vec, unit_10ms_frame_count) + if not bypass_postprocessing: + speech_segments = binarization(ts_vad_binary_frames, cfg_vad_params) + speech_segments = filtering(speech_segments, cfg_vad_params) + else: + cfg_vad_params.onset = 0.5 + cfg_vad_params.offset = 0.5 + cfg_vad_params.pad_onset = 0.0 + cfg_vad_params.pad_offset = 0.0 + speech_segments = binarization(ts_vad_binary_frames, cfg_vad_params) + return speech_segments diff --git a/nemo/collections/common/parts/preprocessing/collections.py b/nemo/collections/common/parts/preprocessing/collections.py index 915f406a3e88..5773ddf4b79b 100644 --- a/nemo/collections/common/parts/preprocessing/collections.py +++ b/nemo/collections/common/parts/preprocessing/collections.py @@ -16,8 +16,7 @@ import json import os from itertools import combinations -from typing import Any, Callable, Dict, Iterable, List, Optional, Union - +from typing import Any, Dict, Iterable, List, Optional, Union import numpy as np import pandas as pd @@ -313,7 +312,10 @@ class InstructionTuningAudioText(_Collection): OUTPUT_TYPE = collections.namedtuple( typename='InstructionTuningText', - field_names='id context context_type context_duration question question_type answer answer_type answer_duration speaker', + field_names=( + 'id context context_type context_duration question ' + 'question_type answer answer_type answer_duration speaker' + ), ) def __init__( @@ -437,7 +439,7 @@ def _get_len(self, field_type, data, duration_data): class ASRAudioText(AudioText): """`AudioText` collector from asr structured json files.""" - def __init__(self, manifests_files: Union[str, List[str]], parse_func: Optional[Callable] = None, *args, **kwargs): + def __init__(self, manifests_files: Union[str, List[str]], *args, **kwargs): """Parse lists of audio files, durations and transcripts texts. Args: @@ -460,9 +462,8 @@ def __init__(self, manifests_files: Union[str, List[str]], parse_func: Optional[ [], [], ) - speakers, orig_srs, token_labels, langs = [], [], [], [] - for item in manifest.item_iter(manifests_files, parse_func=parse_func): + for item in manifest.item_iter(manifests_files): ids.append(item['id']) audio_files.append(item['audio_file']) durations.append(item['duration']) @@ -478,7 +479,10 @@ def __init__(self, manifests_files: Union[str, List[str]], parse_func: Optional[ class SpeechLLMAudioTextEntity(object): + """Class for SpeechLLM dataloader instance.""" + def __init__(self, sid, audio_file, duration, context, answer, offset, speaker, orig_sr, lang) -> None: + """Initialize the AudioTextEntity for a SpeechLLM dataloader instance.""" self.id = sid self.audio_file = audio_file self.duration = duration @@ -559,7 +563,6 @@ def __init__( ): """Instantiates audio-context-answer manifest with filters and preprocessing. - Args: ids: List of examples positions. audio_files: List of audio files. @@ -770,7 +773,8 @@ def __parse_item(self, line: str, manifest_file: str) -> Dict[str, Any]: elif 'question' in item: # compatability with old manifests that uses 'question' as context key logging.warning( - f"Neither `{self.context_key}` is found nor `context_file` is set, but found `question` in item: {item}", + f"Neither `{self.context_key}` is found nor" + f"`context_file` is set, but found `question` in item: {item}", mode=logging_mode.ONCE, ) item['context'] = item.pop('question') @@ -867,7 +871,8 @@ def __init__( else: logging.info(f"Filtered duration for loading collection is {duration_filtered / 3600: .2f} hours.") logging.info( - f"Dataset successfully loaded with {len(data)} items and total duration provided from manifest is {total_duration / 3600: .2f} hours." + f"Dataset successfully loaded with {len(data)} items " + f"and total duration provided from manifest is {total_duration / 3600: .2f} hours." ) self.uniq_labels = sorted(set(map(lambda x: x.label, data))) @@ -1008,13 +1013,15 @@ def __init__( if len(data) == max_number: break - logging.info("# {} files loaded including # {} unique labels".format(len(data), len(self.uniq_labels))) + logging.info(f"# {len(data)} files loaded including # {len(self.uniq_labels)} unique labels") super().__init__(data) def relative_speaker_parser(self, seq_label): """Convert sequence of speaker labels to relative labels. Convert sequence of absolute speaker to sequence of relative speaker [E A C A E E C] -> [0 1 2 1 0 0 2] - In this seq of label , if label do not appear before, assign new relative labels len(pos); else reuse previous assigned relative labels. + In this seq of label , if label do not appear before, assign new relative labels len(pos); + else reuse previous assigned relative labels. + Args: seq_label (str): A string of a sequence of labels. @@ -1051,10 +1058,13 @@ def __init__( """Parse lists of feature files and sequences of labels. Args: - manifests_files: Either single string file or list of such - - manifests to yield items from. - max_number: Maximum number of samples to collect; pass to `FeatureSequenceLabel` constructor. - index_by_file_id: If True, saves a mapping from filename base (ID) to index in data; pass to `FeatureSequenceLabel` constructor. + manifests_files: + Either single string file or list of such manifests to yield items from. + max_number: + Maximum number of samples to collect; pass to `FeatureSequenceLabel` constructor. + index_by_file_id: + If True, saves a mapping from filename base (ID) to index in data; + pass to `FeatureSequenceLabel` constructor. """ feature_files, seq_labels = [], [] @@ -1216,24 +1226,26 @@ def __init__( **kwargs, ): """ - Parse lists of audio files, durations, RTTM (Diarization annotation) files. Since diarization model infers only - two speakers, speaker pairs are generated from the total number of speakers in the session. + Parse lists of audio files, durations, RTTM (Diarization annotation) files. Since the diarization + model infers only two speakers, speaker pairs are generated from the total number of speakers in + the session. Args: manifest_filepath (str): - Path to input manifest json files. + Path to input manifest JSON files. emb_dict (Dict): Dictionary containing cluster-average embeddings and speaker mapping information. clus_label_dict (Dict): Segment-level speaker labels from clustering results. round_digit (int): - Number of digits to be rounded. + Number of digits to round. seq_eval_mode (bool): If True, F1 score will be calculated for each speaker pair during inference mode. pairwise_infer (bool): - If True, this dataset class operates in inference mode. In inference mode, a set of speakers in the input audio - is split into multiple pairs of speakers and speaker tuples (e.g. 3 speakers: [(0,1), (1,2), (0,2)]) and then - fed into the diarization system to merge the individual results. + If True, this dataset class operates in inference mode. In inference mode, a set of + speakers in the input audio is split into multiple pairs of speakers and speaker tuples + (e.g., 3 speakers: [(0,1), (1,2), (0,2)]) and then fed into the diarization system to + merge the individual results. *args: Args to pass to `SpeechLabel` constructor. **kwargs: Kwargs to pass to `SpeechLabel` constructor. """ @@ -1371,6 +1383,188 @@ def __parse_item_rttm(self, line: str, manifest_file: str) -> Dict[str, Any]: return item +class EndtoEndDiarizationLabel(_Collection): + """List of end-to-end diarization audio-label correspondence with preprocessing.""" + + OUTPUT_TYPE = collections.namedtuple( + typename='DiarizationLabelEntity', + field_names='audio_file uniq_id duration rttm_file offset', + ) + + def __init__( + self, + audio_files: List[str], + uniq_ids: List[str], + durations: List[float], + rttm_files: List[str], + offsets: List[float], + max_number: Optional[int] = None, + do_sort_by_duration: bool = False, + index_by_file_id: bool = False, + ): + """ + Instantiates audio-label manifest with filters and preprocessing. + + This method initializes the EndtoEndDiarizationLabel object by processing the input data + and applying optional filters and sorting. + + Args: + audio_files (List[str]): List of audio file paths. + uniq_ids (List[str]): List of unique identifiers for each audio file. + durations (List[float]): List of float durations for each audio file. + rttm_files (List[str]): List of RTTM path strings (Groundtruth diarization annotation file). + offsets (List[float]): List of offsets or None for each audio file. + max_number (Optional[int]): Maximum number of samples to collect. Defaults to None. + do_sort_by_duration (bool): If True, sort samples list by duration. Defaults to False. + index_by_file_id (bool): If True, saves a mapping from filename base (ID) to index in data. + Defaults to False. + + """ + if index_by_file_id: + self.mapping = {} + output_type = self.OUTPUT_TYPE + data, duration_filtered = [], 0.0 + + zipped_items = zip(audio_files, uniq_ids, durations, rttm_files, offsets) + for ( + audio_file, + uniq_id, + duration, + rttm_file, + offset, + ) in zipped_items: + + if duration is None: + duration = 0 + + data.append( + output_type( + audio_file, + uniq_id, + duration, + rttm_file, + offset, + ) + ) + + if index_by_file_id: + if isinstance(audio_file, list): + if len(audio_file) == 0: + raise ValueError(f"Empty audio file list: {audio_file}") + file_id, _ = os.path.splitext(os.path.basename(audio_file)) + self.mapping[file_id] = len(data) - 1 + + # Max number of entities filter. + if len(data) == max_number: + break + + if do_sort_by_duration: + if index_by_file_id: + logging.warning("Tried to sort dataset by duration, but cannot since index_by_file_id is set.") + else: + data.sort(key=lambda entity: entity.duration) + + logging.info( + "Filtered duration for loading collection is %f.", + duration_filtered, + ) + logging.info(f"Total {len(data)} session files loaded accounting to # {len(audio_files)} audio clips") + + super().__init__(data) + + +class EndtoEndDiarizationSpeechLabel(EndtoEndDiarizationLabel): + """End-to-end speaker diarization data sample collector from structured json files.""" + + def __init__( + self, + manifests_files: Union[str, List[str]], + round_digits=2, + *args, + **kwargs, + ): + """ + Parse lists of audio files, durations, RTTM (Diarization annotation) files. + Since diarization model infers only two speakers, speaker pairs are generated + from the total number of speakers in the session. + + Args: + manifest_filepath (str): + Path to input manifest json files. + round_digit (int): + Number of digits to be rounded. + *args: Args to pass to `SpeechLabel` constructor. + **kwargs: Kwargs to pass to `SpeechLabel` constructor. + """ + self.round_digits = round_digits + audio_files, uniq_ids, durations, rttm_files, offsets = ( + [], + [], + [], + [], + [], + ) + + for item in manifest.item_iter(manifests_files, parse_func=self.__parse_item_rttm): + # Training mode + rttm_labels = [] + with open(item['rttm_file'], 'r') as f: + for index, rttm_line in enumerate(f.readlines()): + rttm = rttm_line.strip().split() + start = round(float(rttm[3]), round_digits) + end = round(float(rttm[4]), round_digits) + round(float(rttm[3]), round_digits) + speaker = rttm[7] + rttm_labels.append('{} {} {}'.format(start, end, speaker)) + audio_files.append(item['audio_file']) + uniq_ids.append(item['uniq_id']) + durations.append(item['duration']) + rttm_files.append(item['rttm_file']) + offsets.append(item['offset']) + + super().__init__( + audio_files, + uniq_ids, + durations, + rttm_files, + offsets, + *args, + **kwargs, + ) + + def __parse_item_rttm(self, line: str, manifest_file: str) -> Dict[str, Any]: + """Parse each rttm file and save it to in Dict format""" + item = json.loads(line) + if 'audio_filename' in item: + item['audio_file'] = item.pop('audio_filename') + elif 'audio_filepath' in item: + item['audio_file'] = item.pop('audio_filepath') + else: + raise ValueError( + f"Manifest file has invalid json line " f"structure: {line} without proper audio file key." + ) + if isinstance(item['audio_file'], list): + item['audio_file'] = [os.path.expanduser(audio_file_path) for audio_file_path in item['audio_file']] + else: + item['audio_file'] = os.path.expanduser(item['audio_file']) + + if not isinstance(item['audio_file'], list): + if 'uniq_id' not in item: + item['uniq_id'] = os.path.splitext(os.path.basename(item['audio_file']))[0] + elif 'uniq_id' not in item: + raise ValueError(f"Manifest file has invalid json line " f"structure: {line} without proper uniq_id key.") + + if 'duration' not in item: + raise ValueError(f"Manifest file has invalid json line " f"structure: {line} without proper duration key.") + item = dict( + audio_file=item['audio_file'], + uniq_id=item['uniq_id'], + duration=item['duration'], + rttm_file=item['rttm_filepath'], + offset=item.get('offset', None), + ) + return item + + class Audio(_Collection): """Prepare a list of all audio items, filtered by duration.""" @@ -1641,7 +1835,8 @@ def __init__( manifests_files: Either single string file or list of such - manifests to yield items from. max_number: Maximum number of samples to collect; pass to `FeatureSequenceLabel` constructor. - index_by_file_id: If True, saves a mapping from filename base (ID) to index in data; pass to `FeatureSequenceLabel` constructor. + index_by_file_id: If True, saves a mapping from filename base (ID) to index in data; + pass to `FeatureSequenceLabel` constructor. """ feature_files, labels, durations = [], [], []