Skip to content

Commit

Permalink
add automatic MPS device support (#27)
Browse files Browse the repository at this point in the history
* add MPS device support

* remove duplicate
  • Loading branch information
balintlaczko authored Mar 18, 2024
1 parent 99ab594 commit 1fd603a
Showing 1 changed file with 20 additions and 7 deletions.
27 changes: 20 additions & 7 deletions frechet_audio_distance/fad.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,17 @@ def __init__(
self.submodel_name = submodel_name
self.sample_rate = sample_rate
self.channels = channels
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self.verbose = verbose
self.device = torch.device(
'cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')
if self.device == torch.device('mps') and self.model_name == "clap":
if self.verbose:
print("[Frechet Audio Distance] CLAP does not support MPS device yet, because:")
print("[Frechet Audio Distance] The operator 'aten::upsample_bicubic2d.out' is not currently implemented for the MPS device.")
print("[Frechet Audio Distance] Using CPU device instead.")
self.device = torch.device('cpu')
if self.verbose:
print("[Frechet Audio Distance] Using device: {}".format(self.device))
self.audio_load_worker = audio_load_worker
self.enable_fusion = enable_fusion
if ckpt_dir is not None:
Expand All @@ -112,6 +121,7 @@ def __get_model(self, model_name="vggish", use_pca=False, use_activation=False):
self.model.postprocess = False
if not use_activation:
self.model.embeddings = nn.Sequential(*list(self.model.embeddings.children())[:-1])
self.model.device = self.device
# pann
elif model_name == "pann":
# Kong et al., "PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition", IEEE/ACM Transactions on Audio, Speech, and Language Processing 28 (2020)
Expand Down Expand Up @@ -218,6 +228,7 @@ def __get_model(self, model_name="vggish", use_pca=False, use_activation=False):
# these models use 32 residual quantizers
self.model.set_target_bandwidth(24.0)

self.model.to(self.device)
self.model.eval()

def get_embeddings(self, x, sr):
Expand All @@ -234,21 +245,23 @@ def get_embeddings(self, x, sr):
embd = self.model.forward(audio, sr)
elif self.model_name == "pann":
with torch.no_grad():
audio = torch.tensor(audio).float().unsqueeze(0)
audio = torch.tensor(audio).float().unsqueeze(0).to(self.device)
out = self.model(audio, None)
embd = out['embedding'].data[0]
elif self.model_name == "clap":
audio = torch.tensor(audio).float().unsqueeze(0)
embd = self.model.get_audio_embedding_from_data(audio, use_tensor=True)
elif self.model_name == "encodec":
# add two dimensions
audio = torch.tensor(audio).float().unsqueeze(0).unsqueeze(0)
audio = torch.tensor(
audio).float().unsqueeze(0).unsqueeze(0).to(self.device)
# if SAMPLE_RATE is 48000, we need to make audio stereo
if self.model.sample_rate == 48000:
if audio.shape[-1] != 2:
print(
"[Frechet Audio Distance] Audio is mono, converting to stereo for 48khz model..."
)
if self.verbose:
print(
"[Frechet Audio Distance] Audio is mono, converting to stereo for 48khz model..."
)
audio = torch.cat((audio, audio), dim=1)
else:
# transpose to (batch, channels, samples)
Expand All @@ -272,7 +285,7 @@ def get_embeddings(self, x, sr):
embd.shape
)
)
if self.device == torch.device("cuda"):
if embd.device != torch.device("cpu"):
embd = embd.cpu()

if torch.is_tensor(embd):
Expand Down

0 comments on commit 1fd603a

Please sign in to comment.