Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Sep 22, 2023
1 parent b31dd89 commit aae717a
Show file tree
Hide file tree
Showing 3 changed files with 237 additions and 51 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -284,3 +284,4 @@ ds = open_dataset(
)
```

93 changes: 63 additions & 30 deletions ecml_tools/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,18 @@

LOG = logging.getLogger(__name__)

__all__ = ["open_dataset"]


class Dataset:
def subset(self, **kwargs):
def _subset(self, **kwargs):
if not kwargs:
return self

if "frequency" in kwargs:
frequency = kwargs.pop("frequency")

return Subset(self, self.frequency_to_indices(frequency)).subset(**kwargs)
return Subset(self, self._frequency_to_indices(frequency))._subset(**kwargs)

if "start" in kwargs or "end" in kwargs:
start = kwargs.pop("start")
Expand All @@ -37,27 +39,29 @@ def is_year(x):

if start is None or is_year(start):
if end is None or is_year(end):
return Subset(self, self.years_to_indices(start, end)).subset(
return Subset(self, self._years_to_indices(start, end))._subset(
**kwargs
)

raise NotImplementedError(f"Unsupported start/end: {start} {end}")

if "select" in kwargs:
select = kwargs.pop("select")
return Select(self, self.select_to_columns(select)).subset(**kwargs)
return Select(self, self._select_to_columns(select))._subset(**kwargs)

if "drop" in kwargs:
drop = kwargs.pop("drop")
return Select(self, self.drop_to_columns(drop)).subset(**kwargs)
return Select(self, self._drop_to_columns(drop))._subset(**kwargs)

if "reorder" in kwargs:
reorder = kwargs.pop("reorder")
return Select(self, self.reorder_to_columns(reorder)).subset(**kwargs)

return Select(self, self._reorder_to_columns(reorder))._subset(**kwargs)
if "rename" in kwargs:
rename = kwargs.pop("rename")
return Rename(self, rename)._subset(**kwargs)
raise NotImplementedError("Unsupported arguments: " + ", ".join(kwargs))

def frequency_to_indices(self, frequency):
def _frequency_to_indices(self, frequency):
requested_frequency = _frequency_to_hours(frequency)
dataset_frequency = _frequency_to_hours(self.frequency)
assert requested_frequency % dataset_frequency == 0
Expand All @@ -66,7 +70,7 @@ def frequency_to_indices(self, frequency):

return range(0, len(self), step)

def years_to_indices(self, start, end):
def _years_to_indices(self, start, end):
# TODO: optimize
start = self.dates[0].astype(object).year if start is None else start
end = self.dates[-1].astype(object).year if end is None else end
Expand All @@ -77,27 +81,27 @@ def years_to_indices(self, start, end):
if start <= date.astype(object).year <= end
]

def select_to_columns(self, vars):
def _select_to_columns(self, vars):
if isinstance(vars, set):
# We keep the order of the variables as they are in the zarr file
nvars = [v for v in self.name_to_index if v in vars]
assert len(nvars) == len(vars)
return self.select_to_columns(nvars)
return self._select_to_columns(nvars)

if not isinstance(vars, (list, tuple)):
vars = [vars]

return [self.name_to_index[v] for v in vars]

def drop_to_columns(self, vars):
def _drop_to_columns(self, vars):
if not isinstance(vars, (list, tuple, set)):
vars = [vars]

assert set(vars) <= set(self.name_to_index)

return sorted([v for k, v in self.name_to_index.items() if k not in vars])

def reorder_to_columns(self, vars):
def _reorder_to_columns(self, vars):
if isinstance(vars, (list, tuple)):
vars = {k: i for i, k in enumerate(vars)}

Expand All @@ -109,9 +113,6 @@ def reorder_to_columns(self, vars):

return indices

def date_to_index(self, date):
raise NotImplementedError()


class Zarr(Dataset):
def __init__(self, path):
Expand Down Expand Up @@ -197,6 +198,12 @@ class Forwards(Dataset):
def __init__(self, forward):
self.forward = forward

def __len__(self):
return len(self.forward)

def __getitem__(self, n):
return self.forward[n]

@property
def dates(self):
return self.forward.dates
Expand Down Expand Up @@ -229,6 +236,10 @@ def variables(self):
def statistics(self):
return self.forward.statistics

@property
def shape(self):
return self.forward.shape


class Combined(Forwards):
def __init__(self, datasets):
Expand Down Expand Up @@ -290,6 +301,10 @@ def __repr__(self):
lst = ", ".join(repr(d) for d in self.datasets)
return f"Concat({lst})"

@property
def shape(self):
return (len(self),) + self.datasets[0].shape[1:]


class Join(Combined):
def check_compatibility(self, d1, d2):
Expand All @@ -314,7 +329,7 @@ def __repr__(self):
lst = ", ".join(repr(d) for d in self.datasets)
return f"Join({lst})"

def overlay(self):
def _overlay(self):
indices = {}
i = 0
for d in self.datasets:
Expand All @@ -340,16 +355,20 @@ def overlay(self):

return Select(self, indices)

@property
@cached_property
def variables(self):
seen = set()
result = []
for d in self.datasets:
for v in d.variables:
assert v not in result, "Duplicate variable: " + v
result.append(v)
for d in reversed(self.datasets):
for v in reversed(d.variables):
while v in seen:
v = f"({v})"
seen.add(v)
result.insert(0, v)

return result

@property
@cached_property
def name_to_index(self):
return {k: i for i, k in enumerate(self.variables)}

Expand Down Expand Up @@ -406,9 +425,6 @@ def __init__(self, dataset, indices):

super().__init__(dataset)

def __len__(self):
return len(self.dataset)

def __getitem__(self, n):
row = self.dataset[n]
return row[self.indices]
Expand All @@ -430,6 +446,22 @@ def statistics(self):
return {k: v[self.indices] for k, v in self.dataset.statistics.items()}


class Rename(Forwards):
def __init__(self, dataset, rename):
super().__init__(dataset)
for n in rename:
assert n in dataset.variables
self._variables = [rename.get(v, v) for v in dataset.variables]

@property
def variables(self):
return self._variables

@cached_property
def name_to_index(self):
return {k: i for i, k in enumerate(self.variables)}


def _name_to_path(name):
if name.endswith(".zarr"):
return name
Expand Down Expand Up @@ -463,7 +495,7 @@ def _concat_or_join(datasets):
ranges = [(d.dates[0].astype(object), d.dates[-1].astype(object)) for d in datasets]

if len(set(ranges)) == 1:
return Join(datasets).overlay()
return Join(datasets)._overlay()

# Make sure the dates are disjoint
for i in range(len(ranges)):
Expand All @@ -484,7 +516,8 @@ def _concat_or_join(datasets):
s = ranges[i + 1]
if r[1] + datetime.timedelta(hours=frequency) != s[0]:
raise ValueError(
f"Datasets must be sorted by dates, with no gaps: {r} and {s} ({datasets[i]} {datasets[i+1]})"
"Datasets must be sorted by dates, with no gaps: "
f"{r} and {s} ({datasets[i]} {datasets[i+1]})"
)

return Concat(datasets)
Expand Down Expand Up @@ -525,6 +558,6 @@ def open_dataset(*args, **kwargs):
assert len(sets) > 0, (args, kwargs)

if len(sets) > 1:
return _concat_or_join(sets).subset(**kwargs)
return _concat_or_join(sets)._subset(**kwargs)

return sets[0].subset(**kwargs)
return sets[0]._subset(**kwargs)
Loading

0 comments on commit aae717a

Please sign in to comment.