From 435b5fa4a11445ec4309052a15d0774643e5de2e Mon Sep 17 00:00:00 2001 From: Benjie Genchel Date: Thu, 15 Aug 2024 18:00:09 -0400 Subject: [PATCH] Split enum --- basic_pitch/constants.py | 2 +- .../data/tf_example_deserialization.py | 20 +++++++++---------- tests/data/test_tf_example_deserialization.py | 11 +++++----- 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/basic_pitch/constants.py b/basic_pitch/constants.py index 0123dec..69a2020 100644 --- a/basic_pitch/constants.py +++ b/basic_pitch/constants.py @@ -62,4 +62,4 @@ def _freq_bins(bins_per_semitone: int, base_frequency: float, n_semitones: int) FREQ_BINS_NOTES = _freq_bins(NOTES_BINS_PER_SEMITONE, ANNOTATIONS_BASE_FREQUENCY, ANNOTATIONS_N_SEMITONES) FREQ_BINS_CONTOURS = _freq_bins(CONTOURS_BINS_PER_SEMITONE, ANNOTATIONS_BASE_FREQUENCY, ANNOTATIONS_N_SEMITONES) -Splits = Enum("Splits", ["train", "validation", "test"]) +Split = Enum("Split", ["train", "validation", "test"]) diff --git a/basic_pitch/data/tf_example_deserialization.py b/basic_pitch/data/tf_example_deserialization.py index ffaeb73..b59667d 100644 --- a/basic_pitch/data/tf_example_deserialization.py +++ b/basic_pitch/data/tf_example_deserialization.py @@ -33,6 +33,7 @@ AUDIO_WINDOW_LENGTH, N_FREQ_BINS_NOTES, N_FREQ_BINS_CONTOURS, + Split, ) N_SAMPLES_PER_TRACK = 20 @@ -59,13 +60,13 @@ def prepare_datasets( # init both ds_train = sample_datasets( - "train", + Split.train, datasets_base_path, datasets=datasets_to_use, dataset_sampling_frequency=dataset_sampling_frequency, ) ds_validation = sample_datasets( - "validation", + Split.validation, datasets_base_path, datasets=datasets_to_use, dataset_sampling_frequency=dataset_sampling_frequency, @@ -118,14 +119,14 @@ def prepare_visualization_datasets( assert validation_steps is not None and validation_steps > 0 ds_train = sample_datasets( - "train", + Split.train, datasets_base_path, datasets=datasets_to_use, dataset_sampling_frequency=dataset_sampling_frequency, n_samples_per_track=1, ) ds_validation = sample_datasets( - "validation", + Split.validation, datasets_base_path, datasets=datasets_to_use, dataset_sampling_frequency=dataset_sampling_frequency, @@ -153,7 +154,7 @@ def prepare_visualization_datasets( def sample_datasets( - split: str, + split: Split, datasets_base_path: str, datasets: List[str], dataset_sampling_frequency: np.ndarray, @@ -161,8 +162,7 @@ def sample_datasets( n_samples_per_track: int = N_SAMPLES_PER_TRACK, pairs: bool = False, ) -> tf.data.Dataset: - assert split in ["train", "validation"] - if split == "validation": + if split == Split.validation: n_shuffle = 0 pairs = False if n_samples_per_track != 1: @@ -209,7 +209,7 @@ def sample_datasets( def transcription_file_generator( - split: str, + split: Split, dataset_names: List[str], datasets_base_path: str, sample_weights: np.ndarray, @@ -219,12 +219,12 @@ def transcription_file_generator( """ file_dict = { dataset_name: tf.data.Dataset.list_files( - os.path.join(datasets_base_path, dataset_name, "splits", split, "*tfrecord") + os.path.join(datasets_base_path, dataset_name, "splits", split.name, "*tfrecord") ) for dataset_name in dataset_names } - if split == "train": + if split == Split.train: return lambda: _train_file_generator(file_dict, sample_weights), False return lambda: _validation_file_generator(file_dict), True diff --git a/tests/data/test_tf_example_deserialization.py b/tests/data/test_tf_example_deserialization.py index 4a2fd0e..379f5bc 100644 --- a/tests/data/test_tf_example_deserialization.py +++ b/tests/data/test_tf_example_deserialization.py @@ -24,6 +24,7 @@ from apache_beam.testing.test_pipeline import TestPipeline from typing import List +from basic_pitch.constants import Split from basic_pitch.data.datasets.guitarset import GuitarSetToTfExample from basic_pitch.data.pipeline import WriteBatchToTfRecord from basic_pitch.data.tf_example_deserialization import ( @@ -135,7 +136,7 @@ def test_sample_datasets(tmp_path: pathlib.Path) -> None: datasets_home = setup_test_resources(tmp_path) ds = sample_datasets( - split="train", + split=Split.train, datasets_base_path=str(datasets_home), datasets=["guitarset"], dataset_sampling_frequency=np.array([1]), @@ -148,12 +149,12 @@ def test_sample_datasets(tmp_path: pathlib.Path) -> None: def test_transcription_file_generator_train(tmp_path: pathlib.Path) -> None: - dataset_path = tmp_path / "test_ds" / "splits" / "train" + dataset_path = tmp_path / "test_ds" / "splits" / Split.train.name dataset_path.mkdir(parents=True) create_empty_tfrecord(dataset_path / "test.tfrecord") file_gen, random_seed = transcription_file_generator( - "train", ["test_ds"], datasets_base_path=str(tmp_path), sample_weights=np.array([1]) + Split.train, ["test_ds"], datasets_base_path=str(tmp_path), sample_weights=np.array([1]) ) assert random_seed is False @@ -167,12 +168,12 @@ def test_transcription_file_generator_train(tmp_path: pathlib.Path) -> None: def test_transcription_file_generator_valid(tmp_path: pathlib.Path) -> None: - dataset_path = tmp_path / "test_ds" / "splits" / "valid" + dataset_path = tmp_path / "test_ds" / "splits" / Split.validation.name dataset_path.mkdir(parents=True) create_empty_tfrecord(dataset_path / "test.tfrecord") file_gen, random_seed = transcription_file_generator( - "valid", ["test_ds"], datasets_base_path=str(tmp_path), sample_weights=np.array([1]) + Split.validation, ["test_ds"], datasets_base_path=str(tmp_path), sample_weights=np.array([1]) ) assert random_seed is True