diff --git a/.gitignore b/.gitignore index a629b88..dc6992c 100644 --- a/.gitignore +++ b/.gitignore @@ -177,3 +177,4 @@ bar dev/ *.out _version.py +*.tar diff --git a/pyproject.toml b/pyproject.toml index 6974028..9cb5f54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,22 +40,26 @@ classifiers = [ ] dependencies = [ + "cdsapi", + "earthkit-data>=0.10.3", + "earthkit-meteo", + "earthkit-regrid", + "eccodes>=2.37", + "ecmwf-api-client", + "ecmwf-opendata", "entrypoints", - "requests", - "climetlab>=0.23.0", - "multiurl", - "ecmwflibs>=0.6.1", "gputil", - "earthkit-meteo", + "multiurl", "pyyaml", + "requests", "tqdm", ] [project.urls] -Homepage = "https://github.com/ecmwf/ai-models/" -Repository = "https://github.com/ecmwf/ai-models/" -Issues = "https://github.com/ecmwf/ai-models/issues" +Homepage = "https://github.com/ecmwf-lab/ai-models/" +Repository = "https://github.com/ecmwf-lab/ai-models/" +Issues = "https://github.com/ecmwf-lab/ai-models/issues" [project.scripts] ai-models = "ai_models.__main__:main" @@ -64,10 +68,11 @@ ai-models = "ai_models.__main__:main" version_file = "src/ai_models/_version.py" [project.entry-points."ai_models.input"] -file = "ai_models.inputs:FileInput" -mars = "ai_models.inputs:MarsInput" -cds = "ai_models.inputs:CdsInput" -opendata = "ai_models.inputs:OpenDataInput" +file = "ai_models.inputs.file:FileInput" +mars = "ai_models.inputs.mars:MarsInput" +cds = "ai_models.inputs.cds:CdsInput" +ecmwf-open-data = "ai_models.inputs.opendata:OpenDataInput" +opendata = "ai_models.inputs.opendata:OpenDataInput" [project.entry-points."ai_models.output"] file = "ai_models.outputs:FileOutput" diff --git a/src/ai_models/__main__.py b/src/ai_models/__main__.py index 57a1a6b..37b8ec4 100644 --- a/src/ai_models/__main__.py +++ b/src/ai_models/__main__.py @@ -11,12 +11,16 @@ import shlex import sys +import earthkit.data as ekd + from .inputs import available_inputs from .model import Timer from .model import available_models from .model import load_model from .outputs import available_outputs +ekd.settings.set("cache-policy", "user") + LOG = logging.getLogger(__name__) diff --git a/src/ai_models/inputs/__init__.py b/src/ai_models/inputs/__init__.py index 04e1d40..005b5d6 100644 --- a/src/ai_models/inputs/__init__.py +++ b/src/ai_models/inputs/__init__.py @@ -8,198 +8,14 @@ import logging from functools import cached_property -import climetlab as cml +import earthkit.data as ekd +import earthkit.regrid as ekr import entrypoints +from earthkit.data.indexing.fieldlist import FieldArray LOG = logging.getLogger(__name__) -class RequestBasedInput: - def __init__(self, owner, **kwargs): - self.owner = owner - - def _patch(self, **kargs): - r = dict(**kargs) - self.owner.patch_retrieve_request(r) - return r - - @cached_property - def fields_sfc(self): - param = self.owner.param_sfc - if not param: - return cml.load_source("empty") - - LOG.info(f"Loading surface fields from {self.WHERE}") - return cml.load_source( - "multi", - [ - self.sfc_load_source( - **self._patch( - date=date, - time=time, - param=param, - grid=self.owner.grid, - area=self.owner.area, - **self.owner.retrieve, - ) - ) - for date, time in self.owner.datetimes() - ], - ) - - @cached_property - def fields_pl(self): - param, level = self.owner.param_level_pl - if not (param and level): - return cml.load_source("empty") - - LOG.info(f"Loading pressure fields from {self.WHERE}") - return cml.load_source( - "multi", - [ - self.pl_load_source( - **self._patch( - date=date, - time=time, - param=param, - level=level, - grid=self.owner.grid, - area=self.owner.area, - ) - ) - for date, time in self.owner.datetimes() - ], - ) - - @cached_property - def fields_ml(self): - param, level = self.owner.param_level_ml - if not (param and level): - return cml.load_source("empty") - - LOG.info(f"Loading model fields from {self.WHERE}") - return cml.load_source( - "multi", - [ - self.ml_load_source( - **self._patch( - date=date, - time=time, - param=param, - level=level, - grid=self.owner.grid, - area=self.owner.area, - ) - ) - for date, time in self.owner.datetimes() - ], - ) - - @cached_property - def all_fields(self): - return self.fields_sfc + self.fields_pl + self.fields_ml - - -class MarsInput(RequestBasedInput): - WHERE = "MARS" - - def __init__(self, owner, **kwargs): - self.owner = owner - - def pl_load_source(self, **kwargs): - kwargs["levtype"] = "pl" - logging.debug("load source mars %s", kwargs) - return cml.load_source("mars", kwargs) - - def sfc_load_source(self, **kwargs): - kwargs["levtype"] = "sfc" - logging.debug("load source mars %s", kwargs) - return cml.load_source("mars", kwargs) - - def ml_load_source(self, **kwargs): - kwargs["levtype"] = "ml" - logging.debug("load source mars %s", kwargs) - return cml.load_source("mars", kwargs) - - -class CdsInput(RequestBasedInput): - WHERE = "CDS" - - def pl_load_source(self, **kwargs): - kwargs["product_type"] = "reanalysis" - return cml.load_source("cds", "reanalysis-era5-pressure-levels", kwargs) - - def sfc_load_source(self, **kwargs): - kwargs["product_type"] = "reanalysis" - return cml.load_source("cds", "reanalysis-era5-single-levels", kwargs) - - def ml_load_source(self, **kwargs): - raise NotImplementedError("CDS does not support model levels") - - -class OpenDataInput(RequestBasedInput): - WHERE = "OPENDATA" - - RESOLS = {(0.25, 0.25): "0p25"} - - def __init__(self, owner, **kwargs): - self.owner = owner - - def _adjust(self, kwargs): - if "level" in kwargs: - # OpenData uses levelist instead of level - kwargs["levelist"] = kwargs.pop("level") - - grid = kwargs.pop("grid") - if isinstance(grid, list): - grid = tuple(grid) - - kwargs["resol"] = self.RESOLS[grid] - r = dict(**kwargs) - r.update(self.owner.retrieve) - return r - - def pl_load_source(self, **kwargs): - self._adjust(kwargs) - kwargs["levtype"] = "pl" - logging.debug("load source ecmwf-open-data %s", kwargs) - return cml.load_source("ecmwf-open-data", **kwargs) - - def sfc_load_source(self, **kwargs): - self._adjust(kwargs) - kwargs["levtype"] = "sfc" - logging.debug("load source ecmwf-open-data %s", kwargs) - return cml.load_source("ecmwf-open-data", **kwargs) - - def ml_load_source(self, **kwargs): - self._adjust(kwargs) - kwargs["levtype"] = "ml" - logging.debug("load source ecmwf-open-data %s", kwargs) - return cml.load_source("ecmwf-open-data", **kwargs) - - -class FileInput: - def __init__(self, owner, file, **kwargs): - self.file = file - self.owner = owner - - @cached_property - def fields_sfc(self): - return self.all_fields.sel(levtype="sfc") - - @cached_property - def fields_pl(self): - return self.all_fields.sel(levtype="pl") - - @cached_property - def fields_ml(self): - return self.all_fields.sel(levtype="ml") - - @cached_property - def all_fields(self): - return cml.load_source("file", self.file) - - def get_input(name, *args, **kwargs): return available_inputs()[name].load()(*args, **kwargs) diff --git a/src/ai_models/inputs/base.py b/src/ai_models/inputs/base.py new file mode 100644 index 0000000..acdadb4 --- /dev/null +++ b/src/ai_models/inputs/base.py @@ -0,0 +1,100 @@ +# (C) Copyright 2023 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from functools import cached_property + +import earthkit.data as ekd + +LOG = logging.getLogger(__name__) + + +class RequestBasedInput: + def __init__(self, owner, **kwargs): + self.owner = owner + + def _patch(self, **kargs): + r = dict(**kargs) + self.owner.patch_retrieve_request(r) + return r + + @cached_property + def fields_sfc(self): + param = self.owner.param_sfc + if not param: + return ekd.from_source("empty") + + LOG.info(f"Loading surface fields from {self.WHERE}") + + return ekd.from_source( + "multi", + [ + self.sfc_load_source( + **self._patch( + date=date, + time=time, + param=param, + grid=self.owner.grid, + area=self.owner.area, + **self.owner.retrieve, + ) + ) + for date, time in self.owner.datetimes() + ], + ) + + @cached_property + def fields_pl(self): + param, level = self.owner.param_level_pl + if not (param and level): + return ekd.from_source("empty") + + LOG.info(f"Loading pressure fields from {self.WHERE}") + return ekd.from_source( + "multi", + [ + self.pl_load_source( + **self._patch( + date=date, + time=time, + param=param, + level=level, + grid=self.owner.grid, + area=self.owner.area, + ) + ) + for date, time in self.owner.datetimes() + ], + ) + + @cached_property + def fields_ml(self): + param, level = self.owner.param_level_ml + if not (param and level): + return ekd.from_source("empty") + + LOG.info(f"Loading model fields from {self.WHERE}") + return ekd.from_source( + "multi", + [ + self.ml_load_source( + **self._patch( + date=date, + time=time, + param=param, + level=level, + grid=self.owner.grid, + area=self.owner.area, + ) + ) + for date, time in self.owner.datetimes() + ], + ) + + @cached_property + def all_fields(self): + return self.fields_sfc + self.fields_pl + self.fields_ml diff --git a/src/ai_models/inputs/cds.py b/src/ai_models/inputs/cds.py new file mode 100644 index 0000000..9595c46 --- /dev/null +++ b/src/ai_models/inputs/cds.py @@ -0,0 +1,29 @@ +# (C) Copyright 2023 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging + +import earthkit.data as ekd + +from .base import RequestBasedInput + +LOG = logging.getLogger(__name__) + + +class CdsInput(RequestBasedInput): + WHERE = "CDS" + + def pl_load_source(self, **kwargs): + kwargs["product_type"] = "reanalysis" + return ekd.from_source("cds", "reanalysis-era5-pressure-levels", kwargs) + + def sfc_load_source(self, **kwargs): + kwargs["product_type"] = "reanalysis" + return ekd.from_source("cds", "reanalysis-era5-single-levels", kwargs) + + def ml_load_source(self, **kwargs): + raise NotImplementedError("CDS does not support model levels") diff --git a/src/ai_models/inputs/compute.py b/src/ai_models/inputs/compute.py new file mode 100644 index 0000000..bd656b9 --- /dev/null +++ b/src/ai_models/inputs/compute.py @@ -0,0 +1,39 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging + +import earthkit.data as ekd +import tqdm +from earthkit.data.core.temporary import temp_file +from earthkit.data.indexing.fieldlist import FieldArray + +LOG = logging.getLogger(__name__) + +G = 9.80665 # Same a pgen + + +def make_z_from_gh(ds): + + tmp = temp_file() + + out = ekd.new_grib_output(tmp.path) + other = [] + + for f in tqdm.tqdm(ds, delay=0.5, desc="GH to Z", leave=False): + + if f.metadata("param") == "gh": + out.write(f.to_numpy() * G, template=f, param="z") + else: + other.append(f) + + out.close() + + result = FieldArray(other) + ekd.from_source("file", tmp.path) + result._tmp = tmp + + return result diff --git a/src/ai_models/inputs/file.py b/src/ai_models/inputs/file.py new file mode 100644 index 0000000..fba571b --- /dev/null +++ b/src/ai_models/inputs/file.py @@ -0,0 +1,47 @@ +# (C) Copyright 2023 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from functools import cached_property + +import earthkit.data as ekd +import entrypoints + +LOG = logging.getLogger(__name__) + + +class FileInput: + def __init__(self, owner, file, **kwargs): + self.file = file + self.owner = owner + + @cached_property + def fields_sfc(self): + return self.all_fields.sel(levtype="sfc") + + @cached_property + def fields_pl(self): + return self.all_fields.sel(levtype="pl") + + @cached_property + def fields_ml(self): + return self.all_fields.sel(levtype="ml") + + @cached_property + def all_fields(self): + return ekd.from_source("file", self.file) + + +def get_input(name, *args, **kwargs): + return available_inputs()[name].load()(*args, **kwargs) + + +def available_inputs(): + result = {} + for e in entrypoints.get_group_all("ai_models.input"): + result[e.name] = e + return result diff --git a/src/ai_models/inputs/interpolate.py b/src/ai_models/inputs/interpolate.py new file mode 100644 index 0000000..c5e3145 --- /dev/null +++ b/src/ai_models/inputs/interpolate.py @@ -0,0 +1,41 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging + +import earthkit.data as ekd +import earthkit.regrid as ekr +import tqdm +from earthkit.data.core.temporary import temp_file + +LOG = logging.getLogger(__name__) + + +class Interpolate: + def __init__(self, grid, source, metadata): + self.grid = list(grid) if isinstance(grid, tuple) else grid + self.source = list(source) if isinstance(source, tuple) else source + self.metadata = metadata + + def __call__(self, ds): + tmp = temp_file() + + out = ekd.new_grib_output(tmp.path) + + result = [] + for f in tqdm.tqdm(ds, delay=0.5, desc="Interpolating", leave=False): + data = ekr.interpolate(f.to_numpy(), dict(grid=self.source), dict(grid=self.grid)) + out.write(data, template=f, **self.metadata) + + out.close() + + result = ekd.from_source("file", tmp.path) + result._tmp = tmp + + print("Interpolated data", tmp.path) + + return result diff --git a/src/ai_models/inputs/mars.py b/src/ai_models/inputs/mars.py new file mode 100644 index 0000000..820a3eb --- /dev/null +++ b/src/ai_models/inputs/mars.py @@ -0,0 +1,36 @@ +# (C) Copyright 2023 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging + +import earthkit.data as ekd + +from .base import RequestBasedInput + +LOG = logging.getLogger(__name__) + + +class MarsInput(RequestBasedInput): + WHERE = "MARS" + + def __init__(self, owner, **kwargs): + self.owner = owner + + def pl_load_source(self, **kwargs): + kwargs["levtype"] = "pl" + logging.debug("load source mars %s", kwargs) + return ekd.from_source("mars", kwargs) + + def sfc_load_source(self, **kwargs): + kwargs["levtype"] = "sfc" + logging.debug("load source mars %s", kwargs) + return ekd.from_source("mars", kwargs) + + def ml_load_source(self, **kwargs): + kwargs["levtype"] = "ml" + logging.debug("load source mars %s", kwargs) + return ekd.from_source("mars", kwargs) diff --git a/src/ai_models/inputs/opendata.py b/src/ai_models/inputs/opendata.py new file mode 100644 index 0000000..14813c5 --- /dev/null +++ b/src/ai_models/inputs/opendata.py @@ -0,0 +1,249 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import itertools +import logging +import os + +import earthkit.data as ekd +from earthkit.data.core.temporary import temp_file +from earthkit.data.indexing.fieldlist import FieldArray +from multiurl import download + +from .base import RequestBasedInput +from .compute import make_z_from_gh +from .interpolate import Interpolate +from .recenter import recenter +from .transform import NewMetadataField + +LOG = logging.getLogger(__name__) + +CONSTANTS = ( + "z", + "sdor", + "slor", +) + +CONSTANTS_URL = "https://get.ecmwf.int/repository/test-data/ai-models/opendata/constants-{resol}.grib2" + +RESOLS = { + (0.25, 0.25): ("0p25", (0.25, 0.25), False, False, {}), + (0.1, 0.1): ( + "0p25", + (0.25, 0.25), + True, + True, + dict( + longitudeOfLastGridPointInDegrees=359.9, + iDirectionIncrementInDegrees=0.1, + jDirectionIncrementInDegrees=0.1, + Ni=3600, + Nj=1801, + ), + ), + # "N320": ("0p25", (0.25, 0.25), True, False, dict(gridType='reduced_gg')), + # "O96": ("0p25", (0.25, 0.25), True, False, dict(gridType='reduced_gg', )), +} + + +def _identity(x): + return x + + +class OpenDataInput(RequestBasedInput): + WHERE = "OPENDATA" + + def __init__(self, owner, **kwargs): + self.owner = owner + + def _adjust(self, kwargs): + + kwargs.setdefault("step", 0) + + if "level" in kwargs: + # OpenData uses levelist instead of level + kwargs["levelist"] = kwargs.pop("level") + + if "area" in kwargs: + kwargs.pop("area") + + grid = kwargs.pop("grid") + if isinstance(grid, list): + grid = tuple(grid) + + kwargs["resol"], source, interp, oversampling, metadata = RESOLS[grid] + r = dict(**kwargs) + r.update(self.owner.retrieve) + + if interp: + + logging.info("Interpolating input data from %s to %s.", source, grid) + if oversampling: + logging.warning("This will oversample the input data.") + return Interpolate(grid, source, metadata) + else: + return _identity + + def pl_load_source(self, **kwargs): + + gh_to_z = _identity + interpolate = self._adjust(kwargs) + + kwargs["levtype"] = "pl" + request = kwargs.copy() + + param = [p.lower() for p in kwargs["param"]] + assert isinstance(param, (list, tuple)) + + if "z" in param: + logging.warning("Parameter 'z' on pressure levels is not available in ECMWF open data, using 'gh' instead") + param = list(param) + param.remove("z") + if "gh" not in param: + param.append("gh") + kwargs["param"] = param + gh_to_z = make_z_from_gh + + logging.info("ecmwf-open-data %s", kwargs) + + opendata = recenter(ekd.from_source("ecmwf-open-data", **kwargs)) + opendata = gh_to_z(opendata) + opendata = interpolate(opendata) + + return self.check_pl(opendata, request) + + def constants(self, constant_params, request, kwargs): + if len(constant_params) == 1: + logging.warning( + f"Single level parameter '{constant_params[0]}' is" + " not available in ECMWF open data, using constants.grib2 instead" + ) + else: + logging.warning( + f"Single level parameters {constant_params} are" + " not available in ECMWF open data, using constants.grib2 instead" + ) + + cachedir = os.path.expanduser("~/.cache/ai-models") + constants_url = CONSTANTS_URL.format(resol=request["resol"]) + basename = os.path.basename(constants_url) + + if not os.path.exists(cachedir): + os.makedirs(cachedir) + + path = os.path.join(cachedir, basename) + + if not os.path.exists(path): + logging.info("Downloading %s to %s", constants_url, path) + download(constants_url, path + ".tmp") + os.rename(path + ".tmp", path) + + ds = ekd.from_source("file", path) + ds = ds.sel(param=constant_params) + + tmp = temp_file() + + out = ekd.new_grib_output(tmp.path) + + for f in ds: + out.write( + f.to_numpy(), + template=f, + date=kwargs["date"], + time=kwargs["time"], + step=kwargs.get("step", 0), + ) + + out.close() + + result = ekd.from_source("file", tmp.path) + result._tmp = tmp + + return result + + def sfc_load_source(self, **kwargs): + interpolate = self._adjust(kwargs) + + kwargs["levtype"] = "sfc" + request = kwargs.copy() + + param = [p.lower() for p in kwargs["param"]] + assert isinstance(param, (list, tuple)) + + constant_params = [] + param = list(param) + for c in CONSTANTS: + if c in param: + param.remove(c) + constant_params.append(c) + + if constant_params: + constants = self.constants(constant_params, request, kwargs) + else: + constants = ekd.from_source("empty") + + kwargs["param"] = param + + opendata = recenter(ekd.from_source("ecmwf-open-data", **kwargs)) + opendata = opendata + constants + opendata = interpolate(opendata) + + # Fix grib2/eccodes bug + + opendata = FieldArray([NewMetadataField(f, levelist=None) for f in opendata]) + + return self.check_sfc(opendata, request) + + def ml_load_source(self, **kwargs): + interpolate = self._adjust(kwargs) + kwargs["levtype"] = "ml" + request = kwargs.copy() + + opendata = recenter(ekd.from_source("ecmwf-open-data", **kwargs)) + opendata = interpolate(opendata) + + return self.check_ml(opendata, request) + + def check_pl(self, ds, request): + self._check(ds, "PL", request, "param", "levelist") + return ds + + def check_sfc(self, ds, request): + self._check(ds, "SFC", request, "param") + return ds + + def check_ml(self, ds, request): + self._check(ds, "ML", request, "param", "levelist") + return ds + + def _check(self, ds, what, request, *keys): + + def _(p): + if len(p) == 1: + return p[0] + + expected = set() + for p in itertools.product(*[request[key] for key in keys]): + expected.add(p) + + found = set() + for f in ds: + found.add(tuple(f.metadata(key) for key in keys)) + + missing = expected - found + if missing: + missing = [_(p) for p in missing] + if len(missing) == 1: + raise ValueError(f"The following {what} parameter '{missing[0]}' is not available in ECMWF open data") + raise ValueError(f"The following {what} parameters {missing} are not available in ECMWF open data") + + extra = found - expected + if extra: + extra = [_(p) for p in extra] + if len(extra) == 1: + raise ValueError(f"Unexpected {what} parameter '{extra[0]}' from ECMWF open data") + raise ValueError(f"Unexpected {what} parameters {extra} from ECMWF open data") diff --git a/src/ai_models/inputs/recenter.py b/src/ai_models/inputs/recenter.py new file mode 100644 index 0000000..33bba3e --- /dev/null +++ b/src/ai_models/inputs/recenter.py @@ -0,0 +1,92 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging + +import earthkit.data as ekd +import numpy as np +import tqdm +from earthkit.data.core.temporary import temp_file + +LOG = logging.getLogger(__name__) + +CHECKED = set() + + +def _init_recenter(ds, f): + + # For now, we only support the 0.25x0.25 grid from OPENDATA (centered on the greenwich meridian) + + latitudeOfFirstGridPointInDegrees = f.metadata("latitudeOfFirstGridPointInDegrees") + longitudeOfFirstGridPointInDegrees = f.metadata("longitudeOfFirstGridPointInDegrees") + latitudeOfLastGridPointInDegrees = f.metadata("latitudeOfLastGridPointInDegrees") + longitudeOfLastGridPointInDegrees = f.metadata("longitudeOfLastGridPointInDegrees") + iDirectionIncrementInDegrees = f.metadata("iDirectionIncrementInDegrees") + jDirectionIncrementInDegrees = f.metadata("jDirectionIncrementInDegrees") + scanningMode = f.metadata("scanningMode") + Ni = f.metadata("Ni") + Nj = f.metadata("Nj") + + assert scanningMode == 0 + assert latitudeOfFirstGridPointInDegrees == 90 + assert longitudeOfFirstGridPointInDegrees == 180 + assert latitudeOfLastGridPointInDegrees == -90 + assert longitudeOfLastGridPointInDegrees == 179.75 + assert iDirectionIncrementInDegrees == 0.25 + assert jDirectionIncrementInDegrees == 0.25 + + assert Ni == 1440 + assert Nj == 721 + + shape = (Nj, Ni) + roll = -Ni // 2 + axis = 1 + + key = ( + latitudeOfFirstGridPointInDegrees, + longitudeOfFirstGridPointInDegrees, + latitudeOfLastGridPointInDegrees, + longitudeOfLastGridPointInDegrees, + iDirectionIncrementInDegrees, + jDirectionIncrementInDegrees, + Ni, + Nj, + ) + + ############################ + + if key not in CHECKED: + lon = ekd.from_source("forcings", ds, param=["longitude"], date=f.metadata("date"))[0] + assert np.all(np.roll(lon.to_numpy(), roll, axis=axis)[:, 0] == 0) + CHECKED.add(key) + + return (shape, roll, axis, dict(longitudeOfFirstGridPointInDegrees=0, longitudeOfLastGridPointInDegrees=359.75)) + + +def recenter(ds): + + tmp = temp_file() + + out = ekd.new_grib_output(tmp.path) + + for f in tqdm.tqdm(ds, delay=0.5, desc="Recentering", leave=False): + + shape, roll, axis, metadata = _init_recenter(ds, f) + + data = f.to_numpy() + assert data.shape == shape, (data.shape, shape) + + data = np.roll(data, roll, axis=axis) + + out.write(data, template=f, **metadata) + + out.close() + + result = ekd.from_source("file", tmp.path) + result._tmp = tmp + + return result diff --git a/src/ai_models/inputs/transform.py b/src/ai_models/inputs/transform.py new file mode 100644 index 0000000..2a10cce --- /dev/null +++ b/src/ai_models/inputs/transform.py @@ -0,0 +1,49 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging + +LOG = logging.getLogger(__name__) + + +class WrappedField: + def __init__(self, field): + self._field = field + + def __getattr__(self, name): + return getattr(self._field, name) + + def __repr__(self) -> str: + return repr(self._field) + + +class NewDataField(WrappedField): + def __init__(self, field, data): + super().__init__(field) + self._data = data + self.shape = data.shape + + def to_numpy(self, flatten=False, dtype=None, index=None): + data = self._data + if dtype is not None: + data = data.astype(dtype) + if flatten: + data = data.flatten() + if index is not None: + data = data[index] + return data + + +class NewMetadataField(WrappedField): + def __init__(self, field, **kwargs): + super().__init__(field) + self._metadata = kwargs + + def metadata(self, *args, **kwargs): + if len(args) == 1 and args[0] in self._metadata: + return self._metadata[args[0]] + return self._field.metadata(*args, **kwargs) diff --git a/src/ai_models/model.py b/src/ai_models/model.py index a35e70e..6d98d8b 100644 --- a/src/ai_models/model.py +++ b/src/ai_models/model.py @@ -5,6 +5,7 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +import base64 import datetime import json import logging @@ -14,10 +15,10 @@ from collections import defaultdict from functools import cached_property -import climetlab as cml +import earthkit.data as ekd import entrypoints import numpy as np -from climetlab.utils.humanize import seconds +from earthkit.data.utils.humanize import seconds from multiurl import download from .checkpoint import peek @@ -88,7 +89,6 @@ def __init__(self, input, output, download_assets, **kwargs): LOG.debug("Asset directory is %s", self.assets) try: - # For CliMetLab, when date=-1 self.date = int(self.date) except ValueError: pass @@ -128,7 +128,7 @@ def collect_archive_requests(self, written): # does not return always return recently set keys handle = handle.clone() - self.archiving[path].add(handle.as_mars()) + self.archiving[path].add(handle.as_namespace("mars")) def finalise(self): self.output.flush() @@ -154,13 +154,13 @@ def finalise(self): def json_default(obj): if isinstance(obj, set): if len(obj) > 1: - return list(obj) + return sorted(list(obj)) else: return obj.pop() - return obj + raise TypeError print( - json.dumps(json_requests, separators=(",", ":"), default=json_default), + json.dumps(json_requests, separators=(",", ":"), default=json_default, sort_keys=True), file=f, ) @@ -375,31 +375,66 @@ def _requests_unfiltered(self): param, level = self.param_level_pl r = dict( - levtype="pl", - levelist=level, - param=param, date=date, time=time, ) r.update(first) first = {} - r.update(self._requests_extra) + if param and level: - self.patch_retrieve_request(r) + # PL + r.update( + dict( + levtype="pl", + levelist=level, + param=param, + date=date, + time=time, + ) + ) + + r.update(self._requests_extra) + + self.patch_retrieve_request(r) - result.append(dict(**r)) + result.append(dict(**r)) - r.update( - dict( - levtype="sfc", - param=self.param_sfc, + # ML + param, level = self.param_level_ml + + if param and level: + r.update( + dict( + levtype="ml", + levelist=level, + param=param, + date=date, + time=time, + ) ) - ) - r.pop("levelist", None) - self.patch_retrieve_request(r) - result.append(dict(**r)) + r.update(self._requests_extra) + + self.patch_retrieve_request(r) + + result.append(dict(**r)) + + param = self.param_sfc + if param: + # SFC + r.update( + dict( + levtype="sfc", + param=self.param_sfc, + date=date, + time=time, + levelist="off", + ) + ) + + self.patch_retrieve_request(r) + result.append(dict(**r)) return result @@ -471,8 +506,8 @@ def provenance(self): def forcing_and_constants(self, date, param): source = self.all_fields[:1] - ds = cml.load_source( - "constants", + ds = ekd.from_source( + "forcings", source, date=date, param=param, @@ -488,7 +523,7 @@ def gridpoints(self): @cached_property def start_datetime(self): - return self.all_fields.order_by(valid_datetime="ascending")[-1].datetime() + return self.all_fields.order_by(valid_datetime="ascending")[-1].datetime()["valid_time"] @property def constant_fields(self): @@ -502,15 +537,19 @@ def write_input_fields( accumulations_shape=None, ignore=None, ): + LOG.info("Starting date is %s", self.start_datetime) + LOG.info("Writing input fields") if ignore is None: ignore = [] + fields.save("input.grib") + with self.timer("Writing step 0"): for field in fields: if field.metadata("shortName") in ignore: continue - if field.valid_datetime() == self.start_datetime: + if field.datetime()["valid_time"] == self.start_datetime: self.write( None, template=field, @@ -519,23 +558,54 @@ def write_input_fields( if accumulations is not None: if accumulations_template is None: - accumulations_template = fields.sel(param="2t")[0] + accumulations_template = fields.sel(param="msl")[0] if accumulations_shape is None: accumulations_shape = accumulations_template.shape - for param in accumulations: - self.write( - np.zeros(accumulations_shape, dtype=np.float32), - stepType="accum", - template=accumulations_template, - param=param, - startStep=0, - endStep=0, - date=int(self.start_datetime.strftime("%Y%m%d")), - time=int(self.start_datetime.strftime("%H%M")), - check=True, - ) + if accumulations_template.metadata("edition") == 1: + for param in accumulations: + + self.write( + np.zeros(accumulations_shape, dtype=np.float32), + stepType="accum", + template=accumulations_template, + param=param, + startStep=0, + endStep=0, + date=int(self.start_datetime.strftime("%Y%m%d")), + time=int(self.start_datetime.strftime("%H%M")), + check=True, + ) + else: + # # TODO: Remove this when accumulations are supported for GRIB edition 2 + + template = """ + R1JJQv//AAIAAAAAAAAA3AAAABUBAGIAABsBAQfoCRYGAAAAAQAAABECAAEAAQAJBAIwMDAxAAAA + SAMAAA/XoAAAAAAG////////////////////AAAFoAAAAtEAAAAA/////wVdSoAAAAAAMIVdSoAV + cVlwAAPQkAAD0JAAAAAAOgQAAAAIAcEC//8AAAABAAAAAAH//////////////wfoCRYGAAABAAAA + AAECAQAAAAD/AAAAAAAAABUFAA/XoAAAAAAAAIAKAAAAAAAAAAYG/wAAAAUHNzc3N0dSSUL//wAC + AAAAAAAAANwAAAAVAQBiAAAbAQEH6AkWDAAAAAEAAAARAgABAAEACQQBMDAwMQAAAEgDAAAP16AA + AAAABv///////////////////wAABaAAAALRAAAAAP////8FXUqAAAAAADCFXUqAFXFZcAAD0JAA + A9CQAAAAADoEAAAACAHBAv//AAAAAQAAAAAB//////////////8H6AkWDAAAAQAAAAABAgEAAAAA + /wAAAAAAAAAVBQAP16AAAAAAAACACgAAAAAAAAAGBv8AAAAFBzc3Nzc= + """ + + template = base64.b64decode(template) + accumulations_template = ekd.from_source("memory", template)[0] + + for param in accumulations: + self.write( + np.zeros(accumulations_shape, dtype=np.float32), + stepType="accum", + template=accumulations_template, + param=param, + startStep=0, + endStep=0, + date=int(self.start_datetime.strftime("%Y%m%d")), + time=int(self.start_datetime.strftime("%H%M")), + check=True, + ) def load_model(name, **kwargs): diff --git a/src/ai_models/outputs/__init__.py b/src/ai_models/outputs/__init__.py index 0e395bc..2e9a616 100644 --- a/src/ai_models/outputs/__init__.py +++ b/src/ai_models/outputs/__init__.py @@ -10,7 +10,7 @@ import warnings from functools import cached_property -import climetlab as cml +import earthkit.data as ekd import entrypoints import numpy as np @@ -50,7 +50,7 @@ def grib_keys(self): @cached_property def output(self): - return cml.new_grib_output( + return ekd.new_grib_output( self.path, split_output=True, **self.grib_keys, @@ -67,6 +67,18 @@ def write(self, data, *args, check=False, **kwargs): raise ValueError(f"NaN values found in field. args={args} kwargs={kwargs}") if np.isinf(data).any(): raise ValueError(f"Infinite values found in field. args={args} kwargs={kwargs}") + + options = {} + options.update(self.grib_keys) + options.update(kwargs) + LOG.error("Failed to write data to %s %s", args, options) + cmd = [] + for k, v in options.items(): + if isinstance(v, (int, str, float)): + cmd.append("%s=%s" % (k, v)) + + LOG.error("grib_set -s%s", ",".join(cmd)) + raise if check: diff --git a/src/ai_models/remote/model.py b/src/ai_models/remote/model.py index 629a820..85377b9 100644 --- a/src/ai_models/remote/model.py +++ b/src/ai_models/remote/model.py @@ -4,7 +4,7 @@ import tempfile from functools import cached_property -import climetlab as cml +import earthkit.data as ekd from ..model import Model from .api import RemoteAPI @@ -45,7 +45,7 @@ def run(self): self.api.run(self.cfg) - ds = cml.load_source("file", output_file) + ds = ekd.from_source("file", output_file) for field in ds: self.write(None, template=field) diff --git a/src/ai_models/stepper.py b/src/ai_models/stepper.py index 5a0c545..c0f8c42 100644 --- a/src/ai_models/stepper.py +++ b/src/ai_models/stepper.py @@ -8,7 +8,7 @@ import logging import time -from climetlab.utils.humanize import seconds +from earthkit.data.utils.humanize import seconds LOG = logging.getLogger(__name__)