From e5ee8d35d1a17de24c778ce8a8acff97c836b290 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sun, 22 Sep 2024 07:28:23 +0000 Subject: [PATCH] Fix call to as_mars() --- pyproject.toml | 3 ++- src/ai_models/inputs/interpolate.py | 5 ++++- src/ai_models/inputs/opendata.py | 15 +++++++++------ src/ai_models/inputs/transform.py | 1 + src/ai_models/model.py | 6 ++++-- src/ai_models/outputs/__init__.py | 12 ++++++++++++ 6 files changed, 32 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 86f9f1b..921dd81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ classifiers = [ dependencies = [ "cdsapi", - "earthkit-data>=0.10.1", + "earthkit-data>=0.10.3", "earthkit-meteo", "earthkit-regrid", "eccodes>=2.37", @@ -72,6 +72,7 @@ file = "ai_models.inputs.file:FileInput" mars = "ai_models.inputs.mars:MarsInput" cds = "ai_models.inputs.cds:CdsInput" ecmwf-open-data = "ai_models.inputs.opendata:OpenDataInput" +opendata = "ai_models.inputs.opendata:OpenDataInput" [project.entry-points."ai_models.output"] file = "ai_models.outputs:FileOutput" diff --git a/src/ai_models/inputs/interpolate.py b/src/ai_models/inputs/interpolate.py index 744ed18..3fec238 100644 --- a/src/ai_models/inputs/interpolate.py +++ b/src/ai_models/inputs/interpolate.py @@ -8,6 +8,7 @@ import logging import earthkit.regrid as ekr +import tqdm from earthkit.data.indexing.fieldlist import FieldArray from .transform import NewDataField @@ -22,7 +23,9 @@ def __init__(self, grid, source): def __call__(self, ds): result = [] - for f in ds: + for f in tqdm.tqdm(ds, delay=0.5, desc="Interpolating", leave=False): data = ekr.interpolate(f.to_numpy(), dict(grid=self.source), dict(grid=self.grid)) result.append(NewDataField(f, data)) + + LOG.info("Interpolated %d fields. Input shape %s, output shape %s.", len(result), ds[0].shape, result[0].shape) return FieldArray(result) diff --git a/src/ai_models/inputs/opendata.py b/src/ai_models/inputs/opendata.py index 935a0f8..61d4e7b 100644 --- a/src/ai_models/inputs/opendata.py +++ b/src/ai_models/inputs/opendata.py @@ -35,10 +35,10 @@ class OpenDataInput(RequestBasedInput): WHERE = "OPENDATA" RESOLS = { - (0.25, 0.25): ("0p25", (0.25, 0.25), False), - "N320": ("0p25", (0.25, 0.25), True), - "O96": ("0p25", (0.25, 0.25), True), - # (0.1, 0.1): ("0p25", (0.25, 0.25), False), + (0.25, 0.25): ("0p25", (0.25, 0.25), False, False), + "N320": ("0p25", (0.25, 0.25), True, False), + "O96": ("0p25", (0.25, 0.25), True, False), + (0.1, 0.1): ("0p25", (0.25, 0.25), True, True), } def __init__(self, owner, **kwargs): @@ -56,12 +56,15 @@ def _adjust(self, kwargs): if isinstance(grid, list): grid = tuple(grid) - kwargs["resol"], source, interp = self.RESOLS[grid] + kwargs["resol"], source, interp, oversampling = self.RESOLS[grid] r = dict(**kwargs) r.update(self.owner.retrieve) if interp: - logging.debug("Interpolating from %s to %s", source, grid) + + logging.info("Interpolating input data from %s to %s.", source, grid) + if oversampling: + logging.warning("This will oversample the input data.") return Interpolate(grid, source) else: return lambda x: x diff --git a/src/ai_models/inputs/transform.py b/src/ai_models/inputs/transform.py index 086cde0..2a10cce 100644 --- a/src/ai_models/inputs/transform.py +++ b/src/ai_models/inputs/transform.py @@ -25,6 +25,7 @@ class NewDataField(WrappedField): def __init__(self, field, data): super().__init__(field) self._data = data + self.shape = data.shape def to_numpy(self, flatten=False, dtype=None, index=None): data = self._data diff --git a/src/ai_models/model.py b/src/ai_models/model.py index eb75435..25c4249 100644 --- a/src/ai_models/model.py +++ b/src/ai_models/model.py @@ -127,7 +127,7 @@ def collect_archive_requests(self, written): # does not return always return recently set keys handle = handle.clone() - self.archiving[path].add(handle.as_mars()) + self.archiving[path].add(handle.as_namespace("mars")) def finalise(self): self.output.flush() @@ -536,6 +536,8 @@ def write_input_fields( accumulations_shape=None, ignore=None, ): + LOG.info("Starting date is %s", self.start_datetime) + LOG.info("Writing input fields") if ignore is None: ignore = [] @@ -553,7 +555,7 @@ def write_input_fields( if accumulations is not None: if accumulations_template is None: - accumulations_template = fields.sel(param="2t")[0] + accumulations_template = fields.sel(param="msl")[0] if accumulations_shape is None: accumulations_shape = accumulations_template.shape diff --git a/src/ai_models/outputs/__init__.py b/src/ai_models/outputs/__init__.py index 614fa64..2e9a616 100644 --- a/src/ai_models/outputs/__init__.py +++ b/src/ai_models/outputs/__init__.py @@ -67,6 +67,18 @@ def write(self, data, *args, check=False, **kwargs): raise ValueError(f"NaN values found in field. args={args} kwargs={kwargs}") if np.isinf(data).any(): raise ValueError(f"Infinite values found in field. args={args} kwargs={kwargs}") + + options = {} + options.update(self.grib_keys) + options.update(kwargs) + LOG.error("Failed to write data to %s %s", args, options) + cmd = [] + for k, v in options.items(): + if isinstance(v, (int, str, float)): + cmd.append("%s=%s" % (k, v)) + + LOG.error("grib_set -s%s", ",".join(cmd)) + raise if check: