-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sortformer Diarizer 4spk v1 model PR Part 1: models, modules and dataloaders #11282
base: main
Are you sure you want to change the base?
Changes from 5 commits
e69ec8e
2914325
9a468ac
a910d30
2f44fe1
4ddc59b
43d95f0
40e9f95
f7f84bb
95acd79
919f4da
4134e25
d3432e5
5dd4d4c
ca5eea3
9d493c0
037f61e
a8bc048
cb23268
4a266b9
5e4e9c8
c31c60c
6e2225e
ab93b17
4f3ee66
3f24b82
f49e107
7dea01b
2a99d53
71d515f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
# Sortformer Diarizer is an end-to-end speaker diarization model that is solely based on Transformer-encoder type of architecture. | ||
# Model name convention for Sortformer Diarizer: sortformer_diarizer_<FC_layer>_<TF_layer>_<loss_type>_loss.yaml | ||
# (Example) `sortformer_diarizer_FC18_TF18_hybrid_loss.yaml` has 18 layers for FastConformer and 18 layers of Transformer. | ||
# Sortformer Diarizer model checkpoint (.ckpt) and NeMo file (.nemo) contain Fast Conformer Encoder model (NEST Encoder) and the pre-trained NEST model is loaded along with the Transformer Encoder layers. | ||
# Example: a manifest line for training | ||
# {"audio_filepath": "/path/to/audio01.wav", "offset": 390.83, "duration": 90.00, "text": "-", "num_speakers": 2, "rttm_filepath": "/path/to/audio01.rttm"} | ||
name: "SortFormerDiarizer" | ||
sample_rate: 16000 | ||
num_workers: 18 | ||
batch_size: 8 | ||
|
||
model: | ||
pil_weight: 0.5 | ||
ats_weight: 0.5 | ||
num_workers: ${num_workers} | ||
fc_d_model: 512 | ||
tf_d_model: 192 | ||
max_num_of_spks: 4 # Number of speakers per model. This is currently fixed at 4. | ||
session_len_sec: 90 | ||
|
||
train_ds: | ||
manifest_filepath: ??? | ||
sample_rate: ${sample_rate} | ||
num_spks: ${model.max_num_of_spks} | ||
session_len_sec: ${model.session_len_sec} | ||
soft_label_thres: 0.5 | ||
soft_targets: False | ||
labels: null | ||
batch_size: ${batch_size} | ||
shuffle: True | ||
num_workers: ${num_workers} | ||
validation_mode: False | ||
# lhotse config | ||
use_lhotse: False | ||
use_bucketing: True | ||
num_buckets: 10 | ||
bucket_duration_bins: [10, 20, 30, 40, 50, 60, 70, 80, 90] | ||
pin_memory: True | ||
min_duration: 80 | ||
max_duration: 90 | ||
batch_duration: 400 | ||
quadratic_duration: 1200 | ||
bucket_buffer_size: 20000 | ||
shuffle_buffer_size: 10000 | ||
window_stride: ${model.preprocessor.window_stride} | ||
subsampling_factor: ${model.encoder.subsampling_factor} | ||
|
||
validation_ds: | ||
manifest_filepath: ??? | ||
is_tarred: False | ||
tarred_audio_filepaths: null | ||
sample_rate: ${sample_rate} | ||
num_spks: ${model.max_num_of_spks} | ||
session_len_sec: ${model.session_len_sec} | ||
soft_label_thres: 0.5 | ||
soft_targets: False | ||
labels: null | ||
batch_size: ${batch_size} | ||
shuffle: False | ||
num_workers: ${num_workers} | ||
validation_mode: True | ||
# lhotse config | ||
use_lhotse: False | ||
use_bucketing: False | ||
drop_last: False | ||
pin_memory: True | ||
window_stride: ${model.preprocessor.window_stride} | ||
subsampling_factor: ${model.encoder.subsampling_factor} | ||
|
||
test_ds: | ||
manifest_filepath: null | ||
is_tarred: False | ||
tarred_audio_filepaths: null | ||
sample_rate: 16000 | ||
num_spks: ${model.max_num_of_spks} | ||
session_len_sec: ${model.session_len_sec} | ||
soft_label_thres: 0.5 | ||
soft_targets: False | ||
labels: null | ||
batch_size: ${batch_size} | ||
shuffle: False | ||
seq_eval_mode: True | ||
num_workers: ${num_workers} | ||
validation_mode: True | ||
# lhotse config | ||
use_lhotse: False | ||
use_bucketing: False | ||
drop_last: False | ||
pin_memory: True | ||
window_stride: ${model.preprocessor.window_stride} | ||
subsampling_factor: ${model.encoder.subsampling_factor} | ||
|
||
preprocessor: | ||
_target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor | ||
normalize: "per_feature" | ||
window_size: 0.025 | ||
sample_rate: ${sample_rate} | ||
window_stride: 0.01 | ||
window: "hann" | ||
features: 80 | ||
n_fft: 512 | ||
frame_splicing: 1 | ||
dither: 0.00001 | ||
|
||
sortformer_modules: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ivan and I decided to put every Sortformer modules (trainable weights or functions for streaming) so thats why it is plural with "s". |
||
_target_: nemo.collections.asr.modules.sortformer_modules.SortformerModules | ||
num_spks: ${model.max_num_of_spks} # Number of speakers per model. This is currently fixed at 4. | ||
dropout_rate: 0.5 # Dropout rate | ||
fc_d_model: ${model.fc_d_model} | ||
tf_d_model: ${model.tf_d_model} # Hidden layer size for linear layers in Sortformer Diarizer module | ||
|
||
encoder: | ||
_target_: nemo.collections.asr.modules.ConformerEncoder | ||
feat_in: ${model.preprocessor.features} | ||
feat_out: -1 | ||
n_layers: 18 | ||
d_model: ${model.fc_d_model} | ||
|
||
# Sub-sampling parameters | ||
subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding | ||
subsampling_factor: 8 # must be power of 2 for striding and vggnet | ||
subsampling_conv_channels: 256 # set to -1 to make it equal to the d_model | ||
causal_downsampling: false | ||
|
||
# Feed forward module's params | ||
ff_expansion_factor: 4 | ||
|
||
# Multi-headed Attention Module's params | ||
self_attention_model: rel_pos # rel_pos or abs_pos | ||
n_heads: 8 # may need to be lower for smaller d_models | ||
# [left, right] specifies the number of steps to be seen from left and right of each step in self-attention | ||
att_context_size: [-1, -1] # -1 means unlimited context | ||
att_context_style: regular # regular or chunked_limited | ||
xscaling: true # scales up the input embeddings by sqrt(d_model) | ||
untie_biases: true # unties the biases of the TransformerXL layers | ||
pos_emb_max_len: 5000 | ||
|
||
# Convolution module's params | ||
conv_kernel_size: 9 | ||
conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) | ||
conv_context_size: null | ||
|
||
### regularization | ||
dropout: 0.1 # The dropout used in most of the Conformer Modules | ||
dropout_pre_encoder: 0.1 # The dropout used before the encoder | ||
dropout_emb: 0.0 # The dropout used for embeddings | ||
dropout_att: 0.1 # The dropout for multi-headed attention modules | ||
|
||
# set to non-zero to enable stochastic depth | ||
stochastic_depth_drop_prob: 0.0 | ||
stochastic_depth_mode: linear # linear or uniform | ||
stochastic_depth_start_layer: 1 | ||
|
||
transformer_encoder: | ||
_target_: nemo.collections.asr.modules.transformer.transformer_encoders.TransformerEncoder | ||
num_layers: 18 | ||
hidden_size: ${model.tf_d_model} # Needs to be multiple of num_attention_heads | ||
inner_size: 768 | ||
num_attention_heads: 8 | ||
attn_score_dropout: 0.5 | ||
attn_layer_dropout: 0.5 | ||
ffn_dropout: 0.5 | ||
hidden_act: relu | ||
pre_ln: False | ||
pre_ln_final_layer_norm: True | ||
|
||
loss: | ||
_target_: nemo.collections.asr.losses.bce_loss.BCELoss | ||
weight: null # Weight for binary cross-entropy loss. Either `null` or list type input. (e.g. [0.5,0.5]) | ||
reduction: mean | ||
|
||
lr: 0.0001 | ||
optim: | ||
name: adamw | ||
lr: ${model.lr} | ||
# optimizer arguments | ||
betas: [0.9, 0.98] | ||
weight_decay: 1e-3 | ||
|
||
sched: | ||
name: InverseSquareRootAnnealing | ||
warmup_steps: 2500 | ||
warmup_ratio: null | ||
min_lr: 1e-06 | ||
|
||
trainer: | ||
devices: 1 # number of gpus (devices) | ||
accelerator: gpu | ||
max_epochs: 800 | ||
max_steps: -1 # computed at runtime if not set | ||
num_nodes: 1 | ||
strategy: ddp_find_unused_parameters_true # Could be "ddp" | ||
accumulate_grad_batches: 1 | ||
deterministic: True | ||
enable_checkpointing: False | ||
logger: False | ||
log_every_n_steps: 1 # Interval of logging. | ||
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations | ||
|
||
exp_manager: | ||
use_datetime_version: False | ||
exp_dir: null | ||
name: ${name} | ||
resume_if_exists: True | ||
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. | ||
resume_ignore_no_checkpoint: True | ||
create_tensorboard_logger: True | ||
create_checkpoint_callback: True | ||
create_wandb_logger: False | ||
checkpoint_callback_params: | ||
monitor: "val_f1_acc" | ||
mode: "max" | ||
save_top_k: 9 | ||
every_n_epochs: 1 | ||
wandb_logger_kwargs: | ||
resume: True | ||
name: null | ||
project: null |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Postprocessing parameters for timestamp outputs from speaker diarization models. | ||
# This speaker diarization postprocessing scheme is inspired by the postprocessing procedure in the following paper: | ||
# Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020). | ||
# These parameters were optimized with with hybrid-loss trained Sortformer model introduced in https://arxiv.org/pdf/2409.06656. | ||
# These parameters were optimized on the development split of DIHARD3 dataset. See https://arxiv.org/pdf/2012.01477. | ||
# Trial 24682 finished with value: 0.10257785779242055 and parameters: {'onset': 0.53, 'offset': 0.49, 'pad_onset': 0.23, 'pad_offset': 0.01, 'min_duration_on': 0.42, 'min_duration_off': 0.34}. Best is trial 24682 with value: 0.10257785779242055. | ||
parameters: | ||
window_length_in_sec: 0.0 # Not used | ||
shift_length_in_sec: 0.0 # Not used | ||
smoothing: False # Not used | ||
overlap: 0.5 # Not used | ||
onset: 0.53 # Onset threshold for detecting the beginning and end of a speech | ||
offset: 0.49 # Offset threshold for detecting the end of a speech | ||
pad_onset: 0.23 # Adding durations before each speech segment | ||
pad_offset: 0.01 # Adding durations after each speech segment | ||
min_duration_on: 0.42 # Threshold for small non-speech deletion | ||
min_duration_off: 0.34 # Threshold for short speech segment deletion |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Postprocessing parameters for timestamp outputs from speaker diarization models. | ||
# This speaker diarization postprocessing scheme is inspired by the postprocessing procedure in the following paper: | ||
# Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020). | ||
# These parameters were optimized with with hybrid-loss trained Sortformer model introduced in https://arxiv.org/pdf/2409.06656. | ||
# These parameters were optimized on CallHome Dataset from the NIST SRE 2000 Disc8, especially from the split2 specified in: Kaldi, “Kaldi x-vector recipe v2,” https://github.com/kaldi-asr/kaldi/tree/master/egs/callhome_diarization/v2. | ||
# Trial 732 finished with value: 0.12171946949255649 and parameters: {'onset': 0.64, 'offset': 0.74, 'pad_onset': 0.06, 'pad_offset': 0.0, 'min_duration_on': 0.1, 'min_duration_off': 0.15}. Best is trial 732 with value: 0.12171946949255649. | ||
parameters: | ||
window_length_in_sec: 0.0 # Not used | ||
shift_length_in_sec: 0.0 # Not used | ||
smoothing: False # Not used | ||
overlap: 0.5 # Not used | ||
onset: 0.64 # Onset threshold for detecting the beginning and end of a speech | ||
offset: 0.74 # Offset threshold for detecting the end of a speech | ||
pad_onset: 0.06 # Adding durations before each speech segment | ||
pad_offset: 0.0 # Adding durations after each speech segment | ||
min_duration_on: 0.1 # Threshold for small non-speech deletion | ||
min_duration_off: 0.15 # Threshold for short speech segment deletion |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add some short explanation for
soft_label_thres
andsoft_targets
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added some comments