Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Sep 23, 2023
1 parent 82d3222 commit 95eeb96
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 88 deletions.
132 changes: 50 additions & 82 deletions ecml_tools/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,65 +21,11 @@
__all__ = ["open_dataset"]


def _expand_slice(s, length):
start, stop, step = s.start, s.stop, s.step

if step is None:
step = 1
if start is None:
start = 0
if stop is None:
stop = length

if start < 0:
start = length - start

if stop < 0:
stop = length - stop

stop = min(stop, length)

assert step > 0
assert stop > start
assert start >= 0

return slice(start, stop, step)


def _intersect_slices(s1, s2):
assert s1.step == s2.step
return slice(
max(s1.start, s2.start),
min(s1.stop, s2.stop),
s1.step,
)


class Dataset:
@cached_property
def _cached_length(self):
def _len(self):
return len(self)

def __getitem__(self, n):
if isinstance(n, int):
return self._getitem_int(n)
return self._getitem_slice(_expand_slice(n, self._cached_length))

def _getitem_slice(self, s):
return np.vstack(
[
self._getitem_int(i)
for i in range(
s.start,
s.stop,
s.step,
)
]
)

def _getitem_int(self, i):
raise NotImplementedError()

def _subset(self, **kwargs):
if not kwargs:
return self
Expand Down Expand Up @@ -160,6 +106,9 @@ def _reorder_to_columns(self, vars):

return indices

def _expand_slice(self, s):
return slice(*s.indices(self._len))


class Zarr(Dataset):
def __init__(self, path):
Expand All @@ -173,12 +122,9 @@ def __init__(self, path):
def __len__(self):
return self.z.data.shape[0]

def _getitem_int(self, n):
def __getitem__(self, n):
return self.z.data[n]

# def _getitem_slice(self, n):
# return self.z.data[n]

@property
def shape(self):
return self.z.data.shape
Expand Down Expand Up @@ -251,11 +197,8 @@ def __init__(self, forward):
def __len__(self):
return len(self.forward)

def _getitem_int(self, n):
return self.forward._getitem_int(n)

# def _getitem_slice(self, n):
# return self.forward._getitem_slice(n)
def __getitem__(self, n):
return self.forward[n]

@property
def dates(self):
Expand Down Expand Up @@ -330,24 +273,33 @@ class Concat(Combined):
def __len__(self):
return sum(len(i) for i in self.datasets)

def _getitem_int(self, n):
def __getitem__(self, n):
if isinstance(n, slice):
return self._get_slice(n)

# TODO: optimize
k = 0
while n >= self.datasets[k]._cached_length:
n -= self.datasets[k]._cached_length
while n >= self.datasets[k]._len:
n -= self.datasets[k]._len
k += 1
return self.datasets[k]._getitem_int(n)
return self.datasets[k][n]

def _getitem_sliceX(self, n):
def _get_slice(self, s):
result = []
k = 0

start, stop, step = s.indices(self._len)

for d in self.datasets:
s = slice(k, k + d._cached_length, n.step)
t = _intersect_slices(n, s)
if t.stop > t.start:
t = slice(t.start - k, t.stop - k, t.step)
result.append(d._getitem_slice(t))
k += d._cached_length
length = d._len
begin = start
while begin < 0:
begin += step

result.append(d[begin:stop:step])

start -= length
stop -= length

return np.vstack(result)

def check_compatibility(self, d1, d2):
Expand Down Expand Up @@ -383,7 +335,12 @@ def check_compatibility(self, d1, d2):
def __len__(self):
return len(self.datasets[0])

def _getitem_int(self, n):
def _get_slice(self, s):
return np.vstack([self[i] for i in range(*s.indices(self._len))])

def __getitem__(self, n):
if isinstance(n, slice):
return self._get_slice(n)
return np.concatenate([d[n] for d in self.datasets], axis=0)

@cached_property
Expand Down Expand Up @@ -454,9 +411,18 @@ def __init__(self, dataset, indices):
# Forward other properties to the super dataset
super().__init__(dataset)

def _getitem_int(self, n):
def __getitem__(self, n):
if isinstance(n, slice):
return self._get_slice(n)
n = self.indices[n]
return self.dataset._getitem_int(n)
return self.dataset[n]

def _get_slice(self, s):
# TODO: check if the indices can be simplified to a slice
# the time checking maybe be longer than the time saved
# using a slice
indices = [self.indices[i] for i in range(*s.indices(self._len))]
return np.vstack([self.dataset[i] for i in indices])

def __len__(self):
return len(self.indices)
Expand Down Expand Up @@ -486,11 +452,13 @@ def __init__(self, dataset, indices):
self.indices = list(indices)
assert len(self.indices) > 0

# Forward other properties to the global dataset
# Forward other properties to the main dataset
super().__init__(dataset)

def _getitem_int(self, n):
row = self.dataset._getitem_int(n)
def __getitem__(self, n):
row = self.dataset[n]
if isinstance(n, slice):
return row[:, self.indices]
return row[self.indices]

@cached_property
Expand Down
31 changes: 25 additions & 6 deletions tests/test_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,25 @@ def same_stats(ds1, ds2, vars1, vars2=None):
).all()


def slices(ds):
start = 5
end = len(ds) - 5
step = len(ds) // 10
def slices(ds, start=None, end=None, step=None):
if start is None:
start = 5
if end is None:
end = len(ds) - 5
if step is None:
step = len(ds) // 10
print(start, end, step)
print(list(range(start, end, step)))
s = ds[start:end:step]
print(len(s))
print(s.shape)

assert s[0].shape == ds[0].shape, (
s.shape,
ds.shape,
len(list(range(start, end, step))),
list(range(start, end, step)),
)

for i, n in enumerate(range(start, end, step)):
assert (s[i] == ds[n]).all()

Expand Down Expand Up @@ -930,5 +940,14 @@ def test_slice_5():
assert (s[n - 3] == ds[n - 1]).all()


def test_slice_6():
ds = open_dataset([f"test-{year}-{year}-1h-o96-abcd" for year in range(1940, 2023)])

slices(ds)
slices(ds, 0, len(ds), 1)
slices(ds, 0, len(ds), 10)
slices(ds, 7, -123, 13)


if __name__ == "__main__":
test_concat()
test_select_1()

0 comments on commit 95eeb96

Please sign in to comment.