Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sortformer Diarizer 4spk v1 model PR Part 1: models, modules and dataloaders #11282

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
e69ec8e
Adding the first pr files models and dataset
tango4j Nov 14, 2024
2914325
Tested all unit-test files
tango4j Nov 14, 2024
9a468ac
Name changes on yaml files and train example
tango4j Nov 14, 2024
a910d30
Merge branch 'main' into sortformer/pr_01
tango4j Nov 14, 2024
2f44fe1
Apply isort and black reformatting
tango4j Nov 14, 2024
4ddc59b
Reflecting comments and removing unnecessary parts for this PR
tango4j Nov 15, 2024
43d95f0
Resolved conflicts
tango4j Nov 15, 2024
40e9f95
Apply isort and black reformatting
tango4j Nov 15, 2024
f7f84bb
Adding docstrings to reflect the PR comments
tango4j Nov 15, 2024
95acd79
Resolved the new conflict
tango4j Nov 15, 2024
919f4da
Merge branch 'main' into sortformer/pr_01
tango4j Nov 15, 2024
4134e25
removed the unused find_first_nonzero
tango4j Nov 15, 2024
d3432e5
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j Nov 15, 2024
5dd4d4c
Apply isort and black reformatting
tango4j Nov 15, 2024
ca5eea3
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j Nov 15, 2024
9d493c0
Merge branch 'main' into sortformer/pr_01
tango4j Nov 15, 2024
037f61e
Fixed all pylint issues
tango4j Nov 15, 2024
a8bc048
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j Nov 15, 2024
cb23268
Apply isort and black reformatting
tango4j Nov 15, 2024
4a266b9
Resolving pylint issues
tango4j Nov 15, 2024
5e4e9c8
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j Nov 15, 2024
c31c60c
Merge branch 'main' into sortformer/pr_01
tango4j Nov 15, 2024
6e2225e
Apply isort and black reformatting
tango4j Nov 15, 2024
ab93b17
Removing unused varialbe in audio_to_diar_label.py
tango4j Nov 15, 2024
4f3ee66
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j Nov 15, 2024
3f24b82
Merge branch 'main' into sortformer/pr_01
tango4j Nov 16, 2024
f49e107
Merge branch 'main' into sortformer/pr_01
tango4j Nov 16, 2024
7dea01b
Fixed docstrings in training script
tango4j Nov 16, 2024
2a99d53
Merge branch 'sortformer/pr_01' of https://github.com/tango4j/NeMo in…
tango4j Nov 16, 2024
71d515f
Line-too-long issue from Pylint fixed
tango4j Nov 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
# Sortformer Diarizer is an end-to-end speaker diarization model that is solely based on Transformer-encoder type of architecture.
# Model name convention for Sortformer Diarizer: sortformer_diarizer_<FC_layer>_<TF_layer>_<loss_type>_loss.yaml
# (Example) `sortformer_diarizer_FC18_TF18_hybrid_loss.yaml` has 18 layers for FastConformer and 18 layers of Transformer.
# Sortformer Diarizer model checkpoint (.ckpt) and NeMo file (.nemo) contain Fast Conformer Encoder model (NEST Encoder) and the pre-trained NEST model is loaded along with the Transformer Encoder layers.
# Example: a manifest line for training
# {"audio_filepath": "/path/to/audio01.wav", "offset": 390.83, "duration": 90.00, "text": "-", "num_speakers": 2, "rttm_filepath": "/path/to/audio01.rttm"}
name: "SortFormerDiarizer"
sample_rate: 16000
num_workers: 18
batch_size: 8

model:
pil_weight: 0.5
ats_weight: 0.5
num_workers: ${num_workers}
fc_d_model: 512
tf_d_model: 192
max_num_of_spks: 4 # Number of speakers per model. This is currently fixed at 4.
session_len_sec: 90

train_ds:
manifest_filepath: ???
sample_rate: ${sample_rate}
num_spks: ${model.max_num_of_spks}
session_len_sec: ${model.session_len_sec}
soft_label_thres: 0.5
soft_targets: False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add some short explanation for soft_label_thres and soft_targets?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some comments

labels: null
batch_size: ${batch_size}
shuffle: True
num_workers: ${num_workers}
validation_mode: False
# lhotse config
use_lhotse: False
use_bucketing: True
num_buckets: 10
bucket_duration_bins: [10, 20, 30, 40, 50, 60, 70, 80, 90]
pin_memory: True
min_duration: 80
max_duration: 90
batch_duration: 400
quadratic_duration: 1200
bucket_buffer_size: 20000
shuffle_buffer_size: 10000
window_stride: ${model.preprocessor.window_stride}
subsampling_factor: ${model.encoder.subsampling_factor}

validation_ds:
manifest_filepath: ???
is_tarred: False
tarred_audio_filepaths: null
sample_rate: ${sample_rate}
num_spks: ${model.max_num_of_spks}
session_len_sec: ${model.session_len_sec}
soft_label_thres: 0.5
soft_targets: False
labels: null
batch_size: ${batch_size}
shuffle: False
num_workers: ${num_workers}
validation_mode: True
# lhotse config
use_lhotse: False
use_bucketing: False
drop_last: False
pin_memory: True
window_stride: ${model.preprocessor.window_stride}
subsampling_factor: ${model.encoder.subsampling_factor}

test_ds:
manifest_filepath: null
is_tarred: False
tarred_audio_filepaths: null
sample_rate: 16000
num_spks: ${model.max_num_of_spks}
session_len_sec: ${model.session_len_sec}
soft_label_thres: 0.5
soft_targets: False
labels: null
batch_size: ${batch_size}
shuffle: False
seq_eval_mode: True
num_workers: ${num_workers}
validation_mode: True
# lhotse config
use_lhotse: False
use_bucketing: False
drop_last: False
pin_memory: True
window_stride: ${model.preprocessor.window_stride}
subsampling_factor: ${model.encoder.subsampling_factor}

preprocessor:
_target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor
normalize: "per_feature"
window_size: 0.025
sample_rate: ${sample_rate}
window_stride: 0.01
window: "hann"
features: 80
n_fft: 512
frame_splicing: 1
dither: 0.00001

sortformer_modules:
Copy link
Collaborator

@stevehuang52 stevehuang52 Nov 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does sortformer_modules mean that we can have several different components under this section? If not, maybe just use sortformer_module to align with other fields (e.g., encoder). Also the SortformerModules name could get rid of the s in my opinion.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ivan and I decided to put every Sortformer modules (trainable weights or functions for streaming) so thats why it is plural with "s".
It should be actually "SortformerAuxilaryModules" to be more precise, but for brevity it is "sortformer_modules"

_target_: nemo.collections.asr.modules.sortformer_modules.SortformerModules
num_spks: ${model.max_num_of_spks} # Number of speakers per model. This is currently fixed at 4.
dropout_rate: 0.5 # Dropout rate
fc_d_model: ${model.fc_d_model}
tf_d_model: ${model.tf_d_model} # Hidden layer size for linear layers in Sortformer Diarizer module

encoder:
_target_: nemo.collections.asr.modules.ConformerEncoder
feat_in: ${model.preprocessor.features}
feat_out: -1
n_layers: 18
d_model: ${model.fc_d_model}

# Sub-sampling parameters
subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding
subsampling_factor: 8 # must be power of 2 for striding and vggnet
subsampling_conv_channels: 256 # set to -1 to make it equal to the d_model
causal_downsampling: false

# Feed forward module's params
ff_expansion_factor: 4

# Multi-headed Attention Module's params
self_attention_model: rel_pos # rel_pos or abs_pos
n_heads: 8 # may need to be lower for smaller d_models
# [left, right] specifies the number of steps to be seen from left and right of each step in self-attention
att_context_size: [-1, -1] # -1 means unlimited context
att_context_style: regular # regular or chunked_limited
xscaling: true # scales up the input embeddings by sqrt(d_model)
untie_biases: true # unties the biases of the TransformerXL layers
pos_emb_max_len: 5000

# Convolution module's params
conv_kernel_size: 9
conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups)
conv_context_size: null

### regularization
dropout: 0.1 # The dropout used in most of the Conformer Modules
dropout_pre_encoder: 0.1 # The dropout used before the encoder
dropout_emb: 0.0 # The dropout used for embeddings
dropout_att: 0.1 # The dropout for multi-headed attention modules

# set to non-zero to enable stochastic depth
stochastic_depth_drop_prob: 0.0
stochastic_depth_mode: linear # linear or uniform
stochastic_depth_start_layer: 1

transformer_encoder:
_target_: nemo.collections.asr.modules.transformer.transformer_encoders.TransformerEncoder
num_layers: 18
hidden_size: ${model.tf_d_model} # Needs to be multiple of num_attention_heads
inner_size: 768
num_attention_heads: 8
attn_score_dropout: 0.5
attn_layer_dropout: 0.5
ffn_dropout: 0.5
hidden_act: relu
pre_ln: False
pre_ln_final_layer_norm: True

loss:
_target_: nemo.collections.asr.losses.bce_loss.BCELoss
weight: null # Weight for binary cross-entropy loss. Either `null` or list type input. (e.g. [0.5,0.5])
reduction: mean

lr: 0.0001
optim:
name: adamw
lr: ${model.lr}
# optimizer arguments
betas: [0.9, 0.98]
weight_decay: 1e-3

sched:
name: InverseSquareRootAnnealing
warmup_steps: 2500
warmup_ratio: null
min_lr: 1e-06

trainer:
devices: 1 # number of gpus (devices)
accelerator: gpu
max_epochs: 800
max_steps: -1 # computed at runtime if not set
num_nodes: 1
strategy: ddp_find_unused_parameters_true # Could be "ddp"
accumulate_grad_batches: 1
deterministic: True
enable_checkpointing: False
logger: False
log_every_n_steps: 1 # Interval of logging.
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations

exp_manager:
use_datetime_version: False
exp_dir: null
name: ${name}
resume_if_exists: True
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
resume_ignore_no_checkpoint: True
create_tensorboard_logger: True
create_checkpoint_callback: True
create_wandb_logger: False
checkpoint_callback_params:
monitor: "val_f1_acc"
mode: "max"
save_top_k: 9
every_n_epochs: 1
wandb_logger_kwargs:
resume: True
name: null
project: null
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Postprocessing parameters for timestamp outputs from speaker diarization models.
# This speaker diarization postprocessing scheme is inspired by the postprocessing procedure in the following paper:
# Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020).
# These parameters were optimized with with hybrid-loss trained Sortformer model introduced in https://arxiv.org/pdf/2409.06656.
# These parameters were optimized on the development split of DIHARD3 dataset. See https://arxiv.org/pdf/2012.01477.
# Trial 24682 finished with value: 0.10257785779242055 and parameters: {'onset': 0.53, 'offset': 0.49, 'pad_onset': 0.23, 'pad_offset': 0.01, 'min_duration_on': 0.42, 'min_duration_off': 0.34}. Best is trial 24682 with value: 0.10257785779242055.
parameters:
window_length_in_sec: 0.0 # Not used
shift_length_in_sec: 0.0 # Not used
smoothing: False # Not used
overlap: 0.5 # Not used
onset: 0.53 # Onset threshold for detecting the beginning and end of a speech
offset: 0.49 # Offset threshold for detecting the end of a speech
pad_onset: 0.23 # Adding durations before each speech segment
pad_offset: 0.01 # Adding durations after each speech segment
min_duration_on: 0.42 # Threshold for small non-speech deletion
min_duration_off: 0.34 # Threshold for short speech segment deletion
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Postprocessing parameters for timestamp outputs from speaker diarization models.
# This speaker diarization postprocessing scheme is inspired by the postprocessing procedure in the following paper:
# Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020).
# These parameters were optimized with with hybrid-loss trained Sortformer model introduced in https://arxiv.org/pdf/2409.06656.
# These parameters were optimized on CallHome Dataset from the NIST SRE 2000 Disc8, especially from the split2 specified in: Kaldi, “Kaldi x-vector recipe v2,” https://github.com/kaldi-asr/kaldi/tree/master/egs/callhome_diarization/v2.
# Trial 732 finished with value: 0.12171946949255649 and parameters: {'onset': 0.64, 'offset': 0.74, 'pad_onset': 0.06, 'pad_offset': 0.0, 'min_duration_on': 0.1, 'min_duration_off': 0.15}. Best is trial 732 with value: 0.12171946949255649.
parameters:
window_length_in_sec: 0.0 # Not used
shift_length_in_sec: 0.0 # Not used
smoothing: False # Not used
overlap: 0.5 # Not used
onset: 0.64 # Onset threshold for detecting the beginning and end of a speech
offset: 0.74 # Offset threshold for detecting the end of a speech
pad_onset: 0.06 # Adding durations before each speech segment
pad_offset: 0.0 # Adding durations after each speech segment
min_duration_on: 0.1 # Threshold for small non-speech deletion
min_duration_off: 0.15 # Threshold for short speech segment deletion
Loading
Loading