From 730daf5285260eda65c0f37c5269496b0d6ed5cc Mon Sep 17 00:00:00 2001 From: Alexey Shmelev Date: Mon, 4 Nov 2024 13:42:17 -0500 Subject: [PATCH] minor bugfixes and code cleanup - removed optimizer steps left from an attempt to fix grading buckets - corrected the error message for incompatible model - simplified the train unpacking - fixed an incorrect move of cached tensors to a wrong device - grouped gradients under the same section for tensorboard output --- rvc/train/train.py | 170 +++++++++++++-------------------------------- 1 file changed, 48 insertions(+), 122 deletions(-) diff --git a/rvc/train/train.py b/rvc/train/train.py index cad1f9cb..cf1c3485 100644 --- a/rvc/train/train.py +++ b/rvc/train/train.py @@ -37,12 +37,6 @@ load_wav_to_torch, ) -from data_utils import ( - DistributedBucketSampler, - TextAudioCollateMultiNSFsid, - TextAudioLoaderMultiNSFsid, -) - from losses import ( discriminator_loss, feature_loss, @@ -54,9 +48,6 @@ from rvc.train.process.extract_model import extract_model from rvc.lib.algorithm import commons -from rvc.lib.algorithm.discriminators import MultiPeriodDiscriminator -from rvc.lib.algorithm.discriminators import MultiPeriodDiscriminatorV2 -from rvc.lib.algorithm.synthesizers import Synthesizer # Parse command line arguments model_name = sys.argv[1] @@ -134,7 +125,7 @@ def verify_checkpoint_shapes(checkpoint_path, model): else: model_state_dict = model.load_state_dict(checkpoint_state_dict) except RuntimeError: - print("The sample rate of the pretrain doesn't match the selected one") + print("The parameters of the pretrain model such as the sample rate or architecture do not match the selected model.") sys.exit(1) else: del checkpoint @@ -313,6 +304,8 @@ def run( if rank == 0: writer = SummaryWriter(log_dir=experiment_dir) writer_eval = SummaryWriter(log_dir=os.path.join(experiment_dir, "eval")) + else: + writer, writer_eval = None, None dist.init_process_group( backend="gloo", @@ -327,6 +320,12 @@ def run( torch.cuda.set_device(rank) # Create datasets and dataloaders + from data_utils import ( + DistributedBucketSampler, + TextAudioCollateMultiNSFsid, + TextAudioLoaderMultiNSFsid, + ) + train_dataset = TextAudioLoaderMultiNSFsid(config.data) collate_fn = TextAudioCollateMultiNSFsid() train_sampler = DistributedBucketSampler( @@ -350,6 +349,10 @@ def run( ) # Initialize models and optimizers + from rvc.lib.algorithm.discriminators import MultiPeriodDiscriminator + from rvc.lib.algorithm.discriminators import MultiPeriodDiscriminatorV2 + from rvc.lib.algorithm.synthesizers import Synthesizer + net_g = Synthesizer( config.data.filter_length // 2 + 1, config.train.segment_size // config.data.hop_length, @@ -430,9 +433,6 @@ def run( optim_d, gamma=config.train.lr_decay, last_epoch=epoch_str - 2 ) - optim_g.step() - optim_d.step() - scaler = GradScaler(enabled=config.train.fp16_run and device.type == "cuda") cache = [] @@ -449,42 +449,25 @@ def run( break for epoch in range(epoch_str, total_epoch + 1): - if rank == 0: - train_and_evaluate( - rank, - epoch, - config, - [net_g, net_d], - [optim_g, optim_d], - scaler, - [train_loader, None], - [writer, writer_eval], - cache, - custom_save_every_weights, - custom_total_epoch, - device, - reference, - ) - else: - train_and_evaluate( - rank, - epoch, - config, - [net_g, net_d], - [optim_g, optim_d], - scaler, - [train_loader, None], - None, - cache, - custom_save_every_weights, - custom_total_epoch, - device, - reference, - ) + train_and_evaluate( + rank, + epoch, + config, + [net_g, net_d], + [optim_g, optim_d], + scaler, + [train_loader, None], + [writer, writer_eval], + cache, + custom_save_every_weights, + custom_total_epoch, + device, + reference, + ) + scheduler_g.step() scheduler_d.step() - def train_and_evaluate( rank, epoch, @@ -539,41 +522,9 @@ def train_and_evaluate( data_iterator = cache if cache == []: for batch_idx, info in enumerate(train_loader): - ( - phone, - phone_lengths, - pitch, - pitchf, - spec, - spec_lengths, - wave, - wave_lengths, - sid, - ) = info - cache.append( - ( - batch_idx, - ( - phone.cuda(rank, non_blocking=True), - phone_lengths.cuda(rank, non_blocking=True), - ( - pitch.cuda(rank, non_blocking=True) - if pitch_guidance - else None - ), - ( - pitchf.cuda(rank, non_blocking=True) - if pitch_guidance - else None - ), - spec.cuda(rank, non_blocking=True), - spec_lengths.cuda(rank, non_blocking=True), - wave.cuda(rank, non_blocking=True), - wave_lengths.cuda(rank, non_blocking=True), - sid.cuda(rank, non_blocking=True), - ), - ) - ) + # phone, phone_lengths, pitch, pitchf, spec, spec_lengths, wave, wave_lengths, sid + info = [tensor.cuda(rank, non_blocking=True) for tensor in info] + cache.append((batch_idx, info)) else: shuffle(cache) else: @@ -582,50 +533,22 @@ def train_and_evaluate( epoch_recorder = EpochRecorder() with tqdm(total=len(train_loader), leave=False) as pbar: for batch_idx, info in data_iterator: - ( - phone, - phone_lengths, - pitch, - pitchf, - spec, - spec_lengths, - wave, - wave_lengths, - sid, - ) = info if device.type == "cuda" and not cache_data_in_gpu: - phone = phone.cuda(rank, non_blocking=True) - phone_lengths = phone_lengths.cuda(rank, non_blocking=True) - pitch = pitch.cuda(rank, non_blocking=True) if pitch_guidance else None - pitchf = ( - pitchf.cuda(rank, non_blocking=True) if pitch_guidance else None - ) - sid = sid.cuda(rank, non_blocking=True) - spec = spec.cuda(rank, non_blocking=True) - spec_lengths = spec_lengths.cuda(rank, non_blocking=True) - wave = wave.cuda(rank, non_blocking=True) - wave_lengths = wave_lengths.cuda(rank, non_blocking=True) - else: - phone = phone.to(device) - phone_lengths = phone_lengths.to(device) - pitch = pitch.to(device) if pitch_guidance else None - pitchf = pitchf.to(device) if pitch_guidance else None - sid = sid.to(device) - spec = spec.to(device) - spec_lengths = spec_lengths.to(device) - wave = wave.to(device) - wave_lengths = wave_lengths.to(device) + info = [tensor.cuda(rank, non_blocking=True) for tensor in info] + elif device.type != "cuda": + info = [tensor.to(device) for tensor in info] + # else iterator is going thru a cached list with a device already assigned + + phone, phone_lengths, pitch, pitchf, spec, spec_lengths, wave, wave_lengths, sid = info + pitch = pitch if pitch_guidance else None + pitchf = pitchf if pitch_guidance else None # Forward pass use_amp = config.train.fp16_run and device.type == "cuda" with autocast(enabled=use_amp): - ( - y_hat, - ids_slice, - x_mask, - z_mask, - (z, z_p, m_p, logs_p, m_q, logs_q), - ) = net_g(phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid) + model_output = net_g(phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid) + y_hat, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q) = model_output + # used for tensorboard chart - all/mel mel = spec_to_mel_torch( spec, config.data.filter_length, @@ -634,12 +557,14 @@ def train_and_evaluate( config.data.mel_fmin, config.data.mel_fmax, ) + # used for tensorboard chart - slice/mel_org y_mel = commons.slice_segments( mel, ids_slice, config.train.segment_size // config.data.hop_length, dim=3, ) + # used for tensorboard chart - slice/mel_gen with autocast(enabled=False): y_hat_mel = mel_spectrogram_torch( y_hat.float().squeeze(1), @@ -653,6 +578,7 @@ def train_and_evaluate( ) if use_amp: y_hat_mel = y_hat_mel.half() + # slice of the original waveform to match a generate slice wave = commons.slice_segments( wave, ids_slice * config.data.hop_length, @@ -714,8 +640,8 @@ def train_and_evaluate( "loss/g/total": loss_gen_all, "loss/d/total": loss_disc, "learning_rate": lr, - "grad_norm_d": grad_norm_d, - "grad_norm_g": grad_norm_g, + "grad/norm_d": grad_norm_d, + "grad/norm_g": grad_norm_g, "loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl": loss_kl,