Skip to content

Commit

Permalink
refactor input
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Sep 13, 2024
1 parent e6b6e9a commit 084a2e7
Show file tree
Hide file tree
Showing 8 changed files with 357 additions and 191 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,4 @@ bar
dev/
*.out
_version.py
*.tar
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ ai-models = "ai_models.__main__:main"
version_file = "src/ai_models/_version.py"

[project.entry-points."ai_models.input"]
file = "ai_models.inputs:FileInput"
mars = "ai_models.inputs:MarsInput"
cds = "ai_models.inputs:CdsInput"
opendata = "ai_models.inputs:OpenDataInput"
file = "ai_models.inputs.file:FileInput"
mars = "ai_models.inputs.mars:MarsInput"
cds = "ai_models.inputs.cds:CdsInput"
opendata = "ai_models.inputs.opendata:OpenDataInput"

[project.entry-points."ai_models.output"]
file = "ai_models.outputs:FileOutput"
Expand Down
189 changes: 2 additions & 187 deletions src/ai_models/inputs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,198 +9,13 @@
from functools import cached_property

import earthkit.data as ekd
import earthkit.regrid as ekr
import entrypoints
from earthkit.data.indexing.fieldlist import FieldArray

LOG = logging.getLogger(__name__)


class RequestBasedInput:
def __init__(self, owner, **kwargs):
self.owner = owner

def _patch(self, **kargs):
r = dict(**kargs)
self.owner.patch_retrieve_request(r)
return r

@cached_property
def fields_sfc(self):
param = self.owner.param_sfc
if not param:
return ekd.from_source("empty")

LOG.info(f"Loading surface fields from {self.WHERE}")

return ekd.from_source(
"multi",
[
self.sfc_load_source(
**self._patch(
date=date,
time=time,
param=param,
grid=self.owner.grid,
area=self.owner.area,
**self.owner.retrieve,
)
)
for date, time in self.owner.datetimes()
],
)

@cached_property
def fields_pl(self):
param, level = self.owner.param_level_pl
if not (param and level):
return ekd.from_source("empty")

LOG.info(f"Loading pressure fields from {self.WHERE}")
return ekd.from_source(
"multi",
[
self.pl_load_source(
**self._patch(
date=date,
time=time,
param=param,
level=level,
grid=self.owner.grid,
area=self.owner.area,
)
)
for date, time in self.owner.datetimes()
],
)

@cached_property
def fields_ml(self):
param, level = self.owner.param_level_ml
if not (param and level):
return ekd.from_source("empty")

LOG.info(f"Loading model fields from {self.WHERE}")
return ekd.from_source(
"multi",
[
self.ml_load_source(
**self._patch(
date=date,
time=time,
param=param,
level=level,
grid=self.owner.grid,
area=self.owner.area,
)
)
for date, time in self.owner.datetimes()
],
)

@cached_property
def all_fields(self):
return self.fields_sfc + self.fields_pl + self.fields_ml


class MarsInput(RequestBasedInput):
WHERE = "MARS"

def __init__(self, owner, **kwargs):
self.owner = owner

def pl_load_source(self, **kwargs):
kwargs["levtype"] = "pl"
logging.debug("load source mars %s", kwargs)
return ekd.from_source("mars", kwargs)

def sfc_load_source(self, **kwargs):
kwargs["levtype"] = "sfc"
logging.debug("load source mars %s", kwargs)
return ekd.from_source("mars", kwargs)

def ml_load_source(self, **kwargs):
kwargs["levtype"] = "ml"
logging.debug("load source mars %s", kwargs)
return ekd.from_source("mars", kwargs)


class CdsInput(RequestBasedInput):
WHERE = "CDS"

def pl_load_source(self, **kwargs):
kwargs["product_type"] = "reanalysis"
return ekd.from_source("cds", "reanalysis-era5-pressure-levels", kwargs)

def sfc_load_source(self, **kwargs):
kwargs["product_type"] = "reanalysis"
return ekd.from_source("cds", "reanalysis-era5-single-levels", kwargs)

def ml_load_source(self, **kwargs):
raise NotImplementedError("CDS does not support model levels")


class OpenDataInput(RequestBasedInput):
WHERE = "OPENDATA"

RESOLS = {(0.25, 0.25): "0p25"}

def __init__(self, owner, **kwargs):
self.owner = owner

def _adjust(self, kwargs):
if "level" in kwargs:
# OpenData uses levelist instead of level
kwargs["levelist"] = kwargs.pop("level")

grid = kwargs.pop("grid")
if isinstance(grid, list):
grid = tuple(grid)

kwargs["resol"] = self.RESOLS[grid]
r = dict(**kwargs)
r.update(self.owner.retrieve)
return r

def pl_load_source(self, **kwargs):
self._adjust(kwargs)
kwargs["levtype"] = "pl"
logging.debug("load source ecmwf-open-data %s", kwargs)
return ekd.from_source("ecmwf-open-data", **kwargs)

def sfc_load_source(self, **kwargs):
self._adjust(kwargs)
kwargs["levtype"] = "sfc"
logging.debug("load source ecmwf-open-data %s", kwargs)
return ekd.from_source("ecmwf-open-data", **kwargs)

def ml_load_source(self, **kwargs):
self._adjust(kwargs)
kwargs["levtype"] = "ml"
logging.debug("load source ecmwf-open-data %s", kwargs)
return ekd.from_source("ecmwf-open-data", **kwargs)


class FileInput:
def __init__(self, owner, file, **kwargs):
self.file = file
self.owner = owner

@cached_property
def fields_sfc(self):
return self.all_fields.sel(levtype="sfc")

@cached_property
def fields_pl(self):
return self.all_fields.sel(levtype="pl")

@cached_property
def fields_ml(self):
return self.all_fields.sel(levtype="ml")

@cached_property
def all_fields(self):
return ekd.from_source("file", self.file)


def get_input(name, *args, **kwargs):
return available_inputs()[name].load()(*args, **kwargs)

Expand Down
100 changes: 100 additions & 0 deletions src/ai_models/inputs/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# (C) Copyright 2023 European Centre for Medium-Range Weather Forecasts.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import logging
from functools import cached_property

import earthkit.data as ekd

LOG = logging.getLogger(__name__)


class RequestBasedInput:
def __init__(self, owner, **kwargs):
self.owner = owner

def _patch(self, **kargs):
r = dict(**kargs)
self.owner.patch_retrieve_request(r)
return r

@cached_property
def fields_sfc(self):
param = self.owner.param_sfc
if not param:
return ekd.from_source("empty")

LOG.info(f"Loading surface fields from {self.WHERE}")

return ekd.from_source(
"multi",
[
self.sfc_load_source(
**self._patch(
date=date,
time=time,
param=param,
grid=self.owner.grid,
area=self.owner.area,
**self.owner.retrieve,
)
)
for date, time in self.owner.datetimes()
],
)

@cached_property
def fields_pl(self):
param, level = self.owner.param_level_pl
if not (param and level):
return ekd.from_source("empty")

LOG.info(f"Loading pressure fields from {self.WHERE}")
return ekd.from_source(
"multi",
[
self.pl_load_source(
**self._patch(
date=date,
time=time,
param=param,
level=level,
grid=self.owner.grid,
area=self.owner.area,
)
)
for date, time in self.owner.datetimes()
],
)

@cached_property
def fields_ml(self):
param, level = self.owner.param_level_ml
if not (param and level):
return ekd.from_source("empty")

LOG.info(f"Loading model fields from {self.WHERE}")
return ekd.from_source(
"multi",
[
self.ml_load_source(
**self._patch(
date=date,
time=time,
param=param,
level=level,
grid=self.owner.grid,
area=self.owner.area,
)
)
for date, time in self.owner.datetimes()
],
)

@cached_property
def all_fields(self):
return self.fields_sfc + self.fields_pl + self.fields_ml
29 changes: 29 additions & 0 deletions src/ai_models/inputs/cds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# (C) Copyright 2023 European Centre for Medium-Range Weather Forecasts.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import logging

import earthkit.data as ekd

from .base import RequestBasedInput

LOG = logging.getLogger(__name__)


class CdsInput(RequestBasedInput):
WHERE = "CDS"

def pl_load_source(self, **kwargs):
kwargs["product_type"] = "reanalysis"
return ekd.from_source("cds", "reanalysis-era5-pressure-levels", kwargs)

def sfc_load_source(self, **kwargs):
kwargs["product_type"] = "reanalysis"
return ekd.from_source("cds", "reanalysis-era5-single-levels", kwargs)

def ml_load_source(self, **kwargs):
raise NotImplementedError("CDS does not support model levels")
47 changes: 47 additions & 0 deletions src/ai_models/inputs/file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# (C) Copyright 2023 European Centre for Medium-Range Weather Forecasts.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import logging
from functools import cached_property

import earthkit.data as ekd
import entrypoints

LOG = logging.getLogger(__name__)


class FileInput:
def __init__(self, owner, file, **kwargs):
self.file = file
self.owner = owner

@cached_property
def fields_sfc(self):
return self.all_fields.sel(levtype="sfc")

@cached_property
def fields_pl(self):
return self.all_fields.sel(levtype="pl")

@cached_property
def fields_ml(self):
return self.all_fields.sel(levtype="ml")

@cached_property
def all_fields(self):
return ekd.from_source("file", self.file)


def get_input(name, *args, **kwargs):
return available_inputs()[name].load()(*args, **kwargs)


def available_inputs():
result = {}
for e in entrypoints.get_group_all("ai_models.input"):
result[e.name] = e
return result
Loading

0 comments on commit 084a2e7

Please sign in to comment.