Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Feb 29, 2024
2 parents 5bcdbf9 + 3e4a00e commit f3aadae
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 406 deletions.
27 changes: 2 additions & 25 deletions ecml_tools/create/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ def init(self, check_name=False):
from .loaders import InitialiseLoader

if self._path_readable() and not self.overwrite:
raise Exception(
f"{self.path} already exists. Use overwrite=True to overwrite."
)
raise Exception(f"{self.path} already exists. Use overwrite=True to overwrite.")

with self._cache_context():
obj = InitialiseLoader.from_config(
Expand All @@ -57,11 +55,7 @@ def load(self, parts=None):
)
loader.load(parts=parts)

def statistics(
self,
force=False,
output=None,
):
def statistics(self, force=False, output=None, start=None, end=None):
from .loaders import StatisticsLoader

loader = StatisticsLoader.from_dataset(
Expand All @@ -71,25 +65,8 @@ def statistics(
statistics_tmp=self.statistics_tmp,
statistics_output=output,
recompute=False,
)
loader.run()

def recompute_statistics(
self,
start=None,
end=None,
force=False,
):
from .loaders import StatisticsLoader

loader = StatisticsLoader.from_dataset(
path=self.path,
print=self.print,
force=force,
statistics_tmp=self.statistics_tmp,
statistics_start=start,
statistics_end=end,
recompute=True,
)
loader.run()

Expand Down
186 changes: 59 additions & 127 deletions ecml_tools/create/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@ def build_input(self):
print(builder)
return builder

def build_statistics_dates(self, start, end):
ds = open_dataset(self.path)
subset = ds.dates_interval_to_indices(start, end)
start, end = ds.dates[subset[0]], ds.dates[subset[-1]]
return (
start.astype(datetime.datetime).isoformat(),
end.astype(datetime.datetime).isoformat(),
)

def read_dataset_metadata(self):
ds = open_dataset(self.path)
self.dataset_shape = ds.shape
Expand All @@ -85,13 +94,8 @@ def read_dataset_metadata(self):
self.dates = ds.dates

z = zarr.open(self.path, "r")
self.missing_dates = z.attrs.get("missing_dates")
if self.missing_dates:
self.missing_dates = [np.datetime64(d) for d in self.missing_dates]
assert type(self.missing_dates[0]) is type(self.dates[0]), (
self.missing_dates[0],
self.dates[0],
)
self.missing_dates = z.attrs.get("missing_dates", [])
self.missing_dates = [np.datetime64(d) for d in self.missing_dates]

@cached_property
def registry(self):
Expand Down Expand Up @@ -170,18 +174,14 @@ def initialise(self, check_name=True):
)
print(f"Missing dates: {len(dates.missing)}")
lengths = [len(g) for g in self.groups]
self.print(
f"Found {len(dates)} datetimes {'+'.join([str(_) for _ in lengths])}."
)
self.print(f"Found {len(dates)} datetimes {'+'.join([str(_) for _ in lengths])}.")
print("-------------------------")

variables = self.minimal_input.variables
self.print(f"Found {len(variables)} variables : {','.join(variables)}.")

ensembles = self.minimal_input.ensembles
self.print(
f"Found {len(ensembles)} ensembles : {','.join([str(_) for _ in ensembles])}."
)
self.print(f"Found {len(ensembles)} ensembles : {','.join([str(_) for _ in ensembles])}.")

grid_points = self.minimal_input.grid_points
print(f"gridpoints size: {[len(i) for i in grid_points]}")
Expand All @@ -202,9 +202,7 @@ def initialise(self, check_name=True):
print(f"{chunks=}")
dtype = self.output.dtype

self.print(
f"Creating Dataset '{self.path}', with {total_shape=}, {chunks=} and {dtype=}"
)
self.print(f"Creating Dataset '{self.path}', with {total_shape=}, {chunks=} and {dtype=}")

metadata = {}
metadata["uuid"] = str(uuid.uuid4())
Expand Down Expand Up @@ -275,11 +273,9 @@ def initialise(self, check_name=True):

self.registry.create(lengths=lengths)
self.statistics_registry.create(exist_ok=False)
self.registry.add_to_history(
"statistics_registry_initialised", version=self.statistics_registry.version
)
self.registry.add_to_history("statistics_registry_initialised", version=self.statistics_registry.version)

statistics_start, statistics_end = self._build_statistics_dates(
statistics_start, statistics_end = self.build_statistics_dates(
self.main_config.output.get("statistics_start"),
self.main_config.output.get("statistics_end"),
)
Expand All @@ -293,15 +289,6 @@ def initialise(self, check_name=True):

assert chunks == self.get_zarr_chunks(), (chunks, self.get_zarr_chunks())

def _build_statistics_dates(self, start, end):
ds = open_dataset(self.path)
subset = ds.dates_interval_to_indices(start, end)
start, end = ds.dates[subset[0]], ds.dates[subset[-1]]
return (
start.astype(datetime.datetime).isoformat(),
end.astype(datetime.datetime).isoformat(),
)


class ContentLoader(Loader):
def __init__(self, config, **kwargs):
Expand All @@ -317,17 +304,13 @@ def load(self, parts):
self.registry.add_to_history("loading_data_start", parts=parts)

z = zarr.open(self.path, mode="r+")
data_writer = DataWriter(
parts, parent=self, full_array=z["data"], print=self.print
)
data_writer = DataWriter(parts, full_array=z["data"], owner=self)

total = len(self.registry.get_flags())
filter = CubesFilter(parts=parts, total=total)
for igroup, group in enumerate(self.groups):
if self.registry.get_flag(igroup):
LOG.info(
f" -> Skipping {igroup} total={len(self.groups)} (already done)"
)
LOG.info(f" -> Skipping {igroup} total={len(self.groups)} (already done)")
continue
if not filter(igroup):
continue
Expand All @@ -340,9 +323,7 @@ def load(self, parts):

self.registry.add_to_history("loading_data_end", parts=parts)
self.registry.add_provenance(name="provenance_load")
self.statistics_registry.add_provenance(
name="provenance_load", config=self.main_config
)
self.statistics_registry.add_provenance(name="provenance_load", config=self.main_config)

self.print_info()

Expand All @@ -357,119 +338,70 @@ def __init__(
statistics_start=None,
statistics_end=None,
force=False,
recompute=False,
**kwargs,
):
super().__init__(**kwargs)
assert statistics_start is None, statistics_start
assert statistics_end is None, statistics_end

self.recompute = recompute

self._write_to_dataset = True
self.user_statistics_start = statistics_start
self.user_statistics_end = statistics_end

self.statistics_output = statistics_output

self.output_writer = {
None: self.write_stats_to_dataset,
"-": self.write_stats_to_stdout,
}.get(self.statistics_output, self.write_stats_to_file)

if config:
self.main_config = loader_config(config)

self.check_complete(force=force)
self.read_dataset_metadata()

def run(self):
# if requested, recompute statistics from data
# into the temporary statistics directory
# (this should have been done already when creating the dataset content)
if self.recompute:
self.recompute_temporary_statistics()

dates = [d for d in self.dates if d not in self.missing_dates]
def _get_statistics_dates(self):
dates = self.dates
dtype = type(dates[0])

# remove missing dates
if self.missing_dates:
assert type(self.missing_dates[0]) <= type(dates[0]), (
type(self.missing_dates[0]),
type(dates[0]),
)

dates_computed = self.statistics_registry.dates_computed
for d in dates:
if d in self.missing_dates:
assert d not in dates_computed, (d, dates_computed)
else:
assert d in dates_computed, (d, dates_computed)
assert type(self.missing_dates[0]) is dtype, (type(self.missing_dates[0]), dtype)
dates = [d for d in dates if d not in self.missing_dates]

# filter dates according the the start and end dates in the metadata
z = zarr.open(self.path, mode="r")
start = z.attrs.get("statistics_start_date")
end = z.attrs.get("statistics_end_date")
start = np.datetime64(start)
end = np.datetime64(end)
start, end = z.attrs.get("statistics_start_date"), z.attrs.get("statistics_end_date")
start, end = np.datetime64(start), np.datetime64(end)
assert type(start) is dtype, (type(start), dtype)
dates = [d for d in dates if d >= start and d <= end]
assert type(start) is type(dates[0]), (type(start), type(dates[0]))

stats = self.statistics_registry.get_aggregated(dates, self.variables_names)
# filter dates according the the user specified start and end dates
if self.user_statistics_start or self.user_statistics_end:
start, end = self.build_statistics_dates(self.user_statistics_start, self.user_statistics_end)
start, end = np.datetime64(start), np.datetime64(end)
assert type(start) is dtype, (type(start), dtype)
dates = [d for d in dates if d >= start and d <= end]

writer = {
None: self.write_stats_to_dataset,
"-": self.write_stats_to_stdout,
}.get(self.statistics_output, self.write_stats_to_file)
writer(stats)

def check_complete(self, force):
if self._complete:
return
if not force:
raise Exception(
f"❗Zarr {self.path} is not fully built. Use 'force' option."
)
if self._write_to_dataset:
print(
f"❗Zarr {self.path} is not fully built, not writting statistics into dataset."
)
self._write_to_dataset = False

@property
def _complete(self):
return all(self.registry.get_flags(sync=False))

def recompute_temporary_statistics(self):
raise NotImplementedError("Untested code")
# self.statistics_registry.create(exist_ok=True)

# self.print(
# f"Building temporary statistics from data {self.path}. " f"From {self.date_start} to {self.date_end}"
# )

# shape = (self.i_end + 1 - self.i_start, len(self.variables_names))
# detailed_stats = dict(
# minimum=np.full(shape, np.nan, dtype=np.float64),
# maximum=np.full(shape, np.nan, dtype=np.float64),
# sums=np.full(shape, np.nan, dtype=np.float64),
# squares=np.full(shape, np.nan, dtype=np.float64),
# count=np.full(shape, -1, dtype=np.int64),
# )

# ds = open_dataset(self.path)
# key = (slice(self.i_start, self.i_end + 1), slice(None, None))
# for i in progress_bar(
# desc="Computing Statistics",
# iterable=range(self.i_start, self.i_end + 1),
# ):
# i_ = i - self.i_start
# data = ds[slice(i, i + 1), :]
# one = compute_statistics(data, self.variables_names)
# for k, v in one.items():
# detailed_stats[k][i_] = v

# print(f"✅ Saving statistics for {key} shape={detailed_stats['count'].shape}")
# self.statistics_registry[key] = detailed_stats
# self.statistics_registry.add_provenance(name="provenance_recompute_statistics", config=self.main_config)
return dates

def run(self):
dates = self._get_statistics_dates()
stats = self.statistics_registry.get_aggregated(dates, self.variables_names)
self.output_writer(stats)

def write_stats_to_file(self, stats):
stats.save(self.statistics_output, provenance=dict(config=self.main_config))
print(f"✅ Statistics written in {self.statistics_output}")
return

def write_stats_to_dataset(self, stats):
if self.user_statistics_start or self.user_statistics_end:
raise ValueError(
(
"Cannot write statistics in dataset with user specified dates. "
"This would be conflicting with the dataset metadata."
)
)

if not all(self.registry.get_flags(sync=False)):
raise Exception(f"❗Zarr {self.path} is not fully built, not writting statistics into dataset.")

for k in ["mean", "stdev", "minimum", "maximum", "sums", "squares", "count"]:
self._add_dataset(name=k, array=stats[k])

Expand Down
Loading

0 comments on commit f3aadae

Please sign in to comment.