From 1dc7b1136c80c7cb57c2066da0fd6021dbd3d6dc Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Thu, 29 Feb 2024 15:07:26 +0000 Subject: [PATCH 1/3] up --- ecml_tools/create/loaders.py | 12 ++- ecml_tools/create/statistics.py | 128 +++++++++++++----------- ecml_tools/create/writer.py | 166 ++++++++++---------------------- ecml_tools/create/zarr.py | 8 -- tests/create-missing.yaml | 5 +- 5 files changed, 134 insertions(+), 185 deletions(-) diff --git a/ecml_tools/create/loaders.py b/ecml_tools/create/loaders.py index eca0d71..5e15022 100644 --- a/ecml_tools/create/loaders.py +++ b/ecml_tools/create/loaders.py @@ -98,6 +98,14 @@ def read_dataset_metadata(self): self.dates[0], ) + def date_to_index(self, date): + if isinstance(date, str): + date = np.datetime64(date) + if isinstance(date, datetime.datetime): + date = np.datetime64(date) + assert type(date) is type(self.dates[0]), (type(date), type(self.dates[0])) + return np.where(self.dates == date)[0][0] + @cached_property def registry(self): return ZarrBuiltRegistry(self.path) @@ -377,7 +385,7 @@ def run(self): dates = [d for d in self.dates if d not in self.missing_dates] if self.missing_dates: - assert type(self.missing_dates[0]) == type(dates[0]), (type(self.missing_dates[0]), type(dates[0])) + assert type(self.missing_dates[0]) is type(dates[0]), (type(self.missing_dates[0]), type(dates[0])) dates_computed = self.statistics_registry.dates_computed for d in dates: @@ -392,7 +400,7 @@ def run(self): start = np.datetime64(start) end = np.datetime64(end) dates = [d for d in dates if d >= start and d <= end] - assert type(start) == type(dates[0]), (type(start), type(dates[0])) + assert type(start) is type(dates[0]), (type(start), type(dates[0])) stats = self.statistics_registry.get_aggregated(dates, self.variables_names) diff --git a/ecml_tools/create/statistics.py b/ecml_tools/create/statistics.py index 75d75bd..6d4cbea 100644 --- a/ecml_tools/create/statistics.py +++ b/ecml_tools/create/statistics.py @@ -6,17 +6,18 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. # +import datetime import glob +import hashlib import json import logging import os import pickle import shutil import socket -from collections import defaultdict, Counter +from collections import Counter, defaultdict from functools import cached_property - import numpy as np from ecml_tools.provenance import gather_provenance_info @@ -26,6 +27,18 @@ LOG = logging.getLogger(__name__) +def to_datetime(date): + if isinstance(date, str): + return np.datetime64(date) + if isinstance(date, datetime.datetime): + return np.datetime64(date) + return date + + +def to_datetimes(dates): + return [to_datetime(d) for d in dates] + + def check_variance(x, variables_names, minimum, maximum, mean, count, sums, squares): if (x >= 0).all(): return @@ -112,19 +125,13 @@ def delete(self): except FileNotFoundError: pass + def _hash_key(self, key): + return hashlib.sha256(str(key).encode("utf-8")).hexdigest() + def write(self, key, data, dates): self.create(exist_ok=True) - key_str = ( - str(key) - .replace("(", "") - .replace(")", "") - .replace(" ", "_") - .replace(",", "_") - .replace("None", "x") - .replace("__", "_") - .lower() - ) - path = os.path.join(self.dirname, f"{key_str}.npz") + h = self._hash_key(dates) + path = os.path.join(self.dirname, f"{h}.npz") if not self.overwrite: assert not os.path.exists(path), f"{path} already exists" @@ -134,34 +141,21 @@ def write(self, key, data, dates): pickle.dump((key, dates, data), f) shutil.move(tmp_path, path) - LOG.info(f"Written statistics data for {key} in {path} ({dates})") + LOG.info(f"Written statistics data for {len(dates)} dates in {path} ({dates})") def _gather_data(self): # use glob to read all pickles files = glob.glob(self.dirname + "/*.npz") LOG.info(f"Reading stats data, found {len(files)} in {self.dirname}") assert len(files) > 0, f"No files found in {self.dirname}" - - key_strs = dict() for f in files: with open(f, "rb") as f: - key, dates, data = pickle.load(f) - - key_str = str(key) - if key_str in key_strs: - raise Exception(f"Duplicate key {key}, found in {f} and {key_strs[key_str]}") - key_strs[key_str] = f - - yield key, dates, data - - @cached_property - def n_dates_computed(self): - return len(self.dates_computed) + yield pickle.load(f) @property def dates_computed(self): all_dates = [] - for key, dates, data in self._gather_data(): + for _, dates, data in self._gather_data(): all_dates += dates # assert no duplicates @@ -170,11 +164,11 @@ def dates_computed(self): raise StatisticsValueError(f"Duplicate dates found in statistics: {duplicates}") all_dates = normalise_dates(all_dates) + all_dates = sorted(all_dates) return all_dates def get_aggregated(self, dates, variables_names): - aggregator = StatAggregator(variables_names, self) - aggregator.read(dates) + aggregator = StatAggregator(dates, variables_names, self) return aggregator.aggregate() def __str__(self): @@ -194,12 +188,15 @@ def normalise_dates(dates): class StatAggregator: NAMES = ["minimum", "maximum", "sums", "squares", "count"] - def __init__(self, variables_names, owner): + def __init__(self, dates, variables_names, owner): + dates = sorted(dates) + dates = to_datetimes(dates) self.owner = owner - self.computed_dates = owner.dates_computed - self.shape = (len(self.computed_dates), len(variables_names)) + self.dates = dates self.variables_names = variables_names - print("Aggregating on ", self.shape, variables_names) + + self.shape = (len(self.dates), len(self.variables_names)) + print("Aggregating statistics on ", self.shape, self.variables_names) self.minimum = np.full(self.shape, np.nan, dtype=np.float64) self.maximum = np.full(self.shape, np.nan, dtype=np.float64) @@ -208,33 +205,52 @@ def __init__(self, variables_names, owner): self.count = np.full(self.shape, -1, dtype=np.int64) self.flags = np.full(self.shape, False, dtype=np.bool_) - def read(self, dates): - assert type(dates[0]) == type(self.computed_dates[0]), ( - dates[0], - self.computed_dates[0], - ) - - dates_bitmap = np.isin(self.computed_dates, dates) + self._read() + + def _date_to_index(self, date): + date = to_datetime(date) + assert type(date) is type(self.dates[0]), (type(date), type(self.dates[0])) + assert date in self.dates, f"Statistics for date {date} is not needed." + return np.where(self.dates == date)[0][0] + + def _read(self): + available_dates = [] + for _, dates, stats in self.owner._gather_data(): + assert isinstance(stats, dict), stats + for n in self.NAMES: + assert n in stats, (n, list(stats.keys())) + dates = to_datetimes(dates) + + indexes = [] + stats_indexes = [] + for i, d in enumerate(dates): + if d not in self.dates: + continue + stats_indexes.append(i) + indexes.append(self._date_to_index(d)) + available_dates.append(d) + + if not indexes: + continue - for key, dates, data in self.owner._gather_data(): - assert isinstance(data, dict), data - assert not np.any(self.flags[key]), f"Overlapping values for {key} {self.flags} ({dates})" - self.flags[key] = True + self.flags[indexes] = True for name in self.NAMES: array = getattr(self, name) - array[key] = data[name] - - if not np.all(self.flags[dates_bitmap]): - not_found = np.where(self.flags == False) # noqa: E712 - raise Exception(f"Missing statistics data for {not_found}", not_found) + data = stats[name] + data = data[stats_indexes] + array[indexes] = data - print(f"Selection statistics data from {self.minimum.shape[0]} to {self.minimum[dates_bitmap].shape[0]} dates.") - for name in self.NAMES: - array = getattr(self, name) - array = array[dates_bitmap] - setattr(self, name, array) + assert type(available_dates[0]) is type(self.dates[0]), (available_dates[0], self.dates[0]) + assert len(available_dates) == len(set(available_dates)), "Duplicate dates found in statistics" + for d in self.dates: + assert d in available_dates, f"Statistics for date {d} not precomputed." + assert len(available_dates) == len(self.dates) def aggregate(self): + if not np.all(self.flags): + not_found = np.where(self.flags == False) # noqa: E712 + raise Exception(f"Statistics not precomputed for {not_found}", not_found) + print(f"Aggregating statistics on {self.minimum.shape}") for name in self.NAMES: if name == "count": diff --git a/ecml_tools/create/writer.py b/ecml_tools/create/writer.py index 907d85a..322834e 100644 --- a/ecml_tools/create/writer.py +++ b/ecml_tools/create/writer.py @@ -15,6 +15,7 @@ import numpy as np from .check import check_data_values +from .statistics import compute_statistics from .utils import progress_bar, seconds LOG = logging.getLogger(__name__) @@ -59,98 +60,49 @@ def __call__(self, i): return i in self.parts -class ArrayLike: - def __init__(self, array, shape): - self.array = array - self.shape = shape - - def flush(): - pass - - def new_key(self, key, values_shape): - return key - - -class FastWriteArray(ArrayLike): +class ViewCacheArray: """A class that provides a caching mechanism for writing to a NumPy-like array. - The `FastWriteArray` instance is initialized with a NumPy-like array and its shape. + The is initialized with a NumPy-like array, a shape and a list to reindex the first dimension. The array is used to store the final data, while the cache is used to temporarily - store the data before flushing it to the array. The cache is a NumPy array of the same - shape as the final array, initialized with zeros. + store the data before flushing it to the array. The `flush` method copies the contents of the cache to the final array. """ - def __init__(self, array, shape): + def __init__(self, array, *, shape, indexes): + assert len(indexes) == shape[0], (len(indexes), shape[0]) self.array = array - self.shape = shape self.dtype = array.dtype - self.cache = np.zeros(shape, dtype=self.dtype) + self.cache = np.full(shape, np.nan, dtype=self.dtype) + self.indexes = indexes def __setitem__(self, key, value): self.cache[key] = value - def __getitem__(self, key): - return self.cache[key] - - def new_key(self, key, values_shape): - return self.array.new_key(key, values_shape) - def flush(self): - self.array[:] = self.cache[:] + for i in range(self.cache.shape[0]): + global_i = self.indexes[i] + self.array[global_i] = self.cache[i] - def compute_statistics_and_key(self, variables_names): - from .statistics import compute_statistics - - now = time.time() - stats = compute_statistics(self.cache, variables_names) - LOG.info(f"Computed statistics in {seconds(time.time()-now)}.") - new_key = self.new_key(slice(None, None), self.shape) +class ReindexFirst: + def __init__(self, indexes): + self.indexes = indexes - assert isinstance(self.array, OffsetView) - assert self.array.axis == 0, self.array.axis - new_key = (new_key[0], slice(None, None)) + def __call__(self, first, *others): + if isinstance(first, int): + return (self.indexes[first], *others) - return new_key, stats + if isinstance(first, slice): + start, stop, step = first.start, first.stop, first.step + start = self.indexes[start] + stop = self.indexes[stop] + return (slice(start, stop, step), *others) + if isinstance(first, tuple): + return ([self.indexes[_] for _ in first], *others) - -class OffsetView(ArrayLike): - """A view on a portion of the large_array. - - 'axis' is the axis along which the offset applies. 'shape' is the shape of the view. - """ - - def __init__(self, large_array, *, offset, axis, shape): - self.large_array = large_array - self.dtype = large_array.dtype - self.offset = offset - self.axis = axis - self.shape = shape - - def new_key(self, key, values_shape): - if isinstance(key, slice): - # Ensure that the slice covers the entire view along the axis. - assert key.start is None and key.stop is None, key - - # Create a new key for indexing the large array. - new_key = tuple( - (slice(self.offset, self.offset + values_shape[i]) if i == self.axis else slice(None)) - for i in range(len(self.shape)) - ) - else: - # For non-slice keys, adjust the key based on the offset and axis. - new_key = tuple(k + self.offset if i == self.axis else k for i, k in enumerate(key)) - return new_key - - def __setitem__(self, key, values): - new_key = self.new_key(key, values.shape) - - start = time.time() - LOG.info("Writing data to disk") - self.large_array[new_key] = values - LOG.info(f"Writing data done in {seconds(time.time()-start)}.") + raise NotImplementedError(type(first)) class DataWriter: @@ -164,47 +116,37 @@ def __init__(self, parts, full_array, parent, print=print): self.print = parent.print self.append_axis = parent.output.append_axis - self.n_cubes = len(parent.groups) + self.n_groups = len(parent.groups) + + @property + def variables_names(self): + return self.parent.variables_names def write(self, result, igroup, dates): cube = result.get_cube() - assert cube.extended_user_shape[0] == len(dates), ( - cube.extended_user_shape[0], - len(dates), - ) + assert cube.extended_user_shape[0] == len(dates), (cube.extended_user_shape[0], len(dates)) dates_in_data = cube.user_coords["valid_datetime"] dates_in_data = [datetime.datetime.fromisoformat(_) for _ in dates_in_data] assert dates_in_data == list(dates), (dates_in_data, list(dates)) - self.write_cube(cube, igroup) - - @property - def variables_names(self): - return self.parent.variables_names - def write_cube(self, cube, icube): - assert isinstance(icube, int), icube + assert isinstance(igroup, int), igroup shape = cube.extended_user_shape - dates = cube.user_coords["valid_datetime"] - slice = self.registry.get_slice_for(icube) - LOG.info( - f"Building dataset '{self.path}' i={icube} total={self.n_cubes} " - f"(total shape ={shape}) at {slice}, {self.full_array.chunks=}" - ) - self.print(f"Building dataset (total shape ={shape}) at {slice}, {self.full_array.chunks=}") + msg = f"Building data for group {igroup}/{self.n_groups} ({shape=} in {self.full_array.shape=})" + LOG.info(msg) + self.print(msg) - offset = slice.start - array = OffsetView(self.full_array, offset=offset, axis=self.append_axis, shape=shape) - array = FastWriteArray(array, shape=shape) + indexes = [self.parent.date_to_index(d) for d in dates_in_data] + array = ViewCacheArray(self.full_array, shape=shape, indexes=indexes) self.load_datacube(cube, array) - new_key, stats = array.compute_statistics_and_key(self.variables_names) - self.statistics_registry.write(new_key, stats, dates=dates) + stats = compute_statistics(array.cache, self.variables_names) + dates = cube.user_coords["valid_datetime"] + self.statistics_registry.write(indexes, stats, dates=dates) array.flush() - - self.registry.set_flag(icube) + self.registry.set_flag(igroup) def load_datacube(self, cube, array): start = time.time() @@ -222,31 +164,21 @@ def load_datacube(self, cube, array): for i, cubelet in enumerate(bar): now = time.time() data = cubelet.to_numpy() - cubelet_coords = cubelet.coords - - bar.set_description(f"Loading {i}/{total} {str(cubelet)} ({data.shape})") + local_indexes = cubelet.coords load += time.time() - now - j = cubelet_coords[1] + name = self.variables_names[local_indexes[1]] + check_data_values(data[:], name=name, log=[i, data.shape, local_indexes]) - check_data_values( - data[:], - name=self.variables_names[j], - log=[i, j, data.shape, cubelet_coords], - ) + bar.set_description(f"Loading {i}/{total} {name} {str(cubelet)} ({data.shape})") now = time.time() - array[cubelet_coords] = data + array[local_indexes] = data save += time.time() - now now = time.time() save += time.time() - now - LOG.info("Written.") - - self.print( - f"Elapsed: {seconds(time.time() - start)}," f" load time: {seconds(load)}," f" write time: {seconds(save)}." - ) - LOG.info( - f"Elapsed: {seconds(time.time() - start)}," f" load time: {seconds(load)}," f" write time: {seconds(save)}." - ) + msg = f"Elapsed: {seconds(time.time() - start)}, load time: {seconds(load)}, write time: {seconds(save)}." + self.print(msg) + LOG.info(msg) diff --git a/ecml_tools/create/zarr.py b/ecml_tools/create/zarr.py index 36d0bfe..5c2cb57 100644 --- a/ecml_tools/create/zarr.py +++ b/ecml_tools/create/zarr.py @@ -128,14 +128,6 @@ def add_to_history(self, action, **kwargs): history.append(new) z.attrs["history"] = history - def get_slice_for(self, i): - lengths = self.get_lengths() - assert i >= 0 and i < len(lengths) - - start = sum(lengths[:i]) - stop = sum(lengths[: (i + 1)]) - return slice(start, stop) - def get_lengths(self): z = self._open_read() return list(z["_build"][self.name_lengths][:]) diff --git a/tests/create-missing.yaml b/tests/create-missing.yaml index 0c05976..b45b815 100644 --- a/tests/create-missing.yaml +++ b/tests/create-missing.yaml @@ -15,7 +15,7 @@ dates: end: 2021-01-03 12:00:00 frequency: 12h group_by: monthly - missing: ['2021-01-03 00:00:00'] + missing: ['2020-12-30 12:00:00', '2021-01-03 00:00:00'] include: - mars: @@ -30,6 +30,7 @@ input: template: ${include.0.mars} param: - cos_latitude + #- sin_latitude output: chunking: { dates: 1, ensembles: 1 } @@ -39,4 +40,4 @@ output: remapping: param_level: "{param}_{levelist}" statistics: param_level - statistics_end: 2020 + statistics_end: 2021-01-02 From 10dcc645d7f785f1ac30fb2ec22e143cfc875c1a Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Thu, 29 Feb 2024 15:07:41 +0000 Subject: [PATCH 2/3] up --- ecml_tools/create/zarr.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/ecml_tools/create/zarr.py b/ecml_tools/create/zarr.py index 5c2cb57..79b6119 100644 --- a/ecml_tools/create/zarr.py +++ b/ecml_tools/create/zarr.py @@ -49,9 +49,7 @@ def add_zarr_dataset( if "fill_value" not in kwargs: if str(dtype).startswith("float") or str(dtype).startswith("numpy.float"): kwargs["fill_value"] = np.nan - elif str(dtype).startswith("datetime64") or str(dtype).startswith( - "numpy.datetime64" - ): + elif str(dtype).startswith("datetime64") or str(dtype).startswith("numpy.datetime64"): kwargs["fill_value"] = np.datetime64("NaT") # elif str(dtype).startswith("timedelta64") or str(dtype).startswith( # "numpy.timedelta64" @@ -147,9 +145,7 @@ def set_flag(self, i, value=True): def create(self, lengths, overwrite=False): self.new_dataset(name=self.name_lengths, array=np.array(lengths, dtype="i4")) - self.new_dataset( - name=self.name_flags, array=np.array([False] * len(lengths), dtype=bool) - ) + self.new_dataset(name=self.name_flags, array=np.array([False] * len(lengths), dtype=bool)) self.add_to_history("initialised") def reset(self, lengths): From 3e4a00ed20232be48735212824d878cf94cf7eb2 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Thu, 29 Feb 2024 17:39:14 +0000 Subject: [PATCH 3/3] statistics --- ecml_tools/create/__init__.py | 27 +----- ecml_tools/create/loaders.py | 166 ++++++++++---------------------- ecml_tools/create/statistics.py | 52 +++------- ecml_tools/create/writer.py | 29 +++--- 4 files changed, 88 insertions(+), 186 deletions(-) diff --git a/ecml_tools/create/__init__.py b/ecml_tools/create/__init__.py index 0c22af9..aefe64e 100644 --- a/ecml_tools/create/__init__.py +++ b/ecml_tools/create/__init__.py @@ -35,9 +35,7 @@ def init(self, check_name=False): from .loaders import InitialiseLoader if self._path_readable() and not self.overwrite: - raise Exception( - f"{self.path} already exists. Use overwrite=True to overwrite." - ) + raise Exception(f"{self.path} already exists. Use overwrite=True to overwrite.") with self._cache_context(): obj = InitialiseLoader.from_config( @@ -57,11 +55,7 @@ def load(self, parts=None): ) loader.load(parts=parts) - def statistics( - self, - force=False, - output=None, - ): + def statistics(self, force=False, output=None, start=None, end=None): from .loaders import StatisticsLoader loader = StatisticsLoader.from_dataset( @@ -71,25 +65,8 @@ def statistics( statistics_tmp=self.statistics_tmp, statistics_output=output, recompute=False, - ) - loader.run() - - def recompute_statistics( - self, - start=None, - end=None, - force=False, - ): - from .loaders import StatisticsLoader - - loader = StatisticsLoader.from_dataset( - path=self.path, - print=self.print, - force=force, - statistics_tmp=self.statistics_tmp, statistics_start=start, statistics_end=end, - recompute=True, ) loader.run() diff --git a/ecml_tools/create/loaders.py b/ecml_tools/create/loaders.py index 5e15022..244ff23 100644 --- a/ecml_tools/create/loaders.py +++ b/ecml_tools/create/loaders.py @@ -20,12 +20,7 @@ from .config import build_output, loader_config from .input import build_input from .statistics import TempStatistics -from .utils import ( - bytes, - compute_directory_sizes, - normalize_and_check_dates, - progress_bar, -) +from .utils import bytes, compute_directory_sizes, normalize_and_check_dates from .writer import CubesFilter, DataWriter from .zarr import ZarrBuiltRegistry, add_zarr_dataset @@ -82,6 +77,15 @@ def build_input(self): print(builder) return builder + def build_statistics_dates(self, start, end): + ds = open_dataset(self.path) + subset = ds.dates_interval_to_indices(start, end) + start, end = ds.dates[subset[0]], ds.dates[subset[-1]] + return ( + start.astype(datetime.datetime).isoformat(), + end.astype(datetime.datetime).isoformat(), + ) + def read_dataset_metadata(self): ds = open_dataset(self.path) self.dataset_shape = ds.shape @@ -90,21 +94,8 @@ def read_dataset_metadata(self): self.dates = ds.dates z = zarr.open(self.path, "r") - self.missing_dates = z.attrs.get("missing_dates") - if self.missing_dates: - self.missing_dates = [np.datetime64(d) for d in self.missing_dates] - assert type(self.missing_dates[0]) == type(self.dates[0]), ( - self.missing_dates[0], - self.dates[0], - ) - - def date_to_index(self, date): - if isinstance(date, str): - date = np.datetime64(date) - if isinstance(date, datetime.datetime): - date = np.datetime64(date) - assert type(date) is type(self.dates[0]), (type(date), type(self.dates[0])) - return np.where(self.dates == date)[0][0] + self.missing_dates = z.attrs.get("missing_dates", []) + self.missing_dates = [np.datetime64(d) for d in self.missing_dates] @cached_property def registry(self): @@ -284,7 +275,7 @@ def initialise(self, check_name=True): self.statistics_registry.create(exist_ok=False) self.registry.add_to_history("statistics_registry_initialised", version=self.statistics_registry.version) - statistics_start, statistics_end = self._build_statistics_dates( + statistics_start, statistics_end = self.build_statistics_dates( self.main_config.output.get("statistics_start"), self.main_config.output.get("statistics_end"), ) @@ -298,15 +289,6 @@ def initialise(self, check_name=True): assert chunks == self.get_zarr_chunks(), (chunks, self.get_zarr_chunks()) - def _build_statistics_dates(self, start, end): - ds = open_dataset(self.path) - subset = ds.dates_interval_to_indices(start, end) - start, end = ds.dates[subset[0]], ds.dates[subset[-1]] - return ( - start.astype(datetime.datetime).isoformat(), - end.astype(datetime.datetime).isoformat(), - ) - class ContentLoader(Loader): def __init__(self, config, **kwargs): @@ -322,7 +304,7 @@ def load(self, parts): self.registry.add_to_history("loading_data_start", parts=parts) z = zarr.open(self.path, mode="r+") - data_writer = DataWriter(parts, parent=self, full_array=z["data"], print=self.print) + data_writer = DataWriter(parts, full_array=z["data"], owner=self) total = len(self.registry.get_flags()) filter = CubesFilter(parts=parts, total=total) @@ -356,112 +338,70 @@ def __init__( statistics_start=None, statistics_end=None, force=False, - recompute=False, **kwargs, ): super().__init__(**kwargs) - assert statistics_start is None, statistics_start - assert statistics_end is None, statistics_end - - self.recompute = recompute - - self._write_to_dataset = True + self.user_statistics_start = statistics_start + self.user_statistics_end = statistics_end self.statistics_output = statistics_output + self.output_writer = { + None: self.write_stats_to_dataset, + "-": self.write_stats_to_stdout, + }.get(self.statistics_output, self.write_stats_to_file) + if config: self.main_config = loader_config(config) - self.check_complete(force=force) self.read_dataset_metadata() - def run(self): - # if requested, recompute statistics from data - # into the temporary statistics directory - # (this should have been done already when creating the dataset content) - if self.recompute: - self.recompute_temporary_statistics() - - dates = [d for d in self.dates if d not in self.missing_dates] + def _get_statistics_dates(self): + dates = self.dates + dtype = type(dates[0]) + # remove missing dates if self.missing_dates: - assert type(self.missing_dates[0]) is type(dates[0]), (type(self.missing_dates[0]), type(dates[0])) - - dates_computed = self.statistics_registry.dates_computed - for d in dates: - if d in self.missing_dates: - assert d not in dates_computed, (d, date_computed) - else: - assert d in dates_computed, (d, dates_computed) + assert type(self.missing_dates[0]) is dtype, (type(self.missing_dates[0]), dtype) + dates = [d for d in dates if d not in self.missing_dates] + # filter dates according the the start and end dates in the metadata z = zarr.open(self.path, mode="r") - start = z.attrs.get("statistics_start_date") - end = z.attrs.get("statistics_end_date") - start = np.datetime64(start) - end = np.datetime64(end) + start, end = z.attrs.get("statistics_start_date"), z.attrs.get("statistics_end_date") + start, end = np.datetime64(start), np.datetime64(end) + assert type(start) is dtype, (type(start), dtype) dates = [d for d in dates if d >= start and d <= end] - assert type(start) is type(dates[0]), (type(start), type(dates[0])) - stats = self.statistics_registry.get_aggregated(dates, self.variables_names) + # filter dates according the the user specified start and end dates + if self.user_statistics_start or self.user_statistics_end: + start, end = self.build_statistics_dates(self.user_statistics_start, self.user_statistics_end) + start, end = np.datetime64(start), np.datetime64(end) + assert type(start) is dtype, (type(start), dtype) + dates = [d for d in dates if d >= start and d <= end] - writer = { - None: self.write_stats_to_dataset, - "-": self.write_stats_to_stdout, - }.get(self.statistics_output, self.write_stats_to_file) - writer(stats) - - def check_complete(self, force): - if self._complete: - return - if not force: - raise Exception(f"❗Zarr {self.path} is not fully built. Use 'force' option.") - if self._write_to_dataset: - print(f"❗Zarr {self.path} is not fully built, not writting statistics into dataset.") - self._write_to_dataset = False - - @property - def _complete(self): - return all(self.registry.get_flags(sync=False)) - - def recompute_temporary_statistics(self): - raise NotImplementedError("Untested code") - self.statistics_registry.create(exist_ok=True) - - self.print( - f"Building temporary statistics from data {self.path}. " f"From {self.date_start} to {self.date_end}" - ) - - shape = (self.i_end + 1 - self.i_start, len(self.variables_names)) - detailed_stats = dict( - minimum=np.full(shape, np.nan, dtype=np.float64), - maximum=np.full(shape, np.nan, dtype=np.float64), - sums=np.full(shape, np.nan, dtype=np.float64), - squares=np.full(shape, np.nan, dtype=np.float64), - count=np.full(shape, -1, dtype=np.int64), - ) + return dates - ds = open_dataset(self.path) - key = (slice(self.i_start, self.i_end + 1), slice(None, None)) - for i in progress_bar( - desc="Computing Statistics", - iterable=range(self.i_start, self.i_end + 1), - ): - i_ = i - self.i_start - data = ds[slice(i, i + 1), :] - one = compute_statistics(data, self.variables_names) - for k, v in one.items(): - detailed_stats[k][i_] = v - - print(f"✅ Saving statistics for {key} shape={detailed_stats['count'].shape}") - self.statistics_registry[key] = detailed_stats - self.statistics_registry.add_provenance(name="provenance_recompute_statistics", config=self.main_config) + def run(self): + dates = self._get_statistics_dates() + stats = self.statistics_registry.get_aggregated(dates, self.variables_names) + self.output_writer(stats) def write_stats_to_file(self, stats): stats.save(self.statistics_output, provenance=dict(config=self.main_config)) print(f"✅ Statistics written in {self.statistics_output}") - return def write_stats_to_dataset(self, stats): + if self.user_statistics_start or self.user_statistics_end: + raise ValueError( + ( + "Cannot write statistics in dataset with user specified dates. " + "This would be conflicting with the dataset metadata." + ) + ) + + if not all(self.registry.get_flags(sync=False)): + raise Exception(f"❗Zarr {self.path} is not fully built, not writting statistics into dataset.") + for k in ["mean", "stdev", "minimum", "maximum", "sums", "squares", "count"]: self._add_dataset(name=k, array=stats[k]) diff --git a/ecml_tools/create/statistics.py b/ecml_tools/create/statistics.py index 6d4cbea..9b0a58b 100644 --- a/ecml_tools/create/statistics.py +++ b/ecml_tools/create/statistics.py @@ -15,8 +15,7 @@ import pickle import shutil import socket -from collections import Counter, defaultdict -from functools import cached_property +from collections import defaultdict import numpy as np @@ -125,12 +124,9 @@ def delete(self): except FileNotFoundError: pass - def _hash_key(self, key): - return hashlib.sha256(str(key).encode("utf-8")).hexdigest() - def write(self, key, data, dates): self.create(exist_ok=True) - h = self._hash_key(dates) + h = hashlib.sha256(str(dates).encode("utf-8")).hexdigest() path = os.path.join(self.dirname, f"{h}.npz") if not self.overwrite: @@ -146,27 +142,12 @@ def write(self, key, data, dates): def _gather_data(self): # use glob to read all pickles files = glob.glob(self.dirname + "/*.npz") - LOG.info(f"Reading stats data, found {len(files)} in {self.dirname}") + LOG.info(f"Reading stats data, found {len(files)} files in {self.dirname}") assert len(files) > 0, f"No files found in {self.dirname}" for f in files: with open(f, "rb") as f: yield pickle.load(f) - @property - def dates_computed(self): - all_dates = [] - for _, dates, data in self._gather_data(): - all_dates += dates - - # assert no duplicates - duplicates = [item for item, count in Counter(all_dates).items() if count > 1] - if duplicates: - raise StatisticsValueError(f"Duplicate dates found in statistics: {duplicates}") - - all_dates = normalise_dates(all_dates) - all_dates = sorted(all_dates) - return all_dates - def get_aggregated(self, dates, variables_names): aggregator = StatAggregator(dates, variables_names, self) return aggregator.aggregate() @@ -245,13 +226,13 @@ def _read(self): for d in self.dates: assert d in available_dates, f"Statistics for date {d} not precomputed." assert len(available_dates) == len(self.dates) + print(f"Statistics for {len(available_dates)} dates found.") def aggregate(self): if not np.all(self.flags): not_found = np.where(self.flags == False) # noqa: E712 raise Exception(f"Statistics not precomputed for {not_found}", not_found) - print(f"Aggregating statistics on {self.minimum.shape}") for name in self.NAMES: if name == "count": continue @@ -273,7 +254,7 @@ def aggregate(self): check_variance(x, self.variables_names, minimum, maximum, mean, count, sums, squares) stdev = np.sqrt(x) - stats = Statistics( + return Statistics( minimum=minimum, maximum=maximum, mean=mean, @@ -284,16 +265,13 @@ def aggregate(self): variables_names=self.variables_names, ) - return stats - class Statistics(dict): STATS_NAMES = ["minimum", "maximum", "mean", "stdev"] # order matter for __str__. - def __init__(self, *args, check=True, **kwargs): - super().__init__(*args, **kwargs) - if check: - self.check() + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.check() @property def size(self): @@ -324,12 +302,14 @@ def check(self): raise def __str__(self): - header = ["Variables"] + [self[name] for name in self.STATS_NAMES] - out = " ".join(header) - - for i, v in enumerate(self["variables_names"]): - out += " ".join([v] + [f"{x[i]:.2f}" for x in self.values()]) - return out + header = ["Variables"] + self.STATS_NAMES + out = [" ".join(header)] + + out += [ + " ".join([v] + [f"{self[n][i]:.2f}" for n in self.STATS_NAMES]) + for i, v in enumerate(self["variables_names"]) + ] + return "\n".join(out) def save(self, filename, provenance=None): assert filename.endswith(".json"), filename diff --git a/ecml_tools/create/writer.py b/ecml_tools/create/writer.py index 322834e..21f0a7a 100644 --- a/ecml_tools/create/writer.py +++ b/ecml_tools/create/writer.py @@ -106,21 +106,26 @@ def __call__(self, first, *others): class DataWriter: - def __init__(self, parts, full_array, parent, print=print): - self.parent = parent + def __init__(self, parts, full_array, owner): self.full_array = full_array - self.path = parent.path - self.statistics_registry = parent.statistics_registry - self.registry = parent.registry - self.print = parent.print + self.path = owner.path + self.statistics_registry = owner.statistics_registry + self.registry = owner.registry + self.print = owner.print + self.dates = owner.dates + self.variables_names = owner.variables_names - self.append_axis = parent.output.append_axis - self.n_groups = len(parent.groups) + self.append_axis = owner.output.append_axis + self.n_groups = len(owner.groups) - @property - def variables_names(self): - return self.parent.variables_names + def date_to_index(self, date): + if isinstance(date, str): + date = np.datetime64(date) + if isinstance(date, datetime.datetime): + date = np.datetime64(date) + assert type(date) is type(self.dates[0]), (type(date), type(self.dates[0])) + return np.where(self.dates == date)[0][0] def write(self, result, igroup, dates): cube = result.get_cube() @@ -137,7 +142,7 @@ def write(self, result, igroup, dates): LOG.info(msg) self.print(msg) - indexes = [self.parent.date_to_index(d) for d in dates_in_data] + indexes = [self.date_to_index(d) for d in dates_in_data] array = ViewCacheArray(self.full_array, shape=shape, indexes=indexes) self.load_datacube(cube, array)