Skip to content

Commit

Permalink
for slakh, mock all data except metadata.yaml files, remove all store…
Browse files Browse the repository at this point in the history
…d test data. Create utils file in tests/data for create wav, midi, and flac mock files for testing, clean up maestro tests in accordance.
  • Loading branch information
bgenchel committed Aug 1, 2024
1 parent 9914d6a commit 5f60826
Show file tree
Hide file tree
Showing 92 changed files with 159 additions and 81 deletions.
59 changes: 6 additions & 53 deletions tests/data/test_maestro.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy as np
import os
import pathlib
import wave

from mido import MidiFile, MidiTrack, Message
from typing import List

import apache_beam as beam
Expand All @@ -33,6 +29,8 @@
)
from basic_pitch.data.pipeline import WriteBatchToTfRecord

from utils import create_mock_wav, create_mock_midi

RESOURCES_PATH = pathlib.Path(__file__).parent.parent / "resources"
MAESTRO_TEST_DATA_PATH = RESOURCES_PATH / "data" / "maestro"

Expand All @@ -42,57 +40,13 @@
GT_15M_TRACK_ID = "2004/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav"


def create_mock_wav(output_fpath: str, duration_min: int) -> None:
duration_seconds = duration_min * 60
sample_rate = 44100
n_channels = 2 # Stereo
sampwidth = 2 # 2 bytes per sample (16-bit audio)

# Generate a silent audio data array
num_samples = duration_seconds * sample_rate
audio_data = np.zeros((num_samples, n_channels), dtype=np.int16)

# Create the WAV file
with wave.open(str(output_fpath), "w") as wav_file:
wav_file.setnchannels(n_channels)
wav_file.setsampwidth(sampwidth)
wav_file.setframerate(sample_rate)
wav_file.writeframes(audio_data.tobytes())

logging.info(f"Mock {duration_min}-minute WAV file '{output_fpath}' created successfully.")


def create_mock_midi(output_fpath: str) -> None:
# Create a new MIDI file with one track
mid = MidiFile()
track = MidiTrack()
mid.tracks.append(track)

# Define a sequence of notes (time, type, note, velocity)
notes = [
(0, "note_on", 60, 64), # C4
(500, "note_off", 60, 64),
(0, "note_on", 62, 64), # D4
(500, "note_off", 62, 64),
]

# Add the notes to the track
for time, type, note, velocity in notes:
track.append(Message(type, note=note, velocity=velocity, time=time))

# Save the MIDI file
mid.save(output_fpath)

logging.info(f"Mock MIDI file '{output_fpath}' created successfully.")


def test_maestro_to_tf_example(tmp_path: pathlib.Path) -> None:
mock_maestro_home = tmp_path / "maestro"
mock_maestro_ext = mock_maestro_home / "2004"
mock_maestro_ext.mkdir(parents=True, exist_ok=True)

create_mock_wav(str(mock_maestro_ext / (TRAIN_TRACK_ID.split("/")[1] + ".wav")), 3)
create_mock_midi(str(mock_maestro_ext / (TRAIN_TRACK_ID.split("/")[1] + ".midi")))
create_mock_wav(mock_maestro_ext / f"{TRAIN_TRACK_ID.split('/')[1]}.wav", 3)
create_mock_midi(mock_maestro_ext / f"{TRAIN_TRACK_ID.split('/')[1]}.midi")

output_dir = tmp_path / "outputs"
output_dir.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -122,7 +76,7 @@ def test_maestro_invalid_tracks(tmp_path: pathlib.Path) -> None:
input_data = [(TRAIN_TRACK_ID, "train"), (VALID_TRACK_ID, "validation"), (TEST_TRACK_ID, "test")]

for track_id, _ in input_data:
create_mock_wav(str(mock_maestro_ext / (track_id.split("/")[1] + ".wav")), 3)
create_mock_wav(mock_maestro_ext / f"{track_id.split('/')[1]}.wav", 3)

split_labels = set([e[1] for e in input_data])
with TestPipeline() as p:
Expand Down Expand Up @@ -154,8 +108,7 @@ def test_maestro_invalid_tracks_over_15_min(tmp_path: pathlib.Path) -> None:
mock_maestro_ext = mock_maestro_home / "2004"
mock_maestro_ext.mkdir(parents=True, exist_ok=True)

mock_fpath = mock_maestro_ext / (GT_15M_TRACK_ID.split("/")[1] + ".wav")
create_mock_wav(str(mock_fpath), 16)
create_mock_wav(mock_maestro_ext / f"{GT_15M_TRACK_ID.split('/')[1]}.wav", 16)

input_data = [(GT_15M_TRACK_ID, "train")]
split_labels = set([e[1] for e in input_data])
Expand Down
98 changes: 70 additions & 28 deletions tests/data/test_slakh.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
import itertools
import os
import pathlib
import shutil

from typing import List
from typing import List, Tuple

from apache_beam.testing.test_pipeline import TestPipeline

Expand All @@ -30,110 +31,151 @@
)
from basic_pitch.data.pipeline import WriteBatchToTfRecord

from utils import create_mock_flac, create_mock_midi

RESOURCES_PATH = pathlib.Path(__file__).parent.parent / "resources"
SLAKH_PATH = RESOURCES_PATH / "data" / "slakh" / "slakh2100_flac_redux"

TRAIN_PIANO_TRACK_ID = "Track00001-S02"
TRAIN_DRUMS_TRACK_ID = "Track00001-S01"

VALID_PIANO_TRACK_ID = "Track01501-S06"
VALID_DRUMS_TRACK_ID = "Track01501-S03"

TEST_PIANO_TRACK_ID = "Track01876-S01"
TEST_DRUMS_TRACK_ID = "Track01876-S08"

OMITTED_PIANO_TRACK_ID = "Track00049-S05"
OMITTED_DRUMS_TRACK_ID = "Track00049-S06"


def test_slakh_to_tf_example(tmpdir: str) -> None:
input_data: List[str] = [TRAIN_PIANO_TRACK_ID]
# Function to generate a sine wave
def create_mock_input_data(data_home: pathlib.Path, input_data: List[Tuple[str, str]]) -> None:
for track_id, split in input_data:
track_num, inst_num = track_id.split("-")
track_dir = data_home / split / track_num

stems_dir = track_dir / "stems"
stems_dir.mkdir(parents=True, exist_ok=True)
create_mock_flac(stems_dir / (inst_num + ".flac"))

midi_dir = track_dir / "MIDI"
midi_dir.mkdir(parents=True, exist_ok=True)
create_mock_midi(midi_dir / (inst_num + ".mid"))

shutil.copy(SLAKH_PATH / split / track_num / "metadata.yaml", track_dir / "metadata.yaml")


def test_slakh_to_tf_example(tmp_path: pathlib.Path) -> None:
mock_slakh_home = tmp_path / "slakh"
mock_slakh_ext = mock_slakh_home / "slakh2100_flac_redux"

input_data: List[Tuple[str, str]] = [(TRAIN_PIANO_TRACK_ID, "train")]
create_mock_input_data(mock_slakh_ext, input_data)

output_dir = tmp_path / "outputs"
output_dir.mkdir(parents=True, exist_ok=True)

with TestPipeline() as p:
(
p
| "Create PCollection of track IDs" >> beam.Create([input_data])
| "Create tf.Example"
>> beam.ParDo(SlakhToTfExample(str(RESOURCES_PATH / "data" / "slakh"), download=False))
| "Write to tfrecord" >> beam.ParDo(WriteBatchToTfRecord(tmpdir))
| "Create PCollection of track IDs" >> beam.Create([[track_id for track_id, _ in input_data]])
| "Create tf.Example" >> beam.ParDo(SlakhToTfExample(str(mock_slakh_home), download=False))
| "Write to tfrecord" >> beam.ParDo(WriteBatchToTfRecord(str(output_dir)))
)

assert len(os.listdir(tmpdir)) == 1
assert os.path.splitext(os.listdir(tmpdir)[0])[-1] == ".tfrecord"
with open(os.path.join(tmpdir, os.listdir(tmpdir)[0]), "rb") as fp:
listdir = os.listdir(output_dir)
assert len(listdir) == 1
assert os.path.splitext(listdir[0])[-1] == ".tfrecord"
with open(output_dir / listdir[0], "rb") as fp:
data = fp.read()
assert len(data) != 0


def test_slakh_invalid_tracks(tmpdir: str) -> None:
def test_slakh_invalid_tracks(tmp_path: pathlib.Path) -> None:
mock_slakh_home = tmp_path / "slakh"
mock_slakh_ext = mock_slakh_home / "slakh2100_flac_redux"

split_labels = ["train", "validation", "test"]
input_data = [(TRAIN_PIANO_TRACK_ID, "train"), (VALID_PIANO_TRACK_ID, "validation"), (TEST_PIANO_TRACK_ID, "test")]
create_mock_input_data(mock_slakh_ext, input_data)

with TestPipeline() as p:
splits = (
p
| "Create PCollection" >> beam.Create(input_data)
| "Tag it"
>> beam.ParDo(SlakhFilterInvalidTracks(str(RESOURCES_PATH / "data" / "slakh"))).with_outputs(*split_labels)
| "Tag it" >> beam.ParDo(SlakhFilterInvalidTracks(str(mock_slakh_home))).with_outputs(*split_labels)
)

for split in split_labels:
(
getattr(splits, split)
| f"Write {split} to text"
>> beam.io.WriteToText(os.path.join(tmpdir, f"output_{split}.txt"), shard_name_template="")
>> beam.io.WriteToText(str(tmp_path / f"output_{split}.txt"), shard_name_template="")
)

for track_id, split in input_data:
with open(os.path.join(tmpdir, f"output_{split}.txt"), "r") as fp:
with open(tmp_path / f"output_{split}.txt", "r") as fp:
assert fp.read().strip() == track_id


def test_slakh_invalid_tracks_omitted(tmpdir: str) -> None:
def test_slakh_invalid_tracks_omitted(tmp_path: pathlib.Path) -> None:
mock_slakh_home = tmp_path / "slakh"
mock_slakh_ext = mock_slakh_home / "slakh2100_flac_redux"

split_labels = ["train", "omitted"]
input_data = [(TRAIN_PIANO_TRACK_ID, "train"), (OMITTED_PIANO_TRACK_ID, "omitted")]
create_mock_input_data(mock_slakh_ext, input_data)

with TestPipeline() as p:
splits = (
p
| "Create PCollection" >> beam.Create(input_data)
| "Tag it"
>> beam.ParDo(SlakhFilterInvalidTracks(str(RESOURCES_PATH / "data" / "slakh"))).with_outputs(*split_labels)
| "Tag it" >> beam.ParDo(SlakhFilterInvalidTracks(str(mock_slakh_home))).with_outputs(*split_labels)
)

for split in split_labels:
(
getattr(splits, split)
| f"Write {split} to text"
>> beam.io.WriteToText(os.path.join(tmpdir, f"output_{split}.txt"), shard_name_template="")
>> beam.io.WriteToText(str(tmp_path / f"output_{split}.txt"), shard_name_template="")
)

with open(os.path.join(tmpdir, "output_train.txt"), "r") as fp:
with open(tmp_path / "output_train.txt", "r") as fp:
assert fp.read().strip() == TRAIN_PIANO_TRACK_ID

with open(os.path.join(tmpdir, "output_omitted.txt"), "r") as fp:
with open(tmp_path / "output_omitted.txt", "r") as fp:
assert fp.read().strip() == ""


def test_slakh_invalid_tracks_drums(tmpdir: str) -> None:
def test_slakh_invalid_tracks_drums(tmp_path: pathlib.Path) -> None:
mock_slakh_home = tmp_path / "slakh"
mock_slakh_ext = mock_slakh_home / "slakh2100_flac_redux"

split_labels = ["train", "validation", "test"]
input_data = [(TRAIN_DRUMS_TRACK_ID, "train"), (VALID_DRUMS_TRACK_ID, "validation"), (TEST_DRUMS_TRACK_ID, "test")]
create_mock_input_data(mock_slakh_ext, input_data)

with TestPipeline() as p:
splits = (
p
| "Create PCollection" >> beam.Create(input_data)
| "Tag it"
>> beam.ParDo(SlakhFilterInvalidTracks(str(RESOURCES_PATH / "data" / "slakh"))).with_outputs(*split_labels)
| "Tag it" >> beam.ParDo(SlakhFilterInvalidTracks(str(mock_slakh_home))).with_outputs(*split_labels)
)

for split in split_labels:
(
getattr(splits, split)
| f"Write {split} to text"
>> beam.io.WriteToText(os.path.join(tmpdir, f"output_{split}.txt"), shard_name_template="")
>> beam.io.WriteToText(str(tmp_path / f"output_{split}.txt"), shard_name_template="")
)

for track_id, split in input_data:
with open(os.path.join(tmpdir, f"output_{split}.txt"), "r") as fp:
for _, split in input_data:
with open(tmp_path / f"output_{split}.txt", "r") as fp:
assert fp.read().strip() == ""


def test_create_input_data() -> None:
data = create_input_data()
for key, group in itertools.groupby(data, lambda el: el[1]):
for _, group in itertools.groupby(data, lambda el: el[1]):
assert len(list(group))
83 changes: 83 additions & 0 deletions tests/data/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#!/usr/bin/env python
# encoding: utf-8
#
# Copyright 2024 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import numpy as np
import pathlib
import soundfile as sf
import wave

from mido import MidiFile, MidiTrack, Message


def create_mock_wav(output_fpath: pathlib.Path, duration_min: int) -> None:
duration_seconds = duration_min * 60
sample_rate = 44100
n_channels = 2 # Stereo
sampwidth = 2 # 2 bytes per sample (16-bit audio)

# Generate a silent audio data array
num_samples = duration_seconds * sample_rate
audio_data = np.zeros((num_samples, n_channels), dtype=np.int16)

# Create the WAV file
with wave.open(str(output_fpath), "w") as wav_file:
wav_file.setnchannels(n_channels)
wav_file.setsampwidth(sampwidth)
wav_file.setframerate(sample_rate)
wav_file.writeframes(audio_data.tobytes())

logging.info(f"Mock {duration_min}-minute WAV file '{output_fpath}' created successfully.")


def create_mock_flac(output_fpath: pathlib.Path) -> None:
frequency = 440 # A4
duration = 2 # seconds
sample_rate = 44100 # standard
amplitude = 0.5

t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
sin_wave = amplitude * np.sin(duration * np.pi * frequency * t)

# Save as a FLAC file
sf.write(str(output_fpath), sin_wave, frequency, format="FLAC")

logging.info(f"Mock FLAC file {str(output_fpath)} created successfully.")


def create_mock_midi(output_fpath: pathlib.Path) -> None:
# Create a new MIDI file with one track
mid = MidiFile()
track = MidiTrack()
mid.tracks.append(track)

# Define a sequence of notes (time, type, note, velocity)
notes = [
(0, "note_on", 60, 64), # C4
(500, "note_off", 60, 64),
(0, "note_on", 62, 64), # D4
(500, "note_off", 62, 64),
]

# Add the notes to the track
for time, type, note, velocity in notes:
track.append(Message(type, note=note, velocity=velocity, time=time))

# Save the MIDI file
mid.save(output_fpath)

logging.info(f"Mock MIDI file '{output_fpath}' created successfully.")
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 comments on commit 5f60826

Please sign in to comment.