Skip to content

Commit

Permalink
testing
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Sep 20, 2023
1 parent 0d1dbe8 commit bc88a33
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 3 deletions.
24 changes: 22 additions & 2 deletions ecml_tools/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import re
from functools import cached_property

import numpy as np
import yaml
import zarr

Expand Down Expand Up @@ -68,8 +69,12 @@ def years_to_indices(self, start, end):

class Dataset(Base):
def __init__(self, path):
self.path = path
self.z = zarr.convenience.open(path, "r")
if isinstance(path, zarr.hierarchy.Group):
self.path = '-'
self.z = path
else:
self.path = path
self.z = zarr.convenience.open(path, "r")

def __len__(self):
return self.z.data.shape[0]
Expand Down Expand Up @@ -164,10 +169,21 @@ def resolution(self):
def frequency(self):
return self.datasets[0].frequency

def __len__(self):
return len(self.datasets[0])

def __repr__(self):
lst = ", ".join(repr(d) for d in self.datasets)
return f"Join({lst})"

def __getitem__(self, n):
return np.concatenate([d[n] for d in self.datasets], axis=0)

@cached_property
def shape(self):
cols = sum(d.shape[1] for d in self.datasets)
return (len(self), cols) + self.datasets[0].shape[2:]


class Subset(Base):
def __init__(self, dataset, indices):
Expand Down Expand Up @@ -195,6 +211,10 @@ def dates(self):


def name_to_path(name):

if isinstance(name, zarr.hierarchy.Group):
return name

if name.endswith(".zarr"):
return name

Expand Down
10 changes: 10 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,23 @@
)

print(z)
print(len(z))

z = open_dataset(
"aifs-ea-an-oper-0001-mars-o96-2021-6h-v2-only-z",
"aifs-ea-an-oper-0001-mars-o96-2021-6h-v2-without-z",
)

# z.save('new-zarr.zarr')

print(z)
print(z.shape)
print(len(z))

for i, e in enumerate(z):
print(i, e)
if i > 10:
break

exit()

Expand Down
85 changes: 84 additions & 1 deletion tests/test_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,89 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import datetime

import numpy as np
import zarr

from ecml_tools.data import Concat, Join, open_dataset


def ij(i, j):
return i * 10000 + j


def create_zarr(vars=["2t", "msl", "10u", "10v"], start=2021, end=2021, frequency=6):
root = zarr.group()

dates = []
date = datetime.datetime(start, 1, 1)
while date.year <= end:
dates.append(date)
date += datetime.timedelta(hours=frequency)

dates = np.array(dates, dtype="datetime64")

data = np.zeros((len(dates), len(vars)))
for i in range(len(dates)):
for j in range(len(vars)):
data[i, j] = ij(i, j)

root.data = data
root.dates = dates
root.attrs["frequency"] = frequency
root.attrs["resolution"] = 0

return root


def test_code():
pass # Empty for now
root = create_zarr()
z = open_dataset(root)
print(len(z))
print(z.dates)

for i, e in enumerate(z):
print(i, e)
if i > 10:
break


def test_concat():
z = open_dataset(
create_zarr(start=2021, end=2021),
create_zarr(start=2022, end=2022),
)
assert isinstance(z, Concat)
assert len(z) == 365 * 2 * 4
for i, row in enumerate(z):
n = i if i < 365 * 4 else i - 365 * 4
expect = [ij(n, 0), ij(n, 1), ij(n, 2), ij(n, 3)]
assert (row == expect).all()


def test_join():
z = open_dataset(
create_zarr(vars=["a", "b", "c", "d"]),
create_zarr(vars=["e", "f", "g", "h"]),
)

assert isinstance(z, Join)
assert len(z) == 365 * 4
for i, row in enumerate(z):
n = i
expect = [
ij(n, 0),
ij(n, 1),
ij(n, 2),
ij(n, 3),
ij(n, 0),
ij(n, 1),
ij(n, 2),
ij(n, 3),
]
assert (row == expect).all()


if __name__ == "__main__":
test_join()

0 comments on commit bc88a33

Please sign in to comment.