Skip to content

Commit

Permalink
Merge pull request #130 from spotify/bgenchel/add-maestro
Browse files Browse the repository at this point in the history
Add Maestro
  • Loading branch information
drubinstein authored Jul 31, 2024
2 parents a3fcd5f + c99353d commit 6cfc090
Show file tree
Hide file tree
Showing 9 changed files with 436 additions and 11 deletions.
2 changes: 1 addition & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
include *.txt tox.ini *.rst *.md LICENSE
include catalog-info.yaml
include Dockerfile .dockerignore
recursive-include tests *.py *.wav *.npz *.jams *.zip
recursive-include tests *.py *.wav *.npz *.jams *.zip *.midi *.csv *.json
recursive-include basic_pitch *.py *.md
recursive-include basic_pitch/saved_models *.index *.pb variables.data* *.mlmodel *.json *.onnx *.tflite *.bin
223 changes: 223 additions & 0 deletions basic_pitch/data/datasets/maestro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
#!/usr/bin/env python
# encoding: utf-8
#
# Copyright 2024 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import logging
import os
import sys
import tempfile
import time
from typing import Any, Dict, List, TextIO, Tuple

import apache_beam as beam
import mirdata

from basic_pitch.data import commandline, pipeline


def read_in_chunks(file_object: TextIO, chunk_size: int = 1024) -> Any:
"""Lazy function (generator) to read a file piece by piece.
Default chunk size: 1k."""
while True:
data = file_object.read(chunk_size)
if not data:
break
yield data


class MaestroInvalidTracks(beam.DoFn):
DOWNLOAD_ATTRIBUTES = ["audio_path"]

def __init__(self, source: str) -> None:
self.source = source

def setup(self) -> None:
# Oddly enough we dont want to include the gcs bucket uri.
# Just the path within the bucket
self.maestro_remote = mirdata.initialize("maestro", data_home=self.source)
self.filesystem = beam.io.filesystems.FileSystems()

def process(self, element: Tuple[str, str], *args: Tuple[Any, Any], **kwargs: Dict[str, Any]) -> Any:
import tempfile
import sox

track_id, split = element
logging.info(f"Processing (track_id, split): ({track_id}, {split})")

track_remote = self.maestro_remote.track(track_id)
with tempfile.TemporaryDirectory() as local_tmp_dir:
maestro_local = mirdata.initialize("maestro", local_tmp_dir)
track_local = maestro_local.track(track_id)

for attribute in self.DOWNLOAD_ATTRIBUTES:
source = getattr(track_remote, attribute)
destination = getattr(track_local, attribute)
os.makedirs(os.path.dirname(destination), exist_ok=True)
with self.filesystem.open(source) as s, open(destination, "wb") as d:
for piece in read_in_chunks(s):
d.write(piece)

# 15 minutes * 60 seconds/minute
if sox.file_info.duration(track_local.audio_path) >= 15 * 60:
return None

yield beam.pvalue.TaggedOutput(split, track_id)


class MaestroToTfExample(beam.DoFn):
DOWNLOAD_ATTRIBUTES = ["audio_path", "midi_path"]

def __init__(self, source: str, download: bool):
self.source = source
self.download = download

def setup(self) -> None:
import apache_beam as beam
import mirdata

# Oddly enough we dont want to include the gcs bucket uri.
# Just the path within the bucket
self.maestro_remote = mirdata.initialize("maestro", data_home=self.source)
self.filesystem = beam.io.filesystems.FileSystems()
if self.download:
self.maestro_remote.download()

def process(self, element: List[str], *args: Tuple[Any, Any], **kwargs: Dict[str, Any]) -> List[Any]:
import tempfile

import numpy as np
import sox

from basic_pitch.constants import (
AUDIO_N_CHANNELS,
AUDIO_SAMPLE_RATE,
FREQ_BINS_CONTOURS,
FREQ_BINS_NOTES,
ANNOTATION_HOP,
N_FREQ_BINS_NOTES,
N_FREQ_BINS_CONTOURS,
)
from basic_pitch.data import tf_example_serialization

logging.info(f"Processing {element}")
batch = []

for track_id in element:
track_remote = self.maestro_remote.track(track_id)
with tempfile.TemporaryDirectory() as local_tmp_dir:
maestro_local = mirdata.initialize("maestro", local_tmp_dir)
track_local = maestro_local.track(track_id)

for attribute in self.DOWNLOAD_ATTRIBUTES:
source = getattr(track_remote, attribute)
destination = getattr(track_local, attribute)
os.makedirs(os.path.dirname(destination), exist_ok=True)
with self.filesystem.open(source) as s, open(destination, "wb") as d:
# d.write(s.read())
for piece in read_in_chunks(s):
d.write(piece)

local_wav_path = f"{track_local.audio_path}_tmp.wav"

tfm = sox.Transformer()
tfm.rate(AUDIO_SAMPLE_RATE)
tfm.channels(AUDIO_N_CHANNELS)
tfm.build(track_local.audio_path, local_wav_path)

duration = sox.file_info.duration(local_wav_path)
time_scale = np.arange(0, duration + ANNOTATION_HOP, ANNOTATION_HOP)
n_time_frames = len(time_scale)

note_indices, note_values = track_local.notes.to_sparse_index(time_scale, "s", FREQ_BINS_NOTES, "hz")
onset_indices, onset_values = track_local.notes.to_sparse_index(
time_scale, "s", FREQ_BINS_NOTES, "hz", onsets_only=True
)
contour_indices, contour_values = track_local.notes.to_sparse_index(
time_scale, "s", FREQ_BINS_CONTOURS, "hz"
)

batch.append(
tf_example_serialization.to_transcription_tfexample(
track_local.track_id,
"maestro",
local_wav_path,
note_indices,
note_values,
onset_indices,
onset_values,
contour_indices,
contour_values,
(n_time_frames, N_FREQ_BINS_NOTES),
(n_time_frames, N_FREQ_BINS_CONTOURS),
)
)
return [batch]


def create_input_data(source: str) -> List[Tuple[str, str]]:
import apache_beam as beam

filesystem = beam.io.filesystems.FileSystems()

with tempfile.TemporaryDirectory() as tmpdir:
maestro = mirdata.initialize("maestro", data_home=tmpdir)
metadata_path = maestro._index["metadata"]["maestro-v2.0.0"][0]
with filesystem.open(
os.path.join(source, metadata_path),
) as s, open(os.path.join(tmpdir, metadata_path), "wb") as d:
d.write(s.read())

return [(track_id, track.split) for track_id, track in maestro.load_tracks().items()]


def main(known_args: argparse.Namespace, pipeline_args: List[str]) -> None:
time_created = int(time.time())
destination = commandline.resolve_destination(known_args, time_created)

# TODO: Remove or abstract for foss
pipeline_options = {
"runner": known_args.runner,
"job_name": f"maestro-tfrecords-{time_created}",
"machine_type": "e2-highmem-4",
"num_workers": 25,
"disk_size_gb": 128,
"experiments": ["use_runner_v2", "no_use_multiple_sdk_containers"],
"save_main_session": True,
"sdk_container_image": known_args.sdk_container_image,
"job_endpoint": known_args.job_endpoint,
"environment_type": "DOCKER",
"environment_config": known_args.sdk_container_image,
}
input_data = create_input_data(known_args.source)
pipeline.run(
pipeline_options,
pipeline_args,
input_data,
MaestroToTfExample(known_args.source, download=True),
MaestroInvalidTracks(known_args.source),
destination,
known_args.batch_size,
)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
commandline.add_default(parser, os.path.basename(os.path.splitext(__file__)[0]))
commandline.add_split(parser)
known_args, pipeline_args = parser.parse_known_args(sys.argv)

main(known_args, pipeline_args)
9 changes: 8 additions & 1 deletion basic_pitch/data/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,19 @@
from basic_pitch.data import commandline
from basic_pitch.data.datasets.guitarset import main as guitarset_main
from basic_pitch.data.datasets.ikala import main as ikala_main
from basic_pitch.data.datasets.maestro import main as maestro_main
from basic_pitch.data.datasets.medleydb_pitch import main as medleydb_pitch_main

logger = logging.getLogger()
logger.setLevel(logging.INFO)

DATASET_DICT = {"guitarset": guitarset_main, "ikala": ikala_main, "medleydb_pitch": medleydb_pitch_main}

DATASET_DICT = {
"guitarset": guitarset_main,
"ikala": ikala_main,
"maestro": maestro_main,
"medleydb_pitch": medleydb_pitch_main,
}


def main() -> None:
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ test = [
"coverage>=5.0.2",
"pytest>=6.1.1",
"pytest-mock",
"wave",
"mido"
]
tf = [
"tensorflow>=2.4.1,<2.15.1; platform_system != 'Darwin'",
Expand Down
8 changes: 4 additions & 4 deletions tests/data/test_guitarset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
TRACK_ID = "00_BN1-129-Eb_comp"


def test_guitar_set_to_tf_example(tmpdir: str) -> None:
def test_guitarset_to_tf_example(tmpdir: str) -> None:
input_data: List[str] = [TRACK_ID]
with TestPipeline() as p:
(
Expand All @@ -51,7 +51,7 @@ def test_guitar_set_to_tf_example(tmpdir: str) -> None:
assert len(data) != 0


def test_guitar_set_invalid_tracks(tmpdir: str) -> None:
def test_guitarset_invalid_tracks(tmpdir: str) -> None:
split_labels = ["train", "test", "validation"]
input_data = [(str(i), split) for i, split in enumerate(split_labels)]
with TestPipeline() as p:
Expand All @@ -73,15 +73,15 @@ def test_guitar_set_invalid_tracks(tmpdir: str) -> None:
assert fp.read().strip() == str(i)


def test_create_input_data() -> None:
def test_guitarset_create_input_data() -> None:
data = create_input_data(train_percent=0.33, validation_percent=0.33)
data.sort(key=lambda el: el[1]) # sort by split
tolerance = 0.1
for key, group in itertools.groupby(data, lambda el: el[1]):
assert (0.33 - tolerance) * len(data) <= len(list(group)) <= (0.33 + tolerance) * len(data)


def test_create_input_data_overallocate() -> None:
def test_guitarset_create_input_data_overallocate() -> None:
try:
create_input_data(train_percent=0.6, validation_percent=0.6)
except AssertionError:
Expand Down
8 changes: 4 additions & 4 deletions tests/data/test_ikala.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ def test_ikala_invalid_tracks(tmpdir: str) -> None:
assert fp.read().strip() == str(i)


def test_create_input_data() -> None:
def test_ikala_create_input_data() -> None:
data = create_input_data(train_percent=0.5)
data.sort(key=lambda el: el[1]) # sort by split
tolerance = 0.05
for key, group in itertools.groupby(data, lambda el: el[1]):
tolerance = 0.1
for _, group in itertools.groupby(data, lambda el: el[1]):
assert (0.5 - tolerance) * len(data) <= len(list(group)) <= (0.5 + tolerance) * len(data)


def test_create_input_data_overallocate() -> None:
def test_ikala_create_input_data_overallocate() -> None:
try:
create_input_data(train_percent=1.1)
except AssertionError:
Expand Down
Loading

0 comments on commit 6cfc090

Please sign in to comment.