diff --git a/frechet_audio_distance/fad.py b/frechet_audio_distance/fad.py index 3095b67..1a38f02 100644 --- a/frechet_audio_distance/fad.py +++ b/frechet_audio_distance/fad.py @@ -84,8 +84,17 @@ def __init__( self.submodel_name = submodel_name self.sample_rate = sample_rate self.channels = channels - self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') self.verbose = verbose + self.device = torch.device( + 'cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu') + if self.device == torch.device('mps') and self.model_name == "clap": + if self.verbose: + print("[Frechet Audio Distance] CLAP does not support MPS device yet, because:") + print("[Frechet Audio Distance] The operator 'aten::upsample_bicubic2d.out' is not currently implemented for the MPS device.") + print("[Frechet Audio Distance] Using CPU device instead.") + self.device = torch.device('cpu') + if self.verbose: + print("[Frechet Audio Distance] Using device: {}".format(self.device)) self.audio_load_worker = audio_load_worker self.enable_fusion = enable_fusion if ckpt_dir is not None: @@ -112,6 +121,7 @@ def __get_model(self, model_name="vggish", use_pca=False, use_activation=False): self.model.postprocess = False if not use_activation: self.model.embeddings = nn.Sequential(*list(self.model.embeddings.children())[:-1]) + self.model.device = self.device # pann elif model_name == "pann": # Kong et al., "PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition", IEEE/ACM Transactions on Audio, Speech, and Language Processing 28 (2020) @@ -218,6 +228,7 @@ def __get_model(self, model_name="vggish", use_pca=False, use_activation=False): # these models use 32 residual quantizers self.model.set_target_bandwidth(24.0) + self.model.to(self.device) self.model.eval() def get_embeddings(self, x, sr): @@ -234,7 +245,7 @@ def get_embeddings(self, x, sr): embd = self.model.forward(audio, sr) elif self.model_name == "pann": with torch.no_grad(): - audio = torch.tensor(audio).float().unsqueeze(0) + audio = torch.tensor(audio).float().unsqueeze(0).to(self.device) out = self.model(audio, None) embd = out['embedding'].data[0] elif self.model_name == "clap": @@ -242,13 +253,15 @@ def get_embeddings(self, x, sr): embd = self.model.get_audio_embedding_from_data(audio, use_tensor=True) elif self.model_name == "encodec": # add two dimensions - audio = torch.tensor(audio).float().unsqueeze(0).unsqueeze(0) + audio = torch.tensor( + audio).float().unsqueeze(0).unsqueeze(0).to(self.device) # if SAMPLE_RATE is 48000, we need to make audio stereo if self.model.sample_rate == 48000: if audio.shape[-1] != 2: - print( - "[Frechet Audio Distance] Audio is mono, converting to stereo for 48khz model..." - ) + if self.verbose: + print( + "[Frechet Audio Distance] Audio is mono, converting to stereo for 48khz model..." + ) audio = torch.cat((audio, audio), dim=1) else: # transpose to (batch, channels, samples) @@ -272,7 +285,7 @@ def get_embeddings(self, x, sr): embd.shape ) ) - if self.device == torch.device("cuda"): + if embd.device != torch.device("cpu"): embd = embd.cpu() if torch.is_tensor(embd):