From ca8c520443639faaf111380bbb096070fead1fd2 Mon Sep 17 00:00:00 2001 From: Hagen Wierstorf Date: Wed, 24 Jul 2024 15:27:47 +0200 Subject: [PATCH] Add Dataset.segments + Dataset.segment_durations (#31) * Fix implementation of segments * Add Dataset.segment_durations * Add test * Use less memory * Make loading of tables optional * Make self._load_tables private * Improve comment * Improve comment and variable name * Fix test * Add test for loading from cache * Simplify test * Simplify test docstring * Use present tense in docstring * Change setup() to classmethod * Try to fix class method * Improve docstring * Rename setup to prepare * Try to improve docstring * Remove self._load_tables --- audbcards/core/dataset.py | 149 +++++++++++++++++++++++++++++++++++--- tests/test_dataset.py | 92 +++++++++++++++++++++++ 2 files changed, 231 insertions(+), 10 deletions(-) diff --git a/audbcards/core/dataset.py b/audbcards/core/dataset.py index daf08e11..57bed174 100644 --- a/audbcards/core/dataset.py +++ b/audbcards/core/dataset.py @@ -19,6 +19,38 @@ class _Dataset: + _table_related_cached_properties = ["segment_durations", "segments"] + """Cached properties relying on table data. + + Most of the cached properties + rely on the dependency table, + the header of a dataset, + and misc tables used as scheme labels. + Some might also need to load filewise or segmented tables, + to gather more information. + + Persistence of table related cached properties + depends on the ``load_tables`` argument + of :class:`audbcards.Dataset`. + If ``load_tables`` is ``True``, + :meth:`audbcards.Dataset._cached_properties` + is asked to cache table related cached properties as well. + If ``load_tables`` is ``False``, + :meth:`audbcards.Dataset._cached_properties` + is asked to exclude all cached properties, + listed in ``_table_related_cached_properties``. + Which means, + ``_table_related_cached_properties`` has to list all cached properties, + that will load filewise or segmented tables. + + If a dataset exists in cache, + but does not store table related cached properties, + a call to :class:`audbcards.Dataset` + with ``load_tables`` is ``True``, + will update the cache. + + """ + @classmethod def create( cls, @@ -26,6 +58,7 @@ def create( version: str, *, cache_root: str = None, + load_tables: bool = True, ): r"""Instantiate Dataset Object.""" if cache_root is None: @@ -34,11 +67,31 @@ def create( if os.path.exists(dataset_cache_filename): obj = cls._load_pickled(dataset_cache_filename) + # Load cached properties, + # that require to load filewise or segmented tables, + # if they haven't been cached before. + if load_tables: + cache_again = False + for cached_property in cls._table_related_cached_properties: + # Check if property has been cached, + # see https://stackoverflow.com/a/59740750 + if cached_property not in obj.__dict__: + cache_again = True + # Request property to fill their cached value + getattr(obj, cached_property) + if cache_again: + # Update cache to store the table related cached properties + cls._save_pickled(obj, dataset_cache_filename) return obj - obj = cls(name, version, cache_root) - _ = obj._cached_properties() + obj = cls(name, version, cache_root, load_tables) + # Visit cached properties to fill their cache values + if load_tables: + exclude = [] + else: + exclude = cls._table_related_cached_properties + obj._cached_properties(exclude=exclude) cls._save_pickled(obj, dataset_cache_filename) return obj @@ -48,10 +101,24 @@ def __init__( name: str, version: str, cache_root: str = None, + load_tables: bool = True, ): self.cache_root = audeer.mkdir(cache_root) r"""Cache root folder.""" + # Define `__getstate__()` method, + # which selects the cached attributes + # to include in the pickled cache file + if load_tables: + exclude = [] + else: + exclude = self._table_related_cached_properties + + def getstate(): + return self._cached_properties(exclude=exclude) + + self.__getstate__ = getstate + # Store name and version in private attributes here, # ``self.name`` and ``self.version`` # are implemented as cached properties below @@ -76,10 +143,6 @@ def __init__( for other_version in other_versions: audeer.rmdir(cache_root, name, other_version) - def __getstate__(self): - r"""Returns attributes to be pickled.""" - return self._cached_properties() - @staticmethod def _dataset_cache_path(name: str, version: str, cache_root: str) -> str: r"""Generate the name of the cache file.""" @@ -375,6 +438,24 @@ def schemes_table(self) -> typing.List[typing.List[str]]: scheme_data.insert(0, list(data)) return scheme_data + @functools.cached_property + def segments(self) -> str: + r"""Number of segments in dataset.""" + return str(len(self._segments)) + + @functools.cached_property + def segment_durations(self) -> typing.List[float]: + r"""Segment durations in dataset.""" + if len(self._segments) == 0: + durations = [] + else: + starts = self._segments.get_level_values("start") + ends = self._segments.get_level_values("end") + durations = [ + (end - start).total_seconds() for start, end in zip(starts, ends) + ] + return durations + @functools.cached_property def short_description(self) -> str: r"""Description of dataset shortened to 150 chars.""" @@ -424,13 +505,34 @@ def version(self) -> str: r"""Version of dataset.""" return self._version - def _cached_properties(self): - """Get list of cached properties of the object.""" + def _cached_properties( + self, + *, + exclude: typing.Sequence = [], + ) -> typing.Dict[str, typing.Any]: + """Get list of cached properties of the object. + + When collecting the cached properties, + it also executes their code + in order to generate the associated values. + + Args: + exclude: list of cached properties, + that should not be cached + + Returns: + dictionary with property name and value + + """ class_items = self.__class__.__dict__.items() props = dict( (k, getattr(self, k)) for k, v in class_items - if isinstance(v, functools.cached_property) + if ( + isinstance(v, functools.cached_property) + and k not in exclude + and not k.startswith("_") + ) ) return props @@ -558,6 +660,21 @@ def _scheme_to_list(self, scheme_id): return data_dict + @functools.cached_property + def _segments(self) -> pd.MultiIndex: + """Segments of dataset as combined index.""" + index = audformat.segmented_index() + for table in self.header.tables: + if self.header.tables[table].is_segmented: + df = audb.load_table( + self.name, + table, + version=self.version, + verbose=False, + ) + index = audformat.utils.union([index, df.index]) + return index + @staticmethod def _map_iso_languages(languages: typing.List[str]) -> typing.List[str]: r"""Calculate ISO languages for a list of languages. @@ -598,6 +715,11 @@ class Dataset(object): the environmental variable ``AUDBCARDS_CACHE_ROOT``, or :attr:`audbcards.config.CACHE_ROOT` is used + load_tables: if ``True``, + it caches values extracted from tables. + Set this to ``False``, + if loading the tables takes too long, + or does not fit into memory """ @@ -607,9 +729,15 @@ def __new__( version: str, *, cache_root: str = None, + load_tables: bool = True, ): r"""Create Dataset Instance.""" - instance = _Dataset.create(name, version, cache_root=cache_root) + instance = _Dataset.create( + name, + version, + cache_root=cache_root, + load_tables=load_tables, + ) return instance # Add an __init__() function, @@ -620,6 +748,7 @@ def __init__( version: str, *, cache_root: str = None, + load_tables: bool = True, ): self.cache_root = audeer.mkdir(cache_root) r"""Cache root folder.""" diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 5943de86..daccbb19 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -184,6 +184,14 @@ def test_dataset(audb_cache, tmpdir, repository, db, request): ] assert dataset.schemes_table == expected_schemes_table + # segment_durations + expected_segment_durations = [0.5, 0.5, 150, 151] + assert dataset.segment_durations == expected_segment_durations + + # segments + expected_segments = str(len(db.segments)) + assert dataset.segments == expected_segments + # short_description max_desc_length = 150 expected_description = ( @@ -442,3 +450,87 @@ def test_dataset_cache_loading(audb_cache, tmpdir, repository, db, request): # to compare it. assert str(dataset.header) == str(header) assert dataset.repository_object == repository + + +class TestDatasetLoadTables: + r"""Test load_tables argument of audbcards.Dataset.""" + + @classmethod + @pytest.fixture(autouse=True) + def prepare(cls, cache, medium_db): + r"""Provide test class with cache, database name and database version. + + Args: + cache: cache fixture + medium_db: medium_db fixture + + """ + cls.name = medium_db.name + cls.version = pytest.VERSION + cls.cache_root = cache + + def assert_has_table_properties(self, expected: bool): + r"""Assert dataset holds table related cached properties. + + Args: + expected: if ``True``, + ``dataset`` is expected to contain table related properties + + """ + table_related_properties = [ + "segment_durations", + "segments", + ] + for table_related_property in table_related_properties: + if expected: + assert table_related_property in self.dataset.__dict__ + else: + assert table_related_property not in self.dataset.__dict__ + + def load_dataset(self, *, load_tables: bool): + r"""Load dataset. + + Call ``audbcards.Dataset`` and assign result to ``self.dataset``. + + Args: + load_tables: if ``True``, + it caches properties, + that need to load filewise/segmented tables + + """ + self.dataset = audbcards.Dataset( + self.name, + self.version, + cache_root=self.cache_root, + load_tables=load_tables, + ) + + @pytest.mark.parametrize("load_tables_first", [True, False]) + def test_load_tables(self, load_tables_first): + r"""Load dataset with/without table related properties. + + This tests if the table related arguments + are stored or omitted in cache, + dependent on the ``load_tables`` argument. + + It also loads the dataset another two times from cache, + with changing ``load_tables`` + arguments, + which should always result + in existing table related properties, + as a cache stored first with ``load_tables=False``, + should be updated when loading again with ``load_tables=True``. + + Args: + load_tables_first: if ``True``, + it calls ``audbcards.Dataset`` + with ``load_tables=True`` + during it first call + + """ + self.load_dataset(load_tables=load_tables_first) + self.assert_has_table_properties(load_tables_first) + self.load_dataset(load_tables=not load_tables_first) + self.assert_has_table_properties(True) + self.load_dataset(load_tables=load_tables_first) + self.assert_has_table_properties(True)