Skip to content

Commit

Permalink
add constants
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Sep 14, 2024
1 parent 084a2e7 commit baba519
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 64 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ dependencies = [
"multiurl",
"gputil",
"earthkit-meteo",
"earthkit-regrid",
"pyyaml",
"tqdm",
]
Expand Down
33 changes: 33 additions & 0 deletions src/ai_models/inputs/compute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import logging

from earthkit.data.indexing.fieldlist import FieldArray

from .transform import NewDataField
from .transform import NewMetadataField

LOG = logging.getLogger(__name__)


def make_z_from_gh(previous):
g = 9.80665 # Same a pgen

def _proc(ds):

ds = previous(ds)

result = []
for f in ds:
if f.metadata("param") == "gh":
result.append(NewMetadataField(NewDataField(f, f.to_numpy() * g), param="z"))
else:
result.append(f)
return FieldArray(result)

return _proc
28 changes: 28 additions & 0 deletions src/ai_models/inputs/interpolate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import logging

import earthkit.regrid as ekr
from earthkit.data.indexing.fieldlist import FieldArray

from .transform import NewDataField

LOG = logging.getLogger(__name__)


class Interpolate:
def __init__(self, grid, source):
self.grid = list(grid) if isinstance(grid, tuple) else grid
self.source = list(source) if isinstance(source, tuple) else source

def __call__(self, ds):
result = []
for f in ds:
data = ekr.interpolate(f.to_numpy(), dict(grid=self.source), dict(grid=self.grid))
result.append(NewDataField(f, data))
return FieldArray(result)
165 changes: 101 additions & 64 deletions src/ai_models/inputs/opendata.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,76 +5,26 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import datetime
import itertools
import logging

import earthkit.data as ekd
import earthkit.regrid as ekr
from earthkit.data.indexing.fieldlist import FieldArray

from .base import RequestBasedInput
from .compute import make_z_from_gh
from .interpolate import Interpolate
from .transform import NewMetadataField

LOG = logging.getLogger(__name__)


def _noop(x):
return x


class NewDataField:
def __init__(self, field, data, param):
self.field = field
self.data = data
self.param = param

def to_numpy(self, flatten=False, dtype=None, index=None):
data = self.data
if dtype is not None:
data = data.astype(dtype)
if flatten:
data = data.flatten()
return data

def metadata(self, key, *args, **kwargs):
if key == "param":
return self.param
return self.field.metadata(key, *args, **kwargs)

def __getattr__(self, name):
return getattr(self.field, name)

def __repr__(self) -> str:
return repr(self.field)


class Interpolate:
def __init__(self, grid, source):
self.grid = list(grid) if isinstance(grid, tuple) else grid
self.source = list(source) if isinstance(source, tuple) else source

def __call__(self, ds):
result = []
for f in ds:
data = ekr.interpolate(f.to_numpy(), dict(grid=self.source), dict(grid=self.grid))
result.append(NewDataField(f, data))
return FieldArray(result)


def make_z_from_gh(previous):
g = 9.80665 # Same a pgen

def _proc(ds):

ds = previous(ds)

result = []
for f in ds:
if f.metadata("param") == "gh":
result.append(NewDataField(f, f.to_numpy() * g, param="z"))
else:
result.append(f)
return FieldArray(result)

return _proc
CONSTANTS = (
"z",
"sdor",
"slor",
)


class OpenDataInput(RequestBasedInput):
Expand All @@ -84,6 +34,7 @@ class OpenDataInput(RequestBasedInput):
(0.25, 0.25): ("0p25", (0.25, 0.25), False),
"N320": ("0p25", (0.25, 0.25), True),
"O96": ("0p25", (0.25, 0.25), True),
# (0.1, 0.1): ("0p25", (0.25, 0.25), False),
}

def __init__(self, owner, **kwargs):
Expand All @@ -94,6 +45,9 @@ def _adjust(self, kwargs):
# OpenData uses levelist instead of level
kwargs["levelist"] = kwargs.pop("level")

if "area" in kwargs:
kwargs.pop("area")

grid = kwargs.pop("grid")
if isinstance(grid, list):
grid = tuple(grid)
Expand All @@ -103,18 +57,21 @@ def _adjust(self, kwargs):
r.update(self.owner.retrieve)

if interp:
logging.debug("Interpolating from %s to %s", source, grid)
return Interpolate(grid, source)
else:
return _noop
return lambda x: x

def pl_load_source(self, **kwargs):
pproc = self._adjust(kwargs)
kwargs["levtype"] = "pl"
request = kwargs.copy()

param = [p.lower() for p in kwargs["param"]]
assert isinstance(param, (list, tuple))

if "z" in param:
logging.warning("Parameter 'z' on pressure levels is not available in open data, using 'gh' instead")
param = list(param)
param.remove("z")
if "gh" not in param:
Expand All @@ -123,16 +80,96 @@ def pl_load_source(self, **kwargs):
pproc = make_z_from_gh(pproc)

logging.debug("load source ecmwf-open-data %s", kwargs)
return pproc(ekd.from_source("ecmwf-open-data", **kwargs))
return self.check_pl(pproc(ekd.from_source("ecmwf-open-data", **kwargs)), request)

def sfc_load_source(self, **kwargs):
pproc = self._adjust(kwargs)
kwargs["levtype"] = "sfc"
request = kwargs.copy()

param = [p.lower() for p in kwargs["param"]]
assert isinstance(param, (list, tuple))

constant_params = []
param = list(param)
for c in CONSTANTS:
if c in param:
param.remove(c)
constant_params.append(c)

constants = ekd.from_source("empty")

if constant_params:
if len(constant_params) == 1:
logging.warning(
f"Single level parameter '{constant_params[0]}' is not available in open data, using constants.grib2 instead"
)
else:
logging.warning(
f"Single level parameters {constant_params} are not available in open data, using constants.grib2 instead"
)
constants = []
ds = ekd.from_source("file", "constants.grib2")
ds = ds.sel(param=constant_params)

date = int(kwargs["date"])
time = int(kwargs["time"])
if time < 100:
time *= 100
step = int(kwargs.get("step", 0))
valid = datetime.datetime(
date // 10000, date // 100 % 100, date % 100, time // 100, time % 100
) + datetime.timedelta(hours=step)

for f in ds:

# assert False, (date, time, step)
constants.append(
NewMetadataField(f, valid_datetime=str(valid), date=date, time="%4d" % (time,), step=step)
)

constants = FieldArray(constants)

kwargs["param"] = param

logging.debug("load source ecmwf-open-data %s", kwargs)
return pproc(ekd.from_source("ecmwf-open-data", **kwargs))

return self.check_sfc(pproc(ekd.from_source("ecmwf-open-data", **kwargs) + constants), request)

def ml_load_source(self, **kwargs):
pproc = self._adjust(kwargs)
kwargs["levtype"] = "ml"
request = kwargs.copy()

logging.debug("load source ecmwf-open-data %s", kwargs)
return pproc(ekd.from_source("ecmwf-open-data", **kwargs))
return self.check_ml(pproc(ekd.from_source("ecmwf-open-data", kwargs)), request)

def check_pl(self, ds, request):
self._check(ds, "PL", request, "param", "levelist")
return ds

def check_sfc(self, ds, request):
self._check(ds, "SFC", request, "param")
return ds

def check_ml(self, ds, request):
self._check(ds, "ML", request, "param", "levelist")
return ds

def _check(self, ds, what, request, *keys):
print("CHECKING", what)
expected = set()
for p in itertools.product(*[request[key] for key in keys]):
expected.add(p)

found = set()
for f in ds:
found.add(tuple(f.metadata(key) for key in keys))

missing = expected - found
if missing:
raise ValueError(f"The following {what} parameters {missing} are not available in open data")

extra = found - expected
if extra:
raise ValueError(f"Unexpected {what} parameters {extra} from open data")
47 changes: 47 additions & 0 deletions src/ai_models/inputs/transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import logging

LOG = logging.getLogger(__name__)


class NewDataField:
def __init__(self, field, data):
self._field = field
self._data = data

def to_numpy(self, flatten=False, dtype=None, index=None):
data = self._data
if dtype is not None:
data = data.astype(dtype)
if flatten:
data = data.flatten()
return data

def __getattr__(self, name):
return getattr(self._field, name)

def __repr__(self) -> str:
return repr(self._field)


class NewMetadataField:
def __init__(self, field, **kwargs):
self._field = field
self._metadata = kwargs

def __getattr__(self, name):
return getattr(self._field, name)

def __repr__(self) -> str:
return repr(self._field)

def metadata(self, name, **kwargs):
if name in self._metadata:
return self._metadata[name]
return self._field.metadata(name, **kwargs)

0 comments on commit baba519

Please sign in to comment.