Skip to content

Commit

Permalink
Add Dataset.segments + Dataset.segment_durations (#31)
Browse files Browse the repository at this point in the history
* Fix implementation of segments

* Add Dataset.segment_durations

* Add test

* Use less memory

* Make loading of tables optional

* Make self._load_tables private

* Improve comment

* Improve comment and variable name

* Fix test

* Add test for loading from cache

* Simplify test

* Simplify test docstring

* Use present tense in docstring

* Change setup() to classmethod

* Try to fix class method

* Improve docstring

* Rename setup to prepare

* Try to improve docstring

* Remove self._load_tables
  • Loading branch information
hagenw authored Jul 24, 2024
1 parent 4f139ba commit ca8c520
Show file tree
Hide file tree
Showing 2 changed files with 231 additions and 10 deletions.
149 changes: 139 additions & 10 deletions audbcards/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,46 @@


class _Dataset:
_table_related_cached_properties = ["segment_durations", "segments"]
"""Cached properties relying on table data.
Most of the cached properties
rely on the dependency table,
the header of a dataset,
and misc tables used as scheme labels.
Some might also need to load filewise or segmented tables,
to gather more information.
Persistence of table related cached properties
depends on the ``load_tables`` argument
of :class:`audbcards.Dataset`.
If ``load_tables`` is ``True``,
:meth:`audbcards.Dataset._cached_properties`
is asked to cache table related cached properties as well.
If ``load_tables`` is ``False``,
:meth:`audbcards.Dataset._cached_properties`
is asked to exclude all cached properties,
listed in ``_table_related_cached_properties``.
Which means,
``_table_related_cached_properties`` has to list all cached properties,
that will load filewise or segmented tables.
If a dataset exists in cache,
but does not store table related cached properties,
a call to :class:`audbcards.Dataset`
with ``load_tables`` is ``True``,
will update the cache.
"""

@classmethod
def create(
cls,
name: str,
version: str,
*,
cache_root: str = None,
load_tables: bool = True,
):
r"""Instantiate Dataset Object."""
if cache_root is None:
Expand All @@ -34,11 +67,31 @@ def create(

if os.path.exists(dataset_cache_filename):
obj = cls._load_pickled(dataset_cache_filename)
# Load cached properties,
# that require to load filewise or segmented tables,
# if they haven't been cached before.
if load_tables:
cache_again = False
for cached_property in cls._table_related_cached_properties:
# Check if property has been cached,
# see https://stackoverflow.com/a/59740750
if cached_property not in obj.__dict__:
cache_again = True
# Request property to fill their cached value
getattr(obj, cached_property)
if cache_again:
# Update cache to store the table related cached properties
cls._save_pickled(obj, dataset_cache_filename)

return obj

obj = cls(name, version, cache_root)
_ = obj._cached_properties()
obj = cls(name, version, cache_root, load_tables)
# Visit cached properties to fill their cache values
if load_tables:
exclude = []
else:
exclude = cls._table_related_cached_properties
obj._cached_properties(exclude=exclude)

cls._save_pickled(obj, dataset_cache_filename)
return obj
Expand All @@ -48,10 +101,24 @@ def __init__(
name: str,
version: str,
cache_root: str = None,
load_tables: bool = True,
):
self.cache_root = audeer.mkdir(cache_root)
r"""Cache root folder."""

# Define `__getstate__()` method,
# which selects the cached attributes
# to include in the pickled cache file
if load_tables:
exclude = []
else:
exclude = self._table_related_cached_properties

def getstate():
return self._cached_properties(exclude=exclude)

self.__getstate__ = getstate

# Store name and version in private attributes here,
# ``self.name`` and ``self.version``
# are implemented as cached properties below
Expand All @@ -76,10 +143,6 @@ def __init__(
for other_version in other_versions:
audeer.rmdir(cache_root, name, other_version)

def __getstate__(self):
r"""Returns attributes to be pickled."""
return self._cached_properties()

@staticmethod
def _dataset_cache_path(name: str, version: str, cache_root: str) -> str:
r"""Generate the name of the cache file."""
Expand Down Expand Up @@ -375,6 +438,24 @@ def schemes_table(self) -> typing.List[typing.List[str]]:
scheme_data.insert(0, list(data))
return scheme_data

@functools.cached_property
def segments(self) -> str:
r"""Number of segments in dataset."""
return str(len(self._segments))

@functools.cached_property
def segment_durations(self) -> typing.List[float]:
r"""Segment durations in dataset."""
if len(self._segments) == 0:
durations = []
else:
starts = self._segments.get_level_values("start")
ends = self._segments.get_level_values("end")
durations = [
(end - start).total_seconds() for start, end in zip(starts, ends)
]
return durations

@functools.cached_property
def short_description(self) -> str:
r"""Description of dataset shortened to 150 chars."""
Expand Down Expand Up @@ -424,13 +505,34 @@ def version(self) -> str:
r"""Version of dataset."""
return self._version

def _cached_properties(self):
"""Get list of cached properties of the object."""
def _cached_properties(
self,
*,
exclude: typing.Sequence = [],
) -> typing.Dict[str, typing.Any]:
"""Get list of cached properties of the object.
When collecting the cached properties,
it also executes their code
in order to generate the associated values.
Args:
exclude: list of cached properties,
that should not be cached
Returns:
dictionary with property name and value
"""
class_items = self.__class__.__dict__.items()
props = dict(
(k, getattr(self, k))
for k, v in class_items
if isinstance(v, functools.cached_property)
if (
isinstance(v, functools.cached_property)
and k not in exclude
and not k.startswith("_")
)
)
return props

Expand Down Expand Up @@ -558,6 +660,21 @@ def _scheme_to_list(self, scheme_id):

return data_dict

@functools.cached_property
def _segments(self) -> pd.MultiIndex:
"""Segments of dataset as combined index."""
index = audformat.segmented_index()
for table in self.header.tables:
if self.header.tables[table].is_segmented:
df = audb.load_table(
self.name,
table,
version=self.version,
verbose=False,
)
index = audformat.utils.union([index, df.index])
return index

@staticmethod
def _map_iso_languages(languages: typing.List[str]) -> typing.List[str]:
r"""Calculate ISO languages for a list of languages.
Expand Down Expand Up @@ -598,6 +715,11 @@ class Dataset(object):
the environmental variable ``AUDBCARDS_CACHE_ROOT``,
or :attr:`audbcards.config.CACHE_ROOT`
is used
load_tables: if ``True``,
it caches values extracted from tables.
Set this to ``False``,
if loading the tables takes too long,
or does not fit into memory
"""

Expand All @@ -607,9 +729,15 @@ def __new__(
version: str,
*,
cache_root: str = None,
load_tables: bool = True,
):
r"""Create Dataset Instance."""
instance = _Dataset.create(name, version, cache_root=cache_root)
instance = _Dataset.create(
name,
version,
cache_root=cache_root,
load_tables=load_tables,
)
return instance

# Add an __init__() function,
Expand All @@ -620,6 +748,7 @@ def __init__(
version: str,
*,
cache_root: str = None,
load_tables: bool = True,
):
self.cache_root = audeer.mkdir(cache_root)
r"""Cache root folder."""
Expand Down
92 changes: 92 additions & 0 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,14 @@ def test_dataset(audb_cache, tmpdir, repository, db, request):
]
assert dataset.schemes_table == expected_schemes_table

# segment_durations
expected_segment_durations = [0.5, 0.5, 150, 151]
assert dataset.segment_durations == expected_segment_durations

# segments
expected_segments = str(len(db.segments))
assert dataset.segments == expected_segments

# short_description
max_desc_length = 150
expected_description = (
Expand Down Expand Up @@ -442,3 +450,87 @@ def test_dataset_cache_loading(audb_cache, tmpdir, repository, db, request):
# to compare it.
assert str(dataset.header) == str(header)
assert dataset.repository_object == repository


class TestDatasetLoadTables:
r"""Test load_tables argument of audbcards.Dataset."""

@classmethod
@pytest.fixture(autouse=True)
def prepare(cls, cache, medium_db):
r"""Provide test class with cache, database name and database version.
Args:
cache: cache fixture
medium_db: medium_db fixture
"""
cls.name = medium_db.name
cls.version = pytest.VERSION
cls.cache_root = cache

def assert_has_table_properties(self, expected: bool):
r"""Assert dataset holds table related cached properties.
Args:
expected: if ``True``,
``dataset`` is expected to contain table related properties
"""
table_related_properties = [
"segment_durations",
"segments",
]
for table_related_property in table_related_properties:
if expected:
assert table_related_property in self.dataset.__dict__
else:
assert table_related_property not in self.dataset.__dict__

def load_dataset(self, *, load_tables: bool):
r"""Load dataset.
Call ``audbcards.Dataset`` and assign result to ``self.dataset``.
Args:
load_tables: if ``True``,
it caches properties,
that need to load filewise/segmented tables
"""
self.dataset = audbcards.Dataset(
self.name,
self.version,
cache_root=self.cache_root,
load_tables=load_tables,
)

@pytest.mark.parametrize("load_tables_first", [True, False])
def test_load_tables(self, load_tables_first):
r"""Load dataset with/without table related properties.
This tests if the table related arguments
are stored or omitted in cache,
dependent on the ``load_tables`` argument.
It also loads the dataset another two times from cache,
with changing ``load_tables``
arguments,
which should always result
in existing table related properties,
as a cache stored first with ``load_tables=False``,
should be updated when loading again with ``load_tables=True``.
Args:
load_tables_first: if ``True``,
it calls ``audbcards.Dataset``
with ``load_tables=True``
during it first call
"""
self.load_dataset(load_tables=load_tables_first)
self.assert_has_table_properties(load_tables_first)
self.load_dataset(load_tables=not load_tables_first)
self.assert_has_table_properties(True)
self.load_dataset(load_tables=load_tables_first)
self.assert_has_table_properties(True)

0 comments on commit ca8c520

Please sign in to comment.