From dfe47ef5eaaeb2c96f7e8c1bc71f81c85303e5f4 Mon Sep 17 00:00:00 2001 From: Matthew Chantry Date: Sun, 28 Jul 2024 21:40:06 +0000 Subject: [PATCH 01/20] Added request info for ML --- src/ai_models/model.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/ai_models/model.py b/src/ai_models/model.py index a35e70e..d0bb163 100644 --- a/src/ai_models/model.py +++ b/src/ai_models/model.py @@ -374,6 +374,7 @@ def _requests_unfiltered(self): for date, time in self.datetimes(): # noqa F402 param, level = self.param_level_pl + # PL r = dict( levtype="pl", levelist=level, @@ -390,6 +391,25 @@ def _requests_unfiltered(self): result.append(dict(**r)) + # ML + param, level = self.param_level_ml + r = dict( + levtype="ml", + levelist=level, + param=param, + date=date, + time=time, + ) + r.update(first) + first = {} + + r.update(self._requests_extra) + + self.patch_retrieve_request(r) + + result.append(dict(**r)) + + # SFC r.update( dict( levtype="sfc", From a01b90c5a49ecd65de3514de8225bd9023fc3a6d Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 29 Jul 2024 06:21:11 +0000 Subject: [PATCH 02/20] finalisel ml bug fix --- src/ai_models/model.py | 71 +++++++++++++++++++++++++----------------- 1 file changed, 43 insertions(+), 28 deletions(-) diff --git a/src/ai_models/model.py b/src/ai_models/model.py index d0bb163..485588f 100644 --- a/src/ai_models/model.py +++ b/src/ai_models/model.py @@ -374,52 +374,67 @@ def _requests_unfiltered(self): for date, time in self.datetimes(): # noqa F402 param, level = self.param_level_pl - # PL r = dict( - levtype="pl", - levelist=level, - param=param, date=date, time=time, ) r.update(first) first = {} - r.update(self._requests_extra) + if param and level: - self.patch_retrieve_request(r) + # PL + r.update( + dict( + levtype="pl", + levelist=level, + param=param, + date=date, + time=time, + ) + ) + + r.update(self._requests_extra) + + self.patch_retrieve_request(r) - result.append(dict(**r)) + result.append(dict(**r)) # ML param, level = self.param_level_ml - r = dict( - levtype="ml", - levelist=level, - param=param, - date=date, - time=time, - ) - r.update(first) - first = {} - r.update(self._requests_extra) + if param and level: + r.update( + dict( + levtype="ml", + levelist=level, + param=param, + date=date, + time=time, + ) + ) + + r.update(self._requests_extra) - self.patch_retrieve_request(r) + self.patch_retrieve_request(r) - result.append(dict(**r)) + result.append(dict(**r)) - # SFC - r.update( - dict( - levtype="sfc", - param=self.param_sfc, + param = self.param_sfc + if param: + # SFC + r.update( + dict( + levtype="sfc", + param=self.param_sfc, + date=date, + time=time, + levelist="off", + ) ) - ) - r.pop("levelist", None) - self.patch_retrieve_request(r) - result.append(dict(**r)) + self.patch_retrieve_request(r) + result.append(dict(**r)) return result From 56297e1f21f21769eaa9efb2fca46ad6f08380ae Mon Sep 17 00:00:00 2001 From: Gert Mertes Date: Mon, 29 Jul 2024 11:27:32 +0000 Subject: [PATCH 03/20] Sort json archive request --- src/ai_models/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ai_models/model.py b/src/ai_models/model.py index 485588f..077b7c7 100644 --- a/src/ai_models/model.py +++ b/src/ai_models/model.py @@ -154,13 +154,13 @@ def finalise(self): def json_default(obj): if isinstance(obj, set): if len(obj) > 1: - return list(obj) + return sorted(list(obj)) else: return obj.pop() - return obj + raise TypeError print( - json.dumps(json_requests, separators=(",", ":"), default=json_default), + json.dumps(json_requests, separators=(",", ":"), default=json_default, sort_keys=True), file=f, ) From 5107c57e663e7676738698b86edbf4d478967b0d Mon Sep 17 00:00:00 2001 From: Gert Mertes Date: Tue, 10 Sep 2024 12:01:57 +0000 Subject: [PATCH 04/20] Replace climetlab with earthkit-data Co-authored-by: Sandor Kertesz --- src/ai_models/inputs/__init__.py | 33 ++++++++++++++++--------------- src/ai_models/model.py | 8 ++++---- src/ai_models/outputs/__init__.py | 4 ++-- src/ai_models/stepper.py | 2 +- 4 files changed, 24 insertions(+), 23 deletions(-) diff --git a/src/ai_models/inputs/__init__.py b/src/ai_models/inputs/__init__.py index 04e1d40..7380caa 100644 --- a/src/ai_models/inputs/__init__.py +++ b/src/ai_models/inputs/__init__.py @@ -8,7 +8,7 @@ import logging from functools import cached_property -import climetlab as cml +import earthkit.data as ekd import entrypoints LOG = logging.getLogger(__name__) @@ -27,10 +27,11 @@ def _patch(self, **kargs): def fields_sfc(self): param = self.owner.param_sfc if not param: - return cml.load_source("empty") + return ekd.from_source("empty") LOG.info(f"Loading surface fields from {self.WHERE}") - return cml.load_source( + + return ekd.from_source( "multi", [ self.sfc_load_source( @@ -51,10 +52,10 @@ def fields_sfc(self): def fields_pl(self): param, level = self.owner.param_level_pl if not (param and level): - return cml.load_source("empty") + return ekd.from_source("empty") LOG.info(f"Loading pressure fields from {self.WHERE}") - return cml.load_source( + return ekd.from_source( "multi", [ self.pl_load_source( @@ -75,10 +76,10 @@ def fields_pl(self): def fields_ml(self): param, level = self.owner.param_level_ml if not (param and level): - return cml.load_source("empty") + return ekd.from_source("empty") LOG.info(f"Loading model fields from {self.WHERE}") - return cml.load_source( + return ekd.from_source( "multi", [ self.ml_load_source( @@ -109,17 +110,17 @@ def __init__(self, owner, **kwargs): def pl_load_source(self, **kwargs): kwargs["levtype"] = "pl" logging.debug("load source mars %s", kwargs) - return cml.load_source("mars", kwargs) + return ekd.from_source("mars", kwargs) def sfc_load_source(self, **kwargs): kwargs["levtype"] = "sfc" logging.debug("load source mars %s", kwargs) - return cml.load_source("mars", kwargs) + return ekd.from_source("mars", kwargs) def ml_load_source(self, **kwargs): kwargs["levtype"] = "ml" logging.debug("load source mars %s", kwargs) - return cml.load_source("mars", kwargs) + return ekd.from_source("mars", kwargs) class CdsInput(RequestBasedInput): @@ -127,11 +128,11 @@ class CdsInput(RequestBasedInput): def pl_load_source(self, **kwargs): kwargs["product_type"] = "reanalysis" - return cml.load_source("cds", "reanalysis-era5-pressure-levels", kwargs) + return ekd.from_source("cds", "reanalysis-era5-pressure-levels", kwargs) def sfc_load_source(self, **kwargs): kwargs["product_type"] = "reanalysis" - return cml.load_source("cds", "reanalysis-era5-single-levels", kwargs) + return ekd.from_source("cds", "reanalysis-era5-single-levels", kwargs) def ml_load_source(self, **kwargs): raise NotImplementedError("CDS does not support model levels") @@ -163,19 +164,19 @@ def pl_load_source(self, **kwargs): self._adjust(kwargs) kwargs["levtype"] = "pl" logging.debug("load source ecmwf-open-data %s", kwargs) - return cml.load_source("ecmwf-open-data", **kwargs) + return ekd.from_source("ecmwf-open-data", **kwargs) def sfc_load_source(self, **kwargs): self._adjust(kwargs) kwargs["levtype"] = "sfc" logging.debug("load source ecmwf-open-data %s", kwargs) - return cml.load_source("ecmwf-open-data", **kwargs) + return ekd.from_source("ecmwf-open-data", **kwargs) def ml_load_source(self, **kwargs): self._adjust(kwargs) kwargs["levtype"] = "ml" logging.debug("load source ecmwf-open-data %s", kwargs) - return cml.load_source("ecmwf-open-data", **kwargs) + return ekd.from_source("ecmwf-open-data", **kwargs) class FileInput: @@ -197,7 +198,7 @@ def fields_ml(self): @cached_property def all_fields(self): - return cml.load_source("file", self.file) + return ekd.from_source("file", self.file) def get_input(name, *args, **kwargs): diff --git a/src/ai_models/model.py b/src/ai_models/model.py index 077b7c7..e143772 100644 --- a/src/ai_models/model.py +++ b/src/ai_models/model.py @@ -14,10 +14,10 @@ from collections import defaultdict from functools import cached_property -import climetlab as cml +import earthkit.data as ekd import entrypoints import numpy as np -from climetlab.utils.humanize import seconds +from earthkit.data.utils.humanize import seconds from multiurl import download from .checkpoint import peek @@ -506,8 +506,8 @@ def provenance(self): def forcing_and_constants(self, date, param): source = self.all_fields[:1] - ds = cml.load_source( - "constants", + ds = ekd.from_source( + "forcings", source, date=date, param=param, diff --git a/src/ai_models/outputs/__init__.py b/src/ai_models/outputs/__init__.py index 0e395bc..614fa64 100644 --- a/src/ai_models/outputs/__init__.py +++ b/src/ai_models/outputs/__init__.py @@ -10,7 +10,7 @@ import warnings from functools import cached_property -import climetlab as cml +import earthkit.data as ekd import entrypoints import numpy as np @@ -50,7 +50,7 @@ def grib_keys(self): @cached_property def output(self): - return cml.new_grib_output( + return ekd.new_grib_output( self.path, split_output=True, **self.grib_keys, diff --git a/src/ai_models/stepper.py b/src/ai_models/stepper.py index 5a0c545..c0f8c42 100644 --- a/src/ai_models/stepper.py +++ b/src/ai_models/stepper.py @@ -8,7 +8,7 @@ import logging import time -from climetlab.utils.humanize import seconds +from earthkit.data.utils.humanize import seconds LOG = logging.getLogger(__name__) From 1f973e41b02f3cbec5631a29ab3cecbbbe88bf1b Mon Sep 17 00:00:00 2001 From: Gert Mertes Date: Tue, 10 Sep 2024 13:15:29 +0000 Subject: [PATCH 05/20] Set cache policy --- src/ai_models/__main__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/ai_models/__main__.py b/src/ai_models/__main__.py index 57a1a6b..37b8ec4 100644 --- a/src/ai_models/__main__.py +++ b/src/ai_models/__main__.py @@ -11,12 +11,16 @@ import shlex import sys +import earthkit.data as ekd + from .inputs import available_inputs from .model import Timer from .model import available_models from .model import load_model from .outputs import available_outputs +ekd.settings.set("cache-policy", "user") + LOG = logging.getLogger(__name__) From 729fdd1bea1b15e4f85cd2941b5c3d7614ecf20e Mon Sep 17 00:00:00 2001 From: Gert Mertes Date: Tue, 10 Sep 2024 13:38:09 +0000 Subject: [PATCH 06/20] Fix valid datetime --- src/ai_models/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ai_models/model.py b/src/ai_models/model.py index e143772..20d8e16 100644 --- a/src/ai_models/model.py +++ b/src/ai_models/model.py @@ -523,7 +523,7 @@ def gridpoints(self): @cached_property def start_datetime(self): - return self.all_fields.order_by(valid_datetime="ascending")[-1].datetime() + return self.all_fields.order_by(valid_datetime="ascending")[-1].datetime()["valid_time"] @property def constant_fields(self): @@ -545,7 +545,7 @@ def write_input_fields( if field.metadata("shortName") in ignore: continue - if field.valid_datetime() == self.start_datetime: + if field.datetime()["valid_time"] == self.start_datetime: self.write( None, template=field, From e6b6e9a1ad0532a212a25ef6ebb5750ddd2240c8 Mon Sep 17 00:00:00 2001 From: Gert Mertes Date: Tue, 10 Sep 2024 13:46:54 +0000 Subject: [PATCH 07/20] Update dependencies --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6974028..df6bb18 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,9 +42,9 @@ classifiers = [ dependencies = [ "entrypoints", "requests", - "climetlab>=0.23.0", + "earthkit-data>=0.10.1", + "eccodes>=2.37", "multiurl", - "ecmwflibs>=0.6.1", "gputil", "earthkit-meteo", "pyyaml", From 084a2e7675e8733080f8111e4b7af6eb8622b7e7 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 13 Sep 2024 17:35:26 +0100 Subject: [PATCH 08/20] refactor input --- .gitignore | 1 + pyproject.toml | 8 +- src/ai_models/inputs/__init__.py | 189 +------------------------------ src/ai_models/inputs/base.py | 100 ++++++++++++++++ src/ai_models/inputs/cds.py | 29 +++++ src/ai_models/inputs/file.py | 47 ++++++++ src/ai_models/inputs/mars.py | 36 ++++++ src/ai_models/inputs/opendata.py | 138 ++++++++++++++++++++++ 8 files changed, 357 insertions(+), 191 deletions(-) create mode 100644 src/ai_models/inputs/base.py create mode 100644 src/ai_models/inputs/cds.py create mode 100644 src/ai_models/inputs/file.py create mode 100644 src/ai_models/inputs/mars.py create mode 100644 src/ai_models/inputs/opendata.py diff --git a/.gitignore b/.gitignore index a629b88..dc6992c 100644 --- a/.gitignore +++ b/.gitignore @@ -177,3 +177,4 @@ bar dev/ *.out _version.py +*.tar diff --git a/pyproject.toml b/pyproject.toml index df6bb18..298f794 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,10 +64,10 @@ ai-models = "ai_models.__main__:main" version_file = "src/ai_models/_version.py" [project.entry-points."ai_models.input"] -file = "ai_models.inputs:FileInput" -mars = "ai_models.inputs:MarsInput" -cds = "ai_models.inputs:CdsInput" -opendata = "ai_models.inputs:OpenDataInput" +file = "ai_models.inputs.file:FileInput" +mars = "ai_models.inputs.mars:MarsInput" +cds = "ai_models.inputs.cds:CdsInput" +opendata = "ai_models.inputs.opendata:OpenDataInput" [project.entry-points."ai_models.output"] file = "ai_models.outputs:FileOutput" diff --git a/src/ai_models/inputs/__init__.py b/src/ai_models/inputs/__init__.py index 7380caa..005b5d6 100644 --- a/src/ai_models/inputs/__init__.py +++ b/src/ai_models/inputs/__init__.py @@ -9,198 +9,13 @@ from functools import cached_property import earthkit.data as ekd +import earthkit.regrid as ekr import entrypoints +from earthkit.data.indexing.fieldlist import FieldArray LOG = logging.getLogger(__name__) -class RequestBasedInput: - def __init__(self, owner, **kwargs): - self.owner = owner - - def _patch(self, **kargs): - r = dict(**kargs) - self.owner.patch_retrieve_request(r) - return r - - @cached_property - def fields_sfc(self): - param = self.owner.param_sfc - if not param: - return ekd.from_source("empty") - - LOG.info(f"Loading surface fields from {self.WHERE}") - - return ekd.from_source( - "multi", - [ - self.sfc_load_source( - **self._patch( - date=date, - time=time, - param=param, - grid=self.owner.grid, - area=self.owner.area, - **self.owner.retrieve, - ) - ) - for date, time in self.owner.datetimes() - ], - ) - - @cached_property - def fields_pl(self): - param, level = self.owner.param_level_pl - if not (param and level): - return ekd.from_source("empty") - - LOG.info(f"Loading pressure fields from {self.WHERE}") - return ekd.from_source( - "multi", - [ - self.pl_load_source( - **self._patch( - date=date, - time=time, - param=param, - level=level, - grid=self.owner.grid, - area=self.owner.area, - ) - ) - for date, time in self.owner.datetimes() - ], - ) - - @cached_property - def fields_ml(self): - param, level = self.owner.param_level_ml - if not (param and level): - return ekd.from_source("empty") - - LOG.info(f"Loading model fields from {self.WHERE}") - return ekd.from_source( - "multi", - [ - self.ml_load_source( - **self._patch( - date=date, - time=time, - param=param, - level=level, - grid=self.owner.grid, - area=self.owner.area, - ) - ) - for date, time in self.owner.datetimes() - ], - ) - - @cached_property - def all_fields(self): - return self.fields_sfc + self.fields_pl + self.fields_ml - - -class MarsInput(RequestBasedInput): - WHERE = "MARS" - - def __init__(self, owner, **kwargs): - self.owner = owner - - def pl_load_source(self, **kwargs): - kwargs["levtype"] = "pl" - logging.debug("load source mars %s", kwargs) - return ekd.from_source("mars", kwargs) - - def sfc_load_source(self, **kwargs): - kwargs["levtype"] = "sfc" - logging.debug("load source mars %s", kwargs) - return ekd.from_source("mars", kwargs) - - def ml_load_source(self, **kwargs): - kwargs["levtype"] = "ml" - logging.debug("load source mars %s", kwargs) - return ekd.from_source("mars", kwargs) - - -class CdsInput(RequestBasedInput): - WHERE = "CDS" - - def pl_load_source(self, **kwargs): - kwargs["product_type"] = "reanalysis" - return ekd.from_source("cds", "reanalysis-era5-pressure-levels", kwargs) - - def sfc_load_source(self, **kwargs): - kwargs["product_type"] = "reanalysis" - return ekd.from_source("cds", "reanalysis-era5-single-levels", kwargs) - - def ml_load_source(self, **kwargs): - raise NotImplementedError("CDS does not support model levels") - - -class OpenDataInput(RequestBasedInput): - WHERE = "OPENDATA" - - RESOLS = {(0.25, 0.25): "0p25"} - - def __init__(self, owner, **kwargs): - self.owner = owner - - def _adjust(self, kwargs): - if "level" in kwargs: - # OpenData uses levelist instead of level - kwargs["levelist"] = kwargs.pop("level") - - grid = kwargs.pop("grid") - if isinstance(grid, list): - grid = tuple(grid) - - kwargs["resol"] = self.RESOLS[grid] - r = dict(**kwargs) - r.update(self.owner.retrieve) - return r - - def pl_load_source(self, **kwargs): - self._adjust(kwargs) - kwargs["levtype"] = "pl" - logging.debug("load source ecmwf-open-data %s", kwargs) - return ekd.from_source("ecmwf-open-data", **kwargs) - - def sfc_load_source(self, **kwargs): - self._adjust(kwargs) - kwargs["levtype"] = "sfc" - logging.debug("load source ecmwf-open-data %s", kwargs) - return ekd.from_source("ecmwf-open-data", **kwargs) - - def ml_load_source(self, **kwargs): - self._adjust(kwargs) - kwargs["levtype"] = "ml" - logging.debug("load source ecmwf-open-data %s", kwargs) - return ekd.from_source("ecmwf-open-data", **kwargs) - - -class FileInput: - def __init__(self, owner, file, **kwargs): - self.file = file - self.owner = owner - - @cached_property - def fields_sfc(self): - return self.all_fields.sel(levtype="sfc") - - @cached_property - def fields_pl(self): - return self.all_fields.sel(levtype="pl") - - @cached_property - def fields_ml(self): - return self.all_fields.sel(levtype="ml") - - @cached_property - def all_fields(self): - return ekd.from_source("file", self.file) - - def get_input(name, *args, **kwargs): return available_inputs()[name].load()(*args, **kwargs) diff --git a/src/ai_models/inputs/base.py b/src/ai_models/inputs/base.py new file mode 100644 index 0000000..acdadb4 --- /dev/null +++ b/src/ai_models/inputs/base.py @@ -0,0 +1,100 @@ +# (C) Copyright 2023 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from functools import cached_property + +import earthkit.data as ekd + +LOG = logging.getLogger(__name__) + + +class RequestBasedInput: + def __init__(self, owner, **kwargs): + self.owner = owner + + def _patch(self, **kargs): + r = dict(**kargs) + self.owner.patch_retrieve_request(r) + return r + + @cached_property + def fields_sfc(self): + param = self.owner.param_sfc + if not param: + return ekd.from_source("empty") + + LOG.info(f"Loading surface fields from {self.WHERE}") + + return ekd.from_source( + "multi", + [ + self.sfc_load_source( + **self._patch( + date=date, + time=time, + param=param, + grid=self.owner.grid, + area=self.owner.area, + **self.owner.retrieve, + ) + ) + for date, time in self.owner.datetimes() + ], + ) + + @cached_property + def fields_pl(self): + param, level = self.owner.param_level_pl + if not (param and level): + return ekd.from_source("empty") + + LOG.info(f"Loading pressure fields from {self.WHERE}") + return ekd.from_source( + "multi", + [ + self.pl_load_source( + **self._patch( + date=date, + time=time, + param=param, + level=level, + grid=self.owner.grid, + area=self.owner.area, + ) + ) + for date, time in self.owner.datetimes() + ], + ) + + @cached_property + def fields_ml(self): + param, level = self.owner.param_level_ml + if not (param and level): + return ekd.from_source("empty") + + LOG.info(f"Loading model fields from {self.WHERE}") + return ekd.from_source( + "multi", + [ + self.ml_load_source( + **self._patch( + date=date, + time=time, + param=param, + level=level, + grid=self.owner.grid, + area=self.owner.area, + ) + ) + for date, time in self.owner.datetimes() + ], + ) + + @cached_property + def all_fields(self): + return self.fields_sfc + self.fields_pl + self.fields_ml diff --git a/src/ai_models/inputs/cds.py b/src/ai_models/inputs/cds.py new file mode 100644 index 0000000..9595c46 --- /dev/null +++ b/src/ai_models/inputs/cds.py @@ -0,0 +1,29 @@ +# (C) Copyright 2023 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging + +import earthkit.data as ekd + +from .base import RequestBasedInput + +LOG = logging.getLogger(__name__) + + +class CdsInput(RequestBasedInput): + WHERE = "CDS" + + def pl_load_source(self, **kwargs): + kwargs["product_type"] = "reanalysis" + return ekd.from_source("cds", "reanalysis-era5-pressure-levels", kwargs) + + def sfc_load_source(self, **kwargs): + kwargs["product_type"] = "reanalysis" + return ekd.from_source("cds", "reanalysis-era5-single-levels", kwargs) + + def ml_load_source(self, **kwargs): + raise NotImplementedError("CDS does not support model levels") diff --git a/src/ai_models/inputs/file.py b/src/ai_models/inputs/file.py new file mode 100644 index 0000000..fba571b --- /dev/null +++ b/src/ai_models/inputs/file.py @@ -0,0 +1,47 @@ +# (C) Copyright 2023 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from functools import cached_property + +import earthkit.data as ekd +import entrypoints + +LOG = logging.getLogger(__name__) + + +class FileInput: + def __init__(self, owner, file, **kwargs): + self.file = file + self.owner = owner + + @cached_property + def fields_sfc(self): + return self.all_fields.sel(levtype="sfc") + + @cached_property + def fields_pl(self): + return self.all_fields.sel(levtype="pl") + + @cached_property + def fields_ml(self): + return self.all_fields.sel(levtype="ml") + + @cached_property + def all_fields(self): + return ekd.from_source("file", self.file) + + +def get_input(name, *args, **kwargs): + return available_inputs()[name].load()(*args, **kwargs) + + +def available_inputs(): + result = {} + for e in entrypoints.get_group_all("ai_models.input"): + result[e.name] = e + return result diff --git a/src/ai_models/inputs/mars.py b/src/ai_models/inputs/mars.py new file mode 100644 index 0000000..820a3eb --- /dev/null +++ b/src/ai_models/inputs/mars.py @@ -0,0 +1,36 @@ +# (C) Copyright 2023 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging + +import earthkit.data as ekd + +from .base import RequestBasedInput + +LOG = logging.getLogger(__name__) + + +class MarsInput(RequestBasedInput): + WHERE = "MARS" + + def __init__(self, owner, **kwargs): + self.owner = owner + + def pl_load_source(self, **kwargs): + kwargs["levtype"] = "pl" + logging.debug("load source mars %s", kwargs) + return ekd.from_source("mars", kwargs) + + def sfc_load_source(self, **kwargs): + kwargs["levtype"] = "sfc" + logging.debug("load source mars %s", kwargs) + return ekd.from_source("mars", kwargs) + + def ml_load_source(self, **kwargs): + kwargs["levtype"] = "ml" + logging.debug("load source mars %s", kwargs) + return ekd.from_source("mars", kwargs) diff --git a/src/ai_models/inputs/opendata.py b/src/ai_models/inputs/opendata.py new file mode 100644 index 0000000..d12ecdb --- /dev/null +++ b/src/ai_models/inputs/opendata.py @@ -0,0 +1,138 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging + +import earthkit.data as ekd +import earthkit.regrid as ekr +from earthkit.data.indexing.fieldlist import FieldArray + +from .base import RequestBasedInput + +LOG = logging.getLogger(__name__) + + +def _noop(x): + return x + + +class NewDataField: + def __init__(self, field, data, param): + self.field = field + self.data = data + self.param = param + + def to_numpy(self, flatten=False, dtype=None, index=None): + data = self.data + if dtype is not None: + data = data.astype(dtype) + if flatten: + data = data.flatten() + return data + + def metadata(self, key, *args, **kwargs): + if key == "param": + return self.param + return self.field.metadata(key, *args, **kwargs) + + def __getattr__(self, name): + return getattr(self.field, name) + + def __repr__(self) -> str: + return repr(self.field) + + +class Interpolate: + def __init__(self, grid, source): + self.grid = list(grid) if isinstance(grid, tuple) else grid + self.source = list(source) if isinstance(source, tuple) else source + + def __call__(self, ds): + result = [] + for f in ds: + data = ekr.interpolate(f.to_numpy(), dict(grid=self.source), dict(grid=self.grid)) + result.append(NewDataField(f, data)) + return FieldArray(result) + + +def make_z_from_gh(previous): + g = 9.80665 # Same a pgen + + def _proc(ds): + + ds = previous(ds) + + result = [] + for f in ds: + if f.metadata("param") == "gh": + result.append(NewDataField(f, f.to_numpy() * g, param="z")) + else: + result.append(f) + return FieldArray(result) + + return _proc + + +class OpenDataInput(RequestBasedInput): + WHERE = "OPENDATA" + + RESOLS = { + (0.25, 0.25): ("0p25", (0.25, 0.25), False), + "N320": ("0p25", (0.25, 0.25), True), + "O96": ("0p25", (0.25, 0.25), True), + } + + def __init__(self, owner, **kwargs): + self.owner = owner + + def _adjust(self, kwargs): + if "level" in kwargs: + # OpenData uses levelist instead of level + kwargs["levelist"] = kwargs.pop("level") + + grid = kwargs.pop("grid") + if isinstance(grid, list): + grid = tuple(grid) + + kwargs["resol"], source, interp = self.RESOLS[grid] + r = dict(**kwargs) + r.update(self.owner.retrieve) + + if interp: + return Interpolate(grid, source) + else: + return _noop + + def pl_load_source(self, **kwargs): + pproc = self._adjust(kwargs) + kwargs["levtype"] = "pl" + + param = [p.lower() for p in kwargs["param"]] + assert isinstance(param, (list, tuple)) + + if "z" in param: + param = list(param) + param.remove("z") + if "gh" not in param: + param.append("gh") + kwargs["param"] = param + pproc = make_z_from_gh(pproc) + + logging.debug("load source ecmwf-open-data %s", kwargs) + return pproc(ekd.from_source("ecmwf-open-data", **kwargs)) + + def sfc_load_source(self, **kwargs): + pproc = self._adjust(kwargs) + kwargs["levtype"] = "sfc" + logging.debug("load source ecmwf-open-data %s", kwargs) + return pproc(ekd.from_source("ecmwf-open-data", **kwargs)) + + def ml_load_source(self, **kwargs): + pproc = self._adjust(kwargs) + kwargs["levtype"] = "ml" + logging.debug("load source ecmwf-open-data %s", kwargs) + return pproc(ekd.from_source("ecmwf-open-data", **kwargs)) From baba519772769a7a26acb30a2197d9b413b6d988 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 14 Sep 2024 09:03:28 +0000 Subject: [PATCH 09/20] add constants --- pyproject.toml | 1 + src/ai_models/inputs/compute.py | 33 ++++++ src/ai_models/inputs/interpolate.py | 28 +++++ src/ai_models/inputs/opendata.py | 165 +++++++++++++++++----------- src/ai_models/inputs/transform.py | 47 ++++++++ 5 files changed, 210 insertions(+), 64 deletions(-) create mode 100644 src/ai_models/inputs/compute.py create mode 100644 src/ai_models/inputs/interpolate.py create mode 100644 src/ai_models/inputs/transform.py diff --git a/pyproject.toml b/pyproject.toml index 298f794..0fb7f8e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ "multiurl", "gputil", "earthkit-meteo", + "earthkit-regrid", "pyyaml", "tqdm", ] diff --git a/src/ai_models/inputs/compute.py b/src/ai_models/inputs/compute.py new file mode 100644 index 0000000..41ece36 --- /dev/null +++ b/src/ai_models/inputs/compute.py @@ -0,0 +1,33 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging + +from earthkit.data.indexing.fieldlist import FieldArray + +from .transform import NewDataField +from .transform import NewMetadataField + +LOG = logging.getLogger(__name__) + + +def make_z_from_gh(previous): + g = 9.80665 # Same a pgen + + def _proc(ds): + + ds = previous(ds) + + result = [] + for f in ds: + if f.metadata("param") == "gh": + result.append(NewMetadataField(NewDataField(f, f.to_numpy() * g), param="z")) + else: + result.append(f) + return FieldArray(result) + + return _proc diff --git a/src/ai_models/inputs/interpolate.py b/src/ai_models/inputs/interpolate.py new file mode 100644 index 0000000..744ed18 --- /dev/null +++ b/src/ai_models/inputs/interpolate.py @@ -0,0 +1,28 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging + +import earthkit.regrid as ekr +from earthkit.data.indexing.fieldlist import FieldArray + +from .transform import NewDataField + +LOG = logging.getLogger(__name__) + + +class Interpolate: + def __init__(self, grid, source): + self.grid = list(grid) if isinstance(grid, tuple) else grid + self.source = list(source) if isinstance(source, tuple) else source + + def __call__(self, ds): + result = [] + for f in ds: + data = ekr.interpolate(f.to_numpy(), dict(grid=self.source), dict(grid=self.grid)) + result.append(NewDataField(f, data)) + return FieldArray(result) diff --git a/src/ai_models/inputs/opendata.py b/src/ai_models/inputs/opendata.py index d12ecdb..e0ec5f5 100644 --- a/src/ai_models/inputs/opendata.py +++ b/src/ai_models/inputs/opendata.py @@ -5,76 +5,26 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +import datetime +import itertools import logging import earthkit.data as ekd -import earthkit.regrid as ekr from earthkit.data.indexing.fieldlist import FieldArray from .base import RequestBasedInput +from .compute import make_z_from_gh +from .interpolate import Interpolate +from .transform import NewMetadataField LOG = logging.getLogger(__name__) -def _noop(x): - return x - - -class NewDataField: - def __init__(self, field, data, param): - self.field = field - self.data = data - self.param = param - - def to_numpy(self, flatten=False, dtype=None, index=None): - data = self.data - if dtype is not None: - data = data.astype(dtype) - if flatten: - data = data.flatten() - return data - - def metadata(self, key, *args, **kwargs): - if key == "param": - return self.param - return self.field.metadata(key, *args, **kwargs) - - def __getattr__(self, name): - return getattr(self.field, name) - - def __repr__(self) -> str: - return repr(self.field) - - -class Interpolate: - def __init__(self, grid, source): - self.grid = list(grid) if isinstance(grid, tuple) else grid - self.source = list(source) if isinstance(source, tuple) else source - - def __call__(self, ds): - result = [] - for f in ds: - data = ekr.interpolate(f.to_numpy(), dict(grid=self.source), dict(grid=self.grid)) - result.append(NewDataField(f, data)) - return FieldArray(result) - - -def make_z_from_gh(previous): - g = 9.80665 # Same a pgen - - def _proc(ds): - - ds = previous(ds) - - result = [] - for f in ds: - if f.metadata("param") == "gh": - result.append(NewDataField(f, f.to_numpy() * g, param="z")) - else: - result.append(f) - return FieldArray(result) - - return _proc +CONSTANTS = ( + "z", + "sdor", + "slor", +) class OpenDataInput(RequestBasedInput): @@ -84,6 +34,7 @@ class OpenDataInput(RequestBasedInput): (0.25, 0.25): ("0p25", (0.25, 0.25), False), "N320": ("0p25", (0.25, 0.25), True), "O96": ("0p25", (0.25, 0.25), True), + # (0.1, 0.1): ("0p25", (0.25, 0.25), False), } def __init__(self, owner, **kwargs): @@ -94,6 +45,9 @@ def _adjust(self, kwargs): # OpenData uses levelist instead of level kwargs["levelist"] = kwargs.pop("level") + if "area" in kwargs: + kwargs.pop("area") + grid = kwargs.pop("grid") if isinstance(grid, list): grid = tuple(grid) @@ -103,18 +57,21 @@ def _adjust(self, kwargs): r.update(self.owner.retrieve) if interp: + logging.debug("Interpolating from %s to %s", source, grid) return Interpolate(grid, source) else: - return _noop + return lambda x: x def pl_load_source(self, **kwargs): pproc = self._adjust(kwargs) kwargs["levtype"] = "pl" + request = kwargs.copy() param = [p.lower() for p in kwargs["param"]] assert isinstance(param, (list, tuple)) if "z" in param: + logging.warning("Parameter 'z' on pressure levels is not available in open data, using 'gh' instead") param = list(param) param.remove("z") if "gh" not in param: @@ -123,16 +80,96 @@ def pl_load_source(self, **kwargs): pproc = make_z_from_gh(pproc) logging.debug("load source ecmwf-open-data %s", kwargs) - return pproc(ekd.from_source("ecmwf-open-data", **kwargs)) + return self.check_pl(pproc(ekd.from_source("ecmwf-open-data", **kwargs)), request) def sfc_load_source(self, **kwargs): pproc = self._adjust(kwargs) kwargs["levtype"] = "sfc" + request = kwargs.copy() + + param = [p.lower() for p in kwargs["param"]] + assert isinstance(param, (list, tuple)) + + constant_params = [] + param = list(param) + for c in CONSTANTS: + if c in param: + param.remove(c) + constant_params.append(c) + + constants = ekd.from_source("empty") + + if constant_params: + if len(constant_params) == 1: + logging.warning( + f"Single level parameter '{constant_params[0]}' is not available in open data, using constants.grib2 instead" + ) + else: + logging.warning( + f"Single level parameters {constant_params} are not available in open data, using constants.grib2 instead" + ) + constants = [] + ds = ekd.from_source("file", "constants.grib2") + ds = ds.sel(param=constant_params) + + date = int(kwargs["date"]) + time = int(kwargs["time"]) + if time < 100: + time *= 100 + step = int(kwargs.get("step", 0)) + valid = datetime.datetime( + date // 10000, date // 100 % 100, date % 100, time // 100, time % 100 + ) + datetime.timedelta(hours=step) + + for f in ds: + + # assert False, (date, time, step) + constants.append( + NewMetadataField(f, valid_datetime=str(valid), date=date, time="%4d" % (time,), step=step) + ) + + constants = FieldArray(constants) + + kwargs["param"] = param + logging.debug("load source ecmwf-open-data %s", kwargs) - return pproc(ekd.from_source("ecmwf-open-data", **kwargs)) + + return self.check_sfc(pproc(ekd.from_source("ecmwf-open-data", **kwargs) + constants), request) def ml_load_source(self, **kwargs): pproc = self._adjust(kwargs) kwargs["levtype"] = "ml" + request = kwargs.copy() + logging.debug("load source ecmwf-open-data %s", kwargs) - return pproc(ekd.from_source("ecmwf-open-data", **kwargs)) + return self.check_ml(pproc(ekd.from_source("ecmwf-open-data", kwargs)), request) + + def check_pl(self, ds, request): + self._check(ds, "PL", request, "param", "levelist") + return ds + + def check_sfc(self, ds, request): + self._check(ds, "SFC", request, "param") + return ds + + def check_ml(self, ds, request): + self._check(ds, "ML", request, "param", "levelist") + return ds + + def _check(self, ds, what, request, *keys): + print("CHECKING", what) + expected = set() + for p in itertools.product(*[request[key] for key in keys]): + expected.add(p) + + found = set() + for f in ds: + found.add(tuple(f.metadata(key) for key in keys)) + + missing = expected - found + if missing: + raise ValueError(f"The following {what} parameters {missing} are not available in open data") + + extra = found - expected + if extra: + raise ValueError(f"Unexpected {what} parameters {extra} from open data") diff --git a/src/ai_models/inputs/transform.py b/src/ai_models/inputs/transform.py new file mode 100644 index 0000000..a11aa62 --- /dev/null +++ b/src/ai_models/inputs/transform.py @@ -0,0 +1,47 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging + +LOG = logging.getLogger(__name__) + + +class NewDataField: + def __init__(self, field, data): + self._field = field + self._data = data + + def to_numpy(self, flatten=False, dtype=None, index=None): + data = self._data + if dtype is not None: + data = data.astype(dtype) + if flatten: + data = data.flatten() + return data + + def __getattr__(self, name): + return getattr(self._field, name) + + def __repr__(self) -> str: + return repr(self._field) + + +class NewMetadataField: + def __init__(self, field, **kwargs): + self._field = field + self._metadata = kwargs + + def __getattr__(self, name): + return getattr(self._field, name) + + def __repr__(self) -> str: + return repr(self._field) + + def metadata(self, name, **kwargs): + if name in self._metadata: + return self._metadata[name] + return self._field.metadata(name, **kwargs) From 1351d586d2e1862645ac042b51803e6c94060bbd Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 14 Sep 2024 15:56:29 +0000 Subject: [PATCH 10/20] download constants --- src/ai_models/inputs/opendata.py | 44 ++++++++++++++++++++++++++----- src/ai_models/inputs/transform.py | 8 +++--- 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/src/ai_models/inputs/opendata.py b/src/ai_models/inputs/opendata.py index e0ec5f5..5de4ce3 100644 --- a/src/ai_models/inputs/opendata.py +++ b/src/ai_models/inputs/opendata.py @@ -8,9 +8,11 @@ import datetime import itertools import logging +import os import earthkit.data as ekd from earthkit.data.indexing.fieldlist import FieldArray +from multiurl import download from .base import RequestBasedInput from .compute import make_z_from_gh @@ -26,6 +28,8 @@ "slor", ) +CONSTANTS_URL = "https://get.ecmwf.int/repository/test-data/ai-models/opendata/constants.grib2" + class OpenDataInput(RequestBasedInput): WHERE = "OPENDATA" @@ -102,14 +106,30 @@ def sfc_load_source(self, **kwargs): if constant_params: if len(constant_params) == 1: logging.warning( - f"Single level parameter '{constant_params[0]}' is not available in open data, using constants.grib2 instead" + f"Single level parameter '{constant_params[0]}' is" + " not available in open data, using constants.grib2 instead" ) else: logging.warning( - f"Single level parameters {constant_params} are not available in open data, using constants.grib2 instead" + f"Single level parameters {constant_params} are" + " not available in open data, using constants.grib2 instead" ) constants = [] - ds = ekd.from_source("file", "constants.grib2") + + cachedir = os.path.expanduser("~/.cache/ai-models") + basename = os.path.basename(CONSTANTS_URL) + + if not os.path.exists(cachedir): + os.makedirs(cachedir) + + path = os.path.join(cachedir, basename) + + if not os.path.exists(path): + logging.info("Downloading %s to %s", CONSTANTS_URL, path) + download(CONSTANTS_URL, path + ".tmp") + os.rename(path + ".tmp", path) + + ds = ekd.from_source("file", path) ds = ds.sel(param=constant_params) date = int(kwargs["date"]) @@ -125,7 +145,13 @@ def sfc_load_source(self, **kwargs): # assert False, (date, time, step) constants.append( - NewMetadataField(f, valid_datetime=str(valid), date=date, time="%4d" % (time,), step=step) + NewMetadataField( + f, + valid_datetime=str(valid), + date=date, + time="%4d" % (time,), + step=step, + ) ) constants = FieldArray(constants) @@ -134,7 +160,13 @@ def sfc_load_source(self, **kwargs): logging.debug("load source ecmwf-open-data %s", kwargs) - return self.check_sfc(pproc(ekd.from_source("ecmwf-open-data", **kwargs) + constants), request) + fields = pproc(ekd.from_source("ecmwf-open-data", **kwargs) + constants) + + # Fix grib2/eccodes bug + + fields = FieldArray([NewMetadataField(f, levelist=None) for f in fields]) + + return self.check_sfc(fields, request) def ml_load_source(self, **kwargs): pproc = self._adjust(kwargs) @@ -157,7 +189,7 @@ def check_ml(self, ds, request): return ds def _check(self, ds, what, request, *keys): - print("CHECKING", what) + expected = set() for p in itertools.product(*[request[key] for key in keys]): expected.add(p) diff --git a/src/ai_models/inputs/transform.py b/src/ai_models/inputs/transform.py index a11aa62..29ee983 100644 --- a/src/ai_models/inputs/transform.py +++ b/src/ai_models/inputs/transform.py @@ -41,7 +41,7 @@ def __getattr__(self, name): def __repr__(self) -> str: return repr(self._field) - def metadata(self, name, **kwargs): - if name in self._metadata: - return self._metadata[name] - return self._field.metadata(name, **kwargs) + def metadata(self, *args, **kwargs): + if len(args) == 1 and args[0] in self._metadata: + return self._metadata[args[0]] + return self._field.metadata(*args, **kwargs) From 5eab7110ea8eabff579131acf41088b689f4d033 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 14 Sep 2024 16:21:14 +0000 Subject: [PATCH 11/20] missing depencies --- pyproject.toml | 13 ++++++++----- src/ai_models/model.py | 1 - src/ai_models/remote/model.py | 4 ++-- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0fb7f8e..4c3dc6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,15 +40,18 @@ classifiers = [ ] dependencies = [ - "entrypoints", - "requests", + "cdsapi", "earthkit-data>=0.10.1", - "eccodes>=2.37", - "multiurl", - "gputil", "earthkit-meteo", "earthkit-regrid", + "eccodes>=2.37", + "ecmwf-api-client", + "ecmwf-opendata", + "entrypoints", + "gputil", + "multiurl", "pyyaml", + "requests", "tqdm", ] diff --git a/src/ai_models/model.py b/src/ai_models/model.py index 20d8e16..eb75435 100644 --- a/src/ai_models/model.py +++ b/src/ai_models/model.py @@ -88,7 +88,6 @@ def __init__(self, input, output, download_assets, **kwargs): LOG.debug("Asset directory is %s", self.assets) try: - # For CliMetLab, when date=-1 self.date = int(self.date) except ValueError: pass diff --git a/src/ai_models/remote/model.py b/src/ai_models/remote/model.py index 629a820..85377b9 100644 --- a/src/ai_models/remote/model.py +++ b/src/ai_models/remote/model.py @@ -4,7 +4,7 @@ import tempfile from functools import cached_property -import climetlab as cml +import earthkit.data as ekd from ..model import Model from .api import RemoteAPI @@ -45,7 +45,7 @@ def run(self): self.api.run(self.cfg) - ds = cml.load_source("file", output_file) + ds = ekd.from_source("file", output_file) for field in ds: self.write(None, template=field) From 3068f2fec1a9909982082a5eb86d1780ade09210 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 14 Sep 2024 16:45:18 +0000 Subject: [PATCH 12/20] implement index in field --- src/ai_models/inputs/transform.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/ai_models/inputs/transform.py b/src/ai_models/inputs/transform.py index 29ee983..f5c4815 100644 --- a/src/ai_models/inputs/transform.py +++ b/src/ai_models/inputs/transform.py @@ -21,6 +21,8 @@ def to_numpy(self, flatten=False, dtype=None, index=None): data = data.astype(dtype) if flatten: data = data.flatten() + if index is not None: + data = data[index] return data def __getattr__(self, name): From c9c7545b4ed874b22a85ad52bebbc82fb3e21c40 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 14 Sep 2024 16:53:02 +0000 Subject: [PATCH 13/20] tidy --- src/ai_models/inputs/transform.py | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/src/ai_models/inputs/transform.py b/src/ai_models/inputs/transform.py index f5c4815..086cde0 100644 --- a/src/ai_models/inputs/transform.py +++ b/src/ai_models/inputs/transform.py @@ -10,9 +10,20 @@ LOG = logging.getLogger(__name__) -class NewDataField: - def __init__(self, field, data): +class WrappedField: + def __init__(self, field): self._field = field + + def __getattr__(self, name): + return getattr(self._field, name) + + def __repr__(self) -> str: + return repr(self._field) + + +class NewDataField(WrappedField): + def __init__(self, field, data): + super().__init__(field) self._data = data def to_numpy(self, flatten=False, dtype=None, index=None): @@ -25,24 +36,12 @@ def to_numpy(self, flatten=False, dtype=None, index=None): data = data[index] return data - def __getattr__(self, name): - return getattr(self._field, name) - def __repr__(self) -> str: - return repr(self._field) - - -class NewMetadataField: +class NewMetadataField(WrappedField): def __init__(self, field, **kwargs): - self._field = field + super().__init__(field) self._metadata = kwargs - def __getattr__(self, name): - return getattr(self._field, name) - - def __repr__(self) -> str: - return repr(self._field) - def metadata(self, *args, **kwargs): if len(args) == 1 and args[0] in self._metadata: return self._metadata[args[0]] From e27e6e7f14091dc449994746d9ae1b05abc9e0da Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sun, 15 Sep 2024 16:56:15 +0000 Subject: [PATCH 14/20] change constants url --- src/ai_models/inputs/opendata.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/ai_models/inputs/opendata.py b/src/ai_models/inputs/opendata.py index 5de4ce3..9c3a1d9 100644 --- a/src/ai_models/inputs/opendata.py +++ b/src/ai_models/inputs/opendata.py @@ -28,7 +28,7 @@ "slor", ) -CONSTANTS_URL = "https://get.ecmwf.int/repository/test-data/ai-models/opendata/constants.grib2" +CONSTANTS_URL = "https://get.ecmwf.int/repository/test-data/ai-models/opendata/constants-{resol}.grib2" class OpenDataInput(RequestBasedInput): @@ -62,12 +62,12 @@ def _adjust(self, kwargs): if interp: logging.debug("Interpolating from %s to %s", source, grid) - return Interpolate(grid, source) + return (Interpolate(grid, source), source) else: - return lambda x: x + return (lambda x: x, source) def pl_load_source(self, **kwargs): - pproc = self._adjust(kwargs) + pproc, _ = self._adjust(kwargs) kwargs["levtype"] = "pl" request = kwargs.copy() @@ -87,7 +87,7 @@ def pl_load_source(self, **kwargs): return self.check_pl(pproc(ekd.from_source("ecmwf-open-data", **kwargs)), request) def sfc_load_source(self, **kwargs): - pproc = self._adjust(kwargs) + pproc, resol = self._adjust(kwargs) kwargs["levtype"] = "sfc" request = kwargs.copy() @@ -117,7 +117,8 @@ def sfc_load_source(self, **kwargs): constants = [] cachedir = os.path.expanduser("~/.cache/ai-models") - basename = os.path.basename(CONSTANTS_URL) + constant_url = CONSTANTS_URL.format(resol=resol) + basename = os.path.basename(constant_url) if not os.path.exists(cachedir): os.makedirs(cachedir) @@ -125,8 +126,8 @@ def sfc_load_source(self, **kwargs): path = os.path.join(cachedir, basename) if not os.path.exists(path): - logging.info("Downloading %s to %s", CONSTANTS_URL, path) - download(CONSTANTS_URL, path + ".tmp") + logging.info("Downloading %s to %s", constant_url, path) + download(constant_url, path + ".tmp") os.rename(path + ".tmp", path) ds = ekd.from_source("file", path) @@ -169,7 +170,7 @@ def sfc_load_source(self, **kwargs): return self.check_sfc(fields, request) def ml_load_source(self, **kwargs): - pproc = self._adjust(kwargs) + pproc, _ = self._adjust(kwargs) kwargs["levtype"] = "ml" request = kwargs.copy() From 3d63cb09975f3fcf1018ec9fa08a94c415941f82 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sun, 15 Sep 2024 17:39:39 +0000 Subject: [PATCH 15/20] bug fix --- src/ai_models/inputs/opendata.py | 36 ++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/src/ai_models/inputs/opendata.py b/src/ai_models/inputs/opendata.py index 9c3a1d9..16634c0 100644 --- a/src/ai_models/inputs/opendata.py +++ b/src/ai_models/inputs/opendata.py @@ -62,12 +62,12 @@ def _adjust(self, kwargs): if interp: logging.debug("Interpolating from %s to %s", source, grid) - return (Interpolate(grid, source), source) + return Interpolate(grid, source) else: - return (lambda x: x, source) + return lambda x: x def pl_load_source(self, **kwargs): - pproc, _ = self._adjust(kwargs) + pproc = self._adjust(kwargs) kwargs["levtype"] = "pl" request = kwargs.copy() @@ -75,7 +75,7 @@ def pl_load_source(self, **kwargs): assert isinstance(param, (list, tuple)) if "z" in param: - logging.warning("Parameter 'z' on pressure levels is not available in open data, using 'gh' instead") + logging.warning("Parameter 'z' on pressure levels is not available in ECMWF open data, using 'gh' instead") param = list(param) param.remove("z") if "gh" not in param: @@ -87,7 +87,7 @@ def pl_load_source(self, **kwargs): return self.check_pl(pproc(ekd.from_source("ecmwf-open-data", **kwargs)), request) def sfc_load_source(self, **kwargs): - pproc, resol = self._adjust(kwargs) + pproc = self._adjust(kwargs) kwargs["levtype"] = "sfc" request = kwargs.copy() @@ -107,18 +107,18 @@ def sfc_load_source(self, **kwargs): if len(constant_params) == 1: logging.warning( f"Single level parameter '{constant_params[0]}' is" - " not available in open data, using constants.grib2 instead" + " not available in ECMWF open data, using constants.grib2 instead" ) else: logging.warning( f"Single level parameters {constant_params} are" - " not available in open data, using constants.grib2 instead" + " not available in ECMWF open data, using constants.grib2 instead" ) constants = [] cachedir = os.path.expanduser("~/.cache/ai-models") - constant_url = CONSTANTS_URL.format(resol=resol) - basename = os.path.basename(constant_url) + constants_url = CONSTANTS_URL.format(resol=request["resol"]) + basename = os.path.basename(constants_url) if not os.path.exists(cachedir): os.makedirs(cachedir) @@ -126,8 +126,8 @@ def sfc_load_source(self, **kwargs): path = os.path.join(cachedir, basename) if not os.path.exists(path): - logging.info("Downloading %s to %s", constant_url, path) - download(constant_url, path + ".tmp") + logging.info("Downloading %s to %s", constants_url, path) + download(constants_url, path + ".tmp") os.rename(path + ".tmp", path) ds = ekd.from_source("file", path) @@ -191,6 +191,10 @@ def check_ml(self, ds, request): def _check(self, ds, what, request, *keys): + def _(p): + if len(p) == 1: + return p[0] + expected = set() for p in itertools.product(*[request[key] for key in keys]): expected.add(p) @@ -201,8 +205,14 @@ def _check(self, ds, what, request, *keys): missing = expected - found if missing: - raise ValueError(f"The following {what} parameters {missing} are not available in open data") + missing = [_(p) for p in missing] + if len(missing) == 1: + raise ValueError(f"The following {what} parameter '{missing[0]}' is not available in ECMWF open data") + raise ValueError(f"The following {what} parameters {missing} are not available in ECMWF open data") extra = found - expected if extra: - raise ValueError(f"Unexpected {what} parameters {extra} from open data") + extra = [_(p) for p in extra] + if len(extra) == 1: + raise ValueError(f"Unexpected {what} parameter '{extra[0]}' from ECMWF open data") + raise ValueError(f"Unexpected {what} parameters {extra} from ECMWF open data") From 126fc5c708e0169edc31b3e9da3dd3c3592e0eae Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 16 Sep 2024 10:16:51 +0000 Subject: [PATCH 16/20] typo --- pyproject.toml | 2 +- src/ai_models/inputs/opendata.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4c3dc6d..86f9f1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,7 +71,7 @@ version_file = "src/ai_models/_version.py" file = "ai_models.inputs.file:FileInput" mars = "ai_models.inputs.mars:MarsInput" cds = "ai_models.inputs.cds:CdsInput" -opendata = "ai_models.inputs.opendata:OpenDataInput" +ecmwf-open-data = "ai_models.inputs.opendata:OpenDataInput" [project.entry-points."ai_models.output"] file = "ai_models.outputs:FileOutput" diff --git a/src/ai_models/inputs/opendata.py b/src/ai_models/inputs/opendata.py index 16634c0..935a0f8 100644 --- a/src/ai_models/inputs/opendata.py +++ b/src/ai_models/inputs/opendata.py @@ -170,7 +170,7 @@ def sfc_load_source(self, **kwargs): return self.check_sfc(fields, request) def ml_load_source(self, **kwargs): - pproc, _ = self._adjust(kwargs) + pproc = self._adjust(kwargs) kwargs["levtype"] = "ml" request = kwargs.copy() From e5ee8d35d1a17de24c778ce8a8acff97c836b290 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sun, 22 Sep 2024 07:28:23 +0000 Subject: [PATCH 17/20] Fix call to as_mars() --- pyproject.toml | 3 ++- src/ai_models/inputs/interpolate.py | 5 ++++- src/ai_models/inputs/opendata.py | 15 +++++++++------ src/ai_models/inputs/transform.py | 1 + src/ai_models/model.py | 6 ++++-- src/ai_models/outputs/__init__.py | 12 ++++++++++++ 6 files changed, 32 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 86f9f1b..921dd81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ classifiers = [ dependencies = [ "cdsapi", - "earthkit-data>=0.10.1", + "earthkit-data>=0.10.3", "earthkit-meteo", "earthkit-regrid", "eccodes>=2.37", @@ -72,6 +72,7 @@ file = "ai_models.inputs.file:FileInput" mars = "ai_models.inputs.mars:MarsInput" cds = "ai_models.inputs.cds:CdsInput" ecmwf-open-data = "ai_models.inputs.opendata:OpenDataInput" +opendata = "ai_models.inputs.opendata:OpenDataInput" [project.entry-points."ai_models.output"] file = "ai_models.outputs:FileOutput" diff --git a/src/ai_models/inputs/interpolate.py b/src/ai_models/inputs/interpolate.py index 744ed18..3fec238 100644 --- a/src/ai_models/inputs/interpolate.py +++ b/src/ai_models/inputs/interpolate.py @@ -8,6 +8,7 @@ import logging import earthkit.regrid as ekr +import tqdm from earthkit.data.indexing.fieldlist import FieldArray from .transform import NewDataField @@ -22,7 +23,9 @@ def __init__(self, grid, source): def __call__(self, ds): result = [] - for f in ds: + for f in tqdm.tqdm(ds, delay=0.5, desc="Interpolating", leave=False): data = ekr.interpolate(f.to_numpy(), dict(grid=self.source), dict(grid=self.grid)) result.append(NewDataField(f, data)) + + LOG.info("Interpolated %d fields. Input shape %s, output shape %s.", len(result), ds[0].shape, result[0].shape) return FieldArray(result) diff --git a/src/ai_models/inputs/opendata.py b/src/ai_models/inputs/opendata.py index 935a0f8..61d4e7b 100644 --- a/src/ai_models/inputs/opendata.py +++ b/src/ai_models/inputs/opendata.py @@ -35,10 +35,10 @@ class OpenDataInput(RequestBasedInput): WHERE = "OPENDATA" RESOLS = { - (0.25, 0.25): ("0p25", (0.25, 0.25), False), - "N320": ("0p25", (0.25, 0.25), True), - "O96": ("0p25", (0.25, 0.25), True), - # (0.1, 0.1): ("0p25", (0.25, 0.25), False), + (0.25, 0.25): ("0p25", (0.25, 0.25), False, False), + "N320": ("0p25", (0.25, 0.25), True, False), + "O96": ("0p25", (0.25, 0.25), True, False), + (0.1, 0.1): ("0p25", (0.25, 0.25), True, True), } def __init__(self, owner, **kwargs): @@ -56,12 +56,15 @@ def _adjust(self, kwargs): if isinstance(grid, list): grid = tuple(grid) - kwargs["resol"], source, interp = self.RESOLS[grid] + kwargs["resol"], source, interp, oversampling = self.RESOLS[grid] r = dict(**kwargs) r.update(self.owner.retrieve) if interp: - logging.debug("Interpolating from %s to %s", source, grid) + + logging.info("Interpolating input data from %s to %s.", source, grid) + if oversampling: + logging.warning("This will oversample the input data.") return Interpolate(grid, source) else: return lambda x: x diff --git a/src/ai_models/inputs/transform.py b/src/ai_models/inputs/transform.py index 086cde0..2a10cce 100644 --- a/src/ai_models/inputs/transform.py +++ b/src/ai_models/inputs/transform.py @@ -25,6 +25,7 @@ class NewDataField(WrappedField): def __init__(self, field, data): super().__init__(field) self._data = data + self.shape = data.shape def to_numpy(self, flatten=False, dtype=None, index=None): data = self._data diff --git a/src/ai_models/model.py b/src/ai_models/model.py index eb75435..25c4249 100644 --- a/src/ai_models/model.py +++ b/src/ai_models/model.py @@ -127,7 +127,7 @@ def collect_archive_requests(self, written): # does not return always return recently set keys handle = handle.clone() - self.archiving[path].add(handle.as_mars()) + self.archiving[path].add(handle.as_namespace("mars")) def finalise(self): self.output.flush() @@ -536,6 +536,8 @@ def write_input_fields( accumulations_shape=None, ignore=None, ): + LOG.info("Starting date is %s", self.start_datetime) + LOG.info("Writing input fields") if ignore is None: ignore = [] @@ -553,7 +555,7 @@ def write_input_fields( if accumulations is not None: if accumulations_template is None: - accumulations_template = fields.sel(param="2t")[0] + accumulations_template = fields.sel(param="msl")[0] if accumulations_shape is None: accumulations_shape = accumulations_template.shape diff --git a/src/ai_models/outputs/__init__.py b/src/ai_models/outputs/__init__.py index 614fa64..2e9a616 100644 --- a/src/ai_models/outputs/__init__.py +++ b/src/ai_models/outputs/__init__.py @@ -67,6 +67,18 @@ def write(self, data, *args, check=False, **kwargs): raise ValueError(f"NaN values found in field. args={args} kwargs={kwargs}") if np.isinf(data).any(): raise ValueError(f"Infinite values found in field. args={args} kwargs={kwargs}") + + options = {} + options.update(self.grib_keys) + options.update(kwargs) + LOG.error("Failed to write data to %s %s", args, options) + cmd = [] + for k, v in options.items(): + if isinstance(v, (int, str, float)): + cmd.append("%s=%s" % (k, v)) + + LOG.error("grib_set -s%s", ",".join(cmd)) + raise if check: From cf4b994871240791d8bf6301c5cde9fbc07d7c71 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 23 Sep 2024 13:01:47 +0000 Subject: [PATCH 18/20] recenter fields --- src/ai_models/inputs/compute.py | 36 +++--- src/ai_models/inputs/interpolate.py | 24 ++-- src/ai_models/inputs/opendata.py | 181 ++++++++++++++++------------ src/ai_models/inputs/recenter.py | 92 ++++++++++++++ 4 files changed, 233 insertions(+), 100 deletions(-) create mode 100644 src/ai_models/inputs/recenter.py diff --git a/src/ai_models/inputs/compute.py b/src/ai_models/inputs/compute.py index 41ece36..bd656b9 100644 --- a/src/ai_models/inputs/compute.py +++ b/src/ai_models/inputs/compute.py @@ -7,27 +7,33 @@ import logging +import earthkit.data as ekd +import tqdm +from earthkit.data.core.temporary import temp_file from earthkit.data.indexing.fieldlist import FieldArray -from .transform import NewDataField -from .transform import NewMetadataField - LOG = logging.getLogger(__name__) +G = 9.80665 # Same a pgen + + +def make_z_from_gh(ds): + + tmp = temp_file() + + out = ekd.new_grib_output(tmp.path) + other = [] -def make_z_from_gh(previous): - g = 9.80665 # Same a pgen + for f in tqdm.tqdm(ds, delay=0.5, desc="GH to Z", leave=False): - def _proc(ds): + if f.metadata("param") == "gh": + out.write(f.to_numpy() * G, template=f, param="z") + else: + other.append(f) - ds = previous(ds) + out.close() - result = [] - for f in ds: - if f.metadata("param") == "gh": - result.append(NewMetadataField(NewDataField(f, f.to_numpy() * g), param="z")) - else: - result.append(f) - return FieldArray(result) + result = FieldArray(other) + ekd.from_source("file", tmp.path) + result._tmp = tmp - return _proc + return result diff --git a/src/ai_models/inputs/interpolate.py b/src/ai_models/inputs/interpolate.py index 3fec238..c5e3145 100644 --- a/src/ai_models/inputs/interpolate.py +++ b/src/ai_models/inputs/interpolate.py @@ -7,25 +7,35 @@ import logging +import earthkit.data as ekd import earthkit.regrid as ekr import tqdm -from earthkit.data.indexing.fieldlist import FieldArray - -from .transform import NewDataField +from earthkit.data.core.temporary import temp_file LOG = logging.getLogger(__name__) class Interpolate: - def __init__(self, grid, source): + def __init__(self, grid, source, metadata): self.grid = list(grid) if isinstance(grid, tuple) else grid self.source = list(source) if isinstance(source, tuple) else source + self.metadata = metadata def __call__(self, ds): + tmp = temp_file() + + out = ekd.new_grib_output(tmp.path) + result = [] for f in tqdm.tqdm(ds, delay=0.5, desc="Interpolating", leave=False): data = ekr.interpolate(f.to_numpy(), dict(grid=self.source), dict(grid=self.grid)) - result.append(NewDataField(f, data)) + out.write(data, template=f, **self.metadata) + + out.close() + + result = ekd.from_source("file", tmp.path) + result._tmp = tmp + + print("Interpolated data", tmp.path) - LOG.info("Interpolated %d fields. Input shape %s, output shape %s.", len(result), ds[0].shape, result[0].shape) - return FieldArray(result) + return result diff --git a/src/ai_models/inputs/opendata.py b/src/ai_models/inputs/opendata.py index 61d4e7b..5ea9dbb 100644 --- a/src/ai_models/inputs/opendata.py +++ b/src/ai_models/inputs/opendata.py @@ -5,23 +5,23 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -import datetime import itertools import logging import os import earthkit.data as ekd +from earthkit.data.core.temporary import temp_file from earthkit.data.indexing.fieldlist import FieldArray from multiurl import download from .base import RequestBasedInput from .compute import make_z_from_gh from .interpolate import Interpolate +from .recenter import recenter from .transform import NewMetadataField LOG = logging.getLogger(__name__) - CONSTANTS = ( "z", "sdor", @@ -30,17 +30,33 @@ CONSTANTS_URL = "https://get.ecmwf.int/repository/test-data/ai-models/opendata/constants-{resol}.grib2" +RESOLS = { + (0.25, 0.25): ("0p25", (0.25, 0.25), False, False, {}), + (0.1, 0.1): ( + "0p25", + (0.25, 0.25), + True, + True, + dict( + longitudeOfLastGridPointInDegrees=359.9, + iDirectionIncrementInDegrees=0.1, + jDirectionIncrementInDegrees=0.1, + Ni=3600, + Nj=1801, + ), + ), + # "N320": ("0p25", (0.25, 0.25), True, False, dict(gridType='reduced_gg')), + # "O96": ("0p25", (0.25, 0.25), True, False, dict(gridType='reduced_gg', )), +} + + +def _identity(x): + return x + class OpenDataInput(RequestBasedInput): WHERE = "OPENDATA" - RESOLS = { - (0.25, 0.25): ("0p25", (0.25, 0.25), False, False), - "N320": ("0p25", (0.25, 0.25), True, False), - "O96": ("0p25", (0.25, 0.25), True, False), - (0.1, 0.1): ("0p25", (0.25, 0.25), True, True), - } - def __init__(self, owner, **kwargs): self.owner = owner @@ -56,7 +72,7 @@ def _adjust(self, kwargs): if isinstance(grid, list): grid = tuple(grid) - kwargs["resol"], source, interp, oversampling = self.RESOLS[grid] + kwargs["resol"], source, interp, oversampling, metadata = RESOLS[grid] r = dict(**kwargs) r.update(self.owner.retrieve) @@ -65,12 +81,15 @@ def _adjust(self, kwargs): logging.info("Interpolating input data from %s to %s.", source, grid) if oversampling: logging.warning("This will oversample the input data.") - return Interpolate(grid, source) + return Interpolate(grid, source, metadata) else: - return lambda x: x + return _identity def pl_load_source(self, **kwargs): - pproc = self._adjust(kwargs) + + gh_to_z = _identity + interpolate = self._adjust(kwargs) + kwargs["levtype"] = "pl" request = kwargs.copy() @@ -84,13 +103,68 @@ def pl_load_source(self, **kwargs): if "gh" not in param: param.append("gh") kwargs["param"] = param - pproc = make_z_from_gh(pproc) + gh_to_z = make_z_from_gh logging.debug("load source ecmwf-open-data %s", kwargs) - return self.check_pl(pproc(ekd.from_source("ecmwf-open-data", **kwargs)), request) + + opendata = recenter(ekd.from_source("ecmwf-open-data", **kwargs)) + opendata = gh_to_z(opendata) + opendata = interpolate(opendata) + + return self.check_pl(opendata, request) + + def constants(self, constant_params, request, kwargs): + if len(constant_params) == 1: + logging.warning( + f"Single level parameter '{constant_params[0]}' is" + " not available in ECMWF open data, using constants.grib2 instead" + ) + else: + logging.warning( + f"Single level parameters {constant_params} are" + " not available in ECMWF open data, using constants.grib2 instead" + ) + + cachedir = os.path.expanduser("~/.cache/ai-models") + constants_url = CONSTANTS_URL.format(resol=request["resol"]) + basename = os.path.basename(constants_url) + + if not os.path.exists(cachedir): + os.makedirs(cachedir) + + path = os.path.join(cachedir, basename) + + if not os.path.exists(path): + logging.info("Downloading %s to %s", constants_url, path) + download(constants_url, path + ".tmp") + os.rename(path + ".tmp", path) + + ds = ekd.from_source("file", path) + ds = ds.sel(param=constant_params) + + tmp = temp_file() + + out = ekd.new_grib_output(tmp.path) + + for f in ds: + out.write( + f.to_numpy(), + template=f, + date=kwargs["date"], + time=kwargs["time"], + step=kwargs.get("step", 0), + ) + + out.close() + + result = ekd.from_source("file", tmp.path) + result._tmp = tmp + + return result def sfc_load_source(self, **kwargs): - pproc = self._adjust(kwargs) + interpolate = self._adjust(kwargs) + kwargs["levtype"] = "sfc" request = kwargs.copy() @@ -104,81 +178,32 @@ def sfc_load_source(self, **kwargs): param.remove(c) constant_params.append(c) - constants = ekd.from_source("empty") - if constant_params: - if len(constant_params) == 1: - logging.warning( - f"Single level parameter '{constant_params[0]}' is" - " not available in ECMWF open data, using constants.grib2 instead" - ) - else: - logging.warning( - f"Single level parameters {constant_params} are" - " not available in ECMWF open data, using constants.grib2 instead" - ) - constants = [] - - cachedir = os.path.expanduser("~/.cache/ai-models") - constants_url = CONSTANTS_URL.format(resol=request["resol"]) - basename = os.path.basename(constants_url) - - if not os.path.exists(cachedir): - os.makedirs(cachedir) - - path = os.path.join(cachedir, basename) - - if not os.path.exists(path): - logging.info("Downloading %s to %s", constants_url, path) - download(constants_url, path + ".tmp") - os.rename(path + ".tmp", path) - - ds = ekd.from_source("file", path) - ds = ds.sel(param=constant_params) - - date = int(kwargs["date"]) - time = int(kwargs["time"]) - if time < 100: - time *= 100 - step = int(kwargs.get("step", 0)) - valid = datetime.datetime( - date // 10000, date // 100 % 100, date % 100, time // 100, time % 100 - ) + datetime.timedelta(hours=step) - - for f in ds: - - # assert False, (date, time, step) - constants.append( - NewMetadataField( - f, - valid_datetime=str(valid), - date=date, - time="%4d" % (time,), - step=step, - ) - ) - - constants = FieldArray(constants) + constants = self.constants(constant_params, request, kwargs) + else: + constants = ekd.from_source("empty") kwargs["param"] = param - logging.debug("load source ecmwf-open-data %s", kwargs) - - fields = pproc(ekd.from_source("ecmwf-open-data", **kwargs) + constants) + opendata = recenter(ekd.from_source("ecmwf-open-data", **kwargs)) + opendata = opendata + constants + opendata = interpolate(opendata) # Fix grib2/eccodes bug - fields = FieldArray([NewMetadataField(f, levelist=None) for f in fields]) + opendata = FieldArray([NewMetadataField(f, levelist=None) for f in opendata]) - return self.check_sfc(fields, request) + return self.check_sfc(opendata, request) def ml_load_source(self, **kwargs): - pproc = self._adjust(kwargs) + interpolate = self._adjust(kwargs) kwargs["levtype"] = "ml" request = kwargs.copy() - logging.debug("load source ecmwf-open-data %s", kwargs) - return self.check_ml(pproc(ekd.from_source("ecmwf-open-data", kwargs)), request) + opendata = recenter(ekd.from_source("ecmwf-open-data", **kwargs)) + opendata = interpolate(opendata) + + return self.check_ml(opendata, request) def check_pl(self, ds, request): self._check(ds, "PL", request, "param", "levelist") diff --git a/src/ai_models/inputs/recenter.py b/src/ai_models/inputs/recenter.py new file mode 100644 index 0000000..33bba3e --- /dev/null +++ b/src/ai_models/inputs/recenter.py @@ -0,0 +1,92 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging + +import earthkit.data as ekd +import numpy as np +import tqdm +from earthkit.data.core.temporary import temp_file + +LOG = logging.getLogger(__name__) + +CHECKED = set() + + +def _init_recenter(ds, f): + + # For now, we only support the 0.25x0.25 grid from OPENDATA (centered on the greenwich meridian) + + latitudeOfFirstGridPointInDegrees = f.metadata("latitudeOfFirstGridPointInDegrees") + longitudeOfFirstGridPointInDegrees = f.metadata("longitudeOfFirstGridPointInDegrees") + latitudeOfLastGridPointInDegrees = f.metadata("latitudeOfLastGridPointInDegrees") + longitudeOfLastGridPointInDegrees = f.metadata("longitudeOfLastGridPointInDegrees") + iDirectionIncrementInDegrees = f.metadata("iDirectionIncrementInDegrees") + jDirectionIncrementInDegrees = f.metadata("jDirectionIncrementInDegrees") + scanningMode = f.metadata("scanningMode") + Ni = f.metadata("Ni") + Nj = f.metadata("Nj") + + assert scanningMode == 0 + assert latitudeOfFirstGridPointInDegrees == 90 + assert longitudeOfFirstGridPointInDegrees == 180 + assert latitudeOfLastGridPointInDegrees == -90 + assert longitudeOfLastGridPointInDegrees == 179.75 + assert iDirectionIncrementInDegrees == 0.25 + assert jDirectionIncrementInDegrees == 0.25 + + assert Ni == 1440 + assert Nj == 721 + + shape = (Nj, Ni) + roll = -Ni // 2 + axis = 1 + + key = ( + latitudeOfFirstGridPointInDegrees, + longitudeOfFirstGridPointInDegrees, + latitudeOfLastGridPointInDegrees, + longitudeOfLastGridPointInDegrees, + iDirectionIncrementInDegrees, + jDirectionIncrementInDegrees, + Ni, + Nj, + ) + + ############################ + + if key not in CHECKED: + lon = ekd.from_source("forcings", ds, param=["longitude"], date=f.metadata("date"))[0] + assert np.all(np.roll(lon.to_numpy(), roll, axis=axis)[:, 0] == 0) + CHECKED.add(key) + + return (shape, roll, axis, dict(longitudeOfFirstGridPointInDegrees=0, longitudeOfLastGridPointInDegrees=359.75)) + + +def recenter(ds): + + tmp = temp_file() + + out = ekd.new_grib_output(tmp.path) + + for f in tqdm.tqdm(ds, delay=0.5, desc="Recentering", leave=False): + + shape, roll, axis, metadata = _init_recenter(ds, f) + + data = f.to_numpy() + assert data.shape == shape, (data.shape, shape) + + data = np.roll(data, roll, axis=axis) + + out.write(data, template=f, **metadata) + + out.close() + + result = ekd.from_source("file", tmp.path) + result._tmp = tmp + + return result From 33cae5bd54812a820600227990f9b754d05e947f Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 23 Sep 2024 19:25:46 +0000 Subject: [PATCH 19/20] better grib2 support --- src/ai_models/inputs/opendata.py | 5 ++- src/ai_models/model.py | 58 +++++++++++++++++++++++++------- 2 files changed, 50 insertions(+), 13 deletions(-) diff --git a/src/ai_models/inputs/opendata.py b/src/ai_models/inputs/opendata.py index 5ea9dbb..14813c5 100644 --- a/src/ai_models/inputs/opendata.py +++ b/src/ai_models/inputs/opendata.py @@ -61,6 +61,9 @@ def __init__(self, owner, **kwargs): self.owner = owner def _adjust(self, kwargs): + + kwargs.setdefault("step", 0) + if "level" in kwargs: # OpenData uses levelist instead of level kwargs["levelist"] = kwargs.pop("level") @@ -105,7 +108,7 @@ def pl_load_source(self, **kwargs): kwargs["param"] = param gh_to_z = make_z_from_gh - logging.debug("load source ecmwf-open-data %s", kwargs) + logging.info("ecmwf-open-data %s", kwargs) opendata = recenter(ekd.from_source("ecmwf-open-data", **kwargs)) opendata = gh_to_z(opendata) diff --git a/src/ai_models/model.py b/src/ai_models/model.py index 25c4249..6d98d8b 100644 --- a/src/ai_models/model.py +++ b/src/ai_models/model.py @@ -5,6 +5,7 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +import base64 import datetime import json import logging @@ -541,6 +542,8 @@ def write_input_fields( if ignore is None: ignore = [] + fields.save("input.grib") + with self.timer("Writing step 0"): for field in fields: if field.metadata("shortName") in ignore: @@ -560,18 +563,49 @@ def write_input_fields( if accumulations_shape is None: accumulations_shape = accumulations_template.shape - for param in accumulations: - self.write( - np.zeros(accumulations_shape, dtype=np.float32), - stepType="accum", - template=accumulations_template, - param=param, - startStep=0, - endStep=0, - date=int(self.start_datetime.strftime("%Y%m%d")), - time=int(self.start_datetime.strftime("%H%M")), - check=True, - ) + if accumulations_template.metadata("edition") == 1: + for param in accumulations: + + self.write( + np.zeros(accumulations_shape, dtype=np.float32), + stepType="accum", + template=accumulations_template, + param=param, + startStep=0, + endStep=0, + date=int(self.start_datetime.strftime("%Y%m%d")), + time=int(self.start_datetime.strftime("%H%M")), + check=True, + ) + else: + # # TODO: Remove this when accumulations are supported for GRIB edition 2 + + template = """ + R1JJQv//AAIAAAAAAAAA3AAAABUBAGIAABsBAQfoCRYGAAAAAQAAABECAAEAAQAJBAIwMDAxAAAA + SAMAAA/XoAAAAAAG////////////////////AAAFoAAAAtEAAAAA/////wVdSoAAAAAAMIVdSoAV + cVlwAAPQkAAD0JAAAAAAOgQAAAAIAcEC//8AAAABAAAAAAH//////////////wfoCRYGAAABAAAA + AAECAQAAAAD/AAAAAAAAABUFAA/XoAAAAAAAAIAKAAAAAAAAAAYG/wAAAAUHNzc3N0dSSUL//wAC + AAAAAAAAANwAAAAVAQBiAAAbAQEH6AkWDAAAAAEAAAARAgABAAEACQQBMDAwMQAAAEgDAAAP16AA + AAAABv///////////////////wAABaAAAALRAAAAAP////8FXUqAAAAAADCFXUqAFXFZcAAD0JAA + A9CQAAAAADoEAAAACAHBAv//AAAAAQAAAAAB//////////////8H6AkWDAAAAQAAAAABAgEAAAAA + /wAAAAAAAAAVBQAP16AAAAAAAACACgAAAAAAAAAGBv8AAAAFBzc3Nzc= + """ + + template = base64.b64decode(template) + accumulations_template = ekd.from_source("memory", template)[0] + + for param in accumulations: + self.write( + np.zeros(accumulations_shape, dtype=np.float32), + stepType="accum", + template=accumulations_template, + param=param, + startStep=0, + endStep=0, + date=int(self.start_datetime.strftime("%Y%m%d")), + time=int(self.start_datetime.strftime("%H%M")), + check=True, + ) def load_model(name, **kwargs): From 889598db95d0e1eb971d8be8a07bbc2c2408dd3c Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Tue, 24 Sep 2024 06:52:28 +0000 Subject: [PATCH 20/20] update urls --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 921dd81..9cb5f54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,9 +57,9 @@ dependencies = [ [project.urls] -Homepage = "https://github.com/ecmwf/ai-models/" -Repository = "https://github.com/ecmwf/ai-models/" -Issues = "https://github.com/ecmwf/ai-models/issues" +Homepage = "https://github.com/ecmwf-lab/ai-models/" +Repository = "https://github.com/ecmwf-lab/ai-models/" +Issues = "https://github.com/ecmwf-lab/ai-models/issues" [project.scripts] ai-models = "ai_models.__main__:main"